Source code for quantumengine.backends.surrealdb

"""SurrealDB backend implementation for SurrealEngine."""

import uuid
from typing import Any, Dict, List, Optional, Type

from surrealdb import RecordID

from .base import BaseBackend


[docs] class SurrealDBBackend(BaseBackend): """SurrealDB backend implementation. This backend implements the BaseBackend interface for SurrealDB, providing all the core database operations using SurrealQL. """
[docs] def __init__(self, connection: Any) -> None: """Initialize the SurrealDB backend. Args: connection: SurrealEngine connection (async or sync) """ super().__init__(connection) self.client = connection.client self.is_async = hasattr(connection, 'client') and hasattr(connection.client, 'create')
[docs] async def create_table(self, document_class: Type, **kwargs) -> None: """Create a table/collection for the document class. Args: document_class: The document class to create a table for **kwargs: Backend-specific options: - schemafull: Whether to create a schemafull table (default: True) """ table_name = document_class._meta.get('collection') schemafull = kwargs.get('schemafull', True) # Create table definition schema_type = "SCHEMAFULL" if schemafull else "SCHEMALESS" query = f"DEFINE TABLE {table_name} {schema_type}" await self._execute(query) # Define fields if schemafull if schemafull: for field_name, field in document_class._fields.items(): if field_name == document_class._meta.get('id_field', 'id'): continue # Skip ID field field_type = self.get_field_type(field) field_query = f"DEFINE FIELD {field.db_field} ON {table_name} TYPE {field_type}" if field.required: field_query += " ASSERT $value != NONE" await self._execute(field_query) # Create indexes indexes = document_class._meta.get('indexes', []) for index in indexes: if isinstance(index, str): # Simple field index index_query = f"DEFINE INDEX idx_{index} ON {table_name} COLUMNS {index}" elif isinstance(index, dict): # Complex index index_name = index.get('name', f"idx_{'_'.join(index['fields'])}") fields = ', '.join(index['fields']) index_query = f"DEFINE INDEX {index_name} ON {table_name} COLUMNS {fields}" if index.get('unique'): index_query += " UNIQUE" else: continue await self._execute(index_query)
[docs] async def insert(self, table_name: str, data: Dict[str, Any]) -> Dict[str, Any]: """Insert a single document. Args: table_name: The table name data: The document data to insert Returns: The inserted document with any generated fields """ # Format data for SurrealDB formatted_data = self._format_document_data(data) if 'id' in formatted_data and formatted_data['id']: # Use CREATE with specific ID for new records, UPDATE for existing ones record_id = formatted_data.pop('id') if not isinstance(record_id, RecordID): if ':' in str(record_id): # Split table and id parts parts = str(record_id).split(':', 1) # Check if the id part is numeric try: # If it's numeric, convert to int for proper RecordID format record_id = RecordID(parts[0], int(parts[1])) except ValueError: # If not numeric, keep as string record_id = RecordID(parts[0], parts[1]) else: # Check if the id is numeric try: record_id = RecordID(table_name, int(record_id)) except ValueError: record_id = RecordID(table_name, record_id) # Try CREATE first (for new records), fallback to UPDATE if it exists try: result = await self.client.create(record_id, formatted_data) except Exception as e: if 'already exists' in str(e): # Record exists, use UPDATE instead result = await self.client.update(record_id, formatted_data) else: raise e else: # Use CREATE without ID (auto-generate) result = await self.client.create(table_name, formatted_data) if result: # Result can be either a dict (single record) or a list (multiple records) if isinstance(result, list): return self._format_result_data(result[0]) if result else data else: return self._format_result_data(result) else: return data
[docs] async def insert_many(self, table_name: str, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Insert multiple documents efficiently. Args: table_name: The table name data: List of documents to insert Returns: List of inserted documents """ if not data: return [] results = [] # Group by documents with and without IDs docs_with_id = [] docs_without_id = [] for doc in data: formatted_doc = self._format_document_data(doc) if 'id' in formatted_doc and formatted_doc['id']: docs_with_id.append(formatted_doc) else: docs_without_id.append(formatted_doc) # Insert documents without IDs (bulk create) if docs_without_id: batch_results = await self.client.insert(table_name, docs_without_id) if batch_results: results.extend([self._format_result_data(r) for r in batch_results]) # Insert documents with IDs (individual creates) for doc in docs_with_id: record_id = doc.pop('id') if not isinstance(record_id, RecordID): if ':' in str(record_id): record_id = RecordID(record_id) else: record_id = RecordID(table_name, record_id) result = await self.client.create(record_id, doc) if result and len(result) > 0: results.append(self._format_result_data(result[0])) return results
[docs] async def select(self, table_name: str, conditions: List[str], fields: Optional[List[str]] = None, limit: Optional[int] = None, offset: Optional[int] = None, order_by: Optional[List[tuple[str, str]]] = None) -> List[Dict[str, Any]]: """Select documents from a table. Args: table_name: The table name conditions: List of condition strings fields: List of fields to return (None for all) limit: Maximum number of results offset: Number of results to skip (START in SurrealDB) order_by: List of (field, direction) tuples Returns: List of matching documents """ # Build SELECT clause if fields: select_clause = ", ".join(fields) else: select_clause = "*" query = f"SELECT {select_clause} FROM {table_name}" # Add WHERE clause if conditions: query += f" WHERE {' AND '.join(conditions)}" # Add ORDER BY clause if order_by: order_parts = [] for field, direction in order_by: order_parts.append(f"{field} {direction.upper()}") query += f" ORDER BY {', '.join(order_parts)}" # Add LIMIT clause if limit: query += f" LIMIT {limit}" # Add START clause (SurrealDB's equivalent to OFFSET) if offset: query += f" START {offset}" result = await self._query(query) if result: # The SurrealDB Python client returns SELECT results as a plain list of dicts if isinstance(result, list): # If it's a list of dicts (documents), format each one if result and isinstance(result[0], dict): return [self._format_result_data(doc) for doc in result] # Empty list case return [] return []
[docs] async def select_by_ids(self, table_name: str, ids: List[Any]) -> List[Dict[str, Any]]: """Select documents by their IDs using direct record access. Args: table_name: The table name ids: List of IDs to select Returns: List of matching documents """ if not ids: return [] # Format IDs for direct access record_ids = [] for id_val in ids: if isinstance(id_val, RecordID): record_ids.append(str(id_val)) elif isinstance(id_val, str) and ':' in id_val: record_ids.append(id_val) else: # Convert to proper RecordID format record_ids.append(f"{table_name}:{id_val}") # Use direct record access syntax query = f"SELECT * FROM {', '.join(record_ids)}" result = await self._query(query) if result: # The SurrealDB Python client returns SELECT results as a plain list of dicts if isinstance(result, list): # If it's a list of dicts (documents), format each one if result and isinstance(result[0], dict): return [self._format_result_data(doc) for doc in result] # Empty list case return [] return []
[docs] async def count(self, table_name: str, conditions: List[str]) -> int: """Count documents matching conditions. Args: table_name: The table name conditions: List of condition strings Returns: Number of matching documents """ query = f"SELECT count() FROM {table_name}" if conditions: query += f" WHERE {' AND '.join(conditions)}" result = await self._query(query) if result and isinstance(result, list): # SurrealDB returns a list of dicts with count for each record # Sum all the counts total_count = sum(item.get('count', 0) for item in result if isinstance(item, dict)) return total_count return 0
[docs] async def update(self, table_name: str, conditions: List[str], data: Dict[str, Any]) -> List[Dict[str, Any]]: """Update documents matching conditions. Args: table_name: The table name conditions: List of condition strings data: The fields to update Returns: List of updated documents """ # Format update data formatted_data = self._format_document_data(data) # Build UPDATE query query = f"UPDATE {table_name}" if conditions: query += f" WHERE {' AND '.join(conditions)}" # Add SET clause set_parts = [] for key, value in formatted_data.items(): set_parts.append(f"{key} = {self.format_value(value)}") if set_parts: query += f" SET {', '.join(set_parts)}" result = await self._query(query) if result and len(result) > 0: return [self._format_result_data(doc) for doc in result[0]] return []
[docs] async def delete(self, table_name: str, conditions: List[str]) -> int: """Delete documents matching conditions. Args: table_name: The table name conditions: List of condition strings Returns: Number of deleted documents """ query = f"DELETE FROM {table_name}" if conditions: query += f" WHERE {' AND '.join(conditions)}" result = await self._query(query) if result and len(result) > 0: return len(result[0]) return 0
[docs] async def drop_table(self, table_name: str, if_exists: bool = True) -> None: """Drop a table using SurrealDB's REMOVE TABLE statement. Args: table_name: The table name to drop if_exists: Whether to use IF EXISTS clause to avoid errors if table doesn't exist """ if if_exists: query = f"REMOVE TABLE IF EXISTS {table_name}" else: query = f"REMOVE TABLE {table_name}" await self._execute(query)
[docs] async def execute_raw(self, query: str, params: Optional[Dict[str, Any]] = None) -> Any: """Execute a raw SurrealQL query. Args: query: The raw SurrealQL query string params: Optional query parameters (not used - SurrealDB handles this) Returns: Query result """ return await self._query(query)
[docs] def build_condition(self, field: str, operator: str, value: Any) -> str: """Build a condition string for SurrealQL. Args: field: The field name operator: The operator (=, !=, >, <, >=, <=, ~, !~, etc.) value: The value to compare against Returns: A condition string in SurrealQL """ # Special handling for 'id' field - convert string to RecordID if needed if field == 'id' and isinstance(value, str) and ':' in value: # Convert string ID like "users:abc123" to RecordID parts = value.split(':', 1) value = RecordID(parts[0], parts[1]) formatted_value = self.format_value(value) if operator == '=': return f"{field} = {formatted_value}" elif operator == '!=': return f"{field} != {formatted_value}" elif operator in ['>', '<', '>=', '<=']: return f"{field} {operator} {formatted_value}" elif operator == 'in': return f"{field} INSIDE {formatted_value}" elif operator == 'not in': return f"{field} NOT INSIDE {formatted_value}" elif operator == 'contains': return f"{field} CONTAINS {formatted_value}" elif operator == 'containsnot': return f"{field} CONTAINSNOT {formatted_value}" elif operator == 'containsall': return f"{field} CONTAINSALL {formatted_value}" elif operator == 'containsany': return f"{field} CONTAINSANY {formatted_value}" elif operator == 'containsnone': return f"{field} CONTAINSNONE {formatted_value}" elif operator == '~': return f"{field} ~ {formatted_value}" elif operator == '!~': return f"{field} !~ {formatted_value}" elif operator == 'is null': return f"{field} IS NULL" elif operator == 'is not null': return f"{field} IS NOT NULL" else: return f"{field} {operator} {formatted_value}"
[docs] def get_field_type(self, field: Any) -> str: """Get the SurrealDB field type for a SurrealEngine field. Args: field: A SurrealEngine field instance Returns: The corresponding SurrealDB field type """ # Import here to avoid circular imports from ..fields import ( StringField, IntField, FloatField, BooleanField, DateTimeField, UUIDField, DictField, DecimalField ) if isinstance(field, StringField): return "string" elif isinstance(field, IntField): return "int" elif isinstance(field, FloatField): return "float" elif isinstance(field, BooleanField): return "bool" elif isinstance(field, DateTimeField): return "datetime" elif isinstance(field, UUIDField): return "uuid" elif isinstance(field, DictField): return "object" elif isinstance(field, DecimalField): return "decimal" else: return "any"
[docs] def format_value(self, value: Any, field_type: Optional[str] = None) -> str: """Format a value for SurrealQL. Args: value: The value to format field_type: Optional field type hint Returns: The formatted value as a string for SurrealQL """ if value is None: return "NONE" elif isinstance(value, str): # Escape quotes and wrap in quotes escaped = value.replace('"', '\\"') return f'"{escaped}"' elif isinstance(value, bool): return "true" if value else "false" elif isinstance(value, (int, float)): return str(value) elif isinstance(value, RecordID): return str(value) elif isinstance(value, list): # Format array formatted_items = [self.format_value(item) for item in value] return f"[{', '.join(formatted_items)}]" elif isinstance(value, dict): # Format object formatted_pairs = [] for k, v in value.items(): formatted_pairs.append(f"{k}: {self.format_value(v)}") return f"{{{', '.join(formatted_pairs)}}}" elif isinstance(value, uuid.UUID): return f'"{str(value)}"' else: # Default: convert to string return f'"{str(value)}"'
# Transaction support
[docs] async def begin_transaction(self) -> Any: """Begin a transaction. Returns: Transaction object (SurrealDB client for now) """ # SurrealDB doesn't have explicit transaction syntax like BEGIN # Transactions are implicit within query batches return self.client
[docs] async def commit_transaction(self, transaction: Any) -> None: """Commit a transaction. Args: transaction: The transaction object """ # SurrealDB transactions are auto-committed # This is a no-op for compatibility pass
[docs] async def rollback_transaction(self, transaction: Any) -> None: """Rollback a transaction. Args: transaction: The transaction object """ # SurrealDB doesn't support explicit rollback in the same way # This would require application-level rollback logic pass
[docs] def supports_transactions(self) -> bool: """SurrealDB supports transactions within query batches.""" return True
[docs] def supports_references(self) -> bool: """SurrealDB supports references between records.""" return True
[docs] def supports_graph_relations(self) -> bool: """SurrealDB has native graph relation support.""" return True
[docs] def supports_direct_record_access(self) -> bool: """SurrealDB supports direct record access syntax.""" return True
[docs] def supports_explain(self) -> bool: """SurrealDB supports EXPLAIN queries.""" return True
[docs] def supports_indexes(self) -> bool: """SurrealDB supports indexes.""" return True
[docs] def supports_bulk_operations(self) -> bool: """SurrealDB supports bulk operations.""" return True
[docs] def get_optimized_methods(self) -> Dict[str, str]: """Get SurrealDB-specific optimization methods.""" return { 'direct_record_access': 'SELECT * FROM user:1, user:2, user:3', 'range_access': 'SELECT * FROM user:1..=100', 'graph_traversal': 'SELECT * FROM user:1->likes->post', 'string_functions': 'string::contains(), string::starts_with()', }
# Graph/Relation implementations
[docs] async def create_relation(self, from_table: str, from_id: str, relation_name: str, to_table: str, to_id: str, attributes: Optional[Dict[str, Any]] = None) -> Optional[Dict[str, Any]]: """Create a relation using SurrealDB's RELATE statement. Args: from_table: Source table name from_id: Source document ID relation_name: Name of the relation to_table: Target table name to_id: Target document ID attributes: Optional attributes for the relation Returns: The created relation record """ from_record = RecordID(from_table, from_id) to_record = RecordID(to_table, to_id) # Construct RELATE query query = f"RELATE {from_record}->{relation_name}->{to_record}" # Add attributes if provided if attributes: import json attrs_str = ", ".join([f"{k}: {json.dumps(v)}" for k, v in attributes.items()]) query += f" CONTENT {{ {attrs_str} }}" result = await self._query(query) if result and isinstance(result, list) and len(result) > 0: # SurrealDB returns the relation as a single dict in a list return self._format_result_data(result[0]) return None
[docs] async def delete_relation(self, from_table: str, from_id: str, relation_name: str, to_table: Optional[str] = None, to_id: Optional[str] = None) -> int: """Delete relations using SurrealDB's DELETE statement. Args: from_table: Source table name from_id: Source document ID relation_name: Name of the relation to_table: Target table name (optional) to_id: Target document ID (optional) Returns: Number of relations deleted """ from_record = RecordID(from_table, from_id) if to_table and to_id: # Delete specific relation to_record = RecordID(to_table, to_id) query = f"DELETE {from_record}->{relation_name}->{to_record}" else: # Delete all relations of this type from the source document query = f"DELETE {from_record}->{relation_name}" result = await self._query(query) if result and len(result) > 0: return len(result[0]) return 0
[docs] async def query_relations(self, from_table: str, from_id: str, relation_name: str, direction: str = 'out') -> List[Dict[str, Any]]: """Query relations using SurrealDB's graph traversal. Args: from_table: Source table name from_id: Source document ID relation_name: Name of the relation direction: Direction of relations ('out', 'in', 'both') Returns: List of related documents """ from_record = RecordID(from_table, from_id) if direction == 'out': query = f"SELECT * FROM {from_record}->{relation_name}" elif direction == 'in': query = f"SELECT * FROM {from_record}<-{relation_name}" elif direction == 'both': query = f"SELECT * FROM {from_record}<->{relation_name}" else: raise ValueError(f"Invalid direction: {direction}. Must be 'out', 'in', or 'both'") result = await self._query(query) if result and len(result) > 0: return [self._format_result_data(doc) for doc in result[0]] return []
# Helper methods async def _execute(self, query: str) -> None: """Execute a query without returning results.""" await self.client.query(query) async def _query(self, query: str) -> Any: """Execute a query and return results.""" return await self.client.query(query) def _format_document_data(self, data: Dict[str, Any]) -> Dict[str, Any]: """Format document data for SurrealDB storage.""" from decimal import Decimal formatted = {} for key, value in data.items(): # Handle special field types if hasattr(value, 'to_db'): formatted[key] = value.to_db() elif isinstance(value, Decimal): # Convert Decimal to float for SurrealDB formatted[key] = float(value) else: formatted[key] = value return formatted # Materialized view support async def create_materialized_view(self, materialized_document_class: Type) -> None: """Create a SurrealDB materialized view using DEFINE TABLE ... AS SELECT. Args: materialized_document_class: The MaterializedDocument class """ view_name = materialized_document_class._meta.get('view_name') or \ materialized_document_class._meta.get('table_name') or \ materialized_document_class.__name__.lower() # Build the source query source_query = materialized_document_class._build_source_query() # Convert ClickHouse-specific functions to SurrealDB equivalents source_query = self._convert_query_to_surrealdb(source_query) # SurrealDB materialized view syntax query = f"DEFINE TABLE {view_name} AS {source_query}" # Debug: Print the generated query print("Generated SurrealDB Materialized View SQL:") print(query) print("=" * 60) await self._execute(query) async def drop_materialized_view(self, materialized_document_class: Type) -> None: """Drop a SurrealDB materialized view. Args: materialized_document_class: The MaterializedDocument class """ view_name = materialized_document_class._meta.get('view_name') or \ materialized_document_class._meta.get('table_name') or \ materialized_document_class.__name__.lower() query = f"REMOVE TABLE {view_name}" await self._execute(query) async def refresh_materialized_view(self, materialized_document_class: Type) -> None: """Refresh a SurrealDB materialized view. Note: SurrealDB materialized views update automatically when data changes. This is a no-op for SurrealDB. Args: materialized_document_class: The MaterializedDocument class """ # SurrealDB materialized views refresh automatically pass def _convert_query_to_surrealdb(self, query: str) -> str: """Convert ClickHouse-specific query syntax to SurrealDB. Args: query: The ClickHouse-style query Returns: SurrealDB-compatible query """ # Handle COUNT DISTINCT - SurrealDB doesn't have direct COUNT DISTINCT # For materialized views, we'll use a simplified approach import re count_distinct_pattern = r'COUNT\(DISTINCT\s+([^)]+)\)' def replace_count_distinct(match): field = match.group(1).strip() # For SurrealDB, use a different approach for COUNT DISTINCT # We'll group by the field and count the groups return f'1' # Simplified for now - each record contributes 1 converted_query = re.sub(count_distinct_pattern, replace_count_distinct, query, flags=re.IGNORECASE) # Convert other ClickHouse functions to SurrealDB equivalents conversions = { 'toDate(': 'time::day(', 'toYYYYMM(': 'time::format(', 'COUNT(*)': 'count()', 'SUM(': 'math::sum(', 'AVG(': 'math::mean(', 'MIN(': 'math::min(', 'MAX(': 'math::max(', } # Also handle COUNT(*) in SELECT clauses outside of aggregations converted_query = converted_query.replace('SELECT COUNT(*)', 'SELECT count()') for clickhouse_func, surrealdb_func in conversions.items(): converted_query = converted_query.replace(clickhouse_func, surrealdb_func) # Handle special cases for time format if 'time::format(' in converted_query: # Convert toYYYYMM to proper SurrealDB time format converted_query = converted_query.replace( 'time::format(', 'time::format(' ).replace(') AS year_month', ', "%Y%m") AS year_month') return converted_query def _format_result_data(self, data: Dict[str, Any]) -> Dict[str, Any]: """Format result data from SurrealDB.""" if not isinstance(data, dict): return data formatted = {} for key, value in data.items(): # Handle RecordID conversion if isinstance(value, RecordID): formatted[key] = str(value) else: formatted[key] = value return formatted # Materialized view support
[docs] async def create_materialized_view(self, materialized_document_class: Type) -> None: """Create a SurrealDB materialized view using DEFINE TABLE ... AS SELECT. Args: materialized_document_class: The MaterializedDocument class """ view_name = materialized_document_class._meta.get('view_name') or \ materialized_document_class._meta.get('table_name') or \ materialized_document_class.__name__.lower() # Build the source query source_query = materialized_document_class._build_source_query() # Convert ClickHouse-specific functions to SurrealDB equivalents source_query = self._convert_query_to_surrealdb(source_query) # SurrealDB materialized view syntax query = f"DEFINE TABLE {view_name} AS {source_query}" # Debug: Print the generated query print("Generated SurrealDB Materialized View SQL:") print(query) print("=" * 60) await self._execute(query)
[docs] async def drop_materialized_view(self, materialized_document_class: Type) -> None: """Drop a SurrealDB materialized view. Args: materialized_document_class: The MaterializedDocument class """ view_name = materialized_document_class._meta.get('view_name') or \ materialized_document_class._meta.get('table_name') or \ materialized_document_class.__name__.lower() query = f"REMOVE TABLE {view_name}" await self._execute(query)
[docs] async def refresh_materialized_view(self, materialized_document_class: Type) -> None: """Refresh a SurrealDB materialized view. Note: SurrealDB materialized views update automatically when data changes. This is a no-op for SurrealDB. Args: materialized_document_class: The MaterializedDocument class """ # SurrealDB materialized views refresh automatically pass