import json
from typing import Any, Dict, List, Optional, Tuple, Union, Type, cast
from .exceptions import MultipleObjectsReturned, DoesNotExist
from surrealdb import RecordID
from .pagination import PaginationResult
from .record_id_utils import RecordIdUtils
from .surrealql import escape_literal
# Import these at runtime to avoid circular imports
def _get_connection_classes():
from .connection import SurrealEngineAsyncConnection, SurrealEngineSyncConnection
return SurrealEngineAsyncConnection, SurrealEngineSyncConnection
[docs]
class BaseQuerySet:
"""Base query builder for SurrealDB.
This class provides the foundation for building queries in SurrealDB.
It includes methods for filtering, limiting, ordering, and retrieving results.
Subclasses must implement specific methods like _build_query, all, and count.
Attributes:
connection: The database connection to use for queries
query_parts: List of query conditions (field, operator, value)
limit_value: Maximum number of results to return
start_value: Number of results to skip (for pagination)
order_by_value: Field and direction to order results by
group_by_fields: Fields to group results by
split_fields: Fields to split results by
fetch_fields: Fields to fetch related records for
with_index: Index to use for the query
"""
[docs]
def __init__(self, connection: Any) -> None:
"""Initialize a new BaseQuerySet.
Args:
connection: The database connection to use for queries
"""
self.connection = connection
self.query_parts: List[Tuple[str, str, Any]] = []
self.limit_value: Optional[int] = None
self.start_value: Optional[int] = None
self.order_by_value: Optional[Tuple[str, str]] = None
self.group_by_fields: List[str] = []
self.split_fields: List[str] = []
self.fetch_fields: List[str] = []
self.with_index: Optional[str] = None
self.with_index: Optional[str] = None
self.select_fields: Optional[List[str]] = None
self.omit_fields: List[str] = []
self.timeout_value: Optional[str] = None
self.tempfiles_value: bool = False
self.explain_value: bool = False
self.explain_full_value: bool = False
self.group_by_all: bool = False
# Graph traversal state
self._traversal_path: Optional[str] = None
self._traversal_unique: bool = True
self._traversal_max_depth: Optional[int] = None
# Performance optimization attributes
self._bulk_id_selection: Optional[List[Any]] = None
self._id_range_selection: Optional[Tuple[Any, Any, bool]] = None
self._prefer_direct_access: bool = False
[docs]
def is_async_connection(self) -> bool:
"""Check if the connection is asynchronous.
Returns:
True if the connection is asynchronous, False otherwise
"""
SurrealEngineAsyncConnection, SurrealEngineSyncConnection = _get_connection_classes()
return isinstance(self.connection, SurrealEngineAsyncConnection)
[docs]
def filter(self, query=None, **kwargs) -> 'BaseQuerySet':
"""Add filter conditions to the query with automatic ID optimization.
This method supports both Q objects and Django-style field lookups with double-underscore operators:
- field__gt: Greater than
- field__lt: Less than
- field__gte: Greater than or equal
- field__lte: Less than or equal
- field__ne: Not equal
- field__in: Inside (for arrays) - optimized for ID fields
- field__nin: Not inside (for arrays)
- field__contains: Contains (for strings or arrays)
- field__startswith: Starts with (for strings)
- field__endswith: Ends with (for strings)
- field__regex: Matches regex pattern (for strings)
PERFORMANCE OPTIMIZATIONS:
- id__in automatically uses direct record access syntax
- ID range queries (id__gte + id__lte) use range syntax
Args:
query: Q object or QueryExpression for complex queries
**kwargs: Field names and values to filter by
Returns:
A new queryset instance for method chaining
Raises:
ValueError: If an unknown operator is provided
"""
# Clone first to avoid mutating the original queryset
result = self if (query is None and not kwargs) else self._clone()
# Handle Q objects and QueryExpressions
if query is not None:
# Import here to avoid circular imports
try:
from .query_expressions import Q, QueryExpression
if isinstance(query, Q):
# Use to_where_clause() to properly handle both simple and compound Q objects
where_clause = query.to_where_clause()
if where_clause:
result.query_parts.append(('__raw__', '=', where_clause))
# Don't return early - continue to process kwargs if provided
elif isinstance(query, QueryExpression):
# Apply QueryExpression to this queryset
return query.apply_to_queryset(result)
else:
raise ValueError(f"Unsupported query type: {type(query)}")
except ImportError:
raise ValueError("Query expressions not available")
# Process kwargs (either standalone or combined with a Q object)
if not kwargs:
return result
# Continue with existing kwargs processing
# PERFORMANCE OPTIMIZATION: Check for bulk ID operations
if len(kwargs) == 1 and 'id__in' in kwargs:
result._bulk_id_selection = kwargs['id__in']
return result
# PERFORMANCE OPTIMIZATION: Check for ID range operations
id_range_keys = {k for k in kwargs.keys() if k.startswith('id__') and k.endswith(('gte', 'lte', 'gt', 'lt'))}
if len(kwargs) == 2 and len(id_range_keys) == 2:
if 'id__gte' in kwargs and 'id__lte' in kwargs:
result._id_range_selection = (kwargs['id__gte'], kwargs['id__lte'], True) # inclusive
return result
elif 'id__gt' in kwargs and 'id__lt' in kwargs:
result._id_range_selection = (kwargs['id__gt'], kwargs['id__lt'], False) # exclusive
return result
# Fall back to regular filtering for non-optimizable queries
for k, v in kwargs.items():
if k == 'id':
# Use RecordIdUtils for comprehensive ID handling
table_name = None
if hasattr(self, 'document_class') and self.document_class:
table_name = self.document_class._get_collection_name()
normalized_id = RecordIdUtils.normalize_record_id(v, table_name)
if normalized_id:
result.query_parts.append((k, '=', normalized_id))
else:
# Fall back to original value if normalization fails
result.query_parts.append((k, '=', str(v)))
continue
# Special handling for URL fields - mark them with a special tag
if k == 'url' or (isinstance(v, str) and (v.startswith('http://') or v.startswith('https://'))):
# Add a special tag to indicate this is a URL that needs quoting
result.query_parts.append((k, '=', {'__url_value__': v}))
continue
parts = k.split('__')
field = parts[0]
# Handle operators
if len(parts) > 1:
op = parts[1]
if op == 'gt':
result.query_parts.append((field, '>', v))
elif op == 'lt':
result.query_parts.append((field, '<', v))
elif op == 'gte':
result.query_parts.append((field, '>=', v))
elif op == 'lte':
result.query_parts.append((field, '<=', v))
elif op == 'ne':
result.query_parts.append((field, '!=', v))
elif op == 'in':
# Note: id__in is handled by optimization above
result.query_parts.append((field, 'INSIDE', v))
elif op == 'nin':
result.query_parts.append((field, 'NOT INSIDE', v))
elif op == 'contains':
if isinstance(v, str):
result.query_parts.append((f"string::contains({field}, '{v}')", '=', True))
else:
result.query_parts.append((field, 'CONTAINS', v))
elif op == 'startswith':
result.query_parts.append((f"string::starts_with({field}, '{v}')", '=', True))
elif op == 'endswith':
result.query_parts.append((f"string::ends_with({field}, '{v}')", '=', True))
elif op == 'regex':
result.query_parts.append((f"string::matches({field}, r'{v}')", '=', True))
# New operators for contains/inside variants
elif op == 'contains_any':
result.query_parts.append((field, 'CONTAINSANY', v))
elif op == 'contains_all':
result.query_parts.append((field, 'CONTAINSALL', v))
elif op == 'contains_none':
result.query_parts.append((field, 'CONTAINSNONE', v))
elif op == 'inside':
result.query_parts.append((field, 'INSIDE', v))
elif op == 'not_inside':
result.query_parts.append((field, 'NOT INSIDE', v))
elif op == 'all_inside':
result.query_parts.append((field, 'ALLINSIDE', v))
elif op == 'any_inside':
result.query_parts.append((field, 'ANYINSIDE', v))
elif op == 'none_inside':
result.query_parts.append((field, 'NONEINSIDE', v))
else:
# Handle nested field access for DictFields
document_class = getattr(self, 'document_class', None)
if document_class and hasattr(document_class, '_fields'):
if field in document_class._fields:
from .fields import DictField
if isinstance(document_class._fields[field], DictField):
nested_field = f"{field}.{op}"
result.query_parts.append((nested_field, '=', v))
continue
# If we get here, it's an unknown operator
raise ValueError(f"Unknown operator: {op}")
else:
# Simple equality
result.query_parts.append((field, '=', v))
return result
[docs]
def only(self, *fields: str) -> 'BaseQuerySet':
"""Select only the specified fields.
This method sets the fields to be selected in the query.
It automatically includes the 'id' field.
Args:
*fields: Field names to select
Returns:
The query set instance for method chaining
"""
clone = self._clone()
select_fields = list(fields)
if 'id' not in select_fields:
select_fields.append('id')
clone.select_fields = select_fields
clone.select_fields = select_fields
return clone
[docs]
def omit(self, *fields: str) -> 'BaseQuerySet':
"""Exclude specific fields from the results.
Args:
*fields: Field names to exclude
Returns:
The query set instance for method chaining
"""
clone = self._clone()
clone.omit_fields.extend(fields)
return clone
[docs]
def limit(self, value: int) -> 'BaseQuerySet':
"""Set the maximum number of results to return.
Args:
value: Maximum number of results
Returns:
The query set instance for method chaining
"""
self.limit_value = value
return self
[docs]
def start(self, value: int) -> 'BaseQuerySet':
"""Set the number of results to skip (for pagination).
Args:
value: Number of results to skip
Returns:
The query set instance for method chaining
"""
self.start_value = value
return self
[docs]
def order_by(self, field: str, direction: str = 'ASC') -> 'BaseQuerySet':
"""Set the field and direction to order results by.
Args:
field: Field name to order by
direction: Direction to order by ('ASC' or 'DESC')
Returns:
The query set instance for method chaining
"""
self.order_by_value = (field, direction)
return self
return self
[docs]
def group_by(self, *fields: str, all: bool = False) -> 'BaseQuerySet':
"""Group the results by the specified fields or group all.
This method sets the fields to group the results by using the GROUP BY clause.
Args:
*fields: Field names to group by
all: If True, use GROUP ALL (SurrealDB v2.0.0+)
Returns:
The query set instance for method chaining
"""
self.group_by_fields.extend(fields)
self.group_by_all = all
return self
[docs]
def split(self, *fields: str) -> 'BaseQuerySet':
"""Split the results by the specified fields.
This method sets the fields to split the results by using the SPLIT clause.
Args:
*fields: Field names to split by
Returns:
The query set instance for method chaining
"""
self.split_fields.extend(fields)
return self
[docs]
def fetch(self, *fields: str) -> 'BaseQuerySet':
"""Fetch related records for the specified fields.
This method sets the fields to fetch related records for using the FETCH clause.
Args:
*fields: Field names to fetch related records for
Returns:
The query set instance for method chaining
"""
self.fetch_fields.extend(fields)
return self
[docs]
def get_many(self, ids: List[Union[str, Any]]) -> 'BaseQuerySet':
"""Get multiple records by IDs using optimized direct record access.
This method uses SurrealDB's direct record selection syntax for better
performance compared to WHERE clause filtering.
Args:
ids: List of record IDs (can be strings or other ID types)
Returns:
The query set instance configured for direct record access
Example:
# Efficient: SELECT * FROM users:1, users:2, users:3
users = await User.objects.get_many([1, 2, 3]).all()
users = await User.objects.get_many(['users:1', 'users:2']).all()
"""
clone = self._clone()
clone._bulk_id_selection = ids
return clone
[docs]
def get_range(self, start_id: Union[str, Any], end_id: Union[str, Any],
inclusive: bool = True) -> 'BaseQuerySet':
"""Get a range of records by ID using optimized range syntax.
This method uses SurrealDB's range selection syntax for better
performance compared to WHERE clause filtering.
Args:
start_id: Starting ID of the range
end_id: Ending ID of the range
inclusive: Whether the range is inclusive (default: True)
Returns:
The query set instance configured for range access
Example:
# Efficient: SELECT * FROM users:100..=200
users = await User.objects.get_range(100, 200).all()
users = await User.objects.get_range('users:100', 'users:200', inclusive=False).all()
"""
clone = self._clone()
clone._id_range_selection = (start_id, end_id, inclusive)
return clone
[docs]
def with_index(self, index: str) -> 'BaseQuerySet':
"""Use the specified index for the query.
This method sets the index to use for the query using the WITH clause.
Args:
index: Name of the index to use
Returns:
The query set instance for method chaining
"""
self.with_index = index
return self
[docs]
def no_index(self) -> 'BaseQuerySet':
"""Do not use any index for the query.
This method adds the WITH NOINDEX clause to the query.
Returns:
The query set instance for method chaining
"""
self.with_index = "NOINDEX"
return self
[docs]
def timeout(self, duration: str) -> 'BaseQuerySet':
"""Set a timeout for the query execution.
Args:
duration: Duration string (e.g. "5s", "1m")
Returns:
The query set instance for method chaining
"""
self.timeout_value = duration
return self
[docs]
def tempfiles(self, value: bool = True) -> 'BaseQuerySet':
"""Enable or disable using temporary files for large queries.
Args:
value: Whether to use tempfiles (default: True)
Returns:
The query set instance for method chaining
"""
self.tempfiles_value = value
return self
[docs]
def with_explain(self, full: bool = False) -> 'BaseQuerySet':
"""Explain the query execution plan (builder pattern).
Args:
full: Whether to include full explanation including execution trace (default: False)
Returns:
The query set instance for method chaining
"""
self.explain_value = True
self.explain_full_value = full
return self
[docs]
def use_direct_access(self) -> 'BaseQuerySet':
"""Mark this queryset to prefer direct record access when possible.
This method sets a preference for using direct record access patterns
over WHERE clause filtering for better performance.
Returns:
The query set instance for method chaining
"""
clone = self._clone()
clone._prefer_direct_access = True
return clone
def _build_query(self) -> str:
"""Build the base query string.
This method must be implemented by subclasses to generate the appropriate
query string for the specific database operation.
Returns:
The query string
Raises:
NotImplementedError: If not implemented by a subclass
"""
raise NotImplementedError("Subclasses must implement _build_query")
def _build_conditions(self) -> List[str]:
"""Build query conditions from query_parts.
This method converts the query_parts list into a list of condition strings
that can be used in a WHERE clause.
Returns:
List of condition strings
"""
conditions = []
for field, op, value in self.query_parts:
# Handle raw query conditions
if field == '__raw__':
conditions.append(value)
# Handle special cases
elif op == '=' and isinstance(field, str) and '::' in field:
conditions.append(f"{field}")
else:
# Determine if field is a RecordID field
def _field_is_record_id(field_name: str) -> bool:
document_class = getattr(self, 'document_class', None)
if not document_class or not hasattr(document_class, '_fields'):
return False
field_obj = document_class._fields.get(field_name)
try:
from .fields.id import RecordIDField # type: ignore
return isinstance(field_obj, RecordIDField)
except Exception:
return False
# Special handling for RecordIDs - only for id or RecordIDField or RecordID object
if field == 'id' or _field_is_record_id(field) or isinstance(value, RecordID):
# Ensure RecordID is properly formatted
if isinstance(value, str) and RecordIdUtils.is_valid_record_id(value):
conditions.append(f"{field} {op} {value}")
elif isinstance(value, RecordID):
conditions.append(f"{field} {op} {str(value)}")
else:
# Try to normalize the RecordID
table_name = None
if hasattr(self, 'document_class') and self.document_class:
table_name = self.document_class._get_collection_name()
normalized = RecordIdUtils.normalize_record_id(value, table_name)
if normalized and RecordIdUtils.is_valid_record_id(normalized):
conditions.append(f"{field} {op} {normalized}")
else:
conditions.append(f"{field} {op} {escape_literal(value)}")
# Special handling for INSIDE and NOT INSIDE operators
elif op in ('INSIDE', 'NOT INSIDE'):
# Only treat list items as record IDs if the field is a RecordID field
treat_items_as_ids = _field_is_record_id(field)
def _is_record_id_str(s):
return isinstance(s, str) and RecordIdUtils.is_valid_record_id(s)
def _format_literal(item):
# Accept dicts with 'id'
if isinstance(item, dict) and 'id' in item and _is_record_id_str(item['id']) and treat_items_as_ids:
return item['id']
# RecordID object
if isinstance(item, RecordID) and treat_items_as_ids:
return str(item)
# String record id
if _is_record_id_str(item) and treat_items_as_ids:
return item
# Fallback to escape_literal for proper quoting/escaping
return escape_literal(item)
if isinstance(value, (list, tuple, set)):
items = ', '.join(_format_literal(v) for v in value)
value_str = f"[{items}]"
else:
# Single non-iterable value - still format appropriately
value_str = _format_literal(value)
conditions.append(f"{field} {op} {value_str}")
elif isinstance(value, RecordID):
# If value is a RecordID object but field is not RecordID-typed, quote it to be safe
conditions.append(f"{field} {op} {escape_literal(str(value))}")
elif op == 'STARTSWITH':
conditions.append(f"string::starts_with({field}, {escape_literal(value)})")
elif op == 'ENDSWITH':
conditions.append(f"string::ends_with({field}, {escape_literal(value)})")
elif op == 'CONTAINS':
if isinstance(value, str):
conditions.append(f"string::contains({field}, {escape_literal(value)})")
else:
conditions.append(f"{field} CONTAINS {escape_literal(value)}")
elif op in ('CONTAINSANY', 'CONTAINSALL', 'CONTAINSNONE', 'ALLINSIDE', 'ANYINSIDE', 'NONEINSIDE'):
# Handle new set operators
conditions.append(f"{field} {op} {escape_literal(value)}")
# Special handling for URL values
elif isinstance(value, dict) and '__url_value__' in value:
# Extract the URL value and ensure it's properly quoted
url_value = value['__url_value__']
conditions.append(f"{field} {op} {escape_literal(url_value)}")
else:
# Convert value to database format if we have field information
db_value = self._convert_value_for_query(field, value)
# Always use escape_literal to ensure proper escaping of all values
# This is especially important for URLs, strings with special characters, Expr vars, and RecordIDs
conditions.append(f"{field} {op} {escape_literal(db_value)}")
return conditions
def _convert_value_for_query(self, field_name: str, value: Any) -> Any:
"""Convert a value to its database representation for query conditions.
This method checks if the document class has a field definition for the given
field name and uses its to_db() method to convert the value properly.
Args:
field_name: The name of the field
value: The value to convert
Returns:
The converted value ready for JSON serialization
"""
# Check if we have a document class with field definitions
document_class = getattr(self, 'document_class', None)
if document_class and hasattr(document_class, '_fields'):
# Get the field definition
field_obj = document_class._fields.get(field_name)
if field_obj and hasattr(field_obj, 'to_db'):
# Use the field's to_db method to convert the value
try:
return field_obj.to_db(value)
except Exception:
# If conversion fails, return the original value
pass
# If no field definition or conversion failed, return original value
return value
def _format_record_id(self, id_value: Any) -> str:
"""Format an ID value into a proper SurrealDB record ID.
This method handles various RecordID formats including URL-encoded versions.
Args:
id_value: The ID value to format
Returns:
Properly formatted record ID string
"""
# Get table name if available
table_name = None
if hasattr(self, 'document_class') and self.document_class:
table_name = self.document_class._get_collection_name()
# Use RecordIdUtils for comprehensive handling
normalized = RecordIdUtils.normalize_record_id(id_value, table_name)
# If normalization succeeded, return it
if normalized is not None:
return normalized
# Fall back to original behavior if normalization fails
if isinstance(id_value, str) and ':' in id_value:
return id_value
elif isinstance(id_value, RecordID):
return str(id_value)
elif table_name:
return f"{table_name}:{id_value}"
else:
return str(id_value)
def _build_direct_record_query(self) -> Optional[str]:
"""Build optimized direct record access query if applicable.
Returns:
Optimized query string or None if not applicable
"""
# Handle bulk ID selection optimization
if self._bulk_id_selection:
if not self._bulk_id_selection: # Empty list
return None
record_ids = [self._format_record_id(id_val) for id_val in self._bulk_id_selection]
query = f"SELECT * FROM {', '.join(record_ids)}"
# Add other clauses (but skip WHERE since we're using direct access)
clauses = self._build_clauses()
for clause_name, clause_sql in clauses.items():
if clause_name != 'WHERE': # Skip WHERE for direct access
query += f" {clause_sql}"
return query
# Handle ID range selection optimization
if self._id_range_selection:
start_id, end_id, inclusive = self._id_range_selection
start_record_id = self._format_record_id(start_id)
end_record_id = self._format_record_id(end_id)
# Extract just the numeric part for range syntax
collection_name = getattr(self, 'document_class', None)
if collection_name:
collection_name = collection_name._get_collection_name()
# Extract numeric IDs from record IDs
start_num = str(start_id).split(':')[-1] if ':' in str(start_id) else str(start_id)
end_num = str(end_id).split(':')[-1] if ':' in str(end_id) else str(end_id)
range_op = "..=" if inclusive else ".."
query = f"SELECT * FROM {collection_name}:{start_num}{range_op}{end_num}"
else:
# Fall back to WHERE clause if we can't determine collection
return None
# Add other clauses (but skip WHERE since we're using direct access)
clauses = self._build_clauses()
for clause_name, clause_sql in clauses.items():
if clause_name != 'WHERE': # Skip WHERE for direct access
query += f" {clause_sql}"
return query
return None
def _build_clauses(self) -> Dict[str, str]:
"""Build query clauses from the query parameters.
This method builds the various clauses for the query string, including
WHERE, GROUP BY, SPLIT, WITH, ORDER BY, LIMIT, START, and FETCH.
Returns:
Dictionary of clause names and their string representations
"""
clauses = {}
# Build WHERE clause
if self.query_parts:
conditions = self._build_conditions()
clauses['WHERE'] = f"WHERE {' AND '.join(conditions)}"
if self.group_by_fields:
clauses['GROUP BY'] = f"GROUP BY {', '.join(self.group_by_fields)}"
elif self.group_by_all:
clauses['GROUP BY'] = "GROUP ALL"
# Build SPLIT clause
if self.split_fields:
clauses['SPLIT'] = f"SPLIT {', '.join(self.split_fields)}"
# Build WITH clause
if self.with_index:
clauses['WITH'] = f"WITH INDEX {self.with_index}"
# Build ORDER BY clause
if self.order_by_value:
field, direction = self.order_by_value
clauses['ORDER BY'] = f"ORDER BY {field} {direction}"
# Build LIMIT clause
if self.limit_value is not None:
clauses['LIMIT'] = f"LIMIT {self.limit_value}"
# Build START clause
if self.start_value is not None:
clauses['START'] = f"START {self.start_value}"
# IMPORTANT: In SurrealQL, FETCH must be the last clause
if self.fetch_fields:
clauses['FETCH'] = f"FETCH {', '.join(self.fetch_fields)}"
# Build TIMEOUT clause
if self.timeout_value:
clauses['TIMEOUT'] = f"TIMEOUT {self.timeout_value}"
# Build TEMPFILES clause
if self.tempfiles_value:
clauses['TEMPFILES'] = "TEMPFILES"
# Build EXPLAIN clause
if self.explain_value:
if self.explain_full_value:
clauses['EXPLAIN'] = "EXPLAIN FULL"
else:
clauses['EXPLAIN'] = "EXPLAIN"
return clauses
def _get_collection_name(self) -> Optional[str]:
"""Get the collection name for this queryset.
Returns:
Collection name or None if not available
"""
document_class = getattr(self, 'document_class', None)
if document_class and hasattr(document_class, '_get_collection_name'):
return document_class._get_collection_name()
return getattr(self, 'table_name', None)
[docs]
async def all(self) -> List[Any]:
"""Execute the query and return all results asynchronously.
This method must be implemented by subclasses to execute the query
and return the results.
Returns:
List of results
Raises:
NotImplementedError: If not implemented by a subclass
"""
raise NotImplementedError("Subclasses must implement all")
[docs]
def all_sync(self) -> List[Any]:
"""Execute the query and return all results synchronously.
This method must be implemented by subclasses to execute the query
and return the results.
Returns:
List of results
Raises:
NotImplementedError: If not implemented by a subclass
"""
raise NotImplementedError("Subclasses must implement all_sync")
[docs]
async def first(self) -> Optional[Any]:
"""Execute the query and return the first result asynchronously.
This method limits the query to one result and returns the first item
or None if no results are found.
Returns:
The first result or None if no results
"""
self.limit_value = 1
results = await self.all()
return results[0] if results else None
[docs]
def first_sync(self) -> Optional[Any]:
"""Execute the query and return the first result synchronously.
This method limits the query to one result and returns the first item
or None if no results are found.
Returns:
The first result or None if no results
"""
self.limit_value = 1
results = self.all_sync()
return results[0] if results else None
[docs]
async def get(self, **kwargs) -> Any:
"""Get a single document matching the query asynchronously.
This method applies filters and ensures that exactly one document is returned.
For ID-based lookups, it uses direct record syntax instead of WHERE clause.
Args:
**kwargs: Field names and values to filter by
Returns:
The matching document
Raises:
DoesNotExist: If no matching document is found
MultipleObjectsReturned: If multiple matching documents are found
"""
# Special handling for ID-based lookup
if len(kwargs) == 1 and 'id' in kwargs:
id_value = kwargs['id']
# If it's already a full record ID (table:id format)
if isinstance(id_value, str) and ':' in id_value:
query = f"SELECT * FROM {id_value}"
else:
# Get table name from document class if available
table_name = getattr(self, 'document_class', None)
if table_name:
table_name = table_name._get_collection_name()
else:
table_name = getattr(self, 'table_name', None)
if table_name:
query = f"SELECT * FROM {table_name}:{id_value}"
else:
# Fall back to regular filtering if we can't determine the table
return await self._get_with_filters(**kwargs)
result = await self.connection.client.query(query)
if not result or not result[0]:
raise DoesNotExist(f"Object with ID '{id_value}' does not exist.")
return result[0][0]
# For non-ID lookups, use regular filtering
return await self._get_with_filters(**kwargs)
[docs]
def get_sync(self, **kwargs) -> Any:
"""Get a single document matching the query synchronously.
This method applies filters and ensures that exactly one document is returned.
For ID-based lookups, it uses direct record syntax instead of WHERE clause.
Args:
**kwargs: Field names and values to filter by
Returns:
The matching document
Raises:
DoesNotExist: If no matching document is found
MultipleObjectsReturned: If multiple matching documents are found
"""
# Special handling for ID-based lookup
if len(kwargs) == 1 and 'id' in kwargs:
id_value = kwargs['id']
# If it's already a full record ID (table:id format)
if isinstance(id_value, str) and ':' in id_value:
query = f"SELECT * FROM {id_value}"
else:
# Get table name from document class if available
table_name = getattr(self, 'document_class', None)
if table_name:
table_name = table_name._get_collection_name()
else:
table_name = getattr(self, 'table_name', None)
if table_name:
query = f"SELECT * FROM {table_name}:{id_value}"
else:
# Fall back to regular filtering if we can't determine the table
return self._get_with_filters_sync(**kwargs)
result = self.connection.client.query(query)
if not result or not result[0]:
raise DoesNotExist(f"Object with ID '{id_value}' does not exist.")
return result[0][0]
# For non-ID lookups, use regular filtering
return self._get_with_filters_sync(**kwargs)
async def _get_with_filters(self, **kwargs) -> Any:
"""Internal method to get a single document using filters asynchronously.
Args:
**kwargs: Field names and values to filter by
Returns:
The matching document
Raises:
DoesNotExist: If no matching document is found
MultipleObjectsReturned: If multiple matching documents are found
"""
self.filter(**kwargs)
self.limit_value = 2 # Get 2 to check for multiple
results = await self.all()
if not results:
raise DoesNotExist(f"Object matching query does not exist.")
if len(results) > 1:
raise MultipleObjectsReturned(f"Multiple objects returned instead of one")
return results[0]
def _get_with_filters_sync(self, **kwargs) -> Any:
"""Internal method to get a single document using filters synchronously.
Args:
**kwargs: Field names and values to filter by
Returns:
The matching document
Raises:
DoesNotExist: If no matching document is found
MultipleObjectsReturned: If multiple matching documents are found
"""
self.filter(**kwargs)
self.limit_value = 2 # Get 2 to check for multiple
results = self.all_sync()
if not results:
raise DoesNotExist(f"Object matching query does not exist.")
if len(results) > 1:
raise MultipleObjectsReturned(f"Multiple objects returned instead of one")
return results[0]
[docs]
async def count(self) -> int:
"""Count documents matching the query asynchronously.
This method must be implemented by subclasses to count the number
of documents matching the query.
Returns:
Number of matching documents
Raises:
NotImplementedError: If not implemented by a subclass
"""
raise NotImplementedError("Subclasses must implement count")
[docs]
def count_sync(self) -> int:
"""Count documents matching the query synchronously.
This method must be implemented by subclasses to count the number
of documents matching the query.
Returns:
Number of matching documents
Raises:
NotImplementedError: If not implemented by a subclass
"""
raise NotImplementedError("Subclasses must implement count_sync")
[docs]
def __await__(self):
"""Make the queryset awaitable.
This method allows the queryset to be used with the await keyword,
which will execute the query and return all results.
Returns:
Awaitable that resolves to the query results
"""
return self.all().__await__()
[docs]
def page(self, number: int, size: int) -> 'BaseQuerySet':
"""Set pagination parameters using page number and size.
This method calculates the appropriate LIMIT and START values
based on the page number and size, providing a more convenient
way to paginate results.
Args:
number: Page number (1-based, first page is 1)
size: Number of items per page
Returns:
The query set instance for method chaining
"""
if number < 1:
raise ValueError("Page number must be 1 or greater")
if size < 1:
raise ValueError("Page size must be 1 or greater")
self.limit_value = size
self.start_value = (number - 1) * size
return self
[docs]
async def paginate(self, page: int, per_page: int) -> PaginationResult:
"""Get a page of results with pagination metadata asynchronously.
This method gets a page of results along with metadata about the
pagination, such as the total number of items, the number of pages,
and whether there are next or previous pages.
Args:
page: The page number (1-based)
per_page: The number of items per page
Returns:
A PaginationResult containing the items and pagination metadata
"""
# Get the total count
total = await self.count()
# Get the items for the current page
items = await self.page(page, per_page).all()
# Return a PaginationResult
return PaginationResult(items, page, per_page, total)
[docs]
def paginate_sync(self, page: int, per_page: int) -> PaginationResult:
"""Get a page of results with pagination metadata synchronously.
This method gets a page of results along with metadata about the
pagination, such as the total number of items, the number of pages,
and whether there are next or previous pages.
Args:
page: The page number (1-based)
per_page: The number of items per page
Returns:
A PaginationResult containing the items and pagination metadata
"""
# Get the total count
total = self.count_sync()
# Get the items for the current page
items = self.page(page, per_page).all_sync()
# Return a PaginationResult
return PaginationResult(items, page, per_page, total)
[docs]
def get_raw_query(self) -> str:
"""Get the raw query string without executing it.
This method builds and returns the query string without executing it.
It can be used to get the raw query for manual execution or debugging.
Returns:
The raw query string
"""
return self._build_query()
[docs]
def aggregate(self):
"""Create an aggregation pipeline from this query.
This method returns an AggregationPipeline instance that can be used
to build and execute complex aggregation queries with multiple stages.
Returns:
An AggregationPipeline instance for building and executing
aggregation queries.
"""
from .aggregation import AggregationPipeline
return AggregationPipeline(self)
def _clone(self) -> 'BaseQuerySet':
"""Create a new instance of the queryset with the same parameters.
This method creates a new instance of the same class as the current
instance and copies all the relevant attributes.
Returns:
A new queryset instance with the same parameters
"""
# Create a new instance of the same class
if hasattr(self, 'document_class'):
# For QuerySet subclass
clone = self.__class__(self.document_class, self.connection)
elif hasattr(self, 'table_name'):
# For SchemalessQuerySet subclass
clone = self.__class__(self.table_name, self.connection)
else:
# For BaseQuerySet or other subclasses
clone = self.__class__(self.connection)
# Copy all the query parameters
clone.query_parts = self.query_parts.copy()
clone.limit_value = self.limit_value
clone.start_value = self.start_value
clone.order_by_value = self.order_by_value
clone.group_by_fields = self.group_by_fields.copy()
clone.split_fields = self.split_fields.copy()
clone.fetch_fields = self.fetch_fields.copy()
clone.with_index = self.with_index
clone.select_fields = self.select_fields
# Copy performance optimization attributes
clone._bulk_id_selection = self._bulk_id_selection
clone._id_range_selection = self._id_range_selection
clone._prefer_direct_access = self._prefer_direct_access
# Copy traversal state
clone._traversal_path = self._traversal_path
clone._traversal_unique = self._traversal_unique
clone._traversal_max_depth = self._traversal_max_depth
return clone