# llm_router.py import logging from models_config import LLM_CONFIG logger = logging.getLogger(__name__) class LLMRouter: def __init__(self, hf_token): self.hf_token = hf_token self.health_status = {} logger.info("LLMRouter initialized") if hf_token: logger.info("HF token available") else: logger.warning("No HF token provided") async def route_inference(self, task_type: str, prompt: str, **kwargs): """ Smart routing based on task specialization """ logger.info(f"Routing inference for task: {task_type}") model_config = self._select_model(task_type) logger.info(f"Selected model: {model_config['model_id']}") # Health check and fallback logic if not await self._is_model_healthy(model_config["model_id"]): logger.warning(f"Model unhealthy, using fallback") model_config = self._get_fallback_model(task_type) logger.info(f"Fallback model: {model_config['model_id']}") result = await self._call_hf_endpoint(model_config, prompt, **kwargs) logger.info(f"Inference complete for {task_type}") return result def _select_model(self, task_type: str) -> dict: model_map = { "intent_classification": LLM_CONFIG["models"]["classification_specialist"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["safety_checker"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def _is_model_healthy(self, model_id: str) -> bool: """ Check if the model is healthy and available """ # Check cached health status if model_id in self.health_status: return self.health_status[model_id] # Default to healthy for now (can implement actual health checks) self.health_status[model_id] = True return True def _get_fallback_model(self, task_type: str) -> dict: """ Get fallback model configuration for the task type """ # Fallback mapping fallback_map = { "intent_classification": LLM_CONFIG["models"]["reasoning_primary"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["reasoning_primary"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs): """ Make actual call to Hugging Face Inference API """ try: import requests model_id = model_config["model_id"] api_url = f"https://api-inference.huggingface.co/models/{model_id}" logger.info(f"Calling HF API for model: {model_id}") logger.debug(f"Prompt length: {len(prompt)}") headers = { "Authorization": f"Bearer {self.hf_token}", "Content-Type": "application/json" } # Prepare payload payload = { "inputs": prompt, "parameters": { "max_new_tokens": kwargs.get("max_tokens", 250), "temperature": kwargs.get("temperature", 0.7), "top_p": kwargs.get("top_p", 0.95), "return_full_text": False } } # Make the API call response = requests.post(api_url, json=payload, headers=headers, timeout=30) if response.status_code == 200: result = response.json() # Handle different response formats if isinstance(result, list) and len(result) > 0: generated_text = result[0].get("generated_text", "") else: generated_text = str(result) logger.info(f"HF API returned response (length: {len(generated_text)})") return generated_text else: logger.error(f"HF API error: {response.status_code} - {response.text}") return None except ImportError: logger.warning("requests library not available, using mock response") return f"[Mock] Response to: {prompt[:100]}..." except Exception as e: logger.error(f"Error calling HF endpoint: {e}", exc_info=True) return None