Source code for quantumengine.aggregation

"""Aggregation pipeline for SurrealEngine.

This module provides support for building and executing aggregation pipelines
in SurrealEngine. Aggregation pipelines allow for complex data transformations
and analysis through a series of stages.
"""
import re
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING

from .connection import ConnectionRegistry

if TYPE_CHECKING:
    from .query import QuerySet


[docs] class AggregationPipeline: """Pipeline for building and executing aggregation queries. This class provides a fluent interface for building complex aggregation pipelines with multiple stages, similar to MongoDB's aggregation framework. """
[docs] def __init__(self, query_set: 'QuerySet'): """Initialize a new AggregationPipeline. Args: query_set: The QuerySet to build the pipeline from """ self.query_set = query_set self.stages = [] self.connection = query_set.connection self.backend = query_set.backend
[docs] def group(self, by_fields=None, **aggregations): """Group by fields and apply aggregations. Args: by_fields: Field or list of fields to group by **aggregations: Named aggregation functions to apply Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'group', 'by_fields': by_fields if isinstance(by_fields, list) else ([by_fields] if by_fields else []), 'aggregations': aggregations }) return self
[docs] def project(self, **fields): """Select or compute fields to include in output. Args: **fields: Field mappings for projection Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'project', 'fields': fields }) return self
[docs] def sort(self, **fields): """Sort results by fields. Args: **fields: Field names and sort directions ('ASC' or 'DESC') Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'sort', 'fields': fields }) return self
[docs] def limit(self, count): """Limit number of results. Args: count: Maximum number of results to return Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'limit', 'count': count }) return self
[docs] def skip(self, count): """Skip number of results. Args: count: Number of results to skip Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'skip', 'count': count }) return self
[docs] def with_index(self, index): """Use the specified index for the query. Args: index: Name of the index to use Returns: The pipeline instance for method chaining """ self.stages.append({ 'type': 'with_index', 'index': index }) return self
[docs] def build_query(self): """Build the SurrealQL query from the pipeline stages. Returns: The SurrealQL query string """ # Start with the base query from the query set base_query = self.query_set.get_raw_query() # Extract the FROM clause and any clauses that come after it from_index = base_query.upper().find("FROM") if from_index == -1: return base_query # Split the query into the SELECT part and the rest select_part = base_query[:from_index].strip() rest_part = base_query[from_index:].strip() # Process the stages to modify the query for stage in self.stages: if stage['type'] == 'group': # Handle GROUP BY stage by_fields = stage['by_fields'] aggregations = stage['aggregations'] # Build the GROUP BY clause if by_fields: group_by_clause = f"GROUP BY {', '.join(by_fields)}" # Check if there's already a GROUP BY clause if "GROUP BY" in rest_part.upper(): # Replace the existing GROUP BY clause rest_part = re.sub(r'GROUP BY.*?(?=(ORDER BY|LIMIT|START|$))', group_by_clause, rest_part, flags=re.IGNORECASE) else: # Add the GROUP BY clause before ORDER BY, LIMIT, or START for clause in ["ORDER BY", "LIMIT", "START"]: clause_index = rest_part.upper().find(clause) if clause_index != -1: rest_part = f"{rest_part[:clause_index]}{group_by_clause} {rest_part[clause_index:]}" break else: # No ORDER BY, LIMIT, or START, so add to the end rest_part = f"{rest_part} {group_by_clause}" # Build the SELECT part with aggregations if aggregations: # Start with the group by fields select_fields = by_fields.copy() if by_fields else [] # Add the aggregations for name, agg in aggregations.items(): select_fields.append(f"{agg} AS {name}") # Replace the SELECT part select_part = f"SELECT {', '.join(select_fields)}" elif stage['type'] == 'project': # Handle PROJECT stage fields = stage['fields'] # Build the SELECT part with projections if fields: select_fields = [] # Add the projections for name, expr in fields.items(): if expr is True: # Include the field as is select_fields.append(name) else: # Include the field with an expression select_fields.append(f"{expr} AS {name}") # Replace the SELECT part select_part = f"SELECT {', '.join(select_fields)}" elif stage['type'] == 'sort': # Handle SORT stage fields = stage['fields'] # Build the ORDER BY clause if fields: order_by_parts = [] # Add the sort fields for field, direction in fields.items(): order_by_parts.append(f"{field} {direction}") order_by_clause = f"ORDER BY {', '.join(order_by_parts)}" # Check if there's already an ORDER BY clause if "ORDER BY" in rest_part.upper(): # Replace the existing ORDER BY clause rest_part = re.sub(r'ORDER BY.*?(?=(LIMIT|START|$))', order_by_clause, rest_part, flags=re.IGNORECASE) else: # Add the ORDER BY clause before LIMIT or START for clause in ["LIMIT", "START"]: clause_index = rest_part.upper().find(clause) if clause_index != -1: rest_part = f"{rest_part[:clause_index]}{order_by_clause} {rest_part[clause_index:]}" break else: # No LIMIT or START, so add to the end rest_part = f"{rest_part} {order_by_clause}" elif stage['type'] == 'limit': # Handle LIMIT stage count = stage['count'] # Build the LIMIT clause limit_clause = f"LIMIT {count}" # Check if there's already a LIMIT clause if "LIMIT" in rest_part.upper(): # Replace the existing LIMIT clause rest_part = re.sub(r'LIMIT.*?(?=(START|$))', limit_clause, rest_part, flags=re.IGNORECASE) else: # Add the LIMIT clause before START start_index = rest_part.upper().find("START") if start_index != -1: rest_part = f"{rest_part[:start_index]}{limit_clause} {rest_part[start_index:]}" else: # No START, so add to the end rest_part = f"{rest_part} {limit_clause}" elif stage['type'] == 'skip': # Handle SKIP stage count = stage['count'] # Build the START clause start_clause = f"START {count}" # Check if there's already a START clause if "START" in rest_part.upper(): # Replace the existing START clause rest_part = re.sub(r'START.*?(?=$)', start_clause, rest_part, flags=re.IGNORECASE) else: # Add the START clause to the end rest_part = f"{rest_part} {start_clause}" elif stage['type'] == 'with_index': # Handle WITH_INDEX stage index = stage['index'] # Build the WITH clause with_clause = f"WITH INDEX {index}" # Check if there's already a WITH clause if "WITH" in rest_part.upper(): # Replace the existing WITH clause rest_part = re.sub(r'WITH.*?(?=(WHERE|GROUP BY|SPLIT|FETCH|ORDER BY|LIMIT|START|$))', with_clause, rest_part, flags=re.IGNORECASE) else: # Add the WITH clause before WHERE, GROUP BY, SPLIT, FETCH, ORDER BY, LIMIT, or START for clause in ["WHERE", "GROUP BY", "SPLIT", "FETCH", "ORDER BY", "LIMIT", "START"]: clause_index = rest_part.upper().find(clause) if clause_index != -1: rest_part = f"{rest_part[:clause_index]}{with_clause} {rest_part[clause_index:]}" break else: # No WHERE, GROUP BY, SPLIT, FETCH, ORDER BY, LIMIT, or START, so add to the end rest_part = f"{rest_part} {with_clause}" # Combine the SELECT part with the rest of the query return f"{select_part} {rest_part}"
[docs] async def execute(self, connection=None): """Execute the pipeline and return results. Args: connection: Optional connection to use Returns: The query results """ query = self.build_query() # Use backend for execution if hasattr(self, 'backend') and self.backend: result = await self.backend.execute_raw(query) else: # Fallback to direct client execution connection = connection or self.connection or ConnectionRegistry.get_default_connection() result = await connection.client.query(query) return result
[docs] def execute_sync(self, connection=None): """Execute the pipeline synchronously. Args: connection: Optional connection to use Returns: The query results """ query = self.build_query() # For sync operations, fall back to direct client for now # TODO: Add sync backend methods or convert to async connection = connection or self.connection or ConnectionRegistry.get_default_connection() return connection.client.query(query)