First commit

This commit is contained in:
tcsenpai 2024-11-25 14:38:35 +01:00
commit d2cea458d4
16 changed files with 1580 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
__pycache__
*.db

5
pyflaredb/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .core import PyFlareDB
from .table import Table, Column
from .versioning import Version, VersionStore
__all__ = ['PyFlareDB', 'Table', 'Column', 'Version', 'VersionStore']

View File

@ -0,0 +1,125 @@
import time
from typing import List, Dict, Any
import random
import string
from ..core import PyFlareDB
class BenchmarkSuite:
def __init__(self, db: PyFlareDB):
self.db = db
def run_benchmark(self, num_records: int = 10000):
"""Run comprehensive benchmark"""
results = {
"insert": self._benchmark_insert(num_records),
"select": self._benchmark_select(num_records),
"index": self._benchmark_index_performance(num_records),
"complex_query": self._benchmark_complex_queries(num_records),
}
return results
def _benchmark_insert(self, num_records: int) -> Dict[str, float]:
start_time = time.time()
batch_times = []
for i in range(0, num_records, 1000):
batch_start = time.time()
self._insert_batch(min(1000, num_records - i))
batch_times.append(time.time() - batch_start)
total_time = time.time() - start_time
return {
"total_time": total_time,
"records_per_second": num_records / total_time,
"avg_batch_time": sum(batch_times) / len(batch_times),
}
def _insert_batch(self, size: int):
"""Insert a batch of random records"""
tx_id = None
try:
tx_id = self.db.transaction_manager.begin_transaction()
for _ in range(size):
query = (
"INSERT INTO users (id, username, email, age) "
f"VALUES ('{self._random_string(10)}', "
f"'{self._random_string(8)}', "
f"'{self._random_string(8)}@example.com', "
f"{random.randint(18, 80)})"
)
try:
self.db.execute(query)
except Exception as e:
print(f"Failed to insert record: {e}")
print(f"Query was: {query}")
raise
self.db.transaction_manager.commit(tx_id)
except Exception as e:
if tx_id is not None:
try:
self.db.transaction_manager.rollback(tx_id)
except ValueError:
pass
raise
def _benchmark_select(self, num_records: int) -> Dict[str, float]:
"""Benchmark SELECT queries"""
queries = [
"SELECT * FROM users WHERE age > 30",
"SELECT username, email FROM users WHERE age < 25",
"SELECT COUNT(*) FROM users",
]
results = {}
for query in queries:
start_time = time.time()
try:
self.db.execute(query)
query_time = time.time() - start_time
results[query] = query_time
except Exception as e:
results[query] = f"Error: {e}"
return results
def _benchmark_index_performance(self, num_records: int) -> Dict[str, float]:
"""Benchmark index performance"""
# TODO: Implement index benchmarking
return {"index_creation_time": 0.0}
def _benchmark_complex_queries(self, num_records: int) -> Dict[str, float]:
"""Benchmark complex queries"""
complex_queries = [
"""
SELECT username, COUNT(*) as count
FROM users
GROUP BY username
""",
"""
SELECT * FROM users
WHERE age > 30
ORDER BY username DESC
LIMIT 10
""",
]
results = {}
for query in complex_queries:
start_time = time.time()
try:
self.db.execute(query)
query_time = time.time() - start_time
results[query.strip()] = query_time
except Exception as e:
results[query.strip()] = f"Error: {e}"
return results
@staticmethod
def _random_string(length: int) -> str:
"""Generate a random string of specified length"""
return "".join(random.choices(string.ascii_letters + string.digits, k=length))

39
pyflaredb/cache/query_cache.py vendored Normal file
View File

@ -0,0 +1,39 @@
from typing import Dict, Any, Optional
import time
import hashlib
from collections import OrderedDict
class QueryCache:
def __init__(self, capacity: int = 1000, ttl: int = 300):
self.capacity = capacity
self.ttl = ttl
self.cache = OrderedDict()
def get(self, query: str) -> Optional[Any]:
"""Get cached query result"""
query_hash = self._hash_query(query)
if query_hash in self.cache:
entry = self.cache[query_hash]
if time.time() - entry['timestamp'] < self.ttl:
self.cache.move_to_end(query_hash)
return entry['result']
else:
del self.cache[query_hash]
return None
def set(self, query: str, result: Any):
"""Cache query result"""
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
query_hash = self._hash_query(query)
self.cache[query_hash] = {
'result': result,
'timestamp': time.time()
}
def _hash_query(self, query: str) -> str:
return hashlib.sha256(query.encode()).hexdigest()
def clear(self):
self.cache.clear()

90
pyflaredb/core.py Normal file
View File

@ -0,0 +1,90 @@
from typing import Dict, List, Any, Optional
from .table import Table
from .sql.parser import SQLParser, SelectStatement, InsertStatement
from .sql.executor import QueryExecutor
from .sql.optimizer import QueryOptimizer
from .sql.statistics import TableStatistics
from .transaction import TransactionManager
class PyFlareDB:
def __init__(self, db_path: str):
"""Initialize the database"""
self.db_path = db_path
self.tables: Dict[str, Table] = {}
self.parser = SQLParser()
self.statistics = TableStatistics()
self.optimizer = QueryOptimizer(self.tables, self.statistics)
self.executor = QueryExecutor(self.tables)
self.transaction_manager = TransactionManager()
self._query_cache = {}
def begin_transaction(self) -> str:
"""Begin a new transaction"""
return self.transaction_manager.begin_transaction()
def commit_transaction(self, tx_id: str) -> bool:
"""Commit a transaction"""
return self.transaction_manager.commit(tx_id)
def rollback_transaction(self, tx_id: str) -> bool:
"""Rollback a transaction"""
return self.transaction_manager.rollback(tx_id)
def create_table(self, table: Table) -> None:
"""Create a new table"""
if table.name in self.tables:
raise ValueError(f"Table {table.name} already exists")
self.tables[table.name] = table
self.statistics.collect_statistics(table)
def drop_table(self, table_name: str) -> None:
"""Drop a table"""
if table_name not in self.tables:
raise ValueError(f"Table {table_name} does not exist")
del self.tables[table_name]
def execute(
self, sql: str, tx_id: Optional[str] = None
) -> Optional[List[Dict[str, Any]]]:
"""Execute a SQL query"""
try:
# Check query cache for non-transactional SELECT queries
if tx_id is None and sql in self._query_cache:
return self._query_cache[sql]
# Parse SQL
if sql.strip().upper().startswith("SELECT"):
statement = self.parser.parse_select(sql)
elif sql.strip().upper().startswith("INSERT"):
statement = self.parser.parse_insert(sql)
else:
raise ValueError("Unsupported SQL statement type")
# Get transaction if provided
tx = None
if tx_id:
tx = self.transaction_manager.get_transaction(tx_id)
if not tx:
raise ValueError(f"Transaction {tx_id} does not exist")
# Optimize query plan
optimized_plan = self.optimizer.optimize(statement)
# Execute query
result = self.executor.execute(optimized_plan, transaction=tx)
# Cache SELECT results for non-transactional queries
if tx_id is None and isinstance(statement, SelectStatement):
self._query_cache[sql] = result
return result
except Exception as e:
# Clear cache on error
self._query_cache.clear()
raise e
def clear_cache(self) -> None:
"""Clear the query cache"""
self._query_cache.clear()

137
pyflaredb/indexing/btree.py Normal file
View File

@ -0,0 +1,137 @@
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass
from collections import deque
@dataclass
class Node:
keys: List[Any]
values: List[List[int]] # List of row IDs for each key (handling duplicates)
children: List['Node']
is_leaf: bool = True
class BTreeIndex:
def __init__(self, order: int = 100):
self.root = Node([], [], [])
self.order = order # Maximum number of children per node
def insert(self, key: Any, row_id: int) -> None:
"""Insert a key-value pair into the B-tree"""
if len(self.root.keys) == (2 * self.order) - 1:
# Split root if full
new_root = Node([], [], [], False)
new_root.children.append(self.root)
self._split_child(new_root, 0)
self.root = new_root
self._insert_non_full(self.root, key, row_id)
def search(self, key: Any) -> List[int]:
"""Search for a key and return all matching row IDs"""
return self._search_node(self.root, key)
def range_search(self, start_key: Any, end_key: Any) -> List[int]:
"""Perform a range search and return all matching row IDs"""
result = []
self._range_search_node(self.root, start_key, end_key, result)
return result
def _split_child(self, parent: Node, child_index: int) -> None:
"""Split a full child node"""
order = self.order
child = parent.children[child_index]
new_node = Node([], [], [], child.is_leaf)
# Move the median key to the parent
median = order - 1
parent.keys.insert(child_index, child.keys[median])
parent.values.insert(child_index, child.values[median])
parent.children.insert(child_index + 1, new_node)
# Move half of the keys to the new node
new_node.keys = child.keys[median + 1:]
new_node.values = child.values[median + 1:]
child.keys = child.keys[:median]
child.values = child.values[:median]
# Move children if not a leaf
if not child.is_leaf:
new_node.children = child.children[median + 1:]
child.children = child.children[:median + 1]
def _insert_non_full(self, node: Node, key: Any, row_id: int) -> None:
"""Insert into a non-full node"""
i = len(node.keys) - 1
if node.is_leaf:
# Insert into leaf node
while i >= 0 and self._compare_keys(key, node.keys[i]) < 0:
i -= 1
i += 1
# Handle duplicate keys
if i > 0 and self._compare_keys(key, node.keys[i-1]) == 0:
node.values[i-1].append(row_id)
else:
node.keys.insert(i, key)
node.values.insert(i, [row_id])
else:
# Find the child to insert into
while i >= 0 and self._compare_keys(key, node.keys[i]) < 0:
i -= 1
i += 1
if len(node.children[i].keys) == (2 * self.order) - 1:
self._split_child(node, i)
if self._compare_keys(key, node.keys[i]) > 0:
i += 1
self._insert_non_full(node.children[i], key, row_id)
def _search_node(self, node: Node, key: Any) -> List[int]:
"""Search for a key in a node"""
i = 0
while i < len(node.keys) and self._compare_keys(key, node.keys[i]) > 0:
i += 1
if i < len(node.keys) and self._compare_keys(key, node.keys[i]) == 0:
return node.values[i]
elif node.is_leaf:
return []
else:
return self._search_node(node.children[i], key)
def _range_search_node(self, node: Node, start_key: Any, end_key: Any, result: List[int]) -> None:
"""Perform range search on a node"""
i = 0
while i < len(node.keys) and self._compare_keys(start_key, node.keys[i]) > 0:
i += 1
if node.is_leaf:
while i < len(node.keys) and self._compare_keys(node.keys[i], end_key) <= 0:
result.extend(node.values[i])
i += 1
else:
if i < len(node.keys):
self._range_search_node(node.children[i], start_key, end_key, result)
while i < len(node.keys) and self._compare_keys(node.keys[i], end_key) <= 0:
result.extend(node.values[i])
i += 1
if i < len(node.children):
self._range_search_node(node.children[i], start_key, end_key, result)
@staticmethod
def _compare_keys(key1: Any, key2: Any) -> int:
"""Compare two keys, handling different types"""
if key1 is None or key2 is None:
if key1 is None and key2 is None:
return 0
return -1 if key1 is None else 1
try:
if key1 < key2:
return -1
elif key1 > key2:
return 1
return 0
except TypeError:
# Handle incomparable types
return 0

View File

@ -0,0 +1,33 @@
from typing import Dict, List
import time
from collections import deque
import threading
class PerformanceMetrics:
def __init__(self, window_size: int = 1000):
self.window_size = window_size
self.query_times: Dict[str, deque] = {}
self.lock = threading.Lock()
def record_query(self, query_type: str, execution_time: float):
"""Record query execution time"""
with self.lock:
if query_type not in self.query_times:
self.query_times[query_type] = deque(maxlen=self.window_size)
self.query_times[query_type].append(execution_time)
def get_metrics(self) -> Dict[str, Dict[str, float]]:
"""Get performance metrics"""
metrics = {}
with self.lock:
for query_type, times in self.query_times.items():
if not times:
continue
metrics[query_type] = {
"avg_time": sum(times) / len(times),
"max_time": max(times),
"min_time": min(times),
"count": len(times),
}
return metrics

245
pyflaredb/sql/executor.py Normal file
View File

@ -0,0 +1,245 @@
from typing import List, Dict, Any, Callable, Tuple, Optional
import operator
from ..table import Table
from .parser import SelectStatement, InsertStatement
from ..transaction import Transaction
class QueryExecutor:
def __init__(self, tables: Dict[str, Table]):
self.tables = tables
self._compiled_conditions = {}
self._comparison_ops = {
'>': operator.gt,
'<': operator.lt,
'>=': operator.ge,
'<=': operator.le,
'=': operator.eq,
'!=': operator.ne
}
def _parse_where_clause(self, where_clause: str) -> List[Tuple[str, str, str]]:
"""Parse WHERE clause into list of (field, operator, value) tuples"""
conditions = []
# Split on AND if present
subclauses = [c.strip() for c in where_clause.split(' AND ')]
for subclause in subclauses:
# Find the operator
operator_found = None
for op in ['>=', '<=', '>', '<', '=', '!=']:
if op in subclause:
operator_found = op
field, value = subclause.split(op)
conditions.append((field.strip(), op, value.strip()))
break
if not operator_found:
raise ValueError(f"Invalid condition: {subclause}")
return conditions
def execute(self, statement, transaction: Optional[Transaction] = None):
"""Execute a parsed SQL statement"""
if isinstance(statement, SelectStatement):
return self._execute_select(statement, transaction)
elif isinstance(statement, InsertStatement):
return self._execute_insert(statement, transaction)
elif statement is None:
raise ValueError("No statement to execute")
else:
raise ValueError(f"Unsupported statement type: {type(statement)}")
def _execute_select(self, stmt: SelectStatement, transaction: Optional[Transaction] = None) -> List[Dict[str, Any]]:
if stmt.table_name not in self.tables:
raise ValueError(f"Table {stmt.table_name} does not exist")
table = self.tables[stmt.table_name]
# If in transaction, check for locks
if transaction and table.name in transaction.locks:
# Handle transaction isolation level logic here
pass
# Handle COUNT(*) separately
if len(stmt.columns) == 1 and stmt.columns[0].lower() == "count(*)":
return [{"count": len(table.data)}]
# Try to use index for WHERE clause
if stmt.where_clause:
try:
conditions = self._parse_where_clause(stmt.where_clause)
# Check if we can use an index for any condition
for field, op, value in conditions:
if field in table._indexes:
# Convert value to proper type
column = next((col for col in table.columns if col.name == field), None)
if column:
try:
if column.data_type == "integer":
value = int(value)
elif column.data_type == "float":
value = float(value)
except (ValueError, TypeError):
continue
# Use index for lookup
if op == '=':
results = table.find_by_index(field, value)
elif op in {'>', '>='}:
results = table.range_search(field, value, None)
elif op in {'<', '<='}:
results = table.range_search(field, None, value)
else: # op == '!='
# For inequality, we still need to scan
results = table.data
# Apply remaining conditions
filtered_results = []
for row in results:
if self._matches_all_conditions(row, conditions):
filtered_results.append(row)
return self._process_results(filtered_results, stmt)
except ValueError:
# If WHERE clause parsing fails, fall back to table scan
pass
# Fall back to full table scan
return self._table_scan(table, stmt)
def _matches_all_conditions(self, row: Dict[str, Any], conditions: List[Tuple[str, str, str]]) -> bool:
"""Check if row matches all conditions"""
for field, op, value in conditions:
row_value = row.get(field)
if row_value is None:
return False
# Convert value to proper type based on row_value
try:
if isinstance(row_value, int):
value = int(value)
elif isinstance(row_value, float):
value = float(value)
except (ValueError, TypeError):
return False
# Apply comparison
op_func = self._comparison_ops[op]
try:
if not op_func(row_value, value):
return False
except TypeError:
return False
return True
def _table_scan(self, table: Table, stmt: SelectStatement) -> List[Dict[str, Any]]:
"""Perform a full table scan with filtering"""
results = []
# Parse WHERE conditions if present
conditions = []
if stmt.where_clause:
try:
conditions = self._parse_where_clause(stmt.where_clause)
except ValueError:
# If parsing fails, return empty result
return []
# Process rows
for row in table.data:
# Apply WHERE clause
if conditions and not self._matches_all_conditions(row, conditions):
continue
# Select requested columns
if "*" in stmt.columns:
results.append(row.copy())
else:
filtered_row = {}
for col in stmt.columns:
if "count(" in col.lower():
filtered_row[col] = len(results)
else:
filtered_row[col] = row.get(col)
results.append(filtered_row)
return self._process_results(results, stmt)
def _process_results(self, rows: List[Dict[str, Any]], stmt: SelectStatement) -> List[Dict[str, Any]]:
"""Process result rows according to SELECT statement"""
results = []
for row in rows:
if "*" in stmt.columns:
results.append(row.copy())
else:
filtered_row = {}
for col in stmt.columns:
if "count(" in col.lower():
filtered_row[col] = len(results)
else:
filtered_row[col] = row.get(col)
results.append(filtered_row)
# Handle ORDER BY
if stmt.order_by:
for order_clause in stmt.order_by:
reverse = order_clause.direction.value == "DESC"
results.sort(
key=lambda x: (x.get(order_clause.column) is None, x.get(order_clause.column)),
reverse=reverse
)
# Handle LIMIT
if stmt.limit is not None:
results = results[:stmt.limit]
return results
def _execute_insert(self, stmt: InsertStatement, transaction: Optional[Transaction] = None) -> bool:
if stmt.table_name not in self.tables:
raise ValueError(f"Table {stmt.table_name} does not exist")
table = self.tables[stmt.table_name]
# If in transaction, acquire lock and track changes
if transaction:
transaction.locks.add(table.name)
# Track the changes for potential rollback
transaction.changes.append({
'type': 'INSERT',
'table': table.name,
'data': dict(zip(stmt.columns, stmt.values))
})
# Create dictionary of column-value pairs
row_data = {}
for col_name, value in zip(stmt.columns, stmt.values):
# Find the column definition
column = next((col for col in table.columns if col.name == col_name), None)
if not column:
raise ValueError(f"Column {col_name} does not exist")
# Convert value based on column type
if value is not None:
try:
if column.data_type == "integer":
row_data[col_name] = int(value)
elif column.data_type == "float":
row_data[col_name] = float(value)
elif column.data_type == "boolean":
if isinstance(value, str):
row_data[col_name] = value.lower() == 'true'
else:
row_data[col_name] = bool(value)
else: # string type
row_data[col_name] = str(value)
except (ValueError, TypeError):
raise ValueError(f"Invalid value for column {column.name}: {value}")
else:
row_data[col_name] = None
# Insert the data
return table.insert(row_data)

View File

@ -0,0 +1,51 @@
from typing import List, Dict, Any, Union
from dataclasses import dataclass
from enum import Enum
from .parser import SelectStatement, InsertStatement
from pyflaredb.sql.statistics import TableStatistics
from pyflaredb.table import Table
class JoinStrategy(Enum):
NESTED_LOOP = "nested_loop"
HASH_JOIN = "hash_join"
MERGE_JOIN = "merge_join"
class ScanType(Enum):
SEQUENTIAL = "sequential"
INDEX = "index"
@dataclass
class QueryPlan:
operation: str
strategy: Union[JoinStrategy, ScanType]
estimated_cost: float
children: List["QueryPlan"] = None
class QueryOptimizer:
def __init__(self, tables: Dict[str, "Table"], statistics: "TableStatistics"):
self.tables = tables
self.statistics = statistics
def optimize(self, statement) -> Any:
"""Generate an optimized query plan"""
if isinstance(statement, SelectStatement):
return self._optimize_select(statement)
elif isinstance(statement, InsertStatement):
return statement # No optimization needed for simple inserts
return statement # Return original statement if no optimization is needed
def _optimize_select(self, stmt: SelectStatement) -> SelectStatement:
"""Optimize SELECT query execution"""
# For now, return the original statement
# TODO: Implement actual optimization strategies
return stmt
def _estimate_cost(self, plan: QueryPlan) -> float:
"""Estimate the cost of a query plan"""
# Implementation for cost estimation
pass

176
pyflaredb/sql/parser.py Normal file
View File

@ -0,0 +1,176 @@
from dataclasses import dataclass
from typing import List, Optional, Any
from enum import Enum
class OrderDirection(Enum):
ASC = "ASC"
DESC = "DESC"
@dataclass
class OrderByClause:
column: str
direction: OrderDirection = OrderDirection.ASC
@dataclass
class SelectStatement:
table_name: str
columns: List[str]
where_clause: Optional[str] = None
group_by: Optional[List[str]] = None
order_by: Optional[List[OrderByClause]] = None
limit: Optional[int] = None
@dataclass
class InsertStatement:
table_name: str
columns: List[str]
values: List[Any]
class SQLParser:
@staticmethod
def parse_insert(sql: str) -> InsertStatement:
"""Parse INSERT statement"""
# Remove newlines and extra spaces
sql = ' '.join(sql.split())
# Extract table name
table_start = sql.find("INTO") + 4
table_end = sql.find("(", table_start)
table_name = sql[table_start:table_end].strip()
# Extract columns
cols_start = sql.find("(", table_end) + 1
cols_end = sql.find(")", cols_start)
columns = [col.strip() for col in sql[cols_start:cols_end].split(",")]
# Extract values
values_start = sql.find("VALUES", cols_end) + 6
values_start = sql.find("(", values_start) + 1
values_end = sql.find(")", values_start)
values_str = sql[values_start:values_end]
# Parse values while respecting quotes
values = []
current_value = ""
in_quotes = False
quote_char = None
for char in values_str:
if char in ["'", '"']:
if not in_quotes:
in_quotes = True
quote_char = char
elif quote_char == char:
in_quotes = False
quote_char = None
current_value += char
elif char == ',' and not in_quotes:
values.append(current_value.strip())
current_value = ""
else:
current_value += char
if current_value:
values.append(current_value.strip())
# Clean up values
cleaned_values = []
for value in values:
value = value.strip()
if value.startswith(("'", '"')) and value.endswith(("'", '"')):
# String value - keep quotes
cleaned_values.append(value)
elif value.lower() == 'true':
cleaned_values.append(True)
elif value.lower() == 'false':
cleaned_values.append(False)
elif value.lower() == 'null':
cleaned_values.append(None)
else:
try:
# Try to convert to number if possible
if '.' in value:
cleaned_values.append(float(value))
else:
cleaned_values.append(int(value))
except ValueError:
# If not a number, keep as is
cleaned_values.append(value)
if len(columns) != len(cleaned_values):
raise ValueError(f"Column count ({len(columns)}) doesn't match value count ({len(cleaned_values)})")
return InsertStatement(table_name=table_name, columns=columns, values=cleaned_values)
@staticmethod
def parse_select(sql: str) -> SelectStatement:
"""Parse SELECT statement"""
# Remove newlines and extra spaces
sql = ' '.join(sql.split())
# Extract table name
from_idx = sql.upper().find("FROM")
if from_idx == -1:
raise ValueError("Invalid SELECT statement: missing FROM clause")
# Extract columns
columns_str = sql[6:from_idx].strip()
columns = [col.strip() for col in columns_str.split(",")]
# Find all clause positions
where_idx = sql.upper().find("WHERE")
group_idx = sql.upper().find("GROUP BY")
order_idx = sql.upper().find("ORDER BY")
limit_idx = sql.upper().find("LIMIT")
# Find table name end position
table_end = min(x for x in [where_idx, group_idx, order_idx, limit_idx] if x != -1) if any(x != -1 for x in [where_idx, group_idx, order_idx, limit_idx]) else len(sql)
table_name = sql[from_idx + 4:table_end].strip()
# Parse WHERE clause
where_clause = None
if where_idx != -1:
where_end = min(x for x in [group_idx, order_idx, limit_idx] if x != -1) if any(x != -1 for x in [group_idx, order_idx, limit_idx]) else len(sql)
where_clause = sql[where_idx + 5:where_end].strip()
# Parse GROUP BY clause
group_by = None
if group_idx != -1:
group_end = min(x for x in [order_idx, limit_idx] if x != -1) if any(x != -1 for x in [order_idx, limit_idx]) else len(sql)
group_by_str = sql[group_idx + 8:group_end].strip()
group_by = [col.strip() for col in group_by_str.split(",")]
# Parse ORDER BY clause
order_by = None
if order_idx != -1:
order_end = limit_idx if limit_idx != -1 else len(sql)
order_str = sql[order_idx + 8:order_end].strip()
order_parts = order_str.split(",")
order_by = []
for part in order_parts:
part = part.strip()
if " DESC" in part.upper():
column = part[:part.upper().find(" DESC")].strip()
direction = OrderDirection.DESC
else:
column = part.replace(" ASC", "").strip()
direction = OrderDirection.ASC
order_by.append(OrderByClause(column=column, direction=direction))
# Parse LIMIT clause
limit = None
if limit_idx != -1:
limit_str = sql[limit_idx + 5:].strip()
try:
limit = int(limit_str)
except ValueError:
raise ValueError(f"Invalid LIMIT value: {limit_str}")
return SelectStatement(
table_name=table_name,
columns=columns,
where_clause=where_clause,
group_by=group_by,
order_by=order_by,
limit=limit
)

View File

@ -0,0 +1,38 @@
from typing import Dict, Any
import numpy as np
from pyflaredb.table import Table
class TableStatistics:
def __init__(self):
self.table_sizes: Dict[str, int] = {}
self.column_stats: Dict[str, Dict[str, Any]] = {}
def collect_statistics(self, table: "Table"):
"""Collect statistics for a table"""
self.table_sizes[table.name] = len(table.data)
for column in table.columns:
values = [row[column.name] for row in table.data if column.name in row]
if not values:
continue
stats = {
"distinct_count": len(set(values)),
"null_count": sum(1 for v in values if v is None),
"min": min(values) if values and None not in values else None,
"max": max(values) if values and None not in values else None,
}
if isinstance(values[0], (int, float)):
stats.update(
{
"mean": np.mean(values),
"std_dev": np.std(values),
"histogram": np.histogram(values, bins=100),
}
)
self.column_stats[f"{table.name}.{column.name}"] = stats

191
pyflaredb/table.py Normal file
View File

@ -0,0 +1,191 @@
from typing import Dict, List, Any, Optional
from dataclasses import dataclass
from datetime import datetime
from collections import defaultdict
from .indexing.btree import BTreeIndex
@dataclass
class Column:
name: str
data_type: str
nullable: bool = True
unique: bool = False
primary_key: bool = False
default: Any = None
class Table:
def __init__(self, name: str, columns: List[Column]):
self.name = name
self.columns = columns
self.data: List[Dict[str, Any]] = []
self._unique_indexes: Dict[str, Dict[Any, int]] = defaultdict(dict)
self._compiled_conditions = {}
self._indexes: Dict[str, BTreeIndex] = {}
# Validate column definitions
self._validate_columns()
def _validate_columns(self):
"""Validate column definitions"""
# Ensure only one primary key
primary_keys = [col for col in self.columns if col.primary_key]
if len(primary_keys) > 1:
raise ValueError("Table can only have one primary key")
# Validate data types
valid_types = {"string", "integer", "float", "boolean", "datetime"}
for col in self.columns:
if col.data_type.lower() not in valid_types:
raise ValueError(f"Invalid data type for column {col.name}: {col.data_type}")
def create_index(self, column_name: str) -> None:
"""Create a B-tree index for a column"""
if column_name not in {col.name for col in self.columns}:
raise ValueError(f"Column {column_name} does not exist")
# Create new index
index = BTreeIndex()
# Build index from existing data
for row_id, row in enumerate(self.data):
if column_name in row:
index.insert(row[column_name], row_id)
self._indexes[column_name] = index
def batch_insert(self, rows: List[Dict[str, Any]]) -> bool:
"""Efficiently insert multiple rows with index updates"""
# Pre-validate all rows
validated_rows = []
unique_values = defaultdict(set)
# Check unique constraints across all new rows
for row in rows:
converted_row = {}
# Validate required columns and defaults
for column in self.columns:
if not column.nullable and column.name not in row and column.default is None:
raise ValueError(f"Required column {column.name} is missing")
value = row.get(column.name, column.default)
# Type conversion
if value is not None:
try:
if column.data_type == "integer":
value = int(value)
elif column.data_type == "float":
value = float(value)
elif column.data_type == "boolean":
value = bool(value)
else: # string and datetime
value = str(value)
except (ValueError, TypeError):
raise ValueError(f"Invalid value for column {column.name}: {value}")
converted_row[column.name] = value
# Track unique values
if column.unique and value is not None:
if value in unique_values[column.name] or value in self._unique_indexes[column.name]:
raise ValueError(f"Unique constraint violated for column {column.name}")
unique_values[column.name].add(value)
validated_rows.append(converted_row)
# All rows validated, perform batch insert
start_id = len(self.data)
for i, row in enumerate(validated_rows):
row_id = start_id + i
# Update indexes
for column_name, index in self._indexes.items():
if column_name in row:
index.insert(row[column_name], row_id)
# Update unique indexes
for column in self.columns:
if column.unique:
value = row.get(column.name)
if value is not None:
self._unique_indexes[column.name][value] = row_id
self.data.append(row)
return True
def insert(self, row: Dict[str, Any]) -> bool:
"""Insert a single row (now uses batch_insert)"""
return self.batch_insert([row])
def to_dict(self) -> dict:
"""Convert table to dictionary for serialization"""
return {
"name": self.name,
"columns": [
{
"name": col.name,
"data_type": col.data_type,
"nullable": col.nullable,
"unique": col.unique,
"primary_key": col.primary_key
}
for col in self.columns
],
"data": self.data
}
@classmethod
def from_dict(cls, data: dict) -> 'Table':
"""Create table from dictionary"""
columns = [
Column(**col_data)
for col_data in data["columns"]
]
table = cls(data["name"], columns)
table.data = data["data"]
return table
def _validate_type(self, value: Any, expected_type: str) -> bool:
"""Validate that a value matches the expected data type"""
type_mapping = {
"string": str,
"integer": int,
"float": float,
"boolean": bool,
"datetime": datetime,
}
if expected_type not in type_mapping:
raise ValueError(f"Unsupported data type: {expected_type}")
expected_python_type = type_mapping[expected_type]
if not isinstance(value, expected_python_type):
try:
# Attempt to convert the value
expected_python_type(value)
except (ValueError, TypeError):
raise ValueError(
f"Value {value} is not of expected type {expected_type}"
)
return True
def find_by_index(self, column_name: str, value: Any) -> List[Dict[str, Any]]:
"""Find rows using an index"""
if column_name not in self._indexes:
raise ValueError(f"No index exists for column {column_name}")
index = self._indexes[column_name]
row_ids = index.search(value)
return [self.data[row_id] for row_id in row_ids]
def range_search(self, column_name: str, start_value: Any, end_value: Any) -> List[Dict[str, Any]]:
"""Perform a range search using an index"""
if column_name not in self._indexes:
raise ValueError(f"No index exists for column {column_name}")
index = self._indexes[column_name]
row_ids = index.range_search(start_value, end_value)
return [self.data[row_id] for row_id in row_ids]

81
pyflaredb/transaction.py Normal file
View File

@ -0,0 +1,81 @@
from typing import Dict, Any, Optional, Set, List
from enum import Enum
import time
import uuid
import threading
class TransactionState(Enum):
ACTIVE = "active"
COMMITTED = "committed"
ROLLED_BACK = "rolled_back"
class Transaction:
def __init__(self, tx_id: str):
self.id = tx_id
self.state = TransactionState.ACTIVE
self.start_time = time.time()
self.locks: Set[str] = set() # Set of table names that are locked
self.changes: List[Dict[str, Any]] = (
[]
) # List of changes made during transaction
class TransactionManager:
def __init__(self):
self.transactions: Dict[str, Transaction] = {}
self.lock = threading.Lock()
def begin_transaction(self) -> str:
"""Start a new transaction"""
with self.lock:
tx_id = str(uuid.uuid4())
self.transactions[tx_id] = Transaction(tx_id)
return tx_id
def commit(self, tx_id: str) -> bool:
"""Commit a transaction"""
with self.lock:
if tx_id not in self.transactions:
raise ValueError(f"Transaction {tx_id} not found")
tx = self.transactions[tx_id]
if tx.state != TransactionState.ACTIVE:
raise ValueError(f"Transaction {tx_id} is not active")
# Apply changes
tx.state = TransactionState.COMMITTED
# Release locks
tx.locks.clear()
return True
def rollback(self, tx_id: str) -> bool:
"""Rollback a transaction"""
with self.lock:
if tx_id not in self.transactions:
raise ValueError(f"Transaction {tx_id} not found")
tx = self.transactions[tx_id]
if tx.state != TransactionState.ACTIVE:
raise ValueError(f"Transaction {tx_id} is not active")
# Revert changes
tx.state = TransactionState.ROLLED_BACK
tx.changes.clear()
# Release locks
tx.locks.clear()
return True
def get_transaction(self, tx_id: str) -> Optional[Transaction]:
"""Get transaction by ID"""
return self.transactions.get(tx_id)
def is_active(self, tx_id: str) -> bool:
"""Check if a transaction is active"""
tx = self.get_transaction(tx_id)
return tx is not None and tx.state == TransactionState.ACTIVE

View File

@ -0,0 +1,63 @@
from typing import Dict, List, Any
from enum import Enum
import threading
from datetime import datetime
class TransactionState(Enum):
ACTIVE = "ACTIVE"
COMMITTED = "COMMITTED"
ROLLED_BACK = "ROLLED_BACK"
class Transaction:
def __init__(self, id: str):
self.id = id
self.state = TransactionState.ACTIVE
self.changes: List[Dict[str, Any]] = []
self.locks = set()
self.timestamp = datetime.utcnow()
class TransactionManager:
def __init__(self):
self.transactions: Dict[str, Transaction] = {}
self.lock = threading.Lock()
def begin_transaction(self) -> str:
"""Start a new transaction"""
with self.lock:
tx_id = str(len(self.transactions) + 1)
self.transactions[tx_id] = Transaction(tx_id)
return tx_id
def commit(self, tx_id: str):
"""Commit a transaction"""
with self.lock:
if tx_id not in self.transactions:
raise ValueError(f"Transaction {tx_id} not found")
tx = self.transactions[tx_id]
if tx.state != TransactionState.ACTIVE:
raise ValueError(f"Transaction {tx_id} is not active")
# Apply changes
self._apply_changes(tx)
tx.state = TransactionState.COMMITTED
def rollback(self, tx_id: str):
"""Rollback a transaction"""
with self.lock:
if tx_id not in self.transactions:
raise ValueError(f"Transaction {tx_id} not found")
tx = self.transactions[tx_id]
if tx.state != TransactionState.ACTIVE:
raise ValueError(f"Transaction {tx_id} is not active")
# Discard changes
tx.changes.clear()
tx.state = TransactionState.ROLLED_BACK
def _apply_changes(self, transaction: Transaction):
"""Apply transaction changes"""
for change in transaction.changes:
# Implementation of applying changes
pass

50
pyflaredb/versioning.py Normal file
View File

@ -0,0 +1,50 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, Optional
@dataclass
class Version:
timestamp: datetime
operation: str # 'INSERT', 'UPDATE', 'DELETE'
table_name: str
row_id: str
data: Dict[str, Any]
previous_version: Optional[str] = None # Hash of previous version
class VersionStore:
def __init__(self):
self.versions: List[Version] = []
self.current_version: str = None # Hash of current version
def add_version(self, version: Version):
"""Add a new version to the store"""
version_hash = self._calculate_hash(version)
self.versions.append(version)
self.current_version = version_hash
def get_state_at(self, timestamp: datetime) -> Dict[str, List[Dict[str, Any]]]:
"""Reconstruct database state at given timestamp"""
state = {}
relevant_versions = [v for v in self.versions if v.timestamp <= timestamp]
for version in relevant_versions:
if version.table_name not in state:
state[version.table_name] = []
if version.operation == "INSERT":
state[version.table_name].append(version.data)
elif version.operation == "DELETE":
state[version.table_name] = [
row
for row in state[version.table_name]
if row["id"] != version.row_id
]
elif version.operation == "UPDATE":
state[version.table_name] = [
version.data if row["id"] == version.row_id else row
for row in state[version.table_name]
]
return state

254
test.py Normal file
View File

@ -0,0 +1,254 @@
from pyflaredb.core import PyFlareDB
from pyflaredb.table import Column, Table
from pyflaredb.benchmark.suite import BenchmarkSuite
import time
from datetime import datetime
import random
import string
import json
from typing import List, Dict, Any
def generate_realistic_data(n: int) -> List[Dict[str, Any]]:
"""Generate realistic test data"""
domains = ['gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com', 'company.com']
cities = ['New York', 'London', 'Tokyo', 'Paris', 'Berlin', 'Sydney', 'Toronto']
data = []
for i in range(n):
# Generate realistic username
username = f"{random.choice(string.ascii_lowercase)}{random.choice(string.ascii_lowercase)}"
username += ''.join(random.choices(string.ascii_lowercase + string.digits, k=random.randint(6, 12)))
# Generate realistic email
email = f"{username}@{random.choice(domains)}"
# Generate JSON metadata
metadata = {
"city": random.choice(cities),
"last_login": f"2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}",
"preferences": {
"theme": random.choice(["light", "dark", "system"]),
"notifications": random.choice([True, False])
}
}
data.append({
"id": f"usr_{i:08d}",
"username": username,
"email": email,
"age": random.randint(18, 80),
"score": round(random.uniform(0, 100), 2),
"is_active": random.random() > 0.1, # 90% active users
"login_count": random.randint(1, 1000),
"metadata": json.dumps(metadata)
})
return data
def format_value(value):
"""Format value based on its type"""
if isinstance(value, (float, int)):
return f"{value:.4f}"
return str(value)
def test_database_features():
"""Test all database features with realistic workloads"""
print("\n=== Starting Realistic Database Tests ===")
# Initialize database
db = PyFlareDB("test.db")
# 1. Create test table with realistic schema
print("\n1. Setting up test environment...")
users_table = Table(
name="users",
columns=[
Column("id", "string", nullable=False, primary_key=True),
Column("username", "string", nullable=False, unique=True),
Column("email", "string", nullable=False),
Column("age", "integer", nullable=True),
Column("score", "float", nullable=True),
Column("is_active", "boolean", nullable=True, default=True),
Column("login_count", "integer", nullable=True, default=0),
Column("metadata", "string", nullable=True) # JSON string
],
)
db.tables["users"] = users_table
# Create indexes for commonly queried fields
users_table.create_index("age")
users_table.create_index("score")
users_table.create_index("login_count")
# 2. Performance Tests with Realistic Data
print("\n2. Running performance tests...")
# Generate test data
test_data = generate_realistic_data(1000) # 1000 realistic records
# Insert Performance (Single vs Batch)
print("\nInsert Performance:")
# Single Insert (OLTP-style)
start_time = time.time()
for record in test_data[:100]: # Test with first 100 records
# Properly escape the metadata string
metadata_str = record['metadata'].replace("'", "''")
# Format each value according to its type
values = [
f"'{record['id']}'", # string
f"'{record['username']}'", # string
f"'{record['email']}'", # string
str(record['age']), # integer
str(record['score']), # float
str(record['is_active']).lower(), # boolean
str(record['login_count']), # integer
f"'{metadata_str}'" # string (JSON)
]
query = f"""
INSERT INTO users
(id, username, email, age, score, is_active, login_count, metadata)
VALUES
({', '.join(values)})
"""
db.execute(query)
single_insert_time = time.time() - start_time
print(f"Single Insert (100 records, OLTP): {single_insert_time:.4f}s")
# Batch Insert (OLAP-style)
start_time = time.time()
batch_data = test_data[100:200] # Next 100 records
users_table.batch_insert(batch_data) # This should work as is
batch_insert_time = time.time() - start_time
print(f"Batch Insert (100 records, OLAP): {batch_insert_time:.4f}s")
# 3. Query Performance Tests
print("\nQuery Performance (OLTP vs OLAP):")
# OLTP-style queries (point queries, simple filters)
oltp_queries = [
("Single Record Lookup", "SELECT * FROM users WHERE id = 'usr_00000001'"),
("Simple Range Query", "SELECT * FROM users WHERE age > 30 LIMIT 10"),
("Active Users Count", "SELECT COUNT(*) FROM users WHERE is_active = true"),
("Recent Logins", "SELECT * FROM users WHERE login_count > 500 ORDER BY login_count DESC LIMIT 5")
]
# OLAP-style queries (aggregations, complex filters)
olap_queries = [
("Age Distribution", """
SELECT
CASE
WHEN age < 25 THEN 'Gen Z'
WHEN age < 40 THEN 'Millennial'
WHEN age < 55 THEN 'Gen X'
ELSE 'Boomer'
END as generation,
COUNT(*) as count
FROM users
GROUP BY generation
"""),
("User Engagement", """
SELECT
username,
score,
login_count
FROM users
WHERE score > 75
AND login_count > 100
ORDER BY score DESC
LIMIT 10
"""),
("Complex Analytics", """
SELECT
COUNT(*) as total_users,
AVG(score) as avg_score,
SUM(CASE WHEN is_active THEN 1 ELSE 0 END) as active_users
FROM users
WHERE age BETWEEN 25 AND 45
""")
]
print("\nOLTP Query Performance:")
for query_name, query in oltp_queries:
# First run (cold)
start_time = time.time()
db.execute(query)
cold_time = time.time() - start_time
# Second run (warm/cached)
start_time = time.time()
db.execute(query)
warm_time = time.time() - start_time
print(f"\n{query_name}:")
print(f" Cold run: {cold_time:.4f}s")
print(f" Warm run: {warm_time:.4f}s")
print(f" Cache improvement: {((cold_time - warm_time) / cold_time * 100):.1f}%")
print("\nOLAP Query Performance:")
for query_name, query in olap_queries:
start_time = time.time()
db.execute(query)
execution_time = time.time() - start_time
print(f"\n{query_name}: {execution_time:.4f}s")
# 4. Concurrent Operations Test
print("\nConcurrent Operations Simulation:")
start_time = time.time()
# Simulate mixed workload
for _ in range(100):
if random.random() < 0.8: # 80% reads
query = random.choice(oltp_queries)[1]
else: # 20% writes
record = generate_realistic_data(1)[0]
query = f"""
INSERT INTO users (id, username, email, age, score, is_active, login_count, metadata)
VALUES (
'{record['id']}',
'{record['username']}',
'{record['email']}',
{record['age']},
{record['score']},
{str(record['is_active']).lower()},
{record['login_count']},
'{record['metadata']}'
)
"""
db.execute(query)
mixed_workload_time = time.time() - start_time
print(f"Mixed Workload (100 operations): {mixed_workload_time:.4f}s")
# 5. Memory Usage Test
print("\nMemory Usage:")
import sys
memory_size = sys.getsizeof(db.tables["users"].data) / 1024 # KB
records_count = len(db.tables["users"].data)
print(f"Memory per record: {(memory_size / records_count):.2f} KB")
# 6. Run standard benchmark suite
print("\n6. Running standard benchmark suite...")
benchmark = BenchmarkSuite(db)
results = benchmark.run_benchmark(num_records=10000)
print("\nBenchmark Results:")
for test_name, metrics in results.items():
print(f"\n{test_name.upper()}:")
for metric, value in metrics.items():
print(f" {metric}: {format_value(value)}")
def main():
try:
test_database_features()
except Exception as e:
print(f"Test failed: {e}")
raise e
if __name__ == "__main__":
main()