# llm_router.py - ZeroGPU Chat API (RunPod)
import logging
import asyncio
import aiohttp
import time
from typing import Dict, Optional
from .models_config import LLM_CONFIG
from .config import get_settings
logger = logging.getLogger(__name__)
class LLMRouter:
def __init__(self, hf_token=None, use_local_models: bool = False):
"""
Initialize LLM Router with ZeroGPU Chat API (RunPod).
Args:
hf_token: Not used (kept for backward compatibility)
use_local_models: Must be False (local models disabled)
"""
if use_local_models:
raise ValueError("Local models are disabled. Only ZeroGPU Chat API is supported.")
self.settings = get_settings()
self.base_url = self.settings.zerogpu_base_url.rstrip('/')
self.access_token = None
self.refresh_token = None
self.token_expires_at = 0
self.session = None
# Validate base URL
if not self.settings.zerogpu_base_url:
raise ValueError(
"ZEROGPU_BASE_URL is required. "
"Set it in environment variables or .env file"
)
# Validate credentials
if not self.settings.zerogpu_email or not self.settings.zerogpu_password:
raise ValueError(
"ZEROGPU_EMAIL and ZEROGPU_PASSWORD are required. "
"Set them in environment variables or .env file"
)
logger.info("ZeroGPU Chat API client initializing")
logger.info(f"Base URL: {self.base_url}")
# Initialize session and authenticate
try:
# Authentication will happen on first request if needed
logger.info("ZeroGPU Chat API client initialized (authentication on first request)")
except Exception as e:
logger.error(f"Failed to initialize ZeroGPU Chat API client: {e}")
raise RuntimeError(f"Could not initialize ZeroGPU Chat API client: {e}") from e
async def route_inference(self, task_type: str, prompt: str, **kwargs):
"""
Route inference to ZeroGPU Chat API.
Args:
task_type: Type of task (general_reasoning, intent_classification, etc.)
prompt: Input prompt
**kwargs: Additional parameters (max_tokens, temperature, etc.)
Returns:
Generated text response
"""
logger.info(f"Routing inference to ZeroGPU Chat API for task: {task_type}")
try:
# Ensure authenticated
await self._ensure_authenticated()
# Map internal task types to API task types
api_task = self._map_task_type(task_type)
# Pass original task type for model config lookup
kwargs['original_task_type'] = task_type
# Handle embedding generation (may need special handling)
if task_type == "embedding_generation":
logger.warning("Embedding generation via ZeroGPU API may require special implementation")
result = await self._call_zerogpu_api(api_task, prompt, **kwargs)
else:
result = await self._call_zerogpu_api(api_task, prompt, **kwargs)
if result is None:
logger.error(f"ZeroGPU Chat API returned None for task: {task_type}")
raise RuntimeError(f"Inference failed for task: {task_type}")
logger.info(f"Inference complete for {task_type} (ZeroGPU Chat API)")
return result
except Exception as e:
logger.error(f"ZeroGPU Chat API inference failed: {e}", exc_info=True)
raise RuntimeError(
f"Inference failed for task: {task_type}. "
f"ZeroGPU Chat API error: {e}"
) from e
async def _ensure_authenticated(self):
"""Ensure we have a valid access token, login if needed."""
# Check if token is expired (with 60 second buffer)
if self.access_token and time.time() < (self.token_expires_at - 60):
return
# Create session if needed
if self.session is None:
self.session = aiohttp.ClientSession()
# Login to get tokens
await self._login()
async def _login(self):
"""Login to ZeroGPU Chat API and get access/refresh tokens."""
try:
login_url = f"{self.base_url}/login"
login_data = {
"email": self.settings.zerogpu_email,
"password": self.settings.zerogpu_password
}
async with self.session.post(login_url, json=login_data) as response:
if response.status == 401:
raise ValueError("Invalid email or password for ZeroGPU Chat API")
response.raise_for_status()
data = await response.json()
self.access_token = data.get("access_token")
self.refresh_token = data.get("refresh_token")
# Access tokens typically expire in 15 minutes (900 seconds)
self.token_expires_at = time.time() + 900
logger.info("Successfully authenticated with ZeroGPU Chat API")
except aiohttp.ClientError as e:
logger.error(f"Failed to login to ZeroGPU Chat API: {e}")
raise RuntimeError(f"Authentication failed: {e}") from e
async def _refresh_token(self):
"""Refresh access token using refresh token."""
try:
refresh_url = f"{self.base_url}/refresh"
headers = {"X-Refresh-Token": self.refresh_token}
async with self.session.post(refresh_url, headers=headers) as response:
if response.status == 401:
# Refresh token expired, need to login again
await self._login()
return
response.raise_for_status()
data = await response.json()
self.access_token = data.get("access_token")
self.refresh_token = data.get("refresh_token")
self.token_expires_at = time.time() + 900
logger.info("Successfully refreshed ZeroGPU Chat API token")
except aiohttp.ClientError as e:
logger.error(f"Failed to refresh token: {e}")
# Try login as fallback
await self._login()
def _map_task_type(self, internal_task: str) -> str:
"""Map internal task types to ZeroGPU Chat API task types."""
task_mapping = {
"general_reasoning": "general",
"response_synthesis": "general",
"intent_classification": "classification",
"safety_check": "classification",
"embedding_generation": "embedding"
}
return task_mapping.get(internal_task, "general")
async def _call_zerogpu_api(self, task: str, prompt: str, **kwargs) -> Optional[str]:
"""Call ZeroGPU Chat API for inference."""
if not self.session:
self.session = aiohttp.ClientSession()
# Store original task type for model config lookup
original_task = kwargs.pop('original_task_type', None)
# Get model config for defaults
model_config = self._select_model(original_task or 'general_reasoning')
# Build request payload according to API documentation
payload = {
"message": prompt,
"task": task,
"max_tokens": kwargs.get('max_tokens', model_config.get('max_tokens', 512)),
"temperature": kwargs.get('temperature', model_config.get('temperature', 0.7)),
"top_p": kwargs.get('top_p', model_config.get('top_p', 0.9)),
}
# Add optional parameters
if 'context' in kwargs and kwargs['context']:
# Convert context to API format if needed
context = kwargs['context']
if isinstance(context, list) and len(context) > 0:
# Convert to API format: list of dicts with role, content, timestamp
api_context = []
for item in context[:50]: # Max 50 messages
if isinstance(item, (list, tuple)) and len(item) >= 2:
# Format: [user_msg, assistant_msg]
api_context.append({
"role": "user",
"content": str(item[0]),
"timestamp": kwargs.get('timestamp', time.time())
})
api_context.append({
"role": "assistant",
"content": str(item[1]),
"timestamp": kwargs.get('timestamp', time.time())
})
elif isinstance(item, dict):
api_context.append(item)
payload["context"] = api_context
if 'system_prompt' in kwargs and kwargs['system_prompt']:
payload["system_prompt"] = kwargs['system_prompt']
if 'repetition_penalty' in kwargs:
payload["repetition_penalty"] = kwargs['repetition_penalty']
# Prepare headers
headers = {
"Authorization": f"Bearer {self.access_token}",
"Content-Type": "application/json"
}
try:
chat_url = f"{self.base_url}/chat"
async with self.session.post(chat_url, json=payload, headers=headers) as response:
# Handle token expiration
if response.status == 401:
logger.info("Token expired, refreshing...")
await self._refresh_token()
headers["Authorization"] = f"Bearer {self.access_token}"
# Retry request
async with self.session.post(chat_url, json=payload, headers=headers) as retry_response:
retry_response.raise_for_status()
data = await retry_response.json()
return data.get("response")
response.raise_for_status()
data = await response.json()
# Extract response from API
result = data.get("response")
if result:
logger.info(f"ZeroGPU Chat API generated response (length: {len(result)})")
return result
else:
logger.error("ZeroGPU Chat API returned empty response")
return None
except aiohttp.ClientError as e:
logger.error(f"Error calling ZeroGPU Chat API: {e}", exc_info=True)
raise
def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int:
"""
Calculate safe max_tokens based on input token count and model context window.
Args:
prompt: Input prompt text
requested_max_tokens: Desired max_tokens value
Returns:
int: Adjusted max_tokens that fits within context window
"""
# Estimate input tokens (rough: 1 token ≈ 4 characters)
# For more accuracy, you could use tiktoken if available
input_tokens = len(prompt) // 4
# Get model context window from settings
context_window = self.settings.zerogpu_model_context_window
logger.debug(
f"Calculating safe max_tokens: input ~{input_tokens} tokens, "
f"context_window={context_window}, requested={requested_max_tokens}"
)
# Reserve minimum 100 tokens for safety margin
available_tokens = context_window - input_tokens - 100
# Use the smaller of requested or available
safe_max_tokens = min(requested_max_tokens, available_tokens)
# Ensure minimum of 50 tokens for output
safe_max_tokens = max(50, safe_max_tokens)
if safe_max_tokens < requested_max_tokens:
logger.warning(
f"Reduced max_tokens from {requested_max_tokens} to {safe_max_tokens} "
f"(input: ~{input_tokens} tokens, context window: {context_window} tokens, "
f"available: {available_tokens} tokens)"
)
return safe_max_tokens
def _format_prompt(self, prompt: str, task_type: str, model_config: dict) -> str:
"""
Format prompt for ZeroGPU Chat API.
Can be customized based on model requirements.
"""
formatted_prompt = prompt
# Add math directive for mathematical problems if needed
if self._is_math_query(prompt):
math_directive = "Please reason step by step, and put your final answer within \\boxed{}."
formatted_prompt = f"{formatted_prompt}\n\n{math_directive}"
return formatted_prompt
def _is_math_query(self, prompt: str) -> bool:
"""Detect if query is mathematical"""
math_keywords = [
"solve", "calculate", "compute", "equation", "formula",
"mathematical", "algebra", "geometry", "calculus", "integral",
"derivative", "theorem", "proof", "problem"
]
prompt_lower = prompt.lower()
return any(keyword in prompt_lower for keyword in math_keywords)
def _clean_reasoning_tags(self, text: str) -> str:
"""Clean up reasoning tags from response if present"""
if not text:
return text
# Remove common reasoning tags if present
text = text.replace("``", "").replace("``", "")
text = text.replace("``", "").replace("``", "")
text = text.strip()
return text
def _select_model(self, task_type: str) -> dict:
"""Select model configuration based on task type"""
model_map = {
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["safety_checker"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def get_available_models(self):
"""Get list of available models from ZeroGPU Chat API"""
try:
await self._ensure_authenticated()
if not self.session:
self.session = aiohttp.ClientSession()
tasks_url = f"{self.base_url}/tasks"
headers = {"Authorization": f"Bearer {self.access_token}"}
async with self.session.get(tasks_url, headers=headers) as response:
if response.status == 401:
await self._refresh_token()
headers["Authorization"] = f"Bearer {self.access_token}"
async with self.session.get(tasks_url, headers=headers) as retry_response:
retry_response.raise_for_status()
data = await retry_response.json()
else:
response.raise_for_status()
data = await response.json()
tasks = data.get("tasks", {})
models = [f"ZeroGPU Chat API - {task}: {info.get('model', 'N/A')}"
for task, info in tasks.items()]
return models if models else ["ZeroGPU Chat API"]
except Exception as e:
logger.error(f"Failed to get available models: {e}")
return ["ZeroGPU Chat API"]
async def health_check(self):
"""Perform health check on ZeroGPU Chat API"""
try:
if not self.session:
self.session = aiohttp.ClientSession()
# Check health endpoint (no auth required)
health_url = f"{self.base_url}/health"
async with self.session.get(health_url) as response:
response.raise_for_status()
data = await response.json()
return {
"provider": "zerogpu_chat_api",
"status": "healthy" if data.get("status") == "healthy" else "unhealthy",
"models_ready": data.get("models_ready", False),
"base_url": self.base_url
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"provider": "zerogpu_chat_api",
"status": "unhealthy",
"error": str(e)
}
async def __aenter__(self):
"""Async context manager entry"""
if not self.session:
self.session = aiohttp.ClientSession()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit"""
if self.session:
await self.session.close()
self.session = None
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None,
user_input: Optional[str] = None) -> str:
"""
Smart context windowing with user input priority.
User input is NEVER truncated - context is reduced to fit.
Args:
raw_context: Context dictionary
max_tokens: Optional override (uses config default if None)
user_input: Optional explicit user input (takes priority over raw_context['user_input'])
"""
# Use config budget if not provided
if max_tokens is None:
max_tokens = self.settings.context_preparation_budget
# Get user input (explicit parameter takes priority)
actual_user_input = user_input or raw_context.get('user_input', '')
# Calculate user input tokens (simple estimation: 1 token ≈ 4 chars)
user_input_tokens = len(actual_user_input) // 4
# Ensure user input fits within dedicated budget
user_input_max = self.settings.user_input_max_tokens
if user_input_tokens > user_input_max:
logger.warning(f"User input ({user_input_tokens} tokens) exceeds max ({user_input_max}), truncating")
max_chars = user_input_max * 4
actual_user_input = actual_user_input[:max_chars - 3] + "..."
user_input_tokens = user_input_max
# Reserve space for user input (it has highest priority)
remaining_tokens = max_tokens - user_input_tokens
if remaining_tokens < 0:
logger.warning(f"User input ({user_input_tokens} tokens) exceeds total budget ({max_tokens})")
remaining_tokens = 0
logger.info(f"Token allocation: User input={user_input_tokens}, Context budget={remaining_tokens}, Total={max_tokens}")
# Priority order for context elements (user input already handled)
priority_elements = [
('recent_interactions', 0.8),
('user_preferences', 0.6),
('session_summary', 0.4),
('historical_context', 0.2)
]
formatted_context = []
total_tokens = user_input_tokens # Start with user input tokens
# Add user input first (unconditionally, never truncated)
if actual_user_input:
formatted_context.append(f"=== USER INPUT ===\n{actual_user_input}")
# Now add context elements within remaining budget
for element, priority in priority_elements:
element_key_map = {
'recent_interactions': raw_context.get('interaction_contexts', []),
'user_preferences': raw_context.get('preferences', {}),
'session_summary': raw_context.get('session_context', {}),
'historical_context': raw_context.get('user_context', '')
}
content = element_key_map.get(element, '')
# Convert to string if needed
if isinstance(content, dict):
content = str(content)
elif isinstance(content, list):
content = "\n".join([str(item) for item in content[:10]])
if not content:
continue
# Estimate tokens (simple: 1 token ≈ 4 chars)
tokens = len(content) // 4
if total_tokens + tokens <= max_tokens:
formatted_context.append(f"=== {element.upper()} ===\n{content}")
total_tokens += tokens
elif priority > 0.5 and remaining_tokens > 0: # Critical elements - truncate if needed
available = max_tokens - total_tokens
if available > 100: # Only truncate if we have meaningful space
truncated = self._truncate_to_tokens(content, available)
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}")
total_tokens += available
break
logger.info(f"Context prepared: {total_tokens}/{max_tokens} tokens (user input: {user_input_tokens}, context: {total_tokens - user_input_tokens})")
return "\n\n".join(formatted_context)
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str:
"""Truncate content to fit within token limit"""
# Simple character-based truncation (1 token ≈ 4 chars)
max_chars = max_tokens * 4
if len(content) <= max_chars:
return content
return content[:max_chars - 3] + "..."