# llm_router.py - ZeroGPU Chat API (RunPod) import logging import asyncio import aiohttp import time from typing import Dict, Optional from .models_config import LLM_CONFIG from .config import get_settings logger = logging.getLogger(__name__) class LLMRouter: def __init__(self, hf_token=None, use_local_models: bool = False): """ Initialize LLM Router with ZeroGPU Chat API (RunPod). Args: hf_token: Not used (kept for backward compatibility) use_local_models: Must be False (local models disabled) """ if use_local_models: raise ValueError("Local models are disabled. Only ZeroGPU Chat API is supported.") self.settings = get_settings() self.base_url = self.settings.zerogpu_base_url.rstrip('/') self.access_token = None self.refresh_token = None self.token_expires_at = 0 self.session = None # Validate base URL if not self.settings.zerogpu_base_url: raise ValueError( "ZEROGPU_BASE_URL is required. " "Set it in environment variables or .env file" ) # Validate credentials if not self.settings.zerogpu_email or not self.settings.zerogpu_password: raise ValueError( "ZEROGPU_EMAIL and ZEROGPU_PASSWORD are required. " "Set them in environment variables or .env file" ) logger.info("ZeroGPU Chat API client initializing") logger.info(f"Base URL: {self.base_url}") # Initialize session and authenticate try: # Authentication will happen on first request if needed logger.info("ZeroGPU Chat API client initialized (authentication on first request)") except Exception as e: logger.error(f"Failed to initialize ZeroGPU Chat API client: {e}") raise RuntimeError(f"Could not initialize ZeroGPU Chat API client: {e}") from e async def route_inference(self, task_type: str, prompt: str, **kwargs): """ Route inference to ZeroGPU Chat API. Args: task_type: Type of task (general_reasoning, intent_classification, etc.) prompt: Input prompt **kwargs: Additional parameters (max_tokens, temperature, etc.) Returns: Generated text response """ logger.info(f"Routing inference to ZeroGPU Chat API for task: {task_type}") try: # Ensure authenticated await self._ensure_authenticated() # Map internal task types to API task types api_task = self._map_task_type(task_type) # Pass original task type for model config lookup kwargs['original_task_type'] = task_type # Handle embedding generation (may need special handling) if task_type == "embedding_generation": logger.warning("Embedding generation via ZeroGPU API may require special implementation") result = await self._call_zerogpu_api(api_task, prompt, **kwargs) else: result = await self._call_zerogpu_api(api_task, prompt, **kwargs) if result is None: logger.error(f"ZeroGPU Chat API returned None for task: {task_type}") raise RuntimeError(f"Inference failed for task: {task_type}") logger.info(f"Inference complete for {task_type} (ZeroGPU Chat API)") return result except Exception as e: logger.error(f"ZeroGPU Chat API inference failed: {e}", exc_info=True) raise RuntimeError( f"Inference failed for task: {task_type}. " f"ZeroGPU Chat API error: {e}" ) from e async def _ensure_authenticated(self): """Ensure we have a valid access token, login if needed.""" # Check if token is expired (with 60 second buffer) if self.access_token and time.time() < (self.token_expires_at - 60): return # Create session if needed if self.session is None: self.session = aiohttp.ClientSession() # Login to get tokens await self._login() async def _login(self): """Login to ZeroGPU Chat API and get access/refresh tokens.""" try: login_url = f"{self.base_url}/login" login_data = { "email": self.settings.zerogpu_email, "password": self.settings.zerogpu_password } async with self.session.post(login_url, json=login_data) as response: if response.status == 401: raise ValueError("Invalid email or password for ZeroGPU Chat API") response.raise_for_status() data = await response.json() self.access_token = data.get("access_token") self.refresh_token = data.get("refresh_token") # Access tokens typically expire in 15 minutes (900 seconds) self.token_expires_at = time.time() + 900 logger.info("Successfully authenticated with ZeroGPU Chat API") except aiohttp.ClientError as e: logger.error(f"Failed to login to ZeroGPU Chat API: {e}") raise RuntimeError(f"Authentication failed: {e}") from e async def _refresh_token(self): """Refresh access token using refresh token.""" try: refresh_url = f"{self.base_url}/refresh" headers = {"X-Refresh-Token": self.refresh_token} async with self.session.post(refresh_url, headers=headers) as response: if response.status == 401: # Refresh token expired, need to login again await self._login() return response.raise_for_status() data = await response.json() self.access_token = data.get("access_token") self.refresh_token = data.get("refresh_token") self.token_expires_at = time.time() + 900 logger.info("Successfully refreshed ZeroGPU Chat API token") except aiohttp.ClientError as e: logger.error(f"Failed to refresh token: {e}") # Try login as fallback await self._login() def _map_task_type(self, internal_task: str) -> str: """Map internal task types to ZeroGPU Chat API task types.""" task_mapping = { "general_reasoning": "general", "response_synthesis": "general", "intent_classification": "classification", "safety_check": "classification", "embedding_generation": "embedding" } return task_mapping.get(internal_task, "general") async def _call_zerogpu_api(self, task: str, prompt: str, **kwargs) -> Optional[str]: """Call ZeroGPU Chat API for inference.""" if not self.session: self.session = aiohttp.ClientSession() # Store original task type for model config lookup original_task = kwargs.pop('original_task_type', None) # Get model config for defaults model_config = self._select_model(original_task or 'general_reasoning') # Build request payload according to API documentation payload = { "message": prompt, "task": task, "max_tokens": kwargs.get('max_tokens', model_config.get('max_tokens', 512)), "temperature": kwargs.get('temperature', model_config.get('temperature', 0.7)), "top_p": kwargs.get('top_p', model_config.get('top_p', 0.9)), } # Add optional parameters if 'context' in kwargs and kwargs['context']: # Convert context to API format if needed context = kwargs['context'] if isinstance(context, list) and len(context) > 0: # Convert to API format: list of dicts with role, content, timestamp api_context = [] for item in context[:50]: # Max 50 messages if isinstance(item, (list, tuple)) and len(item) >= 2: # Format: [user_msg, assistant_msg] api_context.append({ "role": "user", "content": str(item[0]), "timestamp": kwargs.get('timestamp', time.time()) }) api_context.append({ "role": "assistant", "content": str(item[1]), "timestamp": kwargs.get('timestamp', time.time()) }) elif isinstance(item, dict): api_context.append(item) payload["context"] = api_context if 'system_prompt' in kwargs and kwargs['system_prompt']: payload["system_prompt"] = kwargs['system_prompt'] if 'repetition_penalty' in kwargs: payload["repetition_penalty"] = kwargs['repetition_penalty'] # Prepare headers headers = { "Authorization": f"Bearer {self.access_token}", "Content-Type": "application/json" } try: chat_url = f"{self.base_url}/chat" async with self.session.post(chat_url, json=payload, headers=headers) as response: # Handle token expiration if response.status == 401: logger.info("Token expired, refreshing...") await self._refresh_token() headers["Authorization"] = f"Bearer {self.access_token}" # Retry request async with self.session.post(chat_url, json=payload, headers=headers) as retry_response: retry_response.raise_for_status() data = await retry_response.json() return data.get("response") response.raise_for_status() data = await response.json() # Extract response from API result = data.get("response") if result: logger.info(f"ZeroGPU Chat API generated response (length: {len(result)})") return result else: logger.error("ZeroGPU Chat API returned empty response") return None except aiohttp.ClientError as e: logger.error(f"Error calling ZeroGPU Chat API: {e}", exc_info=True) raise def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int: """ Calculate safe max_tokens based on input token count and model context window. Args: prompt: Input prompt text requested_max_tokens: Desired max_tokens value Returns: int: Adjusted max_tokens that fits within context window """ # Estimate input tokens (rough: 1 token ≈ 4 characters) # For more accuracy, you could use tiktoken if available input_tokens = len(prompt) // 4 # Get model context window from settings context_window = self.settings.zerogpu_model_context_window logger.debug( f"Calculating safe max_tokens: input ~{input_tokens} tokens, " f"context_window={context_window}, requested={requested_max_tokens}" ) # Reserve minimum 100 tokens for safety margin available_tokens = context_window - input_tokens - 100 # Use the smaller of requested or available safe_max_tokens = min(requested_max_tokens, available_tokens) # Ensure minimum of 50 tokens for output safe_max_tokens = max(50, safe_max_tokens) if safe_max_tokens < requested_max_tokens: logger.warning( f"Reduced max_tokens from {requested_max_tokens} to {safe_max_tokens} " f"(input: ~{input_tokens} tokens, context window: {context_window} tokens, " f"available: {available_tokens} tokens)" ) return safe_max_tokens def _format_prompt(self, prompt: str, task_type: str, model_config: dict) -> str: """ Format prompt for ZeroGPU Chat API. Can be customized based on model requirements. """ formatted_prompt = prompt # Add math directive for mathematical problems if needed if self._is_math_query(prompt): math_directive = "Please reason step by step, and put your final answer within \\boxed{}." formatted_prompt = f"{formatted_prompt}\n\n{math_directive}" return formatted_prompt def _is_math_query(self, prompt: str) -> bool: """Detect if query is mathematical""" math_keywords = [ "solve", "calculate", "compute", "equation", "formula", "mathematical", "algebra", "geometry", "calculus", "integral", "derivative", "theorem", "proof", "problem" ] prompt_lower = prompt.lower() return any(keyword in prompt_lower for keyword in math_keywords) def _clean_reasoning_tags(self, text: str) -> str: """Clean up reasoning tags from response if present""" if not text: return text # Remove common reasoning tags if present text = text.replace("``", "").replace("``", "") text = text.replace("``", "").replace("``", "") text = text.strip() return text def _select_model(self, task_type: str) -> dict: """Select model configuration based on task type""" model_map = { "intent_classification": LLM_CONFIG["models"]["classification_specialist"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["safety_checker"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def get_available_models(self): """Get list of available models from ZeroGPU Chat API""" try: await self._ensure_authenticated() if not self.session: self.session = aiohttp.ClientSession() tasks_url = f"{self.base_url}/tasks" headers = {"Authorization": f"Bearer {self.access_token}"} async with self.session.get(tasks_url, headers=headers) as response: if response.status == 401: await self._refresh_token() headers["Authorization"] = f"Bearer {self.access_token}" async with self.session.get(tasks_url, headers=headers) as retry_response: retry_response.raise_for_status() data = await retry_response.json() else: response.raise_for_status() data = await response.json() tasks = data.get("tasks", {}) models = [f"ZeroGPU Chat API - {task}: {info.get('model', 'N/A')}" for task, info in tasks.items()] return models if models else ["ZeroGPU Chat API"] except Exception as e: logger.error(f"Failed to get available models: {e}") return ["ZeroGPU Chat API"] async def health_check(self): """Perform health check on ZeroGPU Chat API""" try: if not self.session: self.session = aiohttp.ClientSession() # Check health endpoint (no auth required) health_url = f"{self.base_url}/health" async with self.session.get(health_url) as response: response.raise_for_status() data = await response.json() return { "provider": "zerogpu_chat_api", "status": "healthy" if data.get("status") == "healthy" else "unhealthy", "models_ready": data.get("models_ready", False), "base_url": self.base_url } except Exception as e: logger.error(f"Health check failed: {e}") return { "provider": "zerogpu_chat_api", "status": "unhealthy", "error": str(e) } async def __aenter__(self): """Async context manager entry""" if not self.session: self.session = aiohttp.ClientSession() return self async def __aexit__(self, exc_type, exc_val, exc_tb): """Async context manager exit""" if self.session: await self.session.close() self.session = None def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None, user_input: Optional[str] = None) -> str: """ Smart context windowing with user input priority. User input is NEVER truncated - context is reduced to fit. Args: raw_context: Context dictionary max_tokens: Optional override (uses config default if None) user_input: Optional explicit user input (takes priority over raw_context['user_input']) """ # Use config budget if not provided if max_tokens is None: max_tokens = self.settings.context_preparation_budget # Get user input (explicit parameter takes priority) actual_user_input = user_input or raw_context.get('user_input', '') # Calculate user input tokens (simple estimation: 1 token ≈ 4 chars) user_input_tokens = len(actual_user_input) // 4 # Ensure user input fits within dedicated budget user_input_max = self.settings.user_input_max_tokens if user_input_tokens > user_input_max: logger.warning(f"User input ({user_input_tokens} tokens) exceeds max ({user_input_max}), truncating") max_chars = user_input_max * 4 actual_user_input = actual_user_input[:max_chars - 3] + "..." user_input_tokens = user_input_max # Reserve space for user input (it has highest priority) remaining_tokens = max_tokens - user_input_tokens if remaining_tokens < 0: logger.warning(f"User input ({user_input_tokens} tokens) exceeds total budget ({max_tokens})") remaining_tokens = 0 logger.info(f"Token allocation: User input={user_input_tokens}, Context budget={remaining_tokens}, Total={max_tokens}") # Priority order for context elements (user input already handled) priority_elements = [ ('recent_interactions', 0.8), ('user_preferences', 0.6), ('session_summary', 0.4), ('historical_context', 0.2) ] formatted_context = [] total_tokens = user_input_tokens # Start with user input tokens # Add user input first (unconditionally, never truncated) if actual_user_input: formatted_context.append(f"=== USER INPUT ===\n{actual_user_input}") # Now add context elements within remaining budget for element, priority in priority_elements: element_key_map = { 'recent_interactions': raw_context.get('interaction_contexts', []), 'user_preferences': raw_context.get('preferences', {}), 'session_summary': raw_context.get('session_context', {}), 'historical_context': raw_context.get('user_context', '') } content = element_key_map.get(element, '') # Convert to string if needed if isinstance(content, dict): content = str(content) elif isinstance(content, list): content = "\n".join([str(item) for item in content[:10]]) if not content: continue # Estimate tokens (simple: 1 token ≈ 4 chars) tokens = len(content) // 4 if total_tokens + tokens <= max_tokens: formatted_context.append(f"=== {element.upper()} ===\n{content}") total_tokens += tokens elif priority > 0.5 and remaining_tokens > 0: # Critical elements - truncate if needed available = max_tokens - total_tokens if available > 100: # Only truncate if we have meaningful space truncated = self._truncate_to_tokens(content, available) formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") total_tokens += available break logger.info(f"Context prepared: {total_tokens}/{max_tokens} tokens (user input: {user_input_tokens}, context: {total_tokens - user_input_tokens})") return "\n\n".join(formatted_context) def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: """Truncate content to fit within token limit""" # Simple character-based truncation (1 token ≈ 4 chars) max_chars = max_tokens * 4 if len(content) <= max_chars: return content return content[:max_chars - 3] + "..."