|
|
|
|
|
import logging |
|
|
import asyncio |
|
|
import os |
|
|
from typing import Dict, Optional, List |
|
|
from .models_config import LLM_CONFIG |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LLMRouter: |
|
|
def __init__(self, hf_token, use_local_models: bool = True, zero_gpu_config: Optional[Dict] = None): |
|
|
self.hf_token = hf_token |
|
|
self.health_status = {} |
|
|
self.use_local_models = use_local_models |
|
|
self.local_loader = None |
|
|
self.zero_gpu_client = None |
|
|
self.zero_gpu_user_manager = None |
|
|
self.use_zero_gpu = False |
|
|
self.zero_gpu_mode = "service_account" |
|
|
|
|
|
logger.info("LLMRouter initialized") |
|
|
if hf_token: |
|
|
logger.info("HF token available") |
|
|
else: |
|
|
logger.warning("No HF token provided") |
|
|
|
|
|
|
|
|
if zero_gpu_config and zero_gpu_config.get("enabled", False): |
|
|
|
|
|
per_user_mode = zero_gpu_config.get("per_user_mode", False) |
|
|
|
|
|
if per_user_mode: |
|
|
|
|
|
try: |
|
|
from zero_gpu_user_manager import ZeroGPUUserManager |
|
|
base_url = zero_gpu_config.get("base_url", os.getenv("ZERO_GPU_API_URL", "https://bm9njt1ypzvuqw-8000.proxy.runpod.net")) |
|
|
admin_email = zero_gpu_config.get("admin_email", os.getenv("ZERO_GPU_ADMIN_EMAIL", "")) |
|
|
admin_password = zero_gpu_config.get("admin_password", os.getenv("ZERO_GPU_ADMIN_PASSWORD", "")) |
|
|
db_path = zero_gpu_config.get("db_path", os.getenv("DB_PATH", "/tmp/sessions.db")) |
|
|
|
|
|
if admin_email and admin_password: |
|
|
self.zero_gpu_user_manager = ZeroGPUUserManager( |
|
|
base_url, admin_email, admin_password, db_path |
|
|
) |
|
|
self.use_zero_gpu = True |
|
|
self.zero_gpu_mode = "per_user" |
|
|
logger.info("✓ ZeroGPU per-user mode enabled (multi-tenant)") |
|
|
else: |
|
|
logger.warning("ZeroGPU per-user mode enabled but admin credentials not provided") |
|
|
except ImportError: |
|
|
logger.warning("zero_gpu_user_manager not available, falling back to service account mode") |
|
|
per_user_mode = False |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not initialize ZeroGPU user manager: {e}. Falling back to service account mode.") |
|
|
per_user_mode = False |
|
|
|
|
|
if not per_user_mode: |
|
|
|
|
|
try: |
|
|
from zero_gpu_client import ZeroGPUChatClient |
|
|
base_url = zero_gpu_config.get("base_url", os.getenv("ZERO_GPU_API_URL", "https://bm9njt1ypzvuqw-8000.proxy.runpod.net")) |
|
|
email = zero_gpu_config.get("email", os.getenv("ZERO_GPU_EMAIL", "")) |
|
|
password = zero_gpu_config.get("password", os.getenv("ZERO_GPU_PASSWORD", "")) |
|
|
|
|
|
if email and password: |
|
|
self.zero_gpu_client = ZeroGPUChatClient(base_url, email, password) |
|
|
self.use_zero_gpu = True |
|
|
self.zero_gpu_mode = "service_account" |
|
|
logger.info("✓ ZeroGPU API client initialized (service account mode)") |
|
|
|
|
|
|
|
|
try: |
|
|
if not self.zero_gpu_client.wait_for_ready(timeout=10): |
|
|
logger.warning("ZeroGPU API not ready, will use HF fallback") |
|
|
self.use_zero_gpu = False |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not verify ZeroGPU API readiness: {e}. Will use HF fallback.") |
|
|
self.use_zero_gpu = False |
|
|
else: |
|
|
logger.warning("ZeroGPU enabled but credentials not provided") |
|
|
except ImportError: |
|
|
logger.warning("zero_gpu_client not available, ZeroGPU disabled") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not initialize ZeroGPU client: {e}. Falling back to HF API.") |
|
|
self.use_zero_gpu = False |
|
|
|
|
|
|
|
|
if self.use_local_models: |
|
|
try: |
|
|
from .local_model_loader import LocalModelLoader |
|
|
|
|
|
self.local_loader = LocalModelLoader() |
|
|
logger.info("✓ Local model loader initialized (models will load on-demand as fallback)") |
|
|
logger.info("Models will only load if ZeroGPU API fails") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not initialize local model loader: {e}. Local fallback unavailable.") |
|
|
logger.warning("This is normal if transformers/torch not available") |
|
|
self.use_local_models = False |
|
|
self.local_loader = None |
|
|
|
|
|
async def route_inference(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs): |
|
|
""" |
|
|
Smart routing based on task specialization |
|
|
Tries ZeroGPU API first, then local models as fallback (lazy loading), then HF Inference API |
|
|
|
|
|
Args: |
|
|
task_type: Task type (e.g., "intent_classification", "general_reasoning") |
|
|
prompt: User prompt/message |
|
|
context: Optional conversation context |
|
|
user_id: Optional user ID for per-user ZeroGPU accounts (Option B) |
|
|
**kwargs: Additional generation parameters |
|
|
""" |
|
|
logger.info(f"Routing inference for task: {task_type}") |
|
|
model_config = self._select_model(task_type) |
|
|
logger.info(f"Selected model: {model_config['model_id']}") |
|
|
|
|
|
|
|
|
if self.use_zero_gpu: |
|
|
try: |
|
|
result = await self._call_zero_gpu_endpoint(task_type, prompt, context, user_id, **kwargs) |
|
|
if result is not None: |
|
|
logger.info(f"Inference complete for {task_type} (ZeroGPU API)") |
|
|
return result |
|
|
else: |
|
|
logger.warning("ZeroGPU API returned None, falling back to local models") |
|
|
except Exception as e: |
|
|
logger.warning(f"ZeroGPU API inference failed: {e}. Falling back to local models.") |
|
|
logger.debug("Exception details:", exc_info=True) |
|
|
|
|
|
|
|
|
if self.use_local_models and self.local_loader: |
|
|
try: |
|
|
logger.info("ZeroGPU API unavailable, loading local model as fallback...") |
|
|
|
|
|
if task_type == "embedding_generation": |
|
|
result = await self._call_local_embedding(model_config, prompt, **kwargs) |
|
|
else: |
|
|
result = await self._call_local_model(model_config, prompt, task_type, **kwargs) |
|
|
|
|
|
if result is not None: |
|
|
logger.info(f"Inference complete for {task_type} (local model fallback)") |
|
|
return result |
|
|
else: |
|
|
logger.warning("Local model returned None, falling back to HF API") |
|
|
except Exception as e: |
|
|
logger.warning(f"Local model inference failed: {e}. Falling back to HF API.") |
|
|
logger.debug("Exception details:", exc_info=True) |
|
|
|
|
|
|
|
|
logger.info("Using HF Inference API as final fallback") |
|
|
|
|
|
if not await self._is_model_healthy(model_config["model_id"]): |
|
|
logger.warning(f"Model unhealthy, using fallback") |
|
|
model_config = self._get_fallback_model(task_type) |
|
|
logger.info(f"Fallback model: {model_config['model_id']}") |
|
|
|
|
|
result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs) |
|
|
logger.info(f"Inference complete for {task_type}") |
|
|
return result |
|
|
|
|
|
async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]: |
|
|
"""Call local model for inference (lazy loading - only used as fallback).""" |
|
|
if not self.local_loader: |
|
|
return None |
|
|
|
|
|
model_id = model_config["model_id"] |
|
|
max_tokens = kwargs.get('max_tokens', 512) |
|
|
temperature = kwargs.get('temperature', 0.7) |
|
|
|
|
|
try: |
|
|
|
|
|
if model_id not in self.local_loader.loaded_models: |
|
|
logger.info(f"Lazy loading local model {model_id} as fallback (ZeroGPU unavailable)") |
|
|
self.local_loader.load_chat_model(model_id, load_in_8bit=False) |
|
|
|
|
|
|
|
|
messages = [{"role": "user", "content": prompt}] |
|
|
|
|
|
|
|
|
result = await asyncio.to_thread( |
|
|
self.local_loader.generate_chat_completion, |
|
|
model_id=model_id, |
|
|
messages=messages, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature |
|
|
) |
|
|
|
|
|
logger.info(f"Local model {model_id} generated response (length: {len(result)})") |
|
|
logger.info("=" * 80) |
|
|
logger.info("LOCAL MODEL RESPONSE:") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Model: {model_id}") |
|
|
logger.info(f"Task Type: {task_type}") |
|
|
logger.info(f"Response Length: {len(result)} characters") |
|
|
logger.info("-" * 40) |
|
|
logger.info("FULL RESPONSE CONTENT:") |
|
|
logger.info("-" * 40) |
|
|
logger.info(result) |
|
|
logger.info("-" * 40) |
|
|
logger.info("END OF RESPONSE") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calling local model: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]: |
|
|
"""Call local embedding model (lazy loading - only used as fallback).""" |
|
|
if not self.local_loader: |
|
|
return None |
|
|
|
|
|
model_id = model_config["model_id"] |
|
|
|
|
|
try: |
|
|
|
|
|
if model_id not in self.local_loader.loaded_embedding_models: |
|
|
logger.info(f"Lazy loading local embedding model {model_id} as fallback (ZeroGPU unavailable)") |
|
|
self.local_loader.load_embedding_model(model_id) |
|
|
|
|
|
|
|
|
embedding = await asyncio.to_thread( |
|
|
self.local_loader.get_embedding, |
|
|
model_id=model_id, |
|
|
text=text |
|
|
) |
|
|
|
|
|
logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})") |
|
|
return embedding |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calling local embedding model: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
async def _call_zero_gpu_endpoint(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs) -> Optional[str]: |
|
|
""" |
|
|
Call ZeroGPU API endpoint |
|
|
|
|
|
Args: |
|
|
task_type: Task type (e.g., "intent_classification", "general_reasoning") |
|
|
prompt: User prompt/message |
|
|
context: Optional conversation context |
|
|
user_id: Optional user ID for per-user accounts (Option B) |
|
|
**kwargs: Additional generation parameters |
|
|
|
|
|
Returns: |
|
|
Generated text response or None if failed |
|
|
""" |
|
|
|
|
|
client = None |
|
|
if self.zero_gpu_mode == "per_user" and self.zero_gpu_user_manager and user_id: |
|
|
|
|
|
client = await self.zero_gpu_user_manager.get_or_create_user_client(user_id) |
|
|
if not client: |
|
|
logger.warning(f"Could not get ZeroGPU client for user {user_id}, falling back to service account") |
|
|
client = self.zero_gpu_client |
|
|
else: |
|
|
|
|
|
client = self.zero_gpu_client |
|
|
|
|
|
if not client: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
task_mapping = LLM_CONFIG.get("zero_gpu_task_mapping", {}) |
|
|
zero_gpu_task = task_mapping.get(task_type, "general") |
|
|
|
|
|
logger.info(f"Calling ZeroGPU API for task: {task_type} -> {zero_gpu_task}") |
|
|
logger.debug(f"Prompt length: {len(prompt)}") |
|
|
logger.info("=" * 80) |
|
|
logger.info("ZEROGPU API REQUEST:") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Task Type: {task_type} -> ZeroGPU Task: {zero_gpu_task}") |
|
|
logger.info(f"Prompt Length: {len(prompt)} characters") |
|
|
logger.info("-" * 40) |
|
|
logger.info("FULL PROMPT CONTENT:") |
|
|
logger.info("-" * 40) |
|
|
logger.info(prompt) |
|
|
logger.info("-" * 40) |
|
|
logger.info("END OF PROMPT") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
|
|
|
context_messages = None |
|
|
if context: |
|
|
context_messages = [] |
|
|
for msg in context[-50:]: |
|
|
context_messages.append({ |
|
|
"role": msg.get("role", "user"), |
|
|
"content": msg.get("content", ""), |
|
|
"timestamp": msg.get("timestamp", "") |
|
|
}) |
|
|
|
|
|
|
|
|
generation_params = { |
|
|
"max_tokens": kwargs.get('max_tokens', 512), |
|
|
"temperature": kwargs.get('temperature', 0.7), |
|
|
} |
|
|
|
|
|
|
|
|
if 'top_p' in kwargs: |
|
|
generation_params["top_p"] = kwargs['top_p'] |
|
|
if 'system_prompt' in kwargs: |
|
|
generation_params["system_prompt"] = kwargs['system_prompt'] |
|
|
|
|
|
|
|
|
response = client.chat( |
|
|
message=prompt, |
|
|
task=zero_gpu_task, |
|
|
context=context_messages, |
|
|
**generation_params |
|
|
) |
|
|
|
|
|
|
|
|
if response and "response" in response: |
|
|
generated_text = response["response"] |
|
|
|
|
|
if not generated_text or generated_text.strip() == "": |
|
|
logger.warning("ZeroGPU API returned empty response") |
|
|
return None |
|
|
|
|
|
logger.info(f"ZeroGPU API returned response (length: {len(generated_text)})") |
|
|
logger.info("=" * 80) |
|
|
logger.info("COMPLETE ZEROGPU API RESPONSE:") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Task Type: {task_type} -> ZeroGPU Task: {zero_gpu_task}") |
|
|
logger.info(f"Response Length: {len(generated_text)} characters") |
|
|
|
|
|
|
|
|
if "tokens_used" in response: |
|
|
tokens = response["tokens_used"] |
|
|
logger.info(f"Tokens: input={tokens.get('input', 0)}, output={tokens.get('output', 0)}, total={tokens.get('total', 0)}") |
|
|
|
|
|
if "inference_metrics" in response: |
|
|
metrics = response["inference_metrics"] |
|
|
logger.info(f"Inference Duration: {metrics.get('inference_duration', 0):.2f}s") |
|
|
logger.info(f"Tokens/Second: {metrics.get('tokens_per_second', 0):.2f}") |
|
|
|
|
|
logger.info("-" * 40) |
|
|
logger.info("FULL RESPONSE CONTENT:") |
|
|
logger.info("-" * 40) |
|
|
logger.info(generated_text) |
|
|
logger.info("-" * 40) |
|
|
logger.info("END OF RESPONSE") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
return generated_text |
|
|
else: |
|
|
logger.error(f"Unexpected ZeroGPU response format: {response}") |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error calling ZeroGPU API: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
def _select_model(self, task_type: str) -> dict: |
|
|
model_map = { |
|
|
"intent_classification": LLM_CONFIG["models"]["classification_specialist"], |
|
|
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
|
|
"safety_check": LLM_CONFIG["models"]["safety_checker"], |
|
|
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
|
|
} |
|
|
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
|
|
|
|
|
async def _is_model_healthy(self, model_id: str) -> bool: |
|
|
""" |
|
|
Check if the model is healthy and available |
|
|
Mark models as healthy by default - actual availability checked at API call time |
|
|
""" |
|
|
|
|
|
if model_id in self.health_status: |
|
|
return self.health_status[model_id] |
|
|
|
|
|
|
|
|
self.health_status[model_id] = True |
|
|
return True |
|
|
|
|
|
def _get_fallback_model(self, task_type: str) -> dict: |
|
|
""" |
|
|
Get fallback model configuration for the task type |
|
|
""" |
|
|
|
|
|
fallback_map = { |
|
|
"intent_classification": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], |
|
|
"safety_check": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], |
|
|
"response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] |
|
|
} |
|
|
return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) |
|
|
|
|
|
async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs): |
|
|
""" |
|
|
FIXED: Make actual call to Hugging Face Chat Completions API |
|
|
Uses the correct chat completions protocol with retry logic and exponential backoff |
|
|
|
|
|
IMPORTANT: task_type parameter is now properly included in the method signature |
|
|
""" |
|
|
|
|
|
max_retries = kwargs.get('max_retries', 3) |
|
|
initial_delay = kwargs.get('initial_delay', 1.0) |
|
|
max_delay = kwargs.get('max_delay', 16.0) |
|
|
timeout = kwargs.get('timeout', 30) |
|
|
|
|
|
try: |
|
|
import requests |
|
|
from requests.exceptions import Timeout, RequestException, ConnectionError as RequestsConnectionError |
|
|
|
|
|
model_id = model_config["model_id"] |
|
|
|
|
|
|
|
|
api_url = "https://router.huggingface.co/v1/chat/completions" |
|
|
|
|
|
logger.info(f"Calling HF Chat Completions API for model: {model_id}") |
|
|
logger.debug(f"Prompt length: {len(prompt)}") |
|
|
logger.info("=" * 80) |
|
|
logger.info("LLM API REQUEST - COMPLETE PROMPT:") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Model: {model_id}") |
|
|
|
|
|
|
|
|
logger.info(f"Task Type: {task_type}") |
|
|
logger.info(f"Prompt Length: {len(prompt)} characters") |
|
|
logger.info("-" * 40) |
|
|
logger.info("FULL PROMPT CONTENT:") |
|
|
logger.info("-" * 40) |
|
|
logger.info(prompt) |
|
|
logger.info("-" * 40) |
|
|
logger.info("END OF PROMPT") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
|
|
|
max_tokens = kwargs.get('max_tokens', 512) |
|
|
temperature = kwargs.get('temperature', 0.7) |
|
|
|
|
|
payload = { |
|
|
"model": model_id, |
|
|
"messages": [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": prompt |
|
|
} |
|
|
], |
|
|
"max_tokens": max_tokens, |
|
|
"temperature": temperature, |
|
|
"stream": False |
|
|
} |
|
|
|
|
|
headers = { |
|
|
"Authorization": f"Bearer {self.hf_token}", |
|
|
"Content-Type": "application/json" |
|
|
} |
|
|
|
|
|
|
|
|
last_exception = None |
|
|
for attempt in range(max_retries + 1): |
|
|
try: |
|
|
if attempt > 0: |
|
|
|
|
|
delay = min(initial_delay * (2 ** (attempt - 1)), max_delay) |
|
|
logger.warning(f"Retry attempt {attempt}/{max_retries} after {delay:.1f}s delay (exponential backoff)") |
|
|
await asyncio.sleep(delay) |
|
|
|
|
|
logger.info(f"Sending request to: {api_url} (attempt {attempt + 1}/{max_retries + 1})") |
|
|
logger.debug(f"Payload: {payload}") |
|
|
|
|
|
response = requests.post(api_url, json=payload, headers=headers, timeout=timeout) |
|
|
|
|
|
if response.status_code == 200: |
|
|
result = response.json() |
|
|
logger.debug(f"Raw response: {result}") |
|
|
|
|
|
if 'choices' in result and len(result['choices']) > 0: |
|
|
generated_text = result['choices'][0]['message']['content'] |
|
|
|
|
|
if not generated_text or generated_text.strip() == "": |
|
|
logger.warning(f"Empty or invalid response, using fallback") |
|
|
return None |
|
|
|
|
|
if attempt > 0: |
|
|
logger.info(f"Successfully retrieved response after {attempt} retry attempts") |
|
|
|
|
|
logger.info(f"HF API returned response (length: {len(generated_text)})") |
|
|
logger.info("=" * 80) |
|
|
logger.info("COMPLETE LLM API RESPONSE:") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Model: {model_id}") |
|
|
|
|
|
|
|
|
logger.info(f"Task Type: {task_type}") |
|
|
logger.info(f"Response Length: {len(generated_text)} characters") |
|
|
logger.info("-" * 40) |
|
|
logger.info("FULL RESPONSE CONTENT:") |
|
|
logger.info("-" * 40) |
|
|
logger.info(generated_text) |
|
|
logger.info("-" * 40) |
|
|
logger.info("END OF LLM RESPONSE") |
|
|
logger.info("=" * 80) |
|
|
return generated_text |
|
|
else: |
|
|
logger.error(f"Unexpected response format: {result}") |
|
|
return None |
|
|
elif response.status_code == 503: |
|
|
|
|
|
if attempt < max_retries: |
|
|
logger.warning(f"Model loading (503), will retry (attempt {attempt + 1}/{max_retries + 1})") |
|
|
last_exception = Exception(f"Model loading (503)") |
|
|
continue |
|
|
else: |
|
|
|
|
|
logger.warning(f"Model loading (503) after {max_retries} retries, trying fallback model") |
|
|
fallback_config = self._get_fallback_model(task_type) |
|
|
|
|
|
|
|
|
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) |
|
|
else: |
|
|
|
|
|
logger.error(f"HF API error: {response.status_code} - {response.text}") |
|
|
return None |
|
|
|
|
|
except Timeout as e: |
|
|
last_exception = e |
|
|
if attempt < max_retries: |
|
|
logger.warning(f"Request timeout (attempt {attempt + 1}/{max_retries + 1}): {str(e)}") |
|
|
continue |
|
|
else: |
|
|
logger.error(f"Request timeout after {max_retries} retries: {str(e)}") |
|
|
|
|
|
logger.warning("Attempting fallback model due to persistent timeout") |
|
|
fallback_config = self._get_fallback_model(task_type) |
|
|
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) |
|
|
|
|
|
except (RequestsConnectionError, RequestException) as e: |
|
|
last_exception = e |
|
|
if attempt < max_retries: |
|
|
logger.warning(f"Connection error (attempt {attempt + 1}/{max_retries + 1}): {str(e)}") |
|
|
continue |
|
|
else: |
|
|
logger.error(f"Connection error after {max_retries} retries: {str(e)}") |
|
|
|
|
|
logger.warning("Attempting fallback model due to persistent connection error") |
|
|
fallback_config = self._get_fallback_model(task_type) |
|
|
return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) |
|
|
|
|
|
|
|
|
if last_exception: |
|
|
logger.error(f"Failed after {max_retries} retries. Last error: {last_exception}") |
|
|
return None |
|
|
|
|
|
except ImportError: |
|
|
logger.warning("requests library not available, using mock response") |
|
|
return f"[Mock] Response to: {prompt[:100]}..." |
|
|
except Exception as e: |
|
|
logger.error(f"Error calling HF endpoint: {e}", exc_info=True) |
|
|
return None |
|
|
|
|
|
async def get_available_models(self): |
|
|
""" |
|
|
Get list of available models for testing |
|
|
""" |
|
|
return list(LLM_CONFIG["models"].keys()) |
|
|
|
|
|
async def health_check(self): |
|
|
""" |
|
|
Perform health check on all models |
|
|
""" |
|
|
health_status = {} |
|
|
for model_name, model_config in LLM_CONFIG["models"].items(): |
|
|
model_id = model_config["model_id"] |
|
|
is_healthy = await self._is_model_healthy(model_id) |
|
|
health_status[model_name] = { |
|
|
"model_id": model_id, |
|
|
"healthy": is_healthy |
|
|
} |
|
|
|
|
|
return health_status |
|
|
|
|
|
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: int = 4000) -> str: |
|
|
"""Smart context windowing for LLM calls""" |
|
|
|
|
|
try: |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
|
if not hasattr(self, 'tokenizer'): |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") |
|
|
except Exception as e: |
|
|
logger.warning(f"Could not load tokenizer: {e}, using character count estimation") |
|
|
self.tokenizer = None |
|
|
except ImportError: |
|
|
logger.warning("transformers library not available, using character count estimation") |
|
|
self.tokenizer = None |
|
|
|
|
|
|
|
|
priority_elements = [ |
|
|
('current_query', 1.0), |
|
|
('recent_interactions', 0.8), |
|
|
('user_preferences', 0.6), |
|
|
('session_summary', 0.4), |
|
|
('historical_context', 0.2) |
|
|
] |
|
|
|
|
|
formatted_context = [] |
|
|
total_tokens = 0 |
|
|
|
|
|
for element, priority in priority_elements: |
|
|
|
|
|
element_key_map = { |
|
|
'current_query': raw_context.get('user_input', ''), |
|
|
'recent_interactions': raw_context.get('interaction_contexts', []), |
|
|
'user_preferences': raw_context.get('preferences', {}), |
|
|
'session_summary': raw_context.get('session_context', {}), |
|
|
'historical_context': raw_context.get('user_context', '') |
|
|
} |
|
|
|
|
|
content = element_key_map.get(element, '') |
|
|
|
|
|
|
|
|
if isinstance(content, dict): |
|
|
content = str(content) |
|
|
elif isinstance(content, list): |
|
|
content = "\n".join([str(item) for item in content[:10]]) |
|
|
|
|
|
if not content: |
|
|
continue |
|
|
|
|
|
|
|
|
if self.tokenizer: |
|
|
try: |
|
|
tokens = len(self.tokenizer.encode(content)) |
|
|
except: |
|
|
|
|
|
tokens = len(content) // 4 |
|
|
else: |
|
|
|
|
|
tokens = len(content) // 4 |
|
|
|
|
|
if total_tokens + tokens <= max_tokens: |
|
|
formatted_context.append(f"=== {element.upper()} ===\n{content}") |
|
|
total_tokens += tokens |
|
|
elif priority > 0.5: |
|
|
available = max_tokens - total_tokens |
|
|
if available > 100: |
|
|
truncated = self._truncate_to_tokens(content, available) |
|
|
formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") |
|
|
break |
|
|
|
|
|
return "\n\n".join(formatted_context) |
|
|
|
|
|
def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: |
|
|
"""Truncate content to fit within token limit""" |
|
|
if not self.tokenizer: |
|
|
|
|
|
max_chars = max_tokens * 4 |
|
|
if len(content) <= max_chars: |
|
|
return content |
|
|
return content[:max_chars-3] + "..." |
|
|
|
|
|
try: |
|
|
|
|
|
tokens = self.tokenizer.encode(content) |
|
|
if len(tokens) <= max_tokens: |
|
|
return content |
|
|
|
|
|
truncated_tokens = tokens[:max_tokens-3] |
|
|
truncated_text = self.tokenizer.decode(truncated_tokens) |
|
|
return truncated_text + "..." |
|
|
except Exception as e: |
|
|
logger.warning(f"Error truncating with tokenizer: {e}, using character truncation") |
|
|
max_chars = max_tokens * 4 |
|
|
if len(content) <= max_chars: |
|
|
return content |
|
|
return content[:max_chars-3] + "..." |