import re
import uuid
import decimal
import base64
import socket
import urllib.parse
import io
import os
from decimal import Decimal
from typing import Any, Dict, List, Optional, Pattern, Type, Union, BinaryIO
from .base import Field
from .scalar import StringField, NumberField
from ..exceptions import ValidationError
class BytesFieldWrapper:
"""File-like wrapper for BytesField data.
Provides standard Python file operations for binary data stored in BytesField,
making it easy to work with files, images, documents, and other binary content.
Features:
- Standard file operations: read(), write(), seek(), tell()
- Context manager support (with statement)
- Multiple read modes: read all, read chunks, readline for text
- Write operations with automatic size tracking
- File metadata support (filename, content_type, size)
- Stream operations for large files
"""
def __init__(self, data: bytes = b'', filename: Optional[str] = None,
content_type: Optional[str] = None, metadata: Optional[dict] = None):
"""Initialize the file-like wrapper.
Args:
data: Initial binary data
filename: Original filename (if any)
content_type: MIME content type
metadata: Additional file metadata
"""
self._buffer = io.BytesIO(data)
self.filename = filename
self.content_type = content_type
self.metadata = metadata or {}
self._closed = False
@property
def closed(self) -> bool:
"""Check if the file is closed."""
return self._closed
@property
def size(self) -> int:
"""Get the size of the data in bytes."""
current_pos = self._buffer.tell()
self._buffer.seek(0, io.SEEK_END)
size = self._buffer.tell()
self._buffer.seek(current_pos)
return size
def read(self, size: int = -1) -> bytes:
"""Read and return up to size bytes.
Args:
size: Number of bytes to read. If -1, read all remaining data.
Returns:
Bytes data
"""
if self._closed:
raise ValueError("I/O operation on closed file")
return self._buffer.read(size)
def read_text(self, encoding: str = 'utf-8', errors: str = 'strict') -> str:
"""Read the entire content as text.
Args:
encoding: Text encoding to use
errors: How to handle encoding errors
Returns:
Text content
"""
data = self.read()
return data.decode(encoding, errors)
def readline(self, size: int = -1) -> bytes:
"""Read and return one line as bytes.
Args:
size: Maximum number of bytes to read
Returns:
Line data as bytes
"""
if self._closed:
raise ValueError("I/O operation on closed file")
return self._buffer.readline(size)
def readlines(self, hint: int = -1) -> List[bytes]:
"""Read and return a list of lines.
Args:
hint: Hint for number of bytes to read
Returns:
List of line bytes
"""
if self._closed:
raise ValueError("I/O operation on closed file")
return self._buffer.readlines(hint)
def write(self, data: Union[bytes, str]) -> int:
"""Write data to the buffer.
Args:
data: Data to write (bytes or string)
Returns:
Number of bytes written
"""
if self._closed:
raise ValueError("I/O operation on closed file")
if isinstance(data, str):
data = data.encode('utf-8')
return self._buffer.write(data)
def writelines(self, lines: List[Union[bytes, str]]) -> None:
"""Write a list of lines to the buffer.
Args:
lines: List of lines to write
"""
if self._closed:
raise ValueError("I/O operation on closed file")
for line in lines:
self.write(line)
def seek(self, offset: int, whence: int = io.SEEK_SET) -> int:
"""Change stream position.
Args:
offset: Stream position
whence: How to interpret offset (SEEK_SET, SEEK_CUR, SEEK_END)
Returns:
New absolute position
"""
if self._closed:
raise ValueError("I/O operation on closed file")
return self._buffer.seek(offset, whence)
def tell(self) -> int:
"""Get current stream position.
Returns:
Current position
"""
if self._closed:
raise ValueError("I/O operation on closed file")
return self._buffer.tell()
def flush(self) -> None:
"""Flush write buffers (no-op for BytesIO)."""
if self._closed:
raise ValueError("I/O operation on closed file")
self._buffer.flush()
def truncate(self, size: Optional[int] = None) -> int:
"""Truncate file to at most size bytes.
Args:
size: Size to truncate to. If None, use current position.
Returns:
New size
"""
if self._closed:
raise ValueError("I/O operation on closed file")
return self._buffer.truncate(size)
def close(self) -> None:
"""Close the file."""
if not self._closed:
self._buffer.close()
self._closed = True
def getvalue(self) -> bytes:
"""Get the entire contents as bytes.
Returns:
All data as bytes
"""
if self._closed:
raise ValueError("I/O operation on closed file")
current_pos = self._buffer.tell()
self._buffer.seek(0)
data = self._buffer.read()
self._buffer.seek(current_pos)
return data
def save_to_file(self, filepath: str, chunk_size: int = 8192) -> None:
"""Save content to a file on disk.
Args:
filepath: Path to save the file
chunk_size: Size of chunks to write
"""
with open(filepath, 'wb') as f:
self.seek(0)
while True:
chunk = self.read(chunk_size)
if not chunk:
break
f.write(chunk)
def load_from_file(self, filepath: str, chunk_size: int = 8192) -> None:
"""Load content from a file on disk.
Args:
filepath: Path to load from
chunk_size: Size of chunks to read
"""
self._buffer = io.BytesIO()
self.filename = os.path.basename(filepath)
with open(filepath, 'rb') as f:
while True:
chunk = f.read(chunk_size)
if not chunk:
break
self._buffer.write(chunk)
self._buffer.seek(0)
def copy_to_stream(self, stream: BinaryIO, chunk_size: int = 8192) -> int:
"""Copy content to another stream.
Args:
stream: Target stream to copy to
chunk_size: Size of chunks to copy
Returns:
Number of bytes copied
"""
total_bytes = 0
self.seek(0)
while True:
chunk = self.read(chunk_size)
if not chunk:
break
stream.write(chunk)
total_bytes += len(chunk)
return total_bytes
def copy_from_stream(self, stream: BinaryIO, chunk_size: int = 8192) -> int:
"""Copy content from another stream.
Args:
stream: Source stream to copy from
chunk_size: Size of chunks to copy
Returns:
Number of bytes copied
"""
self._buffer = io.BytesIO()
total_bytes = 0
while True:
chunk = stream.read(chunk_size)
if not chunk:
break
self._buffer.write(chunk)
total_bytes += len(chunk)
self._buffer.seek(0)
return total_bytes
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit."""
self.close()
def __len__(self) -> int:
"""Get the size of the data."""
return self.size
def __bool__(self) -> bool:
"""Check if there's any data."""
return self.size > 0
def __repr__(self) -> str:
"""String representation."""
status = "closed" if self._closed else "open"
return f"BytesFieldWrapper(size={self.size}, filename='{self.filename}', {status})"
class BytesField(Field):
"""Enhanced Bytes field type with file-like interface.
This field type stores binary data as byte arrays and provides validation,
conversion between Python bytes objects and SurrealDB bytes format, plus
a file-like interface for easy manipulation of binary data.
Features:
- Standard Python bytes validation and conversion
- File-like interface with read(), write(), seek(), tell() operations
- Context manager support for safe resource handling
- File metadata support (filename, content_type, custom metadata)
- Stream operations for large files
- Direct file loading/saving capabilities
Example::
class Document(SurrealDocument):
file_data = BytesField(max_size=1024*1024) # 1MB limit
# Usage examples:
doc = Document()
# File-like operations
with doc.file_data.open() as f:
f.write(b"Hello, World!")
f.seek(0)
content = f.read()
# Load from file
doc.file_data.load_from_file("/path/to/image.jpg")
print(f"Loaded {len(doc.file_data)} bytes")
# Access file properties
print(f"Filename: {doc.file_data.filename}")
print(f"Size: {doc.file_data.size} bytes")
# Save to file
doc.file_data.save_to_file("/path/to/output.jpg")
# Text operations (for text files)
doc.file_data.write_text("Hello, World!")
"""
def __init__(self, max_size: Optional[int] = None,
allowed_types: Optional[List[str]] = None, **kwargs: Any) -> None:
"""Initialize a new BytesField.
Args:
max_size: Maximum size in bytes (None for unlimited)
allowed_types: List of allowed content types/extensions
**kwargs: Additional arguments to pass to the parent class
"""
super().__init__(**kwargs)
self.py_type = bytes
self.max_size = max_size
self.allowed_types = allowed_types or []
self._wrapper = None
def validate(self, value: Any) -> Optional[bytes]:
"""Validate the bytes value.
This method checks if the value is a valid bytes object or can be
converted to bytes, and enforces size and type restrictions.
Args:
value: The value to validate
Returns:
The validated bytes value
Raises:
TypeError: If the value cannot be converted to bytes
ValueError: If the value exceeds size limits or type restrictions
"""
value = super().validate(value)
if value is not None:
# Handle BytesFieldWrapper
if isinstance(value, BytesFieldWrapper):
value = value.getvalue()
if isinstance(value, bytes):
# Check size limit
if self.max_size and len(value) > self.max_size:
raise ValueError(
f"Data size {len(value)} bytes exceeds maximum {self.max_size} bytes "
f"for field '{self.name}'"
)
return value
if isinstance(value, str):
try:
data = value.encode('utf-8')
if self.max_size and len(data) > self.max_size:
raise ValueError(
f"Data size {len(data)} bytes exceeds maximum {self.max_size} bytes "
f"for field '{self.name}'"
)
return data
except UnicodeEncodeError:
pass
raise TypeError(f"Expected bytes for field '{self.name}', got {type(value)}")
return value
def to_db(self, value: Any) -> Optional[str]:
"""Convert Python bytes to database representation.
This method converts a Python bytes object to a SurrealDB bytes format
for storage in the database.
Args:
value: The Python bytes to convert
Returns:
The SurrealDB bytes format for the database
"""
if value is None:
return None
# Handle BytesFieldWrapper
if isinstance(value, BytesFieldWrapper):
value = value.getvalue()
if isinstance(value, bytes):
# Convert bytes to SurrealDB bytes format
# SurrealDB uses <bytes>"base64_encoded_string" format
encoded = base64.b64encode(value).decode('ascii')
return f'<bytes>"{encoded}"'
if isinstance(value, str) and value.startswith('<bytes>"') and value.endswith('"'):
# If it's already in SurrealDB bytes format, return as is
return value
raise TypeError(f"Cannot convert {type(value)} to bytes")
def from_db(self, value: Any) -> Optional[BytesFieldWrapper]:
"""Convert database value to Python BytesFieldWrapper.
This method converts a SurrealDB bytes format from the database to a
BytesFieldWrapper object with file-like capabilities.
Args:
value: The database value to convert
Returns:
The BytesFieldWrapper object
"""
if value is not None:
data = None
if isinstance(value, bytes):
data = value
elif isinstance(value, str) and value.startswith('<bytes>"') and value.endswith('"'):
# Extract the base64-encoded string from <bytes>"..." format
encoded = value[8:-1] # Remove <bytes>" and "
data = base64.b64decode(encoded)
if data is not None:
return BytesFieldWrapper(data)
return value
def open(self, data: Optional[bytes] = None, **kwargs) -> BytesFieldWrapper:
"""Open a file-like interface for the bytes data.
Args:
data: Initial data (if None, uses empty bytes)
**kwargs: Additional arguments for BytesFieldWrapper
Returns:
BytesFieldWrapper instance
"""
return BytesFieldWrapper(data or b'', **kwargs)
def load_from_file(self, filepath: str, **metadata) -> BytesFieldWrapper:
"""Load data from a file and return a BytesFieldWrapper.
Args:
filepath: Path to the file to load
**metadata: Additional metadata for the wrapper
Returns:
BytesFieldWrapper with loaded data
"""
# Extract specific metadata fields for constructor
filename = metadata.get('filename')
content_type = metadata.get('content_type')
remaining_metadata = {k: v for k, v in metadata.items() if k not in ['filename', 'content_type']}
wrapper = BytesFieldWrapper(
filename=filename,
content_type=content_type,
metadata=remaining_metadata
)
wrapper.load_from_file(filepath)
return wrapper
def from_stream(self, stream: BinaryIO, **metadata) -> BytesFieldWrapper:
"""Create a BytesFieldWrapper from a stream.
Args:
stream: Source stream to read from
**metadata: Additional metadata for the wrapper
Returns:
BytesFieldWrapper with stream data
"""
# Extract specific metadata fields for constructor
filename = metadata.get('filename')
content_type = metadata.get('content_type')
remaining_metadata = {k: v for k, v in metadata.items() if k not in ['filename', 'content_type']}
wrapper = BytesFieldWrapper(
filename=filename,
content_type=content_type,
metadata=remaining_metadata
)
wrapper.copy_from_stream(stream)
return wrapper
# Convenience methods for the field instance
@property
def size(self) -> int:
"""Get the current size of stored data."""
if self._wrapper:
return self._wrapper.size
return 0
@property
def filename(self) -> Optional[str]:
"""Get the filename of stored data."""
if self._wrapper:
return self._wrapper.filename
return None
@property
def content_type(self) -> Optional[str]:
"""Get the content type of stored data."""
if self._wrapper:
return self._wrapper.content_type
return None
def read(self, size: int = -1) -> bytes:
"""Read data from the field (convenience method)."""
if self._wrapper:
return self._wrapper.read(size)
return b''
def write(self, data: Union[bytes, str]) -> int:
"""Write data to the field (convenience method)."""
if not self._wrapper:
self._wrapper = BytesFieldWrapper()
return self._wrapper.write(data)
def write_text(self, text: str, encoding: str = 'utf-8') -> int:
"""Write text data to the field.
Args:
text: Text to write
encoding: Text encoding to use
Returns:
Number of bytes written
"""
return self.write(text.encode(encoding))
def read_text(self, encoding: str = 'utf-8', errors: str = 'strict') -> str:
"""Read data as text.
Args:
encoding: Text encoding to use
errors: How to handle encoding errors
Returns:
Text content
"""
if self._wrapper:
# Reset position to read all data
current_pos = self._wrapper.tell()
self._wrapper.seek(0)
data = self._wrapper.read()
self._wrapper.seek(current_pos)
return data.decode(encoding, errors)
return ""
def save_to_file(self, filepath: str) -> None:
"""Save current data to a file."""
if self._wrapper:
self._wrapper.save_to_file(filepath)
def __len__(self) -> int:
"""Get the size of stored data."""
return self.size
def __bool__(self) -> bool:
"""Check if there's any stored data."""
return self.size > 0
class RegexField(Field):
"""Regular expression field type.
This field type stores regular expressions and provides validation and
conversion between Python regex objects and SurrealDB regex format.
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize a new RegexField.
Args:
**kwargs: Additional arguments to pass to the parent class
"""
super().__init__(**kwargs)
self.py_type = Pattern
def validate(self, value: Any) -> Optional[Pattern]:
"""Validate the regex value.
This method checks if the value is a valid regex pattern or can be
compiled into a regex pattern.
Args:
value: The value to validate
Returns:
The validated regex pattern
Raises:
TypeError: If the value cannot be converted to a regex pattern
ValueError: If the regex pattern is invalid
"""
value = super().validate(value)
if value is not None:
if isinstance(value, Pattern):
return value
if isinstance(value, str):
try:
return re.compile(value)
except re.error as e:
raise ValueError(f"Invalid regex pattern for field '{self.name}': {str(e)}")
raise TypeError(f"Expected regex pattern for field '{self.name}', got {type(value)}")
return value
def to_db(self, value: Any) -> Optional[str]:
"""Convert Python regex to database representation.
This method converts a Python regex pattern to a SurrealDB regex format
for storage in the database.
Args:
value: The Python regex pattern to convert
Returns:
The SurrealDB regex format for the database
"""
if value is None:
return None
if isinstance(value, Pattern):
# Convert regex pattern to SurrealDB regex format
# SurrealDB uses /pattern/flags format
pattern = value.pattern
flags = ""
if value.flags & re.IGNORECASE:
flags += "i"
if value.flags & re.MULTILINE:
flags += "m"
if value.flags & re.DOTALL:
flags += "s"
return f"/{pattern}/{flags}"
if isinstance(value, str):
# If it's already a string, assume it's in the correct format
return value
raise TypeError(f"Cannot convert {type(value)} to regex")
def from_db(self, value: Any) -> Optional[Pattern]:
"""Convert database value to Python regex.
This method converts a SurrealDB regex format from the database to a
Python regex pattern.
Args:
value: The database value to convert
Returns:
The Python regex pattern
"""
if value is not None:
if isinstance(value, Pattern):
return value
if isinstance(value, str) and value.startswith('/') and '/' in value[1:]:
# Parse /pattern/flags format
last_slash = value.rindex('/')
pattern = value[1:last_slash]
flags_str = value[last_slash + 1:]
flags = 0
if 'i' in flags_str:
flags |= re.IGNORECASE
if 'm' in flags_str:
flags |= re.MULTILINE
if 's' in flags_str:
flags |= re.DOTALL
return re.compile(pattern, flags)
return value
[docs]
class DecimalField(NumberField):
"""Decimal field type.
This field type stores decimal values with arbitrary precision using Python's
Decimal class. It provides validation to ensure the value is a valid decimal."""
[docs]
def __init__(self, **kwargs: Any) -> None:
"""Initialize a new DecimalField.
Args:
**kwargs: Additional arguments to pass to the parent class
"""
super().__init__(**kwargs)
self.py_type = Decimal
[docs]
def validate(self, value: Any) -> Optional[Decimal]:
"""Validate the decimal value.
This method checks if the value is a valid decimal or can be
converted to a decimal.
Args:
value: The value to validate
Returns:
The validated decimal value
Raises:
TypeError: If the value cannot be converted to a decimal
"""
value = super().validate(value)
if value is not None:
if isinstance(value, Decimal):
return value
try:
return Decimal(str(value))
except (TypeError, ValueError, decimal.InvalidOperation):
raise TypeError(f"Expected decimal for field '{self.name}', got {type(value)}")
return value
[docs]
def to_db(self, value: Any) -> Optional[str]:
"""Convert Python decimal to database representation.
This method converts a Python Decimal object to a string for storage in the database
to preserve precision.
Args:
value: The Python Decimal to convert
Returns:
The string representation for the database
"""
if value is not None:
if isinstance(value, Decimal):
return str(value)
try:
return str(Decimal(str(value)))
except (TypeError, ValueError, decimal.InvalidOperation):
pass
return value
[docs]
def from_db(self, value: Any) -> Optional[Decimal]:
"""Convert database value to Python Decimal.
This method converts a value from the database to a Python Decimal object.
Args:
value: The database value to convert
Returns:
The Python Decimal object
"""
if value is not None:
try:
return Decimal(str(value))
except (TypeError, ValueError):
pass
return value
class UUIDField(Field):
"""UUID field type.
This field type stores UUID values and provides validation and
conversion between Python UUID objects and SurrealDB string format.
Example::
class User(Document):
id = UUIDField(default=uuid.uuid4)
"""
def __init__(self, **kwargs: Any) -> None:
"""Initialize a new UUIDField.
Args:
**kwargs: Additional arguments to pass to the parent class
"""
super().__init__(**kwargs)
self.py_type = uuid.UUID
def validate(self, value: Any) -> Optional[uuid.UUID]:
"""Validate the UUID value.
This method checks if the value is a valid UUID or can be
converted to a UUID.
Args:
value: The value to validate
Returns:
The validated UUID value
Raises:
TypeError: If the value cannot be converted to a UUID
ValueError: If the UUID format is invalid
"""
value = super().validate(value)
if value is not None:
if isinstance(value, uuid.UUID):
return value
try:
return uuid.UUID(str(value))
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid UUID format for field '{self.name}': {str(e)}")
return value
def to_db(self, value: Any) -> Optional[str]:
"""Convert Python UUID to database representation.
This method converts a Python UUID object to a string for storage in the database.
Args:
value: The Python UUID to convert
Returns:
The string representation for the database
"""
if value is not None:
if isinstance(value, uuid.UUID):
return str(value)
try:
return str(uuid.UUID(str(value)))
except (TypeError, ValueError):
pass
return value
def from_db(self, value: Any) -> Optional[uuid.UUID]:
"""Convert database value to Python UUID.
This method converts a value from the database to a Python UUID object.
Args:
value: The database value to convert
Returns:
The Python UUID object
"""
if value is not None:
try:
return uuid.UUID(str(value))
except (TypeError, ValueError):
pass
return value
[docs]
class LiteralField(Field):
"""Field for union/enum-like values.
Allows a field to accept multiple different types or specific values,
similar to a union or enum type in other languages.
Example:
class Product(Document):
status = LiteralField(["active", "discontinued", "out_of_stock"])
id_or_name = LiteralField([IntField(), StringField()])
"""
[docs]
def __init__(self, allowed_values: List[Any], **kwargs: Any) -> None:
"""Initialize a new LiteralField.
Args:
allowed_values: List of allowed values or field types
**kwargs: Additional arguments to pass to the parent class
"""
self.allowed_values = allowed_values
self.allowed_fields = [v for v in allowed_values if isinstance(v, Field)]
self.allowed_literals = [v for v in allowed_values if not isinstance(v, Field)]
super().__init__(**kwargs)
self.py_type = Union[tuple(f.py_type for f in self.allowed_fields)] if self.allowed_fields else Any
[docs]
def validate(self, value: Any) -> Any:
"""Validate that the value is one of the allowed values or types.
Args:
value: The value to validate
Returns:
The validated value
Raises:
ValidationError: If the value is not one of the allowed values or types
"""
value = super().validate(value)
if value is None:
return None
# Check if the value is one of the allowed literals
if value in self.allowed_literals:
return value
# Try to validate with each allowed field type
for field in self.allowed_fields:
try:
return field.validate(value)
except (TypeError, ValueError):
continue
# If we get here, the value is not valid
if self.allowed_literals:
literals_str = ", ".join(repr(v) for v in self.allowed_literals)
error_msg = f"Value for field '{self.name}' must be one of: {literals_str}"
if self.allowed_fields:
field_types = ", ".join(f.__class__.__name__ for f in self.allowed_fields)
error_msg += f" or a valid {field_types}"
else:
field_types = ", ".join(f.__class__.__name__ for f in self.allowed_fields)
error_msg = f"Value for field '{self.name}' must be a valid {field_types}"
raise ValidationError(error_msg)
[docs]
def to_db(self, value: Any) -> Any:
"""Convert Python value to database representation.
This method converts a Python value to a database representation by
using the appropriate field type if the value is not a literal.
Args:
value: The Python value to convert
Returns:
The database representation of the value
"""
if value is None:
return None
# If it's a literal, return as is
if value in self.allowed_literals:
return value
# Try to convert with each allowed field type
for field in self.allowed_fields:
try:
field.validate(value) # Validate first to ensure it's the right type
return field.to_db(value)
except (TypeError, ValueError):
continue
return value
[docs]
def from_db(self, value: Any) -> Any:
"""Convert database value to Python representation.
This method converts a database value to a Python representation by
using the appropriate field type if the value is not a literal.
Args:
value: The database value to convert
Returns:
The Python representation of the value
"""
if value is None:
return None
# If it's a literal, return as is
if value in self.allowed_literals:
return value
# Try to convert with each allowed field type
for field in self.allowed_fields:
try:
return field.from_db(value)
except (TypeError, ValueError):
continue
return value
[docs]
class EmailField(StringField):
"""Email field type.
This field type stores email addresses and provides validation to ensure
the value is a valid email address.
Example::
class User(Document):
email = EmailField(required=True)
"""
[docs]
def __init__(self, **kwargs: Any) -> None:
"""Initialize a new EmailField.
Args:
**kwargs: Additional arguments to pass to the parent class
"""
# Add a more comprehensive regex pattern to validate email addresses
# This pattern allows more valid email characters and formats
kwargs['regex'] = r'^[a-zA-Z0-9.!#$%&\'*+/=?^_`{|}~-]+@[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(?:\.[a-zA-Z0-9](?:[a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$'
super().__init__(**kwargs)
[docs]
def validate(self, value: Any) -> Optional[str]:
"""Validate the email address.
This method checks if the value is a valid email address.
Args:
value: The value to validate
Returns:
The validated email address
Raises:
ValueError: If the email address is invalid
"""
value = super().validate(value)
if value is not None:
# Additional validation specific to email addresses
if '@' not in value:
raise ValueError(f"Invalid email address for field '{self.name}': missing @ symbol")
if value.count('@') > 1:
raise ValueError(f"Invalid email address for field '{self.name}': multiple @ symbols")
local, domain = value.split('@')
if not local:
raise ValueError(f"Invalid email address for field '{self.name}': empty local part")
if not domain:
raise ValueError(f"Invalid email address for field '{self.name}': empty domain part")
if '.' not in domain:
raise ValueError(f"Invalid email address for field '{self.name}': invalid domain")
return value
[docs]
class URLField(StringField):
"""Enhanced URL field type with urllib integration.
This field type stores URLs and provides validation using urllib.parse.
It also provides convenient access to URL components and allows flexible
URL formats including host-only URLs.
Features:
- Access URL components via properties (.scheme, .host, .path, .query, etc.)
- Allow host-only URLs (automatically adds scheme)
- Robust URL validation using urllib.parse
- Flexible scheme handling (http/https/ftp/etc.)
Example::
class Website(Document):
url = URLField(default_scheme='https', allow_host_only=True)
# Usage examples:
site = Website()
site.url = "example.com" # Auto-converts to "https://example.com"
print(site.url.host) # "example.com"
print(site.url.scheme) # "https"
print(site.url.port) # None
site.url = "https://api.example.com:8080/v1/users?active=true"
print(site.url.host) # "api.example.com"
print(site.url.port) # 8080
print(site.url.path) # "/v1/users"
print(site.url.query) # "active=true"
"""
[docs]
def __init__(self,
default_scheme: str = 'https',
allow_host_only: bool = True,
allowed_schemes: Optional[List[str]] = None,
**kwargs: Any) -> None:
"""Initialize a new enhanced URLField.
Args:
default_scheme: Default scheme to use for host-only URLs
allow_host_only: Whether to allow host-only URLs (will add default_scheme)
allowed_schemes: List of allowed schemes (None = allow all)
**kwargs: Additional arguments to pass to the parent class
"""
self.default_scheme = default_scheme
self.allow_host_only = allow_host_only
self.allowed_schemes = allowed_schemes or ['http', 'https', 'ftp', 'ftps']
# Remove the basic regex validation from parent
if 'regex' in kwargs:
del kwargs['regex']
super().__init__(**kwargs)
self._parsed_url = None
[docs]
def validate(self, value: Any) -> Optional[str]:
"""Validate and normalize the URL.
This method uses urllib.parse for robust URL validation and
automatically adds schemes to host-only URLs if allowed.
Args:
value: The value to validate
Returns:
The validated and normalized URL
Raises:
ValueError: If the URL is invalid
"""
# First run parent validation (handles None, basic string checks)
value = super(StringField, self).validate(value) # Skip StringField's regex
if value is not None:
original_value = value
# Handle host-only URLs
if self.allow_host_only and '://' not in value:
# Check if it looks like a valid hostname/domain
if self._is_valid_hostname(value):
value = f"{self.default_scheme}://{value}"
else:
raise ValueError(f"Invalid hostname for field '{self.name}': {original_value}")
# Parse and validate the URL
try:
parsed = urllib.parse.urlparse(value)
self._parsed_url = parsed
# Validate scheme
if parsed.scheme not in self.allowed_schemes:
allowed_str = ', '.join(self.allowed_schemes)
raise ValueError(f"Invalid URL scheme for field '{self.name}': '{parsed.scheme}'. Allowed: {allowed_str}")
# Validate that we have at least a netloc (host)
if not parsed.netloc:
raise ValueError(f"Invalid URL for field '{self.name}': missing host")
# Reconstruct the URL to ensure it's properly formatted
return urllib.parse.urlunparse(parsed)
except Exception as e:
raise ValueError(f"Invalid URL for field '{self.name}': {str(e)}")
return value
def _is_valid_hostname(self, hostname: str) -> bool:
"""Check if a string is a valid hostname/domain.
Args:
hostname: The hostname to validate
Returns:
True if valid hostname, False otherwise
"""
# Basic hostname validation
hostname_pattern = r'^[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?(\.[a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)*$'
# Allow localhost and IP addresses
if hostname in ['localhost', '127.0.0.1'] or hostname.startswith('192.168.') or hostname.startswith('10.'):
return True
return bool(re.match(hostname_pattern, hostname)) and len(hostname) <= 253
# URL Component Properties
@property
def scheme(self) -> Optional[str]:
"""Get the URL scheme (protocol)."""
if self._parsed_url:
return self._parsed_url.scheme
return None
@property
def host(self) -> Optional[str]:
"""Get the URL host/domain."""
if self._parsed_url:
return self._parsed_url.hostname
return None
@property
def hostname(self) -> Optional[str]:
"""Alias for host property."""
return self.host
@property
def port(self) -> Optional[int]:
"""Get the URL port."""
if self._parsed_url:
return self._parsed_url.port
return None
@property
def path(self) -> str:
"""Get the URL path."""
if self._parsed_url:
return self._parsed_url.path
return ""
@property
def query(self) -> str:
"""Get the URL query string."""
if self._parsed_url:
return self._parsed_url.query
return ""
@property
def fragment(self) -> str:
"""Get the URL fragment (hash)."""
if self._parsed_url:
return self._parsed_url.fragment
return ""
@property
def netloc(self) -> str:
"""Get the network location (host:port)."""
if self._parsed_url:
return self._parsed_url.netloc
return ""
@property
def params(self) -> str:
"""Get the URL parameters."""
if self._parsed_url:
return self._parsed_url.params
return ""
[docs]
def get_query_params(self) -> Dict[str, str]:
"""Parse query string into a dictionary.
Returns:
Dictionary of query parameters
"""
if self._parsed_url and self._parsed_url.query:
return dict(urllib.parse.parse_qsl(self._parsed_url.query))
return {}
[docs]
def get_query_param(self, param_name: str, default: Any = None) -> Any:
"""Get a specific query parameter value.
Args:
param_name: Name of the parameter to get
default: Default value if parameter not found
Returns:
Parameter value or default
"""
params = self.get_query_params()
return params.get(param_name, default)
[docs]
def is_secure(self) -> bool:
"""Check if the URL uses a secure scheme (https/ftps)."""
if self._parsed_url:
return self._parsed_url.scheme in ['https', 'ftps']
return False
[docs]
def get_base_url(self) -> str:
"""Get the base URL (scheme + netloc).
Returns:
Base URL string
"""
if self._parsed_url:
return f"{self._parsed_url.scheme}://{self._parsed_url.netloc}"
return ""
[docs]
def to_db(self, value: Any) -> Optional[str]:
"""Convert Python URL to database representation.
Args:
value: The Python URL to convert
Returns:
The string representation for the database
"""
# The validated value is already a proper URL string
return value
[docs]
def from_db(self, value: Any) -> Optional[str]:
"""Convert database value to Python URL.
Args:
value: The database value to convert
Returns:
The Python URL string with parsed components available
"""
if value is not None:
# Re-validate and parse the URL from database
try:
return self.validate(value)
except ValueError:
pass
return value
[docs]
def __str__(self) -> str:
"""String representation of the URL."""
if self._parsed_url:
return urllib.parse.urlunparse(self._parsed_url)
return super().__str__()
[docs]
def __repr__(self) -> str:
"""Detailed representation of the URL."""
return f"URLField('{self.__str__()}')"
[docs]
class IPAddressField(StringField):
"""IP address field type.
This field type stores IP addresses and provides validation to ensure
the value is a valid IPv4 or IPv6 address.
Example::
class Server(Document):
ip_address = IPAddressField(required=True)
ip_v4 = IPAddressField(ipv4_only=True)
ip_v6 = IPAddressField(ipv6_only=True)
"""
[docs]
def __init__(self, ipv4_only: bool = False, ipv6_only: bool = False, version: str = None, **kwargs: Any) -> None:
"""Initialize a new IPAddressField.
Args:
ipv4_only: Whether to only allow IPv4 addresses
ipv6_only: Whether to only allow IPv6 addresses
version: IP version to validate ('ipv4', 'ipv6', or 'both')
**kwargs: Additional arguments to pass to the parent class
"""
# Handle version parameter for backward compatibility
if version is not None:
version = version.lower()
if version not in ('ipv4', 'ipv6', 'both'):
raise ValueError("version must be 'ipv4', 'ipv6', or 'both'")
ipv4_only = (version == 'ipv4')
ipv6_only = (version == 'ipv6')
self.ipv4_only = ipv4_only
self.ipv6_only = ipv6_only
if ipv4_only and ipv6_only:
raise ValueError("Cannot set both ipv4_only and ipv6_only to True")
# Remove version from kwargs to avoid passing it to the parent class
# This prevents it from being included in the schema definition
if 'version' in kwargs:
del kwargs['version']
super().__init__(**kwargs)
[docs]
def validate(self, value: Any) -> Optional[str]:
"""Validate the IP address.
This method checks if the value is a valid IP address.
Args:
value: The value to validate
Returns:
The validated IP address
Raises:
ValueError: If the IP address is invalid
"""
value = super().validate(value)
if value is not None:
# Validate IPv4 address
if self.ipv4_only or not self.ipv6_only:
ipv4_pattern = r'^(\d{1,3}\.){3}\d{1,3}$'
if re.match(ipv4_pattern, value):
# Check that each octet is in the valid range
octets = value.split('.')
try:
if all(0 <= int(octet) <= 255 for octet in octets):
return value
if self.ipv4_only:
raise ValueError(f"Invalid IPv4 address for field '{self.name}': octets must be between 0 and 255")
except ValueError:
if self.ipv4_only:
raise ValueError(f"Invalid IPv4 address for field '{self.name}': octets must be numeric")
# Validate IPv6 address
if self.ipv6_only or not self.ipv4_only:
try:
# Use socket.inet_pton to validate IPv6 address
socket.inet_pton(socket.AF_INET6, value)
return value
except (socket.error, ValueError):
if self.ipv6_only:
raise ValueError(f"Invalid IPv6 address for field '{self.name}'")
# If we get here, the value is not a valid IP address
if self.ipv4_only:
raise ValueError(f"Invalid IPv4 address for field '{self.name}'")
elif self.ipv6_only:
raise ValueError(f"Invalid IPv6 address for field '{self.name}'")
else:
raise ValueError(f"Invalid IP address for field '{self.name}'")
return value
[docs]
class SlugField(StringField):
"""Slug field type.
This field type stores slugs (URL-friendly strings) and provides validation
to ensure the value is a valid slug.
Example::
class Article(Document):
slug = SlugField(required=True)
"""
[docs]
def __init__(self, **kwargs: Any) -> None:
"""Initialize a new SlugField.
Args:
**kwargs: Additional arguments to pass to the parent class
"""
# Add a regex pattern to validate slugs
kwargs['regex'] = r'^[a-z0-9]+(?:-[a-z0-9]+)*$'
super().__init__(**kwargs)
[docs]
def validate(self, value: Any) -> Optional[str]:
"""Validate the slug.
This method checks if the value is a valid slug.
Args:
value: The value to validate
Returns:
The validated slug
Raises:
ValueError: If the slug is invalid
"""
value = super().validate(value)
if value is not None:
# Additional validation specific to slugs
if not value:
raise ValueError(f"Slug for field '{self.name}' cannot be empty")
if value.startswith('-') or value.endswith('-'):
raise ValueError(f"Slug for field '{self.name}' cannot start or end with a hyphen")
if '--' in value:
raise ValueError(f"Slug for field '{self.name}' cannot contain consecutive hyphens")
return value
[docs]
class ChoiceField(Field):
"""Choice field type.
This field type stores values from a predefined set of choices and provides
validation to ensure the value is one of the allowed choices.
Example::
class Product(Document):
status = ChoiceField(choices=['active', 'inactive', 'discontinued'])
"""
[docs]
def __init__(self, choices: List[Union[str, tuple]], **kwargs: Any) -> None:
"""Initialize a new ChoiceField.
Args:
choices: List of allowed choices. Each choice can be a string or a tuple
of (value, display_name).
**kwargs: Additional arguments to pass to the parent class
"""
self.choices = choices
self.values = [c[0] if isinstance(c, tuple) else c for c in choices]
super().__init__(**kwargs)
self.py_type = str
[docs]
def validate(self, value: Any) -> Optional[str]:
"""Validate the choice value.
This method checks if the value is one of the allowed choices.
Args:
value: The value to validate
Returns:
The validated choice value
Raises:
ValueError: If the value is not one of the allowed choices
"""
value = super().validate(value)
if value is not None and value not in self.values:
choices_str = ", ".join(repr(v) for v in self.values)
raise ValueError(f"Value for field '{self.name}' must be one of: {choices_str}")
return value