JatsTheAIGen commited on
Commit
0747201
·
1 Parent(s): bd329bc

Replace Novita AI with ZeroGPU Chat API (RunPod)

Browse files

- Replace Novita AI API integration with ZeroGPU Chat API
- Update llm_router.py to use aiohttp for HTTP requests with JWT authentication
- Add automatic token refresh and authentication handling
- Update config.py with ZeroGPU settings (base_url, email, password)
- Update ENV_EXAMPLE_CONTENT.txt with ZeroGPU configuration
- Update flask_api_standalone.py references
- Remove OpenAI dependency from requirements.txt
- Implement task type mapping (general_reasoning -> general, etc.)
- Add context conversion for API format compatibility

ENV_EXAMPLE_CONTENT.txt CHANGED
@@ -5,27 +5,18 @@
5
  # Never commit .env to version control!
6
 
7
  # =============================================================================
8
- # Novita AI Configuration (REQUIRED)
9
  # =============================================================================
10
- # Get your API key from: https://novita.ai
11
- NOVITA_API_KEY=your_novita_api_key_here
 
 
12
 
13
- # Dedicated endpoint base URL (default for dedicated endpoints)
14
- NOVITA_BASE_URL=https://api.novita.ai/dedicated/v1/openai
15
 
16
- # Your dedicated endpoint model ID
17
- # Format: model-name:endpoint-id
18
- NOVITA_MODEL=deepseek-ai/DeepSeek-R1-Distill-Qwen-7B:de-1a706eeafbf3ebc2
19
-
20
- # =============================================================================
21
- # DeepSeek-R1 Optimized Settings
22
- # =============================================================================
23
- # Temperature: 0.5-0.7 range (0.6 recommended for DeepSeek-R1)
24
- DEEPSEEK_R1_TEMPERATURE=0.6
25
-
26
- # Force reasoning trigger: Enable to ensure DeepSeek-R1 uses reasoning pattern
27
- # Set to True to add `<think>` prefix for reasoning tasks
28
- DEEPSEEK_R1_FORCE_REASONING=True
29
 
30
  # =============================================================================
31
  # Token Allocation Configuration
@@ -45,10 +36,10 @@ CONTEXT_PRUNING_THRESHOLD=115000
45
  PRIORITIZE_USER_INPUT=True
46
 
47
  # Model context window (actual limit for your deployed model)
48
- # Default: 128000 tokens for DeepSeek R1 (128K context window)
49
  # This is the maximum total tokens (input + output) the model can handle
50
- # Take full advantage of DeepSeek R1's 128K capability
51
- NOVITA_MODEL_CONTEXT_WINDOW=128000
52
 
53
  # =============================================================================
54
  # Database Configuration
 
5
  # Never commit .env to version control!
6
 
7
  # =============================================================================
8
+ # ZeroGPU Chat API Configuration (REQUIRED)
9
  # =============================================================================
10
+ # Base URL for your ZeroGPU Chat API endpoint (RunPod)
11
+ # Format: http://your-pod-ip:8000 or https://your-domain.com
12
+ # Example: http://bm9njt1ypzvuqw-8000.proxy.runpod.net
13
+ ZEROGPU_BASE_URL=http://your-pod-ip:8000
14
 
15
+ # Email for authentication (register first via /register endpoint)
16
+ ZEROGPU_EMAIL=your-email@example.com
17
 
18
+ # Password for authentication
19
+ ZEROGPU_PASSWORD=your_secure_password_here
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  # =============================================================================
22
  # Token Allocation Configuration
 
36
  PRIORITIZE_USER_INPUT=True
37
 
38
  # Model context window (actual limit for your deployed model)
39
+ # Default: 8192 tokens (adjust based on your model)
40
  # This is the maximum total tokens (input + output) the model can handle
41
+ # Common values: 4096, 8192, 16384, 32768, etc.
42
+ ZEROGPU_MODEL_CONTEXT_WINDOW=8192
43
 
44
  # =============================================================================
45
  # Database Configuration
flask_api_standalone.py CHANGED
@@ -145,7 +145,7 @@ initialization_attempted = False
145
  initialization_error = None
146
 
147
  def initialize_orchestrator():
148
- """Initialize the AI orchestrator with Novita AI API only"""
149
  global orchestrator, orchestrator_available, initialization_attempted, initialization_error
150
 
151
  initialization_attempted = True
@@ -153,7 +153,7 @@ def initialize_orchestrator():
153
 
154
  try:
155
  logger.info("=" * 60)
156
- logger.info("INITIALIZING AI ORCHESTRATOR (Novita AI API Only)")
157
  logger.info("=" * 60)
158
 
159
  from src.agents.intent_agent import create_intent_agent
@@ -166,16 +166,16 @@ def initialize_orchestrator():
166
 
167
  logger.info("✓ Imports successful")
168
 
169
- # Initialize LLM Router - Novita AI API only
170
- logger.info("Initializing LLM Router (Novita AI API only)...")
171
  try:
172
- # Always use Novita AI API (local models disabled)
173
  llm_router = LLMRouter(hf_token=None, use_local_models=False)
174
- logger.info("✓ LLM Router initialized (Novita AI API)")
175
  except Exception as e:
176
  logger.error(f"❌ Failed to initialize LLM Router: {e}", exc_info=True)
177
- logger.error("This is a critical error - Novita AI API is required")
178
- logger.error("Please ensure NOVITA_API_KEY is set in environment variables")
179
  raise
180
 
181
  logger.info("Initializing Agents...")
@@ -210,24 +210,25 @@ def initialize_orchestrator():
210
  orchestrator_available = True
211
  logger.info("=" * 60)
212
  logger.info("✓ AI ORCHESTRATOR READY")
213
- logger.info(" - Novita AI API enabled")
214
  logger.info(" - MAX_WORKERS: 4")
215
  logger.info("=" * 60)
216
 
217
  return True
218
 
219
  except ValueError as e:
220
- # Handle configuration errors (e.g., missing NOVITA_API_KEY)
221
- if "NOVITA_API_KEY" in str(e) or "required" in str(e).lower():
222
  logger.error("=" * 60)
223
  logger.error("❌ CONFIGURATION ERROR")
224
  logger.error("=" * 60)
225
  logger.error(f"Error: {e}")
226
  logger.error("")
227
  logger.error("SOLUTION:")
228
- logger.error("1. Set NOVITA_API_KEY in environment variables")
229
- logger.error("2. Ensure NOVITA_BASE_URL is correct")
230
- logger.error("3. Verify NOVITA_MODEL matches your endpoint")
 
231
  logger.error("=" * 60)
232
  orchestrator_available = False
233
  initialization_error = f"Configuration Error: {str(e)}"
 
145
  initialization_error = None
146
 
147
  def initialize_orchestrator():
148
+ """Initialize the AI orchestrator with ZeroGPU Chat API (RunPod)"""
149
  global orchestrator, orchestrator_available, initialization_attempted, initialization_error
150
 
151
  initialization_attempted = True
 
153
 
154
  try:
155
  logger.info("=" * 60)
156
+ logger.info("INITIALIZING AI ORCHESTRATOR (ZeroGPU Chat API - RunPod)")
157
  logger.info("=" * 60)
158
 
159
  from src.agents.intent_agent import create_intent_agent
 
166
 
167
  logger.info("✓ Imports successful")
168
 
169
+ # Initialize LLM Router - ZeroGPU Chat API
170
+ logger.info("Initializing LLM Router (ZeroGPU Chat API)...")
171
  try:
172
+ # Always use ZeroGPU Chat API (local models disabled)
173
  llm_router = LLMRouter(hf_token=None, use_local_models=False)
174
+ logger.info("✓ LLM Router initialized (ZeroGPU Chat API)")
175
  except Exception as e:
176
  logger.error(f"❌ Failed to initialize LLM Router: {e}", exc_info=True)
177
+ logger.error("This is a critical error - ZeroGPU Chat API is required")
178
+ logger.error("Please ensure ZEROGPU_BASE_URL, ZEROGPU_EMAIL, and ZEROGPU_PASSWORD are set in environment variables")
179
  raise
180
 
181
  logger.info("Initializing Agents...")
 
210
  orchestrator_available = True
211
  logger.info("=" * 60)
212
  logger.info("✓ AI ORCHESTRATOR READY")
213
+ logger.info(" - ZeroGPU Chat API enabled")
214
  logger.info(" - MAX_WORKERS: 4")
215
  logger.info("=" * 60)
216
 
217
  return True
218
 
219
  except ValueError as e:
220
+ # Handle configuration errors (e.g., missing ZeroGPU credentials)
221
+ if "ZEROGPU" in str(e) or "required" in str(e).lower():
222
  logger.error("=" * 60)
223
  logger.error("❌ CONFIGURATION ERROR")
224
  logger.error("=" * 60)
225
  logger.error(f"Error: {e}")
226
  logger.error("")
227
  logger.error("SOLUTION:")
228
+ logger.error("1. Set ZEROGPU_BASE_URL in environment variables (e.g., http://your-pod-ip:8000)")
229
+ logger.error("2. Set ZEROGPU_EMAIL in environment variables")
230
+ logger.error("3. Set ZEROGPU_PASSWORD in environment variables")
231
+ logger.error("4. Register your account first via the /register endpoint if needed")
232
  logger.error("=" * 60)
233
  orchestrator_available = False
234
  initialization_error = f"Configuration Error: {str(e)}"
requirements.txt CHANGED
@@ -107,6 +107,6 @@ debugpy>=1.7.0
107
  bandit>=1.7.5 # Security linter for Python code
108
  safety>=2.3.5 # Dependency vulnerability scanner
109
 
110
- # LLM API Client (required for Novita AI API)
111
- openai>=1.0.0
112
 
 
107
  bandit>=1.7.5 # Security linter for Python code
108
  safety>=2.3.5 # Dependency vulnerability scanner
109
 
110
+ # HTTP Client for ZeroGPU Chat API (aiohttp already included above)
111
+ # Note: No OpenAI client needed - using direct HTTP requests
112
 
src/config.py CHANGED
@@ -174,37 +174,24 @@ class Settings(BaseSettings):
174
 
175
  return self._cached_cache_dir
176
 
177
- # ==================== Novita AI Configuration ====================
178
 
179
- novita_api_key: str = Field(
180
- default="",
181
- description="Novita AI API key (required)",
182
- env="NOVITA_API_KEY"
183
- )
184
-
185
- novita_base_url: str = Field(
186
- default="https://api.novita.ai/dedicated/v1/openai",
187
- description="Novita AI dedicated endpoint base URL",
188
- env="NOVITA_BASE_URL"
189
- )
190
-
191
- novita_model: str = Field(
192
- default="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B:de-1a706eeafbf3ebc2",
193
- description="Novita AI dedicated endpoint model ID",
194
- env="NOVITA_MODEL"
195
  )
196
 
197
- # DeepSeek-R1 optimized settings
198
- deepseek_r1_temperature: float = Field(
199
- default=0.6,
200
- description="Temperature for DeepSeek-R1 models (0.5-0.7 range, 0.6 recommended)",
201
- env="DEEPSEEK_R1_TEMPERATURE"
202
  )
203
 
204
- deepseek_r1_force_reasoning: bool = Field(
205
- default=True,
206
- description="Force DeepSeek-R1 to start with reasoning trigger",
207
- env="DEEPSEEK_R1_FORCE_REASONING"
208
  )
209
 
210
  # Token Allocation Configuration
@@ -233,34 +220,40 @@ class Settings(BaseSettings):
233
  )
234
 
235
  # Model Context Window Configuration
236
- novita_model_context_window: int = Field(
237
- default=128000,
238
- description="Maximum context window for Novita AI model (input + output tokens). DeepSeek R1 supports 128K tokens.",
239
- env="NOVITA_MODEL_CONTEXT_WINDOW"
240
  )
241
 
242
- @validator("novita_api_key", pre=True)
243
- def validate_novita_api_key(cls, v):
244
- """Validate and clean Novita API key"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  if v is None:
246
  return ""
247
  return str(v).strip()
248
 
249
- @validator("deepseek_r1_temperature", pre=True)
250
- def validate_deepseek_temperature(cls, v):
251
- """Validate DeepSeek-R1 temperature is in recommended range"""
252
- if isinstance(v, str):
253
- v = float(v)
254
- temp = float(v) if v else 0.6
255
- return max(0.5, min(0.7, temp))
256
-
257
- @validator("deepseek_r1_force_reasoning", pre=True)
258
- def validate_force_reasoning(cls, v):
259
- """Convert string to boolean for force_reasoning"""
260
- if isinstance(v, str):
261
- return v.lower() in ("true", "1", "yes", "on")
262
- return bool(v)
263
-
264
  @validator("user_input_max_tokens", pre=True)
265
  def validate_user_input_tokens(cls, v):
266
  """Validate user input token limit"""
@@ -279,10 +272,10 @@ class Settings(BaseSettings):
279
  val = int(v) if v else 115000
280
  return max(4000, min(125000, val)) # Match context_preparation_budget limits
281
 
282
- @validator("novita_model_context_window", pre=True)
283
  def validate_context_window(cls, v):
284
  """Validate context window size"""
285
- val = int(v) if v else 128000
286
  return max(1000, min(200000, val)) # Support up to 200K for future models
287
 
288
  # ==================== Model Configuration ====================
 
174
 
175
  return self._cached_cache_dir
176
 
177
+ # ==================== ZeroGPU Chat API Configuration ====================
178
 
179
+ zerogpu_base_url: str = Field(
180
+ default="http://your-pod-ip:8000",
181
+ description="ZeroGPU Chat API base URL (RunPod endpoint)",
182
+ env="ZEROGPU_BASE_URL"
 
 
 
 
 
 
 
 
 
 
 
 
183
  )
184
 
185
+ zerogpu_email: str = Field(
186
+ default="",
187
+ description="ZeroGPU Chat API email for authentication (required)",
188
+ env="ZEROGPU_EMAIL"
 
189
  )
190
 
191
+ zerogpu_password: str = Field(
192
+ default="",
193
+ description="ZeroGPU Chat API password for authentication (required)",
194
+ env="ZEROGPU_PASSWORD"
195
  )
196
 
197
  # Token Allocation Configuration
 
220
  )
221
 
222
  # Model Context Window Configuration
223
+ zerogpu_model_context_window: int = Field(
224
+ default=8192,
225
+ description="Maximum context window for ZeroGPU Chat API model (input + output tokens). Adjust based on your deployed model.",
226
+ env="ZEROGPU_MODEL_CONTEXT_WINDOW"
227
  )
228
 
229
+ @validator("zerogpu_base_url", pre=True)
230
+ def validate_zerogpu_base_url(cls, v):
231
+ """Validate ZeroGPU base URL"""
232
+ if v is None:
233
+ return "http://your-pod-ip:8000"
234
+ url = str(v).strip()
235
+ # Remove trailing slash
236
+ if url.endswith('/'):
237
+ url = url[:-1]
238
+ return url
239
+
240
+ @validator("zerogpu_email", pre=True)
241
+ def validate_zerogpu_email(cls, v):
242
+ """Validate ZeroGPU email"""
243
+ if v is None:
244
+ return ""
245
+ email = str(v).strip()
246
+ if email and '@' not in email:
247
+ logger.warning("ZEROGPU_EMAIL may not be a valid email address")
248
+ return email
249
+
250
+ @validator("zerogpu_password", pre=True)
251
+ def validate_zerogpu_password(cls, v):
252
+ """Validate ZeroGPU password"""
253
  if v is None:
254
  return ""
255
  return str(v).strip()
256
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  @validator("user_input_max_tokens", pre=True)
258
  def validate_user_input_tokens(cls, v):
259
  """Validate user input token limit"""
 
272
  val = int(v) if v else 115000
273
  return max(4000, min(125000, val)) # Match context_preparation_budget limits
274
 
275
+ @validator("zerogpu_model_context_window", pre=True)
276
  def validate_context_window(cls, v):
277
  """Validate context window size"""
278
+ val = int(v) if v else 8192
279
  return max(1000, min(200000, val)) # Support up to 200K for future models
280
 
281
  # ==================== Model Configuration ====================
src/llm_router.py CHANGED
@@ -1,67 +1,61 @@
1
- # llm_router.py - NOVITA AI API ONLY
2
  import logging
3
  import asyncio
 
 
4
  from typing import Dict, Optional
5
  from .models_config import LLM_CONFIG
6
  from .config import get_settings
7
 
8
- # Import OpenAI client for Novita AI API
9
- try:
10
- from openai import OpenAI
11
- OPENAI_AVAILABLE = True
12
- except ImportError:
13
- OPENAI_AVAILABLE = False
14
- logger = logging.getLogger(__name__)
15
- logger.error("openai package not available - Novita AI API requires openai package")
16
-
17
  logger = logging.getLogger(__name__)
18
 
19
  class LLMRouter:
20
  def __init__(self, hf_token=None, use_local_models: bool = False):
21
  """
22
- Initialize LLM Router with Novita AI API only.
23
 
24
  Args:
25
  hf_token: Not used (kept for backward compatibility)
26
  use_local_models: Must be False (local models disabled)
27
  """
28
  if use_local_models:
29
- raise ValueError("Local models are disabled. Only Novita AI API is supported.")
30
 
31
  self.settings = get_settings()
32
- self.novita_client = None
33
-
34
- # Validate OpenAI package
35
- if not OPENAI_AVAILABLE:
36
- raise ImportError(
37
- "openai package is required for Novita AI API. "
38
- "Install it with: pip install openai>=1.0.0"
 
 
 
 
39
  )
40
 
41
- # Validate API key
42
- if not self.settings.novita_api_key:
43
  raise ValueError(
44
- "NOVITA_API_KEY is required. "
45
- "Set it in environment variables or .env file"
46
  )
47
 
48
- # Initialize Novita AI client
 
 
 
49
  try:
50
- self.novita_client = OpenAI(
51
- base_url=self.settings.novita_base_url,
52
- api_key=self.settings.novita_api_key,
53
- )
54
- logger.info("Novita AI API client initialized")
55
- logger.info(f"Base URL: {self.settings.novita_base_url}")
56
- logger.info(f"Model: {self.settings.novita_model}")
57
- logger.info(f"Context Window: {self.settings.novita_model_context_window} tokens")
58
  except Exception as e:
59
- logger.error(f"Failed to initialize Novita AI client: {e}")
60
- raise RuntimeError(f"Could not initialize Novita AI API client: {e}") from e
61
 
62
  async def route_inference(self, task_type: str, prompt: str, **kwargs):
63
  """
64
- Route inference to Novita AI API.
65
 
66
  Args:
67
  task_type: Type of task (general_reasoning, intent_classification, etc.)
@@ -71,101 +65,200 @@ class LLMRouter:
71
  Returns:
72
  Generated text response
73
  """
74
- logger.info(f"Routing inference to Novita AI API for task: {task_type}")
75
-
76
- if not self.novita_client:
77
- raise RuntimeError("Novita AI client not initialized")
78
 
79
  try:
 
 
 
 
 
 
 
 
 
80
  # Handle embedding generation (may need special handling)
81
  if task_type == "embedding_generation":
82
- logger.warning("Embedding generation via Novita API may require special implementation")
83
- # For now, use chat completion (may need adjustment based on Novita API capabilities)
84
- result = await self._call_novita_api(task_type, prompt, **kwargs)
85
  else:
86
- result = await self._call_novita_api(task_type, prompt, **kwargs)
87
 
88
  if result is None:
89
- logger.error(f"Novita AI API returned None for task: {task_type}")
90
  raise RuntimeError(f"Inference failed for task: {task_type}")
91
 
92
- logger.info(f"Inference complete for {task_type} (Novita AI API)")
93
  return result
94
 
95
  except Exception as e:
96
- logger.error(f"Novita AI API inference failed: {e}", exc_info=True)
97
  raise RuntimeError(
98
  f"Inference failed for task: {task_type}. "
99
- f"Novita AI API error: {e}"
100
  ) from e
101
 
102
- async def _call_novita_api(self, task_type: str, prompt: str, **kwargs) -> Optional[str]:
103
- """Call Novita AI API for inference."""
104
- if not self.novita_client:
105
- return None
106
-
107
- # Get model config
108
- model_config = self._select_model(task_type)
109
- model_name = kwargs.get('model', self.settings.novita_model)
110
-
111
- # Get optimized parameters
112
- requested_max_tokens = kwargs.get('max_tokens', model_config.get('max_tokens', 4096))
113
- temperature = kwargs.get('temperature',
114
- model_config.get('temperature', self.settings.deepseek_r1_temperature))
115
- top_p = kwargs.get('top_p', model_config.get('top_p', 0.95))
116
- stream = kwargs.get('stream', False)
117
-
118
- # Format prompt according to DeepSeek-R1 best practices
119
- formatted_prompt = self._format_deepseek_r1_prompt(prompt, task_type, model_config)
120
-
121
- # IMPORTANT: Calculate safe max_tokens based on input size
122
- max_tokens = self._calculate_safe_max_tokens(formatted_prompt, requested_max_tokens)
123
-
124
- # IMPORTANT: No system prompt - all instructions in user prompt
125
- messages = [{"role": "user", "content": formatted_prompt}]
126
-
127
- # Build request parameters
128
- request_params = {
129
- "model": model_name,
130
- "messages": messages,
131
- "stream": stream,
132
- "max_tokens": max_tokens,
133
- "temperature": temperature,
134
- "top_p": top_p,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  }
136
 
137
  try:
138
- if stream:
139
- # Handle streaming response
140
- response_text = ""
141
- stream_response = self.novita_client.chat.completions.create(**request_params)
 
 
 
 
 
 
 
 
 
142
 
143
- for chunk in stream_response:
144
- if chunk.choices and len(chunk.choices) > 0:
145
- delta = chunk.choices[0].delta
146
- if delta and delta.content:
147
- response_text += delta.content
148
 
149
- # Clean up reasoning tags if present
150
- response_text = self._clean_reasoning_tags(response_text)
151
- logger.info(f"Novita AI API generated response (length: {len(response_text)})")
152
- return response_text
153
- else:
154
- # Handle non-streaming response
155
- response = self.novita_client.chat.completions.create(**request_params)
156
-
157
- if response.choices and len(response.choices) > 0:
158
- result = response.choices[0].message.content
159
- # Clean up reasoning tags if present
160
- result = self._clean_reasoning_tags(result)
161
- logger.info(f"Novita AI API generated response (length: {len(result)})")
162
  return result
163
  else:
164
- logger.error("Novita AI API returned empty response")
165
  return None
166
-
167
- except Exception as e:
168
- logger.error(f"Error calling Novita AI API: {e}", exc_info=True)
169
  raise
170
 
171
  def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int:
@@ -184,7 +277,7 @@ class LLMRouter:
184
  input_tokens = len(prompt) // 4
185
 
186
  # Get model context window from settings
187
- context_window = self.settings.novita_model_context_window
188
 
189
  logger.debug(
190
  f"Calculating safe max_tokens: input ~{input_tokens} tokens, "
@@ -209,26 +302,14 @@ class LLMRouter:
209
 
210
  return safe_max_tokens
211
 
212
- def _format_deepseek_r1_prompt(self, prompt: str, task_type: str, model_config: dict) -> str:
213
  """
214
- Format prompt according to DeepSeek-R1 best practices:
215
- - No system prompt (all instructions in user prompt)
216
- - Force reasoning trigger for reasoning tasks
217
- - Add math directive for mathematical problems
218
  """
219
  formatted_prompt = prompt
220
 
221
- # Check if we should force reasoning prefix
222
- force_reasoning = (
223
- self.settings.deepseek_r1_force_reasoning and
224
- model_config.get("force_reasoning_prefix", False)
225
- )
226
-
227
- if force_reasoning:
228
- # Force model to start with reasoning trigger
229
- formatted_prompt = f"`<think>`\n\n{formatted_prompt}"
230
-
231
- # Add math directive for mathematical problems
232
  if self._is_math_query(prompt):
233
  math_directive = "Please reason step by step, and put your final answer within \\boxed{}."
234
  formatted_prompt = f"{formatted_prompt}\n\n{math_directive}"
@@ -246,7 +327,11 @@ class LLMRouter:
246
  return any(keyword in prompt_lower for keyword in math_keywords)
247
 
248
  def _clean_reasoning_tags(self, text: str) -> str:
249
- """Clean up reasoning tags from response"""
 
 
 
 
250
  text = text.replace("`<think>`", "").replace("`</think>`", "")
251
  text = text.strip()
252
  return text
@@ -263,33 +348,72 @@ class LLMRouter:
263
  return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
264
 
265
  async def get_available_models(self):
266
- """Get list of available models (Novita AI only)"""
267
- return ["Novita AI API - DeepSeek-R1-Distill-Qwen-7B"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
  async def health_check(self):
270
- """Perform health check on Novita AI API"""
271
  try:
272
- # Test API with a simple request
273
- test_response = self.novita_client.chat.completions.create(
274
- model=self.settings.novita_model,
275
- messages=[{"role": "user", "content": "test"}],
276
- max_tokens=5
277
- )
278
 
279
- return {
280
- "provider": "novita_api",
281
- "status": "healthy",
282
- "model": self.settings.novita_model,
283
- "base_url": self.settings.novita_base_url
284
- }
 
 
 
 
 
 
285
  except Exception as e:
286
  logger.error(f"Health check failed: {e}")
287
  return {
288
- "provider": "novita_api",
289
  "status": "unhealthy",
290
  "error": str(e)
291
  }
292
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None,
294
  user_input: Optional[str] = None) -> str:
295
  """
 
1
+ # llm_router.py - ZeroGPU Chat API (RunPod)
2
  import logging
3
  import asyncio
4
+ import aiohttp
5
+ import time
6
  from typing import Dict, Optional
7
  from .models_config import LLM_CONFIG
8
  from .config import get_settings
9
 
 
 
 
 
 
 
 
 
 
10
  logger = logging.getLogger(__name__)
11
 
12
  class LLMRouter:
13
  def __init__(self, hf_token=None, use_local_models: bool = False):
14
  """
15
+ Initialize LLM Router with ZeroGPU Chat API (RunPod).
16
 
17
  Args:
18
  hf_token: Not used (kept for backward compatibility)
19
  use_local_models: Must be False (local models disabled)
20
  """
21
  if use_local_models:
22
+ raise ValueError("Local models are disabled. Only ZeroGPU Chat API is supported.")
23
 
24
  self.settings = get_settings()
25
+ self.base_url = self.settings.zerogpu_base_url.rstrip('/')
26
+ self.access_token = None
27
+ self.refresh_token = None
28
+ self.token_expires_at = 0
29
+ self.session = None
30
+
31
+ # Validate base URL
32
+ if not self.settings.zerogpu_base_url:
33
+ raise ValueError(
34
+ "ZEROGPU_BASE_URL is required. "
35
+ "Set it in environment variables or .env file"
36
  )
37
 
38
+ # Validate credentials
39
+ if not self.settings.zerogpu_email or not self.settings.zerogpu_password:
40
  raise ValueError(
41
+ "ZEROGPU_EMAIL and ZEROGPU_PASSWORD are required. "
42
+ "Set them in environment variables or .env file"
43
  )
44
 
45
+ logger.info("ZeroGPU Chat API client initializing")
46
+ logger.info(f"Base URL: {self.base_url}")
47
+
48
+ # Initialize session and authenticate
49
  try:
50
+ # Authentication will happen on first request if needed
51
+ logger.info("ZeroGPU Chat API client initialized (authentication on first request)")
 
 
 
 
 
 
52
  except Exception as e:
53
+ logger.error(f"Failed to initialize ZeroGPU Chat API client: {e}")
54
+ raise RuntimeError(f"Could not initialize ZeroGPU Chat API client: {e}") from e
55
 
56
  async def route_inference(self, task_type: str, prompt: str, **kwargs):
57
  """
58
+ Route inference to ZeroGPU Chat API.
59
 
60
  Args:
61
  task_type: Type of task (general_reasoning, intent_classification, etc.)
 
65
  Returns:
66
  Generated text response
67
  """
68
+ logger.info(f"Routing inference to ZeroGPU Chat API for task: {task_type}")
 
 
 
69
 
70
  try:
71
+ # Ensure authenticated
72
+ await self._ensure_authenticated()
73
+
74
+ # Map internal task types to API task types
75
+ api_task = self._map_task_type(task_type)
76
+
77
+ # Pass original task type for model config lookup
78
+ kwargs['original_task_type'] = task_type
79
+
80
  # Handle embedding generation (may need special handling)
81
  if task_type == "embedding_generation":
82
+ logger.warning("Embedding generation via ZeroGPU API may require special implementation")
83
+ result = await self._call_zerogpu_api(api_task, prompt, **kwargs)
 
84
  else:
85
+ result = await self._call_zerogpu_api(api_task, prompt, **kwargs)
86
 
87
  if result is None:
88
+ logger.error(f"ZeroGPU Chat API returned None for task: {task_type}")
89
  raise RuntimeError(f"Inference failed for task: {task_type}")
90
 
91
+ logger.info(f"Inference complete for {task_type} (ZeroGPU Chat API)")
92
  return result
93
 
94
  except Exception as e:
95
+ logger.error(f"ZeroGPU Chat API inference failed: {e}", exc_info=True)
96
  raise RuntimeError(
97
  f"Inference failed for task: {task_type}. "
98
+ f"ZeroGPU Chat API error: {e}"
99
  ) from e
100
 
101
+ async def _ensure_authenticated(self):
102
+ """Ensure we have a valid access token, login if needed."""
103
+ # Check if token is expired (with 60 second buffer)
104
+ if self.access_token and time.time() < (self.token_expires_at - 60):
105
+ return
106
+
107
+ # Create session if needed
108
+ if self.session is None:
109
+ self.session = aiohttp.ClientSession()
110
+
111
+ # Login to get tokens
112
+ await self._login()
113
+
114
+ async def _login(self):
115
+ """Login to ZeroGPU Chat API and get access/refresh tokens."""
116
+ try:
117
+ login_url = f"{self.base_url}/login"
118
+ login_data = {
119
+ "email": self.settings.zerogpu_email,
120
+ "password": self.settings.zerogpu_password
121
+ }
122
+
123
+ async with self.session.post(login_url, json=login_data) as response:
124
+ if response.status == 401:
125
+ raise ValueError("Invalid email or password for ZeroGPU Chat API")
126
+ response.raise_for_status()
127
+ data = await response.json()
128
+
129
+ self.access_token = data.get("access_token")
130
+ self.refresh_token = data.get("refresh_token")
131
+
132
+ # Access tokens typically expire in 15 minutes (900 seconds)
133
+ self.token_expires_at = time.time() + 900
134
+
135
+ logger.info("Successfully authenticated with ZeroGPU Chat API")
136
+
137
+ except aiohttp.ClientError as e:
138
+ logger.error(f"Failed to login to ZeroGPU Chat API: {e}")
139
+ raise RuntimeError(f"Authentication failed: {e}") from e
140
+
141
+ async def _refresh_token(self):
142
+ """Refresh access token using refresh token."""
143
+ try:
144
+ refresh_url = f"{self.base_url}/refresh"
145
+ headers = {"X-Refresh-Token": self.refresh_token}
146
+
147
+ async with self.session.post(refresh_url, headers=headers) as response:
148
+ if response.status == 401:
149
+ # Refresh token expired, need to login again
150
+ await self._login()
151
+ return
152
+
153
+ response.raise_for_status()
154
+ data = await response.json()
155
+
156
+ self.access_token = data.get("access_token")
157
+ self.refresh_token = data.get("refresh_token")
158
+ self.token_expires_at = time.time() + 900
159
+
160
+ logger.info("Successfully refreshed ZeroGPU Chat API token")
161
+
162
+ except aiohttp.ClientError as e:
163
+ logger.error(f"Failed to refresh token: {e}")
164
+ # Try login as fallback
165
+ await self._login()
166
+
167
+ def _map_task_type(self, internal_task: str) -> str:
168
+ """Map internal task types to ZeroGPU Chat API task types."""
169
+ task_mapping = {
170
+ "general_reasoning": "general",
171
+ "response_synthesis": "general",
172
+ "intent_classification": "classification",
173
+ "safety_check": "classification",
174
+ "embedding_generation": "embedding"
175
+ }
176
+ return task_mapping.get(internal_task, "general")
177
+
178
+ async def _call_zerogpu_api(self, task: str, prompt: str, **kwargs) -> Optional[str]:
179
+ """Call ZeroGPU Chat API for inference."""
180
+ if not self.session:
181
+ self.session = aiohttp.ClientSession()
182
+
183
+ # Store original task type for model config lookup
184
+ original_task = kwargs.pop('original_task_type', None)
185
+
186
+ # Get model config for defaults
187
+ model_config = self._select_model(original_task or 'general_reasoning')
188
+
189
+ # Build request payload according to API documentation
190
+ payload = {
191
+ "message": prompt,
192
+ "task": task,
193
+ "max_tokens": kwargs.get('max_tokens', model_config.get('max_tokens', 512)),
194
+ "temperature": kwargs.get('temperature', model_config.get('temperature', 0.7)),
195
+ "top_p": kwargs.get('top_p', model_config.get('top_p', 0.9)),
196
+ }
197
+
198
+ # Add optional parameters
199
+ if 'context' in kwargs and kwargs['context']:
200
+ # Convert context to API format if needed
201
+ context = kwargs['context']
202
+ if isinstance(context, list) and len(context) > 0:
203
+ # Convert to API format: list of dicts with role, content, timestamp
204
+ api_context = []
205
+ for item in context[:50]: # Max 50 messages
206
+ if isinstance(item, (list, tuple)) and len(item) >= 2:
207
+ # Format: [user_msg, assistant_msg]
208
+ api_context.append({
209
+ "role": "user",
210
+ "content": str(item[0]),
211
+ "timestamp": kwargs.get('timestamp', time.time())
212
+ })
213
+ api_context.append({
214
+ "role": "assistant",
215
+ "content": str(item[1]),
216
+ "timestamp": kwargs.get('timestamp', time.time())
217
+ })
218
+ elif isinstance(item, dict):
219
+ api_context.append(item)
220
+ payload["context"] = api_context
221
+
222
+ if 'system_prompt' in kwargs and kwargs['system_prompt']:
223
+ payload["system_prompt"] = kwargs['system_prompt']
224
+ if 'repetition_penalty' in kwargs:
225
+ payload["repetition_penalty"] = kwargs['repetition_penalty']
226
+
227
+ # Prepare headers
228
+ headers = {
229
+ "Authorization": f"Bearer {self.access_token}",
230
+ "Content-Type": "application/json"
231
  }
232
 
233
  try:
234
+ chat_url = f"{self.base_url}/chat"
235
+
236
+ async with self.session.post(chat_url, json=payload, headers=headers) as response:
237
+ # Handle token expiration
238
+ if response.status == 401:
239
+ logger.info("Token expired, refreshing...")
240
+ await self._refresh_token()
241
+ headers["Authorization"] = f"Bearer {self.access_token}"
242
+ # Retry request
243
+ async with self.session.post(chat_url, json=payload, headers=headers) as retry_response:
244
+ retry_response.raise_for_status()
245
+ data = await retry_response.json()
246
+ return data.get("response")
247
 
248
+ response.raise_for_status()
249
+ data = await response.json()
 
 
 
250
 
251
+ # Extract response from API
252
+ result = data.get("response")
253
+ if result:
254
+ logger.info(f"ZeroGPU Chat API generated response (length: {len(result)})")
 
 
 
 
 
 
 
 
 
255
  return result
256
  else:
257
+ logger.error("ZeroGPU Chat API returned empty response")
258
  return None
259
+
260
+ except aiohttp.ClientError as e:
261
+ logger.error(f"Error calling ZeroGPU Chat API: {e}", exc_info=True)
262
  raise
263
 
264
  def _calculate_safe_max_tokens(self, prompt: str, requested_max_tokens: int) -> int:
 
277
  input_tokens = len(prompt) // 4
278
 
279
  # Get model context window from settings
280
+ context_window = self.settings.zerogpu_model_context_window
281
 
282
  logger.debug(
283
  f"Calculating safe max_tokens: input ~{input_tokens} tokens, "
 
302
 
303
  return safe_max_tokens
304
 
305
+ def _format_prompt(self, prompt: str, task_type: str, model_config: dict) -> str:
306
  """
307
+ Format prompt for ZeroGPU Chat API.
308
+ Can be customized based on model requirements.
 
 
309
  """
310
  formatted_prompt = prompt
311
 
312
+ # Add math directive for mathematical problems if needed
 
 
 
 
 
 
 
 
 
 
313
  if self._is_math_query(prompt):
314
  math_directive = "Please reason step by step, and put your final answer within \\boxed{}."
315
  formatted_prompt = f"{formatted_prompt}\n\n{math_directive}"
 
327
  return any(keyword in prompt_lower for keyword in math_keywords)
328
 
329
  def _clean_reasoning_tags(self, text: str) -> str:
330
+ """Clean up reasoning tags from response if present"""
331
+ if not text:
332
+ return text
333
+ # Remove common reasoning tags if present
334
+ text = text.replace("`<think>`", "").replace("`</think>`", "")
335
  text = text.replace("`<think>`", "").replace("`</think>`", "")
336
  text = text.strip()
337
  return text
 
348
  return model_map.get(task_type, LLM_CONFIG["models"]["reasoning_primary"])
349
 
350
  async def get_available_models(self):
351
+ """Get list of available models from ZeroGPU Chat API"""
352
+ try:
353
+ await self._ensure_authenticated()
354
+ if not self.session:
355
+ self.session = aiohttp.ClientSession()
356
+
357
+ tasks_url = f"{self.base_url}/tasks"
358
+ headers = {"Authorization": f"Bearer {self.access_token}"}
359
+
360
+ async with self.session.get(tasks_url, headers=headers) as response:
361
+ if response.status == 401:
362
+ await self._refresh_token()
363
+ headers["Authorization"] = f"Bearer {self.access_token}"
364
+ async with self.session.get(tasks_url, headers=headers) as retry_response:
365
+ retry_response.raise_for_status()
366
+ data = await retry_response.json()
367
+ else:
368
+ response.raise_for_status()
369
+ data = await response.json()
370
+
371
+ tasks = data.get("tasks", {})
372
+ models = [f"ZeroGPU Chat API - {task}: {info.get('model', 'N/A')}"
373
+ for task, info in tasks.items()]
374
+ return models if models else ["ZeroGPU Chat API"]
375
+ except Exception as e:
376
+ logger.error(f"Failed to get available models: {e}")
377
+ return ["ZeroGPU Chat API"]
378
 
379
  async def health_check(self):
380
+ """Perform health check on ZeroGPU Chat API"""
381
  try:
382
+ if not self.session:
383
+ self.session = aiohttp.ClientSession()
 
 
 
 
384
 
385
+ # Check health endpoint (no auth required)
386
+ health_url = f"{self.base_url}/health"
387
+ async with self.session.get(health_url) as response:
388
+ response.raise_for_status()
389
+ data = await response.json()
390
+
391
+ return {
392
+ "provider": "zerogpu_chat_api",
393
+ "status": "healthy" if data.get("status") == "healthy" else "unhealthy",
394
+ "models_ready": data.get("models_ready", False),
395
+ "base_url": self.base_url
396
+ }
397
  except Exception as e:
398
  logger.error(f"Health check failed: {e}")
399
  return {
400
+ "provider": "zerogpu_chat_api",
401
  "status": "unhealthy",
402
  "error": str(e)
403
  }
404
 
405
+ async def __aenter__(self):
406
+ """Async context manager entry"""
407
+ if not self.session:
408
+ self.session = aiohttp.ClientSession()
409
+ return self
410
+
411
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
412
+ """Async context manager exit"""
413
+ if self.session:
414
+ await self.session.close()
415
+ self.session = None
416
+
417
  def prepare_context_for_llm(self, raw_context: Dict, max_tokens: Optional[int] = None,
418
  user_input: Optional[str] = None) -> str:
419
  """