commit d2cea458d486be7c74a8aafd3602804ce3668064 Author: tcsenpai Date: Mon Nov 25 14:38:35 2024 +0100 First commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..fc18853 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +__pycache__ +*.db diff --git a/pyflaredb/__init__.py b/pyflaredb/__init__.py new file mode 100644 index 0000000..6164172 --- /dev/null +++ b/pyflaredb/__init__.py @@ -0,0 +1,5 @@ +from .core import PyFlareDB +from .table import Table, Column +from .versioning import Version, VersionStore + +__all__ = ['PyFlareDB', 'Table', 'Column', 'Version', 'VersionStore'] \ No newline at end of file diff --git a/pyflaredb/benchmark/suite.py b/pyflaredb/benchmark/suite.py new file mode 100644 index 0000000..b2370d5 --- /dev/null +++ b/pyflaredb/benchmark/suite.py @@ -0,0 +1,125 @@ +import time +from typing import List, Dict, Any +import random +import string +from ..core import PyFlareDB + + +class BenchmarkSuite: + def __init__(self, db: PyFlareDB): + self.db = db + + def run_benchmark(self, num_records: int = 10000): + """Run comprehensive benchmark""" + results = { + "insert": self._benchmark_insert(num_records), + "select": self._benchmark_select(num_records), + "index": self._benchmark_index_performance(num_records), + "complex_query": self._benchmark_complex_queries(num_records), + } + return results + + def _benchmark_insert(self, num_records: int) -> Dict[str, float]: + start_time = time.time() + batch_times = [] + + for i in range(0, num_records, 1000): + batch_start = time.time() + self._insert_batch(min(1000, num_records - i)) + batch_times.append(time.time() - batch_start) + + total_time = time.time() - start_time + return { + "total_time": total_time, + "records_per_second": num_records / total_time, + "avg_batch_time": sum(batch_times) / len(batch_times), + } + + def _insert_batch(self, size: int): + """Insert a batch of random records""" + tx_id = None + try: + tx_id = self.db.transaction_manager.begin_transaction() + + for _ in range(size): + query = ( + "INSERT INTO users (id, username, email, age) " + f"VALUES ('{self._random_string(10)}', " + f"'{self._random_string(8)}', " + f"'{self._random_string(8)}@example.com', " + f"{random.randint(18, 80)})" + ) + try: + self.db.execute(query) + except Exception as e: + print(f"Failed to insert record: {e}") + print(f"Query was: {query}") + raise + + self.db.transaction_manager.commit(tx_id) + + except Exception as e: + if tx_id is not None: + try: + self.db.transaction_manager.rollback(tx_id) + except ValueError: + pass + raise + + def _benchmark_select(self, num_records: int) -> Dict[str, float]: + """Benchmark SELECT queries""" + queries = [ + "SELECT * FROM users WHERE age > 30", + "SELECT username, email FROM users WHERE age < 25", + "SELECT COUNT(*) FROM users", + ] + + results = {} + for query in queries: + start_time = time.time() + try: + self.db.execute(query) + query_time = time.time() - start_time + results[query] = query_time + except Exception as e: + results[query] = f"Error: {e}" + + return results + + def _benchmark_index_performance(self, num_records: int) -> Dict[str, float]: + """Benchmark index performance""" + # TODO: Implement index benchmarking + return {"index_creation_time": 0.0} + + def _benchmark_complex_queries(self, num_records: int) -> Dict[str, float]: + """Benchmark complex queries""" + complex_queries = [ + """ + SELECT username, COUNT(*) as count + FROM users + GROUP BY username + """, + """ + SELECT * FROM users + WHERE age > 30 + ORDER BY username DESC + LIMIT 10 + """, + ] + + results = {} + for query in complex_queries: + start_time = time.time() + try: + self.db.execute(query) + query_time = time.time() - start_time + results[query.strip()] = query_time + except Exception as e: + results[query.strip()] = f"Error: {e}" + + return results + + @staticmethod + def _random_string(length: int) -> str: + """Generate a random string of specified length""" + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) diff --git a/pyflaredb/cache/query_cache.py b/pyflaredb/cache/query_cache.py new file mode 100644 index 0000000..bbd1c8c --- /dev/null +++ b/pyflaredb/cache/query_cache.py @@ -0,0 +1,39 @@ +from typing import Dict, Any, Optional +import time +import hashlib +from collections import OrderedDict + +class QueryCache: + def __init__(self, capacity: int = 1000, ttl: int = 300): + self.capacity = capacity + self.ttl = ttl + self.cache = OrderedDict() + + def get(self, query: str) -> Optional[Any]: + """Get cached query result""" + query_hash = self._hash_query(query) + if query_hash in self.cache: + entry = self.cache[query_hash] + if time.time() - entry['timestamp'] < self.ttl: + self.cache.move_to_end(query_hash) + return entry['result'] + else: + del self.cache[query_hash] + return None + + def set(self, query: str, result: Any): + """Cache query result""" + if len(self.cache) >= self.capacity: + self.cache.popitem(last=False) + + query_hash = self._hash_query(query) + self.cache[query_hash] = { + 'result': result, + 'timestamp': time.time() + } + + def _hash_query(self, query: str) -> str: + return hashlib.sha256(query.encode()).hexdigest() + + def clear(self): + self.cache.clear() \ No newline at end of file diff --git a/pyflaredb/core.py b/pyflaredb/core.py new file mode 100644 index 0000000..fb7b292 --- /dev/null +++ b/pyflaredb/core.py @@ -0,0 +1,90 @@ +from typing import Dict, List, Any, Optional +from .table import Table +from .sql.parser import SQLParser, SelectStatement, InsertStatement +from .sql.executor import QueryExecutor +from .sql.optimizer import QueryOptimizer +from .sql.statistics import TableStatistics +from .transaction import TransactionManager + + +class PyFlareDB: + def __init__(self, db_path: str): + """Initialize the database""" + self.db_path = db_path + self.tables: Dict[str, Table] = {} + self.parser = SQLParser() + self.statistics = TableStatistics() + self.optimizer = QueryOptimizer(self.tables, self.statistics) + self.executor = QueryExecutor(self.tables) + self.transaction_manager = TransactionManager() + self._query_cache = {} + + def begin_transaction(self) -> str: + """Begin a new transaction""" + return self.transaction_manager.begin_transaction() + + def commit_transaction(self, tx_id: str) -> bool: + """Commit a transaction""" + return self.transaction_manager.commit(tx_id) + + def rollback_transaction(self, tx_id: str) -> bool: + """Rollback a transaction""" + return self.transaction_manager.rollback(tx_id) + + def create_table(self, table: Table) -> None: + """Create a new table""" + if table.name in self.tables: + raise ValueError(f"Table {table.name} already exists") + self.tables[table.name] = table + self.statistics.collect_statistics(table) + + def drop_table(self, table_name: str) -> None: + """Drop a table""" + if table_name not in self.tables: + raise ValueError(f"Table {table_name} does not exist") + del self.tables[table_name] + + def execute( + self, sql: str, tx_id: Optional[str] = None + ) -> Optional[List[Dict[str, Any]]]: + """Execute a SQL query""" + try: + # Check query cache for non-transactional SELECT queries + if tx_id is None and sql in self._query_cache: + return self._query_cache[sql] + + # Parse SQL + if sql.strip().upper().startswith("SELECT"): + statement = self.parser.parse_select(sql) + elif sql.strip().upper().startswith("INSERT"): + statement = self.parser.parse_insert(sql) + else: + raise ValueError("Unsupported SQL statement type") + + # Get transaction if provided + tx = None + if tx_id: + tx = self.transaction_manager.get_transaction(tx_id) + if not tx: + raise ValueError(f"Transaction {tx_id} does not exist") + + # Optimize query plan + optimized_plan = self.optimizer.optimize(statement) + + # Execute query + result = self.executor.execute(optimized_plan, transaction=tx) + + # Cache SELECT results for non-transactional queries + if tx_id is None and isinstance(statement, SelectStatement): + self._query_cache[sql] = result + + return result + + except Exception as e: + # Clear cache on error + self._query_cache.clear() + raise e + + def clear_cache(self) -> None: + """Clear the query cache""" + self._query_cache.clear() diff --git a/pyflaredb/indexing/btree.py b/pyflaredb/indexing/btree.py new file mode 100644 index 0000000..fbe207e --- /dev/null +++ b/pyflaredb/indexing/btree.py @@ -0,0 +1,137 @@ +from typing import Any, Dict, List, Optional, Tuple +from dataclasses import dataclass +from collections import deque + +@dataclass +class Node: + keys: List[Any] + values: List[List[int]] # List of row IDs for each key (handling duplicates) + children: List['Node'] + is_leaf: bool = True + +class BTreeIndex: + def __init__(self, order: int = 100): + self.root = Node([], [], []) + self.order = order # Maximum number of children per node + + def insert(self, key: Any, row_id: int) -> None: + """Insert a key-value pair into the B-tree""" + if len(self.root.keys) == (2 * self.order) - 1: + # Split root if full + new_root = Node([], [], [], False) + new_root.children.append(self.root) + self._split_child(new_root, 0) + self.root = new_root + self._insert_non_full(self.root, key, row_id) + + def search(self, key: Any) -> List[int]: + """Search for a key and return all matching row IDs""" + return self._search_node(self.root, key) + + def range_search(self, start_key: Any, end_key: Any) -> List[int]: + """Perform a range search and return all matching row IDs""" + result = [] + self._range_search_node(self.root, start_key, end_key, result) + return result + + def _split_child(self, parent: Node, child_index: int) -> None: + """Split a full child node""" + order = self.order + child = parent.children[child_index] + new_node = Node([], [], [], child.is_leaf) + + # Move the median key to the parent + median = order - 1 + parent.keys.insert(child_index, child.keys[median]) + parent.values.insert(child_index, child.values[median]) + parent.children.insert(child_index + 1, new_node) + + # Move half of the keys to the new node + new_node.keys = child.keys[median + 1:] + new_node.values = child.values[median + 1:] + child.keys = child.keys[:median] + child.values = child.values[:median] + + # Move children if not a leaf + if not child.is_leaf: + new_node.children = child.children[median + 1:] + child.children = child.children[:median + 1] + + def _insert_non_full(self, node: Node, key: Any, row_id: int) -> None: + """Insert into a non-full node""" + i = len(node.keys) - 1 + + if node.is_leaf: + # Insert into leaf node + while i >= 0 and self._compare_keys(key, node.keys[i]) < 0: + i -= 1 + i += 1 + + # Handle duplicate keys + if i > 0 and self._compare_keys(key, node.keys[i-1]) == 0: + node.values[i-1].append(row_id) + else: + node.keys.insert(i, key) + node.values.insert(i, [row_id]) + else: + # Find the child to insert into + while i >= 0 and self._compare_keys(key, node.keys[i]) < 0: + i -= 1 + i += 1 + + if len(node.children[i].keys) == (2 * self.order) - 1: + self._split_child(node, i) + if self._compare_keys(key, node.keys[i]) > 0: + i += 1 + + self._insert_non_full(node.children[i], key, row_id) + + def _search_node(self, node: Node, key: Any) -> List[int]: + """Search for a key in a node""" + i = 0 + while i < len(node.keys) and self._compare_keys(key, node.keys[i]) > 0: + i += 1 + + if i < len(node.keys) and self._compare_keys(key, node.keys[i]) == 0: + return node.values[i] + elif node.is_leaf: + return [] + else: + return self._search_node(node.children[i], key) + + def _range_search_node(self, node: Node, start_key: Any, end_key: Any, result: List[int]) -> None: + """Perform range search on a node""" + i = 0 + while i < len(node.keys) and self._compare_keys(start_key, node.keys[i]) > 0: + i += 1 + + if node.is_leaf: + while i < len(node.keys) and self._compare_keys(node.keys[i], end_key) <= 0: + result.extend(node.values[i]) + i += 1 + else: + if i < len(node.keys): + self._range_search_node(node.children[i], start_key, end_key, result) + while i < len(node.keys) and self._compare_keys(node.keys[i], end_key) <= 0: + result.extend(node.values[i]) + i += 1 + if i < len(node.children): + self._range_search_node(node.children[i], start_key, end_key, result) + + @staticmethod + def _compare_keys(key1: Any, key2: Any) -> int: + """Compare two keys, handling different types""" + if key1 is None or key2 is None: + if key1 is None and key2 is None: + return 0 + return -1 if key1 is None else 1 + + try: + if key1 < key2: + return -1 + elif key1 > key2: + return 1 + return 0 + except TypeError: + # Handle incomparable types + return 0 \ No newline at end of file diff --git a/pyflaredb/monitoring/metrics.py b/pyflaredb/monitoring/metrics.py new file mode 100644 index 0000000..e3f14bb --- /dev/null +++ b/pyflaredb/monitoring/metrics.py @@ -0,0 +1,33 @@ +from typing import Dict, List +import time +from collections import deque +import threading + + +class PerformanceMetrics: + def __init__(self, window_size: int = 1000): + self.window_size = window_size + self.query_times: Dict[str, deque] = {} + self.lock = threading.Lock() + + def record_query(self, query_type: str, execution_time: float): + """Record query execution time""" + with self.lock: + if query_type not in self.query_times: + self.query_times[query_type] = deque(maxlen=self.window_size) + self.query_times[query_type].append(execution_time) + + def get_metrics(self) -> Dict[str, Dict[str, float]]: + """Get performance metrics""" + metrics = {} + with self.lock: + for query_type, times in self.query_times.items(): + if not times: + continue + metrics[query_type] = { + "avg_time": sum(times) / len(times), + "max_time": max(times), + "min_time": min(times), + "count": len(times), + } + return metrics diff --git a/pyflaredb/sql/executor.py b/pyflaredb/sql/executor.py new file mode 100644 index 0000000..029f6c1 --- /dev/null +++ b/pyflaredb/sql/executor.py @@ -0,0 +1,245 @@ +from typing import List, Dict, Any, Callable, Tuple, Optional +import operator +from ..table import Table +from .parser import SelectStatement, InsertStatement +from ..transaction import Transaction + + +class QueryExecutor: + def __init__(self, tables: Dict[str, Table]): + self.tables = tables + self._compiled_conditions = {} + self._comparison_ops = { + '>': operator.gt, + '<': operator.lt, + '>=': operator.ge, + '<=': operator.le, + '=': operator.eq, + '!=': operator.ne + } + + def _parse_where_clause(self, where_clause: str) -> List[Tuple[str, str, str]]: + """Parse WHERE clause into list of (field, operator, value) tuples""" + conditions = [] + # Split on AND if present + subclauses = [c.strip() for c in where_clause.split(' AND ')] + + for subclause in subclauses: + # Find the operator + operator_found = None + for op in ['>=', '<=', '>', '<', '=', '!=']: + if op in subclause: + operator_found = op + field, value = subclause.split(op) + conditions.append((field.strip(), op, value.strip())) + break + + if not operator_found: + raise ValueError(f"Invalid condition: {subclause}") + + return conditions + + def execute(self, statement, transaction: Optional[Transaction] = None): + """Execute a parsed SQL statement""" + if isinstance(statement, SelectStatement): + return self._execute_select(statement, transaction) + elif isinstance(statement, InsertStatement): + return self._execute_insert(statement, transaction) + elif statement is None: + raise ValueError("No statement to execute") + else: + raise ValueError(f"Unsupported statement type: {type(statement)}") + + def _execute_select(self, stmt: SelectStatement, transaction: Optional[Transaction] = None) -> List[Dict[str, Any]]: + if stmt.table_name not in self.tables: + raise ValueError(f"Table {stmt.table_name} does not exist") + + table = self.tables[stmt.table_name] + + # If in transaction, check for locks + if transaction and table.name in transaction.locks: + # Handle transaction isolation level logic here + pass + + # Handle COUNT(*) separately + if len(stmt.columns) == 1 and stmt.columns[0].lower() == "count(*)": + return [{"count": len(table.data)}] + + # Try to use index for WHERE clause + if stmt.where_clause: + try: + conditions = self._parse_where_clause(stmt.where_clause) + + # Check if we can use an index for any condition + for field, op, value in conditions: + if field in table._indexes: + # Convert value to proper type + column = next((col for col in table.columns if col.name == field), None) + if column: + try: + if column.data_type == "integer": + value = int(value) + elif column.data_type == "float": + value = float(value) + except (ValueError, TypeError): + continue + + # Use index for lookup + if op == '=': + results = table.find_by_index(field, value) + elif op in {'>', '>='}: + results = table.range_search(field, value, None) + elif op in {'<', '<='}: + results = table.range_search(field, None, value) + else: # op == '!=' + # For inequality, we still need to scan + results = table.data + + # Apply remaining conditions + filtered_results = [] + for row in results: + if self._matches_all_conditions(row, conditions): + filtered_results.append(row) + + return self._process_results(filtered_results, stmt) + except ValueError: + # If WHERE clause parsing fails, fall back to table scan + pass + + # Fall back to full table scan + return self._table_scan(table, stmt) + + def _matches_all_conditions(self, row: Dict[str, Any], conditions: List[Tuple[str, str, str]]) -> bool: + """Check if row matches all conditions""" + for field, op, value in conditions: + row_value = row.get(field) + if row_value is None: + return False + + # Convert value to proper type based on row_value + try: + if isinstance(row_value, int): + value = int(value) + elif isinstance(row_value, float): + value = float(value) + except (ValueError, TypeError): + return False + + # Apply comparison + op_func = self._comparison_ops[op] + try: + if not op_func(row_value, value): + return False + except TypeError: + return False + + return True + + def _table_scan(self, table: Table, stmt: SelectStatement) -> List[Dict[str, Any]]: + """Perform a full table scan with filtering""" + results = [] + + # Parse WHERE conditions if present + conditions = [] + if stmt.where_clause: + try: + conditions = self._parse_where_clause(stmt.where_clause) + except ValueError: + # If parsing fails, return empty result + return [] + + # Process rows + for row in table.data: + # Apply WHERE clause + if conditions and not self._matches_all_conditions(row, conditions): + continue + + # Select requested columns + if "*" in stmt.columns: + results.append(row.copy()) + else: + filtered_row = {} + for col in stmt.columns: + if "count(" in col.lower(): + filtered_row[col] = len(results) + else: + filtered_row[col] = row.get(col) + results.append(filtered_row) + + return self._process_results(results, stmt) + + def _process_results(self, rows: List[Dict[str, Any]], stmt: SelectStatement) -> List[Dict[str, Any]]: + """Process result rows according to SELECT statement""" + results = [] + for row in rows: + if "*" in stmt.columns: + results.append(row.copy()) + else: + filtered_row = {} + for col in stmt.columns: + if "count(" in col.lower(): + filtered_row[col] = len(results) + else: + filtered_row[col] = row.get(col) + results.append(filtered_row) + + # Handle ORDER BY + if stmt.order_by: + for order_clause in stmt.order_by: + reverse = order_clause.direction.value == "DESC" + results.sort( + key=lambda x: (x.get(order_clause.column) is None, x.get(order_clause.column)), + reverse=reverse + ) + + # Handle LIMIT + if stmt.limit is not None: + results = results[:stmt.limit] + + return results + + def _execute_insert(self, stmt: InsertStatement, transaction: Optional[Transaction] = None) -> bool: + if stmt.table_name not in self.tables: + raise ValueError(f"Table {stmt.table_name} does not exist") + + table = self.tables[stmt.table_name] + + # If in transaction, acquire lock and track changes + if transaction: + transaction.locks.add(table.name) + # Track the changes for potential rollback + transaction.changes.append({ + 'type': 'INSERT', + 'table': table.name, + 'data': dict(zip(stmt.columns, stmt.values)) + }) + + # Create dictionary of column-value pairs + row_data = {} + for col_name, value in zip(stmt.columns, stmt.values): + # Find the column definition + column = next((col for col in table.columns if col.name == col_name), None) + if not column: + raise ValueError(f"Column {col_name} does not exist") + + # Convert value based on column type + if value is not None: + try: + if column.data_type == "integer": + row_data[col_name] = int(value) + elif column.data_type == "float": + row_data[col_name] = float(value) + elif column.data_type == "boolean": + if isinstance(value, str): + row_data[col_name] = value.lower() == 'true' + else: + row_data[col_name] = bool(value) + else: # string type + row_data[col_name] = str(value) + except (ValueError, TypeError): + raise ValueError(f"Invalid value for column {column.name}: {value}") + else: + row_data[col_name] = None + + # Insert the data + return table.insert(row_data) diff --git a/pyflaredb/sql/optimizer.py b/pyflaredb/sql/optimizer.py new file mode 100644 index 0000000..45928dc --- /dev/null +++ b/pyflaredb/sql/optimizer.py @@ -0,0 +1,51 @@ +from typing import List, Dict, Any, Union +from dataclasses import dataclass +from enum import Enum +from .parser import SelectStatement, InsertStatement + +from pyflaredb.sql.statistics import TableStatistics +from pyflaredb.table import Table + + +class JoinStrategy(Enum): + NESTED_LOOP = "nested_loop" + HASH_JOIN = "hash_join" + MERGE_JOIN = "merge_join" + + +class ScanType(Enum): + SEQUENTIAL = "sequential" + INDEX = "index" + + +@dataclass +class QueryPlan: + operation: str + strategy: Union[JoinStrategy, ScanType] + estimated_cost: float + children: List["QueryPlan"] = None + + +class QueryOptimizer: + def __init__(self, tables: Dict[str, "Table"], statistics: "TableStatistics"): + self.tables = tables + self.statistics = statistics + + def optimize(self, statement) -> Any: + """Generate an optimized query plan""" + if isinstance(statement, SelectStatement): + return self._optimize_select(statement) + elif isinstance(statement, InsertStatement): + return statement # No optimization needed for simple inserts + return statement # Return original statement if no optimization is needed + + def _optimize_select(self, stmt: SelectStatement) -> SelectStatement: + """Optimize SELECT query execution""" + # For now, return the original statement + # TODO: Implement actual optimization strategies + return stmt + + def _estimate_cost(self, plan: QueryPlan) -> float: + """Estimate the cost of a query plan""" + # Implementation for cost estimation + pass diff --git a/pyflaredb/sql/parser.py b/pyflaredb/sql/parser.py new file mode 100644 index 0000000..0763a83 --- /dev/null +++ b/pyflaredb/sql/parser.py @@ -0,0 +1,176 @@ +from dataclasses import dataclass +from typing import List, Optional, Any +from enum import Enum + +class OrderDirection(Enum): + ASC = "ASC" + DESC = "DESC" + +@dataclass +class OrderByClause: + column: str + direction: OrderDirection = OrderDirection.ASC + +@dataclass +class SelectStatement: + table_name: str + columns: List[str] + where_clause: Optional[str] = None + group_by: Optional[List[str]] = None + order_by: Optional[List[OrderByClause]] = None + limit: Optional[int] = None + +@dataclass +class InsertStatement: + table_name: str + columns: List[str] + values: List[Any] + +class SQLParser: + @staticmethod + def parse_insert(sql: str) -> InsertStatement: + """Parse INSERT statement""" + # Remove newlines and extra spaces + sql = ' '.join(sql.split()) + + # Extract table name + table_start = sql.find("INTO") + 4 + table_end = sql.find("(", table_start) + table_name = sql[table_start:table_end].strip() + + # Extract columns + cols_start = sql.find("(", table_end) + 1 + cols_end = sql.find(")", cols_start) + columns = [col.strip() for col in sql[cols_start:cols_end].split(",")] + + # Extract values + values_start = sql.find("VALUES", cols_end) + 6 + values_start = sql.find("(", values_start) + 1 + values_end = sql.find(")", values_start) + values_str = sql[values_start:values_end] + + # Parse values while respecting quotes + values = [] + current_value = "" + in_quotes = False + quote_char = None + + for char in values_str: + if char in ["'", '"']: + if not in_quotes: + in_quotes = True + quote_char = char + elif quote_char == char: + in_quotes = False + quote_char = None + current_value += char + elif char == ',' and not in_quotes: + values.append(current_value.strip()) + current_value = "" + else: + current_value += char + + if current_value: + values.append(current_value.strip()) + + # Clean up values + cleaned_values = [] + for value in values: + value = value.strip() + if value.startswith(("'", '"')) and value.endswith(("'", '"')): + # String value - keep quotes + cleaned_values.append(value) + elif value.lower() == 'true': + cleaned_values.append(True) + elif value.lower() == 'false': + cleaned_values.append(False) + elif value.lower() == 'null': + cleaned_values.append(None) + else: + try: + # Try to convert to number if possible + if '.' in value: + cleaned_values.append(float(value)) + else: + cleaned_values.append(int(value)) + except ValueError: + # If not a number, keep as is + cleaned_values.append(value) + + if len(columns) != len(cleaned_values): + raise ValueError(f"Column count ({len(columns)}) doesn't match value count ({len(cleaned_values)})") + + return InsertStatement(table_name=table_name, columns=columns, values=cleaned_values) + + @staticmethod + def parse_select(sql: str) -> SelectStatement: + """Parse SELECT statement""" + # Remove newlines and extra spaces + sql = ' '.join(sql.split()) + + # Extract table name + from_idx = sql.upper().find("FROM") + if from_idx == -1: + raise ValueError("Invalid SELECT statement: missing FROM clause") + + # Extract columns + columns_str = sql[6:from_idx].strip() + columns = [col.strip() for col in columns_str.split(",")] + + # Find all clause positions + where_idx = sql.upper().find("WHERE") + group_idx = sql.upper().find("GROUP BY") + order_idx = sql.upper().find("ORDER BY") + limit_idx = sql.upper().find("LIMIT") + + # Find table name end position + table_end = min(x for x in [where_idx, group_idx, order_idx, limit_idx] if x != -1) if any(x != -1 for x in [where_idx, group_idx, order_idx, limit_idx]) else len(sql) + table_name = sql[from_idx + 4:table_end].strip() + + # Parse WHERE clause + where_clause = None + if where_idx != -1: + where_end = min(x for x in [group_idx, order_idx, limit_idx] if x != -1) if any(x != -1 for x in [group_idx, order_idx, limit_idx]) else len(sql) + where_clause = sql[where_idx + 5:where_end].strip() + + # Parse GROUP BY clause + group_by = None + if group_idx != -1: + group_end = min(x for x in [order_idx, limit_idx] if x != -1) if any(x != -1 for x in [order_idx, limit_idx]) else len(sql) + group_by_str = sql[group_idx + 8:group_end].strip() + group_by = [col.strip() for col in group_by_str.split(",")] + + # Parse ORDER BY clause + order_by = None + if order_idx != -1: + order_end = limit_idx if limit_idx != -1 else len(sql) + order_str = sql[order_idx + 8:order_end].strip() + order_parts = order_str.split(",") + order_by = [] + for part in order_parts: + part = part.strip() + if " DESC" in part.upper(): + column = part[:part.upper().find(" DESC")].strip() + direction = OrderDirection.DESC + else: + column = part.replace(" ASC", "").strip() + direction = OrderDirection.ASC + order_by.append(OrderByClause(column=column, direction=direction)) + + # Parse LIMIT clause + limit = None + if limit_idx != -1: + limit_str = sql[limit_idx + 5:].strip() + try: + limit = int(limit_str) + except ValueError: + raise ValueError(f"Invalid LIMIT value: {limit_str}") + + return SelectStatement( + table_name=table_name, + columns=columns, + where_clause=where_clause, + group_by=group_by, + order_by=order_by, + limit=limit + ) \ No newline at end of file diff --git a/pyflaredb/sql/statistics.py b/pyflaredb/sql/statistics.py new file mode 100644 index 0000000..1f2bef3 --- /dev/null +++ b/pyflaredb/sql/statistics.py @@ -0,0 +1,38 @@ +from typing import Dict, Any +import numpy as np + +from pyflaredb.table import Table + + +class TableStatistics: + def __init__(self): + self.table_sizes: Dict[str, int] = {} + self.column_stats: Dict[str, Dict[str, Any]] = {} + + def collect_statistics(self, table: "Table"): + """Collect statistics for a table""" + self.table_sizes[table.name] = len(table.data) + + for column in table.columns: + values = [row[column.name] for row in table.data if column.name in row] + + if not values: + continue + + stats = { + "distinct_count": len(set(values)), + "null_count": sum(1 for v in values if v is None), + "min": min(values) if values and None not in values else None, + "max": max(values) if values and None not in values else None, + } + + if isinstance(values[0], (int, float)): + stats.update( + { + "mean": np.mean(values), + "std_dev": np.std(values), + "histogram": np.histogram(values, bins=100), + } + ) + + self.column_stats[f"{table.name}.{column.name}"] = stats diff --git a/pyflaredb/table.py b/pyflaredb/table.py new file mode 100644 index 0000000..d0f166c --- /dev/null +++ b/pyflaredb/table.py @@ -0,0 +1,191 @@ +from typing import Dict, List, Any, Optional +from dataclasses import dataclass +from datetime import datetime +from collections import defaultdict +from .indexing.btree import BTreeIndex + +@dataclass +class Column: + name: str + data_type: str + nullable: bool = True + unique: bool = False + primary_key: bool = False + default: Any = None + +class Table: + def __init__(self, name: str, columns: List[Column]): + self.name = name + self.columns = columns + self.data: List[Dict[str, Any]] = [] + self._unique_indexes: Dict[str, Dict[Any, int]] = defaultdict(dict) + self._compiled_conditions = {} + self._indexes: Dict[str, BTreeIndex] = {} + + # Validate column definitions + self._validate_columns() + + def _validate_columns(self): + """Validate column definitions""" + # Ensure only one primary key + primary_keys = [col for col in self.columns if col.primary_key] + if len(primary_keys) > 1: + raise ValueError("Table can only have one primary key") + + # Validate data types + valid_types = {"string", "integer", "float", "boolean", "datetime"} + for col in self.columns: + if col.data_type.lower() not in valid_types: + raise ValueError(f"Invalid data type for column {col.name}: {col.data_type}") + + def create_index(self, column_name: str) -> None: + """Create a B-tree index for a column""" + if column_name not in {col.name for col in self.columns}: + raise ValueError(f"Column {column_name} does not exist") + + # Create new index + index = BTreeIndex() + + # Build index from existing data + for row_id, row in enumerate(self.data): + if column_name in row: + index.insert(row[column_name], row_id) + + self._indexes[column_name] = index + + def batch_insert(self, rows: List[Dict[str, Any]]) -> bool: + """Efficiently insert multiple rows with index updates""" + # Pre-validate all rows + validated_rows = [] + unique_values = defaultdict(set) + + # Check unique constraints across all new rows + for row in rows: + converted_row = {} + # Validate required columns and defaults + for column in self.columns: + if not column.nullable and column.name not in row and column.default is None: + raise ValueError(f"Required column {column.name} is missing") + + value = row.get(column.name, column.default) + + # Type conversion + if value is not None: + try: + if column.data_type == "integer": + value = int(value) + elif column.data_type == "float": + value = float(value) + elif column.data_type == "boolean": + value = bool(value) + else: # string and datetime + value = str(value) + except (ValueError, TypeError): + raise ValueError(f"Invalid value for column {column.name}: {value}") + + converted_row[column.name] = value + + # Track unique values + if column.unique and value is not None: + if value in unique_values[column.name] or value in self._unique_indexes[column.name]: + raise ValueError(f"Unique constraint violated for column {column.name}") + unique_values[column.name].add(value) + + validated_rows.append(converted_row) + + # All rows validated, perform batch insert + start_id = len(self.data) + for i, row in enumerate(validated_rows): + row_id = start_id + i + + # Update indexes + for column_name, index in self._indexes.items(): + if column_name in row: + index.insert(row[column_name], row_id) + + # Update unique indexes + for column in self.columns: + if column.unique: + value = row.get(column.name) + if value is not None: + self._unique_indexes[column.name][value] = row_id + + self.data.append(row) + + return True + + def insert(self, row: Dict[str, Any]) -> bool: + """Insert a single row (now uses batch_insert)""" + return self.batch_insert([row]) + + def to_dict(self) -> dict: + """Convert table to dictionary for serialization""" + return { + "name": self.name, + "columns": [ + { + "name": col.name, + "data_type": col.data_type, + "nullable": col.nullable, + "unique": col.unique, + "primary_key": col.primary_key + } + for col in self.columns + ], + "data": self.data + } + + @classmethod + def from_dict(cls, data: dict) -> 'Table': + """Create table from dictionary""" + columns = [ + Column(**col_data) + for col_data in data["columns"] + ] + table = cls(data["name"], columns) + table.data = data["data"] + return table + + def _validate_type(self, value: Any, expected_type: str) -> bool: + """Validate that a value matches the expected data type""" + type_mapping = { + "string": str, + "integer": int, + "float": float, + "boolean": bool, + "datetime": datetime, + } + + if expected_type not in type_mapping: + raise ValueError(f"Unsupported data type: {expected_type}") + + expected_python_type = type_mapping[expected_type] + + if not isinstance(value, expected_python_type): + try: + # Attempt to convert the value + expected_python_type(value) + except (ValueError, TypeError): + raise ValueError( + f"Value {value} is not of expected type {expected_type}" + ) + + return True + + def find_by_index(self, column_name: str, value: Any) -> List[Dict[str, Any]]: + """Find rows using an index""" + if column_name not in self._indexes: + raise ValueError(f"No index exists for column {column_name}") + + index = self._indexes[column_name] + row_ids = index.search(value) + return [self.data[row_id] for row_id in row_ids] + + def range_search(self, column_name: str, start_value: Any, end_value: Any) -> List[Dict[str, Any]]: + """Perform a range search using an index""" + if column_name not in self._indexes: + raise ValueError(f"No index exists for column {column_name}") + + index = self._indexes[column_name] + row_ids = index.range_search(start_value, end_value) + return [self.data[row_id] for row_id in row_ids] \ No newline at end of file diff --git a/pyflaredb/transaction.py b/pyflaredb/transaction.py new file mode 100644 index 0000000..8b78885 --- /dev/null +++ b/pyflaredb/transaction.py @@ -0,0 +1,81 @@ +from typing import Dict, Any, Optional, Set, List +from enum import Enum +import time +import uuid +import threading + + +class TransactionState(Enum): + ACTIVE = "active" + COMMITTED = "committed" + ROLLED_BACK = "rolled_back" + + +class Transaction: + def __init__(self, tx_id: str): + self.id = tx_id + self.state = TransactionState.ACTIVE + self.start_time = time.time() + self.locks: Set[str] = set() # Set of table names that are locked + self.changes: List[Dict[str, Any]] = ( + [] + ) # List of changes made during transaction + + +class TransactionManager: + def __init__(self): + self.transactions: Dict[str, Transaction] = {} + self.lock = threading.Lock() + + def begin_transaction(self) -> str: + """Start a new transaction""" + with self.lock: + tx_id = str(uuid.uuid4()) + self.transactions[tx_id] = Transaction(tx_id) + return tx_id + + def commit(self, tx_id: str) -> bool: + """Commit a transaction""" + with self.lock: + if tx_id not in self.transactions: + raise ValueError(f"Transaction {tx_id} not found") + + tx = self.transactions[tx_id] + if tx.state != TransactionState.ACTIVE: + raise ValueError(f"Transaction {tx_id} is not active") + + # Apply changes + tx.state = TransactionState.COMMITTED + + # Release locks + tx.locks.clear() + + return True + + def rollback(self, tx_id: str) -> bool: + """Rollback a transaction""" + with self.lock: + if tx_id not in self.transactions: + raise ValueError(f"Transaction {tx_id} not found") + + tx = self.transactions[tx_id] + if tx.state != TransactionState.ACTIVE: + raise ValueError(f"Transaction {tx_id} is not active") + + # Revert changes + tx.state = TransactionState.ROLLED_BACK + tx.changes.clear() + + # Release locks + tx.locks.clear() + + return True + + def get_transaction(self, tx_id: str) -> Optional[Transaction]: + """Get transaction by ID""" + return self.transactions.get(tx_id) + + def is_active(self, tx_id: str) -> bool: + """Check if a transaction is active""" + tx = self.get_transaction(tx_id) + return tx is not None and tx.state == TransactionState.ACTIVE diff --git a/pyflaredb/transaction/manager.py b/pyflaredb/transaction/manager.py new file mode 100644 index 0000000..49e99e9 --- /dev/null +++ b/pyflaredb/transaction/manager.py @@ -0,0 +1,63 @@ +from typing import Dict, List, Any +from enum import Enum +import threading +from datetime import datetime + +class TransactionState(Enum): + ACTIVE = "ACTIVE" + COMMITTED = "COMMITTED" + ROLLED_BACK = "ROLLED_BACK" + +class Transaction: + def __init__(self, id: str): + self.id = id + self.state = TransactionState.ACTIVE + self.changes: List[Dict[str, Any]] = [] + self.locks = set() + self.timestamp = datetime.utcnow() + +class TransactionManager: + def __init__(self): + self.transactions: Dict[str, Transaction] = {} + self.lock = threading.Lock() + + def begin_transaction(self) -> str: + """Start a new transaction""" + with self.lock: + tx_id = str(len(self.transactions) + 1) + self.transactions[tx_id] = Transaction(tx_id) + return tx_id + + def commit(self, tx_id: str): + """Commit a transaction""" + with self.lock: + if tx_id not in self.transactions: + raise ValueError(f"Transaction {tx_id} not found") + + tx = self.transactions[tx_id] + if tx.state != TransactionState.ACTIVE: + raise ValueError(f"Transaction {tx_id} is not active") + + # Apply changes + self._apply_changes(tx) + tx.state = TransactionState.COMMITTED + + def rollback(self, tx_id: str): + """Rollback a transaction""" + with self.lock: + if tx_id not in self.transactions: + raise ValueError(f"Transaction {tx_id} not found") + + tx = self.transactions[tx_id] + if tx.state != TransactionState.ACTIVE: + raise ValueError(f"Transaction {tx_id} is not active") + + # Discard changes + tx.changes.clear() + tx.state = TransactionState.ROLLED_BACK + + def _apply_changes(self, transaction: Transaction): + """Apply transaction changes""" + for change in transaction.changes: + # Implementation of applying changes + pass \ No newline at end of file diff --git a/pyflaredb/versioning.py b/pyflaredb/versioning.py new file mode 100644 index 0000000..9a40d39 --- /dev/null +++ b/pyflaredb/versioning.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, Optional + + +@dataclass +class Version: + timestamp: datetime + operation: str # 'INSERT', 'UPDATE', 'DELETE' + table_name: str + row_id: str + data: Dict[str, Any] + previous_version: Optional[str] = None # Hash of previous version + + +class VersionStore: + def __init__(self): + self.versions: List[Version] = [] + self.current_version: str = None # Hash of current version + + def add_version(self, version: Version): + """Add a new version to the store""" + version_hash = self._calculate_hash(version) + self.versions.append(version) + self.current_version = version_hash + + def get_state_at(self, timestamp: datetime) -> Dict[str, List[Dict[str, Any]]]: + """Reconstruct database state at given timestamp""" + state = {} + relevant_versions = [v for v in self.versions if v.timestamp <= timestamp] + + for version in relevant_versions: + if version.table_name not in state: + state[version.table_name] = [] + + if version.operation == "INSERT": + state[version.table_name].append(version.data) + elif version.operation == "DELETE": + state[version.table_name] = [ + row + for row in state[version.table_name] + if row["id"] != version.row_id + ] + elif version.operation == "UPDATE": + state[version.table_name] = [ + version.data if row["id"] == version.row_id else row + for row in state[version.table_name] + ] + + return state diff --git a/test.py b/test.py new file mode 100644 index 0000000..75b396c --- /dev/null +++ b/test.py @@ -0,0 +1,254 @@ +from pyflaredb.core import PyFlareDB +from pyflaredb.table import Column, Table +from pyflaredb.benchmark.suite import BenchmarkSuite +import time +from datetime import datetime +import random +import string +import json +from typing import List, Dict, Any + + +def generate_realistic_data(n: int) -> List[Dict[str, Any]]: + """Generate realistic test data""" + domains = ['gmail.com', 'yahoo.com', 'hotmail.com', 'outlook.com', 'company.com'] + cities = ['New York', 'London', 'Tokyo', 'Paris', 'Berlin', 'Sydney', 'Toronto'] + + data = [] + for i in range(n): + # Generate realistic username + username = f"{random.choice(string.ascii_lowercase)}{random.choice(string.ascii_lowercase)}" + username += ''.join(random.choices(string.ascii_lowercase + string.digits, k=random.randint(6, 12))) + + # Generate realistic email + email = f"{username}@{random.choice(domains)}" + + # Generate JSON metadata + metadata = { + "city": random.choice(cities), + "last_login": f"2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}", + "preferences": { + "theme": random.choice(["light", "dark", "system"]), + "notifications": random.choice([True, False]) + } + } + + data.append({ + "id": f"usr_{i:08d}", + "username": username, + "email": email, + "age": random.randint(18, 80), + "score": round(random.uniform(0, 100), 2), + "is_active": random.random() > 0.1, # 90% active users + "login_count": random.randint(1, 1000), + "metadata": json.dumps(metadata) + }) + + return data + + +def format_value(value): + """Format value based on its type""" + if isinstance(value, (float, int)): + return f"{value:.4f}" + return str(value) + + +def test_database_features(): + """Test all database features with realistic workloads""" + print("\n=== Starting Realistic Database Tests ===") + + # Initialize database + db = PyFlareDB("test.db") + + # 1. Create test table with realistic schema + print("\n1. Setting up test environment...") + users_table = Table( + name="users", + columns=[ + Column("id", "string", nullable=False, primary_key=True), + Column("username", "string", nullable=False, unique=True), + Column("email", "string", nullable=False), + Column("age", "integer", nullable=True), + Column("score", "float", nullable=True), + Column("is_active", "boolean", nullable=True, default=True), + Column("login_count", "integer", nullable=True, default=0), + Column("metadata", "string", nullable=True) # JSON string + ], + ) + db.tables["users"] = users_table + + # Create indexes for commonly queried fields + users_table.create_index("age") + users_table.create_index("score") + users_table.create_index("login_count") + + # 2. Performance Tests with Realistic Data + print("\n2. Running performance tests...") + + # Generate test data + test_data = generate_realistic_data(1000) # 1000 realistic records + + # Insert Performance (Single vs Batch) + print("\nInsert Performance:") + + # Single Insert (OLTP-style) + start_time = time.time() + for record in test_data[:100]: # Test with first 100 records + # Properly escape the metadata string + metadata_str = record['metadata'].replace("'", "''") + + # Format each value according to its type + values = [ + f"'{record['id']}'", # string + f"'{record['username']}'", # string + f"'{record['email']}'", # string + str(record['age']), # integer + str(record['score']), # float + str(record['is_active']).lower(), # boolean + str(record['login_count']), # integer + f"'{metadata_str}'" # string (JSON) + ] + + query = f""" + INSERT INTO users + (id, username, email, age, score, is_active, login_count, metadata) + VALUES + ({', '.join(values)}) + """ + db.execute(query) + single_insert_time = time.time() - start_time + print(f"Single Insert (100 records, OLTP): {single_insert_time:.4f}s") + + # Batch Insert (OLAP-style) + start_time = time.time() + batch_data = test_data[100:200] # Next 100 records + users_table.batch_insert(batch_data) # This should work as is + batch_insert_time = time.time() - start_time + print(f"Batch Insert (100 records, OLAP): {batch_insert_time:.4f}s") + + # 3. Query Performance Tests + print("\nQuery Performance (OLTP vs OLAP):") + + # OLTP-style queries (point queries, simple filters) + oltp_queries = [ + ("Single Record Lookup", "SELECT * FROM users WHERE id = 'usr_00000001'"), + ("Simple Range Query", "SELECT * FROM users WHERE age > 30 LIMIT 10"), + ("Active Users Count", "SELECT COUNT(*) FROM users WHERE is_active = true"), + ("Recent Logins", "SELECT * FROM users WHERE login_count > 500 ORDER BY login_count DESC LIMIT 5") + ] + + # OLAP-style queries (aggregations, complex filters) + olap_queries = [ + ("Age Distribution", """ + SELECT + CASE + WHEN age < 25 THEN 'Gen Z' + WHEN age < 40 THEN 'Millennial' + WHEN age < 55 THEN 'Gen X' + ELSE 'Boomer' + END as generation, + COUNT(*) as count + FROM users + GROUP BY generation + """), + ("User Engagement", """ + SELECT + username, + score, + login_count + FROM users + WHERE score > 75 + AND login_count > 100 + ORDER BY score DESC + LIMIT 10 + """), + ("Complex Analytics", """ + SELECT + COUNT(*) as total_users, + AVG(score) as avg_score, + SUM(CASE WHEN is_active THEN 1 ELSE 0 END) as active_users + FROM users + WHERE age BETWEEN 25 AND 45 + """) + ] + + print("\nOLTP Query Performance:") + for query_name, query in oltp_queries: + # First run (cold) + start_time = time.time() + db.execute(query) + cold_time = time.time() - start_time + + # Second run (warm/cached) + start_time = time.time() + db.execute(query) + warm_time = time.time() - start_time + + print(f"\n{query_name}:") + print(f" Cold run: {cold_time:.4f}s") + print(f" Warm run: {warm_time:.4f}s") + print(f" Cache improvement: {((cold_time - warm_time) / cold_time * 100):.1f}%") + + print("\nOLAP Query Performance:") + for query_name, query in olap_queries: + start_time = time.time() + db.execute(query) + execution_time = time.time() - start_time + print(f"\n{query_name}: {execution_time:.4f}s") + + # 4. Concurrent Operations Test + print("\nConcurrent Operations Simulation:") + start_time = time.time() + # Simulate mixed workload + for _ in range(100): + if random.random() < 0.8: # 80% reads + query = random.choice(oltp_queries)[1] + else: # 20% writes + record = generate_realistic_data(1)[0] + query = f""" + INSERT INTO users (id, username, email, age, score, is_active, login_count, metadata) + VALUES ( + '{record['id']}', + '{record['username']}', + '{record['email']}', + {record['age']}, + {record['score']}, + {str(record['is_active']).lower()}, + {record['login_count']}, + '{record['metadata']}' + ) + """ + db.execute(query) + mixed_workload_time = time.time() - start_time + print(f"Mixed Workload (100 operations): {mixed_workload_time:.4f}s") + + # 5. Memory Usage Test + print("\nMemory Usage:") + import sys + memory_size = sys.getsizeof(db.tables["users"].data) / 1024 # KB + records_count = len(db.tables["users"].data) + print(f"Memory per record: {(memory_size / records_count):.2f} KB") + + # 6. Run standard benchmark suite + print("\n6. Running standard benchmark suite...") + benchmark = BenchmarkSuite(db) + results = benchmark.run_benchmark(num_records=10000) + + print("\nBenchmark Results:") + for test_name, metrics in results.items(): + print(f"\n{test_name.upper()}:") + for metric, value in metrics.items(): + print(f" {metric}: {format_value(value)}") + + +def main(): + try: + test_database_features() + except Exception as e: + print(f"Test failed: {e}") + raise e + + +if __name__ == "__main__": + main()