Research_AI_Assistant / src /llm_router.py
JatsTheAIGen's picture
fix: Resolve database permission errors and OMP_NUM_THREADS warning
f5d3311
raw
history blame
32.6 kB
# llm_router.py - UPDATED FOR LOCAL GPU MODEL LOADING + ZEROGPU API
import logging
import asyncio
import os
from typing import Dict, Optional, List
from .models_config import LLM_CONFIG
logger = logging.getLogger(__name__)
class LLMRouter:
def __init__(self, hf_token, use_local_models: bool = True, zero_gpu_config: Optional[Dict] = None):
self.hf_token = hf_token
self.health_status = {}
self.use_local_models = use_local_models
self.local_loader = None
self.zero_gpu_client = None # Service account client (Option A)
self.zero_gpu_user_manager = None # Per-user manager (Option B)
self.use_zero_gpu = False
self.zero_gpu_mode = "service_account" # "service_account" or "per_user"
logger.info("LLMRouter initialized")
if hf_token:
logger.info("HF token available")
else:
logger.warning("No HF token provided")
# Initialize ZeroGPU client if configured
if zero_gpu_config and zero_gpu_config.get("enabled", False):
# Check if per-user mode is enabled
per_user_mode = zero_gpu_config.get("per_user_mode", False)
if per_user_mode:
# Option B: Per-User Accounts (Multi-tenant)
try:
from zero_gpu_user_manager import ZeroGPUUserManager
base_url = zero_gpu_config.get("base_url", os.getenv("ZERO_GPU_API_URL", "https://bm9njt1ypzvuqw-8000.proxy.runpod.net"))
admin_email = zero_gpu_config.get("admin_email", os.getenv("ZERO_GPU_ADMIN_EMAIL", ""))
admin_password = zero_gpu_config.get("admin_password", os.getenv("ZERO_GPU_ADMIN_PASSWORD", ""))
db_path = zero_gpu_config.get("db_path", os.getenv("DB_PATH", "/tmp/sessions.db"))
if admin_email and admin_password:
self.zero_gpu_user_manager = ZeroGPUUserManager(
base_url, admin_email, admin_password, db_path
)
self.use_zero_gpu = True
self.zero_gpu_mode = "per_user"
logger.info("✓ ZeroGPU per-user mode enabled (multi-tenant)")
else:
logger.warning("ZeroGPU per-user mode enabled but admin credentials not provided")
except ImportError:
logger.warning("zero_gpu_user_manager not available, falling back to service account mode")
per_user_mode = False
except Exception as e:
logger.warning(f"Could not initialize ZeroGPU user manager: {e}. Falling back to service account mode.")
per_user_mode = False
if not per_user_mode:
# Option A: Service Account (Single-tenant)
try:
from zero_gpu_client import ZeroGPUChatClient
base_url = zero_gpu_config.get("base_url", os.getenv("ZERO_GPU_API_URL", "https://bm9njt1ypzvuqw-8000.proxy.runpod.net"))
email = zero_gpu_config.get("email", os.getenv("ZERO_GPU_EMAIL", ""))
password = zero_gpu_config.get("password", os.getenv("ZERO_GPU_PASSWORD", ""))
if email and password:
self.zero_gpu_client = ZeroGPUChatClient(base_url, email, password)
self.use_zero_gpu = True
self.zero_gpu_mode = "service_account"
logger.info("✓ ZeroGPU API client initialized (service account mode)")
# Wait for API to be ready (non-blocking, will fallback if not ready)
try:
if not self.zero_gpu_client.wait_for_ready(timeout=10):
logger.warning("ZeroGPU API not ready, will use HF fallback")
self.use_zero_gpu = False
except Exception as e:
logger.warning(f"Could not verify ZeroGPU API readiness: {e}. Will use HF fallback.")
self.use_zero_gpu = False
else:
logger.warning("ZeroGPU enabled but credentials not provided")
except ImportError:
logger.warning("zero_gpu_client not available, ZeroGPU disabled")
except Exception as e:
logger.warning(f"Could not initialize ZeroGPU client: {e}. Falling back to HF API.")
self.use_zero_gpu = False
# Initialize local model loader if enabled (but don't load models yet - lazy loading)
if self.use_local_models:
try:
from .local_model_loader import LocalModelLoader
# Initialize loader but don't load models yet
self.local_loader = LocalModelLoader()
logger.info("✓ Local model loader initialized (models will load on-demand as fallback)")
logger.info("Models will only load if ZeroGPU API fails")
except Exception as e:
logger.warning(f"Could not initialize local model loader: {e}. Local fallback unavailable.")
logger.warning("This is normal if transformers/torch not available")
self.use_local_models = False
self.local_loader = None
async def route_inference(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs):
"""
Smart routing based on task specialization
Tries ZeroGPU API first, then local models as fallback (lazy loading), then HF Inference API
Args:
task_type: Task type (e.g., "intent_classification", "general_reasoning")
prompt: User prompt/message
context: Optional conversation context
user_id: Optional user ID for per-user ZeroGPU accounts (Option B)
**kwargs: Additional generation parameters
"""
logger.info(f"Routing inference for task: {task_type}")
model_config = self._select_model(task_type)
logger.info(f"Selected model: {model_config['model_id']}")
# Try ZeroGPU API first (primary path)
if self.use_zero_gpu:
try:
result = await self._call_zero_gpu_endpoint(task_type, prompt, context, user_id, **kwargs)
if result is not None:
logger.info(f"Inference complete for {task_type} (ZeroGPU API)")
return result
else:
logger.warning("ZeroGPU API returned None, falling back to local models")
except Exception as e:
logger.warning(f"ZeroGPU API inference failed: {e}. Falling back to local models.")
logger.debug("Exception details:", exc_info=True)
# Fallback to local models (lazy loading - only if ZeroGPU fails)
if self.use_local_models and self.local_loader:
try:
logger.info("ZeroGPU API unavailable, loading local model as fallback...")
# Handle embedding generation separately
if task_type == "embedding_generation":
result = await self._call_local_embedding(model_config, prompt, **kwargs)
else:
result = await self._call_local_model(model_config, prompt, task_type, **kwargs)
if result is not None:
logger.info(f"Inference complete for {task_type} (local model fallback)")
return result
else:
logger.warning("Local model returned None, falling back to HF API")
except Exception as e:
logger.warning(f"Local model inference failed: {e}. Falling back to HF API.")
logger.debug("Exception details:", exc_info=True)
# Final fallback to HF Inference API
logger.info("Using HF Inference API as final fallback")
# Health check and fallback logic
if not await self._is_model_healthy(model_config["model_id"]):
logger.warning(f"Model unhealthy, using fallback")
model_config = self._get_fallback_model(task_type)
logger.info(f"Fallback model: {model_config['model_id']}")
result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs)
logger.info(f"Inference complete for {task_type}")
return result
async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]:
"""Call local model for inference (lazy loading - only used as fallback)."""
if not self.local_loader:
return None
model_id = model_config["model_id"]
max_tokens = kwargs.get('max_tokens', 512)
temperature = kwargs.get('temperature', 0.7)
try:
# Ensure model is loaded (lazy loading on first use)
if model_id not in self.local_loader.loaded_models:
logger.info(f"Lazy loading local model {model_id} as fallback (ZeroGPU unavailable)")
self.local_loader.load_chat_model(model_id, load_in_8bit=False)
# Format as chat messages if needed
messages = [{"role": "user", "content": prompt}]
# Generate using local model
result = await asyncio.to_thread(
self.local_loader.generate_chat_completion,
model_id=model_id,
messages=messages,
max_tokens=max_tokens,
temperature=temperature
)
logger.info(f"Local model {model_id} generated response (length: {len(result)})")
logger.info("=" * 80)
logger.info("LOCAL MODEL RESPONSE:")
logger.info("=" * 80)
logger.info(f"Model: {model_id}")
logger.info(f"Task Type: {task_type}")
logger.info(f"Response Length: {len(result)} characters")
logger.info("-" * 40)
logger.info("FULL RESPONSE CONTENT:")
logger.info("-" * 40)
logger.info(result)
logger.info("-" * 40)
logger.info("END OF RESPONSE")
logger.info("=" * 80)
return result
except Exception as e:
logger.error(f"Error calling local model: {e}", exc_info=True)
return None
async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]:
"""Call local embedding model (lazy loading - only used as fallback)."""
if not self.local_loader:
return None
model_id = model_config["model_id"]
try:
# Ensure model is loaded (lazy loading on first use)
if model_id not in self.local_loader.loaded_embedding_models:
logger.info(f"Lazy loading local embedding model {model_id} as fallback (ZeroGPU unavailable)")
self.local_loader.load_embedding_model(model_id)
# Generate embedding
embedding = await asyncio.to_thread(
self.local_loader.get_embedding,
model_id=model_id,
text=text
)
logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})")
return embedding
except Exception as e:
logger.error(f"Error calling local embedding model: {e}", exc_info=True)
return None
async def _call_zero_gpu_endpoint(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs) -> Optional[str]:
"""
Call ZeroGPU API endpoint
Args:
task_type: Task type (e.g., "intent_classification", "general_reasoning")
prompt: User prompt/message
context: Optional conversation context
user_id: Optional user ID for per-user accounts (Option B)
**kwargs: Additional generation parameters
Returns:
Generated text response or None if failed
"""
# Get appropriate client based on mode
client = None
if self.zero_gpu_mode == "per_user" and self.zero_gpu_user_manager and user_id:
# Option B: Per-user accounts
client = await self.zero_gpu_user_manager.get_or_create_user_client(user_id)
if not client:
logger.warning(f"Could not get ZeroGPU client for user {user_id}, falling back to service account")
client = self.zero_gpu_client
else:
# Option A: Service account
client = self.zero_gpu_client
if not client:
return None
try:
# Map task type to ZeroGPU task
task_mapping = LLM_CONFIG.get("zero_gpu_task_mapping", {})
zero_gpu_task = task_mapping.get(task_type, "general")
logger.info(f"Calling ZeroGPU API for task: {task_type} -> {zero_gpu_task}")
logger.debug(f"Prompt length: {len(prompt)}")
logger.info("=" * 80)
logger.info("ZEROGPU API REQUEST:")
logger.info("=" * 80)
logger.info(f"Task Type: {task_type} -> ZeroGPU Task: {zero_gpu_task}")
logger.info(f"Prompt Length: {len(prompt)} characters")
logger.info("-" * 40)
logger.info("FULL PROMPT CONTENT:")
logger.info("-" * 40)
logger.info(prompt)
logger.info("-" * 40)
logger.info("END OF PROMPT")
logger.info("=" * 80)
# Prepare context if provided
context_messages = None
if context:
context_messages = []
for msg in context[-50:]: # Limit to 50 messages as per API
context_messages.append({
"role": msg.get("role", "user"),
"content": msg.get("content", ""),
"timestamp": msg.get("timestamp", "")
})
# Prepare generation parameters
generation_params = {
"max_tokens": kwargs.get('max_tokens', 512),
"temperature": kwargs.get('temperature', 0.7),
}
# Add optional parameters
if 'top_p' in kwargs:
generation_params["top_p"] = kwargs['top_p']
if 'system_prompt' in kwargs:
generation_params["system_prompt"] = kwargs['system_prompt']
# Call ZeroGPU API
response = client.chat(
message=prompt,
task=zero_gpu_task,
context=context_messages,
**generation_params
)
# Extract response text
if response and "response" in response:
generated_text = response["response"]
if not generated_text or generated_text.strip() == "":
logger.warning("ZeroGPU API returned empty response")
return None
logger.info(f"ZeroGPU API returned response (length: {len(generated_text)})")
logger.info("=" * 80)
logger.info("COMPLETE ZEROGPU API RESPONSE:")
logger.info("=" * 80)
logger.info(f"Task Type: {task_type} -> ZeroGPU Task: {zero_gpu_task}")
logger.info(f"Response Length: {len(generated_text)} characters")
# Log metrics if available
if "tokens_used" in response:
tokens = response["tokens_used"]
logger.info(f"Tokens: input={tokens.get('input', 0)}, output={tokens.get('output', 0)}, total={tokens.get('total', 0)}")
if "inference_metrics" in response:
metrics = response["inference_metrics"]
logger.info(f"Inference Duration: {metrics.get('inference_duration', 0):.2f}s")
logger.info(f"Tokens/Second: {metrics.get('tokens_per_second', 0):.2f}")
logger.info("-" * 40)
logger.info("FULL RESPONSE CONTENT:")
logger.info("-" * 40)
logger.info(generated_text)
logger.info("-" * 40)
logger.info("END OF RESPONSE")
logger.info("=" * 80)
return generated_text
else:
logger.error(f"Unexpected ZeroGPU response format: {response}")
return None
except Exception as e:
logger.error(f"Error calling ZeroGPU API: {e}", exc_info=True)
return None
def _select_model(self, task_type: str) -> dict:
model_map = {
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["safety_checker"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def _is_model_healthy(self, model_id: str) -> bool:
"""
Check if the model is healthy and available
Mark models as healthy by default - actual availability checked at API call time
"""
# Check cached health status
if model_id in self.health_status:
return self.health_status[model_id]
# All models marked healthy initially - real check happens during API call
self.health_status[model_id] = True
return True
def _get_fallback_model(self, task_type: str) -> dict:
"""
Get fallback model configuration for the task type
"""
# Fallback mapping
fallback_map = {
"intent_classification": LLM_CONFIG["models"]["reasoning_primary"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["reasoning_primary"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs):
"""
FIXED: Make actual call to Hugging Face Chat Completions API
Uses the correct chat completions protocol with retry logic and exponential backoff
IMPORTANT: task_type parameter is now properly included in the method signature
"""
# Retry configuration
max_retries = kwargs.get('max_retries', 3)
initial_delay = kwargs.get('initial_delay', 1.0) # Start with 1 second
max_delay = kwargs.get('max_delay', 16.0) # Cap at 16 seconds
timeout = kwargs.get('timeout', 30)
try:
import requests
from requests.exceptions import Timeout, RequestException, ConnectionError as RequestsConnectionError
model_id = model_config["model_id"]
# Use the chat completions endpoint
api_url = "https://router.huggingface.co/v1/chat/completions"
logger.info(f"Calling HF Chat Completions API for model: {model_id}")
logger.debug(f"Prompt length: {len(prompt)}")
logger.info("=" * 80)
logger.info("LLM API REQUEST - COMPLETE PROMPT:")
logger.info("=" * 80)
logger.info(f"Model: {model_id}")
# FIXED: task_type is now properly available as a parameter
logger.info(f"Task Type: {task_type}")
logger.info(f"Prompt Length: {len(prompt)} characters")
logger.info("-" * 40)
logger.info("FULL PROMPT CONTENT:")
logger.info("-" * 40)
logger.info(prompt)
logger.info("-" * 40)
logger.info("END OF PROMPT")
logger.info("=" * 80)
# Prepare the request payload
max_tokens = kwargs.get('max_tokens', 512)
temperature = kwargs.get('temperature', 0.7)
payload = {
"model": model_id,
"messages": [
{
"role": "user",
"content": prompt
}
],
"max_tokens": max_tokens,
"temperature": temperature,
"stream": False
}
headers = {
"Authorization": f"Bearer {self.hf_token}",
"Content-Type": "application/json"
}
# Retry logic with exponential backoff
last_exception = None
for attempt in range(max_retries + 1):
try:
if attempt > 0:
# Calculate exponential backoff delay
delay = min(initial_delay * (2 ** (attempt - 1)), max_delay)
logger.warning(f"Retry attempt {attempt}/{max_retries} after {delay:.1f}s delay (exponential backoff)")
await asyncio.sleep(delay)
logger.info(f"Sending request to: {api_url} (attempt {attempt + 1}/{max_retries + 1})")
logger.debug(f"Payload: {payload}")
response = requests.post(api_url, json=payload, headers=headers, timeout=timeout)
if response.status_code == 200:
result = response.json()
logger.debug(f"Raw response: {result}")
if 'choices' in result and len(result['choices']) > 0:
generated_text = result['choices'][0]['message']['content']
if not generated_text or generated_text.strip() == "":
logger.warning(f"Empty or invalid response, using fallback")
return None
if attempt > 0:
logger.info(f"Successfully retrieved response after {attempt} retry attempts")
logger.info(f"HF API returned response (length: {len(generated_text)})")
logger.info("=" * 80)
logger.info("COMPLETE LLM API RESPONSE:")
logger.info("=" * 80)
logger.info(f"Model: {model_id}")
# FIXED: task_type is now properly available
logger.info(f"Task Type: {task_type}")
logger.info(f"Response Length: {len(generated_text)} characters")
logger.info("-" * 40)
logger.info("FULL RESPONSE CONTENT:")
logger.info("-" * 40)
logger.info(generated_text)
logger.info("-" * 40)
logger.info("END OF LLM RESPONSE")
logger.info("=" * 80)
return generated_text
else:
logger.error(f"Unexpected response format: {result}")
return None
elif response.status_code == 503:
# Model is loading - this is retryable
if attempt < max_retries:
logger.warning(f"Model loading (503), will retry (attempt {attempt + 1}/{max_retries + 1})")
last_exception = Exception(f"Model loading (503)")
continue
else:
# After max retries, try fallback model
logger.warning(f"Model loading (503) after {max_retries} retries, trying fallback model")
fallback_config = self._get_fallback_model(task_type)
# FIXED: Ensure task_type is passed in recursive call
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
else:
# Non-retryable HTTP errors
logger.error(f"HF API error: {response.status_code} - {response.text}")
return None
except Timeout as e:
last_exception = e
if attempt < max_retries:
logger.warning(f"Request timeout (attempt {attempt + 1}/{max_retries + 1}): {str(e)}")
continue
else:
logger.error(f"Request timeout after {max_retries} retries: {str(e)}")
# Try fallback model on final timeout
logger.warning("Attempting fallback model due to persistent timeout")
fallback_config = self._get_fallback_model(task_type)
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
except (RequestsConnectionError, RequestException) as e:
last_exception = e
if attempt < max_retries:
logger.warning(f"Connection error (attempt {attempt + 1}/{max_retries + 1}): {str(e)}")
continue
else:
logger.error(f"Connection error after {max_retries} retries: {str(e)}")
# Try fallback model on final connection error
logger.warning("Attempting fallback model due to persistent connection error")
fallback_config = self._get_fallback_model(task_type)
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs)
# If we exhausted all retries and didn't return
if last_exception:
logger.error(f"Failed after {max_retries} retries. Last error: {last_exception}")
return None
except ImportError:
logger.warning("requests library not available, using mock response")
return f"[Mock] Response to: {prompt[:100]}..."
except Exception as e:
logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
return None
async def get_available_models(self):
"""
Get list of available models for testing
"""
return list(LLM_CONFIG["models"].keys())
async def health_check(self):
"""
Perform health check on all models
"""
health_status = {}
for model_name, model_config in LLM_CONFIG["models"].items():
model_id = model_config["model_id"]
is_healthy = await self._is_model_healthy(model_id)
health_status[model_name] = {
"model_id": model_id,
"healthy": is_healthy
}
return health_status
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: int = 4000) -> str:
"""Smart context windowing for LLM calls"""
try:
from transformers import AutoTokenizer
# Initialize tokenizer lazily
if not hasattr(self, 'tokenizer'):
try:
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
except Exception as e:
logger.warning(f"Could not load tokenizer: {e}, using character count estimation")
self.tokenizer = None
except ImportError:
logger.warning("transformers library not available, using character count estimation")
self.tokenizer = None
# Priority order for context elements
priority_elements = [
('current_query', 1.0),
('recent_interactions', 0.8),
('user_preferences', 0.6),
('session_summary', 0.4),
('historical_context', 0.2)
]
formatted_context = []
total_tokens = 0
for element, priority in priority_elements:
# Map element names to context keys
element_key_map = {
'current_query': raw_context.get('user_input', ''),
'recent_interactions': raw_context.get('interaction_contexts', []),
'user_preferences': raw_context.get('preferences', {}),
'session_summary': raw_context.get('session_context', {}),
'historical_context': raw_context.get('user_context', '')
}
content = element_key_map.get(element, '')
# Convert to string if needed
if isinstance(content, dict):
content = str(content)
elif isinstance(content, list):
content = "\n".join([str(item) for item in content[:10]]) # Limit to 10 items
if not content:
continue
# Estimate tokens
if self.tokenizer:
try:
tokens = len(self.tokenizer.encode(content))
except:
# Fallback to character-based estimation (rough: 1 token ≈ 4 chars)
tokens = len(content) // 4
else:
# Character-based estimation (rough: 1 token ≈ 4 chars)
tokens = len(content) // 4
if total_tokens + tokens <= max_tokens:
formatted_context.append(f"=== {element.upper()} ===\n{content}")
total_tokens += tokens
elif priority > 0.5: # Critical elements - truncate if needed
available = max_tokens - total_tokens
if available > 100: # Only truncate if we have meaningful space
truncated = self._truncate_to_tokens(content, available)
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}")
break
return "\n\n".join(formatted_context)
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str:
"""Truncate content to fit within token limit"""
if not self.tokenizer:
# Simple character-based truncation
max_chars = max_tokens * 4
if len(content) <= max_chars:
return content
return content[:max_chars-3] + "..."
try:
# Tokenize and truncate
tokens = self.tokenizer.encode(content)
if len(tokens) <= max_tokens:
return content
truncated_tokens = tokens[:max_tokens-3] # Leave room for "..."
truncated_text = self.tokenizer.decode(truncated_tokens)
return truncated_text + "..."
except Exception as e:
logger.warning(f"Error truncating with tokenizer: {e}, using character truncation")
max_chars = max_tokens * 4
if len(content) <= max_chars:
return content
return content[:max_chars-3] + "..."