Spaces:
Sleeping
Sleeping
File size: 7,343 Bytes
40ee6b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 |
"""
LLM Client Factory and Provider Registry.
This module provides a factory function to instantiate the correct LLM client
based on provider settings, with lazy loading of adapters.
"""
import importlib
import logging
from typing import Any
from .base import BaseLLMClient, LLMClient, LLMResponse, LLMToolResponse, ToolCall
from .exceptions import (
CircuitBreakerOpenError,
LLMAuthenticationError,
LLMClientError,
LLMConnectionError,
LLMContentFilterError,
LLMContextLengthError,
LLMInvalidRequestError,
LLMModelNotFoundError,
LLMQuotaExceededError,
LLMRateLimitError,
LLMResponseParseError,
LLMServerError,
LLMStreamError,
LLMTimeoutError,
)
logger = logging.getLogger(__name__)
# Provider registry with lazy loading
# Maps provider name to (module_path, class_name)
_PROVIDER_REGISTRY: dict[str, tuple[str, str]] = {
"openai": ("src.adapters.llm.openai_client", "OpenAIClient"),
"anthropic": ("src.adapters.llm.anthropic_client", "AnthropicClient"),
"lmstudio": ("src.adapters.llm.lmstudio_client", "LMStudioClient"),
"local": ("src.adapters.llm.lmstudio_client", "LMStudioClient"), # Alias
}
# Cache for loaded client classes
_CLIENT_CACHE: dict[str, type[BaseLLMClient]] = {}
def register_provider(name: str, module_path: str, class_name: str, override: bool = False) -> None:
"""
Register a new LLM provider.
Args:
name: Provider identifier (e.g., "azure", "bedrock")
module_path: Full module path (e.g., "src.adapters.llm.azure_client")
class_name: Class name in the module (e.g., "AzureOpenAIClient")
override: If True, allow overriding existing provider
"""
if name in _PROVIDER_REGISTRY and not override:
raise ValueError(f"Provider '{name}' already registered. Use override=True to replace.")
_PROVIDER_REGISTRY[name] = (module_path, class_name)
# Clear cache if overriding
if name in _CLIENT_CACHE:
del _CLIENT_CACHE[name]
logger.info(f"Registered LLM provider: {name} -> {module_path}.{class_name}")
def list_providers() -> list[str]:
"""
List all registered provider names.
Returns:
List of provider identifiers
"""
return list(_PROVIDER_REGISTRY.keys())
def get_provider_class(provider: str) -> type[BaseLLMClient]:
"""
Get the client class for a provider (with lazy loading).
Args:
provider: Provider identifier
Returns:
Client class (not instantiated)
Raises:
ValueError: If provider not registered
ImportError: If module cannot be loaded
"""
if provider not in _PROVIDER_REGISTRY:
available = ", ".join(list_providers())
raise ValueError(f"Unknown provider '{provider}'. Available: {available}")
# Check cache first
if provider in _CLIENT_CACHE:
return _CLIENT_CACHE[provider]
# Lazy load the module
module_path, class_name = _PROVIDER_REGISTRY[provider]
try:
module = importlib.import_module(module_path)
client_class = getattr(module, class_name)
except ImportError as e:
raise ImportError(f"Failed to load provider '{provider}': {e}") from e
except AttributeError as e:
raise ImportError(f"Class '{class_name}' not found in module '{module_path}'") from e
# Cache for future use
_CLIENT_CACHE[provider] = client_class
return client_class
def create_client(
provider: str = "openai",
*,
api_key: str | None = None,
model: str | None = None,
base_url: str | None = None,
timeout: float | None = None,
max_retries: int | None = None,
**kwargs: Any,
) -> BaseLLMClient:
"""
Create an LLM client instance.
This is the main factory function for creating provider clients.
Args:
provider: Provider name ("openai", "anthropic", "lmstudio", etc.)
api_key: API key (may be optional for some providers)
model: Model identifier
base_url: Base URL for API
timeout: Request timeout in seconds
max_retries: Maximum retry attempts
**kwargs: Provider-specific parameters
Returns:
Configured LLMClient instance
Examples:
# OpenAI client
client = create_client("openai", model="gpt-4-turbo-preview")
# Anthropic client
client = create_client("anthropic", model="sonnet")
# Local LM Studio
client = create_client("lmstudio", base_url="http://localhost:1234/v1")
# With custom settings
client = create_client(
"openai",
api_key="sk-...",
timeout=120.0,
max_retries=5,
organization="org-..."
)
"""
client_class = get_provider_class(provider)
# Build kwargs for client initialization
init_kwargs = {**kwargs}
if api_key is not None:
init_kwargs["api_key"] = api_key
if model is not None:
init_kwargs["model"] = model
if base_url is not None:
init_kwargs["base_url"] = base_url
if timeout is not None:
init_kwargs["timeout"] = timeout
if max_retries is not None:
init_kwargs["max_retries"] = max_retries
logger.info(f"Creating {provider} client with model={model or 'default'}")
return client_class(**init_kwargs)
def create_client_from_config(config: dict) -> BaseLLMClient:
"""
Create an LLM client from a configuration dictionary.
Useful for loading settings from YAML/JSON config files.
Args:
config: Configuration dictionary with keys:
- provider: Required provider name
- Other keys passed to create_client
Returns:
Configured LLMClient instance
Example:
config = {
"provider": "openai",
"model": "gpt-4-turbo-preview",
"timeout": 60.0,
"max_retries": 3
}
client = create_client_from_config(config)
"""
config = config.copy()
provider = config.pop("provider", "openai")
return create_client(provider, **config)
# Convenience aliases for common use cases
def create_openai_client(**kwargs) -> BaseLLMClient:
"""Create an OpenAI client."""
return create_client("openai", **kwargs)
def create_anthropic_client(**kwargs) -> BaseLLMClient:
"""Create an Anthropic Claude client."""
return create_client("anthropic", **kwargs)
def create_local_client(**kwargs) -> BaseLLMClient:
"""Create a local LM Studio client."""
return create_client("lmstudio", **kwargs)
__all__ = [
# Base types
"LLMClient",
"LLMResponse",
"LLMToolResponse",
"ToolCall",
"BaseLLMClient",
# Exceptions
"LLMClientError",
"LLMAuthenticationError",
"LLMRateLimitError",
"LLMQuotaExceededError",
"LLMModelNotFoundError",
"LLMContextLengthError",
"LLMInvalidRequestError",
"LLMTimeoutError",
"LLMConnectionError",
"LLMServerError",
"LLMResponseParseError",
"LLMStreamError",
"LLMContentFilterError",
"CircuitBreakerOpenError",
# Factory functions
"create_client",
"create_client_from_config",
"create_openai_client",
"create_anthropic_client",
"create_local_client",
# Registry functions
"register_provider",
"list_providers",
"get_provider_class",
]
|