Source code for quantumengine.document

import json
import datetime
import logging
from dataclasses import dataclass, field as dataclass_field, make_dataclass
from typing import Any, Dict, List, Optional, Type, Union, ClassVar, TypeVar, Generic
from .query import QuerySet, RelationQuerySet, QuerySetDescriptor
from .fields import Field, RecordIDField, ReferenceField, DictField
from .connection import ConnectionRegistry, SurrealEngineAsyncConnection, SurrealEngineSyncConnection
from .types import IdType, DatabaseValue
from surrealdb import RecordID
from .signals import (
    pre_init, post_init, pre_save, pre_save_post_validation, post_save,
    pre_delete, post_delete, pre_bulk_insert, post_bulk_insert, SIGNAL_SUPPORT
)
from .materialized_view import MaterializedView

# Type variable for Document classes
T = TypeVar('T', bound='Document')

# Set up logging
logger = logging.getLogger(__name__)


class DocumentMetaclass(type):
    """Metaclass for Document classes.

    This metaclass processes field attributes in Document classes to create
    a structured schema. It handles field inheritance, field naming, and
    metadata configuration.

    Attributes:
        _meta: Dictionary of metadata for the document class
        _fields: Dictionary of fields for the document class
        _fields_ordered: List of field names in order of definition
    """

    def __new__(mcs, name: str, bases: tuple, attrs: Dict[str, Any]) -> Type:
        """Create a new Document class.

        This method processes the class attributes to create a structured schema.
        It handles field inheritance, field naming, and metadata configuration.

        Args:
            name: Name of the class being created
            bases: Tuple of base classes
            attrs: Dictionary of class attributes

        Returns:
            The new Document class
        """
        # Skip processing for the base Document class
        if name == 'Document' and attrs.get('__module__') == __name__:
            return super().__new__(mcs, name, bases, attrs)

        # Get or create _meta
        meta = attrs.get('Meta', type('Meta', (), {}))
        attrs['_meta'] = {
            'collection': getattr(meta, 'collection', name.lower()),
            'table_name': getattr(meta, 'table_name', getattr(meta, 'collection', name.lower())),
            'backend': getattr(meta, 'backend', 'surrealdb'),
            'indexes': getattr(meta, 'indexes', []),
            'id_field': getattr(meta, 'id_field', 'id'),
            'strict': getattr(meta, 'strict', True),
            # ClickHouse-specific attributes
            'engine': getattr(meta, 'engine', None),
            'engine_params': getattr(meta, 'engine_params', None),
            'partition_by': getattr(meta, 'partition_by', None),
            'order_by': getattr(meta, 'order_by', None),
            'primary_key': getattr(meta, 'primary_key', None),
            'ttl': getattr(meta, 'ttl', None),
            'settings': getattr(meta, 'settings', None),
            # MaterializedDocument-specific attributes
            'view_name': getattr(meta, 'view_name', None),
        }

        # Process fields
        fields: Dict[str, Field] = {}
        fields_ordered: List[str] = []

        # Inherit fields from parent classes
        for base in bases:
            if hasattr(base, '_fields'):
                fields.update(base._fields)
                fields_ordered.extend(base._fields_ordered)

        # Add fields from current class
        for attr_name, attr_value in list(attrs.items()):
            if isinstance(attr_value, Field):
                fields[attr_name] = attr_value
                fields_ordered.append(attr_name)

                # Set field name
                attr_value.name = attr_name

                # Set db_field if not set
                if not attr_value.db_field:
                    attr_value.db_field = attr_name

                # Remove the field from attrs so it doesn't become a class attribute
                del attrs[attr_name]

        attrs['_fields'] = fields
        attrs['_fields_ordered'] = fields_ordered

        # Create the new class
        new_class = super().__new__(mcs, name, bases, attrs)

        # Assign owner document to fields
        for field_name, field in new_class._fields.items():
            field.owner_document = new_class

        return new_class


[docs] class Document(metaclass=DocumentMetaclass): """Base class for all documents. This class provides the foundation for all document models in the ORM. It includes methods for CRUD operations, validation, and serialization. The Document class uses a Meta inner class to configure behavior: Example: >>> class User(Document): ... username = StringField(required=True) ... email = EmailField(required=True) ... age = IntField(min_value=0) ... ... class Meta: ... collection = "users" # Collection/table name ... backend = "surrealdb" # Backend to use ... indexes = [ # Index definitions ... {"name": "idx_username", "fields": ["username"], "unique": True}, ... {"name": "idx_email", "fields": ["email"], "unique": True} ... ] ... strict = True # Strict field validation Meta Options: collection (str): The name of the collection/table in the database. Defaults to the lowercase class name. table_name (str): Alternative to 'collection', used by some backends. Defaults to the value of 'collection'. backend (str): The database backend to use ("surrealdb" or "clickhouse"). Defaults to "surrealdb". indexes (list): List of index definitions. Each index is a dict with: - name (str): Index name - fields (list): List of field names to index - unique (bool): Whether the index is unique (optional) - type (str): Index type for backend-specific indexes (optional) id_field (str): Name of the ID field. Defaults to "id". strict (bool): Whether to raise errors for unknown fields. Defaults to True. Attributes: objects: QuerySetDescriptor for querying documents of this class _data: Dictionary of field values _changed_fields: List of field names that have been changed _fields: Dictionary of fields for this document class (class attribute) _fields_ordered: List of field names in order of definition (class attribute) _meta: Dictionary of metadata for this document class (class attribute) """ objects = QuerySetDescriptor() id = RecordIDField()
[docs] def __init__(self, **values: Any) -> None: """Initialize a new Document. Args: **values: Field values to set on the document Raises: AttributeError: If strict mode is enabled and an unknown field is provided """ if 'id' not in self._fields: self._fields['id'] = RecordIDField() # Trigger pre_init signal if SIGNAL_SUPPORT: pre_init.send(self.__class__, document=self, values=values) self._data: Dict[str, Any] = {} self._changed_fields: List[str] = [] # Set default values for field_name, field in self._fields.items(): value = field.default if callable(value): value = value() self._data[field_name] = value # Set values from kwargs for key, value in values.items(): if key in self._fields: setattr(self, key, value) elif self._meta.get('strict', True): raise AttributeError(f"Unknown field: {key}") # Trigger post_init signal if SIGNAL_SUPPORT: post_init.send(self.__class__, document=self)
[docs] def __getattr__(self, name: str) -> Any: """Get a field value. This method is called when an attribute is not found through normal lookup. It checks if the attribute is a field and returns its value if it is. Args: name: Name of the attribute to get Returns: The field value Raises: AttributeError: If the attribute is not a field """ if name in self._fields: # Return the value directly from _data instead of the field instance return self._data.get(name) raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
[docs] def __setattr__(self, name: str, value: Any) -> None: """Set a field value. This method is called when an attribute is set. It checks if the attribute is a field and validates the value if it is. Args: name: Name of the attribute to set value: Value to set """ if name.startswith('_'): super().__setattr__(name, value) elif name in self._fields: field = self._fields[name] self._data[name] = field.validate(value) if name not in self._changed_fields: self._changed_fields.append(name) else: super().__setattr__(name, value)
@property def id(self) -> Optional[IdType]: """Get the document ID. Returns: The document ID (string, RecordID, or None) """ return self._data.get('id') @id.setter def id(self, value: Optional[IdType]) -> None: """Set the document ID. Args: value: The document ID to set """ if 'id' in self._fields: field = self._fields['id'] self._data['id'] = field.validate(value) if 'id' not in self._changed_fields: self._changed_fields.append('id') else: self._data['id'] = value @classmethod def _get_backend(cls): """Get the backend instance for this document class. Returns: Backend instance configured for this document """ backend_name = cls._meta.get('backend', 'surrealdb') from .backends import BackendRegistry from .connection import ConnectionRegistry # Get the backend class backend_class = BackendRegistry.get_backend(backend_name) # Get the connection for this backend connection = ConnectionRegistry.get_default_connection(backend_name) # Return backend instance return backend_class(connection) @classmethod def _get_table_name(cls) -> str: """Return the table name for this document. Returns: The table name """ return cls._meta.get('table_name', cls._meta.get('collection')) @classmethod def _get_collection_name(cls) -> str: """Return the collection name for this document. Returns: The collection name """ return cls._meta.get('collection')
[docs] def validate(self) -> None: """Validate all fields. This method validates all fields in the document against their validation rules. Raises: ValidationError: If a field fails validation """ for field_name, field in self._fields.items(): value = self._data.get(field_name) field.validate(value)
[docs] def to_dict(self) -> Dict[str, DatabaseValue]: """Convert the document to a dictionary. This method converts the document to a dictionary containing all field values including the document ID. It ensures that RecordID objects are properly converted to strings for JSON serialization. It also recursively converts embedded documents to dictionaries. Returns: Dictionary of field values including ID """ # Start with the ID if it exists result = {} if self.id is not None: # Convert RecordID to string if needed result['id'] = str(self.id) if isinstance(self.id, RecordID) else self.id # Add all other fields with proper conversion for k, v in self._data.items(): if k in self._fields: # Convert RecordID objects to strings if isinstance(v, RecordID): result[k] = str(v) # Handle embedded documents by recursively calling to_dict() elif hasattr(v, 'to_dict') and callable(v.to_dict): result[k] = v.to_dict() # Handle lists that might contain RecordIDs or embedded documents elif isinstance(v, list): result[k] = [ item.to_dict() if hasattr(item, 'to_dict') and callable(item.to_dict) else str(item) if isinstance(item, RecordID) else item for item in v ] # Handle dicts that might contain RecordIDs or embedded documents elif isinstance(v, dict): result[k] = { key: val.to_dict() if hasattr(val, 'to_dict') and callable(val.to_dict) else str(val) if isinstance(val, RecordID) else val for key, val in v.items() } else: result[k] = v return result
[docs] def to_db(self) -> Dict[str, DatabaseValue]: """Convert the document to a database-friendly dictionary. This method converts the document to a dictionary suitable for storage in the database. It applies field-specific conversions and includes only non-None values unless the field is required. Returns: Dictionary of field values for the database """ result = {} backend_name = self._meta.get('backend', 'surrealdb') for field_name, field in self._fields.items(): value = self._data.get(field_name) if value is not None or field.required: db_field = field.db_field or field_name # Pass backend parameter to field.to_db if supported if 'backend' in field.to_db.__code__.co_varnames: result[db_field] = field.to_db(value, backend=backend_name) else: result[db_field] = field.to_db(value) return result
[docs] @classmethod def from_db(cls, data: Any, dereference: bool = False) -> 'Document': """Create a document instance from database data. Args: data: Data from the database (dictionary, string, RecordID, etc.) dereference: Whether to dereference references (default: False) Returns: A new document instance """ # Create an empty instance without triggering signals instance = cls.__new__(cls) # Initialize _data and _changed_fields instance._data = {} instance._changed_fields = [] # Add id field if not present if 'id' not in instance._fields: instance._fields['id'] = RecordIDField() # Set default values for field_name, field in instance._fields.items(): value = field.default if callable(value): value = value() instance._data[field_name] = value # If data is a dictionary, update with database values if isinstance(data, dict): backend_name = cls._meta.get('backend', 'surrealdb') # First, handle fields with db_field mapping for field_name, field in instance._fields.items(): db_field = field.db_field or field_name if db_field in data: # Pass the dereference and backend parameters to from_db if supported co_varnames = field.from_db.__code__.co_varnames kwargs = {} if 'dereference' in co_varnames: kwargs['dereference'] = dereference if 'backend' in co_varnames: kwargs['backend'] = backend_name if kwargs: instance._data[field_name] = field.from_db(data[db_field], **kwargs) else: instance._data[field_name] = field.from_db(data[db_field]) # Then, handle fields without db_field mapping (for backward compatibility) for key, value in data.items(): if key in instance._fields: field = instance._fields[key] # Pass the dereference and backend parameters to from_db if supported co_varnames = field.from_db.__code__.co_varnames kwargs = {} if 'dereference' in co_varnames: kwargs['dereference'] = dereference if 'backend' in co_varnames: kwargs['backend'] = backend_name if kwargs: instance._data[key] = field.from_db(value, **kwargs) else: instance._data[key] = field.from_db(value) # If data is a RecordID or string, set it as the ID elif isinstance(data, (RecordID, str)): instance._data['id'] = data # For other types, try to convert to string and set as ID else: try: instance._data['id'] = str(data) except (TypeError, ValueError): # If conversion fails, just use the data as is pass return instance
[docs] async def resolve_references(self, depth: int = 1) -> 'Document': """Resolve all references in this document using FETCH. This method uses SurrealDB's FETCH clause to efficiently resolve references instead of making individual queries for each reference. Args: depth: Maximum depth of reference resolution (default: 1) Returns: The document instance with resolved references """ if depth <= 0 or not self.id: return self # Build FETCH clause for all reference fields fetch_fields = [] for field_name, field in self._fields.items(): if isinstance(field, ReferenceField) and getattr(self, field_name): fetch_fields.append(field_name) if not fetch_fields: return self # Use FETCH to resolve references in a single query connection = ConnectionRegistry.get_default_connection() query = f"SELECT * FROM `{self.id}` FETCH {', '.join(fetch_fields)}" try: # Use FETCH with a WHERE clause instead of selecting from specific record fetch_query = f"SELECT * FROM {self.__class__._get_collection_name()} WHERE id = {self.id} FETCH {', '.join(fetch_fields)}" result = await connection.client.query(fetch_query) if result and result[0]: # Update this document with fetched data fetched_data = result[0][0] updated_doc = self.from_db(fetched_data) # Copy the resolved references to this instance for field_name in fetch_fields: if hasattr(updated_doc, field_name): setattr(self, field_name, getattr(updated_doc, field_name)) # If depth > 1, recursively resolve references in fetched documents if depth > 1: for field_name in fetch_fields: referenced_doc = getattr(self, field_name, None) if referenced_doc and hasattr(referenced_doc, 'resolve_references'): await referenced_doc.resolve_references(depth=depth-1) except Exception: # Fall back to manual resolution if FETCH fails for field_name, field in self._fields.items(): if isinstance(field, ReferenceField) and getattr(self, field_name): ref_id = getattr(self, field_name) if isinstance(ref_id, str) and ':' in ref_id: referenced_doc = await field.document_type.get(id=ref_id, dereference=True) if referenced_doc and depth > 1: await referenced_doc.resolve_references(depth=depth-1) setattr(self, field_name, referenced_doc) elif isinstance(ref_id, RecordID): ref_id_str = str(ref_id) referenced_doc = await field.document_type.get(id=ref_id_str, dereference=True) if referenced_doc and depth > 1: await referenced_doc.resolve_references(depth=depth-1) setattr(self, field_name, referenced_doc) return self
[docs] def resolve_references_sync(self, depth: int = 1) -> 'Document': """Resolve all references in this document synchronously using FETCH. This method uses SurrealDB's FETCH clause to efficiently resolve references instead of making individual queries for each reference. Args: depth: Maximum depth of reference resolution (default: 1) Returns: The document instance with resolved references """ if depth <= 0 or not self.id: return self # Build FETCH clause for all reference fields fetch_fields = [] for field_name, field in self._fields.items(): if isinstance(field, ReferenceField) and getattr(self, field_name): fetch_fields.append(field_name) if not fetch_fields: return self # Use FETCH to resolve references in a single query connection = ConnectionRegistry.get_default_connection() query = f"SELECT * FROM `{self.id}` FETCH {', '.join(fetch_fields)}" try: # Use FETCH with a WHERE clause instead of selecting from specific record fetch_query = f"SELECT * FROM {self.__class__._get_collection_name()} WHERE id = {self.id} FETCH {', '.join(fetch_fields)}" result = connection.client.query(fetch_query) if result and result[0]: # Update this document with fetched data fetched_data = result[0][0] updated_doc = self.from_db(fetched_data) # Copy the resolved references to this instance for field_name in fetch_fields: if hasattr(updated_doc, field_name): setattr(self, field_name, getattr(updated_doc, field_name)) # If depth > 1, recursively resolve references in fetched documents if depth > 1: for field_name in fetch_fields: referenced_doc = getattr(self, field_name, None) if referenced_doc and hasattr(referenced_doc, 'resolve_references_sync'): referenced_doc.resolve_references_sync(depth=depth-1) except Exception: # Fall back to manual resolution if FETCH fails for field_name, field in self._fields.items(): if isinstance(field, ReferenceField) and getattr(self, field_name): ref_id = getattr(self, field_name) if isinstance(ref_id, str) and ':' in ref_id: referenced_doc = field.document_type.get_sync(id=ref_id, dereference=True) if referenced_doc and depth > 1: referenced_doc.resolve_references_sync(depth=depth-1) setattr(self, field_name, referenced_doc) elif isinstance(ref_id, RecordID): ref_id_str = str(ref_id) referenced_doc = field.document_type.get_sync(id=ref_id_str, dereference=True) if referenced_doc and depth > 1: referenced_doc.resolve_references_sync(depth=depth-1) setattr(self, field_name, referenced_doc) return self
[docs] @classmethod async def get(cls: Type[T], id: IdType, dereference: bool = False, dereference_depth: int = 1, **kwargs: Any) -> T: """Get a document by ID with optional dereferencing using FETCH. This method retrieves a document by ID and optionally resolves references using SurrealDB's FETCH clause for efficient reference resolution. Args: id: The ID of the document to retrieve dereference: Whether to resolve references (default: False) dereference_depth: Maximum depth of reference resolution (default: 1) **kwargs: Additional arguments to pass to the get method Returns: The document instance with optionally resolved references """ if not dereference: # No dereferencing needed, use regular get return await cls.objects.get(id=id, **kwargs) # Build FETCH clause for reference fields fetch_fields = [] for field_name, field in cls._fields.items(): if isinstance(field, ReferenceField): fetch_fields.append(field_name) if fetch_fields: # Use FETCH to resolve references in the initial query connection = ConnectionRegistry.get_default_connection() # Handle ID format - both strings and RecordID objects if (isinstance(id, str) and ':' in id) or isinstance(id, RecordID): record_id = str(id) # Convert RecordID to string else: record_id = f"{cls._get_collection_name()}:{id}" try: # Use FETCH on the entire collection, then filter fetch_query = f"SELECT * FROM {cls._get_collection_name()} FETCH {', '.join(fetch_fields)}" result = await connection.client.query(fetch_query) if not result or not result[0]: from .exceptions import DoesNotExist raise DoesNotExist(f"Object with ID '{id}' does not exist.") # Handle both single document and list of documents documents = result[0] target_doc = None # If documents is a single dict, wrap it in a list if isinstance(documents, dict): documents = [documents] # Find the document with the matching ID for doc_data in documents: if isinstance(doc_data, dict) and str(doc_data.get('id')) == record_id: target_doc = doc_data break if not target_doc: from .exceptions import DoesNotExist raise DoesNotExist(f"Object with ID '{id}' does not exist.") document = cls.from_db(target_doc) # If dereference_depth > 1, recursively resolve deeper references if dereference_depth > 1: await document.resolve_references(depth=dereference_depth) return document except Exception: # Fall back to regular get with manual dereferencing pass # Fallback to original method document = await cls.objects.get(id=id, **kwargs) if dereference and dereference_depth > 1 and document: await document.resolve_references(depth=dereference_depth) return document
[docs] @classmethod def get_sync(cls: Type[T], id: IdType, dereference: bool = False, dereference_depth: int = 1, **kwargs: Any) -> T: """Get a document by ID with optional dereferencing synchronously using FETCH. This method retrieves a document by ID and optionally resolves references using SurrealDB's FETCH clause for efficient reference resolution. Args: id: The ID of the document to retrieve dereference: Whether to resolve references (default: False) dereference_depth: Maximum depth of reference resolution (default: 1) **kwargs: Additional arguments to pass to the get method Returns: The document instance with optionally resolved references """ if not dereference: # No dereferencing needed, use regular get return cls.objects.get_sync(id=id, **kwargs) # Build FETCH clause for reference fields fetch_fields = [] for field_name, field in cls._fields.items(): if isinstance(field, ReferenceField): fetch_fields.append(field_name) if fetch_fields: # Use FETCH to resolve references in the initial query connection = ConnectionRegistry.get_default_connection() # Handle ID format - both strings and RecordID objects if (isinstance(id, str) and ':' in id) or isinstance(id, RecordID): record_id = str(id) # Convert RecordID to string else: record_id = f"{cls._get_collection_name()}:{id}" try: # Use FETCH on the entire collection, then filter fetch_query = f"SELECT * FROM {cls._get_collection_name()} FETCH {', '.join(fetch_fields)}" result = connection.client.query(fetch_query) if not result or not result[0]: from .exceptions import DoesNotExist raise DoesNotExist(f"Object with ID '{id}' does not exist.") # Handle both single document and list of documents documents = result[0] target_doc = None # If documents is a single dict, wrap it in a list if isinstance(documents, dict): documents = [documents] # Find the document with the matching ID for doc_data in documents: if isinstance(doc_data, dict) and str(doc_data.get('id')) == record_id: target_doc = doc_data break if not target_doc: from .exceptions import DoesNotExist raise DoesNotExist(f"Object with ID '{id}' does not exist.") document = cls.from_db(target_doc) # If dereference_depth > 1, recursively resolve deeper references if dereference_depth > 1: document.resolve_references_sync(depth=dereference_depth) return document except Exception: # Fall back to regular get with manual dereferencing pass # Fallback to original method document = cls.objects.get_sync(id=id, **kwargs) if dereference and dereference_depth > 1 and document: document.resolve_references_sync(depth=dereference_depth) return document
[docs] async def save(self: T, connection: Optional[Any] = None) -> T: """Save the document to the database asynchronously. This method saves the document to the database, either creating a new document or updating an existing one based on whether the document has an ID. Args: connection: The database connection to use (optional, deprecated for multi-backend) Returns: The saved document instance Raises: ValidationError: If the document fails validation """ # Trigger pre_save signal if SIGNAL_SUPPORT: pre_save.send(self.__class__, document=self) # Note: connection parameter is deprecated for multi-backend support # The backend is determined by _get_backend() which handles multi-backend connections self.validate() data = self.to_db() # Trigger pre_save_post_validation signal if SIGNAL_SUPPORT: pre_save_post_validation.send(self.__class__, document=self) is_new = not self.id # Get backend for this document backend = self._get_backend() table_name = self._get_collection_name() if self.id: # Update existing document - for now, use backend insert with ID # Most backends will treat this as an upsert result_data = await backend.insert(table_name, data) else: # Create new document result_data = await backend.insert(table_name, data) # Convert result to list format for consistency if result_data and not isinstance(result_data, list): result = [result_data] else: result = result_data or [] # Update the current instance with the returned data if result: if isinstance(result, list) and result: doc_data = result[0] else: doc_data = result # Update the instance's _data with the returned document if isinstance(doc_data, dict): # First update the raw data self._data.update(doc_data) # Make sure to capture the ID if it's a new document if 'id' in doc_data: self._data['id'] = doc_data['id'] # Then properly convert each field using its from_db method for field_name, field in self._fields.items(): if field_name in doc_data: self._data[field_name] = field.from_db(doc_data[field_name]) # Trigger post_save signal if SIGNAL_SUPPORT: post_save.send(self.__class__, document=self, created=is_new) return self
[docs] def save_sync(self: T, connection: Optional[Any] = None) -> T: """Save the document to the database synchronously. This method saves the document to the database, either creating a new document or updating an existing one based on whether the document has an ID. Args: connection: The database connection to use (optional) Returns: The saved document instance Raises: ValidationError: If the document fails validation """ # Trigger pre_save signal if SIGNAL_SUPPORT: pre_save.send(self.__class__, document=self) if connection is None: connection = ConnectionRegistry.get_default_connection() self.validate() data = self.to_db() # Trigger pre_save_post_validation signal if SIGNAL_SUPPORT: pre_save_post_validation.send(self.__class__, document=self) is_new = not self.id # For sync operations, fall back to direct client for now # TODO: Add sync backend methods or convert to async if self.id: del data['id'] id_part = str(self.id).split(':')[1] result = connection.client.upsert( RecordID(self._get_collection_name(), int(id_part) if id_part.isdigit() else id_part), data ) else: # Create new document result = connection.client.create( self._get_collection_name(), data ) # Update the current instance with the returned data if result: if isinstance(result, list) and result: doc_data = result[0] else: doc_data = result # Update the instance's _data with the returned document if isinstance(doc_data, dict): # First update the raw data self._data.update(doc_data) # Make sure to capture the ID if it's a new document if 'id' in doc_data: self._data['id'] = doc_data['id'] # Then properly convert each field using its from_db method for field_name, field in self._fields.items(): if field_name in doc_data: self._data[field_name] = field.from_db(doc_data[field_name]) # Trigger post_save signal if SIGNAL_SUPPORT: post_save.send(self.__class__, document=self, created=is_new) return self
[docs] async def delete(self, connection: Optional[Any] = None) -> bool: """Delete the document from the database asynchronously. This method deletes the document from the database. Args: connection: The database connection to use (optional) Returns: True if the document was deleted Raises: ValueError: If the document doesn't have an ID """ # Trigger pre_delete signal if SIGNAL_SUPPORT: pre_delete.send(self.__class__, document=self) if connection is None: connection = ConnectionRegistry.get_default_connection() if not self.id: raise ValueError("Cannot delete a document without an ID") # Get backend for this document backend = self._get_backend() table_name = self._get_collection_name() # Build condition to delete by ID id_condition = backend.build_condition('id', '=', self.id) deleted_count = await backend.delete(table_name, [id_condition]) # Trigger post_delete signal if SIGNAL_SUPPORT: post_delete.send(self.__class__, document=self) return True
[docs] def delete_sync(self, connection: Optional[Any] = None) -> bool: """Delete the document from the database synchronously. This method deletes the document from the database. Args: connection: The database connection to use (optional) Returns: True if the document was deleted Raises: ValueError: If the document doesn't have an ID """ # Trigger pre_delete signal if SIGNAL_SUPPORT: pre_delete.send(self.__class__, document=self) if connection is None: connection = ConnectionRegistry.get_default_connection() if not self.id: raise ValueError("Cannot delete a document without an ID") # For sync operations, fall back to direct client for now # TODO: Add sync backend methods or convert to async connection.client.delete(f"{self.id}") # Trigger post_delete signal if SIGNAL_SUPPORT: post_delete.send(self.__class__, document=self) return True
[docs] async def refresh(self, connection: Optional[Any] = None) -> 'Document': """Refresh the document from the database asynchronously. This method refreshes the document's data from the database. Args: connection: The database connection to use (optional) Returns: The refreshed document instance Raises: ValueError: If the document doesn't have an ID """ if connection is None: connection = ConnectionRegistry.get_default_connection() if not self.id: raise ValueError("Cannot refresh a document without an ID") result = await connection.client.select(f"{self.id}") if result: if isinstance(result, list) and result: doc = result[0] else: doc = result for field_name, field in self._fields.items(): db_field = field.db_field or field_name if db_field in doc: self._data[field_name] = field.from_db(doc[db_field]) self._changed_fields = [] return self
[docs] def refresh_sync(self, connection: Optional[Any] = None) -> 'Document': """Refresh the document from the database synchronously. This method refreshes the document's data from the database. Args: connection: The database connection to use (optional) Returns: The refreshed document instance Raises: ValueError: If the document doesn't have an ID """ if connection is None: connection = ConnectionRegistry.get_default_connection() if not self.id: raise ValueError("Cannot refresh a document without an ID") result = connection.client.select(f"{self.id}") if result: if isinstance(result, list) and result: doc = result[0] else: doc = result for field_name, field in self._fields.items(): db_field = field.db_field or field_name if db_field in doc: self._data[field_name] = field.from_db(doc[db_field]) self._changed_fields = [] return self
[docs] @classmethod def relates(cls, relation_name: str) -> callable: """Get a RelationQuerySet for a specific relation. This method returns a function that creates a RelationQuerySet for the specified relation name. The function can be called with an optional connection parameter. Args: relation_name: Name of the relation Returns: Function that creates a RelationQuerySet """ def relation_query_builder(connection: Optional[Any] = None) -> RelationQuerySet: """Create a RelationQuerySet for the specified relation. Args: connection: The database connection to use (optional) Returns: A RelationQuerySet for the relation """ if connection is None: connection = ConnectionRegistry.get_default_connection() return RelationQuerySet(cls, connection, relation=relation_name) return relation_query_builder
[docs] async def fetch_relation(self, relation_name: str, target_document: Optional[Type] = None, relation_document: Optional[Type] = None, connection: Optional[Any] = None, **filters: Any) -> List[Any]: """Fetch related documents asynchronously. This method fetches documents related to this document through the specified relation. Args: relation_name: Name of the relation target_document: The document class of the target documents (optional) relation_document: The document class representing the relation (optional) connection: The database connection to use (optional) **filters: Filters to apply to the related documents Returns: List of related documents, relation documents, or relation records """ if connection is None: connection = ConnectionRegistry.get_default_connection() relation_query = RelationQuerySet(self.__class__, connection, relation=relation_name) result = await relation_query.get_related(self, target_document, **filters) # If relation_document is specified, convert the relation records to RelationDocument instances if relation_document and not target_document: return [relation_document.from_db(record) for record in result] return result
[docs] def fetch_relation_sync(self, relation_name: str, target_document: Optional[Type] = None, relation_document: Optional[Type] = None, connection: Optional[Any] = None, **filters: Any) -> List[Any]: """Fetch related documents synchronously. This method fetches documents related to this document through the specified relation. Args: relation_name: Name of the relation target_document: The document class of the target documents (optional) relation_document: The document class representing the relation (optional) connection: The database connection to use (optional) **filters: Filters to apply to the related documents Returns: List of related documents, relation documents, or relation records """ if connection is None: connection = ConnectionRegistry.get_default_connection() relation_query = RelationQuerySet(self.__class__, connection, relation=relation_name) result = relation_query.get_related_sync(self, target_document, **filters) # If relation_document is specified, convert the relation records to RelationDocument instances if relation_document and not target_document: return [relation_document.from_db(record) for record in result] return result
[docs] async def resolve_relation(self, relation_name: str, target_document_class: Optional[Type] = None, relation_document: Optional[Type] = None, connection: Optional[Any] = None) -> List[Any]: """Resolve related documents from a relation fetch result asynchronously. This method resolves related documents from a relation fetch result. It fetches the relation data and then resolves each related document. Args: relation_name: Name of the relation to resolve target_document_class: Class of the target document (optional) relation_document: The document class representing the relation (optional) connection: Database connection to use (optional) Returns: List of resolved document instances """ if connection is None: connection = ConnectionRegistry.get_default_connection() # If relation_document is specified, convert the relation records to RelationDocument instances if relation_document and not target_document_class: return await self.fetch_relation(relation_name, relation_document=relation_document, connection=connection) # First fetch the relation data relation_data = await self.fetch_relation(relation_name, connection=connection) if not relation_data: return [] resolved_documents = [] if isinstance(relation_data, dict) and 'related' in relation_data and isinstance(relation_data['related'], list): for related_id in relation_data['related']: if isinstance(related_id, RecordID): collection = related_id.table_name record_id = related_id.id # Fetch the actual document try: result = await connection.client.select(related_id) if result and isinstance(result, list): doc = result[0] else: doc = result if doc: resolved_documents.append(doc) except Exception as e: logger.error(f"Error resolving document {collection}:{record_id}: {str(e)}") return resolved_documents
[docs] def resolve_relation_sync(self, relation_name: str, target_document_class: Optional[Type] = None, relation_document: Optional[Type] = None, connection: Optional[Any] = None) -> List[Any]: """Resolve related documents from a relation fetch result synchronously. This method resolves related documents from a relation fetch result. It fetches the relation data and then resolves each related document. Args: relation_name: Name of the relation to resolve target_document_class: Class of the target document (optional) relation_document: The document class representing the relation (optional) connection: Database connection to use (optional) Returns: List of resolved document instances """ if connection is None: connection = ConnectionRegistry.get_default_connection() # If relation_document is specified, convert the relation records to RelationDocument instances if relation_document and not target_document_class: return self.fetch_relation_sync(relation_name, relation_document=relation_document, connection=connection) # First fetch the relation data relation_data = self.fetch_relation_sync(relation_name, connection=connection) if not relation_data: return [] resolved_documents = [] if isinstance(relation_data, dict) and 'related' in relation_data and isinstance(relation_data['related'], list): for related_id in relation_data['related']: if isinstance(related_id, RecordID): collection = related_id.table_name record_id = related_id.id # Fetch the actual document try: result = connection.client.select(related_id) if result and isinstance(result, list): doc = result[0] else: doc = result if doc: resolved_documents.append(doc) except Exception as e: logger.error(f"Error resolving document {collection}:{record_id}: {str(e)}") return resolved_documents
[docs] async def relate_to(self, relation_name: str, target_instance: Any, connection: Optional[Any] = None, **attrs: Any) -> Optional[Any]: """Create a relation to another document asynchronously. This method creates a relation from this document to another document. Args: relation_name: Name of the relation target_instance: The document instance to relate to connection: The database connection to use (optional) **attrs: Attributes to set on the relation Returns: The created relation record or None if creation failed """ if connection is None: connection = ConnectionRegistry.get_default_connection() relation_query = RelationQuerySet(self.__class__, connection, relation=relation_name) return await relation_query.relate(self, target_instance, **attrs)
[docs] def relate_to_sync(self, relation_name: str, target_instance: Any, connection: Optional[Any] = None, **attrs: Any) -> Optional[Any]: """Create a relation to another document synchronously. This method creates a relation from this document to another document. Args: relation_name: Name of the relation target_instance: The document instance to relate to connection: The database connection to use (optional) **attrs: Attributes to set on the relation Returns: The created relation record or None if creation failed """ if connection is None: connection = ConnectionRegistry.get_default_connection() relation_query = RelationQuerySet(self.__class__, connection, relation=relation_name) return relation_query.relate_sync(self, target_instance, **attrs)
[docs] async def update_relation_to(self, relation_name: str, target_instance: Any, connection: Optional[Any] = None, **attrs: Any) -> Optional[Any]: """Update a relation to another document asynchronously. This method updates a relation from this document to another document. Args: relation_name: Name of the relation target_instance: The document instance the relation is to connection: The database connection to use (optional) **attrs: Attributes to update on the relation Returns: The updated relation record or None if update failed """ if connection is None: connection = ConnectionRegistry.get_default_connection() relation_query = RelationQuerySet(self.__class__, connection, relation=relation_name) return await relation_query.update_relation(self, target_instance, **attrs)
[docs] def update_relation_to_sync(self, relation_name: str, target_instance: Any, connection: Optional[Any] = None, **attrs: Any) -> Optional[Any]: """Update a relation to another document synchronously. This method updates a relation from this document to another document. Args: relation_name: Name of the relation target_instance: The document instance the relation is to connection: The database connection to use (optional) **attrs: Attributes to update on the relation Returns: The updated relation record or None if update failed """ if connection is None: connection = ConnectionRegistry.get_default_connection() relation_query = RelationQuerySet(self.__class__, connection, relation=relation_name) return relation_query.update_relation_sync(self, target_instance, **attrs)
[docs] async def delete_relation_to(self, relation_name: str, target_instance: Optional[Any] = None, connection: Optional[Any] = None) -> int: """Delete a relation to another document asynchronously. This method deletes a relation from this document to another document. If target_instance is not provided, it deletes all relations with the specified name from this document. Args: relation_name: Name of the relation target_instance: The document instance the relation is to (optional) connection: The database connection to use (optional) Returns: Number of deleted relations """ if connection is None: connection = ConnectionRegistry.get_default_connection() relation_query = RelationQuerySet(self.__class__, connection, relation=relation_name) return await relation_query.delete_relation(self, target_instance)
[docs] def delete_relation_to_sync(self, relation_name: str, target_instance: Optional[Any] = None, connection: Optional[Any] = None) -> int: """Delete a relation to another document synchronously. This method deletes a relation from this document to another document. If target_instance is not provided, it deletes all relations with the specified name from this document. Args: relation_name: Name of the relation target_instance: The document instance the relation is to (optional) connection: The database connection to use (optional) Returns: Number of deleted relations """ if connection is None: connection = ConnectionRegistry.get_default_connection() relation_query = RelationQuerySet(self.__class__, connection, relation=relation_name) return relation_query.delete_relation_sync(self, target_instance)
[docs] async def traverse_path(self, path_spec: str, target_document: Optional[Type] = None, connection: Optional[Any] = None, **filters: Any) -> List[Any]: """Traverse a path in the graph asynchronously. This method traverses a path in the graph starting from this document. The path_spec is a string like "->[watched]->->[acted_in]->" which describes a path through the graph. Args: path_spec: String describing the path to traverse target_document: The document class to return instances of (optional) connection: The database connection to use (optional) **filters: Filters to apply to the results Returns: List of documents or path results Raises: ValueError: If the document is not saved """ if connection is None: connection = ConnectionRegistry.get_default_connection() if not self.id: raise ValueError(f"Cannot traverse from unsaved {self.__class__.__name__}") start_id = f"{self.__class__._get_collection_name()}:{self.id}" if target_document: end_collection = target_document._get_collection_name() query = f"SELECT * FROM {end_collection} WHERE {path_spec}{start_id}" else: query = f"SELECT {path_spec} as path FROM {start_id}" # Add additional filters if provided if filters: conditions = [] for field, value in filters.items(): conditions.append(f"{field} = {json.dumps(value)}") if target_document: query += f" AND {' AND '.join(conditions)}" else: query += f" WHERE {' AND '.join(conditions)}" result = await connection.client.query(query) if not result or not result[0]: return [] # Process results based on query type if target_document: # Return list of related document instances return [target_document.from_db(doc) for doc in result[0]] else: # Return raw path results return result[0]
[docs] def traverse_path_sync(self, path_spec: str, target_document: Optional[Type] = None, connection: Optional[Any] = None, **filters: Any) -> List[Any]: """Traverse a path in the graph synchronously. This method traverses a path in the graph starting from this document. The path_spec is a string like "->[watched]->->[acted_in]->" which describes a path through the graph. Args: path_spec: String describing the path to traverse target_document: The document class to return instances of (optional) connection: The database connection to use (optional) **filters: Filters to apply to the results Returns: List of documents or path results Raises: ValueError: If the document is not saved """ if connection is None: connection = ConnectionRegistry.get_default_connection() if not self.id: raise ValueError(f"Cannot traverse from unsaved {self.__class__.__name__}") start_id = f"{self.__class__._get_collection_name()}:{self.id}" if target_document: end_collection = target_document._get_collection_name() query = f"SELECT * FROM {end_collection} WHERE {path_spec}{start_id}" else: query = f"SELECT {path_spec} as path FROM {start_id}" # Add additional filters if provided if filters: conditions = [] for field, value in filters.items(): conditions.append(f"{field} = {json.dumps(value)}") if target_document: query += f" AND {' AND '.join(conditions)}" else: query += f" WHERE {' AND '.join(conditions)}" result = connection.client.query(query) if not result or not result[0]: return [] # Process results based on query type if target_document: # Return list of related document instances return [target_document.from_db(doc) for doc in result[0]] else: # Return raw path results return result[0]
[docs] @classmethod async def bulk_create(self, documents: List[Any], batch_size: int = 1000, validate: bool = True, return_documents: bool = True, connection: Optional[Any] = None) -> \ Union[List[Any], int]: """Create multiple documents in batches. Args: documents: List of documents to create batch_size: Number of documents per batch validate: Whether to validate documents before creation return_documents: Whether to return created documents Returns: List of created documents if return_documents=True, else count of created documents """ results = [] total_count = 0 # Process documents in batches for i in range(0, len(documents), batch_size): batch = documents[i:i + batch_size] if validate: # Perform validation without using asyncio.gather since validate is not async for doc in batch: doc.validate() # Convert batch to DB representation data = [doc.to_db() for doc in batch] # Create the documents in the database using backend collection = batch[0]._get_collection_name() backend = batch[0]._get_backend() created = await backend.insert_many(collection, data) if created: if return_documents: # Convert created records back to documents for record in created: doc = self.from_db(record) results.append(doc) total_count += len(created) return results if return_documents else total_count
[docs] @classmethod def bulk_create_sync(cls, documents: List[Any], batch_size: int = 1000, validate: bool = True, return_documents: bool = True, connection: Optional[Any] = None) -> Union[List[Any], int]: """Create multiple documents in a single operation synchronously. This method creates multiple documents in a single operation, processing them in batches for better performance. It can optionally validate the documents and return the created documents. Args: documents: List of Document instances to create batch_size: Number of documents per batch (default: 1000) validate: Whether to validate documents (default: True) return_documents: Whether to return created documents (default: True) connection: The database connection to use (optional) Returns: List of created documents with their IDs set if return_documents=True, otherwise returns the count of created documents """ # Trigger pre_bulk_insert signal if SIGNAL_SUPPORT: pre_bulk_insert.send(cls, documents=documents) if connection is None: connection = ConnectionRegistry.get_default_connection() result = cls.objects(connection).bulk_create_sync( documents, batch_size=batch_size, validate=validate, return_documents=return_documents ) # Trigger post_bulk_insert signal if SIGNAL_SUPPORT: post_bulk_insert.send(cls, documents=documents, loaded=return_documents) return result
[docs] @classmethod async def create_index(cls, index_name: str, fields: List[str], unique: bool = False, search: bool = False, analyzer: Optional[str] = None, comment: Optional[str] = None, connection: Optional[Any] = None) -> None: """Create an index on the document's collection asynchronously. Args: index_name: Name of the index fields: List of field names to include in the index unique: Whether the index should enforce uniqueness search: Whether the index is a search index analyzer: Analyzer to use for search indexes comment: Optional comment for the index connection: Optional connection to use """ if connection is None: from .connection import ConnectionRegistry connection = ConnectionRegistry.get_default_connection() collection_name = cls._get_collection_name() fields_str = ", ".join(fields) # Build the index definition query = f"DEFINE INDEX {index_name} ON {collection_name} FIELDS {fields_str}" # Add index type if unique: query += " UNIQUE" elif search and analyzer: query += f" SEARCH ANALYZER {analyzer}" # Add comment if provided if comment: query += f" COMMENT '{comment}'" # Execute the query await connection.client.query(query)
[docs] @classmethod def create_index_sync(cls, index_name: str, fields: List[str], unique: bool = False, search: bool = False, analyzer: Optional[str] = None, comment: Optional[str] = None, connection: Optional[Any] = None) -> None: """Create an index on the document's collection synchronously. Args: index_name: Name of the index fields: List of field names to include in the index unique: Whether the index should enforce uniqueness search: Whether the index is a search index analyzer: Analyzer to use for search indexes comment: Optional comment for the index connection: Optional connection to use """ if connection is None: from .connection import ConnectionRegistry connection = ConnectionRegistry.get_default_connection() collection_name = cls._get_collection_name() fields_str = ", ".join(fields) # Build the index definition query = f"DEFINE INDEX {index_name} ON {collection_name} FIELDS {fields_str}" # Add index type if unique: query += " UNIQUE" elif search and analyzer: query += f" SEARCH ANALYZER {analyzer}" # Add comment if provided if comment: query += f" COMMENT '{comment}'" # Execute the query connection.client.query(query)
[docs] @classmethod async def create_indexes(cls, connection: Optional[Any] = None) -> None: """Create all indexes defined for this document class asynchronously. This method creates indexes defined in the Meta class and also creates indexes for fields marked as indexed. Args: connection: Optional connection to use """ connection = connection or ConnectionRegistry.get_default_connection() # Track processed multi-field indexes to avoid duplicates processed_multi_field_indexes = set() # Create indexes defined in Meta.indexes if hasattr(cls, '_meta') and 'indexes' in cls._meta and cls._meta['indexes']: for index_def in cls._meta['indexes']: # Handle different index definition formats if isinstance(index_def, dict): # Dictionary format with options index_name = index_def.get('name') fields = index_def.get('fields', []) unique = index_def.get('unique', False) search = index_def.get('search', False) analyzer = index_def.get('analyzer') comment = index_def.get('comment') elif isinstance(index_def, tuple) and len(index_def) >= 2: # Tuple format (name, fields, [unique]) index_name = index_def[0] fields = index_def[1] if isinstance(index_def[1], list) else [index_def[1]] unique = index_def[2] if len(index_def) > 2 else False search = False analyzer = None comment = None else: # Skip invalid index definitions continue await cls.create_index( index_name=index_name, fields=fields, unique=unique, search=search, analyzer=analyzer, comment=comment, connection=connection ) # Mark this index as processed to avoid duplicates if fields: processed_multi_field_indexes.add(tuple(sorted(fields))) # Create indexes for fields marked as indexed for field_name, field_obj in cls._fields.items(): if getattr(field_obj, 'indexed', False): db_field_name = field_obj.db_field or field_name # Check if this is a multi-field index index_with = getattr(field_obj, 'index_with', None) if index_with and isinstance(index_with, list) and len(index_with) > 0: # Get the actual field names for the index_with fields index_with_fields = [] for with_field_name in index_with: if with_field_name in cls._fields: with_field_obj = cls._fields[with_field_name] with_db_field_name = with_field_obj.db_field or with_field_name index_with_fields.append(with_db_field_name) else: # If the field doesn't exist, use the name as is index_with_fields.append(with_field_name) # Generate a unique identifier for this multi-field index # Sort fields to ensure consistent ordering all_fields = sorted([db_field_name] + index_with_fields) index_key = tuple(all_fields) # Skip if we've already processed this combination if index_key in processed_multi_field_indexes: continue # Mark this combination as processed processed_multi_field_indexes.add(index_key) # Generate a default index name index_name = f"{cls._get_collection_name()}_{'_'.join(all_fields)}_idx" # Get index options unique = getattr(field_obj, 'unique', False) search = getattr(field_obj, 'search', False) analyzer = getattr(field_obj, 'analyzer', None) # Create the multi-field index await cls.create_index( index_name=index_name, fields=all_fields, unique=unique, search=search, analyzer=analyzer, connection=connection ) else: # Create a single-field index # Skip if we've already processed this field if (db_field_name,) in processed_multi_field_indexes: continue # Mark this field as processed processed_multi_field_indexes.add((db_field_name,)) # Generate a default index name index_name = f"{cls._get_collection_name()}_{field_name}_idx" # Get index options unique = getattr(field_obj, 'unique', False) search = getattr(field_obj, 'search', False) analyzer = getattr(field_obj, 'analyzer', None) # Create the single-field index await cls.create_index( index_name=index_name, fields=[db_field_name], unique=unique, search=search, analyzer=analyzer, connection=connection )
[docs] @classmethod def create_indexes_sync(cls, connection: Optional[Any] = None) -> None: """Create all indexes defined for this document class synchronously. This method creates indexes defined in the Meta class and also creates indexes for fields marked as indexed. Args: connection: Optional connection to use """ connection = connection or ConnectionRegistry.get_default_connection() # Track processed multi-field indexes to avoid duplicates processed_multi_field_indexes = set() # Create indexes defined in Meta.indexes if hasattr(cls, '_meta') and 'indexes' in cls._meta and cls._meta['indexes']: for index_def in cls._meta['indexes']: # Handle different index definition formats if isinstance(index_def, dict): # Dictionary format with options index_name = index_def.get('name') fields = index_def.get('fields', []) unique = index_def.get('unique', False) search = index_def.get('search', False) analyzer = index_def.get('analyzer') comment = index_def.get('comment') elif isinstance(index_def, tuple) and len(index_def) >= 2: # Tuple format (name, fields, [unique]) index_name = index_def[0] fields = index_def[1] if isinstance(index_def[1], list) else [index_def[1]] unique = index_def[2] if len(index_def) > 2 else False search = False analyzer = None comment = None else: # Skip invalid index definitions continue cls.create_index_sync( index_name=index_name, fields=fields, unique=unique, search=search, analyzer=analyzer, comment=comment, connection=connection ) # Mark this index as processed to avoid duplicates if fields: processed_multi_field_indexes.add(tuple(sorted(fields))) # Create indexes for fields marked as indexed for field_name, field_obj in cls._fields.items(): if getattr(field_obj, 'indexed', False): db_field_name = field_obj.db_field or field_name # Check if this is a multi-field index index_with = getattr(field_obj, 'index_with', None) if index_with and isinstance(index_with, list) and len(index_with) > 0: # Get the actual field names for the index_with fields index_with_fields = [] for with_field_name in index_with: if with_field_name in cls._fields: with_field_obj = cls._fields[with_field_name] with_db_field_name = with_field_obj.db_field or with_field_name index_with_fields.append(with_db_field_name) else: # If the field doesn't exist, use the name as is index_with_fields.append(with_field_name) # Generate a unique identifier for this multi-field index # Sort fields to ensure consistent ordering all_fields = sorted([db_field_name] + index_with_fields) index_key = tuple(all_fields) # Skip if we've already processed this combination if index_key in processed_multi_field_indexes: continue # Mark this combination as processed processed_multi_field_indexes.add(index_key) # Generate a default index name index_name = f"{cls._get_collection_name()}_{'_'.join(all_fields)}_idx" # Get index options unique = getattr(field_obj, 'unique', False) search = getattr(field_obj, 'search', False) analyzer = getattr(field_obj, 'analyzer', None) # Create the multi-field index cls.create_index_sync( index_name=index_name, fields=all_fields, unique=unique, search=search, analyzer=analyzer, connection=connection ) else: # Create a single-field index # Skip if we've already processed this field if (db_field_name,) in processed_multi_field_indexes: continue # Mark this field as processed processed_multi_field_indexes.add((db_field_name,)) # Generate a default index name index_name = f"{cls._get_collection_name()}_{field_name}_idx" # Get index options unique = getattr(field_obj, 'unique', False) search = getattr(field_obj, 'search', False) analyzer = getattr(field_obj, 'analyzer', None) # Create the single-field index cls.create_index_sync( index_name=index_name, fields=[db_field_name], unique=unique, search=search, analyzer=analyzer, connection=connection )
@classmethod def _get_field_type_for_surreal(cls, field: Field) -> str: """Get the SurrealDB type for a field. Args: field: The field to get the type for Returns: The SurrealDB type as a string """ from .fields import ( StringField, IntField, FloatField, BooleanField, DateTimeField, ListField, DictField, ReferenceField, GeometryField, RelationField, DecimalField, DurationField, BytesField, RegexField, OptionField, FutureField, UUIDField, TableField, RecordIDField ) if isinstance(field, StringField): return "string" elif isinstance(field, IntField): return "int" elif isinstance(field, FloatField) or isinstance(field, DecimalField): return "float" elif isinstance(field, BooleanField): return "bool" elif isinstance(field, DateTimeField): return "datetime" elif isinstance(field, DurationField): return "duration" elif isinstance(field, ListField): if field.field_type: inner_type = cls._get_field_type_for_surreal(field.field_type) return f"array<{inner_type}>" return "array" elif isinstance(field, DictField): return "object" elif isinstance(field, ReferenceField): # Get the target collection name target_cls = field.document_type target_collection = target_cls._get_collection_name() return f"record<{target_collection}>" elif isinstance(field, RelationField): # Get the target collection name target_cls = field.to_document target_collection = target_cls._get_collection_name() return f"record<{target_collection}>" elif isinstance(field, GeometryField): return "geometry" elif isinstance(field, BytesField): return "bytes" elif isinstance(field, RegexField): return "regex" elif isinstance(field, OptionField): if field.field_type: inner_type = cls._get_field_type_for_surreal(field.field_type) return f"option<{inner_type}>" return "option" elif isinstance(field, UUIDField): return "uuid" elif isinstance(field, TableField): return "table" elif isinstance(field, RecordIDField): return "record" elif isinstance(field, FutureField): return "any" # Future fields are computed at query time # Default to any type if we can't determine a specific type return "any"
[docs] @classmethod async def create_table(cls, connection: Optional[Any] = None, schemafull: bool = True) -> None: """Create the table for this document class asynchronously. Args: connection: Optional connection to use schemafull: Whether to create a SCHEMAFULL table (default: True) """ if connection is None: connection = ConnectionRegistry.get_default_connection() collection_name = cls._get_collection_name() # Create the table schema_type = "SCHEMAFULL" if schemafull else "SCHEMALESS" query = f"DEFINE TABLE {collection_name} {schema_type}" # Check if this is a time series table is_time_series = False time_field = None # Check if the Meta class has time_series and time_field attributes if hasattr(cls, '_meta'): is_time_series = cls._meta.get('time_series', False) time_field = cls._meta.get('time_field') # If time_series is True but time_field is not specified, try to find a TimeSeriesField if is_time_series and not time_field: for field_name, field in cls._fields.items(): if field.__class__.__name__ == 'TimeSeriesField': time_field = field.db_field break # Add time series configuration if applicable if is_time_series and time_field: query += f" TYPE TIMESTAMP TIMEFIELD {time_field}" # Add comment if available if hasattr(cls, '__doc__') and cls.__doc__: # Clean up docstring and escape single quotes doc = cls.__doc__.strip().replace("'", "''") if doc: query += f" COMMENT '{doc}'" await connection.client.query(query) # Create fields if schemafull or if field is marked with define_schema=True for field_name, field in cls._fields.items(): # Skip id field as it's handled by SurrealDB if field_name == cls._meta.get('id_field', 'id'): continue # Only define fields if schemafull or if field is explicitly marked for schema definition if schemafull or field.define_schema: field_type = cls._get_field_type_for_surreal(field) field_query = f"DEFINE FIELD {field.db_field} ON {collection_name} TYPE {field_type}" # Add constraints if field.required: field_query += " ASSERT $value != NONE" await connection.client.query(field_query) # Handle nested fields for DictField if isinstance(field, DictField) and schemafull: if field.db_field == 'settings': nested_field_query = f"DEFINE FIELD {field.db_field}.theme ON {collection_name} TYPE string" await connection.client.query(nested_field_query)
[docs] @classmethod def create_table_sync(cls, connection: Optional[Any] = None, schemafull: bool = True) -> None: """Create the table for this document class synchronously.""" if connection is None: from .connection import ConnectionRegistry connection = ConnectionRegistry.get_default_connection() collection_name = cls._get_collection_name() # Create the table schema_type = "SCHEMAFULL" if schemafull else "SCHEMALESS" query = f"DEFINE TABLE {collection_name} {schema_type}" # Check if this is a time series table is_time_series = False time_field = None # Check if the Meta class has time_series and time_field attributes if hasattr(cls, '_meta'): is_time_series = cls._meta.get('time_series', False) time_field = cls._meta.get('time_field') # If time_series is True but time_field is not specified, try to find a TimeSeriesField if is_time_series and not time_field: for field_name, field in cls._fields.items(): if field.__class__.__name__ == 'TimeSeriesField': time_field = field.db_field break # Add time series configuration if applicable if is_time_series and time_field: query += f" TYPE TIMESTAMP TIMEFIELD {time_field}" # Add comment if available if hasattr(cls, '__doc__') and cls.__doc__: # Clean up docstring: remove newlines, extra spaces, and escape quotes doc = ' '.join(cls.__doc__.strip().split()) doc = doc.replace("'", "''") if doc: query += f" COMMENT '{doc}'" connection.client.query(query) # Create fields if schemafull or if field is marked with define_schema=True for field_name, field in cls._fields.items(): # Skip id field as it's handled by SurrealDB if field_name == cls._meta.get('id_field', 'id'): continue # Only define fields if schemafull or if field is explicitly marked for schema definition if schemafull or field.define_schema: field_type = cls._get_field_type_for_surreal(field) field_query = f"DEFINE FIELD {field.db_field} ON {collection_name} TYPE {field_type}" # Add constraints if field.required: field_query += " ASSERT $value != NONE" # Add comment if available if hasattr(field, '__doc__') and field.__doc__: # Clean up docstring: remove newlines, extra spaces, and escape quotes doc = ' '.join(field.__doc__.strip().split()) doc = doc.replace("'", "''") if doc: field_query += f" COMMENT '{doc}'" connection.client.query(field_query) # Handle nested fields for DictField if isinstance(field, DictField) and schemafull: if field.db_field == 'settings': nested_field_query = f"DEFINE FIELD {field.db_field}.theme ON {collection_name} TYPE string" connection.client.query(nested_field_query)
[docs] @classmethod async def drop_table(cls, connection: Optional[Any] = None, if_exists: bool = True) -> None: """Drop the table for this document class asynchronously. Args: connection: Optional connection to use if_exists: Whether to use IF EXISTS clause to avoid errors if table doesn't exist """ if connection is None: connection = ConnectionRegistry.get_default_connection() collection_name = cls._get_collection_name() # Use the backend's drop_table method await connection.backend.drop_table(collection_name, if_exists=if_exists)
[docs] @classmethod def drop_table_sync(cls, connection: Optional[Any] = None, if_exists: bool = True) -> None: """Drop the table for this document class synchronously. Args: connection: Optional connection to use if_exists: Whether to use IF EXISTS clause to avoid errors if table doesn't exist """ if connection is None: from .connection import ConnectionRegistry connection = ConnectionRegistry.get_default_connection() collection_name = cls._get_collection_name() # Use the backend's drop_table method synchronously import asyncio if hasattr(connection.backend, 'drop_table'): asyncio.run(connection.backend.drop_table(collection_name, if_exists=if_exists)) else: # Fallback for older backends if if_exists: query = f"REMOVE TABLE IF EXISTS {collection_name}" else: query = f"REMOVE TABLE {collection_name}" connection.client.query(query)
[docs] @classmethod def to_dataclass(cls): """Convert the document class to a dataclass. This method creates a dataclass based on the document's fields. It uses the field names, types, and whether they are required. Required fields have no default value, making them required during initialization. Non-required fields use None as default if they don't define one. A __post_init__ method is added to validate all fields after initialization. Returns: A dataclass type based on the document's fields """ fields = [('id', Optional[str], dataclass_field(default=None))] # Process fields for field_name, field_obj in cls._fields.items(): # Skip id field as it's handled separately if field_name == cls._meta.get('id_field', 'id'): continue # For required fields, don't provide a default value if field_obj.required: fields.insert(0, (field_name, field_obj.py_type)) # For fields with a non-callable default, use that default elif field_obj.default is not None and not callable(field_obj.default): fields.append((field_name, field_obj.py_type, dataclass_field(default=field_obj.default))) # For other fields, use None as default else: fields.append((field_name, field_obj.py_type, dataclass_field(default=None))) # Define the __post_init__ method to validate fields def post_init(self): """Validate all fields after initialization.""" for field_name, field_obj in cls._fields.items(): value = getattr(self, field_name, None) field_obj.validate(value) # Create the dataclass using make_dataclass return make_dataclass( cls_name=f"{cls.__name__}_Dataclass", fields=fields, namespace={"__post_init__": post_init} )
[docs] @classmethod def create_materialized_view(cls, name: str, query: QuerySet, refresh_interval: str = None, aggregations=None, select_fields=None, **kwargs): """Create a materialized view based on a query. This method creates a materialized view in SurrealDB based on a query. Materialized views are precomputed views of data that can be used to improve query performance for frequently accessed aggregated data. Args: name: The name of the materialized view query: The query that defines the materialized view refresh_interval: The interval at which the view is refreshed (e.g., "1h", "30m") aggregations: Dictionary of field names and aggregation functions select_fields: List of fields to select (if None, selects all fields) **kwargs: Additional keyword arguments to pass to the MaterializedView constructor Returns: A MaterializedView instance """ from .materialized_view import MaterializedView, Count, Mean, Sum, Min, Max, ArrayCollect # Process aggregations if provided as keyword arguments if aggregations is None: aggregations = {} # Check for aggregation functions in kwargs for key, value in kwargs.items(): if key.startswith('count_'): field_name = key[6:] # Remove 'count_' prefix aggregations[field_name] = Count() elif key.startswith('mean_'): field_name = key[5:] # Remove 'mean_' prefix field = kwargs.get(key) aggregations[field_name] = Mean(field) elif key.startswith('sum_'): field_name = key[4:] # Remove 'sum_' prefix field = kwargs.get(key) aggregations[field_name] = Sum(field) elif key.startswith('min_'): field_name = key[4:] # Remove 'min_' prefix field = kwargs.get(key) aggregations[field_name] = Min(field) elif key.startswith('max_'): field_name = key[4:] # Remove 'max_' prefix field = kwargs.get(key) aggregations[field_name] = Max(field) elif key.startswith('collect_'): field_name = key[8:] # Remove 'collect_' prefix field = kwargs.get(key) aggregations[field_name] = ArrayCollect(field) return MaterializedView(name, query, refresh_interval, cls, aggregations, select_fields)
@classmethod def _get_document_class_for_collection(cls, collection_name: str) -> Optional[Type['Document']]: """Get the document class for a collection name. This method looks up the document class for a given collection name in the document registry. If no class is found, it returns None. Args: collection_name: The name of the collection Returns: The document class for the collection, or None if not found """ # Initialize the document registry if it doesn't exist if not hasattr(cls, '_document_registry'): cls._document_registry = {} # Populate the registry with all existing document classes def register_subclasses(doc_class): for subclass in doc_class.__subclasses__(): if hasattr(subclass, '_meta') and not subclass._meta.get('abstract', False): collection = subclass._meta.get('collection') if collection: cls._document_registry[collection] = subclass register_subclasses(subclass) # Start with Document subclasses register_subclasses(cls) # Handle RecordID objects if isinstance(collection_name, RecordID): collection_name = collection_name.table_name # Handle string IDs in the format "collection:id" elif isinstance(collection_name, str) and ':' in collection_name: collection_name = collection_name.split(':', 1)[0] # Look up the document class in the registry return cls._document_registry.get(collection_name)
[docs] class RelationDocument(Document): """A Document that represents a relationship between two documents. RelationDocuments should be used to model relationships with additional attributes. They can be used with Document.relates(), Document.fetch_relation(), and Document.resolve_relation(). """
[docs] class Meta: """Meta options for RelationDocument.""" abstract = True
in_document = ReferenceField(Document, required=True, db_field="in") out_document = ReferenceField(Document, required=True, db_field="out")
[docs] @classmethod def get_relation_name(cls) -> str: """Get the name of the relation. By default, this is the lowercase name of the class. Override this method to customize the relation name. Returns: The name of the relation """ return cls._meta.get('collection')
[docs] @classmethod def relates(cls, from_document: Optional[Type] = None, to_document: Optional[Type] = None) -> callable: """Get a RelationQuerySet for this relation. This method returns a function that creates a RelationQuerySet for this relation. The function can be called with an optional connection parameter. Args: from_document: The document class the relation is from (optional) to_document: The document class the relation is to (optional) Returns: Function that creates a RelationQuerySet """ relation_name = cls.get_relation_name() def relation_query_builder(connection: Optional[Any] = None) -> 'RelationQuerySet': """Create a RelationQuerySet for this relation. Args: connection: The database connection to use (optional) Returns: A RelationQuerySet for the relation """ if connection is None: connection = ConnectionRegistry.get_default_connection() return RelationQuerySet(from_document or Document, connection, relation=relation_name) return relation_query_builder
[docs] @classmethod async def create_relation(cls, from_instance: Any, to_instance: Any, **attrs: Any) -> 'RelationDocument': """Create a relation between two instances asynchronously. This method creates a relation between two document instances and returns a RelationDocument instance representing the relationship. Args: from_instance: The instance to create the relation from to_instance: The instance to create the relation to **attrs: Attributes to set on the relation Returns: A RelationDocument instance representing the relationship Raises: ValueError: If either instance is not saved """ if not from_instance.id: raise ValueError(f"Cannot create relation from unsaved {from_instance.__class__.__name__}") if not to_instance.id: raise ValueError(f"Cannot create relation to unsaved {to_instance.__class__.__name__}") # Create the relation using Document.relate_to relation = await from_instance.relate_to(cls.get_relation_name(), to_instance, **attrs) # Create a RelationDocument instance from the relation data relation_doc = cls( in_document=from_instance, out_document=to_instance, **attrs ) # Set the ID from the relation if relation and 'id' in relation: relation_doc.id = relation['id'] return relation_doc
[docs] @classmethod def create_relation_sync(cls, from_instance: Any, to_instance: Any, **attrs: Any) -> 'RelationDocument': """Create a relation between two instances synchronously. This method creates a relation between two document instances and returns a RelationDocument instance representing the relationship. Args: from_instance: The instance to create the relation from to_instance: The instance to create the relation to **attrs: Attributes to set on the relation Returns: A RelationDocument instance representing the relationship Raises: ValueError: If either instance is not saved """ if not from_instance.id: raise ValueError(f"Cannot create relation from unsaved {from_instance.__class__.__name__}") if not to_instance.id: raise ValueError(f"Cannot create relation to unsaved {to_instance.__class__.__name__}") # Create the relation using Document.relate_to_sync relation = from_instance.relate_to_sync(cls.get_relation_name(), to_instance, **attrs) # Create a RelationDocument instance from the relation data relation_doc = cls( in_document=from_instance, out_document=to_instance, **attrs ) # Set the ID from the relation if relation and 'id' in relation: relation_doc.id = relation['id'] return relation_doc
[docs] @classmethod def find_by_in_document(cls, in_doc, **additional_filters): """ Query RelationDocument by in_document field. Args: in_doc: The document instance or ID to filter by **additional_filters: Additional filters to apply Returns: QuerySet filtered by in_document """ # Get the default connection connection = ConnectionRegistry.get_default_connection() queryset = QuerySet(cls, connection) # Apply the in_document filter and any additional filters filters = {'in': in_doc, **additional_filters} return queryset.filter(**filters)
[docs] @classmethod def find_by_in_document_sync(cls, in_doc, **additional_filters): """ Query RelationDocument by in_document field synchronously. Args: in_doc: The document instance or ID to filter by **additional_filters: Additional filters to apply Returns: QuerySet filtered by in_document """ # Get the default connection connection = ConnectionRegistry.get_default_connection() queryset = QuerySet(cls, connection) # Apply the in_document filter and any additional filters filters = {'in': in_doc, **additional_filters} return queryset.filter(**filters)
[docs] async def resolve_out(self, connection=None): """Resolve the out_document field asynchronously. This method resolves the out_document field if it's currently just an ID reference. If the out_document is already a document instance, it returns it directly. Args: connection: Database connection to use (optional) Returns: The resolved out_document instance """ # If out_document is already a document instance, return it if isinstance(self.out_document, Document): return self.out_document # Get the connection if not provided if connection is None: connection = ConnectionRegistry.get_default_connection() # If out_document is a string ID, fetch the document if isinstance(self.out_document, str) and ':' in self.out_document: try: # Fetch the document using the ID result = await connection.client.select(self.out_document) # Process the result if result: if isinstance(result, list) and result: doc = result[0] else: doc = result return doc except Exception as e: logger.error(f"Error resolving out_document {self.out_document}: {str(e)}") elif isinstance(self.out_document, RecordID): try: result = await connection.client.select(self.out_document) if result: if isinstance(result, list) and result: doc = result[0] else: doc = result return doc except Exception as e: logger.error(f"Error resolving out_document {self.out_document}: {str(e)}") # Return the current value if resolution failed return self.out_document
[docs] def resolve_out_sync(self, connection=None): """Resolve the out_document field synchronously. This method resolves the out_document field if it's currently just an ID reference. If the out_document is already a document instance, it returns it directly. Args: connection: Database connection to use (optional) Returns: The resolved out_document instance """ # If out_document is already a document instance, return it if isinstance(self.out_document, Document): return self.out_document # Get the connection if not provided if connection is None: connection = ConnectionRegistry.get_default_connection() # If out_document is a string ID, fetch the document if isinstance(self.out_document, str) and ':' in self.out_document: try: # Fetch the document using the ID result = connection.client.select(self.out_document) # Process the result if result: if isinstance(result, list) and result: doc = result[0] else: doc = result return doc except Exception as e: logger.error(f"Error resolving out_document {self.out_document}: {str(e)}") elif isinstance(self.out_document, RecordID): try: result = connection.client.select(self.out_document) if result: if isinstance(result, list) and result: doc = result[0] else: doc = result return doc except Exception as e: logger.error(f"Error resolving out_document {self.out_document}: {str(e)}") # Return the current value if resolution failed return self.out_document