Commit
·
0747201
1
Parent(s):
bd329bc
Replace Novita AI with ZeroGPU Chat API (RunPod)
Browse files- Replace Novita AI API integration with ZeroGPU Chat API
- Update llm_router.py to use aiohttp for HTTP requests with JWT authentication
- Add automatic token refresh and authentication handling
- Update config.py with ZeroGPU settings (base_url, email, password)
- Update ENV_EXAMPLE_CONTENT.txt with ZeroGPU configuration
- Update flask_api_standalone.py references
- Remove OpenAI dependency from requirements.txt
- Implement task type mapping (general_reasoning -> general, etc.)
- Add context conversion for API format compatibility
- ENV_EXAMPLE_CONTENT.txt +12 -21
- flask_api_standalone.py +15 -14
- requirements.txt +2 -2
- src/config.py +43 -50
- src/llm_router.py +264 -140
ENV_EXAMPLE_CONTENT.txt
CHANGED
|
@@ -5,27 +5,18 @@
|
|
| 5 |
# Never commit .env to version control!
|
| 6 |
|
| 7 |
# =============================================================================
|
| 8 |
-
#
|
| 9 |
# =============================================================================
|
| 10 |
-
#
|
| 11 |
-
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
#
|
| 14 |
-
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
|
| 18 |
-
NOVITA_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B:de-1a706eeafbf3ebc2
|
| 19 |
-
|
| 20 |
-
# =============================================================================
|
| 21 |
-
# DeepSeek-R1 Optimized Settings
|
| 22 |
-
# =============================================================================
|
| 23 |
-
# Temperature: 0.5-0.7 range (0.6 recommended for DeepSeek-R1)
|
| 24 |
-
DEEPSEEK_R1_TEMPERATURE=0.6
|
| 25 |
-
|
| 26 |
-
# Force reasoning trigger: Enable to ensure DeepSeek-R1 uses reasoning pattern
|
| 27 |
-
# Set to True to add `<think>` prefix for reasoning tasks
|
| 28 |
-
DEEPSEEK_R1_FORCE_REASONING=True
|
| 29 |
|
| 30 |
# =============================================================================
|
| 31 |
# Token Allocation Configuration
|
|
@@ -45,10 +36,10 @@ CONTEXT_PRUNING_THRESHOLD=115000
|
|
| 45 |
PRIORITIZE_USER_INPUT=True
|
| 46 |
|
| 47 |
# Model context window (actual limit for your deployed model)
|
| 48 |
-
# Default:
|
| 49 |
# This is the maximum total tokens (input + output) the model can handle
|
| 50 |
-
#
|
| 51 |
-
|
| 52 |
|
| 53 |
# =============================================================================
|
| 54 |
# Database Configuration
|
|
|
|
| 5 |
# Never commit .env to version control!
|
| 6 |
|
| 7 |
# =============================================================================
|
| 8 |
+
# ZeroGPU Chat API Configuration (REQUIRED)
|
| 9 |
# =============================================================================
|
| 10 |
+
# Base URL for your ZeroGPU Chat API endpoint (RunPod)
|
| 11 |
+
# Format: http://your-pod-ip:8000 or https://your-domain.com
|
| 12 |
+
# Example: http://bm9njt1ypzvuqw-8000.proxy.runpod.net
|
| 13 |
+
ZEROGPU_BASE_URL=http://your-pod-ip:8000
|
| 14 |
|
| 15 |
+
# Email for authentication (register first via /register endpoint)
|
| 16 |
+
ZEROGPU_EMAIL=your-email@example.com
|
| 17 |
|
| 18 |
+
# Password for authentication
|
| 19 |
+
ZEROGPU_PASSWORD=your_secure_password_here
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# =============================================================================
|
| 22 |
# Token Allocation Configuration
|
|
|
|
| 36 |
PRIORITIZE_USER_INPUT=True
|
| 37 |
|
| 38 |
# Model context window (actual limit for your deployed model)
|
| 39 |
+
# Default: 8192 tokens (adjust based on your model)
|
| 40 |
# This is the maximum total tokens (input + output) the model can handle
|
| 41 |
+
# Common values: 4096, 8192, 16384, 32768, etc.
|
| 42 |
+
ZEROGPU_MODEL_CONTEXT_WINDOW=8192
|
| 43 |
|
| 44 |
# =============================================================================
|
| 45 |
# Database Configuration
|
flask_api_standalone.py
CHANGED
|
@@ -145,7 +145,7 @@ initialization_attempted = False
|
|
| 145 |
initialization_error = None
|
| 146 |
|
| 147 |
def initialize_orchestrator():
|
| 148 |
-
"""Initialize the AI orchestrator with
|
| 149 |
global orchestrator, orchestrator_available, initialization_attempted, initialization_error
|
| 150 |
|
| 151 |
initialization_attempted = True
|
|
@@ -153,7 +153,7 @@ def initialize_orchestrator():
|
|
| 153 |
|
| 154 |
try:
|
| 155 |
logger.info("=" * 60)
|
| 156 |
-
logger.info("INITIALIZING AI ORCHESTRATOR (
|
| 157 |
logger.info("=" * 60)
|
| 158 |
|
| 159 |
from src.agents.intent_agent import create_intent_agent
|
|
@@ -166,16 +166,16 @@ def initialize_orchestrator():
|
|
| 166 |
|
| 167 |
logger.info("✓ Imports successful")
|
| 168 |
|
| 169 |
-
# Initialize LLM Router -
|
| 170 |
-
logger.info("Initializing LLM Router (
|
| 171 |
try:
|
| 172 |
-
# Always use
|
| 173 |
llm_router = LLMRouter(hf_token=None, use_local_models=False)
|
| 174 |
-
logger.info("✓ LLM Router initialized (
|
| 175 |
except Exception as e:
|
| 176 |
logger.error(f"❌ Failed to initialize LLM Router: {e}", exc_info=True)
|
| 177 |
-
logger.error("This is a critical error -
|
| 178 |
-
logger.error("Please ensure
|
| 179 |
raise
|
| 180 |
|
| 181 |
logger.info("Initializing Agents...")
|
|
@@ -210,24 +210,25 @@ def initialize_orchestrator():
|
|
| 210 |
orchestrator_available = True
|
| 211 |
logger.info("=" * 60)
|
| 212 |
logger.info("✓ AI ORCHESTRATOR READY")
|
| 213 |
-
logger.info(" -
|
| 214 |
logger.info(" - MAX_WORKERS: 4")
|
| 215 |
logger.info("=" * 60)
|
| 216 |
|
| 217 |
return True
|
| 218 |
|
| 219 |
except ValueError as e:
|
| 220 |
-
# Handle configuration errors (e.g., missing
|
| 221 |
-
if "
|
| 222 |
logger.error("=" * 60)
|
| 223 |
logger.error("❌ CONFIGURATION ERROR")
|
| 224 |
logger.error("=" * 60)
|
| 225 |
logger.error(f"Error: {e}")
|
| 226 |
logger.error("")
|
| 227 |
logger.error("SOLUTION:")
|
| 228 |
-
logger.error("1. Set
|
| 229 |
-
logger.error("2.
|
| 230 |
-
logger.error("3.
|
|
|
|
| 231 |
logger.error("=" * 60)
|
| 232 |
orchestrator_available = False
|
| 233 |
initialization_error = f"Configuration Error: {str(e)}"
|
|
|
|
| 145 |
initialization_error = None
|
| 146 |
|
| 147 |
def initialize_orchestrator():
|
| 148 |
+
"""Initialize the AI orchestrator with ZeroGPU Chat API (RunPod)"""
|
| 149 |
global orchestrator, orchestrator_available, initialization_attempted, initialization_error
|
| 150 |
|
| 151 |
initialization_attempted = True
|
|
|
|
| 153 |
|
| 154 |
try:
|
| 155 |
logger.info("=" * 60)
|
| 156 |
+
logger.info("INITIALIZING AI ORCHESTRATOR (ZeroGPU Chat API - RunPod)")
|
| 157 |
logger.info("=" * 60)
|
| 158 |
|
| 159 |
from src.agents.intent_agent import create_intent_agent
|
|
|
|
| 166 |
|
| 167 |
logger.info("✓ Imports successful")
|
| 168 |
|
| 169 |
+
# Initialize LLM Router - ZeroGPU Chat API
|
| 170 |
+
logger.info("Initializing LLM Router (ZeroGPU Chat API)...")
|
| 171 |
try:
|
| 172 |
+
# Always use ZeroGPU Chat API (local models disabled)
|
| 173 |
llm_router = LLMRouter(hf_token=None, use_local_models=False)
|
| 174 |
+
logger.info("✓ LLM Router initialized (ZeroGPU Chat API)")
|
| 175 |
except Exception as e:
|
| 176 |
logger.error(f"❌ Failed to initialize LLM Router: {e}", exc_info=True)
|
| 177 |
+
logger.error("This is a critical error - ZeroGPU Chat API is required")
|
| 178 |
+
logger.error("Please ensure ZEROGPU_BASE_URL, ZEROGPU_EMAIL, and ZEROGPU_PASSWORD are set in environment variables")
|
| 179 |
raise
|
| 180 |
|
| 181 |
logger.info("Initializing Agents...")
|
|
|
|
| 210 |
orchestrator_available = True
|
| 211 |
logger.info("=" * 60)
|
| 212 |
logger.info("✓ AI ORCHESTRATOR READY")
|
| 213 |
+
logger.info(" - ZeroGPU Chat API enabled")
|
| 214 |
logger.info(" - MAX_WORKERS: 4")
|
| 215 |
logger.info("=" * 60)
|
| 216 |
|
| 217 |
return True
|
| 218 |
|
| 219 |
except ValueError as e:
|
| 220 |
+
# Handle configuration errors (e.g., missing ZeroGPU credentials)
|
| 221 |
+
if "ZEROGPU" in str(e) or "required" in str(e).lower():
|
| 222 |
logger.error("=" * 60)
|
| 223 |
logger.error("❌ CONFIGURATION ERROR")
|
| 224 |
logger.error("=" * 60)
|
| 225 |
logger.error(f"Error: {e}")
|
| 226 |
logger.error("")
|
| 227 |
logger.error("SOLUTION:")
|
| 228 |
+
logger.error("1. Set ZEROGPU_BASE_URL in environment variables (e.g., http://your-pod-ip:8000)")
|
| 229 |
+
logger.error("2. Set ZEROGPU_EMAIL in environment variables")
|
| 230 |
+
logger.error("3. Set ZEROGPU_PASSWORD in environment variables")
|
| 231 |
+
logger.error("4. Register your account first via the /register endpoint if needed")
|
| 232 |
logger.error("=" * 60)
|
| 233 |
orchestrator_available = False
|
| 234 |
initialization_error = f"Configuration Error: {str(e)}"
|
requirements.txt
CHANGED
|
@@ -107,6 +107,6 @@ debugpy>=1.7.0
|
|
| 107 |
bandit>=1.7.5 # Security linter for Python code
|
| 108 |
safety>=2.3.5 # Dependency vulnerability scanner
|
| 109 |
|
| 110 |
-
#
|
| 111 |
-
|
| 112 |
|
|
|
|
| 107 |
bandit>=1.7.5 # Security linter for Python code
|
| 108 |
safety>=2.3.5 # Dependency vulnerability scanner
|
| 109 |
|
| 110 |
+
# HTTP Client for ZeroGPU Chat API (aiohttp already included above)
|
| 111 |
+
# Note: No OpenAI client needed - using direct HTTP requests
|
| 112 |
|
src/config.py
CHANGED
|
@@ -174,37 +174,24 @@ class Settings(BaseSettings):
|
|
| 174 |
|
| 175 |
return self._cached_cache_dir
|
| 176 |
|
| 177 |
-
# ====================
|
| 178 |
|
| 179 |
-
|
| 180 |
-
default="",
|
| 181 |
-
description="
|
| 182 |
-
env="
|
| 183 |
-
)
|
| 184 |
-
|
| 185 |
-
novita_base_url: str = Field(
|
| 186 |
-
default="https://api.novita.ai/dedicated/v1/openai",
|
| 187 |
-
description="Novita AI dedicated endpoint base URL",
|
| 188 |
-
env="NOVITA_BASE_URL"
|
| 189 |
-
)
|
| 190 |
-
|
| 191 |
-
novita_model: str = Field(
|
| 192 |
-
default="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B:de-1a706eeafbf3ebc2",
|
| 193 |
-
description="Novita AI dedicated endpoint model ID",
|
| 194 |
-
env="NOVITA_MODEL"
|
| 195 |
)
|
| 196 |
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
env="DEEPSEEK_R1_TEMPERATURE"
|
| 202 |
)
|
| 203 |
|
| 204 |
-
|
| 205 |
-
default=
|
| 206 |
-
description="
|
| 207 |
-
env="
|
| 208 |
)
|
| 209 |
|
| 210 |
# Token Allocation Configuration
|
|
@@ -233,34 +220,40 @@ class Settings(BaseSettings):
|
|
| 233 |
)
|
| 234 |
|
| 235 |
# Model Context Window Configuration
|
| 236 |
-
|
| 237 |
-
default=
|
| 238 |
-
description="Maximum context window for
|
| 239 |
-
env="
|
| 240 |
)
|
| 241 |
|
| 242 |
-
@validator("
|
| 243 |
-
def
|
| 244 |
-
"""Validate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
if v is None:
|
| 246 |
return ""
|
| 247 |
return str(v).strip()
|
| 248 |
|
| 249 |
-
@validator("deepseek_r1_temperature", pre=True)
|
| 250 |
-
def validate_deepseek_temperature(cls, v):
|
| 251 |
-
"""Validate DeepSeek-R1 temperature is in recommended range"""
|
| 252 |
-
if isinstance(v, str):
|
| 253 |
-
v = float(v)
|
| 254 |
-
temp = float(v) if v else 0.6
|
| 255 |
-
return max(0.5, min(0.7, temp))
|
| 256 |
-
|
| 257 |
-
@validator("deepseek_r1_force_reasoning", pre=True)
|
| 258 |
-
def validate_force_reasoning(cls, v):
|
| 259 |
-
"""Convert string to boolean for force_reasoning"""
|
| 260 |
-
if isinstance(v, str):
|
| 261 |
-
return v.lower() in ("true", "1", "yes", "on")
|
| 262 |
-
return bool(v)
|
| 263 |
-
|
| 264 |
@validator("user_input_max_tokens", pre=True)
|
| 265 |
def validate_user_input_tokens(cls, v):
|
| 266 |
"""Validate user input token limit"""
|
|
@@ -279,10 +272,10 @@ class Settings(BaseSettings):
|
|
| 279 |
val = int(v) if v else 115000
|
| 280 |
return max(4000, min(125000, val)) # Match context_preparation_budget limits
|
| 281 |
|
| 282 |
-
@validator("
|
| 283 |
def validate_context_window(cls, v):
|
| 284 |
"""Validate context window size"""
|
| 285 |
-
val = int(v) if v else
|
| 286 |
return max(1000, min(200000, val)) # Support up to 200K for future models
|
| 287 |
|
| 288 |
# ==================== Model Configuration ====================
|
|
|
|
| 174 |
|
| 175 |
return self._cached_cache_dir
|
| 176 |
|
| 177 |
+
# ==================== ZeroGPU Chat API Configuration ====================
|
| 178 |
|
| 179 |
+
zerogpu_base_url: str = Field(
|
| 180 |
+
default="http://your-pod-ip:8000",
|
| 181 |
+
description="ZeroGPU Chat API base URL (RunPod endpoint)",
|
| 182 |
+
env="ZEROGPU_BASE_URL"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
)
|
| 184 |
|
| 185 |
+
zerogpu_email: str = Field(
|
| 186 |
+
default="",
|
| 187 |
+
description="ZeroGPU Chat API email for authentication (required)",
|
| 188 |
+
env="ZEROGPU_EMAIL"
|
|
|
|
| 189 |
)
|
| 190 |
|
| 191 |
+
zerogpu_password: str = Field(
|
| 192 |
+
default="",
|
| 193 |
+
description="ZeroGPU Chat API password for authentication (required)",
|
| 194 |
+
env="ZEROGPU_PASSWORD"
|
| 195 |
)
|
| 196 |
|
| 197 |
# Token Allocation Configuration
|
|
|
|
| 220 |
)
|
| 221 |
|
| 222 |
# Model Context Window Configuration
|
| 223 |
+
zerogpu_model_context_window: int = Field(
|
| 224 |
+
default=8192,
|
| 225 |
+
description="Maximum context window for ZeroGPU Chat API model (input + output tokens). Adjust based on your deployed model.",
|
| 226 |
+
env="ZEROGPU_MODEL_CONTEXT_WINDOW"
|
| 227 |
)
|
| 228 |
|
| 229 |
+
@validator("zerogpu_base_url", pre=True)
|
| 230 |
+
def validate_zerogpu_base_url(cls, v):
|
| 231 |
+
"""Validate ZeroGPU base URL"""
|
| 232 |
+
if v is None:
|
| 233 |
+
return "http://your-pod-ip:8000"
|
| 234 |
+
url = str(v).strip()
|
| 235 |
+
# Remove trailing slash
|
| 236 |
+
if url.endswith('/'):
|
| 237 |
+
url = url[:-1]
|
| 238 |
+
return url
|
| 239 |
+
|
| 240 |
+
@validator("zerogpu_email", pre=True)
|
| 241 |
+
def validate_zerogpu_email(cls, v):
|
| 242 |
+
"""Validate ZeroGPU email"""
|
| 243 |
+
if v is None:
|
| 244 |
+
return ""
|
| 245 |
+
email = str(v).strip()
|
| 246 |
+
if email and '@' not in email:
|
| 247 |
+
logger.warning("ZEROGPU_EMAIL may not be a valid email address")
|
| 248 |
+
return email
|
| 249 |
+
|
| 250 |
+
@validator("zerogpu_password", pre=True)
|
| 251 |
+
def validate_zerogpu_password(cls, v):
|
| 252 |
+
"""Validate ZeroGPU password"""
|
| 253 |
if v is None:
|
| 254 |
return ""
|
| 255 |
return str(v).strip()
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 257 |
@validator("user_input_max_tokens", pre=True)
|
| 258 |
def validate_user_input_tokens(cls, v):
|
| 259 |
"""Validate user input token limit"""
|
|
|
|
| 272 |
val = int(v) if v else 115000
|
| 273 |
return max(4000, min(125000, val)) # Match context_preparation_budget limits
|
| 274 |
|
| 275 |
+
@validator("zerogpu_model_context_window", pre=True)
|
| 276 |
def validate_context_window(cls, v):
|
| 277 |
"""Validate context window size"""
|
| 278 |
+
val = int(v) if v else 8192
|
| 279 |
return max(1000, min(200000, val)) # Support up to 200K for future models
|
| 280 |
|
| 281 |
# ==================== Model Configuration ====================
|
src/llm_router.py
CHANGED
|
@@ -1,67 +1,61 @@
|
|
| 1 |
-
# llm_router.py -
|
| 2 |
import logging
|
| 3 |
import asyncio
|
|
|
|
|
|
|
| 4 |
from typing import Dict, Optional
|
| 5 |
from .models_config import LLM_CONFIG
|
| 6 |
from .config import get_settings
|
| 7 |
|
| 8 |
-
# Import OpenAI client for Novita AI API
|
| 9 |
-
try:
|
| 10 |
-
from openai import OpenAI
|
| 11 |
-
OPENAI_AVAILABLE = True
|
| 12 |
-
except ImportError:
|
| 13 |
-
OPENAI_AVAILABLE = False
|
| 14 |
-
logger = logging.getLogger(__name__)
|
| 15 |
-
logger.error("openai package not available - Novita AI API requires openai package")
|
| 16 |
-
|
| 17 |
logger = logging.getLogger(__name__)
|
| 18 |
|
| 19 |
class LLMRouter:
|
| 20 |
def __init__(self, hf_token=None, use_local_models: bool = False):
|
| 21 |
"""
|
| 22 |
-
Initialize LLM Router with
|
| 23 |
|
| 24 |
Args:
|
| 25 |
hf_token: Not used (kept for backward compatibility)
|
| 26 |
use_local_models: Must be False (local models disabled)
|
| 27 |
"""
|
| 28 |
if use_local_models:
|
| 29 |
-
raise ValueError("Local models are disabled. Only
|
| 30 |
|
| 31 |
self.settings = get_settings()
|
| 32 |
-
self.
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
)
|
| 40 |
|
| 41 |
-
# Validate
|
| 42 |
-
if not self.settings.
|
| 43 |
raise ValueError(
|
| 44 |
-
"
|
| 45 |
-
"Set
|
| 46 |
)
|
| 47 |
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
| 49 |
try:
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
api_key=self.settings.novita_api_key,
|
| 53 |
-
)
|
| 54 |
-
logger.info("Novita AI API client initialized")
|
| 55 |
-
logger.info(f"Base URL: {self.settings.novita_base_url}")
|
| 56 |
-
logger.info(f"Model: {self.settings.novita_model}")
|
| 57 |
-
logger.info(f"Context Window: {self.settings.novita_model_context_window} tokens")
|
| 58 |
except Exception as e:
|
| 59 |
-
logger.error(f"Failed to initialize
|
| 60 |
-
raise RuntimeError(f"Could not initialize
|
| 61 |
|
| 62 |
async def route_inference(self, task_type: str, prompt: str, **kwargs):
|
| 63 |
"""
|
| 64 |
-
Route inference to
|
| 65 |
|
| 66 |
Args:
|
| 67 |
task_type: Type of task (general_reasoning, intent_classification, etc.)
|
|
@@ -71,101 +65,200 @@ class LLMRouter:
|
|
| 71 |
Returns:
|
| 72 |
Generated text response
|
| 73 |
"""
|
| 74 |
-
logger.info(f"Routing inference to
|
| 75 |
-
|
| 76 |
-
if not self.novita_client:
|
| 77 |
-
raise RuntimeError("Novita AI client not initialized")
|
| 78 |
|
| 79 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 80 |
# Handle embedding generation (may need special handling)
|
| 81 |
if task_type == "embedding_generation":
|
| 82 |
-
logger.warning("Embedding generation via
|
| 83 |
-
|
| 84 |
-
result = await self._call_novita_api(task_type, prompt, **kwargs)
|
| 85 |
else:
|
| 86 |
-
result = await self.
|
| 87 |
|
| 88 |
if result is None:
|
| 89 |
-
logger.error(f"
|
| 90 |
raise RuntimeError(f"Inference failed for task: {task_type}")
|
| 91 |
|
| 92 |
-
logger.info(f"Inference complete for {task_type} (
|
| 93 |
return result
|
| 94 |
|
| 95 |
except Exception as e:
|
| 96 |
-
logger.error(f"
|
| 97 |
raise RuntimeError(
|
| 98 |
f"Inference failed for task: {task_type}. "
|
| 99 |
-
f"
|
| 100 |
) from e
|
| 101 |
|
| 102 |
-
async def
|
| 103 |
-
"""
|
| 104 |
-
if
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
}
|
| 136 |
|
| 137 |
try:
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
delta = chunk.choices[0].delta
|
| 146 |
-
if delta and delta.content:
|
| 147 |
-
response_text += delta.content
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
else:
|
| 154 |
-
# Handle non-streaming response
|
| 155 |
-
response = self.novita_client.chat.completions.create(**request_params)
|
| 156 |
-
|
| 157 |
-
if response.choices and len(response.choices) > 0:
|
| 158 |
-
result = response.choices[0].message.content
|
| 159 |
-
# Clean up reasoning tags if present
|
| 160 |
-
result = self._clean_reasoning_tags(result)
|
| 161 |
-
logger.info(f"Novita AI API generated response (length: {len(result)})")
|
| 162 |
return result
|
| 163 |
else:
|
| 164 |
-
logger.error("
|
| 165 |
return None
|
| 166 |
-
|
| 167 |
-
except
|
| 168 |
-
logger.error(f"Error calling
|
| 169 |
raise
|
| 170 |
|
| 171 |
def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int:
|
|
@@ -184,7 +277,7 @@ class LLMRouter:
|
|
| 184 |
input_tokens = len(prompt) // 4
|
| 185 |
|
| 186 |
# Get model context window from settings
|
| 187 |
-
context_window = self.settings.
|
| 188 |
|
| 189 |
logger.debug(
|
| 190 |
f"Calculating safe max_tokens: input ~{input_tokens} tokens, "
|
|
@@ -209,26 +302,14 @@ class LLMRouter:
|
|
| 209 |
|
| 210 |
return safe_max_tokens
|
| 211 |
|
| 212 |
-
def
|
| 213 |
"""
|
| 214 |
-
Format prompt
|
| 215 |
-
|
| 216 |
-
- Force reasoning trigger for reasoning tasks
|
| 217 |
-
- Add math directive for mathematical problems
|
| 218 |
"""
|
| 219 |
formatted_prompt = prompt
|
| 220 |
|
| 221 |
-
#
|
| 222 |
-
force_reasoning = (
|
| 223 |
-
self.settings.deepseek_r1_force_reasoning and
|
| 224 |
-
model_config.get("force_reasoning_prefix", False)
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
if force_reasoning:
|
| 228 |
-
# Force model to start with reasoning trigger
|
| 229 |
-
formatted_prompt = f"`<think>`\n\n{formatted_prompt}"
|
| 230 |
-
|
| 231 |
-
# Add math directive for mathematical problems
|
| 232 |
if self._is_math_query(prompt):
|
| 233 |
math_directive = "Please reason step by step, and put your final answer within \\boxed{}."
|
| 234 |
formatted_prompt = f"{formatted_prompt}\n\n{math_directive}"
|
|
@@ -246,7 +327,11 @@ class LLMRouter:
|
|
| 246 |
return any(keyword in prompt_lower for keyword in math_keywords)
|
| 247 |
|
| 248 |
def _clean_reasoning_tags(self, text: str) -> str:
|
| 249 |
-
"""Clean up reasoning tags from response"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
text = text.replace("`<think>`", "").replace("`</think>`", "")
|
| 251 |
text = text.strip()
|
| 252 |
return text
|
|
@@ -263,33 +348,72 @@ class LLMRouter:
|
|
| 263 |
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
|
| 264 |
|
| 265 |
async def get_available_models(self):
|
| 266 |
-
"""Get list of available models
|
| 267 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
|
| 269 |
async def health_check(self):
|
| 270 |
-
"""Perform health check on
|
| 271 |
try:
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
model=self.settings.novita_model,
|
| 275 |
-
messages=[{"role": "user", "content": "test"}],
|
| 276 |
-
max_tokens=5
|
| 277 |
-
)
|
| 278 |
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
except Exception as e:
|
| 286 |
logger.error(f"Health check failed: {e}")
|
| 287 |
return {
|
| 288 |
-
"provider": "
|
| 289 |
"status": "unhealthy",
|
| 290 |
"error": str(e)
|
| 291 |
}
|
| 292 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None,
|
| 294 |
user_input: Optional[str] = None) -> str:
|
| 295 |
"""
|
|
|
|
| 1 |
+
# llm_router.py - ZeroGPU Chat API (RunPod)
|
| 2 |
import logging
|
| 3 |
import asyncio
|
| 4 |
+
import aiohttp
|
| 5 |
+
import time
|
| 6 |
from typing import Dict, Optional
|
| 7 |
from .models_config import LLM_CONFIG
|
| 8 |
from .config import get_settings
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
class LLMRouter:
|
| 13 |
def __init__(self, hf_token=None, use_local_models: bool = False):
|
| 14 |
"""
|
| 15 |
+
Initialize LLM Router with ZeroGPU Chat API (RunPod).
|
| 16 |
|
| 17 |
Args:
|
| 18 |
hf_token: Not used (kept for backward compatibility)
|
| 19 |
use_local_models: Must be False (local models disabled)
|
| 20 |
"""
|
| 21 |
if use_local_models:
|
| 22 |
+
raise ValueError("Local models are disabled. Only ZeroGPU Chat API is supported.")
|
| 23 |
|
| 24 |
self.settings = get_settings()
|
| 25 |
+
self.base_url = self.settings.zerogpu_base_url.rstrip('/')
|
| 26 |
+
self.access_token = None
|
| 27 |
+
self.refresh_token = None
|
| 28 |
+
self.token_expires_at = 0
|
| 29 |
+
self.session = None
|
| 30 |
+
|
| 31 |
+
# Validate base URL
|
| 32 |
+
if not self.settings.zerogpu_base_url:
|
| 33 |
+
raise ValueError(
|
| 34 |
+
"ZEROGPU_BASE_URL is required. "
|
| 35 |
+
"Set it in environment variables or .env file"
|
| 36 |
)
|
| 37 |
|
| 38 |
+
# Validate credentials
|
| 39 |
+
if not self.settings.zerogpu_email or not self.settings.zerogpu_password:
|
| 40 |
raise ValueError(
|
| 41 |
+
"ZEROGPU_EMAIL and ZEROGPU_PASSWORD are required. "
|
| 42 |
+
"Set them in environment variables or .env file"
|
| 43 |
)
|
| 44 |
|
| 45 |
+
logger.info("ZeroGPU Chat API client initializing")
|
| 46 |
+
logger.info(f"Base URL: {self.base_url}")
|
| 47 |
+
|
| 48 |
+
# Initialize session and authenticate
|
| 49 |
try:
|
| 50 |
+
# Authentication will happen on first request if needed
|
| 51 |
+
logger.info("ZeroGPU Chat API client initialized (authentication on first request)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
except Exception as e:
|
| 53 |
+
logger.error(f"Failed to initialize ZeroGPU Chat API client: {e}")
|
| 54 |
+
raise RuntimeError(f"Could not initialize ZeroGPU Chat API client: {e}") from e
|
| 55 |
|
| 56 |
async def route_inference(self, task_type: str, prompt: str, **kwargs):
|
| 57 |
"""
|
| 58 |
+
Route inference to ZeroGPU Chat API.
|
| 59 |
|
| 60 |
Args:
|
| 61 |
task_type: Type of task (general_reasoning, intent_classification, etc.)
|
|
|
|
| 65 |
Returns:
|
| 66 |
Generated text response
|
| 67 |
"""
|
| 68 |
+
logger.info(f"Routing inference to ZeroGPU Chat API for task: {task_type}")
|
|
|
|
|
|
|
|
|
|
| 69 |
|
| 70 |
try:
|
| 71 |
+
# Ensure authenticated
|
| 72 |
+
await self._ensure_authenticated()
|
| 73 |
+
|
| 74 |
+
# Map internal task types to API task types
|
| 75 |
+
api_task = self._map_task_type(task_type)
|
| 76 |
+
|
| 77 |
+
# Pass original task type for model config lookup
|
| 78 |
+
kwargs['original_task_type'] = task_type
|
| 79 |
+
|
| 80 |
# Handle embedding generation (may need special handling)
|
| 81 |
if task_type == "embedding_generation":
|
| 82 |
+
logger.warning("Embedding generation via ZeroGPU API may require special implementation")
|
| 83 |
+
result = await self._call_zerogpu_api(api_task, prompt, **kwargs)
|
|
|
|
| 84 |
else:
|
| 85 |
+
result = await self._call_zerogpu_api(api_task, prompt, **kwargs)
|
| 86 |
|
| 87 |
if result is None:
|
| 88 |
+
logger.error(f"ZeroGPU Chat API returned None for task: {task_type}")
|
| 89 |
raise RuntimeError(f"Inference failed for task: {task_type}")
|
| 90 |
|
| 91 |
+
logger.info(f"Inference complete for {task_type} (ZeroGPU Chat API)")
|
| 92 |
return result
|
| 93 |
|
| 94 |
except Exception as e:
|
| 95 |
+
logger.error(f"ZeroGPU Chat API inference failed: {e}", exc_info=True)
|
| 96 |
raise RuntimeError(
|
| 97 |
f"Inference failed for task: {task_type}. "
|
| 98 |
+
f"ZeroGPU Chat API error: {e}"
|
| 99 |
) from e
|
| 100 |
|
| 101 |
+
async def _ensure_authenticated(self):
|
| 102 |
+
"""Ensure we have a valid access token, login if needed."""
|
| 103 |
+
# Check if token is expired (with 60 second buffer)
|
| 104 |
+
if self.access_token and time.time() < (self.token_expires_at - 60):
|
| 105 |
+
return
|
| 106 |
+
|
| 107 |
+
# Create session if needed
|
| 108 |
+
if self.session is None:
|
| 109 |
+
self.session = aiohttp.ClientSession()
|
| 110 |
+
|
| 111 |
+
# Login to get tokens
|
| 112 |
+
await self._login()
|
| 113 |
+
|
| 114 |
+
async def _login(self):
|
| 115 |
+
"""Login to ZeroGPU Chat API and get access/refresh tokens."""
|
| 116 |
+
try:
|
| 117 |
+
login_url = f"{self.base_url}/login"
|
| 118 |
+
login_data = {
|
| 119 |
+
"email": self.settings.zerogpu_email,
|
| 120 |
+
"password": self.settings.zerogpu_password
|
| 121 |
+
}
|
| 122 |
+
|
| 123 |
+
async with self.session.post(login_url, json=login_data) as response:
|
| 124 |
+
if response.status == 401:
|
| 125 |
+
raise ValueError("Invalid email or password for ZeroGPU Chat API")
|
| 126 |
+
response.raise_for_status()
|
| 127 |
+
data = await response.json()
|
| 128 |
+
|
| 129 |
+
self.access_token = data.get("access_token")
|
| 130 |
+
self.refresh_token = data.get("refresh_token")
|
| 131 |
+
|
| 132 |
+
# Access tokens typically expire in 15 minutes (900 seconds)
|
| 133 |
+
self.token_expires_at = time.time() + 900
|
| 134 |
+
|
| 135 |
+
logger.info("Successfully authenticated with ZeroGPU Chat API")
|
| 136 |
+
|
| 137 |
+
except aiohttp.ClientError as e:
|
| 138 |
+
logger.error(f"Failed to login to ZeroGPU Chat API: {e}")
|
| 139 |
+
raise RuntimeError(f"Authentication failed: {e}") from e
|
| 140 |
+
|
| 141 |
+
async def _refresh_token(self):
|
| 142 |
+
"""Refresh access token using refresh token."""
|
| 143 |
+
try:
|
| 144 |
+
refresh_url = f"{self.base_url}/refresh"
|
| 145 |
+
headers = {"X-Refresh-Token": self.refresh_token}
|
| 146 |
+
|
| 147 |
+
async with self.session.post(refresh_url, headers=headers) as response:
|
| 148 |
+
if response.status == 401:
|
| 149 |
+
# Refresh token expired, need to login again
|
| 150 |
+
await self._login()
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
response.raise_for_status()
|
| 154 |
+
data = await response.json()
|
| 155 |
+
|
| 156 |
+
self.access_token = data.get("access_token")
|
| 157 |
+
self.refresh_token = data.get("refresh_token")
|
| 158 |
+
self.token_expires_at = time.time() + 900
|
| 159 |
+
|
| 160 |
+
logger.info("Successfully refreshed ZeroGPU Chat API token")
|
| 161 |
+
|
| 162 |
+
except aiohttp.ClientError as e:
|
| 163 |
+
logger.error(f"Failed to refresh token: {e}")
|
| 164 |
+
# Try login as fallback
|
| 165 |
+
await self._login()
|
| 166 |
+
|
| 167 |
+
def _map_task_type(self, internal_task: str) -> str:
|
| 168 |
+
"""Map internal task types to ZeroGPU Chat API task types."""
|
| 169 |
+
task_mapping = {
|
| 170 |
+
"general_reasoning": "general",
|
| 171 |
+
"response_synthesis": "general",
|
| 172 |
+
"intent_classification": "classification",
|
| 173 |
+
"safety_check": "classification",
|
| 174 |
+
"embedding_generation": "embedding"
|
| 175 |
+
}
|
| 176 |
+
return task_mapping.get(internal_task, "general")
|
| 177 |
+
|
| 178 |
+
async def _call_zerogpu_api(self, task: str, prompt: str, **kwargs) -> Optional[str]:
|
| 179 |
+
"""Call ZeroGPU Chat API for inference."""
|
| 180 |
+
if not self.session:
|
| 181 |
+
self.session = aiohttp.ClientSession()
|
| 182 |
+
|
| 183 |
+
# Store original task type for model config lookup
|
| 184 |
+
original_task = kwargs.pop('original_task_type', None)
|
| 185 |
+
|
| 186 |
+
# Get model config for defaults
|
| 187 |
+
model_config = self._select_model(original_task or 'general_reasoning')
|
| 188 |
+
|
| 189 |
+
# Build request payload according to API documentation
|
| 190 |
+
payload = {
|
| 191 |
+
"message": prompt,
|
| 192 |
+
"task": task,
|
| 193 |
+
"max_tokens": kwargs.get('max_tokens', model_config.get('max_tokens', 512)),
|
| 194 |
+
"temperature": kwargs.get('temperature', model_config.get('temperature', 0.7)),
|
| 195 |
+
"top_p": kwargs.get('top_p', model_config.get('top_p', 0.9)),
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# Add optional parameters
|
| 199 |
+
if 'context' in kwargs and kwargs['context']:
|
| 200 |
+
# Convert context to API format if needed
|
| 201 |
+
context = kwargs['context']
|
| 202 |
+
if isinstance(context, list) and len(context) > 0:
|
| 203 |
+
# Convert to API format: list of dicts with role, content, timestamp
|
| 204 |
+
api_context = []
|
| 205 |
+
for item in context[:50]: # Max 50 messages
|
| 206 |
+
if isinstance(item, (list, tuple)) and len(item) >= 2:
|
| 207 |
+
# Format: [user_msg, assistant_msg]
|
| 208 |
+
api_context.append({
|
| 209 |
+
"role": "user",
|
| 210 |
+
"content": str(item[0]),
|
| 211 |
+
"timestamp": kwargs.get('timestamp', time.time())
|
| 212 |
+
})
|
| 213 |
+
api_context.append({
|
| 214 |
+
"role": "assistant",
|
| 215 |
+
"content": str(item[1]),
|
| 216 |
+
"timestamp": kwargs.get('timestamp', time.time())
|
| 217 |
+
})
|
| 218 |
+
elif isinstance(item, dict):
|
| 219 |
+
api_context.append(item)
|
| 220 |
+
payload["context"] = api_context
|
| 221 |
+
|
| 222 |
+
if 'system_prompt' in kwargs and kwargs['system_prompt']:
|
| 223 |
+
payload["system_prompt"] = kwargs['system_prompt']
|
| 224 |
+
if 'repetition_penalty' in kwargs:
|
| 225 |
+
payload["repetition_penalty"] = kwargs['repetition_penalty']
|
| 226 |
+
|
| 227 |
+
# Prepare headers
|
| 228 |
+
headers = {
|
| 229 |
+
"Authorization": f"Bearer {self.access_token}",
|
| 230 |
+
"Content-Type": "application/json"
|
| 231 |
}
|
| 232 |
|
| 233 |
try:
|
| 234 |
+
chat_url = f"{self.base_url}/chat"
|
| 235 |
+
|
| 236 |
+
async with self.session.post(chat_url, json=payload, headers=headers) as response:
|
| 237 |
+
# Handle token expiration
|
| 238 |
+
if response.status == 401:
|
| 239 |
+
logger.info("Token expired, refreshing...")
|
| 240 |
+
await self._refresh_token()
|
| 241 |
+
headers["Authorization"] = f"Bearer {self.access_token}"
|
| 242 |
+
# Retry request
|
| 243 |
+
async with self.session.post(chat_url, json=payload, headers=headers) as retry_response:
|
| 244 |
+
retry_response.raise_for_status()
|
| 245 |
+
data = await retry_response.json()
|
| 246 |
+
return data.get("response")
|
| 247 |
|
| 248 |
+
response.raise_for_status()
|
| 249 |
+
data = await response.json()
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
+
# Extract response from API
|
| 252 |
+
result = data.get("response")
|
| 253 |
+
if result:
|
| 254 |
+
logger.info(f"ZeroGPU Chat API generated response (length: {len(result)})")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 255 |
return result
|
| 256 |
else:
|
| 257 |
+
logger.error("ZeroGPU Chat API returned empty response")
|
| 258 |
return None
|
| 259 |
+
|
| 260 |
+
except aiohttp.ClientError as e:
|
| 261 |
+
logger.error(f"Error calling ZeroGPU Chat API: {e}", exc_info=True)
|
| 262 |
raise
|
| 263 |
|
| 264 |
def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int:
|
|
|
|
| 277 |
input_tokens = len(prompt) // 4
|
| 278 |
|
| 279 |
# Get model context window from settings
|
| 280 |
+
context_window = self.settings.zerogpu_model_context_window
|
| 281 |
|
| 282 |
logger.debug(
|
| 283 |
f"Calculating safe max_tokens: input ~{input_tokens} tokens, "
|
|
|
|
| 302 |
|
| 303 |
return safe_max_tokens
|
| 304 |
|
| 305 |
+
def _format_prompt(self, prompt: str, task_type: str, model_config: dict) -> str:
|
| 306 |
"""
|
| 307 |
+
Format prompt for ZeroGPU Chat API.
|
| 308 |
+
Can be customized based on model requirements.
|
|
|
|
|
|
|
| 309 |
"""
|
| 310 |
formatted_prompt = prompt
|
| 311 |
|
| 312 |
+
# Add math directive for mathematical problems if needed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
if self._is_math_query(prompt):
|
| 314 |
math_directive = "Please reason step by step, and put your final answer within \\boxed{}."
|
| 315 |
formatted_prompt = f"{formatted_prompt}\n\n{math_directive}"
|
|
|
|
| 327 |
return any(keyword in prompt_lower for keyword in math_keywords)
|
| 328 |
|
| 329 |
def _clean_reasoning_tags(self, text: str) -> str:
|
| 330 |
+
"""Clean up reasoning tags from response if present"""
|
| 331 |
+
if not text:
|
| 332 |
+
return text
|
| 333 |
+
# Remove common reasoning tags if present
|
| 334 |
+
text = text.replace("`<think>`", "").replace("`</think>`", "")
|
| 335 |
text = text.replace("`<think>`", "").replace("`</think>`", "")
|
| 336 |
text = text.strip()
|
| 337 |
return text
|
|
|
|
| 348 |
return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
|
| 349 |
|
| 350 |
async def get_available_models(self):
|
| 351 |
+
"""Get list of available models from ZeroGPU Chat API"""
|
| 352 |
+
try:
|
| 353 |
+
await self._ensure_authenticated()
|
| 354 |
+
if not self.session:
|
| 355 |
+
self.session = aiohttp.ClientSession()
|
| 356 |
+
|
| 357 |
+
tasks_url = f"{self.base_url}/tasks"
|
| 358 |
+
headers = {"Authorization": f"Bearer {self.access_token}"}
|
| 359 |
+
|
| 360 |
+
async with self.session.get(tasks_url, headers=headers) as response:
|
| 361 |
+
if response.status == 401:
|
| 362 |
+
await self._refresh_token()
|
| 363 |
+
headers["Authorization"] = f"Bearer {self.access_token}"
|
| 364 |
+
async with self.session.get(tasks_url, headers=headers) as retry_response:
|
| 365 |
+
retry_response.raise_for_status()
|
| 366 |
+
data = await retry_response.json()
|
| 367 |
+
else:
|
| 368 |
+
response.raise_for_status()
|
| 369 |
+
data = await response.json()
|
| 370 |
+
|
| 371 |
+
tasks = data.get("tasks", {})
|
| 372 |
+
models = [f"ZeroGPU Chat API - {task}: {info.get('model', 'N/A')}"
|
| 373 |
+
for task, info in tasks.items()]
|
| 374 |
+
return models if models else ["ZeroGPU Chat API"]
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.error(f"Failed to get available models: {e}")
|
| 377 |
+
return ["ZeroGPU Chat API"]
|
| 378 |
|
| 379 |
async def health_check(self):
|
| 380 |
+
"""Perform health check on ZeroGPU Chat API"""
|
| 381 |
try:
|
| 382 |
+
if not self.session:
|
| 383 |
+
self.session = aiohttp.ClientSession()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
+
# Check health endpoint (no auth required)
|
| 386 |
+
health_url = f"{self.base_url}/health"
|
| 387 |
+
async with self.session.get(health_url) as response:
|
| 388 |
+
response.raise_for_status()
|
| 389 |
+
data = await response.json()
|
| 390 |
+
|
| 391 |
+
return {
|
| 392 |
+
"provider": "zerogpu_chat_api",
|
| 393 |
+
"status": "healthy" if data.get("status") == "healthy" else "unhealthy",
|
| 394 |
+
"models_ready": data.get("models_ready", False),
|
| 395 |
+
"base_url": self.base_url
|
| 396 |
+
}
|
| 397 |
except Exception as e:
|
| 398 |
logger.error(f"Health check failed: {e}")
|
| 399 |
return {
|
| 400 |
+
"provider": "zerogpu_chat_api",
|
| 401 |
"status": "unhealthy",
|
| 402 |
"error": str(e)
|
| 403 |
}
|
| 404 |
|
| 405 |
+
async def __aenter__(self):
|
| 406 |
+
"""Async context manager entry"""
|
| 407 |
+
if not self.session:
|
| 408 |
+
self.session = aiohttp.ClientSession()
|
| 409 |
+
return self
|
| 410 |
+
|
| 411 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 412 |
+
"""Async context manager exit"""
|
| 413 |
+
if self.session:
|
| 414 |
+
await self.session.close()
|
| 415 |
+
self.session = None
|
| 416 |
+
|
| 417 |
def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None,
|
| 418 |
user_input: Optional[str] = None) -> str:
|
| 419 |
"""
|