Digital-Galatea / llm_wrapper.py
Your Name
Integrate initialization into app.py, add quantum emotion service, JSON emotion persistence, and improve rate limit handling
8c9977f
raw
history blame
11.5 kB
"""Custom LLM Wrapper - Direct API calls using requests (no LiteLLM)"""
import os
import sys
import logging
import requests
# Add current directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from config import MODEL_CONFIG
class LLMWrapper:
"""Custom LLM wrapper for Gemini and Inflection AI using direct API calls"""
def __init__(self, gemini_model=None, inflection_model=None, config=None):
"""
Initialize LLM Wrapper with models and configuration
Args:
gemini_model: Gemini model name (e.g., 'gemini-2.0-flash-exp')
inflection_model: Inflection AI model name (e.g., 'Pi-3.1')
config: Configuration dict (optional, will load from MODEL_CONFIG if not provided)
"""
self.config = config or MODEL_CONFIG or {}
self.gemini_api_key = os.getenv("GEMINI_API_KEY")
self.inflection_ai_api_key = os.getenv("INFLECTION_AI_API_KEY")
# Set models from parameters or config
if gemini_model:
self.gemini_model = gemini_model
else:
gemini_config = self.config.get('gemini', {}) if self.config else {}
self.gemini_model = gemini_config.get('model', 'gemini-2.0-flash-exp')
if inflection_model:
self.inflection_model = inflection_model
else:
inflection_config = self.config.get('inflection_ai', {}) if self.config else {}
self.inflection_model = inflection_config.get('model', 'Pi-3.1')
# Remove 'gemini/' prefix if present
if self.gemini_model.startswith('gemini/'):
self.gemini_model = self.gemini_model.replace('gemini/', '')
logging.info(f"[LLMWrapper] Initialized with Gemini model: {self.gemini_model}, Inflection model: {self.inflection_model}")
def call_gemini(self, messages, temperature=0.7, max_tokens=1024):
"""
Call Gemini API directly using requests
Args:
messages: List of message dicts with 'role' and 'content'
temperature: Temperature for generation
max_tokens: Maximum tokens to generate
Returns:
Response text or None if failed
"""
if not self.gemini_api_key:
logging.error("[LLMWrapper] GEMINI_API_KEY not found")
return None
# Use the model set during initialization
model = self.gemini_model
# Gemini API endpoint
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent"
headers = {
"Content-Type": "application/json",
"X-goog-api-key": self.gemini_api_key
}
# Convert messages to Gemini format
contents = []
system_instruction = None
for msg in messages:
role = msg.get('role', 'user')
content = msg.get('content', '')
if role == 'system':
system_instruction = content
elif role == 'user':
contents.append({
"role": "user",
"parts": [{"text": content}]
})
elif role == 'assistant':
contents.append({
"role": "model",
"parts": [{"text": content}]
})
# Build request payload
payload = {
"contents": contents,
"generationConfig": {
"temperature": temperature,
"maxOutputTokens": max_tokens
}
}
# Add system instruction if present
if system_instruction:
payload["systemInstruction"] = {
"parts": [{"text": system_instruction}]
}
try:
logging.info(f"[LLMWrapper] Calling Gemini API: {model}")
response = requests.post(url, headers=headers, json=payload, timeout=30)
if response.status_code == 200:
result = response.json()
# Extract text from Gemini response
if 'candidates' in result and len(result['candidates']) > 0:
candidate = result['candidates'][0]
if 'content' in candidate and 'parts' in candidate['content']:
parts = candidate['content']['parts']
if len(parts) > 0 and 'text' in parts[0]:
text = parts[0]['text']
logging.info("[LLMWrapper] ✓ Gemini response received")
return text.strip()
logging.error(f"[LLMWrapper] Unexpected Gemini response format: {result}")
return None
else:
# Raise exception with status code and response text so validation can catch it
error_text = response.text
logging.error(f"[LLMWrapper] Gemini API returned status {response.status_code}: {error_text}")
# Create exception with status code info for validation to catch
api_error = Exception(f"Gemini API status {response.status_code}: {error_text}")
api_error.status_code = response.status_code
api_error.response_text = error_text
raise api_error
except requests.RequestException as e:
# Network/request errors - log and return None
logging.error(f"[LLMWrapper] Network error calling Gemini API: {e}")
return None
except Exception as e:
# Re-raise status code errors so validation can catch them
if hasattr(e, 'status_code'):
raise
# Other errors - log and return None
logging.error(f"[LLMWrapper] Error calling Gemini API: {e}")
return None
def call_inflection_ai(self, context_parts):
"""
Call Inflection AI API directly using requests
Args:
context_parts: List of context dicts with 'text' and 'type'
Returns:
Response text or None if failed
"""
if not self.inflection_ai_api_key:
logging.error("[LLMWrapper] INFLECTION_AI_API_KEY not found")
return None
# Use the model set during initialization
model_config = self.inflection_model
# Get endpoint from config
inflection_config = self.config.get('inflection_ai', {}) if self.config else {}
url = inflection_config.get('api_endpoint', 'https://api.inflection.ai/external/api/inference')
headers = {
"Authorization": f"Bearer {self.inflection_ai_api_key}",
"Content-Type": "application/json"
}
data = {
"context": context_parts,
"config": model_config
}
try:
logging.info(f"[LLMWrapper] Calling Inflection AI API: {model_config}")
logging.info(f"[LLMWrapper] Request URL: {url}")
logging.info(f"[LLMWrapper] Request context parts count: {len(context_parts)}")
logging.debug(f"[LLMWrapper] Request data: {data}")
response = requests.post(url, headers=headers, json=data, timeout=30)
logging.info(f"[LLMWrapper] Response status: {response.status_code}")
if response.status_code == 200:
try:
result = response.json()
except Exception as json_error:
logging.error(f"[LLMWrapper] Failed to parse JSON response: {json_error}")
logging.error(f"[LLMWrapper] Raw response text: {response.text[:500]}")
return None
logging.info(f"[LLMWrapper] Response JSON: {result}")
logging.info(f"[LLMWrapper] Response type: {type(result)}")
logging.info(f"[LLMWrapper] Response keys: {list(result.keys()) if isinstance(result, dict) else 'N/A'}")
# Extract response text - Inflection AI returns text in 'text' field
# Based on actual API response: {"created": ..., "text": "...", "tool_calls": [], "reasoning_content": null}
text = None
if isinstance(result, dict):
# Prioritize 'text' field as that's what the API actually returns
if 'text' in result:
text = result['text']
logging.debug(f"[LLMWrapper] Found text in 'text' field: {text[:100]}...")
elif 'output' in result:
text = result['output']
logging.debug(f"[LLMWrapper] Found text in 'output' field")
elif 'response' in result:
text = result['response']
logging.debug(f"[LLMWrapper] Found text in 'response' field")
elif 'message' in result:
text = result['message']
logging.debug(f"[LLMWrapper] Found text in 'message' field")
else:
# If result is a dict but no known field, try to get first string value
logging.warning(f"[LLMWrapper] No standard text field found, searching for string values...")
for key, value in result.items():
if isinstance(value, str) and value.strip():
text = value
logging.debug(f"[LLMWrapper] Found text in '{key}' field")
break
if not text:
logging.error(f"[LLMWrapper] No text found in response dict. Keys: {list(result.keys())}")
text = str(result)
elif isinstance(result, str):
text = result
logging.debug(f"[LLMWrapper] Response is a string")
else:
logging.warning(f"[LLMWrapper] Unexpected response type: {type(result)}")
text = str(result)
if text and isinstance(text, str) and text.strip():
logging.info(f"[LLMWrapper] ✓ Inflection AI response received: {text[:100]}...")
return text.strip()
else:
logging.error(f"[LLMWrapper] No valid text found in response. Text value: {text}, Type: {type(text)}")
logging.error(f"[LLMWrapper] Full response: {result}")
return None
else:
logging.error(f"[LLMWrapper] Inflection AI API returned status {response.status_code}")
try:
error_detail = response.json()
logging.error(f"[LLMWrapper] Error details: {error_detail}")
except:
logging.error(f"[LLMWrapper] Error response text: {response.text[:500]}")
return None
except Exception as e:
logging.error(f"[LLMWrapper] Error calling Inflection AI API: {e}")
return None