""" Authentication and authorization layer for LangGraph Multi-Agent MCTS Framework. Provides: - API key authentication with secure hashing - JWT token support (optional) - Rate limiting per client - Role-based access control """ import hashlib import secrets import time from collections import defaultdict from dataclasses import dataclass, field from datetime import datetime, timedelta from src.api.exceptions import ( AuthenticationError, AuthorizationError, RateLimitError, ) @dataclass class ClientInfo: """Information about an authenticated client.""" client_id: str roles: set[str] = field(default_factory=lambda: {"user"}) created_at: datetime = field(default_factory=datetime.utcnow) last_access: datetime = field(default_factory=datetime.utcnow) request_count: int = 0 @dataclass class RateLimitConfig: """Rate limiting configuration.""" requests_per_minute: int = 60 requests_per_hour: int = 1000 requests_per_day: int = 10000 burst_limit: int = 100 # Max requests in 1 second class APIKeyAuthenticator: """ API key-based authentication with secure hashing. Keys are stored as SHA-256 hashes to prevent exposure. """ def __init__( self, valid_keys: list[str] | None = None, rate_limit_config: RateLimitConfig | None = None, ): """ Initialize authenticator. Args: valid_keys: List of valid API keys (will be hashed) rate_limit_config: Rate limiting configuration """ self._key_to_client: dict[str, ClientInfo] = {} self._rate_limits: dict[str, list[float]] = defaultdict(list) self.rate_limit_config = rate_limit_config or RateLimitConfig() # Hash and store initial keys if valid_keys: for i, key in enumerate(valid_keys): client_id = f"client_{i}" self._add_key(key, client_id) def _hash_key(self, api_key: str) -> str: """ Securely hash an API key. Uses SHA-256 with consistent encoding. """ return hashlib.sha256(api_key.encode("utf-8")).hexdigest() def _add_key(self, api_key: str, client_id: str, roles: set[str] | None = None) -> None: """ Add a new API key. Args: api_key: Raw API key client_id: Client identifier roles: Set of roles (defaults to {"user"}) """ key_hash = self._hash_key(api_key) self._key_to_client[key_hash] = ClientInfo( client_id=client_id, roles=roles or {"user"}, ) def authenticate(self, api_key: str | None) -> ClientInfo: """ Authenticate an API key. Args: api_key: API key to validate Returns: ClientInfo for the authenticated client Raises: AuthenticationError: If authentication fails """ if not api_key: raise AuthenticationError( user_message="API key is required", internal_details="No API key provided in request", ) # Constant-time comparison to prevent timing attacks key_hash = self._hash_key(api_key) if key_hash not in self._key_to_client: raise AuthenticationError( user_message="Invalid API key", internal_details=f"API key hash not found: {key_hash[:16]}...", ) client_info = self._key_to_client[key_hash] client_info.last_access = datetime.utcnow() client_info.request_count += 1 # Check rate limits self._check_rate_limit(client_info.client_id) return client_info def _check_rate_limit(self, client_id: str) -> None: """ Check if client has exceeded rate limits. Args: client_id: Client identifier Raises: RateLimitError: If rate limit exceeded """ now = time.time() request_times = self._rate_limits[client_id] # Clean old entries one_day_ago = now - 86400 request_times = [t for t in request_times if t > one_day_ago] self._rate_limits[client_id] = request_times # Check burst limit (1 second window) one_second_ago = now - 1 burst_count = sum(1 for t in request_times if t > one_second_ago) if burst_count >= self.rate_limit_config.burst_limit: raise RateLimitError( user_message="Too many requests. Please slow down.", internal_details=f"Client {client_id} exceeded burst limit: {burst_count}/{self.rate_limit_config.burst_limit}", retry_after_seconds=1, ) # Check per-minute limit one_minute_ago = now - 60 minute_count = sum(1 for t in request_times if t > one_minute_ago) if minute_count >= self.rate_limit_config.requests_per_minute: raise RateLimitError( user_message="Rate limit exceeded. Please wait a minute.", internal_details=f"Client {client_id} exceeded minute limit: {minute_count}/{self.rate_limit_config.requests_per_minute}", retry_after_seconds=60, ) # Check per-hour limit one_hour_ago = now - 3600 hour_count = sum(1 for t in request_times if t > one_hour_ago) if hour_count >= self.rate_limit_config.requests_per_hour: raise RateLimitError( user_message="Hourly rate limit exceeded. Please try again later.", internal_details=f"Client {client_id} exceeded hour limit: {hour_count}/{self.rate_limit_config.requests_per_hour}", retry_after_seconds=3600, ) # Check per-day limit day_count = len(request_times) if day_count >= self.rate_limit_config.requests_per_day: raise RateLimitError( user_message="Daily rate limit exceeded. Please try again tomorrow.", internal_details=f"Client {client_id} exceeded day limit: {day_count}/{self.rate_limit_config.requests_per_day}", retry_after_seconds=86400, ) # Record this request request_times.append(now) def require_auth(self, api_key: str | None) -> ClientInfo: """ Require authentication for a request. Convenience method that raises on failure. Args: api_key: API key to validate Returns: ClientInfo for authenticated client Raises: AuthenticationError: If authentication fails """ return self.authenticate(api_key) def require_role(self, client_info: ClientInfo, required_role: str) -> None: """ Require a specific role for an operation. Args: client_info: Authenticated client info required_role: Role that is required Raises: AuthorizationError: If client doesn't have required role """ if required_role not in client_info.roles: raise AuthorizationError( user_message="You do not have permission for this operation", internal_details=f"Client {client_info.client_id} missing role: {required_role}", required_permission=required_role, ) def generate_api_key(self) -> str: """ Generate a secure random API key. Returns: New API key (32 bytes hex = 64 characters) """ return secrets.token_hex(32) def revoke_key(self, api_key: str) -> bool: """ Revoke an API key. Args: api_key: Key to revoke Returns: True if key was revoked, False if not found """ key_hash = self._hash_key(api_key) if key_hash in self._key_to_client: del self._key_to_client[key_hash] return True return False def add_client( self, client_id: str, roles: set[str] | None = None, ) -> str: """ Add a new client and generate their API key. Args: client_id: Unique client identifier roles: Set of roles for the client Returns: Generated API key (save this securely!) """ api_key = self.generate_api_key() self._add_key(api_key, client_id, roles) return api_key def get_client_stats(self, client_id: str) -> dict: """ Get statistics for a client. Args: client_id: Client identifier Returns: Dictionary with client statistics """ now = time.time() request_times = self._rate_limits.get(client_id, []) return { "total_requests_today": len([t for t in request_times if t > now - 86400]), "requests_last_hour": len([t for t in request_times if t > now - 3600]), "requests_last_minute": len([t for t in request_times if t > now - 60]), } class JWTAuthenticator: """ JWT token-based authentication. Note: Requires PyJWT library for full functionality. This is a placeholder for JWT support. """ def __init__(self, secret_key: str, algorithm: str = "HS256"): """ Initialize JWT authenticator. Args: secret_key: Secret key for signing tokens algorithm: JWT signing algorithm """ self.secret_key = secret_key self.algorithm = algorithm self._token_blacklist: set[str] = set() def create_token( self, client_id: str, roles: set[str], expires_in_hours: int = 24, ) -> str: """ Create a JWT token. Args: client_id: Client identifier roles: Client roles expires_in_hours: Token validity period Returns: JWT token string """ try: import jwt except ImportError: raise ImportError("PyJWT library required for JWT authentication. Install with: pip install PyJWT") now = datetime.utcnow() payload = { "sub": client_id, "roles": list(roles), "iat": now, "exp": now + timedelta(hours=expires_in_hours), "jti": secrets.token_hex(16), # Unique token ID } return jwt.encode(payload, self.secret_key, algorithm=self.algorithm) def verify_token(self, token: str) -> ClientInfo: """ Verify a JWT token. Args: token: JWT token string Returns: ClientInfo from token claims Raises: AuthenticationError: If token is invalid """ try: import jwt except ImportError: raise ImportError("PyJWT library required for JWT authentication") if token in self._token_blacklist: raise AuthenticationError( user_message="Token has been revoked", internal_details="Token found in blacklist", ) try: payload = jwt.decode( token, self.secret_key, algorithms=[self.algorithm], ) return ClientInfo( client_id=payload["sub"], roles=set(payload.get("roles", ["user"])), ) except jwt.ExpiredSignatureError: raise AuthenticationError( user_message="Token has expired", internal_details="JWT signature expired", ) except jwt.InvalidTokenError as e: raise AuthenticationError( user_message="Invalid token", internal_details=f"JWT validation failed: {str(e)}", ) def revoke_token(self, token: str) -> None: """ Revoke a JWT token by adding to blacklist. Args: token: Token to revoke """ self._token_blacklist.add(token) # Default authenticator instance _default_authenticator: APIKeyAuthenticator | None = None def get_authenticator() -> APIKeyAuthenticator: """ Get or create the default authenticator instance. Returns: APIKeyAuthenticator instance """ global _default_authenticator if _default_authenticator is None: _default_authenticator = APIKeyAuthenticator() return _default_authenticator def set_authenticator(authenticator: APIKeyAuthenticator) -> None: """ Set the default authenticator instance. Args: authenticator: Authenticator to use """ global _default_authenticator _default_authenticator = authenticator # Exports __all__ = [ "APIKeyAuthenticator", "JWTAuthenticator", "ClientInfo", "RateLimitConfig", "get_authenticator", "set_authenticator", ]