Research_AI_Assistant / llm_router.py
JatsTheAIGen's picture
workflow errors debugging v10
a814110
raw
history blame
6.43 kB
# llm_router.py
import logging
from models_config import LLM_CONFIG
logger = logging.getLogger(__name__)
class LLMRouter:
def __init__(self, hf_token):
self.hf_token = hf_token
self.health_status = {}
logger.info("LLMRouter initialized")
if hf_token:
logger.info("HF token available")
else:
logger.warning("No HF token provided")
async def route_inference(self, task_type: str, prompt: str, **kwargs):
"""
Smart routing based on task specialization
"""
logger.info(f"Routing inference for task: {task_type}")
model_config = self._select_model(task_type)
logger.info(f"Selected model: {model_config['model_id']}")
# Health check and fallback logic
if not await self._is_model_healthy(model_config["model_id"]):
logger.warning(f"Model unhealthy, using fallback")
model_config = self._get_fallback_model(task_type)
logger.info(f"Fallback model: {model_config['model_id']}")
result = await self._call_hf_endpoint(model_config, prompt, **kwargs)
logger.info(f"Inference complete for {task_type}")
return result
def _select_model(self, task_type: str) -> dict:
model_map = {
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["safety_checker"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def _is_model_healthy(self, model_id: str) -> bool:
"""
Check if the model is healthy and available
Mark models as healthy by default - actual availability checked at API call time
"""
# Check cached health status
if model_id in self.health_status:
return self.health_status[model_id]
# All models marked healthy initially - real check happens during API call
self.health_status[model_id] = True
return True
def _get_fallback_model(self, task_type: str) -> dict:
"""
Get fallback model configuration for the task type
"""
# Fallback mapping
fallback_map = {
"intent_classification": LLM_CONFIG["models"]["reasoning_primary"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["reasoning_primary"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs):
"""
Make actual call to Hugging Face Chat Completions API
Uses the correct chat completions protocol
"""
try:
import requests
model_id = model_config["model_id"]
# Use the chat completions endpoint
api_url = "https://router.huggingface.co/v1/chat/completions"
logger.info(f"Calling HF Chat Completions API for model: {model_id}")
logger.debug(f"Prompt length: {len(prompt)}")
headers = {
"Authorization": f"Bearer {self.hf_token}",
"Content-Type": "application/json"
}
# Prepare payload in chat completions format
# Extract the actual question from the prompt if it's in a structured format
user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip()
payload = {
"model": f"{model_id}:together", # Use the Together endpoint as specified
"messages": [
{
"role": "user",
"content": user_message
}
],
"max_tokens": kwargs.get("max_tokens", 2000),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.95)
}
# Make the API call
response = requests.post(api_url, json=payload, headers=headers, timeout=60)
if response.status_code == 200:
result = response.json()
# Handle chat completions response format
if "choices" in result and len(result["choices"]) > 0:
message = result["choices"][0].get("message", {})
generated_text = message.get("content", "")
# Ensure we always return a string, never None
if not generated_text or not isinstance(generated_text, str):
logger.warning(f"Empty or invalid response, using fallback")
return None
logger.info(f"HF API returned response (length: {len(generated_text)})")
return generated_text
else:
logger.error(f"Unexpected response format: {result}")
return None
elif response.status_code == 503:
# Model is loading, retry with simpler model
logger.warning(f"Model loading (503), trying fallback")
fallback_config = self._get_fallback_model("response_synthesis")
return await self._call_hf_endpoint(fallback_config, prompt, **kwargs)
else:
logger.error(f"HF API error: {response.status_code} - {response.text}")
return None
except ImportError:
logger.warning("requests library not available, using mock response")
return f"[Mock] Response to: {prompt[:100]}..."
except Exception as e:
logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
return None