mirror of
https://github.com/tcsenpai/src.git
synced 2025-06-03 18:00:25 +00:00
First commit
This commit is contained in:
commit
d2cea458d4
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
__pycache__
|
||||
*.db
|
5
pyflaredb/__init__.py
Normal file
5
pyflaredb/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
from .core import PyFlareDB
|
||||
from .table import Table, Column
|
||||
from .versioning import Version, VersionStore
|
||||
|
||||
__all__ = ['PyFlareDB', 'Table', 'Column', 'Version', 'VersionStore']
|
125
pyflaredb/benchmark/suite.py
Normal file
125
pyflaredb/benchmark/suite.py
Normal file
@ -0,0 +1,125 @@
|
||||
import time
|
||||
from typing import List, Dict, Any
|
||||
import random
|
||||
import string
|
||||
from ..core import PyFlareDB
|
||||
|
||||
|
||||
class BenchmarkSuite:
|
||||
def __init__(self, db: PyFlareDB):
|
||||
self.db = db
|
||||
|
||||
def run_benchmark(self, num_records: int = 10000):
|
||||
"""Run comprehensive benchmark"""
|
||||
results = {
|
||||
"insert": self._benchmark_insert(num_records),
|
||||
"select": self._benchmark_select(num_records),
|
||||
"index": self._benchmark_index_performance(num_records),
|
||||
"complex_query": self._benchmark_complex_queries(num_records),
|
||||
}
|
||||
return results
|
||||
|
||||
def _benchmark_insert(self, num_records: int) -> Dict[str, float]:
|
||||
start_time = time.time()
|
||||
batch_times = []
|
||||
|
||||
for i in range(0, num_records, 1000):
|
||||
batch_start = time.time()
|
||||
self._insert_batch(min(1000, num_records - i))
|
||||
batch_times.append(time.time() - batch_start)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
return {
|
||||
"total_time": total_time,
|
||||
"records_per_second": num_records / total_time,
|
||||
"avg_batch_time": sum(batch_times) / len(batch_times),
|
||||
}
|
||||
|
||||
def _insert_batch(self, size: int):
|
||||
"""Insert a batch of random records"""
|
||||
tx_id = None
|
||||
try:
|
||||
tx_id = self.db.transaction_manager.begin_transaction()
|
||||
|
||||
for _ in range(size):
|
||||
query = (
|
||||
"INSERT INTO users (id, username, email, age) "
|
||||
f"VALUES ('{self._random_string(10)}', "
|
||||
f"'{self._random_string(8)}', "
|
||||
f"'{self._random_string(8)}@example.com', "
|
||||
f"{random.randint(18, 80)})"
|
||||
)
|
||||
try:
|
||||
self.db.execute(query)
|
||||
except Exception as e:
|
||||
print(f"Failed to insert record: {e}")
|
||||
print(f"Query was: {query}")
|
||||
raise
|
||||
|
||||
self.db.transaction_manager.commit(tx_id)
|
||||
|
||||
except Exception as e:
|
||||
if tx_id is not None:
|
||||
try:
|
||||
self.db.transaction_manager.rollback(tx_id)
|
||||
except ValueError:
|
||||
pass
|
||||
raise
|
||||
|
||||
def _benchmark_select(self, num_records: int) -> Dict[str, float]:
|
||||
"""Benchmark SELECT queries"""
|
||||
queries = [
|
||||
"SELECT * FROM users WHERE age > 30",
|
||||
"SELECT username, email FROM users WHERE age < 25",
|
||||
"SELECT COUNT(*) FROM users",
|
||||
]
|
||||
|
||||
results = {}
|
||||
for query in queries:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.db.execute(query)
|
||||
query_time = time.time() - start_time
|
||||
results[query] = query_time
|
||||
except Exception as e:
|
||||
results[query] = f"Error: {e}"
|
||||
|
||||
return results
|
||||
|
||||
def _benchmark_index_performance(self, num_records: int) -> Dict[str, float]:
|
||||
"""Benchmark index performance"""
|
||||
# TODO: Implement index benchmarking
|
||||
return {"index_creation_time": 0.0}
|
||||
|
||||
def _benchmark_complex_queries(self, num_records: int) -> Dict[str, float]:
|
||||
"""Benchmark complex queries"""
|
||||
complex_queries = [
|
||||
"""
|
||||
SELECT username, COUNT(*) as count
|
||||
FROM users
|
||||
GROUP BY username
|
||||
""",
|
||||
"""
|
||||
SELECT * FROM users
|
||||
WHERE age > 30
|
||||
ORDER BY username DESC
|
||||
LIMIT 10
|
||||
""",
|
||||
]
|
||||
|
||||
results = {}
|
||||
for query in complex_queries:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.db.execute(query)
|
||||
query_time = time.time() - start_time
|
||||
results[query.strip()] = query_time
|
||||
except Exception as e:
|
||||
results[query.strip()] = f"Error: {e}"
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _random_string(length: int) -> str:
|
||||
"""Generate a random string of specified length"""
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=length))
|
39
pyflaredb/cache/query_cache.py
vendored
Normal file
39
pyflaredb/cache/query_cache.py
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import time
|
||||
import hashlib
|
||||
from collections import OrderedDict
|
||||
|
||||
class QueryCache:
|
||||
def __init__(self, capacity: int = 1000, ttl: int = 300):
|
||||
self.capacity = capacity
|
||||
self.ttl = ttl
|
||||
self.cache = OrderedDict()
|
||||
|
||||
def get(self, query: str) -> Optional[Any]:
|
||||
"""Get cached query result"""
|
||||
query_hash = self._hash_query(query)
|
||||
if query_hash in self.cache:
|
||||
entry = self.cache[query_hash]
|
||||
if time.time() - entry['timestamp'] < self.ttl:
|
||||
self.cache.move_to_end(query_hash)
|
||||
return entry['result']
|
||||
else:
|
||||
del self.cache[query_hash]
|
||||
return None
|
||||
|
||||
def set(self, query: str, result: Any):
|
||||
"""Cache query result"""
|
||||
if len(self.cache) >= self.capacity:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
query_hash = self._hash_query(query)
|
||||
self.cache[query_hash] = {
|
||||
'result': result,
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
def _hash_query(self, query: str) -> str:
|
||||
return hashlib.sha256(query.encode()).hexdigest()
|
||||
|
||||
def clear(self):
|
||||
self.cache.clear()
|
90
pyflaredb/core.py
Normal file
90
pyflaredb/core.py
Normal file
@ -0,0 +1,90 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
from .table import Table
|
||||
from .sql.parser import SQLParser, SelectStatement, InsertStatement
|
||||
from .sql.executor import QueryExecutor
|
||||
from .sql.optimizer import QueryOptimizer
|
||||
from .sql.statistics import TableStatistics
|
||||
from .transaction import TransactionManager
|
||||
|
||||
|
||||
class PyFlareDB:
|
||||
def __init__(self, db_path: str):
|
||||
"""Initialize the database"""
|
||||
self.db_path = db_path
|
||||
self.tables: Dict[str, Table] = {}
|
||||
self.parser = SQLParser()
|
||||
self.statistics = TableStatistics()
|
||||
self.optimizer = QueryOptimizer(self.tables, self.statistics)
|
||||
self.executor = QueryExecutor(self.tables)
|
||||
self.transaction_manager = TransactionManager()
|
||||
self._query_cache = {}
|
||||
|
||||
def begin_transaction(self) -> str:
|
||||
"""Begin a new transaction"""
|
||||
return self.transaction_manager.begin_transaction()
|
||||
|
||||
def commit_transaction(self, tx_id: str) -> bool:
|
||||
"""Commit a transaction"""
|
||||
return self.transaction_manager.commit(tx_id)
|
||||
|
||||
def rollback_transaction(self, tx_id: str) -> bool:
|
||||
"""Rollback a transaction"""
|
||||
return self.transaction_manager.rollback(tx_id)
|
||||
|
||||
def create_table(self, table: Table) -> None:
|
||||
"""Create a new table"""
|
||||
if table.name in self.tables:
|
||||
raise ValueError(f"Table {table.name} already exists")
|
||||
self.tables[table.name] = table
|
||||
self.statistics.collect_statistics(table)
|
||||
|
||||
def drop_table(self, table_name: str) -> None:
|
||||
"""Drop a table"""
|
||||
if table_name not in self.tables:
|
||||
raise ValueError(f"Table {table_name} does not exist")
|
||||
del self.tables[table_name]
|
||||
|
||||
def execute(
|
||||
self, sql: str, tx_id: Optional[str] = None
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Execute a SQL query"""
|
||||
try:
|
||||
# Check query cache for non-transactional SELECT queries
|
||||
if tx_id is None and sql in self._query_cache:
|
||||
return self._query_cache[sql]
|
||||
|
||||
# Parse SQL
|
||||
if sql.strip().upper().startswith("SELECT"):
|
||||
statement = self.parser.parse_select(sql)
|
||||
elif sql.strip().upper().startswith("INSERT"):
|
||||
statement = self.parser.parse_insert(sql)
|
||||
else:
|
||||
raise ValueError("Unsupported SQL statement type")
|
||||
|
||||
# Get transaction if provided
|
||||
tx = None
|
||||
if tx_id:
|
||||
tx = self.transaction_manager.get_transaction(tx_id)
|
||||
if not tx:
|
||||
raise ValueError(f"Transaction {tx_id} does not exist")
|
||||
|
||||
# Optimize query plan
|
||||
optimized_plan = self.optimizer.optimize(statement)
|
||||
|
||||
# Execute query
|
||||
result = self.executor.execute(optimized_plan, transaction=tx)
|
||||
|
||||
# Cache SELECT results for non-transactional queries
|
||||
if tx_id is None and isinstance(statement, SelectStatement):
|
||||
self._query_cache[sql] = result
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Clear cache on error
|
||||
self._query_cache.clear()
|
||||
raise e
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""Clear the query cache"""
|
||||
self._query_cache.clear()
|
137
pyflaredb/indexing/btree.py
Normal file
137
pyflaredb/indexing/btree.py
Normal file
@ -0,0 +1,137 @@
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
from collections import deque
|
||||
|
||||
@dataclass
|
||||
class Node:
|
||||
keys: List[Any]
|
||||
values: List[List[int]] # List of row IDs for each key (handling duplicates)
|
||||
children: List['Node']
|
||||
is_leaf: bool = True
|
||||
|
||||
class BTreeIndex:
|
||||
def __init__(self, order: int = 100):
|
||||
self.root = Node([], [], [])
|
||||
self.order = order # Maximum number of children per node
|
||||
|
||||
def insert(self, key: Any, row_id: int) -> None:
|
||||
"""Insert a key-value pair into the B-tree"""
|
||||
if len(self.root.keys) == (2 * self.order) - 1:
|
||||
# Split root if full
|
||||
new_root = Node([], [], [], False)
|
||||
new_root.children.append(self.root)
|
||||
self._split_child(new_root, 0)
|
||||
self.root = new_root
|
||||
self._insert_non_full(self.root, key, row_id)
|
||||
|
||||
def search(self, key: Any) -> List[int]:
|
||||
"""Search for a key and return all matching row IDs"""
|
||||
return self._search_node(self.root, key)
|
||||
|
||||
def range_search(self, start_key: Any, end_key: Any) -> List[int]:
|
||||
"""Perform a range search and return all matching row IDs"""
|
||||
result = []
|
||||
self._range_search_node(self.root, start_key, end_key, result)
|
||||
return result
|
||||
|
||||
def _split_child(self, parent: Node, child_index: int) -> None:
|
||||
"""Split a full child node"""
|
||||
order = self.order
|
||||
child = parent.children[child_index]
|
||||
new_node = Node([], [], [], child.is_leaf)
|
||||
|
||||
# Move the median key to the parent
|
||||
median = order - 1
|
||||
parent.keys.insert(child_index, child.keys[median])
|
||||
parent.values.insert(child_index, child.values[median])
|
||||
parent.children.insert(child_index + 1, new_node)
|
||||
|
||||
# Move half of the keys to the new node
|
||||
new_node.keys = child.keys[median + 1:]
|
||||
new_node.values = child.values[median + 1:]
|
||||
child.keys = child.keys[:median]
|
||||
child.values = child.values[:median]
|
||||
|
||||
# Move children if not a leaf
|
||||
if not child.is_leaf:
|
||||
new_node.children = child.children[median + 1:]
|
||||
child.children = child.children[:median + 1]
|
||||
|
||||
def _insert_non_full(self, node: Node, key: Any, row_id: int) -> None:
|
||||
"""Insert into a non-full node"""
|
||||
i = len(node.keys) - 1
|
||||
|
||||
if node.is_leaf:
|
||||
# Insert into leaf node
|
||||
while i >= 0 and self._compare_keys(key, node.keys[i]) < 0:
|
||||
i -= 1
|
||||
i += 1
|
||||
|
||||
# Handle duplicate keys
|
||||
if i > 0 and self._compare_keys(key, node.keys[i-1]) == 0:
|
||||
node.values[i-1].append(row_id)
|
||||
else:
|
||||
node.keys.insert(i, key)
|
||||
node.values.insert(i, [row_id])
|
||||
else:
|
||||
# Find the child to insert into
|
||||
while i >= 0 and self._compare_keys(key, node.keys[i]) < 0:
|
||||
i -= 1
|
||||
i += 1
|
||||
|
||||
if len(node.children[i].keys) == (2 * self.order) - 1:
|
||||
self._split_child(node, i)
|
||||
if self._compare_keys(key, node.keys[i]) > 0:
|
||||
i += 1
|
||||
|
||||
self._insert_non_full(node.children[i], key, row_id)
|
||||
|
||||
def _search_node(self, node: Node, key: Any) -> List[int]:
|
||||
"""Search for a key in a node"""
|
||||
i = 0
|
||||
while i < len(node.keys) and self._compare_keys(key, node.keys[i]) > 0:
|
||||
i += 1
|
||||
|
||||
if i < len(node.keys) and self._compare_keys(key, node.keys[i]) == 0:
|
||||
return node.values[i]
|
||||
elif node.is_leaf:
|
||||
return []
|
||||
else:
|
||||
return self._search_node(node.children[i], key)
|
||||
|
||||
def _range_search_node(self, node: Node, start_key: Any, end_key: Any, result: List[int]) -> None:
|
||||
"""Perform range search on a node"""
|
||||
i = 0
|
||||
while i < len(node.keys) and self._compare_keys(start_key, node.keys[i]) > 0:
|
||||
i += 1
|
||||
|
||||
if node.is_leaf:
|
||||
while i < len(node.keys) and self._compare_keys(node.keys[i], end_key) <= 0:
|
||||
result.extend(node.values[i])
|
||||
i += 1
|
||||
else:
|
||||
if i < len(node.keys):
|
||||
self._range_search_node(node.children[i], start_key, end_key, result)
|
||||
while i < len(node.keys) and self._compare_keys(node.keys[i], end_key) <= 0:
|
||||
result.extend(node.values[i])
|
||||
i += 1
|
||||
if i < len(node.children):
|
||||
self._range_search_node(node.children[i], start_key, end_key, result)
|
||||
|
||||
@staticmethod
|
||||
def _compare_keys(key1: Any, key2: Any) -> int:
|
||||
"""Compare two keys, handling different types"""
|
||||
if key1 is None or key2 is None:
|
||||
if key1 is None and key2 is None:
|
||||
return 0
|
||||
return -1 if key1 is None else 1
|
||||
|
||||
try:
|
||||
if key1 < key2:
|
||||
return -1
|
||||
elif key1 > key2:
|
||||
return 1
|
||||
return 0
|
||||
except TypeError:
|
||||
# Handle incomparable types
|
||||
return 0
|
33
pyflaredb/monitoring/metrics.py
Normal file
33
pyflaredb/monitoring/metrics.py
Normal file
@ -0,0 +1,33 @@
|
||||
from typing import Dict, List
|
||||
import time
|
||||
from collections import deque
|
||||
import threading
|
||||
|
||||
|
||||
class PerformanceMetrics:
|
||||
def __init__(self, window_size: int = 1000):
|
||||
self.window_size = window_size
|
||||
self.query_times: Dict[str, deque] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def record_query(self, query_type: str, execution_time: float):
|
||||
"""Record query execution time"""
|
||||
with self.lock:
|
||||
if query_type not in self.query_times:
|
||||
self.query_times[query_type] = deque(maxlen=self.window_size)
|
||||
self.query_times[query_type].append(execution_time)
|
||||
|
||||
def get_metrics(self) -> Dict[str, Dict[str, float]]:
|
||||
"""Get performance metrics"""
|
||||
metrics = {}
|
||||
with self.lock:
|
||||
for query_type, times in self.query_times.items():
|
||||
if not times:
|
||||
continue
|
||||
metrics[query_type] = {
|
||||
"avg_time": sum(times) / len(times),
|
||||
"max_time": max(times),
|
||||
"min_time": min(times),
|
||||
"count": len(times),
|
||||
}
|
||||
return metrics
|
245
pyflaredb/sql/executor.py
Normal file
245
pyflaredb/sql/executor.py
Normal file
@ -0,0 +1,245 @@
|
||||
from typing import List, Dict, Any, Callable, Tuple, Optional
|
||||
import operator
|
||||
from ..table import Table
|
||||
from .parser import SelectStatement, InsertStatement
|
||||
from ..transaction import Transaction
|
||||
|
||||
|
||||
class QueryExecutor:
|
||||
def __init__(self, tables: Dict[str, Table]):
|
||||
self.tables = tables
|
||||
self._compiled_conditions = {}
|
||||
self._comparison_ops = {
|
||||
'>': operator.gt,
|
||||
'<': operator.lt,
|
||||
'>=': operator.ge,
|
||||
'<=': operator.le,
|
||||
'=': operator.eq,
|
||||
'!=': operator.ne
|
||||
}
|
||||
|
||||
def _parse_where_clause(self, where_clause: str) -> List[Tuple[str, str, str]]:
|
||||
"""Parse WHERE clause into list of (field, operator, value) tuples"""
|
||||
conditions = []
|
||||
# Split on AND if present
|
||||
subclauses = [c.strip() for c in where_clause.split(' AND ')]
|
||||
|
||||
for subclause in subclauses:
|
||||
# Find the operator
|
||||
operator_found = None
|
||||
for op in ['>=', '<=', '>', '<', '=', '!=']:
|
||||
if op in subclause:
|
||||
operator_found = op
|
||||
field, value = subclause.split(op)
|
||||
conditions.append((field.strip(), op, value.strip()))
|
||||
break
|
||||
|
||||
if not operator_found:
|
||||
raise ValueError(f"Invalid condition: {subclause}")
|
||||
|
||||
return conditions
|
||||
|
||||
def execute(self, statement, transaction: Optional[Transaction] = None):
|
||||
"""Execute a parsed SQL statement"""
|
||||
if isinstance(statement, SelectStatement):
|
||||
return self._execute_select(statement, transaction)
|
||||
elif isinstance(statement, InsertStatement):
|
||||
return self._execute_insert(statement, transaction)
|
||||
elif statement is None:
|
||||
raise ValueError("No statement to execute")
|
||||
else:
|
||||
raise ValueError(f"Unsupported statement type: {type(statement)}")
|
||||
|
||||
def _execute_select(self, stmt: SelectStatement, transaction: Optional[Transaction] = None) -> List[Dict[str, Any]]:
|
||||
if stmt.table_name not in self.tables:
|
||||
raise ValueError(f"Table {stmt.table_name} does not exist")
|
||||
|
||||
table = self.tables[stmt.table_name]
|
||||
|
||||
# If in transaction, check for locks
|
||||
if transaction and table.name in transaction.locks:
|
||||
# Handle transaction isolation level logic here
|
||||
pass
|
||||
|
||||
# Handle COUNT(*) separately
|
||||
if len(stmt.columns) == 1 and stmt.columns[0].lower() == "count(*)":
|
||||
return [{"count": len(table.data)}]
|
||||
|
||||
# Try to use index for WHERE clause
|
||||
if stmt.where_clause:
|
||||
try:
|
||||
conditions = self._parse_where_clause(stmt.where_clause)
|
||||
|
||||
# Check if we can use an index for any condition
|
||||
for field, op, value in conditions:
|
||||
if field in table._indexes:
|
||||
# Convert value to proper type
|
||||
column = next((col for col in table.columns if col.name == field), None)
|
||||
if column:
|
||||
try:
|
||||
if column.data_type == "integer":
|
||||
value = int(value)
|
||||
elif column.data_type == "float":
|
||||
value = float(value)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Use index for lookup
|
||||
if op == '=':
|
||||
results = table.find_by_index(field, value)
|
||||
elif op in {'>', '>='}:
|
||||
results = table.range_search(field, value, None)
|
||||
elif op in {'<', '<='}:
|
||||
results = table.range_search(field, None, value)
|
||||
else: # op == '!='
|
||||
# For inequality, we still need to scan
|
||||
results = table.data
|
||||
|
||||
# Apply remaining conditions
|
||||
filtered_results = []
|
||||
for row in results:
|
||||
if self._matches_all_conditions(row, conditions):
|
||||
filtered_results.append(row)
|
||||
|
||||
return self._process_results(filtered_results, stmt)
|
||||
except ValueError:
|
||||
# If WHERE clause parsing fails, fall back to table scan
|
||||
pass
|
||||
|
||||
# Fall back to full table scan
|
||||
return self._table_scan(table, stmt)
|
||||
|
||||
def _matches_all_conditions(self, row: Dict[str, Any], conditions: List[Tuple[str, str, str]]) -> bool:
|
||||
"""Check if row matches all conditions"""
|
||||
for field, op, value in conditions:
|
||||
row_value = row.get(field)
|
||||
if row_value is None:
|
||||
return False
|
||||
|
||||
# Convert value to proper type based on row_value
|
||||
try:
|
||||
if isinstance(row_value, int):
|
||||
value = int(value)
|
||||
elif isinstance(row_value, float):
|
||||
value = float(value)
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
|
||||
# Apply comparison
|
||||
op_func = self._comparison_ops[op]
|
||||
try:
|
||||
if not op_func(row_value, value):
|
||||
return False
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _table_scan(self, table: Table, stmt: SelectStatement) -> List[Dict[str, Any]]:
|
||||
"""Perform a full table scan with filtering"""
|
||||
results = []
|
||||
|
||||
# Parse WHERE conditions if present
|
||||
conditions = []
|
||||
if stmt.where_clause:
|
||||
try:
|
||||
conditions = self._parse_where_clause(stmt.where_clause)
|
||||
except ValueError:
|
||||
# If parsing fails, return empty result
|
||||
return []
|
||||
|
||||
# Process rows
|
||||
for row in table.data:
|
||||
# Apply WHERE clause
|
||||
if conditions and not self._matches_all_conditions(row, conditions):
|
||||
continue
|
||||
|
||||
# Select requested columns
|
||||
if "*" in stmt.columns:
|
||||
results.append(row.copy())
|
||||
else:
|
||||
filtered_row = {}
|
||||
for col in stmt.columns:
|
||||
if "count(" in col.lower():
|
||||
filtered_row[col] = len(results)
|
||||
else:
|
||||
filtered_row[col] = row.get(col)
|
||||
results.append(filtered_row)
|
||||
|
||||
return self._process_results(results, stmt)
|
||||
|
||||
def _process_results(self, rows: List[Dict[str, Any]], stmt: SelectStatement) -> List[Dict[str, Any]]:
|
||||
"""Process result rows according to SELECT statement"""
|
||||
results = []
|
||||
for row in rows:
|
||||
if "*" in stmt.columns:
|
||||
results.append(row.copy())
|
||||
else:
|
||||
filtered_row = {}
|
||||
for col in stmt.columns:
|
||||
if "count(" in col.lower():
|
||||
filtered_row[col] = len(results)
|
||||
else:
|
||||
filtered_row[col] = row.get(col)
|
||||
results.append(filtered_row)
|
||||
|
||||
# Handle ORDER BY
|
||||
if stmt.order_by:
|
||||
for order_clause in stmt.order_by:
|
||||
reverse = order_clause.direction.value == "DESC"
|
||||
results.sort(
|
||||
key=lambda x: (x.get(order_clause.column) is None, x.get(order_clause.column)),
|
||||
reverse=reverse
|
||||
)
|
||||
|
||||
# Handle LIMIT
|
||||
if stmt.limit is not None:
|
||||
results = results[:stmt.limit]
|
||||
|
||||
return results
|
||||
|
||||
def _execute_insert(self, stmt: InsertStatement, transaction: Optional[Transaction] = None) -> bool:
|
||||
if stmt.table_name not in self.tables:
|
||||
raise ValueError(f"Table {stmt.table_name} does not exist")
|
||||
|
||||
table = self.tables[stmt.table_name]
|
||||
|
||||
# If in transaction, acquire lock and track changes
|
||||
if transaction:
|
||||
transaction.locks.add(table.name)
|
||||
# Track the changes for potential rollback
|
||||
transaction.changes.append({
|
||||
'type': 'INSERT',
|
||||
'table': table.name,
|
||||
'data': dict(zip(stmt.columns, stmt.values))
|
||||
})
|
||||
|
||||
# Create dictionary of column-value pairs
|
||||
row_data = {}
|
||||
for col_name, value in zip(stmt.columns, stmt.values):
|
||||
# Find the column definition
|
||||
column = next((col for col in table.columns if col.name == col_name), None)
|
||||
if not column:
|
||||
raise ValueError(f"Column {col_name} does not exist")
|
||||
|
||||
# Convert value based on column type
|
||||
if value is not None:
|
||||
try:
|
||||
if column.data_type == "integer":
|
||||
row_data[col_name] = int(value)
|
||||
elif column.data_type == "float":
|
||||
row_data[col_name] = float(value)
|
||||
elif column.data_type == "boolean":
|
||||
if isinstance(value, str):
|
||||
row_data[col_name] = value.lower() == 'true'
|
||||
else:
|
||||
row_data[col_name] = bool(value)
|
||||
else: # string type
|
||||
row_data[col_name] = str(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"Invalid value for column {column.name}: {value}")
|
||||
else:
|
||||
row_data[col_name] = None
|
||||
|
||||
# Insert the data
|
||||
return table.insert(row_data)
|
51
pyflaredb/sql/optimizer.py
Normal file
51
pyflaredb/sql/optimizer.py
Normal file
@ -0,0 +1,51 @@
|
||||
from typing import List, Dict, Any, Union
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from .parser import SelectStatement, InsertStatement
|
||||
|
||||
from pyflaredb.sql.statistics import TableStatistics
|
||||
from pyflaredb.table import Table
|
||||
|
||||
|
||||
class JoinStrategy(Enum):
|
||||
NESTED_LOOP = "nested_loop"
|
||||
HASH_JOIN = "hash_join"
|
||||
MERGE_JOIN = "merge_join"
|
||||
|
||||
|
||||
class ScanType(Enum):
|
||||
SEQUENTIAL = "sequential"
|
||||
INDEX = "index"
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueryPlan:
|
||||
operation: str
|
||||
strategy: Union[JoinStrategy, ScanType]
|
||||
estimated_cost: float
|
||||
children: List["QueryPlan"] = None
|
||||
|
||||
|
||||
class QueryOptimizer:
|
||||
def __init__(self, tables: Dict[str, "Table"], statistics: "TableStatistics"):
|
||||
self.tables = tables
|
||||
self.statistics = statistics
|
||||
|
||||
def optimize(self, statement) -> Any:
|
||||
"""Generate an optimized query plan"""
|
||||
if isinstance(statement, SelectStatement):
|
||||
return self._optimize_select(statement)
|
||||
elif isinstance(statement, InsertStatement):
|
||||
return statement # No optimization needed for simple inserts
|
||||
return statement # Return original statement if no optimization is needed
|
||||
|
||||
def _optimize_select(self, stmt: SelectStatement) -> SelectStatement:
|
||||
"""Optimize SELECT query execution"""
|
||||
# For now, return the original statement
|
||||
# TODO: Implement actual optimization strategies
|
||||
return stmt
|
||||
|
||||
def _estimate_cost(self, plan: QueryPlan) -> float:
|
||||
"""Estimate the cost of a query plan"""
|
||||
# Implementation for cost estimation
|
||||
pass
|
176
pyflaredb/sql/parser.py
Normal file
176
pyflaredb/sql/parser.py
Normal file
@ -0,0 +1,176 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Any
|
||||
from enum import Enum
|
||||
|
||||
class OrderDirection(Enum):
|
||||
ASC = "ASC"
|
||||
DESC = "DESC"
|
||||
|
||||
@dataclass
|
||||
class OrderByClause:
|
||||
column: str
|
||||
direction: OrderDirection = OrderDirection.ASC
|
||||
|
||||
@dataclass
|
||||
class SelectStatement:
|
||||
table_name: str
|
||||
columns: List[str]
|
||||
where_clause: Optional[str] = None
|
||||
group_by: Optional[List[str]] = None
|
||||
order_by: Optional[List[OrderByClause]] = None
|
||||
limit: Optional[int] = None
|
||||
|
||||
@dataclass
|
||||
class InsertStatement:
|
||||
table_name: str
|
||||
columns: List[str]
|
||||
values: List[Any]
|
||||
|
||||
class SQLParser:
|
||||
@staticmethod
|
||||
def parse_insert(sql: str) -> InsertStatement:
|
||||
"""Parse INSERT statement"""
|
||||
# Remove newlines and extra spaces
|
||||
sql = ' '.join(sql.split())
|
||||
|
||||
# Extract table name
|
||||
table_start = sql.find("INTO") + 4
|
||||
table_end = sql.find("(", table_start)
|
||||
table_name = sql[table_start:table_end].strip()
|
||||
|
||||
# Extract columns
|
||||
cols_start = sql.find("(", table_end) + 1
|
||||
cols_end = sql.find(")", cols_start)
|
||||
columns = [col.strip() for col in sql[cols_start:cols_end].split(",")]
|
||||
|
||||
# Extract values
|
||||
values_start = sql.find("VALUES", cols_end) + 6
|
||||
values_start = sql.find("(", values_start) + 1
|
||||
values_end = sql.find(")", values_start)
|
||||
values_str = sql[values_start:values_end]
|
||||
|
||||
# Parse values while respecting quotes
|
||||
values = []
|
||||
current_value = ""
|
||||
in_quotes = False
|
||||
quote_char = None
|
||||
|
||||
for char in values_str:
|
||||
if char in ["'", '"']:
|
||||
if not in_quotes:
|
||||
in_quotes = True
|
||||
quote_char = char
|
||||
elif quote_char == char:
|
||||
in_quotes = False
|
||||
quote_char = None
|
||||
current_value += char
|
||||
elif char == ',' and not in_quotes:
|
||||
values.append(current_value.strip())
|
||||
current_value = ""
|
||||
else:
|
||||
current_value += char
|
||||
|
||||
if current_value:
|
||||
values.append(current_value.strip())
|
||||
|
||||
# Clean up values
|
||||
cleaned_values = []
|
||||
for value in values:
|
||||
value = value.strip()
|
||||
if value.startswith(("'", '"')) and value.endswith(("'", '"')):
|
||||
# String value - keep quotes
|
||||
cleaned_values.append(value)
|
||||
elif value.lower() == 'true':
|
||||
cleaned_values.append(True)
|
||||
elif value.lower() == 'false':
|
||||
cleaned_values.append(False)
|
||||
elif value.lower() == 'null':
|
||||
cleaned_values.append(None)
|
||||
else:
|
||||
try:
|
||||
# Try to convert to number if possible
|
||||
if '.' in value:
|
||||
cleaned_values.append(float(value))
|
||||
else:
|
||||
cleaned_values.append(int(value))
|
||||
except ValueError:
|
||||
# If not a number, keep as is
|
||||
cleaned_values.append(value)
|
||||
|
||||
if len(columns) != len(cleaned_values):
|
||||
raise ValueError(f"Column count ({len(columns)}) doesn't match value count ({len(cleaned_values)})")
|
||||
|
||||
return InsertStatement(table_name=table_name, columns=columns, values=cleaned_values)
|
||||
|
||||
@staticmethod
|
||||
def parse_select(sql: str) -> SelectStatement:
|
||||
"""Parse SELECT statement"""
|
||||
# Remove newlines and extra spaces
|
||||
sql = ' '.join(sql.split())
|
||||
|
||||
# Extract table name
|
||||
from_idx = sql.upper().find("FROM")
|
||||
if from_idx == -1:
|
||||
raise ValueError("Invalid SELECT statement: missing FROM clause")
|
||||
|
||||
# Extract columns
|
||||
columns_str = sql[6:from_idx].strip()
|
||||
columns = [col.strip() for col in columns_str.split(",")]
|
||||
|
||||
# Find all clause positions
|
||||
where_idx = sql.upper().find("WHERE")
|
||||
group_idx = sql.upper().find("GROUP BY")
|
||||
order_idx = sql.upper().find("ORDER BY")
|
||||
limit_idx = sql.upper().find("LIMIT")
|
||||
|
||||
# Find table name end position
|
||||
table_end = min(x for x in [where_idx, group_idx, order_idx, limit_idx] if x != -1) if any(x != -1 for x in [where_idx, group_idx, order_idx, limit_idx]) else len(sql)
|
||||
table_name = sql[from_idx + 4:table_end].strip()
|
||||
|
||||
# Parse WHERE clause
|
||||
where_clause = None
|
||||
if where_idx != -1:
|
||||
where_end = min(x for x in [group_idx, order_idx, limit_idx] if x != -1) if any(x != -1 for x in [group_idx, order_idx, limit_idx]) else len(sql)
|
||||
where_clause = sql[where_idx + 5:where_end].strip()
|
||||
|
||||
# Parse GROUP BY clause
|
||||
group_by = None
|
||||
if group_idx != -1:
|
||||
group_end = min(x for x in [order_idx, limit_idx] if x != -1) if any(x != -1 for x in [order_idx, limit_idx]) else len(sql)
|
||||
group_by_str = sql[group_idx + 8:group_end].strip()
|
||||
group_by = [col.strip() for col in group_by_str.split(",")]
|
||||
|
||||
# Parse ORDER BY clause
|
||||
order_by = None
|
||||
if order_idx != -1:
|
||||
order_end = limit_idx if limit_idx != -1 else len(sql)
|
||||
order_str = sql[order_idx + 8:order_end].strip()
|
||||
order_parts = order_str.split(",")
|
||||
order_by = []
|
||||
for part in order_parts:
|
||||
part = part.strip()
|
||||
if " DESC" in part.upper():
|
||||
column = part[:part.upper().find(" DESC")].strip()
|
||||
direction = OrderDirection.DESC
|
||||
else:
|
||||
column = part.replace(" ASC", "").strip()
|
||||
direction = OrderDirection.ASC
|
||||
order_by.append(OrderByClause(column=column, direction=direction))
|
||||
|
||||
# Parse LIMIT clause
|
||||
limit = None
|
||||
if limit_idx != -1:
|
||||
limit_str = sql[limit_idx + 5:].strip()
|
||||
try:
|
||||
limit = int(limit_str)
|
||||
except ValueError:
|
||||
raise ValueError(f"Invalid LIMIT value: {limit_str}")
|
||||
|
||||
return SelectStatement(
|
||||
table_name=table_name,
|
||||
columns=columns,
|
||||
where_clause=where_clause,
|
||||
group_by=group_by,
|
||||
order_by=order_by,
|
||||
limit=limit
|
||||
)
|
38
pyflaredb/sql/statistics.py
Normal file
38
pyflaredb/sql/statistics.py
Normal file
@ -0,0 +1,38 @@
|
||||
from typing import Dict, Any
|
||||
import numpy as np
|
||||
|
||||
from pyflaredb.table import Table
|
||||
|
||||
|
||||
class TableStatistics:
|
||||
def __init__(self):
|
||||
self.table_sizes: Dict[str, int] = {}
|
||||
self.column_stats: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def collect_statistics(self, table: "Table"):
|
||||
"""Collect statistics for a table"""
|
||||
self.table_sizes[table.name] = len(table.data)
|
||||
|
||||
for column in table.columns:
|
||||
values = [row[column.name] for row in table.data if column.name in row]
|
||||
|
||||
if not values:
|
||||
continue
|
||||
|
||||
stats = {
|
||||
"distinct_count": len(set(values)),
|
||||
"null_count": sum(1 for v in values if v is None),
|
||||
"min": min(values) if values and None not in values else None,
|
||||
"max": max(values) if values and None not in values else None,
|
||||
}
|
||||
|
||||
if isinstance(values[0], (int, float)):
|
||||
stats.update(
|
||||
{
|
||||
"mean": np.mean(values),
|
||||
"std_dev": np.std(values),
|
||||
"histogram": np.histogram(values, bins=100),
|
||||
}
|
||||
)
|
||||
|
||||
self.column_stats[f"{table.name}.{column.name}"] = stats
|
191
pyflaredb/table.py
Normal file
191
pyflaredb/table.py
Normal file
@ -0,0 +1,191 @@
|
||||
from typing import Dict, List, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from .indexing.btree import BTreeIndex
|
||||
|
||||
@dataclass
|
||||
class Column:
|
||||
name: str
|
||||
data_type: str
|
||||
nullable: bool = True
|
||||
unique: bool = False
|
||||
primary_key: bool = False
|
||||
default: Any = None
|
||||
|
||||
class Table:
|
||||
def __init__(self, name: str, columns: List[Column]):
|
||||
self.name = name
|
||||
self.columns = columns
|
||||
self.data: List[Dict[str, Any]] = []
|
||||
self._unique_indexes: Dict[str, Dict[Any, int]] = defaultdict(dict)
|
||||
self._compiled_conditions = {}
|
||||
self._indexes: Dict[str, BTreeIndex] = {}
|
||||
|
||||
# Validate column definitions
|
||||
self._validate_columns()
|
||||
|
||||
def _validate_columns(self):
|
||||
"""Validate column definitions"""
|
||||
# Ensure only one primary key
|
||||
primary_keys = [col for col in self.columns if col.primary_key]
|
||||
if len(primary_keys) > 1:
|
||||
raise ValueError("Table can only have one primary key")
|
||||
|
||||
# Validate data types
|
||||
valid_types = {"string", "integer", "float", "boolean", "datetime"}
|
||||
for col in self.columns:
|
||||
if col.data_type.lower() not in valid_types:
|
||||
raise ValueError(f"Invalid data type for column {col.name}: {col.data_type}")
|
||||
|
||||
def create_index(self, column_name: str) -> None:
|
||||
"""Create a B-tree index for a column"""
|
||||
if column_name not in {col.name for col in self.columns}:
|
||||
raise ValueError(f"Column {column_name} does not exist")
|
||||
|
||||
# Create new index
|
||||
index = BTreeIndex()
|
||||
|
||||
# Build index from existing data
|
||||
for row_id, row in enumerate(self.data):
|
||||
if column_name in row:
|
||||
index.insert(row[column_name], row_id)
|
||||
|
||||
self._indexes[column_name] = index
|
||||
|
||||
def batch_insert(self, rows: List[Dict[str, Any]]) -> bool:
|
||||
"""Efficiently insert multiple rows with index updates"""
|
||||
# Pre-validate all rows
|
||||
validated_rows = []
|
||||
unique_values = defaultdict(set)
|
||||
|
||||
# Check unique constraints across all new rows
|
||||
for row in rows:
|
||||
converted_row = {}
|
||||
# Validate required columns and defaults
|
||||
for column in self.columns:
|
||||
if not column.nullable and column.name not in row and column.default is None:
|
||||
raise ValueError(f"Required column {column.name} is missing")
|
||||
|
||||
value = row.get(column.name, column.default)
|
||||
|
||||
# Type conversion
|
||||
if value is not None:
|
||||
try:
|
||||
if column.data_type == "integer":
|
||||
value = int(value)
|
||||
elif column.data_type == "float":
|
||||
value = float(value)
|
||||
elif column.data_type == "boolean":
|
||||
value = bool(value)
|
||||
else: # string and datetime
|
||||
value = str(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(f"Invalid value for column {column.name}: {value}")
|
||||
|
||||
converted_row[column.name] = value
|
||||
|
||||
# Track unique values
|
||||
if column.unique and value is not None:
|
||||
if value in unique_values[column.name] or value in self._unique_indexes[column.name]:
|
||||
raise ValueError(f"Unique constraint violated for column {column.name}")
|
||||
unique_values[column.name].add(value)
|
||||
|
||||
validated_rows.append(converted_row)
|
||||
|
||||
# All rows validated, perform batch insert
|
||||
start_id = len(self.data)
|
||||
for i, row in enumerate(validated_rows):
|
||||
row_id = start_id + i
|
||||
|
||||
# Update indexes
|
||||
for column_name, index in self._indexes.items():
|
||||
if column_name in row:
|
||||
index.insert(row[column_name], row_id)
|
||||
|
||||
# Update unique indexes
|
||||
for column in self.columns:
|
||||
if column.unique:
|
||||
value = row.get(column.name)
|
||||
if value is not None:
|
||||
self._unique_indexes[column.name][value] = row_id
|
||||
|
||||
self.data.append(row)
|
||||
|
||||
return True
|
||||
|
||||
def insert(self, row: Dict[str, Any]) -> bool:
|
||||
"""Insert a single row (now uses batch_insert)"""
|
||||
return self.batch_insert([row])
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert table to dictionary for serialization"""
|
||||
return {
|
||||
"name": self.name,
|
||||
"columns": [
|
||||
{
|
||||
"name": col.name,
|
||||
"data_type": col.data_type,
|
||||
"nullable": col.nullable,
|
||||
"unique": col.unique,
|
||||
"primary_key": col.primary_key
|
||||
}
|
||||
for col in self.columns
|
||||
],
|
||||
"data": self.data
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict) -> 'Table':
|
||||
"""Create table from dictionary"""
|
||||
columns = [
|
||||
Column(**col_data)
|
||||
for col_data in data["columns"]
|
||||
]
|
||||
table = cls(data["name"], columns)
|
||||
table.data = data["data"]
|
||||
return table
|
||||
|
||||
def _validate_type(self, value: Any, expected_type: str) -> bool:
|
||||
"""Validate that a value matches the expected data type"""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"float": float,
|
||||
"boolean": bool,
|
||||
"datetime": datetime,
|
||||
}
|
||||
|
||||
if expected_type not in type_mapping:
|
||||
raise ValueError(f"Unsupported data type: {expected_type}")
|
||||
|
||||
expected_python_type = type_mapping[expected_type]
|
||||
|
||||
if not isinstance(value, expected_python_type):
|
||||
try:
|
||||
# Attempt to convert the value
|
||||
expected_python_type(value)
|
||||
except (ValueError, TypeError):
|
||||
raise ValueError(
|
||||
f"Value {value} is not of expected type {expected_type}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def find_by_index(self, column_name: str, value: Any) -> List[Dict[str, Any]]:
|
||||
"""Find rows using an index"""
|
||||
if column_name not in self._indexes:
|
||||
raise ValueError(f"No index exists for column {column_name}")
|
||||
|
||||
index = self._indexes[column_name]
|
||||
row_ids = index.search(value)
|
||||
return [self.data[row_id] for row_id in row_ids]
|
||||
|
||||
def range_search(self, column_name: str, start_value: Any, end_value: Any) -> List[Dict[str, Any]]:
|
||||
"""Perform a range search using an index"""
|
||||
if column_name not in self._indexes:
|
||||
raise ValueError(f"No index exists for column {column_name}")
|
||||
|
||||
index = self._indexes[column_name]
|
||||
row_ids = index.range_search(start_value, end_value)
|
||||
return [self.data[row_id] for row_id in row_ids]
|
81
pyflaredb/transaction.py
Normal file
81
pyflaredb/transaction.py
Normal file
@ -0,0 +1,81 @@
|
||||
from typing import Dict, Any, Optional, Set, List
|
||||
from enum import Enum
|
||||
import time
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
|
||||
class TransactionState(Enum):
|
||||
ACTIVE = "active"
|
||||
COMMITTED = "committed"
|
||||
ROLLED_BACK = "rolled_back"
|
||||
|
||||
|
||||
class Transaction:
|
||||
def __init__(self, tx_id: str):
|
||||
self.id = tx_id
|
||||
self.state = TransactionState.ACTIVE
|
||||
self.start_time = time.time()
|
||||
self.locks: Set[str] = set() # Set of table names that are locked
|
||||
self.changes: List[Dict[str, Any]] = (
|
||||
[]
|
||||
) # List of changes made during transaction
|
||||
|
||||
|
||||
class TransactionManager:
|
||||
def __init__(self):
|
||||
self.transactions: Dict[str, Transaction] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def begin_transaction(self) -> str:
|
||||
"""Start a new transaction"""
|
||||
with self.lock:
|
||||
tx_id = str(uuid.uuid4())
|
||||
self.transactions[tx_id] = Transaction(tx_id)
|
||||
return tx_id
|
||||
|
||||
def commit(self, tx_id: str) -> bool:
|
||||
"""Commit a transaction"""
|
||||
with self.lock:
|
||||
if tx_id not in self.transactions:
|
||||
raise ValueError(f"Transaction {tx_id} not found")
|
||||
|
||||
tx = self.transactions[tx_id]
|
||||
if tx.state != TransactionState.ACTIVE:
|
||||
raise ValueError(f"Transaction {tx_id} is not active")
|
||||
|
||||
# Apply changes
|
||||
tx.state = TransactionState.COMMITTED
|
||||
|
||||
# Release locks
|
||||
tx.locks.clear()
|
||||
|
||||
return True
|
||||
|
||||
def rollback(self, tx_id: str) -> bool:
|
||||
"""Rollback a transaction"""
|
||||
with self.lock:
|
||||
if tx_id not in self.transactions:
|
||||
raise ValueError(f"Transaction {tx_id} not found")
|
||||
|
||||
tx = self.transactions[tx_id]
|
||||
if tx.state != TransactionState.ACTIVE:
|
||||
raise ValueError(f"Transaction {tx_id} is not active")
|
||||
|
||||
# Revert changes
|
||||
tx.state = TransactionState.ROLLED_BACK
|
||||
tx.changes.clear()
|
||||
|
||||
# Release locks
|
||||
tx.locks.clear()
|
||||
|
||||
return True
|
||||
|
||||
def get_transaction(self, tx_id: str) -> Optional[Transaction]:
|
||||
"""Get transaction by ID"""
|
||||
return self.transactions.get(tx_id)
|
||||
|
||||
def is_active(self, tx_id: str) -> bool:
|
||||
"""Check if a transaction is active"""
|
||||
tx = self.get_transaction(tx_id)
|
||||
return tx is not None and tx.state == TransactionState.ACTIVE
|
63
pyflaredb/transaction/manager.py
Normal file
63
pyflaredb/transaction/manager.py
Normal file
@ -0,0 +1,63 @@
|
||||
from typing import Dict, List, Any
|
||||
from enum import Enum
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
class TransactionState(Enum):
|
||||
ACTIVE = "ACTIVE"
|
||||
COMMITTED = "COMMITTED"
|
||||
ROLLED_BACK = "ROLLED_BACK"
|
||||
|
||||
class Transaction:
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
self.state = TransactionState.ACTIVE
|
||||
self.changes: List[Dict[str, Any]] = []
|
||||
self.locks = set()
|
||||
self.timestamp = datetime.utcnow()
|
||||
|
||||
class TransactionManager:
|
||||
def __init__(self):
|
||||
self.transactions: Dict[str, Transaction] = {}
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def begin_transaction(self) -> str:
|
||||
"""Start a new transaction"""
|
||||
with self.lock:
|
||||
tx_id = str(len(self.transactions) + 1)
|
||||
self.transactions[tx_id] = Transaction(tx_id)
|
||||
return tx_id
|
||||
|
||||
def commit(self, tx_id: str):
|
||||
"""Commit a transaction"""
|
||||
with self.lock:
|
||||
if tx_id not in self.transactions:
|
||||
raise ValueError(f"Transaction {tx_id} not found")
|
||||
|
||||
tx = self.transactions[tx_id]
|
||||
if tx.state != TransactionState.ACTIVE:
|
||||
raise ValueError(f"Transaction {tx_id} is not active")
|
||||
|
||||
# Apply changes
|
||||
self._apply_changes(tx)
|
||||
tx.state = TransactionState.COMMITTED
|
||||
|
||||
def rollback(self, tx_id: str):
|
||||
"""Rollback a transaction"""
|
||||
with self.lock:
|
||||
if tx_id not in self.transactions:
|
||||
raise ValueError(f"Transaction {tx_id} not found")
|
||||
|
||||
tx = self.transactions[tx_id]
|
||||
if tx.state != TransactionState.ACTIVE:
|
||||
raise ValueError(f"Transaction {tx_id} is not active")
|
||||
|
||||
# Discard changes
|
||||
tx.changes.clear()
|
||||
tx.state = TransactionState.ROLLED_BACK
|
||||
|
||||
def _apply_changes(self, transaction: Transaction):
|
||||
"""Apply transaction changes"""
|
||||
for change in transaction.changes:
|
||||
# Implementation of applying changes
|
||||
pass
|
50
pyflaredb/versioning.py
Normal file
50
pyflaredb/versioning.py
Normal file
@ -0,0 +1,50 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
class Version:
|
||||
timestamp: datetime
|
||||
operation: str # 'INSERT', 'UPDATE', 'DELETE'
|
||||
table_name: str
|
||||
row_id: str
|
||||
data: Dict[str, Any]
|
||||
previous_version: Optional[str] = None # Hash of previous version
|
||||
|
||||
|
||||
class VersionStore:
|
||||
def __init__(self):
|
||||
self.versions: List[Version] = []
|
||||
self.current_version: str = None # Hash of current version
|
||||
|
||||
def add_version(self, version: Version):
|
||||
"""Add a new version to the store"""
|
||||
version_hash = self._calculate_hash(version)
|
||||
self.versions.append(version)
|
||||
self.current_version = version_hash
|
||||
|
||||
def get_state_at(self, timestamp: datetime) -> Dict[str, List[Dict[str, Any]]]:
|
||||
"""Reconstruct database state at given timestamp"""
|
||||
state = {}
|
||||
relevant_versions = [v for v in self.versions if v.timestamp <= timestamp]
|
||||
|
||||
for version in relevant_versions:
|
||||
if version.table_name not in state:
|
||||
state[version.table_name] = []
|
||||
|
||||
if version.operation == "INSERT":
|
||||
state[version.table_name].append(version.data)
|
||||
elif version.operation == "DELETE":
|
||||
state[version.table_name] = [
|
||||
row
|
||||
for row in state[version.table_name]
|
||||
if row["id"] != version.row_id
|
||||
]
|
||||
elif version.operation == "UPDATE":
|
||||
state[version.table_name] = [
|
||||
version.data if row["id"] == version.row_id else row
|
||||
for row in state[version.table_name]
|
||||
]
|
||||
|
||||
return state
|
254
test.py
Normal file
254
test.py
Normal file
@ -0,0 +1,254 @@
|
||||
from pyflaredb.core import PyFlareDB
|
||||
from pyflaredb.table import Column, Table
|
||||
from pyflaredb.benchmark.suite import BenchmarkSuite
|
||||
import time
|
||||
from datetime import datetime
|
||||
import random
|
||||
import string
|
||||
import json
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
def generate_realistic_data(n: int) -> List[Dict[str, Any]]:
|
||||
"""Generate realistic test data"""
|
||||
domains = ['gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com', 'company.com']
|
||||
cities = ['New York', 'London', 'Tokyo', 'Paris', 'Berlin', 'Sydney', 'Toronto']
|
||||
|
||||
data = []
|
||||
for i in range(n):
|
||||
# Generate realistic username
|
||||
username = f"{random.choice(string.ascii_lowercase)}{random.choice(string.ascii_lowercase)}"
|
||||
username += ''.join(random.choices(string.ascii_lowercase + string.digits, k=random.randint(6, 12)))
|
||||
|
||||
# Generate realistic email
|
||||
email = f"{username}@{random.choice(domains)}"
|
||||
|
||||
# Generate JSON metadata
|
||||
metadata = {
|
||||
"city": random.choice(cities),
|
||||
"last_login": f"2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}",
|
||||
"preferences": {
|
||||
"theme": random.choice(["light", "dark", "system"]),
|
||||
"notifications": random.choice([True, False])
|
||||
}
|
||||
}
|
||||
|
||||
data.append({
|
||||
"id": f"usr_{i:08d}",
|
||||
"username": username,
|
||||
"email": email,
|
||||
"age": random.randint(18, 80),
|
||||
"score": round(random.uniform(0, 100), 2),
|
||||
"is_active": random.random() > 0.1, # 90% active users
|
||||
"login_count": random.randint(1, 1000),
|
||||
"metadata": json.dumps(metadata)
|
||||
})
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def format_value(value):
|
||||
"""Format value based on its type"""
|
||||
if isinstance(value, (float, int)):
|
||||
return f"{value:.4f}"
|
||||
return str(value)
|
||||
|
||||
|
||||
def test_database_features():
|
||||
"""Test all database features with realistic workloads"""
|
||||
print("\n=== Starting Realistic Database Tests ===")
|
||||
|
||||
# Initialize database
|
||||
db = PyFlareDB("test.db")
|
||||
|
||||
# 1. Create test table with realistic schema
|
||||
print("\n1. Setting up test environment...")
|
||||
users_table = Table(
|
||||
name="users",
|
||||
columns=[
|
||||
Column("id", "string", nullable=False, primary_key=True),
|
||||
Column("username", "string", nullable=False, unique=True),
|
||||
Column("email", "string", nullable=False),
|
||||
Column("age", "integer", nullable=True),
|
||||
Column("score", "float", nullable=True),
|
||||
Column("is_active", "boolean", nullable=True, default=True),
|
||||
Column("login_count", "integer", nullable=True, default=0),
|
||||
Column("metadata", "string", nullable=True) # JSON string
|
||||
],
|
||||
)
|
||||
db.tables["users"] = users_table
|
||||
|
||||
# Create indexes for commonly queried fields
|
||||
users_table.create_index("age")
|
||||
users_table.create_index("score")
|
||||
users_table.create_index("login_count")
|
||||
|
||||
# 2. Performance Tests with Realistic Data
|
||||
print("\n2. Running performance tests...")
|
||||
|
||||
# Generate test data
|
||||
test_data = generate_realistic_data(1000) # 1000 realistic records
|
||||
|
||||
# Insert Performance (Single vs Batch)
|
||||
print("\nInsert Performance:")
|
||||
|
||||
# Single Insert (OLTP-style)
|
||||
start_time = time.time()
|
||||
for record in test_data[:100]: # Test with first 100 records
|
||||
# Properly escape the metadata string
|
||||
metadata_str = record['metadata'].replace("'", "''")
|
||||
|
||||
# Format each value according to its type
|
||||
values = [
|
||||
f"'{record['id']}'", # string
|
||||
f"'{record['username']}'", # string
|
||||
f"'{record['email']}'", # string
|
||||
str(record['age']), # integer
|
||||
str(record['score']), # float
|
||||
str(record['is_active']).lower(), # boolean
|
||||
str(record['login_count']), # integer
|
||||
f"'{metadata_str}'" # string (JSON)
|
||||
]
|
||||
|
||||
query = f"""
|
||||
INSERT INTO users
|
||||
(id, username, email, age, score, is_active, login_count, metadata)
|
||||
VALUES
|
||||
({', '.join(values)})
|
||||
"""
|
||||
db.execute(query)
|
||||
single_insert_time = time.time() - start_time
|
||||
print(f"Single Insert (100 records, OLTP): {single_insert_time:.4f}s")
|
||||
|
||||
# Batch Insert (OLAP-style)
|
||||
start_time = time.time()
|
||||
batch_data = test_data[100:200] # Next 100 records
|
||||
users_table.batch_insert(batch_data) # This should work as is
|
||||
batch_insert_time = time.time() - start_time
|
||||
print(f"Batch Insert (100 records, OLAP): {batch_insert_time:.4f}s")
|
||||
|
||||
# 3. Query Performance Tests
|
||||
print("\nQuery Performance (OLTP vs OLAP):")
|
||||
|
||||
# OLTP-style queries (point queries, simple filters)
|
||||
oltp_queries = [
|
||||
("Single Record Lookup", "SELECT * FROM users WHERE id = 'usr_00000001'"),
|
||||
("Simple Range Query", "SELECT * FROM users WHERE age > 30 LIMIT 10"),
|
||||
("Active Users Count", "SELECT COUNT(*) FROM users WHERE is_active = true"),
|
||||
("Recent Logins", "SELECT * FROM users WHERE login_count > 500 ORDER BY login_count DESC LIMIT 5")
|
||||
]
|
||||
|
||||
# OLAP-style queries (aggregations, complex filters)
|
||||
olap_queries = [
|
||||
("Age Distribution", """
|
||||
SELECT
|
||||
CASE
|
||||
WHEN age < 25 THEN 'Gen Z'
|
||||
WHEN age < 40 THEN 'Millennial'
|
||||
WHEN age < 55 THEN 'Gen X'
|
||||
ELSE 'Boomer'
|
||||
END as generation,
|
||||
COUNT(*) as count
|
||||
FROM users
|
||||
GROUP BY generation
|
||||
"""),
|
||||
("User Engagement", """
|
||||
SELECT
|
||||
username,
|
||||
score,
|
||||
login_count
|
||||
FROM users
|
||||
WHERE score > 75
|
||||
AND login_count > 100
|
||||
ORDER BY score DESC
|
||||
LIMIT 10
|
||||
"""),
|
||||
("Complex Analytics", """
|
||||
SELECT
|
||||
COUNT(*) as total_users,
|
||||
AVG(score) as avg_score,
|
||||
SUM(CASE WHEN is_active THEN 1 ELSE 0 END) as active_users
|
||||
FROM users
|
||||
WHERE age BETWEEN 25 AND 45
|
||||
""")
|
||||
]
|
||||
|
||||
print("\nOLTP Query Performance:")
|
||||
for query_name, query in oltp_queries:
|
||||
# First run (cold)
|
||||
start_time = time.time()
|
||||
db.execute(query)
|
||||
cold_time = time.time() - start_time
|
||||
|
||||
# Second run (warm/cached)
|
||||
start_time = time.time()
|
||||
db.execute(query)
|
||||
warm_time = time.time() - start_time
|
||||
|
||||
print(f"\n{query_name}:")
|
||||
print(f" Cold run: {cold_time:.4f}s")
|
||||
print(f" Warm run: {warm_time:.4f}s")
|
||||
print(f" Cache improvement: {((cold_time - warm_time) / cold_time * 100):.1f}%")
|
||||
|
||||
print("\nOLAP Query Performance:")
|
||||
for query_name, query in olap_queries:
|
||||
start_time = time.time()
|
||||
db.execute(query)
|
||||
execution_time = time.time() - start_time
|
||||
print(f"\n{query_name}: {execution_time:.4f}s")
|
||||
|
||||
# 4. Concurrent Operations Test
|
||||
print("\nConcurrent Operations Simulation:")
|
||||
start_time = time.time()
|
||||
# Simulate mixed workload
|
||||
for _ in range(100):
|
||||
if random.random() < 0.8: # 80% reads
|
||||
query = random.choice(oltp_queries)[1]
|
||||
else: # 20% writes
|
||||
record = generate_realistic_data(1)[0]
|
||||
query = f"""
|
||||
INSERT INTO users (id, username, email, age, score, is_active, login_count, metadata)
|
||||
VALUES (
|
||||
'{record['id']}',
|
||||
'{record['username']}',
|
||||
'{record['email']}',
|
||||
{record['age']},
|
||||
{record['score']},
|
||||
{str(record['is_active']).lower()},
|
||||
{record['login_count']},
|
||||
'{record['metadata']}'
|
||||
)
|
||||
"""
|
||||
db.execute(query)
|
||||
mixed_workload_time = time.time() - start_time
|
||||
print(f"Mixed Workload (100 operations): {mixed_workload_time:.4f}s")
|
||||
|
||||
# 5. Memory Usage Test
|
||||
print("\nMemory Usage:")
|
||||
import sys
|
||||
memory_size = sys.getsizeof(db.tables["users"].data) / 1024 # KB
|
||||
records_count = len(db.tables["users"].data)
|
||||
print(f"Memory per record: {(memory_size / records_count):.2f} KB")
|
||||
|
||||
# 6. Run standard benchmark suite
|
||||
print("\n6. Running standard benchmark suite...")
|
||||
benchmark = BenchmarkSuite(db)
|
||||
results = benchmark.run_benchmark(num_records=10000)
|
||||
|
||||
print("\nBenchmark Results:")
|
||||
for test_name, metrics in results.items():
|
||||
print(f"\n{test_name.upper()}:")
|
||||
for metric, value in metrics.items():
|
||||
print(f" {metric}: {format_value(value)}")
|
||||
|
||||
|
||||
def main():
|
||||
try:
|
||||
test_database_features()
|
||||
except Exception as e:
|
||||
print(f"Test failed: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user