File size: 6,427 Bytes
66dbebd ae20ff2 66dbebd ae20ff2 66dbebd ae20ff2 66dbebd ae20ff2 66dbebd ae20ff2 66dbebd ae20ff2 66dbebd ae20ff2 66dbebd ae20ff2 66dbebd a814110 66dbebd a814110 66dbebd a814110 66dbebd a814110 ae20ff2 66dbebd a814110 66dbebd a814110 66dbebd a814110 66dbebd a814110 66dbebd a814110 66dbebd ae20ff2 66dbebd ae20ff2 66dbebd ae20ff2 66dbebd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
# llm_router.py
import logging
from models_config import LLM_CONFIG
logger = logging.getLogger(__name__)
class LLMRouter:
def __init__(self, hf_token):
self.hf_token = hf_token
self.health_status = {}
logger.info("LLMRouter initialized")
if hf_token:
logger.info("HF token available")
else:
logger.warning("No HF token provided")
async def route_inference(self, task_type: str, prompt: str, **kwargs):
"""
Smart routing based on task specialization
"""
logger.info(f"Routing inference for task: {task_type}")
model_config = self._select_model(task_type)
logger.info(f"Selected model: {model_config['model_id']}")
# Health check and fallback logic
if not await self._is_model_healthy(model_config["model_id"]):
logger.warning(f"Model unhealthy, using fallback")
model_config = self._get_fallback_model(task_type)
logger.info(f"Fallback model: {model_config['model_id']}")
result = await self._call_hf_endpoint(model_config, prompt, **kwargs)
logger.info(f"Inference complete for {task_type}")
return result
def _select_model(self, task_type: str) -> dict:
model_map = {
"intent_classification": LLM_CONFIG["models"]["classification_specialist"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["safety_checker"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def _is_model_healthy(self, model_id: str) -> bool:
"""
Check if the model is healthy and available
Mark models as healthy by default - actual availability checked at API call time
"""
# Check cached health status
if model_id in self.health_status:
return self.health_status[model_id]
# All models marked healthy initially - real check happens during API call
self.health_status[model_id] = True
return True
def _get_fallback_model(self, task_type: str) -> dict:
"""
Get fallback model configuration for the task type
"""
# Fallback mapping
fallback_map = {
"intent_classification": LLM_CONFIG["models"]["reasoning_primary"],
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
"safety_check": LLM_CONFIG["models"]["reasoning_primary"],
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
}
return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs):
"""
Make actual call to Hugging Face Chat Completions API
Uses the correct chat completions protocol
"""
try:
import requests
model_id = model_config["model_id"]
# Use the chat completions endpoint
api_url = "https://router.huggingface.co/v1/chat/completions"
logger.info(f"Calling HF Chat Completions API for model: {model_id}")
logger.debug(f"Prompt length: {len(prompt)}")
headers = {
"Authorization": f"Bearer {self.hf_token}",
"Content-Type": "application/json"
}
# Prepare payload in chat completions format
# Extract the actual question from the prompt if it's in a structured format
user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip()
payload = {
"model": f"{model_id}:together", # Use the Together endpoint as specified
"messages": [
{
"role": "user",
"content": user_message
}
],
"max_tokens": kwargs.get("max_tokens", 2000),
"temperature": kwargs.get("temperature", 0.7),
"top_p": kwargs.get("top_p", 0.95)
}
# Make the API call
response = requests.post(api_url, json=payload, headers=headers, timeout=60)
if response.status_code == 200:
result = response.json()
# Handle chat completions response format
if "choices" in result and len(result["choices"]) > 0:
message = result["choices"][0].get("message", {})
generated_text = message.get("content", "")
# Ensure we always return a string, never None
if not generated_text or not isinstance(generated_text, str):
logger.warning(f"Empty or invalid response, using fallback")
return None
logger.info(f"HF API returned response (length: {len(generated_text)})")
return generated_text
else:
logger.error(f"Unexpected response format: {result}")
return None
elif response.status_code == 503:
# Model is loading, retry with simpler model
logger.warning(f"Model loading (503), trying fallback")
fallback_config = self._get_fallback_model("response_synthesis")
return await self._call_hf_endpoint(fallback_config, prompt, **kwargs)
else:
logger.error(f"HF API error: {response.status_code} - {response.text}")
return None
except ImportError:
logger.warning("requests library not available, using mock response")
return f"[Mock] Response to: {prompt[:100]}..."
except Exception as e:
logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
return None
|