Source code for surrealengine.aggregation

"""Aggregation pipeline for SurrealEngine.

This module provides support for building and executing aggregation pipelines
in SurrealEngine. Aggregation pipelines allow for complex data transformations
and analysis through a series of stages.
"""
import json
from .surrealql import escape_literal
import re
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

from .connection import ConnectionRegistry

if TYPE_CHECKING:
    from .query import QuerySet


[docs] class AggregationPipeline: """Pipeline for building and executing aggregation queries. This class provides a fluent interface for building complex aggregation pipelines with multiple stages, similar to MongoDB's aggregation framework. """
[docs] def __init__(self, query_set: 'QuerySet'): """Initialize a new AggregationPipeline. Args: query_set: The QuerySet to build the pipeline from """ self.query_set = query_set self.stages = [] self.connection = query_set.connection
[docs] def group(self, by_fields=None, **aggregations): """Group by fields and apply aggregations. Args: by_fields: Field or list of fields to group by **aggregations: Named aggregation functions to apply Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'group', 'by_fields': by_fields if isinstance(by_fields, list) else ([by_fields] if by_fields else []), 'aggregations': aggregations }) return self
[docs] def project(self, **fields): """Select or compute fields to include in output. Args: **fields: Field mappings for projection Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'project', 'fields': fields }) return self
[docs] def sort(self, **fields): """Sort results by fields. Args: **fields: Field names and sort directions ('ASC' or 'DESC') Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'sort', 'fields': fields }) return self
[docs] def limit(self, count): """Limit number of results. Args: count: Maximum number of results to return Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'limit', 'count': count }) return self
[docs] def skip(self, count): """Skip number of results. Args: count: Number of results to skip Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'skip', 'count': count }) return self
[docs] def with_index(self, index): """Use the specified index for the query. Args: index: Name of the index to use Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'with_index', 'index': index }) return self
[docs] def match(self, **conditions): """Filter documents before aggregation (similar to WHERE clause). This method adds filtering conditions that are applied before any aggregation operations. Multiple conditions are combined with AND. Args: **conditions: Field-value pairs for filtering (e.g., status='active') Returns: The pipeline instance for method chaining Example: pipeline.match(status='completed', price__gt=100) """ self.stages.append({ 'type': 'match', 'conditions': conditions }) return self
[docs] def having(self, **conditions): """Filter aggregated results (similar to HAVING clause). This method adds filtering conditions that are applied after aggregation operations. Use this to filter based on aggregated values. Args: **conditions: Field-value pairs for filtering aggregated results Returns: The pipeline instance for method chaining Example: pipeline.group(by_fields='category', total=Sum('price')).having(total__gt=1000) """ self.stages.append({ 'type': 'having', 'conditions': conditions }) return self
[docs] def build_query(self): """Build the SurrealQL query from the pipeline stages. Returns: The SurrealQL query string """ # Start with the base query from the query set base_query = self.query_set.get_raw_query() # Extract the FROM clause and any clauses that come after it from_index = base_query.upper().find("FROM") if from_index == -1: return base_query # Split the query into the SELECT part and the rest select_part = base_query[:from_index].strip() rest_part = base_query[from_index:].strip() # Process the stages to modify the query for stage in self.stages: if stage['type'] == 'match': # Handle MATCH stage (pre-aggregation filtering) conditions = stage['conditions'] if conditions: # Build WHERE conditions where_conditions = [] for field, value in conditions.items(): # Handle Django-style operators if '__' in field: field_name, op = field.rsplit('__', 1) if op == 'gt': where_conditions.append(f"{field_name} > {escape_literal(value)}") elif op == 'lt': where_conditions.append(f"{field_name} < {escape_literal(value)}") elif op == 'gte': where_conditions.append(f"{field_name} >= {escape_literal(value)}") elif op == 'lte': where_conditions.append(f"{field_name} <= {escape_literal(value)}") elif op == 'ne': where_conditions.append(f"{field_name} != {escape_literal(value)}") elif op == 'in': where_conditions.append(f"{field_name} IN {escape_literal(value)}") elif op == 'nin': where_conditions.append(f"{field_name} NOT IN {escape_literal(value)}") elif op == 'contains': where_conditions.append(f"{field_name} CONTAINS {escape_literal(value)}") elif op == 'startswith': where_conditions.append(f"string::starts_with({field_name}, {escape_literal(value)})") elif op == 'endswith': where_conditions.append(f"string::ends_with({field_name}, {escape_literal(value)})") else: # Default to equality where_conditions.append(f"{field_name} = {escape_literal(value)}") else: # Simple equality where_conditions.append(f"{field} = {escape_literal(value)}") where_clause = f"WHERE {' AND '.join(where_conditions)}" # Check if there's already a WHERE clause if "WHERE" in rest_part.upper(): # Append to existing WHERE clause where_index = rest_part.upper().find("WHERE") # Find the end of WHERE clause for clause in ["GROUP BY", "SPLIT", "FETCH", "ORDER BY", "LIMIT", "START"]: clause_index = rest_part.upper().find(clause, where_index) if clause_index != -1: # Insert before the next clause existing_where = rest_part[where_index:clause_index].strip() new_where = f"{existing_where} AND {' AND '.join(where_conditions)}" rest_part = f"{rest_part[:where_index]}{new_where} {rest_part[clause_index:]}" break else: # No other clause after WHERE rest_part = f"{rest_part} AND {' AND '.join(where_conditions)}" else: # Add WHERE clause before GROUP BY or other clauses for clause in ["GROUP BY", "SPLIT", "FETCH", "ORDER BY", "LIMIT", "START"]: clause_index = rest_part.upper().find(clause) if clause_index != -1: rest_part = f"{rest_part[:clause_index]}{where_clause} {rest_part[clause_index:]}" break else: # No other clauses, add to the end rest_part = f"{rest_part} {where_clause}" elif stage['type'] == 'group': # Handle GROUP BY stage by_fields = stage['by_fields'] aggregations = stage['aggregations'] # Build the GROUP BY clause or GROUP ALL if by_fields: group_clause = f"GROUP BY {', '.join(by_fields)}" else: # If no explicit group fields but aggregations exist with array functions, # we can use GROUP ALL to enable row-collection semantics. group_clause = "GROUP ALL" if by_fields or aggregations: # Inject group clause if not already present upper = rest_part.upper() if "GROUP BY" in upper or "GROUP ALL" in upper: # Replace existing group clause rest_part = re.sub(r'GROUP (?:BY|ALL).*?(?=(ORDER BY|LIMIT|START|$))', group_clause, rest_part, flags=re.IGNORECASE) else: # Add the group clause before ORDER BY, LIMIT, or START inserted = False for clause in ["ORDER BY", "LIMIT", "START"]: clause_index = rest_part.upper().find(clause) if clause_index != -1: rest_part = f"{rest_part[:clause_index]}{group_clause} {rest_part[clause_index:]}" inserted = True break if not inserted: rest_part = f"{rest_part} {group_clause}" # Build the SELECT part with aggregations if aggregations: # Start with the group by fields select_fields = by_fields.copy() if by_fields else [] # Add the aggregations for name, agg in aggregations.items(): select_fields.append(f"{agg} AS {name}") # Replace the SELECT part select_part = f"SELECT {', '.join(select_fields)}" elif stage['type'] == 'project': # Handle PROJECT stage fields = stage['fields'] # Build the SELECT part with projections if fields: select_fields = [] # Add the projections for name, expr in fields.items(): if expr is True: # Include the field as is select_fields.append(name) else: # Include the field with an expression select_fields.append(f"{expr} AS {name}") # Replace the SELECT part select_part = f"SELECT {', '.join(select_fields)}" elif stage['type'] == 'having': # Handle HAVING stage (post-aggregation filtering) conditions = stage['conditions'] if conditions: # Build HAVING conditions having_conditions = [] for field, value in conditions.items(): # Handle Django-style operators if '__' in field: field_name, op = field.rsplit('__', 1) if op == 'gt': having_conditions.append(f"{field_name} > {escape_literal(value)}") elif op == 'lt': having_conditions.append(f"{field_name} < {escape_literal(value)}") elif op == 'gte': having_conditions.append(f"{field_name} >= {escape_literal(value)}") elif op == 'lte': having_conditions.append(f"{field_name} <= {escape_literal(value)}") elif op == 'ne': having_conditions.append(f"{field_name} != {escape_literal(value)}") elif op == 'in': having_conditions.append(f"{field_name} IN {escape_literal(value)}") elif op == 'nin': having_conditions.append(f"{field_name} NOT IN {escape_literal(value)}") else: # Default to equality having_conditions.append(f"{field_name} = {escape_literal(value)}") else: # Simple equality having_conditions.append(f"{field} = {escape_literal(value)}") # For SurrealDB, HAVING is implemented as a WHERE clause on the aggregated results # We need to wrap the entire query in a subquery and apply WHERE on it # This will be handled at the end of query building self.having_conditions = having_conditions elif stage['type'] == 'sort': # Handle SORT stage fields = stage['fields'] # Build the ORDER BY clause if fields: order_by_parts = [] # Add the sort fields for field, direction in fields.items(): order_by_parts.append(f"{field} {direction}") order_by_clause = f"ORDER BY {', '.join(order_by_parts)}" # Check if there's already an ORDER BY clause if "ORDER BY" in rest_part.upper(): # Replace the existing ORDER BY clause rest_part = re.sub(r'ORDER BY.*?(?=(LIMIT|START|$))', order_by_clause, rest_part, flags=re.IGNORECASE) else: # Add the ORDER BY clause before LIMIT or START for clause in ["LIMIT", "START"]: clause_index = rest_part.upper().find(clause) if clause_index != -1: rest_part = f"{rest_part[:clause_index]}{order_by_clause} {rest_part[clause_index:]}" break else: # No LIMIT or START, so add to the end rest_part = f"{rest_part} {order_by_clause}" elif stage['type'] == 'limit': # Handle LIMIT stage count = stage['count'] # Build the LIMIT clause limit_clause = f"LIMIT {count}" # Check if there's already a LIMIT clause if "LIMIT" in rest_part.upper(): # Replace the existing LIMIT clause rest_part = re.sub(r'LIMIT.*?(?=(START|$))', limit_clause, rest_part, flags=re.IGNORECASE) else: # Add the LIMIT clause before START start_index = rest_part.upper().find("START") if start_index != -1: rest_part = f"{rest_part[:start_index]}{limit_clause} {rest_part[start_index:]}" else: # No START, so add to the end rest_part = f"{rest_part} {limit_clause}" elif stage['type'] == 'skip': # Handle SKIP stage count = stage['count'] # Build the START clause start_clause = f"START {count}" # Check if there's already a START clause if "START" in rest_part.upper(): # Replace the existing START clause rest_part = re.sub(r'START.*?(?=$)', start_clause, rest_part, flags=re.IGNORECASE) else: # Add the START clause to the end rest_part = f"{rest_part} {start_clause}" elif stage['type'] == 'with_index': # Handle WITH_INDEX stage index = stage['index'] # Build the WITH clause with_clause = f"WITH INDEX {index}" # Check if there's already a WITH clause if "WITH" in rest_part.upper(): # Replace the existing WITH clause rest_part = re.sub(r'WITH.*?(?=(WHERE|GROUP BY|SPLIT|FETCH|ORDER BY|LIMIT|START|$))', with_clause, rest_part, flags=re.IGNORECASE) else: # Add the WITH clause before WHERE, GROUP BY, SPLIT, FETCH, ORDER BY, LIMIT, or START for clause in ["WHERE", "GROUP BY", "SPLIT", "FETCH", "ORDER BY", "LIMIT", "START"]: clause_index = rest_part.upper().find(clause) if clause_index != -1: rest_part = f"{rest_part[:clause_index]}{with_clause} {rest_part[clause_index:]}" break else: # No WHERE, GROUP BY, SPLIT, FETCH, ORDER BY, LIMIT, or START, so add to the end rest_part = f"{rest_part} {with_clause}" # Combine the SELECT part with the rest of the query final_query = f"{select_part} {rest_part}" # Handle HAVING conditions by wrapping in a subquery if hasattr(self, 'having_conditions') and self.having_conditions: # Wrap the entire query in a subquery and apply WHERE conditions having_where = " AND ".join(self.having_conditions) final_query = f"SELECT * FROM ({final_query}) WHERE {having_where}" return final_query
[docs] async def execute(self, connection=None): """Execute the pipeline and return results. Args: connection: Optional connection to use Returns: A list of result rows (dicts) from the final SELECT statement """ query = self.build_query() connection = connection or self.connection or ConnectionRegistry.get_default_connection() results = await connection.client.query(query) # Normalize RPC response to a list of row dicts, similar to QuerySet.all() if not results: return [] rows = None if isinstance(results, list): if results and isinstance(results[0], dict): rows = results else: for part in reversed(results): if isinstance(part, list): rows = part break else: rows = results if not rows: return [] if isinstance(rows, dict): rows = [rows] return rows
[docs] def execute_sync(self, connection=None): """Execute the pipeline synchronously. Args: connection: Optional connection to use Returns: A list of result rows (dicts) from the final SELECT statement """ query = self.build_query() connection = connection or self.connection or ConnectionRegistry.get_default_connection() results = connection.client.query(query) # Normalize RPC response to a list of row dicts, similar to QuerySet.all_sync() if not results: return [] rows = None if isinstance(results, list): if results and isinstance(results[0], dict): rows = results else: for part in reversed(results): if isinstance(part, list): rows = part break else: rows = results if not rows: return [] if isinstance(rows, dict): rows = [rows] return rows