|
|
|
|
|
import logging |
|
|
from models_config import LLM_CONFIG |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LLMRouter: |
|
|
def __init__(self, hf_token): |
|
|
self.hf_token = hf_token |
|
|
self.health_status = {} |
|
|
logger.info("LLMRouter initialized") |
|
|
if hf_token: |
|
|
logger.info("HF token available") |
|
|
else: |
|
|
logger.warning("No HF token provided") |
|
|
|
|
|
async def route_inference(self, task_type: str, prompt: str, **kwargs): |
|
|
""" |
|
|
Smart routing based on task specialization |
|
|
""" |
|
|
logger.info(f"Routing inference for task: {task_type}") |
|
|
model_config = self._select_model(task_type) |
|
|
logger.info(f"Selected model: {model_config['model_id']}") |
|
|
|
|
|
|
|
|
if not await self._is_model_healthy(model_config["model_id"]): |
|
|
logger.warning(f"Model unhealthy, using fallback") |
|
|
model_config = self._get_fallback_model(task_type) |
|
|
logger.info(f"Fallback model: {model_config['model_id']}") |
|
|
|
|
|
result = await self._call_hf_endpoint(model_config, prompt, **kwargs) |
|
|
logger.info(f"Inference complete for {task_type}") |
|
|
return result |
|
|
|
|
|
def _select_model(self, task_type: str) -> dict: |
|
|
model_map = { |
|
|
"intent_classification": LLM_CONFIG["models"]["classification_specialist"], |
|
|
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
|
|
"safety_check": LLM_CONFIG["models"]["safety_checker"], |
|
|
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
|
|
} |
|
|
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
|
|
|
|
|
async def _is_model_healthy(self, model_id: str) -> bool: |
|
|
""" |
|
|
Check if the model is healthy and available |
|
|
Mark models as healthy by default - actual availability checked at API call time |
|
|
""" |
|
|
|
|
|
if model_id in self.health_status: |
|
|
return self.health_status[model_id] |
|
|
|
|
|
|
|
|
self.health_status[model_id] = True |
|
|
return True |
|
|
|
|
|
def _get_fallback_model(self, task_type: str) -> dict: |
|
|
""" |
|
|
Get fallback model configuration for the task type |
|
|
""" |
|
|
|
|
|
fallback_map = { |
|
|
"intent_classification": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
|
|
"safety_check": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
|
|
} |
|
|
return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
|
|
|
|
|
async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs): |
|
|
""" |
|
|
Make actual call to Hugging Face Chat Completions API |
|
|
Uses the correct chat completions protocol |
|
|
""" |
|
|
try: |
|
|
import requests |
|
|
|
|
|
model_id = model_config["model_id"] |
|
|
|
|
|
|
|
|
api_url = "https://router.huggingface.co/v1/chat/completions" |
|
|
|
|
|
logger.info(f"Calling HF Chat Completions API for model: {model_id}") |
|
|
logger.debug(f"Prompt length: {len(prompt)}") |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {self.hf_token}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip() |
|
|
|
|
|
payload = { |
|
|
"model": f"{model_id}:together", |
|
|
"messages": [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": user_message |
|
|
} |
|
|
], |
|
|
"max_tokens": kwargs.get("max_tokens", 2000), |
|
|
"temperature": kwargs.get("temperature", 0.7), |
|
|
"top_p": kwargs.get("top_p", 0.95) |
|
|
} |
|
|
|
|
|
|
|
|
response = requests.post(api_url, json=payload, headers=headers, timeout=60) |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
|
|
|
if "choices" in result and len(result["choices"]) > 0: |
|
|
message = result["choices"][0].get("message", {}) |
|
|
generated_text = message.get("content", "") |
|
|
|
|
|
|
|
|
if not generated_text or not isinstance(generated_text, str): |
|
|
logger.warning(f"Empty or invalid response, using fallback") |
|
|
return None |
|
|
|
|
|
logger.info(f"HF API returned response (length: {len(generated_text)})") |
|
|
return generated_text |
|
|
else: |
|
|
logger.error(f"Unexpected response format: {result}") |
|
|
return None |
|
|
elif response.status_code == 503: |
|
|
|
|
|
logger.warning(f"Model loading (503), trying fallback") |
|
|
fallback_config = self._get_fallback_model("response_synthesis") |
|
|
return await self._call_hf_endpoint(fallback_config, prompt, **kwargs) |
|
|
else: |
|
|
logger.error(f"HF API error: {response.status_code} - {response.text}") |
|
|
return None |
|
|
|
|
|
except ImportError: |
|
|
logger.warning("requests library not available, using mock response") |
|
|
return f"[Mock] Response to: {prompt[:100]}..." |
|
|
except Exception as e: |
|
|
logger.error(f"Error calling HF endpoint: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
|