# llm_router.py from models_config import LLM_CONFIG class LLMRouter: def __init__(self, hf_token): self.hf_token = hf_token self.health_status = {} async def route_inference(self, task_type: str, prompt: str, **kwargs): """ Smart routing based on task specialization """ model_config = self._select_model(task_type) # Health check and fallback logic if not await self._is_model_healthy(model_config["model_id"]): model_config = self._get_fallback_model(task_type) return await self._call_hf_endpoint(model_config, prompt, **kwargs) def _select_model(self, task_type: str) -> dict: model_map = { "intent_classification": LLM_CONFIG["models"]["classification_specialist"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["safety_checker"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def _is_model_healthy(self, model_id: str) -> bool: """ Check if the model is healthy and available """ # Check cached health status if model_id in self.health_status: return self.health_status[model_id] # Default to healthy for now (can implement actual health checks) self.health_status[model_id] = True return True def _get_fallback_model(self, task_type: str) -> dict: """ Get fallback model configuration for the task type """ # Fallback mapping fallback_map = { "intent_classification": LLM_CONFIG["models"]["reasoning_primary"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["reasoning_primary"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs): """ Make actual call to Hugging Face Inference API """ try: import requests model_id = model_config["model_id"] api_url = f"https://api-inference.huggingface.co/models/{model_id}" headers = { "Authorization": f"Bearer {self.hf_token}", "Content-Type": "application/json" } # Prepare payload payload = { "inputs": prompt, "parameters": { "max_new_tokens": kwargs.get("max_tokens", 250), "temperature": kwargs.get("temperature", 0.7), "top_p": kwargs.get("top_p", 0.95), "return_full_text": False } } # Make the API call response = requests.post(api_url, json=payload, headers=headers, timeout=30) if response.status_code == 200: result = response.json() # Handle different response formats if isinstance(result, list) and len(result) > 0: generated_text = result[0].get("generated_text", "") else: generated_text = str(result) return generated_text else: print(f"HF API error: {response.status_code} - {response.text}") return None except ImportError: print("requests library not available, using mock response") return f"[Mock] Response to: {prompt[:100]}..." except Exception as e: print(f"Error calling HF endpoint: {e}") return None