File size: 5,017 Bytes
66dbebd
ae20ff2
66dbebd
 
ae20ff2
 
66dbebd
 
 
 
ae20ff2
 
 
 
 
66dbebd
 
 
 
 
ae20ff2
66dbebd
ae20ff2
66dbebd
 
 
ae20ff2
66dbebd
ae20ff2
66dbebd
ae20ff2
 
 
66dbebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae20ff2
 
 
66dbebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ae20ff2
66dbebd
 
ae20ff2
66dbebd
 
 
ae20ff2
66dbebd
 
ae20ff2
66dbebd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# llm_router.py
import logging
from models_config import LLM_CONFIG

logger = logging.getLogger(__name__)

class LLMRouter:
    def __init__(self, hf_token):
        self.hf_token = hf_token
        self.health_status = {}
        logger.info("LLMRouter initialized")
        if hf_token:
            logger.info("HF token available")
        else:
            logger.warning("No HF token provided")
        
    async def route_inference(self, task_type: str, prompt: str, **kwargs):
        """
        Smart routing based on task specialization
        """
        logger.info(f"Routing inference for task: {task_type}")
        model_config = self._select_model(task_type)
        logger.info(f"Selected model: {model_config['model_id']}")
        
        # Health check and fallback logic
        if not await self._is_model_healthy(model_config["model_id"]):
            logger.warning(f"Model unhealthy, using fallback")
            model_config = self._get_fallback_model(task_type)
            logger.info(f"Fallback model: {model_config['model_id']}")
            
        result = await self._call_hf_endpoint(model_config, prompt, **kwargs)
        logger.info(f"Inference complete for {task_type}")
        return result
    
    def _select_model(self, task_type: str) -> dict:
        model_map = {
            "intent_classification": LLM_CONFIG["models"]["classification_specialist"],
            "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
            "safety_check": LLM_CONFIG["models"]["safety_checker"],
            "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
            "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
        }
        return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
    
    async def _is_model_healthy(self, model_id: str) -> bool:
        """
        Check if the model is healthy and available
        """
        # Check cached health status
        if model_id in self.health_status:
            return self.health_status[model_id]
        
        # Default to healthy for now (can implement actual health checks)
        self.health_status[model_id] = True
        return True
    
    def _get_fallback_model(self, task_type: str) -> dict:
        """
        Get fallback model configuration for the task type
        """
        # Fallback mapping
        fallback_map = {
            "intent_classification": LLM_CONFIG["models"]["reasoning_primary"],
            "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"],
            "safety_check": LLM_CONFIG["models"]["reasoning_primary"],
            "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"],
            "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"]
        }
        return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
    
    async def _call_hf_endpoint(self, model_config: dict, prompt: str, **kwargs):
        """
        Make actual call to Hugging Face Inference API
        """
        try:
            import requests
            
            model_id = model_config["model_id"]
            api_url = f"https://api-inference.huggingface.co/models/{model_id}"
            
            logger.info(f"Calling HF API for model: {model_id}")
            logger.debug(f"Prompt length: {len(prompt)}")
            
            headers = {
                "Authorization": f"Bearer {self.hf_token}",
                "Content-Type": "application/json"
            }
            
            # Prepare payload
            payload = {
                "inputs": prompt,
                "parameters": {
                    "max_new_tokens": kwargs.get("max_tokens", 250),
                    "temperature": kwargs.get("temperature", 0.7),
                    "top_p": kwargs.get("top_p", 0.95),
                    "return_full_text": False
                }
            }
            
            # Make the API call
            response = requests.post(api_url, json=payload, headers=headers, timeout=30)
            
            if response.status_code == 200:
                result = response.json()
                # Handle different response formats
                if isinstance(result, list) and len(result) > 0:
                    generated_text = result[0].get("generated_text", "")
                else:
                    generated_text = str(result)
                logger.info(f"HF API returned response (length: {len(generated_text)})")
                return generated_text
            else:
                logger.error(f"HF API error: {response.status_code} - {response.text}")
                return None
                
        except ImportError:
            logger.warning("requests library not available, using mock response")
            return f"[Mock] Response to: {prompt[:100]}..."
        except Exception as e:
            logger.error(f"Error calling HF endpoint: {e}", exc_info=True)
            return None