""" Base LLM client interface for provider-agnostic model access. This module defines the protocol and data structures for LLM clients, enabling seamless switching between providers (OpenAI, Anthropic, LM Studio, etc.) """ import asyncio import time from abc import ABC, abstractmethod from collections.abc import AsyncIterator from dataclasses import dataclass, field from datetime import datetime from typing import Any, Protocol, runtime_checkable @dataclass class LLMResponse: """Standardized response from any LLM provider.""" text: str usage: dict = field(default_factory=dict) model: str = "" raw_response: Any = None finish_reason: str = "stop" created_at: datetime = field(default_factory=datetime.utcnow) @property def total_tokens(self) -> int: """Total tokens used in request/response.""" return self.usage.get("total_tokens", 0) @property def prompt_tokens(self) -> int: """Tokens used in prompt.""" return self.usage.get("prompt_tokens", 0) @property def completion_tokens(self) -> int: """Tokens used in completion.""" return self.usage.get("completion_tokens", 0) @dataclass class ToolCall: """Represents a tool/function call from the LLM.""" id: str name: str arguments: dict type: str = "function" @dataclass class LLMToolResponse(LLMResponse): """Response containing tool calls.""" tool_calls: list[ToolCall] = field(default_factory=list) class TokenBucketRateLimiter: """ Token bucket rate limiter for controlling request rates. This implementation uses a token bucket algorithm where: - Tokens are added at a fixed rate (rate_per_second) - Each request consumes one token - If no tokens available, caller waits until one becomes available """ def __init__(self, rate_per_minute: int = 60): """ Initialize the rate limiter. Args: rate_per_minute: Maximum requests allowed per minute """ self.rate_per_second = rate_per_minute / 60.0 self.max_tokens = float(rate_per_minute) self.tokens = self.max_tokens self.last_refill = time.monotonic() self._lock = asyncio.Lock() self._wait_count = 0 self._total_wait_time = 0.0 async def acquire(self) -> float: """ Acquire a token, waiting if necessary. Returns: Time spent waiting (0.0 if no wait was needed) """ async with self._lock: now = time.monotonic() elapsed = now - self.last_refill # Refill tokens based on elapsed time self.tokens = min(self.max_tokens, self.tokens + elapsed * self.rate_per_second) self.last_refill = now wait_time = 0.0 if self.tokens < 1: # Calculate how long to wait for one token wait_time = (1 - self.tokens) / self.rate_per_second self._wait_count += 1 self._total_wait_time += wait_time # Release lock during sleep to allow other operations self._lock.release() try: await asyncio.sleep(wait_time) finally: await self._lock.acquire() # After sleeping, update time and set tokens to 0 self.last_refill = time.monotonic() self.tokens = 0 else: self.tokens -= 1 return wait_time @property def stats(self) -> dict: """Get rate limiter statistics.""" return { "rate_limit_waits": self._wait_count, "total_rate_limit_wait_time": self._total_wait_time, "current_tokens": self.tokens, } @runtime_checkable class LLMClient(Protocol): """ Protocol for LLM clients. This protocol defines the interface that all LLM provider adapters must implement. Using Protocol allows for structural subtyping (duck typing) while maintaining type safety. """ async def generate( self, *, messages: list[dict] | None = None, prompt: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, tools: list[dict] | None = None, stream: bool = False, stop: list[str] | None = None, **kwargs: Any, ) -> LLMResponse | AsyncIterator[str]: """ Generate a response from the LLM. Args: messages: List of message dicts in OpenAI format [{"role": "...", "content": "..."}] prompt: Simple string prompt (converted to single user message) temperature: Sampling temperature (0.0 to 2.0) max_tokens: Maximum tokens to generate tools: List of tool definitions for function calling stream: If True, returns AsyncIterator[str] for streaming stop: Stop sequences **kwargs: Provider-specific parameters Returns: LLMResponse if stream=False, AsyncIterator[str] if stream=True Raises: LLMClientError: Base exception for all client errors """ ... class BaseLLMClient(ABC): """ Abstract base class for LLM clients. Provides common functionality and enforces the interface contract. All concrete implementations should inherit from this class. """ def __init__( self, api_key: str | None = None, model: str = "default", base_url: str | None = None, timeout: float = 60.0, max_retries: int = 3, rate_limit_per_minute: int | None = None, ): """ Initialize the LLM client. Args: api_key: API key for authentication model: Model identifier base_url: Base URL for API requests timeout: Request timeout in seconds max_retries: Maximum number of retry attempts rate_limit_per_minute: Rate limit (requests per minute), None to disable """ self.api_key = api_key self.model = model self.base_url = base_url self.timeout = timeout self.max_retries = max_retries self._request_count = 0 self._total_tokens_used = 0 self._rate_limited_requests = 0 # Initialize rate limiter if configured if rate_limit_per_minute is not None and rate_limit_per_minute > 0: self._rate_limiter: TokenBucketRateLimiter | None = TokenBucketRateLimiter( rate_per_minute=rate_limit_per_minute ) else: self._rate_limiter = None @abstractmethod async def generate( self, *, messages: list[dict] | None = None, prompt: str | None = None, temperature: float = 0.7, max_tokens: int | None = None, tools: list[dict] | None = None, stream: bool = False, stop: list[str] | None = None, **kwargs: Any, ) -> LLMResponse | AsyncIterator[str]: """Generate a response from the LLM.""" pass def _build_messages( self, messages: list[dict] | None = None, prompt: str | None = None, ) -> list[dict]: """ Build message list from either messages or prompt. Args: messages: Pre-formatted message list prompt: Simple string prompt Returns: List of message dicts Raises: ValueError: If neither messages nor prompt provided """ if messages is not None: return messages elif prompt is not None: return [{"role": "user", "content": prompt}] else: raise ValueError("Either 'messages' or 'prompt' must be provided") def _update_stats(self, response: LLMResponse) -> None: """Update internal statistics.""" self._request_count += 1 self._total_tokens_used += response.total_tokens async def _apply_rate_limit(self) -> None: """ Apply rate limiting if configured. Waits if necessary to comply with rate limits. Tracks rate-limited requests in metrics. """ if self._rate_limiter is not None: wait_time = await self._rate_limiter.acquire() if wait_time > 0: self._rate_limited_requests += 1 @property def stats(self) -> dict: """Get client statistics.""" base_stats = { "request_count": self._request_count, "total_tokens_used": self._total_tokens_used, "rate_limited_requests": self._rate_limited_requests, } # Include rate limiter stats if available if self._rate_limiter is not None: base_stats.update(self._rate_limiter.stats) return base_stats async def close(self) -> None: # noqa: B027 """Clean up resources. Override in subclasses if needed.""" pass async def __aenter__(self): """Async context manager entry.""" return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit.""" await self.close()