Source code for surrealengine.graph

import json
from typing import Any, Dict, List, Optional, Type, Union


[docs] class GraphQuery: """Helper for complex graph queries. This class provides a fluent interface for building and executing complex graph traversal queries in SurrealDB. It allows defining a starting point, traversal path, end point, and filters for the query. Attributes: connection: The database connection to use for queries query_parts: List of query parts start_class: The document class to start the traversal from start_filters: Filters to apply to the starting documents path_spec: The traversal path specification end_class: The document class to end the traversal at end_filters: Filters to apply to the end results """
[docs] def __init__(self, connection: Any) -> None: """Initialize a new GraphQuery. Args: connection: The database connection to use for queries """ self.connection = connection self.query_parts: List[Any] = []
[docs] def start_from(self, document_class: Type, **filters: Any) -> 'GraphQuery': """Set the starting point for the graph query. Args: document_class: The document class to start the traversal from **filters: Filters to apply to the starting documents Returns: The GraphQuery instance for method chaining """ self.start_class = document_class self.start_filters = filters return self
[docs] def traverse(self, path_spec: str) -> 'GraphQuery': """Define a traversal path. Args: path_spec: The traversal path specification, e.g., "->[relation]->" Returns: The GraphQuery instance for method chaining """ self.path_spec = path_spec return self
[docs] def end_at(self, document_class: Optional[Type] = None) -> 'GraphQuery': """Set the end point document type. Args: document_class: The document class to end the traversal at Returns: The GraphQuery instance for method chaining """ self.end_class = document_class return self
[docs] def filter_results(self, **filters: Any) -> 'GraphQuery': """Add filters to the end results. Args: **filters: Filters to apply to the end results Returns: The GraphQuery instance for method chaining """ self.end_filters = filters return self
[docs] async def execute(self) -> List[Any]: """Execute the graph query. This method builds and executes the graph query based on the components defined using the fluent interface methods. It validates that the required components are present, builds the query string, executes it, and processes the results. Returns: List of results, either document instances or raw results Raises: ValueError: If required components are missing """ # Build query based on components if not hasattr(self, 'start_class'): raise ValueError("Must specify a starting document class with start_from()") if not hasattr(self, 'path_spec'): raise ValueError("Must specify a traversal path with traverse()") # Start with the FROM clause collection = self.start_class._get_collection_name() query = f"SELECT " # Define what to select if hasattr(self, 'end_class') and self.end_class: end_collection = self.end_class._get_collection_name() query += f"* FROM {end_collection}" is_end_query = True else: query += f"{self.path_spec} as path FROM {collection}" is_end_query = False # Add WHERE clause for start filters where_clauses = [] if hasattr(self, 'start_filters') and self.start_filters: if is_end_query: path_query = f" WHERE {self.path_spec}" # Add start filters start_conditions = [] for field, value in self.start_filters.items(): from .surrealql import escape_literal start_conditions.append(f"{field} = {escape_literal(value)}") if start_conditions: path_query += f"({collection} WHERE {' AND '.join(start_conditions)})" else: path_query += f"{collection}" where_clauses.append(path_query) else: for field, value in self.start_filters.items(): from .surrealql import escape_literal where_clauses.append(f"{field} = {escape_literal(value)}") # Add end filters if hasattr(self, 'end_filters') and self.end_filters: for field, value in self.end_filters.items(): from .surrealql import escape_literal where_clauses.append(f"{field} = {escape_literal(value)}") # Complete the query if where_clauses: query += " WHERE " + " AND ".join(where_clauses) # Execute the query result = await self.connection.client.query(query) # Process results if not result or not result[0]: return [] if is_end_query and hasattr(self, 'end_class'): # Return document instances return [self.end_class.from_db(doc) for doc in result[0]] else: # Return raw results return result[0]