# llm_router.py - UPDATED FOR LOCAL GPU MODEL LOADING + ZEROGPU API import logging import asyncio import os from typing import Dict, Optional, List from .models_config import LLM_CONFIG logger = logging.getLogger(__name__) class LLMRouter: def __init__(self, hf_token, use_local_models: bool = True, zero_gpu_config: Optional[Dict] = None): self.hf_token = hf_token self.health_status = {} self.use_local_models = use_local_models self.local_loader = None self.zero_gpu_client = None # Service account client (Option A) self.zero_gpu_user_manager = None # Per-user manager (Option B) self.use_zero_gpu = False self.zero_gpu_mode = "service_account" # "service_account" or "per_user" logger.info("LLMRouter initialized") if hf_token: logger.info("HF token available") else: logger.warning("No HF token provided") # Initialize ZeroGPU client if configured if zero_gpu_config and zero_gpu_config.get("enabled", False): # Check if per-user mode is enabled per_user_mode = zero_gpu_config.get("per_user_mode", False) if per_user_mode: # Option B: Per-User Accounts (Multi-tenant) try: from zero_gpu_user_manager import ZeroGPUUserManager base_url = zero_gpu_config.get("base_url", os.getenv("ZERO_GPU_API_URL", "https://bm9njt1ypzvuqw-8000.proxy.runpod.net")) admin_email = zero_gpu_config.get("admin_email", os.getenv("ZERO_GPU_ADMIN_EMAIL", "")) admin_password = zero_gpu_config.get("admin_password", os.getenv("ZERO_GPU_ADMIN_PASSWORD", "")) db_path = zero_gpu_config.get("db_path", os.getenv("DB_PATH", "/tmp/sessions.db")) if admin_email and admin_password: self.zero_gpu_user_manager = ZeroGPUUserManager( base_url, admin_email, admin_password, db_path ) self.use_zero_gpu = True self.zero_gpu_mode = "per_user" logger.info("✓ ZeroGPU per-user mode enabled (multi-tenant)") logger.info(" → ZeroGPU API is PRIMARY inference method (first priority)") else: logger.warning("ZeroGPU per-user mode enabled but admin credentials not provided") except ImportError: logger.warning("zero_gpu_user_manager not available, falling back to service account mode") per_user_mode = False except Exception as e: logger.warning(f"Could not initialize ZeroGPU user manager: {e}. Falling back to service account mode.") per_user_mode = False if not per_user_mode: # Option A: Service Account (Single-tenant) try: from zero_gpu_client import ZeroGPUChatClient base_url = zero_gpu_config.get("base_url", os.getenv("ZERO_GPU_API_URL", "https://bm9njt1ypzvuqw-8000.proxy.runpod.net")) email = zero_gpu_config.get("email", os.getenv("ZERO_GPU_EMAIL", "")) password = zero_gpu_config.get("password", os.getenv("ZERO_GPU_PASSWORD", "")) if email and password: self.zero_gpu_client = ZeroGPUChatClient(base_url, email, password) self.use_zero_gpu = True self.zero_gpu_mode = "service_account" logger.info("✓ ZeroGPU API client initialized (service account mode)") logger.info(" → ZeroGPU API is PRIMARY inference method (first priority)") # Check API readiness (non-blocking, will still try during inference) try: if self.zero_gpu_client.wait_for_ready(timeout=10): logger.info(" ✓ ZeroGPU API is ready") else: logger.warning(" ⚠ ZeroGPU API not ready yet (will retry during inference)") # Keep use_zero_gpu=True - we'll try it first anyway except Exception as e: logger.warning(f" ⚠ Could not verify ZeroGPU API readiness: {e}") logger.info(" → Will still attempt ZeroGPU first during inference") # Keep use_zero_gpu=True - we'll try it first anyway else: logger.warning("ZeroGPU enabled but credentials not provided") except ImportError: logger.warning("zero_gpu_client not available, ZeroGPU disabled") except Exception as e: logger.warning(f"Could not initialize ZeroGPU client: {e}. Falling back to HF API.") self.use_zero_gpu = False # Initialize local model loader if enabled (but don't load models yet - lazy loading) if self.use_local_models: try: from .local_model_loader import LocalModelLoader # Initialize loader but don't load models yet self.local_loader = LocalModelLoader() logger.info("✓ Local model loader initialized (models will load on-demand as fallback)") logger.info("Models will only load if ZeroGPU API fails") except Exception as e: logger.warning(f"Could not initialize local model loader: {e}. Local fallback unavailable.") logger.warning("This is normal if transformers/torch not available") self.use_local_models = False self.local_loader = None async def route_inference(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs): """ Smart routing based on task specialization PRIORITY ORDER (ZeroGPU is FIRST): 1. ZeroGPU API (PRIMARY - always tried first if configured) 2. Local models (fallback - lazy loading, only if ZeroGPU fails) 3. HF Inference API (final fallback - only if both above fail) Args: task_type: Task type (e.g., "intent_classification", "general_reasoning") prompt: User prompt/message context: Optional conversation context user_id: Optional user ID for per-user ZeroGPU accounts (Option B) **kwargs: Additional generation parameters """ logger.info(f"Routing inference for task: {task_type}") model_config = self._select_model(task_type) logger.info(f"Selected model: {model_config['model_id']}") # PRIORITY 1: Try ZeroGPU API first (PRIMARY inference method) # Always try if client is configured, even if initialization had warnings if self.zero_gpu_client or self.zero_gpu_user_manager: logger.info("→ Attempting ZeroGPU API (PRIMARY inference method)...") try: result = await self._call_zero_gpu_endpoint(task_type, prompt, context, user_id, **kwargs) if result is not None and result.strip(): # Check for non-empty result logger.info(f"✓ Inference complete for {task_type} (ZeroGPU API - PRIMARY)") return result else: logger.warning("ZeroGPU API returned empty result, falling back to local models") except Exception as e: logger.warning(f"ZeroGPU API inference failed: {e}") logger.info("→ Falling back to local models (lazy loading)...") logger.debug("Exception details:", exc_info=True) # PRIORITY 2: Fallback to local models (lazy loading - only if ZeroGPU fails) if self.use_local_models and self.local_loader: try: logger.info("→ Loading local model as fallback (ZeroGPU unavailable)...") # Handle embedding generation separately if task_type == "embedding_generation": result = await self._call_local_embedding(model_config, prompt, **kwargs) else: result = await self._call_local_model(model_config, prompt, task_type, **kwargs) if result is not None and result.strip(): # Check for non-empty result logger.info(f"✓ Inference complete for {task_type} (local model fallback)") return result else: logger.warning("Local model returned empty result, falling back to HF API") except Exception as e: logger.warning(f"Local model inference failed: {e}") logger.info("→ Falling back to HF Inference API (final fallback)...") logger.debug("Exception details:", exc_info=True) # PRIORITY 3: Final fallback to HF Inference API (only if ZeroGPU and local models fail) logger.info("→ Using HF Inference API as final fallback...") # Health check and fallback logic if not await self._is_model_healthy(model_config["model_id"]): logger.warning(f"Model unhealthy, using fallback") model_config = self._get_fallback_model(task_type) logger.info(f"Fallback model: {model_config['model_id']}") result = await self._call_hf_endpoint(model_config, prompt, task_type, **kwargs) logger.info(f"Inference complete for {task_type}") return result async def _call_local_model(self, model_config: dict, prompt: str, task_type: str, **kwargs) -> Optional[str]: """Call local model for inference (lazy loading - only used as fallback).""" if not self.local_loader: return None model_id = model_config["model_id"] max_tokens = kwargs.get('max_tokens', 512) temperature = kwargs.get('temperature', 0.7) try: # Ensure model is loaded (lazy loading on first use) if model_id not in self.local_loader.loaded_models: logger.info(f"Lazy loading local model {model_id} as fallback (ZeroGPU unavailable)") self.local_loader.load_chat_model(model_id, load_in_8bit=False) # Format as chat messages if needed messages = [{"role": "user", "content": prompt}] # Generate using local model result = await asyncio.to_thread( self.local_loader.generate_chat_completion, model_id=model_id, messages=messages, max_tokens=max_tokens, temperature=temperature ) logger.info(f"Local model {model_id} generated response (length: {len(result)})") logger.info("=" * 80) logger.info("LOCAL MODEL RESPONSE:") logger.info("=" * 80) logger.info(f"Model: {model_id}") logger.info(f"Task Type: {task_type}") logger.info(f"Response Length: {len(result)} characters") logger.info("-" * 40) logger.info("FULL RESPONSE CONTENT:") logger.info("-" * 40) logger.info(result) logger.info("-" * 40) logger.info("END OF RESPONSE") logger.info("=" * 80) return result except Exception as e: logger.error(f"Error calling local model: {e}", exc_info=True) return None async def _call_local_embedding(self, model_config: dict, text: str, **kwargs) -> Optional[list]: """Call local embedding model (lazy loading - only used as fallback).""" if not self.local_loader: return None model_id = model_config["model_id"] try: # Ensure model is loaded (lazy loading on first use) if model_id not in self.local_loader.loaded_embedding_models: logger.info(f"Lazy loading local embedding model {model_id} as fallback (ZeroGPU unavailable)") self.local_loader.load_embedding_model(model_id) # Generate embedding embedding = await asyncio.to_thread( self.local_loader.get_embedding, model_id=model_id, text=text ) logger.info(f"Local embedding model {model_id} generated vector (dim: {len(embedding)})") return embedding except Exception as e: logger.error(f"Error calling local embedding model: {e}", exc_info=True) return None async def _call_zero_gpu_endpoint(self, task_type: str, prompt: str, context: Optional[List[Dict]] = None, user_id: Optional[str] = None, **kwargs) -> Optional[str]: """ Call ZeroGPU API endpoint Args: task_type: Task type (e.g., "intent_classification", "general_reasoning") prompt: User prompt/message context: Optional conversation context user_id: Optional user ID for per-user accounts (Option B) **kwargs: Additional generation parameters Returns: Generated text response or None if failed """ # Get appropriate client based on mode client = None if self.zero_gpu_mode == "per_user" and self.zero_gpu_user_manager and user_id: # Option B: Per-user accounts client = await self.zero_gpu_user_manager.get_or_create_user_client(user_id) if not client: logger.warning(f"Could not get ZeroGPU client for user {user_id}, falling back to service account") client = self.zero_gpu_client else: # Option A: Service account client = self.zero_gpu_client if not client: return None try: # Map task type to ZeroGPU task task_mapping = LLM_CONFIG.get("zero_gpu_task_mapping", {}) zero_gpu_task = task_mapping.get(task_type, "general") logger.info(f"Calling ZeroGPU API for task: {task_type} -> {zero_gpu_task}") logger.debug(f"Prompt length: {len(prompt)}") logger.info("=" * 80) logger.info("ZEROGPU API REQUEST:") logger.info("=" * 80) logger.info(f"Task Type: {task_type} -> ZeroGPU Task: {zero_gpu_task}") logger.info(f"Prompt Length: {len(prompt)} characters") logger.info("-" * 40) logger.info("FULL PROMPT CONTENT:") logger.info("-" * 40) logger.info(prompt) logger.info("-" * 40) logger.info("END OF PROMPT") logger.info("=" * 80) # Prepare context if provided context_messages = None if context: context_messages = [] for msg in context[-50:]: # Limit to 50 messages as per API context_messages.append({ "role": msg.get("role", "user"), "content": msg.get("content", ""), "timestamp": msg.get("timestamp", "") }) # Prepare generation parameters generation_params = { "max_tokens": kwargs.get('max_tokens', 512), "temperature": kwargs.get('temperature', 0.7), } # Add optional parameters if 'top_p' in kwargs: generation_params["top_p"] = kwargs['top_p'] if 'system_prompt' in kwargs: generation_params["system_prompt"] = kwargs['system_prompt'] # Call ZeroGPU API response = client.chat( message=prompt, task=zero_gpu_task, context=context_messages, **generation_params ) # Extract response text if response and "response" in response: generated_text = response["response"] if not generated_text or generated_text.strip() == "": logger.warning("ZeroGPU API returned empty response") return None logger.info(f"ZeroGPU API returned response (length: {len(generated_text)})") logger.info("=" * 80) logger.info("COMPLETE ZEROGPU API RESPONSE:") logger.info("=" * 80) logger.info(f"Task Type: {task_type} -> ZeroGPU Task: {zero_gpu_task}") logger.info(f"Response Length: {len(generated_text)} characters") # Log metrics if available if "tokens_used" in response: tokens = response["tokens_used"] logger.info(f"Tokens: input={tokens.get('input', 0)}, output={tokens.get('output', 0)}, total={tokens.get('total', 0)}") if "inference_metrics" in response: metrics = response["inference_metrics"] logger.info(f"Inference Duration: {metrics.get('inference_duration', 0):.2f}s") logger.info(f"Tokens/Second: {metrics.get('tokens_per_second', 0):.2f}") logger.info("-" * 40) logger.info("FULL RESPONSE CONTENT:") logger.info("-" * 40) logger.info(generated_text) logger.info("-" * 40) logger.info("END OF RESPONSE") logger.info("=" * 80) return generated_text else: logger.error(f"Unexpected ZeroGPU response format: {response}") return None except Exception as e: logger.error(f"Error calling ZeroGPU API: {e}", exc_info=True) return None def _select_model(self, task_type: str) -> dict: model_map = { "intent_classification": LLM_CONFIG["models"]["classification_specialist"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["safety_checker"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def _is_model_healthy(self, model_id: str) -> bool: """ Check if the model is healthy and available Mark models as healthy by default - actual availability checked at API call time """ # Check cached health status if model_id in self.health_status: return self.health_status[model_id] # All models marked healthy initially - real check happens during API call self.health_status[model_id] = True return True def _get_fallback_model(self, task_type: str) -> dict: """ Get fallback model configuration for the task type """ # Fallback mapping fallback_map = { "intent_classification": LLM_CONFIG["models"]["reasoning_primary"], "embedding_generation": LLM_CONFIG["models"]["embedding_specialist"], "safety_check": LLM_CONFIG["models"]["reasoning_primary"], "general_reasoning": LLM_CONFIG["models"]["reasoning_primary"], "response_synthesis": LLM_CONFIG["models"]["reasoning_primary"] } return fallback_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"]) async def _call_hf_endpoint(self, model_config: dict, prompt: str, task_type: str, **kwargs): """ FIXED: Make actual call to Hugging Face Chat Completions API Uses the correct chat completions protocol with retry logic and exponential backoff IMPORTANT: task_type parameter is now properly included in the method signature """ # Retry configuration max_retries = kwargs.get('max_retries', 3) initial_delay = kwargs.get('initial_delay', 1.0) # Start with 1 second max_delay = kwargs.get('max_delay', 16.0) # Cap at 16 seconds timeout = kwargs.get('timeout', 30) try: import requests from requests.exceptions import Timeout, RequestException, ConnectionError as RequestsConnectionError model_id = model_config["model_id"] # Use the chat completions endpoint api_url = "https://router.huggingface.co/v1/chat/completions" logger.info(f"Calling HF Chat Completions API for model: {model_id}") logger.debug(f"Prompt length: {len(prompt)}") logger.info("=" * 80) logger.info("LLM API REQUEST - COMPLETE PROMPT:") logger.info("=" * 80) logger.info(f"Model: {model_id}") # FIXED: task_type is now properly available as a parameter logger.info(f"Task Type: {task_type}") logger.info(f"Prompt Length: {len(prompt)} characters") logger.info("-" * 40) logger.info("FULL PROMPT CONTENT:") logger.info("-" * 40) logger.info(prompt) logger.info("-" * 40) logger.info("END OF PROMPT") logger.info("=" * 80) # Prepare the request payload max_tokens = kwargs.get('max_tokens', 512) temperature = kwargs.get('temperature', 0.7) payload = { "model": model_id, "messages": [ { "role": "user", "content": prompt } ], "max_tokens": max_tokens, "temperature": temperature, "stream": False } headers = { "Authorization": f"Bearer {self.hf_token}", "Content-Type": "application/json" } # Retry logic with exponential backoff last_exception = None for attempt in range(max_retries + 1): try: if attempt > 0: # Calculate exponential backoff delay delay = min(initial_delay * (2 ** (attempt - 1)), max_delay) logger.warning(f"Retry attempt {attempt}/{max_retries} after {delay:.1f}s delay (exponential backoff)") await asyncio.sleep(delay) logger.info(f"Sending request to: {api_url} (attempt {attempt + 1}/{max_retries + 1})") logger.debug(f"Payload: {payload}") response = requests.post(api_url, json=payload, headers=headers, timeout=timeout) if response.status_code == 200: result = response.json() logger.debug(f"Raw response: {result}") if 'choices' in result and len(result['choices']) > 0: generated_text = result['choices'][0]['message']['content'] if not generated_text or generated_text.strip() == "": logger.warning(f"Empty or invalid response, using fallback") return None if attempt > 0: logger.info(f"Successfully retrieved response after {attempt} retry attempts") logger.info(f"HF API returned response (length: {len(generated_text)})") logger.info("=" * 80) logger.info("COMPLETE LLM API RESPONSE:") logger.info("=" * 80) logger.info(f"Model: {model_id}") # FIXED: task_type is now properly available logger.info(f"Task Type: {task_type}") logger.info(f"Response Length: {len(generated_text)} characters") logger.info("-" * 40) logger.info("FULL RESPONSE CONTENT:") logger.info("-" * 40) logger.info(generated_text) logger.info("-" * 40) logger.info("END OF LLM RESPONSE") logger.info("=" * 80) return generated_text else: logger.error(f"Unexpected response format: {result}") return None elif response.status_code == 503: # Model is loading - this is retryable if attempt < max_retries: logger.warning(f"Model loading (503), will retry (attempt {attempt + 1}/{max_retries + 1})") last_exception = Exception(f"Model loading (503)") continue else: # After max retries, try fallback model logger.warning(f"Model loading (503) after {max_retries} retries, trying fallback model") fallback_config = self._get_fallback_model(task_type) # FIXED: Ensure task_type is passed in recursive call return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) else: # Non-retryable HTTP errors logger.error(f"HF API error: {response.status_code} - {response.text}") return None except Timeout as e: last_exception = e if attempt < max_retries: logger.warning(f"Request timeout (attempt {attempt + 1}/{max_retries + 1}): {str(e)}") continue else: logger.error(f"Request timeout after {max_retries} retries: {str(e)}") # Try fallback model on final timeout logger.warning("Attempting fallback model due to persistent timeout") fallback_config = self._get_fallback_model(task_type) return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) except (RequestsConnectionError, RequestException) as e: last_exception = e if attempt < max_retries: logger.warning(f"Connection error (attempt {attempt + 1}/{max_retries + 1}): {str(e)}") continue else: logger.error(f"Connection error after {max_retries} retries: {str(e)}") # Try fallback model on final connection error logger.warning("Attempting fallback model due to persistent connection error") fallback_config = self._get_fallback_model(task_type) return await self._call_hf_endpoint(fallback_config, prompt, task_type, **kwargs) # If we exhausted all retries and didn't return if last_exception: logger.error(f"Failed after {max_retries} retries. Last error: {last_exception}") return None except ImportError: logger.warning("requests library not available, using mock response") return f"[Mock] Response to: {prompt[:100]}..." except Exception as e: logger.error(f"Error calling HF endpoint: {e}", exc_info=True) return None async def get_available_models(self): """ Get list of available models for testing """ return list(LLM_CONFIG["models"].keys()) async def health_check(self): """ Perform health check on all models """ health_status = {} for model_name, model_config in LLM_CONFIG["models"].items(): model_id = model_config["model_id"] is_healthy = await self._is_model_healthy(model_id) health_status[model_name] = { "model_id": model_id, "healthy": is_healthy } return health_status def prepare_context_for_llm(self, raw_context: Dict, max_tokens: int = 4000) -> str: """Smart context windowing for LLM calls""" try: from transformers import AutoTokenizer # Initialize tokenizer lazily if not hasattr(self, 'tokenizer'): try: self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct") except Exception as e: logger.warning(f"Could not load tokenizer: {e}, using character count estimation") self.tokenizer = None except ImportError: logger.warning("transformers library not available, using character count estimation") self.tokenizer = None # Priority order for context elements priority_elements = [ ('current_query', 1.0), ('recent_interactions', 0.8), ('user_preferences', 0.6), ('session_summary', 0.4), ('historical_context', 0.2) ] formatted_context = [] total_tokens = 0 for element, priority in priority_elements: # Map element names to context keys element_key_map = { 'current_query': raw_context.get('user_input', ''), 'recent_interactions': raw_context.get('interaction_contexts', []), 'user_preferences': raw_context.get('preferences', {}), 'session_summary': raw_context.get('session_context', {}), 'historical_context': raw_context.get('user_context', '') } content = element_key_map.get(element, '') # Convert to string if needed if isinstance(content, dict): content = str(content) elif isinstance(content, list): content = "\n".join([str(item) for item in content[:10]]) # Limit to 10 items if not content: continue # Estimate tokens if self.tokenizer: try: tokens = len(self.tokenizer.encode(content)) except: # Fallback to character-based estimation (rough: 1 token ≈ 4 chars) tokens = len(content) // 4 else: # Character-based estimation (rough: 1 token ≈ 4 chars) tokens = len(content) // 4 if total_tokens + tokens <= max_tokens: formatted_context.append(f"=== {element.upper()} ===\n{content}") total_tokens += tokens elif priority > 0.5: # Critical elements - truncate if needed available = max_tokens - total_tokens if available > 100: # Only truncate if we have meaningful space truncated = self._truncate_to_tokens(content, available) formatted_context.append(f"=== {element.upper()} (TRUNCATED) ===\n{truncated}") break return "\n\n".join(formatted_context) def _truncate_to_tokens(self, content: str, max_tokens: int) -> str: """Truncate content to fit within token limit""" if not self.tokenizer: # Simple character-based truncation max_chars = max_tokens * 4 if len(content) <= max_chars: return content return content[:max_chars-3] + "..." try: # Tokenize and truncate tokens = self.tokenizer.encode(content) if len(tokens) <= max_tokens: return content truncated_tokens = tokens[:max_tokens-3] # Leave room for "..." truncated_text = self.tokenizer.decode(truncated_tokens) return truncated_text + "..." except Exception as e: logger.warning(f"Error truncating with tokenizer: {e}, using character truncation") max_chars = max_tokens * 4 if len(content) <= max_chars: return content return content[:max_chars-3] + "..."