File size: 6,427 Bytes
66dbebd
ae20ff2
66dbebd
 
ae20ff2
 
66dbebd
 
 
 
ae20ff2
 
 
 
 
66dbebd
 
 
 
 
ae20ff2
66dbebd
ae20ff2
66dbebd
 
 
ae20ff2
66dbebd
ae20ff2
66dbebd
ae20ff2
 
 
66dbebd
 
 
 
 
 
 
 
 
 
 
 
 
 
a814110
66dbebd
 
 
 
 
a814110
66dbebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a814110
 
66dbebd
 
 
 
 
 
a814110
 
 
 
ae20ff2
 
66dbebd
 
 
 
 
a814110
 
 
 
66dbebd
a814110
 
 
 
 
 
 
 
 
 
66dbebd
 
 
a814110
66dbebd
 
 
a814110
 
 
 
 
 
 
 
 
 
 
 
66dbebd
a814110
 
 
 
 
 
 
66dbebd
ae20ff2
66dbebd
 
 
ae20ff2
66dbebd
 
ae20ff2
66dbebd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# llm_router.py
import logging
from models_config import LLM_CONFIG

logger = logging.getLogger(__name__)

class LLMRouter:
    def __init__(self, hf_token):
        self.hf_token = hf_token
        self.health_status = {}
        logger.info("LLMRouter initialized")
        if hf_token:
            logger.info("HF token available")
        else:
            logger.warning("No HF token provided")
        
    async def route_inference(self, task_type: str, prompt: str, **kwargs):
        """
        Smart routing based on task specialization
        """
        logger.info(f"Routing inference for task: {task_type}")
        model_config = self._select_model(task_type)
        logger.info(f"Selected model: {model_config['model_id']}")
        
        # Health check and fallback logic
        if not await self._is_model_healthy(model_config["model_id"]):
            logger.warning(f"Model unhealthy, using fallback")
            model_config = self._get_fallback_model(task_type)
            logger.info(f"Fallback model: {model_config['model_id']}")
            
        result = await self._call_hf_endpoint(model_config, prompt, **kwargs)
        logger.info(f"Inference complete for {task_type}")
        return result
    
    def _select_model(self, task_type: str) -> dict:
        model_map = {
            "intent_classification": LLM_CONFIG["models"]["classification_specialist"],
            "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
            "safety_check": LLM_CONFIG["models"]["safety_checker"],
            "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
            "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
        }
        return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
    
    async def _is_model_healthy(self, model_id: str) -> bool:
        """
        Check if the model is healthy and available
        Mark models as healthy by default - actual availability checked at API call time
        """
        # Check cached health status
        if model_id in self.health_status:
            return self.health_status[model_id]
        
        # All models marked healthy initially - real check happens during API call
        self.health_status[model_id] = True
        return True
    
    def _get_fallback_model(self, task_type: str) -> dict:
        """
        Get fallback model configuration for the task type
        """
        # Fallback mapping
        fallback_map = {
            "intent_classification": LLM_CONFIG["models"]["reasoning_primary"],
            "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
            "safety_check": LLM_CONFIG["models"]["reasoning_primary"],
            "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
            "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
        }
        return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
    
    async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs):
        """
        Make actual call to Hugging Face Chat Completions API
        Uses the correct chat completions protocol
        """
        try:
            import requests
            
            model_id = model_config["model_id"]
            
            # Use the chat completions endpoint
            api_url = "https://router.huggingface.co/v1/chat/completions"
            
            logger.info(f"Calling HF Chat Completions API for model: {model_id}")
            logger.debug(f"Prompt length: {len(prompt)}")
            
            headers = {
                "Authorization": f"Bearer {self.hf_token}",
                "Content-Type": "application/json"
            }
            
            # Prepare payload in chat completions format
            # Extract the actual question from the prompt if it's in a structured format
            user_message = prompt if "User Question:" not in prompt else prompt.split("User Question:")[1].split("\n")[0].strip()
            
            payload = {
                "model": f"{model_id}:together",  # Use the Together endpoint as specified
                "messages": [
                    {
                        "role": "user",
                        "content": user_message
                    }
                ],
                "max_tokens": kwargs.get("max_tokens", 2000),
                "temperature": kwargs.get("temperature", 0.7),
                "top_p": kwargs.get("top_p", 0.95)
            }
            
            # Make the API call
            response = requests.post(api_url, json=payload, headers=headers, timeout=60)
            
            if response.status_code == 200:
                result = response.json()
                # Handle chat completions response format
                if "choices" in result and len(result["choices"]) > 0:
                    message = result["choices"][0].get("message", {})
                    generated_text = message.get("content", "")
                    
                    # Ensure we always return a string, never None
                    if not generated_text or not isinstance(generated_text, str):
                        logger.warning(f"Empty or invalid response, using fallback")
                        return None
                    
                    logger.info(f"HF API returned response (length: {len(generated_text)})")
                    return generated_text
                else:
                    logger.error(f"Unexpected response format: {result}")
                    return None
            elif response.status_code == 503:
                # Model is loading, retry with simpler model
                logger.warning(f"Model loading (503), trying fallback")
                fallback_config = self._get_fallback_model("response_synthesis")
                return await self._call_hf_endpoint(fallback_config, prompt, **kwargs)
            else:
                logger.error(f"HF API error: {response.status_code} - {response.text}")
                return None
                
        except ImportError:
            logger.warning("requests library not available, using mock response")
            return f"[Mock] Response to: {prompt[:100]}..."
        except Exception as e:
            logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
            return None