HonestAI / src /config.py
JatsTheAIGen's picture
Replace Novita AI with ZeroGPU Chat API (RunPod)
0747201
"""
Configuration Management Module
This module provides secure, robust configuration management with:
- Environment variable handling with secure defaults
- Cache directory management with automatic fallbacks
- Comprehensive logging and error handling
- Security best practices for sensitive data
- Backward compatibility with existing code
Environment Variables:
HF_TOKEN: HuggingFace API token (required for API access)
HF_HOME: Primary cache directory for HuggingFace models
TRANSFORMERS_CACHE: Alternative cache directory path
MAX_WORKERS: Maximum worker threads (default: 4)
CACHE_TTL: Cache time-to-live in seconds (default: 3600)
DB_PATH: Database file path (default: sessions.db)
LOG_LEVEL: Logging level (default: INFO)
LOG_FORMAT: Log format (default: json)
Security Notes:
- Never commit .env files to version control
- Use environment variables for all sensitive data
- Cache directories are automatically secured with proper permissions
"""
import os
import logging
from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import Field, validator
# Configure logging
logger = logging.getLogger(__name__)
class CacheDirectoryManager:
"""
Manages cache directory with secure fallback mechanism.
Implements:
- Multi-level fallback strategy
- Permission validation
- Automatic directory creation
- Security best practices
"""
@staticmethod
def get_cache_directory() -> str:
"""
Get cache directory with secure fallback chain.
Priority order:
1. HF_HOME environment variable
2. TRANSFORMERS_CACHE environment variable
3. User home directory (~/.cache/huggingface)
4. User-specific fallback directory
5. Temporary directory (last resort)
Returns:
str: Path to writable cache directory
"""
# Priority order for cache directory
# In Docker, ~ may resolve to / which causes permission issues
# So we prefer /tmp over ~/.cache in containerized environments
is_docker = os.path.exists("/.dockerenv") or os.path.exists("/tmp")
cache_candidates = [
os.getenv("HF_HOME"),
os.getenv("TRANSFORMERS_CACHE"),
# In Docker, prefer /tmp over ~/.cache
"/tmp/huggingface_cache" if is_docker else None,
os.path.join(os.path.expanduser("~"), ".cache", "huggingface") if os.path.expanduser("~") and not is_docker else None,
os.path.join(os.path.expanduser("~"), ".cache", "huggingface_fallback") if os.path.expanduser("~") and not is_docker else None,
"/tmp/huggingface_cache" if not is_docker else None,
"/tmp/huggingface" # Final fallback
]
for cache_dir in cache_candidates:
if not cache_dir:
continue
try:
# Ensure directory exists
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
# Set secure permissions (rwxr-xr-x)
try:
os.chmod(cache_path, 0o755)
except (OSError, PermissionError):
# If we can't set permissions, continue if directory is writable
pass
# Test write access
test_file = cache_path / ".write_test"
try:
test_file.write_text("test")
test_file.unlink()
logger.info(f"✓ Cache directory verified: {cache_dir}")
return str(cache_path)
except (PermissionError, OSError) as e:
logger.debug(f"Write test failed for {cache_dir}: {e}")
continue
except (PermissionError, OSError) as e:
logger.debug(f"Could not create/access {cache_dir}: {e}")
continue
# If all candidates failed, use emergency fallback
fallback = "/tmp/huggingface_emergency"
try:
Path(fallback).mkdir(parents=True, exist_ok=True)
logger.warning(f"Using emergency fallback cache: {fallback}")
return fallback
except Exception as e:
logger.error(f"Emergency fallback also failed: {e}")
# Return a default that will fail gracefully later
return "/tmp/huggingface"
class Settings(BaseSettings):
"""
Application settings with secure defaults and validation.
Backward Compatibility:
- All existing attributes are preserved
- hf_token is accessible as string (via property)
- hf_cache_dir is accessible as property (works like before)
- All defaults match original implementation
"""
# ==================== HuggingFace Configuration ====================
# BACKWARD COMPAT: hf_token as regular field (backward compatible)
hf_token: str = Field(
default="",
description="HuggingFace API token",
env="HF_TOKEN"
)
@validator("hf_token", pre=True)
def validate_hf_token(cls, v):
"""Validate HF token (backward compatible)"""
if v is None:
return ""
token = str(v) if v else ""
if not token:
logger.debug("HF_TOKEN not set")
return token
@property
def hf_cache_dir(self) -> str:
"""
Get cache directory with automatic fallback and validation.
BACKWARD COMPAT: Works like the original hf_cache_dir field.
Returns:
str: Path to writable cache directory
"""
if not hasattr(self, '_cached_cache_dir'):
try:
self._cached_cache_dir = CacheDirectoryManager.get_cache_directory()
except Exception as e:
logger.error(f"Cache directory setup failed: {e}")
# Fallback to original default
fallback = os.getenv("HF_HOME", "/tmp/huggingface")
Path(fallback).mkdir(parents=True, exist_ok=True)
self._cached_cache_dir = fallback
return self._cached_cache_dir
# ==================== ZeroGPU Chat API Configuration ====================
zerogpu_base_url: str = Field(
default="http://your-pod-ip:8000",
description="ZeroGPU Chat API base URL (RunPod endpoint)",
env="ZEROGPU_BASE_URL"
)
zerogpu_email: str = Field(
default="",
description="ZeroGPU Chat API email for authentication (required)",
env="ZEROGPU_EMAIL"
)
zerogpu_password: str = Field(
default="",
description="ZeroGPU Chat API password for authentication (required)",
env="ZEROGPU_PASSWORD"
)
# Token Allocation Configuration
user_input_max_tokens: int = Field(
default=32000,
description="Maximum tokens dedicated for user input (prioritized over context)",
env="USER_INPUT_MAX_TOKENS"
)
context_preparation_budget: int = Field(
default=115000,
description="Maximum tokens for context preparation (includes user input + context)",
env="CONTEXT_PREPARATION_BUDGET"
)
context_pruning_threshold: int = Field(
default=115000,
description="Context pruning threshold (should match context_preparation_budget)",
env="CONTEXT_PRUNING_THRESHOLD"
)
prioritize_user_input: bool = Field(
default=True,
description="Always prioritize user input over historical context",
env="PRIORITIZE_USER_INPUT"
)
# Model Context Window Configuration
zerogpu_model_context_window: int = Field(
default=8192,
description="Maximum context window for ZeroGPU Chat API model (input + output tokens). Adjust based on your deployed model.",
env="ZEROGPU_MODEL_CONTEXT_WINDOW"
)
@validator("zerogpu_base_url", pre=True)
def validate_zerogpu_base_url(cls, v):
"""Validate ZeroGPU base URL"""
if v is None:
return "http://your-pod-ip:8000"
url = str(v).strip()
# Remove trailing slash
if url.endswith('/'):
url = url[:-1]
return url
@validator("zerogpu_email", pre=True)
def validate_zerogpu_email(cls, v):
"""Validate ZeroGPU email"""
if v is None:
return ""
email = str(v).strip()
if email and '@' not in email:
logger.warning("ZEROGPU_EMAIL may not be a valid email address")
return email
@validator("zerogpu_password", pre=True)
def validate_zerogpu_password(cls, v):
"""Validate ZeroGPU password"""
if v is None:
return ""
return str(v).strip()
@validator("user_input_max_tokens", pre=True)
def validate_user_input_tokens(cls, v):
"""Validate user input token limit"""
val = int(v) if v else 32000
return max(1000, min(50000, val)) # Allow up to 50K for large inputs
@validator("context_preparation_budget", pre=True)
def validate_context_budget(cls, v):
"""Validate context preparation budget"""
val = int(v) if v else 115000
return max(4000, min(125000, val)) # Allow up to 125K for 128K context window
@validator("context_pruning_threshold", pre=True)
def validate_pruning_threshold(cls, v):
"""Validate context pruning threshold"""
val = int(v) if v else 115000
return max(4000, min(125000, val)) # Match context_preparation_budget limits
@validator("zerogpu_model_context_window", pre=True)
def validate_context_window(cls, v):
"""Validate context window size"""
val = int(v) if v else 8192
return max(1000, min(200000, val)) # Support up to 200K for future models
# ==================== Model Configuration ====================
default_model: str = Field(
default="meta-llama/Llama-3.1-8B-Instruct:cerebras",
description="Primary model for reasoning tasks (Cerebras deployment with 4-bit quantization)"
)
embedding_model: str = Field(
default="intfloat/e5-large-v2",
description="Model for embeddings (upgraded: 1024-dim embeddings)"
)
classification_model: str = Field(
default="meta-llama/Llama-3.1-8B-Instruct:cerebras",
description="Model for classification tasks (Cerebras deployment)"
)
# ==================== Performance Configuration ====================
max_workers: int = Field(
default=4,
description="Maximum worker threads for parallel processing",
env="MAX_WORKERS"
)
@validator("max_workers", pre=True)
def validate_max_workers(cls, v):
"""Validate and convert max_workers (backward compatible)"""
if v is None:
return 4
if isinstance(v, str):
try:
v = int(v)
except ValueError:
logger.warning(f"Invalid MAX_WORKERS value: {v}, using default 4")
return 4
try:
val = int(v)
return max(1, min(16, val)) # Clamp between 1 and 16
except (ValueError, TypeError):
return 4
cache_ttl: int = Field(
default=3600,
description="Cache time-to-live in seconds",
env="CACHE_TTL"
)
@validator("cache_ttl", pre=True)
def validate_cache_ttl(cls, v):
"""Validate cache TTL (backward compatible)"""
if v is None:
return 3600
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 3600
try:
return max(0, int(v))
except (ValueError, TypeError):
return 3600
# ==================== Database Configuration ====================
db_path: str = Field(
default="sessions.db",
description="Path to SQLite database file",
env="DB_PATH"
)
@validator("db_path", pre=True)
def validate_db_path(cls, v):
"""Validate db_path with Docker fallback (backward compatible)"""
if v is None:
# Check if we're in Docker (HF Spaces) - if so, use /tmp
if os.path.exists("/.dockerenv") or os.path.exists("/tmp"):
return "/tmp/sessions.db"
return "sessions.db"
return str(v)
faiss_index_path: str = Field(
default="embeddings.faiss",
description="Path to FAISS index file",
env="FAISS_INDEX_PATH"
)
@validator("faiss_index_path", pre=True)
def validate_faiss_path(cls, v):
"""Validate faiss path with Docker fallback (backward compatible)"""
if v is None:
# Check if we're in Docker (HF Spaces) - if so, use /tmp
if os.path.exists("/.dockerenv") or os.path.exists("/tmp"):
return "/tmp/embeddings.faiss"
return "embeddings.faiss"
return str(v)
# ==================== Session Configuration ====================
session_timeout: int = Field(
default=3600,
description="Session timeout in seconds",
env="SESSION_TIMEOUT"
)
@validator("session_timeout", pre=True)
def validate_session_timeout(cls, v):
"""Validate session timeout (backward compatible)"""
if v is None:
return 3600
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 3600
try:
return max(60, int(v))
except (ValueError, TypeError):
return 3600
max_session_size_mb: int = Field(
default=10,
description="Maximum session size in megabytes",
env="MAX_SESSION_SIZE_MB"
)
@validator("max_session_size_mb", pre=True)
def validate_max_session_size(cls, v):
"""Validate max session size (backward compatible)"""
if v is None:
return 10
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 10
try:
return max(1, min(100, int(v)))
except (ValueError, TypeError):
return 10
# ==================== Mobile Optimization ====================
mobile_max_tokens: int = Field(
default=800,
description="Maximum tokens for mobile responses",
env="MOBILE_MAX_TOKENS"
)
@validator("mobile_max_tokens", pre=True)
def validate_mobile_max_tokens(cls, v):
"""Validate mobile max tokens (backward compatible)"""
if v is None:
return 800
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 800
try:
return max(100, min(2000, int(v)))
except (ValueError, TypeError):
return 800
mobile_timeout: int = Field(
default=15000,
description="Mobile request timeout in milliseconds",
env="MOBILE_TIMEOUT"
)
@validator("mobile_timeout", pre=True)
def validate_mobile_timeout(cls, v):
"""Validate mobile timeout (backward compatible)"""
if v is None:
return 15000
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 15000
try:
return max(5000, min(60000, int(v)))
except (ValueError, TypeError):
return 15000
# ==================== API Configuration ====================
gradio_port: int = Field(
default=7860,
description="Gradio server port",
env="GRADIO_PORT"
)
@validator("gradio_port", pre=True)
def validate_gradio_port(cls, v):
"""Validate gradio port (backward compatible)"""
if v is None:
return 7860
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 7860
try:
return max(1024, min(65535, int(v)))
except (ValueError, TypeError):
return 7860
gradio_host: str = Field(
default="0.0.0.0",
description="Gradio server host",
env="GRADIO_HOST"
)
# ==================== Logging Configuration ====================
log_level: str = Field(
default="INFO",
description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
env="LOG_LEVEL"
)
@validator("log_level")
def validate_log_level(cls, v):
"""Validate log level (backward compatible)"""
if not v:
return "INFO"
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if v.upper() not in valid_levels:
logger.warning(f"Invalid log level: {v}, using INFO")
return "INFO"
return v.upper()
log_format: str = Field(
default="json",
description="Log format (json or text)",
env="LOG_FORMAT"
)
@validator("log_format")
def validate_log_format(cls, v):
"""Validate log format (backward compatible)"""
if not v:
return "json"
if v.lower() not in ["json", "text"]:
logger.warning(f"Invalid log format: {v}, using json")
return "json"
return v.lower()
# ==================== Pydantic Configuration ====================
class Config:
"""Pydantic configuration"""
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
validate_assignment = True
# Allow extra fields for backward compatibility
extra = "ignore"
# ==================== Utility Methods ====================
def validate_configuration(self) -> bool:
"""
Validate configuration and log status.
Returns:
bool: True if configuration is valid, False otherwise
"""
try:
# Validate cache directory
cache_dir = self.hf_cache_dir
if logger.isEnabledFor(logging.INFO):
logger.info("Configuration validated:")
logger.info(f" - Cache directory: {cache_dir}")
logger.info(f" - Max workers: {self.max_workers}")
logger.info(f" - Log level: {self.log_level}")
logger.info(f" - HF token: {'Set' if self.hf_token else 'Not set'}")
return True
except Exception as e:
logger.error(f"Configuration validation failed: {e}")
return False
# ==================== Global Settings Instance ====================
def get_settings() -> Settings:
"""
Get or create global settings instance.
Returns:
Settings: Global settings instance
Note:
This function ensures settings are loaded once and cached.
"""
if not hasattr(get_settings, '_instance'):
get_settings._instance = Settings()
# Validate on first load (non-blocking)
try:
get_settings._instance.validate_configuration()
except Exception as e:
logger.warning(f"Configuration validation warning: {e}")
return get_settings._instance
# Create global settings instance (backward compatible)
settings = get_settings()
# Log configuration on import (at INFO level, non-blocking)
if logger.isEnabledFor(logging.INFO):
try:
logger.info("=" * 60)
logger.info("Configuration Loaded")
logger.info("=" * 60)
logger.info(f"Cache directory: {settings.hf_cache_dir}")
logger.info(f"Max workers: {settings.max_workers}")
logger.info(f"Log level: {settings.log_level}")
logger.info("=" * 60)
except Exception as e:
logger.debug(f"Configuration logging skipped: {e}")