Source code for surrealengine.query.base

from ..base_query import BaseQuerySet
from typing import Any, Dict, List, Optional, Tuple, Type, Union, cast
from ..exceptions import MultipleObjectsReturned, DoesNotExist
from ..fields import ReferenceField
from surrealdb import RecordID
import json
import asyncio
import logging
from ..surrealql import escape_literal

# Set up logging
logger = logging.getLogger(__name__)


[docs] class QuerySet(BaseQuerySet): """Query builder for SurrealDB. This class provides a query builder for document classes with a predefined schema. It extends BaseQuerySet to provide methods for querying and manipulating documents of a specific document class. Attributes: document_class: The document class to query connection: The database connection to use for queries """
[docs] def __init__(self, document_class: Type, connection: Any) -> None: """Initialize a new QuerySet. Args: document_class: The document class to query connection: The database connection to use for queries """ super().__init__(connection) self.document_class = document_class
[docs] def traverse(self, path: str, max_depth: Optional[int] = None, unique: bool = True) -> 'QuerySet': """Configure a graph traversal for this query. Args: path: Arrow path segment(s), e.g. "->likes->user" or "<-follows". max_depth: Optional bound for depth. For simple single-edge paths we will repeat the path up to max_depth. For complex paths this is ignored and the path is used as-is. This is a pragmatic workaround until SurrealQL exposes native depth quantifiers in arrow paths. unique: When True, deduplicate results via GROUP BY id to avoid duplicate rows. Returns: A cloned QuerySet configured with traversal. """ clone = self._clone() # Store as-is; _build_query applies simple bounded expansion clone._traversal_path = path clone._traversal_unique = bool(unique) clone._traversal_max_depth = max_depth if (isinstance(max_depth, int) and max_depth > 0) else None return clone
[docs] def shortest_path(self, src: Union[str, RecordID], dst: Union[str, RecordID], edge: str) -> 'QuerySet': """Helper for shortest path queries (if supported by SurrealDB). Note: As of now, SurrealDB does not expose a stable built-in shortest path function in SurrealQL. This method prepares a placeholder raw condition to document the limitation. If SurrealDB adds support, this can be updated to emit the proper function call. """ # Currently not supported - we document limitation by raising raise NotImplementedError("Shortest path is not supported via SurrealQL in this version. Track SurrealDB updates for native support.")
[docs] async def live(self, where: Optional[Union["Q", dict]] = None, action: Optional[Union[str, List[str]]] = None, *, retry_limit: int = 3, initial_delay: float = 0.5, backoff: float = 2.0): """Subscribe to changes on this table via LIVE queries as an async generator. This method provides real-time updates for table changes using SurrealDB's LIVE query functionality. It returns LiveEvent objects for each change (CREATE, UPDATE, DELETE) that occurs on the table. The underlying implementation uses the surrealdb Async client (websocket). If the current connection uses a connection pool client which does not support LIVE, a NotImplementedError is raised. Args: where: Optional filter (Q or dict) applied client-side to incoming events. Only events matching this filter will be yielded. action: Optional action filter ('CREATE', 'UPDATE', 'DELETE') or list of actions. Use this to subscribe to specific event types only. retry_limit: Number of times to retry subscription on transient errors (default: 3). initial_delay: Initial backoff delay in seconds (default: 0.5). backoff: Multiplier for exponential backoff (default: 2.0). Yields: LiveEvent: Typed event objects with the following attributes: - action: Event type (CREATE, UPDATE, DELETE) - data: Dictionary containing the document fields - ts: Optional timestamp of the event - id: Optional RecordID of the affected document Raises: NotImplementedError: If the active connection does not support LIVE queries (e.g., when using connection pooling). Example:: # Subscribe to all user creation events async for evt in User.objects.live(action="CREATE"): print(f"New user: {evt.id}") print(f"Data: {evt.data}") # Filter for specific conditions async for evt in User.objects.live(where={"status": "active"}, action=["CREATE", "UPDATE"]): if evt.is_create: print(f"Active user created: {evt.id}") elif evt.is_update: print(f"Active user updated: {evt.id}") """ # Import LiveEvent locally to avoid circular imports during module load from ..events import LiveEvent # Normalize action filter allowed_actions = None if action: if isinstance(action, str): allowed_actions = {action.upper()} elif isinstance(action, (list, tuple, set)): allowed_actions = {a.upper() for a in action} # Ensure async client and availability of live API client = getattr(self.connection, 'client', None) if client is None or not hasattr(client, 'live') or not hasattr(client, 'subscribe_live') or not hasattr(client, 'kill'): raise NotImplementedError("LIVE queries require an async websocket client; connection pooling is not supported for LIVE in this version.") table = self.document_class._get_collection_name() # Prepare optional predicate for client-side filtering predicate = None if where is not None: try: from ..query_expressions import Q if isinstance(where, dict): q = Q(**where) elif isinstance(where, Q): q = where else: raise ValueError("where must be a Q or dict") # Build a simple predicate using Q.to_conditions semantics conditions = q.to_conditions() def _eval(record): for field, op, value in conditions: if field == '__raw__': # raw cannot be evaluated here; accept all continue lhs = record.get(field) if op == '=' and lhs != value: return False if op == '!=' and lhs == value: return False if op == '>' and not (lhs is not None and lhs > value): return False if op == '<' and not (lhs is not None and lhs < value): return False if op == '>=' and not (lhs is not None and lhs >= value): return False if op == '<=' and not (lhs is not None and lhs <= value): return False if op == 'INSIDE' and isinstance(value, (list, tuple, set)) and lhs not in value: return False if op == 'NOT INSIDE' and isinstance(value, (list, tuple, set)) and lhs in value: return False if op == 'CONTAINS': if isinstance(lhs, str) and isinstance(value, str): if value not in lhs: return False elif isinstance(lhs, (list, tuple, set)): if value not in lhs: return False else: return False return True predicate = _eval except Exception: # If anything goes wrong, fallback to no filtering predicate = None import asyncio import datetime attempt = 0 delay = initial_delay async def _start_live(): # returns (uuid, agen or full queue consumer) qid = await client.live(table) # Try to access underlying connection live_queues to get full event payloads candidate_attrs = ( 'connection', '_connection', 'conn', '_conn', 'ws' ) under = None for attr in candidate_attrs: obj = getattr(client, attr, None) if obj is not None and hasattr(obj, 'live_queues'): under = obj break # Some SDK variants may expose live_queues on the client itself if under is None and hasattr(client, 'live_queues'): under = client # type: ignore if under is not None and hasattr(under, 'live_queues'): import asyncio as _asyncio full_queue: _asyncio.Queue = _asyncio.Queue() under.live_queues[str(qid)].append(full_queue) # Log limited client attributes for diagnosis at debug level try: attrs = [a for a in dir(client) if ('live' in a.lower() or 'conn' in a.lower())] logger.debug('Client attributes (filtered): %s', attrs) except Exception: pass return qid, ('queue', full_queue, under) # Fallback to SDK generator which yields only the inner 'result' agen = await client.subscribe_live(qid) # Log limited client attributes for diagnosis at debug level try: attrs = [a for a in dir(client) if ('live' in a.lower() or 'conn' in a.lower())] logger.debug('Client attributes (filtered): %s', attrs) except Exception: pass return qid, ('agen', agen, None) qid = None agen = None extra = None try: while True: try: if agen is None: qid, packed = await _start_live() kind, source, under = packed agen = (kind, source) extra = under attempt = 0 delay = initial_delay kind, source = agen if kind == 'queue': # Consume full payloads with action/time/result while True: msg = await source.get() # Log full live envelope at debug level to help locate fields try: logger.debug("Live envelope: %s", msg) except Exception: pass action_str = msg.get('action') or msg.get('event') or 'UNKNOWN' action_upper = str(action_str).upper() # Filter by action if requested if allowed_actions and action_upper not in allowed_actions: continue data = msg.get('result') or msg.get('record') or msg.get('data') or msg ts = msg.get('time') or msg.get('ts') if predicate is None or (isinstance(data, dict) and predicate(data)): # Parse timestamp if possible ts_val = ts # Convert ID to RecordID if possible id_val = None if isinstance(data, dict) and 'id' in data: try: id_val = RecordID(str(data['id'])) except Exception: pass yield LiveEvent( action=action_upper, data=data, ts=ts_val, id=id_val ) else: # agen path: yields only inner result; no metadata available async for msg in source: # Log inner message yielded by SDK subscribe_live at debug level try: logger.debug("Live inner message: %s", msg) except Exception: pass data = msg.get('result') or msg.get('record') or msg.get('data') or msg # Heuristic: if payload has only 'id', treat as DELETE; else as UPSERT (CREATE/UPDATE) inferred_action = 'UNKNOWN' if isinstance(data, dict): keys = [k for k in data.keys()] if len(keys) == 1 and keys[0] == 'id': inferred_action = 'DELETE' else: inferred_action = 'UPSERT' # Filter by action if requested if allowed_actions: # Start with inferred action match match = False if inferred_action in allowed_actions: match = True # If allowed contains CREATE or UPDATE and we have UPSERT, allow it elif inferred_action == 'UPSERT' and ('CREATE' in allowed_actions or 'UPDATE' in allowed_actions): match = True if not match: continue if predicate is None or (isinstance(data, dict) and predicate(data)): # Convert ID to RecordID if possible id_val = None if isinstance(data, dict) and 'id' in data: try: id_val = RecordID(str(data['id'])) except Exception: pass yield LiveEvent( action=inferred_action, data=data, ts=None, id=id_val ) # If loop exits, restart # If loop exits, restart agen = None except asyncio.CancelledError: # Graceful cancellation; cleanup below raise except Exception: attempt += 1 if attempt > retry_limit: raise await asyncio.sleep(delay) delay = min(delay * backoff, 30.0) # cleanup old subscription if qid is not None: try: await client.kill(qid) except Exception: pass # If we registered our own queue, try to remove it if extra is not None and hasattr(extra, 'live_queues'): try: lst = extra.live_queues.get(str(qid)) or [] # remove any queue instances we might have appended for i, q in enumerate(list(lst)): # best-effort removal; identity check is fine pass except Exception: pass agen = None qid = None finally: # Ensure live query is killed and detach queue if used if qid is not None: try: await client.kill(qid) except Exception: pass
[docs] async def join(self, field_name: str, target_fields: Optional[List[str]] = None, dereference: bool = True, dereference_depth: int = 1) -> List[Any]: """Perform a JOIN-like operation on a reference field using FETCH. This method performs a JOIN-like operation on a reference field by using SurrealDB's FETCH clause to efficiently resolve references in a single query. Args: field_name: The name of the reference field to join on target_fields: Optional list of fields to select from the target document dereference: Whether to dereference references in the joined documents (default: True) dereference_depth: Maximum depth of reference resolution (default: 1) Returns: List of documents with joined data Raises: ValueError: If the field is not a ReferenceField """ # Ensure field_name is a ReferenceField field = self.document_class._fields.get(field_name) if not field or not isinstance(field, ReferenceField): raise ValueError(f"{field_name} is not a ReferenceField") if not dereference: # If no dereferencing needed, just return regular results return await self.all() # Use FETCH to join in a single query queryset = self._clone() queryset.fetch_fields.append(field_name) try: documents = await queryset.all() # If dereference_depth > 1, recursively resolve deeper references if dereference_depth > 1: for doc in documents: referenced_doc = getattr(doc, field_name, None) if referenced_doc and hasattr(referenced_doc, 'resolve_references'): await referenced_doc.resolve_references(depth=dereference_depth-1) return documents except Exception: # Fall back to manual resolution if FETCH fails documents = await self.all() target_document_class = field.document_type for doc in documents: if getattr(doc, field_name, None): ref_value = getattr(doc, field_name) ref_id = None if isinstance(ref_value, str) and ':' in ref_value: ref_id = ref_value elif hasattr(ref_value, 'id'): ref_id = ref_value.id if ref_id: referenced_doc = await target_document_class.get(id=ref_id, dereference=dereference, dereference_depth=dereference_depth) setattr(doc, field_name, referenced_doc) return documents
[docs] def join_sync(self, field_name: str, target_fields: Optional[List[str]] = None, dereference: bool = True, dereference_depth: int = 1) -> List[Any]: """Perform a JOIN-like operation on a reference field synchronously using FETCH. This method performs a JOIN-like operation on a reference field by using SurrealDB's FETCH clause to efficiently resolve references in a single query. Args: field_name: The name of the reference field to join on target_fields: Optional list of fields to select from the target document dereference: Whether to dereference references in the joined documents (default: True) dereference_depth: Maximum depth of reference resolution (default: 1) Returns: List of documents with joined data Raises: ValueError: If the field is not a ReferenceField """ # Ensure field_name is a ReferenceField field = self.document_class._fields.get(field_name) if not field or not isinstance(field, ReferenceField): raise ValueError(f"{field_name} is not a ReferenceField") if not dereference: # If no dereferencing needed, just return regular results return self.all_sync() # Use FETCH to join in a single query queryset = self._clone() queryset.fetch_fields.append(field_name) try: documents = queryset.all_sync() # If dereference_depth > 1, recursively resolve deeper references if dereference_depth > 1: for doc in documents: referenced_doc = getattr(doc, field_name, None) if referenced_doc and hasattr(referenced_doc, 'resolve_references_sync'): referenced_doc.resolve_references_sync(depth=dereference_depth-1) return documents except Exception: # Fall back to manual resolution if FETCH fails documents = self.all_sync() target_document_class = field.document_type for doc in documents: if getattr(doc, field_name, None): ref_value = getattr(doc, field_name) ref_id = None if isinstance(ref_value, str) and ':' in ref_value: ref_id = ref_value elif hasattr(ref_value, 'id'): ref_id = ref_value.id if ref_id: referenced_doc = target_document_class.get_sync(id=ref_id, dereference=dereference, dereference_depth=dereference_depth) setattr(doc, field_name, referenced_doc) return documents
def _build_query(self) -> str: """Build the query string with performance optimizations. This method builds the query string for the document class query. It automatically uses optimized direct record access when possible. Returns: The optimized query string """ # Try to build optimized direct record access query first optimized_query = self._build_direct_record_query() if optimized_query: return optimized_query # Fall back to regular query building # SurrealQL does not support SQL-style SELECT DISTINCT for full rows. # When traversal uniqueness is requested, we will deduplicate by grouping on id. from_part = self.document_class._get_collection_name() # If traversal is configured, render it in the SELECT projection, not in FROM traversal = getattr(self, "_traversal_path", None) if traversal: max_depth = getattr(self, "_traversal_max_depth", None) if max_depth and max_depth > 1: simple = traversal.strip() if simple.count("->") + simple.count("<-") == 1 and (" " not in simple): traversal_to_use = simple * max_depth else: traversal_to_use = simple else: traversal_to_use = traversal.strip() select_keyword = f"SELECT {traversal_to_use} AS traversed" elif self.select_fields: select_keyword = f"SELECT {', '.join(self.select_fields)}" else: select_keyword = "SELECT *" # Build OMIT clause if self.omit_fields: select_keyword += f" OMIT {', '.join(self.omit_fields)}" select_query = f"{select_keyword} FROM {from_part}" if self.query_parts: conditions = self._build_conditions() select_query += f" WHERE {' AND '.join(conditions)}" # Add other clauses from _build_clauses clauses = self._build_clauses() # Note: GROUP BY id for traversal deduplication can change result shapes in SurrealDB. # To keep traversal results straightforward, we do not auto-inject GROUP BY here. for clause_name, clause_sql in clauses.items(): if clause_name != 'WHERE': # WHERE clause is already handled select_query += f" {clause_sql}" return select_query
[docs] async def all(self, dereference: bool = False) -> List[Any]: """Execute the query and return all results asynchronously. This method builds and executes the query, then converts the results to instances of the document class. Args: dereference: Whether to dereference references (default: False) Returns: List of document instances """ query = self._build_query() results = await self.connection.client.query(query) if not results: return [] # Extract rows: handle both single SELECT (list[dict]) and multi-statement (list[resultset]) rows = None if isinstance(results, list): if results and isinstance(results[0], dict): rows = results else: for part in reversed(results): if isinstance(part, list): rows = part break else: rows = results if not rows: return [] if isinstance(rows, dict): rows = [rows] # If this is a traversal query, return raw rows (shape may not match document schema) if getattr(self, "_traversal_path", None): return rows is_partial = self.select_fields is not None processed_results = [self.document_class.from_db(doc, dereference=dereference, partial=is_partial) for doc in rows] return processed_results
[docs] def all_sync(self, dereference: bool = False) -> List[Any]: """Execute the query and return all results synchronously. This method builds and executes the query, then converts the results to instances of the document class. Args: dereference: Whether to dereference references (default: False) Returns: List of document instances """ query = self._build_query() results = self.connection.client.query(query) if not results: return [] # Extract rows: handle both single SELECT (list[dict]) and multi-statement (list[resultset]) rows = None if isinstance(results, list): if results and isinstance(results[0], dict): rows = results else: for part in reversed(results): if isinstance(part, list): rows = part break else: rows = results if not rows: return [] if isinstance(rows, dict): rows = [rows] # If this is a traversal query, return raw rows (shape may not match document schema) if getattr(self, "_traversal_path", None): return rows is_partial = self.select_fields is not None processed_results = [self.document_class.from_db(doc, dereference=dereference, partial=is_partial) for doc in rows] return processed_results
[docs] async def count(self) -> int: """Count documents matching the query asynchronously. This method builds and executes a count query to count the number of documents matching the query. Returns: Number of matching documents """ count_query = f"SELECT count() FROM {self.document_class._get_collection_name()}" if self.query_parts: conditions = self._build_conditions() count_query += f" WHERE {' AND '.join(conditions)}" result = await self.connection.client.query(count_query) if not result or not result[0]: return 0 return len(result)
[docs] def count_sync(self) -> int: """Count documents matching the query synchronously. This method builds and executes a count query to count the number of documents matching the query. Returns: Number of matching documents """ count_query = f"SELECT count() FROM {self.document_class._get_collection_name()}" if self.query_parts: conditions = self._build_conditions() count_query += f" WHERE {' AND '.join(conditions)}" result = self.connection.client.query(count_query) if not result or not result[0]: return 0 return len(result)
[docs] async def get(self, dereference: bool = False, **kwargs: Any) -> Any: """Get a single document matching the query asynchronously. This method applies filters and ensures that exactly one document is returned. Args: dereference: Whether to dereference references (default: False) **kwargs: Field names and values to filter by Returns: The matching document Raises: DoesNotExist: If no matching document is found MultipleObjectsReturned: If multiple matching documents are found """ queryset = self.filter(**kwargs) queryset.limit_value = 2 # Get 2 to check for multiple results = await queryset.all(dereference=dereference) if not results: raise DoesNotExist(f"{self.document_class.__name__} matching query does not exist.") if len(results) > 1: raise MultipleObjectsReturned(f"Multiple {self.document_class.__name__} objects returned instead of one") return results[0]
[docs] def get_sync(self, dereference: bool = False, **kwargs: Any) -> Any: """Get a single document matching the query synchronously. This method applies filters and ensures that exactly one document is returned. Args: dereference: Whether to dereference references (default: False) **kwargs: Field names and values to filter by Returns: The matching document Raises: DoesNotExist: If no matching document is found MultipleObjectsReturned: If multiple matching documents are found """ queryset = self.filter(**kwargs) queryset.limit_value = 2 # Get 2 to check for multiple results = queryset.all_sync(dereference=dereference) if not results: raise DoesNotExist(f"{self.document_class.__name__} matching query does not exist.") if len(results) > 1: raise MultipleObjectsReturned(f"Multiple {self.document_class.__name__} objects returned instead of one") return results[0]
[docs] async def create(self, **kwargs: Any) -> Any: """Create a new document asynchronously. This method creates a new document with the given field values. Args: **kwargs: Field names and values for the new document Returns: The created document """ document = self.document_class(**kwargs) return await document.save(self.connection)
[docs] def create_sync(self, **kwargs: Any) -> Any: """Create a new document synchronously. This method creates a new document with the given field values. Args: **kwargs: Field names and values for the new document Returns: The created document """ document = self.document_class(**kwargs) return document.save_sync(self.connection)
[docs] async def update(self, returning: Optional[str] = None, **kwargs: Any) -> List[Any]: """Update documents matching the query asynchronously with performance optimizations. This method updates documents matching the query with the given field values. Uses direct record access for bulk ID operations for better performance. Args: **kwargs: Field names and values to update Returns: List of updated documents """ # PERFORMANCE OPTIMIZATION: Use direct record access for bulk operations if self._bulk_id_selection or self._id_range_selection: # For bulk operations, use subquery with direct record access for better performance optimized_query = self._build_direct_record_query() if optimized_query: # Convert SELECT to subquery for UPDATE subquery = optimized_query.replace("SELECT *", "SELECT id") update_query = f"UPDATE ({subquery}) SET {', '.join(f'{k} = {escape_literal(v)}' for k, v in kwargs.items())}" if returning in ("before", "after", "diff"): update_query += f" RETURN {returning.upper()}" result = await self.connection.client.query(update_query) if not result: return [] # Handle different result structures if isinstance(result[0], dict): # Subquery UPDATE case: result is a flat list of documents return [self.document_class.from_db(doc) for doc in result] elif isinstance(result[0], list): # Normal case: result[0] is a list of document dictionaries return [self.document_class.from_db(doc) for doc in result[0]] else: return [] # Fall back to regular update query update_query = f"UPDATE {self.document_class._get_collection_name()}" if self.query_parts: conditions = self._build_conditions() update_query += f" WHERE {' AND '.join(conditions)}" update_query += f" SET {', '.join(f'{k} = {escape_literal(v)}' for k, v in kwargs.items())}" if returning in ("before", "after", "diff"): update_query += f" RETURN {returning.upper()}" result = await self.connection.client.query(update_query) if not result or not result[0]: return [] return [self.document_class.from_db(doc) for doc in result[0]]
[docs] def update_sync(self, returning: Optional[str] = None, **kwargs: Any) -> List[Any]: """Update documents matching the query synchronously with performance optimizations. This method updates documents matching the query with the given field values. Uses direct record access for bulk ID operations for better performance. Args: **kwargs: Field names and values to update Returns: List of updated documents """ # PERFORMANCE OPTIMIZATION: Use direct record access for bulk operations if self._bulk_id_selection or self._id_range_selection: # For bulk operations, use subquery with direct record access for better performance optimized_query = self._build_direct_record_query() if optimized_query: # Convert SELECT to subquery for UPDATE subquery = optimized_query.replace("SELECT *", "SELECT id") update_query = f"UPDATE ({subquery}) SET {', '.join(f'{k} = {escape_literal(v)}' for k, v in kwargs.items())}" if returning in ("before", "after", "diff"): update_query += f" RETURN {returning.upper()}" result = self.connection.client.query(update_query) if not result: return [] # Handle different result structures if isinstance(result[0], dict): # Subquery UPDATE case: result is a flat list of documents return [self.document_class.from_db(doc) for doc in result] elif isinstance(result[0], list): # Normal case: result[0] is a list of document dictionaries return [self.document_class.from_db(doc) for doc in result[0]] else: return [] # Fall back to regular update query update_query = f"UPDATE {self.document_class._get_collection_name()}" if self.query_parts: conditions = self._build_conditions() update_query += f" WHERE {' AND '.join(conditions)}" update_query += f" SET {', '.join(f'{k} = {escape_literal(v)}' for k, v in kwargs.items())}" result = self.connection.client.query(update_query) if not result or not result[0]: return [] return [self.document_class.from_db(doc) for doc in result[0]]
[docs] async def delete(self) -> int: """Delete documents matching the query asynchronously with performance optimizations. This method deletes documents matching the query. Uses direct record access for bulk ID operations for better performance. Returns: Number of deleted documents """ # PERFORMANCE OPTIMIZATION: Use direct record access for bulk operations if self._bulk_id_selection: # Use direct record deletion syntax for bulk ID operations record_ids = [self._format_record_id(id_val) for id_val in self._bulk_id_selection] delete_query = f"DELETE {', '.join(record_ids)}" result = await self.connection.client.query(delete_query) # Direct record deletion returns empty list on success # Return the count of IDs we attempted to delete return len(record_ids) elif self._id_range_selection: # For range operations, use optimized query with subquery optimized_query = self._build_direct_record_query() if optimized_query: # Convert SELECT to subquery for DELETE subquery = optimized_query.replace("SELECT *", "SELECT id") delete_query = f"DELETE ({subquery})" result = await self.connection.client.query(delete_query) if not result or not result[0]: return 0 return len(result[0]) # Fall back to regular delete query delete_query = f"DELETE FROM {self.document_class._get_collection_name()}" if self.query_parts: conditions = self._build_conditions() delete_query += f" WHERE {' AND '.join(conditions)}" result = await self.connection.client.query(delete_query) if not result or not result[0]: return 0 return len(result[0])
[docs] def delete_sync(self) -> int: """Delete documents matching the query synchronously with performance optimizations. This method deletes documents matching the query. Uses direct record access for bulk ID operations for better performance. Returns: Number of deleted documents """ # PERFORMANCE OPTIMIZATION: Use direct record access for bulk operations if self._bulk_id_selection: # Use direct record deletion syntax for bulk ID operations record_ids = [self._format_record_id(id_val) for id_val in self._bulk_id_selection] delete_query = f"DELETE {', '.join(record_ids)}" result = self.connection.client.query(delete_query) # Direct record deletion returns empty list on success # Return the count of IDs we attempted to delete return len(record_ids) elif self._id_range_selection: # For range operations, use optimized query with subquery optimized_query = self._build_direct_record_query() if optimized_query: # Convert SELECT to subquery for DELETE subquery = optimized_query.replace("SELECT *", "SELECT id") delete_query = f"DELETE ({subquery})" result = self.connection.client.query(delete_query) if not result or not result[0]: return 0 return len(result[0]) # Fall back to regular delete query delete_query = f"DELETE FROM {self.document_class._get_collection_name()}" if self.query_parts: conditions = self._build_conditions() delete_query += f" WHERE {' AND '.join(conditions)}" result = self.connection.client.query(delete_query) if not result or not result[0]: return 0 return len(result[0])
[docs] async def bulk_create(self, documents: List[Any], batch_size: int = 1000, validate: bool = True, return_documents: bool = True) -> Union[List[Any], int]: """Create multiple documents in a single operation asynchronously. This method creates multiple documents in a single operation, processing them in batches for better performance. It can optionally validate the documents and return the created documents. Args: documents: List of Document instances to create batch_size: Number of documents per batch (default: 1000) validate: Whether to validate documents (default: True) return_documents: Whether to return created documents (default: True) Returns: List of created documents with their IDs set if return_documents=True, otherwise returns the count of created documents """ if not documents: return [] if return_documents else 0 collection = self.document_class._get_collection_name() total_created = 0 created_docs = [] if return_documents else None # Process in batches for i in range(0, len(documents), batch_size): batch = documents[i:i + batch_size] # Validate batch if required if validate: # Sequential validation since validate() is synchronous for doc in batch: doc.validate() # Separate documents with and without explicit IDs docs_without_ids = [] docs_with_ids = [] for doc in batch: if doc.id: docs_with_ids.append(doc) else: docs_without_ids.append(doc) # Handle documents without IDs using bulk INSERT if docs_without_ids: data = [doc.to_db() for doc in docs_without_ids] from ..document import serialize_http_safe data = [serialize_http_safe(d) for d in data] query = f"INSERT INTO {collection} {json.dumps(data)};" try: result = await self.connection.client.query(query) if return_documents and result and result[0]: batch_docs = [self.document_class.from_db(doc_data) for doc_data in result[0]] created_docs.extend(batch_docs) total_created += len(batch_docs) elif result and result[0]: total_created += len(result[0]) except Exception as e: logger.error(f"Error in bulk create batch (no IDs): {str(e)}") # Handle documents with explicit IDs using individual upserts for doc in docs_with_ids: try: data = doc.to_db() # Remove ID from data and extract ID part if 'id' in data: del data['id'] id_part = str(doc.id).split(':')[1] result = await self.connection.client.upsert( RecordID(collection, int(id_part) if id_part.isdigit() else id_part), data ) if return_documents and result: if isinstance(result, list) and result: doc_data = result[0] else: doc_data = result if isinstance(doc_data, dict): if created_docs is not None: created_docs.append(self.document_class.from_db(doc_data)) total_created += 1 except Exception as e: logger.error(f"Error creating document with ID {doc.id}: {str(e)}") continue return created_docs if return_documents else total_created
[docs] def bulk_create_sync(self, documents: List[Any], batch_size: int = 1000, validate: bool = True, return_documents: bool = True) -> Union[List[Any], int]: """Create multiple documents in a single operation synchronously. This method creates multiple documents in a single operation, processing them in batches for better performance. It can optionally validate the documents and return the created documents. Args: documents: List of Document instances to create batch_size: Number of documents per batch (default: 1000) validate: Whether to validate documents (default: True) return_documents: Whether to return created documents (default: True) Returns: List of created documents with their IDs set if return_documents=True, otherwise returns the count of created documents """ if not documents: return [] if return_documents else 0 collection = self.document_class._get_collection_name() total_created = 0 created_docs = [] if return_documents else None # Process in batches for i in range(0, len(documents), batch_size): batch = documents[i:i + batch_size] # Validate batch if required if validate: # Sequential validation for sync version for doc in batch: doc.validate() # Convert batch to DB representation data = [doc.to_db() for doc in batch] from ..document import serialize_http_safe data = [serialize_http_safe(d) for d in data] # Construct optimized bulk insert query query = f"INSERT INTO {collection} {json.dumps(data)};" # Execute batch insert try: result = self.connection.client.query(query) if return_documents and result and result[0]: # Process results if needed batch_docs = [self.document_class.from_db(doc_data) for doc_data in result[0]] created_docs.extend(batch_docs) total_created += len(batch_docs) elif result and result[0]: total_created += len(result[0]) except Exception as e: # Log error and continue with next batch logger.error(f"Error in bulk create batch: {str(e)}") continue return created_docs if return_documents else total_created
[docs] async def explain(self, full: bool = False) -> List[Dict[str, Any]]: """Get query execution plan for performance analysis. This method appends EXPLAIN to the query to show how SurrealDB will execute it, helping identify performance bottlenecks. Args: full: Whether to include full explanation including execution trace (default: False) Returns: List of execution plan steps with details Example: plan = await User.objects.filter(age__lt=18).explain() print(f"Query will use: {plan[0]['operation']}") """ # If with_explain() was called, explain_value might be set. # But we override duplicates anyway. query = self._build_query() if "EXPLAIN" not in query: query += " EXPLAIN FULL" if full else " EXPLAIN" elif full and "EXPLAIN FULL" not in query: query = query.replace("EXPLAIN", "EXPLAIN FULL") result = await self.connection.client.query(query) return result[0] if result and result[0] else []
[docs] def explain_sync(self, full: bool = False) -> List[Dict[str, Any]]: """Get query execution plan for performance analysis synchronously. Args: full: Whether to include full explanation including execution trace (default: False) Returns: List of execution plan steps with details """ query = self._build_query() if "EXPLAIN" not in query: query += " EXPLAIN FULL" if full else " EXPLAIN" elif full and "EXPLAIN FULL" not in query: query = query.replace("EXPLAIN", "EXPLAIN FULL") result = self.connection.client.query(query) return result[0] if result and result[0] else []
[docs] def suggest_indexes(self) -> List[str]: """Suggest indexes based on current query patterns. Analyzes the current query conditions and suggests optimal indexes that could improve performance. Returns: List of suggested DEFINE INDEX statements Example:: suggestions = User.objects.filter(age__lt=18, city="NYC").suggest_indexes() for suggestion in suggestions: print(f"Consider: {suggestion}") """ suggestions = [] collection_name = self.document_class._get_collection_name() # Analyze filter conditions analyzed_fields = set() for field, op, value in self.query_parts: if field != 'id' and field not in analyzed_fields: # ID doesn't need indexing analyzed_fields.add(field) if op in ('=', '!=', '>', '<', '>=', '<=', 'INSIDE', 'NOT INSIDE'): suggestions.append( f"DEFINE INDEX idx_{collection_name}_{field} ON {collection_name} FIELDS {field}" ) # Suggest compound indexes for multiple conditions if len(analyzed_fields) > 1: field_list = ', '.join(sorted(analyzed_fields)) suggestions.append( f"DEFINE INDEX idx_{collection_name}_compound ON {collection_name} FIELDS {field_list}" ) # Suggest order by indexes if self.order_by_value: order_field, _ = self.order_by_value if order_field not in analyzed_fields: suggestions.append( f"DEFINE INDEX idx_{collection_name}_{order_field} ON {collection_name} FIELDS {order_field}" ) return list(set(suggestions)) # Remove duplicates