Anirudh Esthuri
Update Gemini model names to use available models (gemini-1.5-pro and gemini-2.0-flash-exp)
1f9a06d
import json
import os
import time
import boto3
import openai
import requests
from dotenv import load_dotenv
from model_config import MODEL_TO_PROVIDER, MODEL_TO_INFERENCE_PROFILE_ARN
# Lazy initialization of Google Gemini client
_google_client = None
def get_google_client():
"""Get or create the Google Gemini client with proper error handling."""
global _google_client
if _google_client is None:
try:
import google.generativeai as genai
except ImportError:
raise ValueError(
"google-generativeai package not installed. "
"Please add 'google-generativeai' to requirements.txt"
)
google_api_key = os.getenv("GOOGLE_API_KEY", "").strip()
if not google_api_key:
raise ValueError(
"Google API key not found. Please set GOOGLE_API_KEY "
"as a secret in Hugging Face Spaces settings."
)
try:
genai.configure(api_key=google_api_key)
_google_client = genai
except Exception as e:
raise ValueError(
f"Failed to initialize Google Gemini client: {str(e)}. "
"Please verify your GOOGLE_API_KEY is correct."
) from e
return _google_client
# ──────────────────────────────────────────────────────────────
# Load environment variables
load_dotenv()
# ──────────────────────────────────────────────────────────────
# ──────────────────────────────────────────────────────────────
# Configuration
# ──────────────────────────────────────────────────────────────
MODEL_STRING = "gpt-4.1-mini" # we default on gpt-4.1-mini
api_key = os.getenv("MODEL_API_KEY")
client = openai.OpenAI(api_key=api_key)
# Lazy initialization of bedrock client to avoid errors if credentials are missing
_bedrock_runtime = None
def get_bedrock_client():
"""Get or create the Bedrock runtime client with proper error handling."""
global _bedrock_runtime
if _bedrock_runtime is None:
aws_access_key = os.getenv("AWS_ACCESS_KEY_ID", "").strip()
aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY", "").strip()
aws_region = os.getenv("AWS_DEFAULT_REGION", "us-east-1").strip()
if not aws_access_key or not aws_secret_key:
raise ValueError(
"AWS credentials not found. Please set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY "
"as secrets in Hugging Face Spaces settings. "
f"Current values: AWS_ACCESS_KEY_ID={'***' if aws_access_key else 'EMPTY'}, "
f"AWS_SECRET_ACCESS_KEY={'***' if aws_secret_key else 'EMPTY'}"
)
try:
_bedrock_runtime = boto3.client(
"bedrock-runtime",
region_name=aws_region,
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key
)
except Exception as e:
raise ValueError(
f"Failed to initialize AWS Bedrock client: {str(e)}. "
"Please verify your AWS credentials are valid and have Bedrock access."
) from e
return _bedrock_runtime
# ──────────────────────────────────────────────────────────────
# Model switcher
# ──────────────────────────────────────────────────────────────
def set_model(model_id: str) -> None:
global MODEL_STRING
MODEL_STRING = model_id
print(f"Model changed to: {model_id}")
def set_provider(provider: str) -> None:
global PROVIDER
# ──────────────────────────────────────────────────────────────
# High-level Chat wrapper
# ──────────────────────────────────────────────────────────────
def chat(messages, persona):
provider = MODEL_TO_PROVIDER[MODEL_STRING]
if provider == "openai":
print("Using openai: ", MODEL_STRING)
t0 = time.time()
# Add system prompt for better behavior
system_prompt = ""
# Prepare messages with system prompt
chat_messages = [{"role": "system", "content": system_prompt}]
for msg in messages:
chat_messages.append({
"role": msg["role"],
"content": msg["content"]
})
request_kwargs = {
"model": MODEL_STRING,
"messages": chat_messages,
"max_completion_tokens": 4000,
}
# Some newer OpenAI models only support the default temperature.
if MODEL_STRING not in {"gpt-5", "gpt-5-nano", "gpt-5-mini"}:
request_kwargs["temperature"] = 0.3
response = client.chat.completions.create(**request_kwargs)
dt = time.time() - t0
text = response.choices[0].message.content.strip()
# Calculate tokens
total_tok = response.usage.total_tokens if response.usage else len(text.split())
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
elif provider == "anthropic":
print("Using anthropic: ", MODEL_STRING)
t0 = time.time()
# Add system prompt for better behavior
system_prompt = ""
claude_messages = [
{"role": m["role"], "content": m["content"]} for m in messages
]
try:
bedrock_runtime = get_bedrock_client()
# Use inference profile ARN if available (for provisioned throughput models)
# Otherwise use modelId (for on-demand models)
invoke_kwargs = {
"contentType": "application/json",
"accept": "application/json",
"body": json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"system": system_prompt,
"messages": claude_messages,
"max_tokens": 4000, # Much higher limit for longer responses
"temperature": 0.3, # Lower temperature for more focused responses
}
),
}
# Check if this model has an inference profile ARN (provisioned throughput)
# For provisioned throughput, use the ARN as the modelId
if MODEL_STRING in MODEL_TO_INFERENCE_PROFILE_ARN:
invoke_kwargs["modelId"] = MODEL_TO_INFERENCE_PROFILE_ARN[MODEL_STRING]
else:
invoke_kwargs["modelId"] = MODEL_STRING
response = bedrock_runtime.invoke_model(**invoke_kwargs)
dt = time.time() - t0
body = json.loads(response["body"].read())
except ValueError as e:
# Re-raise ValueError (credential errors) as-is
raise
except Exception as e:
error_msg = str(e)
if "ValidationException" in error_msg and "model identifier is invalid" in error_msg:
raise ValueError(
f"Invalid Bedrock model ID: '{MODEL_STRING}'. "
f"Error: {error_msg}. "
"Please verify the model ID is correct and the model is available in your AWS region. "
"Common Claude model IDs: 'anthropic.claude-3-5-sonnet-20241022-v2' or 'anthropic.claude-3-haiku-20240307-v1'"
) from e
elif "UnrecognizedClientException" in error_msg or "invalid" in error_msg.lower():
raise ValueError(
f"AWS Bedrock authentication failed: {error_msg}. "
"Please verify your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY secrets "
"are correct and have Bedrock access permissions."
) from e
raise
text = "".join(
part["text"] for part in body["content"] if part["type"] == "text"
).strip()
total_tok = len(text.split())
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
elif provider == "google":
print("Using google (Gemini): ", MODEL_STRING)
t0 = time.time()
try:
genai = get_google_client()
# Get the model
model = genai.GenerativeModel(MODEL_STRING)
# Convert messages to Gemini format
# Gemini API expects a chat history format with "user" and "model" roles
chat_history = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
# Skip system messages (we'll handle them separately)
if role == "system":
continue
# Gemini uses "model" instead of "assistant"
if role == "assistant":
role = "model"
chat_history.append({
"role": role,
"parts": [content]
})
# Separate history from the last user message
if chat_history and chat_history[-1]["role"] == "user":
history = chat_history[:-1]
last_user_message = chat_history[-1]["parts"][0]
else:
history = []
last_user_message = chat_history[-1]["parts"][0] if chat_history else ""
# Start a chat session with history
chat = model.start_chat(history=history)
# Send the last message
response = chat.send_message(
last_user_message,
generation_config=genai.types.GenerationConfig(
max_output_tokens=4000,
temperature=0.3,
)
)
dt = time.time() - t0
text = response.text.strip()
# Calculate tokens (approximate)
total_tok = len(text.split())
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
except ValueError as e:
# Re-raise ValueError (credential errors) as-is
raise
except Exception as e:
error_msg = str(e)
if "API key" in error_msg or "invalid" in error_msg.lower() or "401" in error_msg or "403" in error_msg:
raise ValueError(
f"Google API authentication failed: {error_msg}. "
"Please verify your GOOGLE_API_KEY secret is correct and has Gemini API access."
) from e
elif "not found" in error_msg.lower() or "404" in error_msg:
raise ValueError(
f"Invalid Gemini model ID: '{MODEL_STRING}'. "
f"Error: {error_msg}. "
"Please verify the model ID is correct. "
"Common Gemini model IDs: 'gemini-1.5-pro', 'gemini-1.5-flash', 'gemini-2.0-flash-exp', 'gemini-pro'"
) from e
raise
elif provider == "deepseek":
print("Using deepseek: ", MODEL_STRING)
t0 = time.time()
system_prompt = (
""
)
ds_messages = [
{
"role": "system",
"content": [{"type": "text", "text": system_prompt}],
}
]
for msg in messages:
role = msg.get("role", "user")
ds_messages.append(
{
"role": role,
"content": [{"type": "text", "text": msg["content"]}],
}
)
try:
bedrock_runtime = get_bedrock_client()
response = bedrock_runtime.invoke_model(
modelId=MODEL_STRING,
contentType="application/json",
accept="application/json",
body=json.dumps(
{
"messages": ds_messages,
"max_completion_tokens": 500,
"temperature": 0.5,
"top_p": 0.9,
}
),
)
dt = time.time() - t0
body = json.loads(response["body"].read())
except ValueError as e:
# Re-raise ValueError (credential errors) as-is
raise
except Exception as e:
error_msg = str(e)
if "ValidationException" in error_msg and "model identifier is invalid" in error_msg:
raise ValueError(
f"Invalid Bedrock model ID: '{MODEL_STRING}'. "
f"Error: {error_msg}. "
"Please verify the model ID is correct and the model is available in your AWS region."
) from e
elif "UnrecognizedClientException" in error_msg or "invalid" in error_msg.lower():
raise ValueError(
f"AWS Bedrock authentication failed: {error_msg}. "
"Please verify your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY secrets "
"are correct and have Bedrock access permissions."
) from e
raise
outputs = body.get("output", [])
text_chunks = []
for item in outputs:
for content in item.get("content", []):
chunk_text = content.get("text") or content.get("output_text")
if chunk_text:
text_chunks.append(chunk_text)
text = "".join(text_chunks).strip()
if not text and "response" in body:
text = body["response"].get("output_text", "").strip()
total_tok = len(text.split())
return text, dt, total_tok, (total_tok / dt if dt else total_tok)
# elif provider == "meta":
# print("Using meta (LLaMA): ", MODEL_STRING)
# t0 = time.time()
# # Add system prompt for better behavior
# system_prompt = ""
# # Format conversation properly for Llama3
# formatted_prompt = "<|begin_of_text|>"
# # Add system prompt
# formatted_prompt += "<|start_header_id|>system<|end_header_id|>\n" + system_prompt + "<|eot_id|>\n"
# # Add conversation history
# for msg in messages:
# if msg["role"] == "user":
# formatted_prompt += "<|start_header_id|>user<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
# elif msg["role"] == "assistant":
# formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n" + msg["content"] + "<|eot_id|>\n"
# # Add final assistant prompt
# formatted_prompt += "<|start_header_id|>assistant<|end_header_id|>\n"
# response = bedrock_runtime.invoke_model(
# modelId=MODEL_STRING,
# contentType="application/json",
# accept="application/json",
# body=json.dumps(
# {
# "prompt": formatted_prompt,
# "max_gen_len": 512, # Shorter responses
# "temperature": 0.3, # Lower temperature for more focused responses
# }
# ),
# )
# dt = time.time() - t0
# body = json.loads(response["body"].read())
# text = body.get("generation", "").strip()
# total_tok = len(text.split())
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
# elif provider == "mistral":
# print("Using mistral: ", MODEL_STRING)
# t0 = time.time()
# prompt = messages[-1]["content"]
# formatted_prompt = f"<s>[INST] {prompt} [/INST]"
# response = bedrock_runtime.invoke_model(
# modelId=MODEL_STRING,
# contentType="application/json",
# accept="application/json",
# body=json.dumps(
# {"prompt": formatted_prompt, "max_tokens": 512, "temperature": 0.5}
# ),
# )
# dt = time.time() - t0
# body = json.loads(response["body"].read())
# text = body["outputs"][0]["text"].strip()
# total_tok = len(text.split())
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
# elif provider == "ollama":
# print("Using ollama: ", MODEL_STRING)
# t0 = time.time()
# # Format messages for Ollama API with system prompt
# ollama_messages = []
# # Add system prompt for better behavior
# system_prompt = ""
# ollama_messages.append({
# "role": "system",
# "content": system_prompt
# })
# for msg in messages:
# ollama_messages.append({
# "role": msg["role"],
# "content": msg["content"]
# })
# # Make request to Ollama API
# response = requests.post(
# f"{OLLAMA_BASE_URL}/api/chat",
# json={
# "model": MODEL_STRING,
# "messages": ollama_messages,
# "stream": False,
# "options": {
# "temperature": 0.3, # Lower temperature for more focused responses
# # "num_predict": 4000, # Much higher limit for longer responses
# "top_p": 0.9,
# "repeat_penalty": 1.1
# }
# },
# timeout=60
# )
# dt = time.time() - t0
# if response.status_code == 200:
# result = response.json()
# text = result["message"]["content"].strip()
# total_tok = len(text.split())
# return text, dt, total_tok, (total_tok / dt if dt else total_tok)
# else:
# raise Exception(f"Ollama API error: {response.status_code} - {response.text}")
# ──────────────────────────────────────────────────────────────
# Diagnostics / CLI test
# ──────────────────────────────────────────────────────────────
def check_credentials():
# # Check if using Ollama (no API key required)
# if MODEL_TO_PROVIDER.get(MODEL_STRING) == "ollama":
# # Test Ollama connection
# try:
# response = requests.get(f"{OLLAMA_BASE_URL}/api/tags", timeout=5)
# if response.status_code == 200:
# print("Ollama connection successful")
# return True
# else:
# print(f"Ollama connection failed: {response.status_code}")
# return False
# except Exception as e:
# print(f"Ollama connection failed: {e}")
# return False
# Check if using Bedrock providers (anthropic, meta, mistral, deepseek)
bedrock_providers = ["anthropic"]
if MODEL_TO_PROVIDER.get(MODEL_STRING) in bedrock_providers:
# Test AWS Bedrock connection by trying to invoke a simple model
try:
bedrock_runtime = get_bedrock_client()
# Try a simple test invocation to verify credentials
test_model = "anthropic.claude-haiku-4-5-20251001-v1:0"
test_kwargs = {
"contentType": "application/json",
"accept": "application/json",
"body": json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"messages": [{"role": "user", "content": "test"}],
"max_tokens": 10,
"temperature": 0.1
})
}
# Use inference profile ARN if available (use ARN as modelId for provisioned throughput)
if test_model in MODEL_TO_INFERENCE_PROFILE_ARN:
test_kwargs["modelId"] = MODEL_TO_INFERENCE_PROFILE_ARN[test_model]
else:
test_kwargs["modelId"] = test_model
test_response = bedrock_runtime.invoke_model(**test_kwargs)
print("Bedrock connection successful")
return True
except Exception as e:
print(f"Bedrock connection failed: {e}")
print("Make sure AWS credentials are configured and you have access to Bedrock")
return False
# For OpenAI, check API key
if MODEL_TO_PROVIDER.get(MODEL_STRING) == "openai":
required = ["MODEL_API_KEY"]
missing = [var for var in required if not os.getenv(var)]
if missing:
print(f"Missing environment variables: {missing}")
return False
return True
# For Google Gemini, check API key
if MODEL_TO_PROVIDER.get(MODEL_STRING) == "google":
required = ["GOOGLE_API_KEY"]
missing = [var for var in required if not os.getenv(var)]
if missing:
print(f"Missing environment variables: {missing}")
return False
# Try to initialize the client to verify the key works
try:
get_google_client()
return True
except Exception as e:
print(f"Google API client initialization failed: {e}")
return False
return True
def test_chat():
print("Testing chat...")
try:
test_messages = [
{
"role": "user",
"content": "Hello! Please respond with just 'Test successful'.",
}
]
text, latency, tokens, tps = chat(test_messages)
print(f"Test passed! {text} {latency:.2f}s {tokens} ⚑ {tps:.1f} tps")
except Exception as e:
print(f"Test failed: {e}")
if __name__ == "__main__":
print("running diagnostics")
if check_credentials():
test_chat()
print("\nDone.")