|
|
|
|
|
import logging |
|
|
import asyncio |
|
|
import aiohttp |
|
|
import time |
|
|
from typing import Dict, Optional |
|
|
from .models_config import LLM_CONFIG |
|
|
from .config import get_settings |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LLMRouter: |
|
|
def __init__(self, hf_token=None, use_local_models: bool = False): |
|
|
""" |
|
|
Initialize LLM Router with ZeroGPU Chat API (RunPod). |
|
|
|
|
|
Args: |
|
|
hf_token: Not used (kept for backward compatibility) |
|
|
use_local_models: Must be False (local models disabled) |
|
|
""" |
|
|
if use_local_models: |
|
|
raise ValueError("Local models are disabled. Only ZeroGPU Chat API is supported.") |
|
|
|
|
|
self.settings = get_settings() |
|
|
self.base_url = self.settings.zerogpu_base_url.rstrip('/') |
|
|
self.access_token = None |
|
|
self.refresh_token = None |
|
|
self.token_expires_at = 0 |
|
|
self.session = None |
|
|
|
|
|
|
|
|
if not self.settings.zerogpu_base_url: |
|
|
raise ValueError( |
|
|
"ZEROGPU_BASE_URL is required. " |
|
|
"Set it in environment variables or .env file" |
|
|
) |
|
|
|
|
|
|
|
|
if not self.settings.zerogpu_email or not self.settings.zerogpu_password: |
|
|
raise ValueError( |
|
|
"ZEROGPU_EMAIL and ZEROGPU_PASSWORD are required. " |
|
|
"Set them in environment variables or .env file" |
|
|
) |
|
|
|
|
|
logger.info("ZeroGPU Chat API client initializing") |
|
|
logger.info(f"Base URL: {self.base_url}") |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
logger.info("ZeroGPU Chat API client initialized (authentication on first request)") |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to initialize ZeroGPU Chat API client: {e}") |
|
|
raise RuntimeError(f"Could not initialize ZeroGPU Chat API client: {e}") from e |
|
|
|
|
|
async def route_inference(self, task_type: str, prompt: str, **kwargs): |
|
|
""" |
|
|
Route inference to ZeroGPU Chat API. |
|
|
|
|
|
Args: |
|
|
task_type: Type of task (general_reasoning, intent_classification, etc.) |
|
|
prompt: Input prompt |
|
|
**kwargs: Additional parameters (max_tokens, temperature, etc.) |
|
|
|
|
|
Returns: |
|
|
Generated text response |
|
|
""" |
|
|
logger.info(f"Routing inference to ZeroGPU Chat API for task: {task_type}") |
|
|
|
|
|
try: |
|
|
|
|
|
await self._ensure_authenticated() |
|
|
|
|
|
|
|
|
api_task = self._map_task_type(task_type) |
|
|
|
|
|
|
|
|
kwargs['original_task_type'] = task_type |
|
|
|
|
|
|
|
|
if task_type == "embedding_generation": |
|
|
logger.warning("Embedding generation via ZeroGPU API may require special implementation") |
|
|
result = await self._call_zerogpu_api(api_task, prompt, **kwargs) |
|
|
else: |
|
|
result = await self._call_zerogpu_api(api_task, prompt, **kwargs) |
|
|
|
|
|
if result is None: |
|
|
logger.error(f"ZeroGPU Chat API returned None for task: {task_type}") |
|
|
raise RuntimeError(f"Inference failed for task: {task_type}") |
|
|
|
|
|
logger.info(f"Inference complete for {task_type} (ZeroGPU Chat API)") |
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"ZeroGPU Chat API inference failed: {e}", exc_info=True) |
|
|
raise RuntimeError( |
|
|
f"Inference failed for task: {task_type}. " |
|
|
f"ZeroGPU Chat API error: {e}" |
|
|
) from e |
|
|
|
|
|
async def _ensure_authenticated(self): |
|
|
"""Ensure we have a valid access token, login if needed.""" |
|
|
|
|
|
if self.access_token and time.time() < (self.token_expires_at - 60): |
|
|
return |
|
|
|
|
|
|
|
|
if self.session is None: |
|
|
self.session = aiohttp.ClientSession() |
|
|
|
|
|
|
|
|
await self._login() |
|
|
|
|
|
async def _login(self): |
|
|
"""Login to ZeroGPU Chat API and get access/refresh tokens.""" |
|
|
try: |
|
|
login_url = f"{self.base_url}/login" |
|
|
login_data = { |
|
|
"email": self.settings.zerogpu_email, |
|
|
"password": self.settings.zerogpu_password |
|
|
} |
|
|
|
|
|
async with self.session.post(login_url, json=login_data) as response: |
|
|
if response.status == 401: |
|
|
raise ValueError("Invalid email or password for ZeroGPU Chat API") |
|
|
response.raise_for_status() |
|
|
data = await response.json() |
|
|
|
|
|
self.access_token = data.get("access_token") |
|
|
self.refresh_token = data.get("refresh_token") |
|
|
|
|
|
|
|
|
self.token_expires_at = time.time() + 900 |
|
|
|
|
|
logger.info("Successfully authenticated with ZeroGPU Chat API") |
|
|
|
|
|
except aiohttp.ClientError as e: |
|
|
logger.error(f"Failed to login to ZeroGPU Chat API: {e}") |
|
|
raise RuntimeError(f"Authentication failed: {e}") from e |
|
|
|
|
|
async def _refresh_token(self): |
|
|
"""Refresh access token using refresh token.""" |
|
|
try: |
|
|
refresh_url = f"{self.base_url}/refresh" |
|
|
headers = {"X-Refresh-Token": self.refresh_token} |
|
|
|
|
|
async with self.session.post(refresh_url, headers=headers) as response: |
|
|
if response.status == 401: |
|
|
|
|
|
await self._login() |
|
|
return |
|
|
|
|
|
response.raise_for_status() |
|
|
data = await response.json() |
|
|
|
|
|
self.access_token = data.get("access_token") |
|
|
self.refresh_token = data.get("refresh_token") |
|
|
self.token_expires_at = time.time() + 900 |
|
|
|
|
|
logger.info("Successfully refreshed ZeroGPU Chat API token") |
|
|
|
|
|
except aiohttp.ClientError as e: |
|
|
logger.error(f"Failed to refresh token: {e}") |
|
|
|
|
|
await self._login() |
|
|
|
|
|
def _map_task_type(self, internal_task: str) -> str: |
|
|
"""Map internal task types to ZeroGPU Chat API task types.""" |
|
|
task_mapping = { |
|
|
"general_reasoning": "general", |
|
|
"response_synthesis": "general", |
|
|
"intent_classification": "classification", |
|
|
"safety_check": "classification", |
|
|
"embedding_generation": "embedding" |
|
|
} |
|
|
return task_mapping.get(internal_task, "general") |
|
|
|
|
|
async def _call_zerogpu_api(self, task: str, prompt: str, **kwargs) -> Optional[str]: |
|
|
"""Call ZeroGPU Chat API for inference.""" |
|
|
if not self.session: |
|
|
self.session = aiohttp.ClientSession() |
|
|
|
|
|
|
|
|
original_task = kwargs.pop('original_task_type', None) |
|
|
|
|
|
|
|
|
model_config = self._select_model(original_task or 'general_reasoning') |
|
|
|
|
|
|
|
|
payload = { |
|
|
"message": prompt, |
|
|
"task": task, |
|
|
"max_tokens": kwargs.get('max_tokens', model_config.get('max_tokens', 512)), |
|
|
"temperature": kwargs.get('temperature', model_config.get('temperature', 0.7)), |
|
|
"top_p": kwargs.get('top_p', model_config.get('top_p', 0.9)), |
|
|
} |
|
|
|
|
|
|
|
|
if 'context' in kwargs and kwargs['context']: |
|
|
|
|
|
context = kwargs['context'] |
|
|
if isinstance(context, list) and len(context) > 0: |
|
|
|
|
|
api_context = [] |
|
|
for item in context[:50]: |
|
|
if isinstance(item, (list, tuple)) and len(item) >= 2: |
|
|
|
|
|
api_context.append({ |
|
|
"role": "user", |
|
|
"content": str(item[0]), |
|
|
"timestamp": kwargs.get('timestamp', time.time()) |
|
|
}) |
|
|
api_context.append({ |
|
|
"role": "assistant", |
|
|
"content": str(item[1]), |
|
|
"timestamp": kwargs.get('timestamp', time.time()) |
|
|
}) |
|
|
elif isinstance(item, dict): |
|
|
api_context.append(item) |
|
|
payload["context"] = api_context |
|
|
|
|
|
if 'system_prompt' in kwargs and kwargs['system_prompt']: |
|
|
payload["system_prompt"] = kwargs['system_prompt'] |
|
|
if 'repetition_penalty' in kwargs: |
|
|
payload["repetition_penalty"] = kwargs['repetition_penalty'] |
|
|
|
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {self.access_token}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
try: |
|
|
chat_url = f"{self.base_url}/chat" |
|
|
|
|
|
async with self.session.post(chat_url, json=payload, headers=headers) as response: |
|
|
|
|
|
if response.status == 401: |
|
|
logger.info("Token expired, refreshing...") |
|
|
await self._refresh_token() |
|
|
headers["Authorization"] = f"Bearer {self.access_token}" |
|
|
|
|
|
async with self.session.post(chat_url, json=payload, headers=headers) as retry_response: |
|
|
retry_response.raise_for_status() |
|
|
data = await retry_response.json() |
|
|
return data.get("response") |
|
|
|
|
|
response.raise_for_status() |
|
|
data = await response.json() |
|
|
|
|
|
|
|
|
result = data.get("response") |
|
|
if result: |
|
|
logger.info(f"ZeroGPU Chat API generated response (length: {len(result)})") |
|
|
return result |
|
|
else: |
|
|
logger.error("ZeroGPU Chat API returned empty response") |
|
|
return None |
|
|
|
|
|
except aiohttp.ClientError as e: |
|
|
logger.error(f"Error calling ZeroGPU Chat API: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int: |
|
|
""" |
|
|
Calculate safe max_tokens based on input token count and model context window. |
|
|
|
|
|
Args: |
|
|
prompt: Input prompt text |
|
|
requested_max_tokens: Desired max_tokens value |
|
|
|
|
|
Returns: |
|
|
int: Adjusted max_tokens that fits within context window |
|
|
""" |
|
|
|
|
|
|
|
|
input_tokens = len(prompt) // 4 |
|
|
|
|
|
|
|
|
context_window = self.settings.zerogpu_model_context_window |
|
|
|
|
|
logger.debug( |
|
|
f"Calculating safe max_tokens: input ~{input_tokens} tokens, " |
|
|
f"context_window={context_window}, requested={requested_max_tokens}" |
|
|
) |
|
|
|
|
|
|
|
|
available_tokens = context_window - input_tokens - 100 |
|
|
|
|
|
|
|
|
safe_max_tokens = min(requested_max_tokens, available_tokens) |
|
|
|
|
|
|
|
|
safe_max_tokens = max(50, safe_max_tokens) |
|
|
|
|
|
if safe_max_tokens < requested_max_tokens: |
|
|
logger.warning( |
|
|
f"Reduced max_tokens from {requested_max_tokens} to {safe_max_tokens} " |
|
|
f"(input: ~{input_tokens} tokens, context window: {context_window} tokens, " |
|
|
f"available: {available_tokens} tokens)" |
|
|
) |
|
|
|
|
|
return safe_max_tokens |
|
|
|
|
|
def _format_prompt(self, prompt: str, task_type: str, model_config: dict) -> str: |
|
|
""" |
|
|
Format prompt for ZeroGPU Chat API. |
|
|
Can be customized based on model requirements. |
|
|
""" |
|
|
formatted_prompt = prompt |
|
|
|
|
|
|
|
|
if self._is_math_query(prompt): |
|
|
math_directive = "Please reason step by step, and put your final answer within \\boxed{}." |
|
|
formatted_prompt = f"{formatted_prompt}\n\n{math_directive}" |
|
|
|
|
|
return formatted_prompt |
|
|
|
|
|
def _is_math_query(self, prompt: str) -> bool: |
|
|
"""Detect if query is mathematical""" |
|
|
math_keywords = [ |
|
|
"solve", "calculate", "compute", "equation", "formula", |
|
|
"mathematical", "algebra", "geometry", "calculus", "integral", |
|
|
"derivative", "theorem", "proof", "problem" |
|
|
] |
|
|
prompt_lower = prompt.lower() |
|
|
return any(keyword in prompt_lower for keyword in math_keywords) |
|
|
|
|
|
def _clean_reasoning_tags(self, text: str) -> str: |
|
|
"""Clean up reasoning tags from response if present""" |
|
|
if not text: |
|
|
return text |
|
|
|
|
|
text = text.replace("`<think>`", "").replace("`</think>`", "") |
|
|
text = text.replace("`<think>`", "").replace("`</think>`", "") |
|
|
text = text.strip() |
|
|
return text |
|
|
|
|
|
def _select_model(self, task_type: str) -> dict: |
|
|
"""Select model configuration based on task type""" |
|
|
model_map = { |
|
|
"intent_classification": LLM_CONFIG["models"]["classification_specialist"], |
|
|
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
|
|
"safety_check": LLM_CONFIG["models"]["safety_checker"], |
|
|
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
|
|
} |
|
|
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
|
|
|
|
|
async def get_available_models(self): |
|
|
"""Get list of available models from ZeroGPU Chat API""" |
|
|
try: |
|
|
await self._ensure_authenticated() |
|
|
if not self.session: |
|
|
self.session = aiohttp.ClientSession() |
|
|
|
|
|
tasks_url = f"{self.base_url}/tasks" |
|
|
headers = {"Authorization": f"Bearer {self.access_token}"} |
|
|
|
|
|
async with self.session.get(tasks_url, headers=headers) as response: |
|
|
if response.status == 401: |
|
|
await self._refresh_token() |
|
|
headers["Authorization"] = f"Bearer {self.access_token}" |
|
|
async with self.session.get(tasks_url, headers=headers) as retry_response: |
|
|
retry_response.raise_for_status() |
|
|
data = await retry_response.json() |
|
|
else: |
|
|
response.raise_for_status() |
|
|
data = await response.json() |
|
|
|
|
|
tasks = data.get("tasks", {}) |
|
|
models = [f"ZeroGPU Chat API - {task}: {info.get('model', 'N/A')}" |
|
|
for task, info in tasks.items()] |
|
|
return models if models else ["ZeroGPU Chat API"] |
|
|
except Exception as e: |
|
|
logger.error(f"Failed to get available models: {e}") |
|
|
return ["ZeroGPU Chat API"] |
|
|
|
|
|
async def health_check(self): |
|
|
"""Perform health check on ZeroGPU Chat API""" |
|
|
try: |
|
|
if not self.session: |
|
|
self.session = aiohttp.ClientSession() |
|
|
|
|
|
|
|
|
health_url = f"{self.base_url}/health" |
|
|
async with self.session.get(health_url) as response: |
|
|
response.raise_for_status() |
|
|
data = await response.json() |
|
|
|
|
|
return { |
|
|
"provider": "zerogpu_chat_api", |
|
|
"status": "healthy" if data.get("status") == "healthy" else "unhealthy", |
|
|
"models_ready": data.get("models_ready", False), |
|
|
"base_url": self.base_url |
|
|
} |
|
|
except Exception as e: |
|
|
logger.error(f"Health check failed: {e}") |
|
|
return { |
|
|
"provider": "zerogpu_chat_api", |
|
|
"status": "unhealthy", |
|
|
"error": str(e) |
|
|
} |
|
|
|
|
|
async def __aenter__(self): |
|
|
"""Async context manager entry""" |
|
|
if not self.session: |
|
|
self.session = aiohttp.ClientSession() |
|
|
return self |
|
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb): |
|
|
"""Async context manager exit""" |
|
|
if self.session: |
|
|
await self.session.close() |
|
|
self.session = None |
|
|
|
|
|
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None, |
|
|
user_input: Optional[str] = None) -> str: |
|
|
""" |
|
|
Smart context windowing with user input priority. |
|
|
User input is NEVER truncated - context is reduced to fit. |
|
|
|
|
|
Args: |
|
|
raw_context: Context dictionary |
|
|
max_tokens: Optional override (uses config default if None) |
|
|
user_input: Optional explicit user input (takes priority over raw_context['user_input']) |
|
|
""" |
|
|
|
|
|
if max_tokens is None: |
|
|
max_tokens = self.settings.context_preparation_budget |
|
|
|
|
|
|
|
|
actual_user_input = user_input or raw_context.get('user_input', '') |
|
|
|
|
|
|
|
|
user_input_tokens = len(actual_user_input) // 4 |
|
|
|
|
|
|
|
|
user_input_max = self.settings.user_input_max_tokens |
|
|
if user_input_tokens > user_input_max: |
|
|
logger.warning(f"User input ({user_input_tokens} tokens) exceeds max ({user_input_max}), truncating") |
|
|
max_chars = user_input_max * 4 |
|
|
actual_user_input = actual_user_input[:max_chars - 3] + "..." |
|
|
user_input_tokens = user_input_max |
|
|
|
|
|
|
|
|
remaining_tokens = max_tokens - user_input_tokens |
|
|
if remaining_tokens < 0: |
|
|
logger.warning(f"User input ({user_input_tokens} tokens) exceeds total budget ({max_tokens})") |
|
|
remaining_tokens = 0 |
|
|
|
|
|
logger.info(f"Token allocation: User input={user_input_tokens}, Context budget={remaining_tokens}, Total={max_tokens}") |
|
|
|
|
|
|
|
|
priority_elements = [ |
|
|
('recent_interactions', 0.8), |
|
|
('user_preferences', 0.6), |
|
|
('session_summary', 0.4), |
|
|
('historical_context', 0.2) |
|
|
] |
|
|
|
|
|
formatted_context = [] |
|
|
total_tokens = user_input_tokens |
|
|
|
|
|
|
|
|
if actual_user_input: |
|
|
formatted_context.append(f"=== USER INPUT ===\n{actual_user_input}") |
|
|
|
|
|
|
|
|
for element, priority in priority_elements: |
|
|
element_key_map = { |
|
|
'recent_interactions': raw_context.get('interaction_contexts', []), |
|
|
'user_preferences': raw_context.get('preferences', {}), |
|
|
'session_summary': raw_context.get('session_context', {}), |
|
|
'historical_context': raw_context.get('user_context', '') |
|
|
} |
|
|
|
|
|
content = element_key_map.get(element, '') |
|
|
|
|
|
|
|
|
if isinstance(content, dict): |
|
|
content = str(content) |
|
|
elif isinstance(content, list): |
|
|
content = "\n".join([str(item) for item in content[:10]]) |
|
|
|
|
|
if not content: |
|
|
continue |
|
|
|
|
|
|
|
|
tokens = len(content) // 4 |
|
|
|
|
|
if total_tokens + tokens <= max_tokens: |
|
|
formatted_context.append(f"=== {element.upper()} ===\n{content}") |
|
|
total_tokens += tokens |
|
|
elif priority > 0.5 and remaining_tokens > 0: |
|
|
available = max_tokens - total_tokens |
|
|
if available > 100: |
|
|
truncated = self._truncate_to_tokens(content, available) |
|
|
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") |
|
|
total_tokens += available |
|
|
break |
|
|
|
|
|
logger.info(f"Context prepared: {total_tokens}/{max_tokens} tokens (user input: {user_input_tokens}, context: {total_tokens - user_input_tokens})") |
|
|
return "\n\n".join(formatted_context) |
|
|
|
|
|
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: |
|
|
"""Truncate content to fit within token limit""" |
|
|
|
|
|
max_chars = max_tokens * 4 |
|
|
if len(content) <= max_chars: |
|
|
return content |
|
|
return content[:max_chars - 3] + "..." |
|
|
|