Integrate Novita AI as exclusive inference provider - Add Novita AI API integration with DeepSeek-R1-Distill-Qwen-7B model - Remove all local model dependencies - Optimize token allocation for user inputs and context - Add Anaconda environment setup files - Add comprehensive test scripts and documentation
927854c
| """ | |
| Configuration Management Module | |
| This module provides secure, robust configuration management with: | |
| - Environment variable handling with secure defaults | |
| - Cache directory management with automatic fallbacks | |
| - Comprehensive logging and error handling | |
| - Security best practices for sensitive data | |
| - Backward compatibility with existing code | |
| Environment Variables: | |
| HF_TOKEN: HuggingFace API token (required for API access) | |
| HF_HOME: Primary cache directory for HuggingFace models | |
| TRANSFORMERS_CACHE: Alternative cache directory path | |
| MAX_WORKERS: Maximum worker threads (default: 4) | |
| CACHE_TTL: Cache time-to-live in seconds (default: 3600) | |
| DB_PATH: Database file path (default: sessions.db) | |
| LOG_LEVEL: Logging level (default: INFO) | |
| LOG_FORMAT: Log format (default: json) | |
| Security Notes: | |
| - Never commit .env files to version control | |
| - Use environment variables for all sensitive data | |
| - Cache directories are automatically secured with proper permissions | |
| """ | |
| import os | |
| import logging | |
| from pathlib import Path | |
| from typing import Optional | |
| from pydantic_settings import BaseSettings | |
| from pydantic import Field, validator | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| class CacheDirectoryManager: | |
| """ | |
| Manages cache directory with secure fallback mechanism. | |
| Implements: | |
| - Multi-level fallback strategy | |
| - Permission validation | |
| - Automatic directory creation | |
| - Security best practices | |
| """ | |
| def get_cache_directory() -> str: | |
| """ | |
| Get cache directory with secure fallback chain. | |
| Priority order: | |
| 1. HF_HOME environment variable | |
| 2. TRANSFORMERS_CACHE environment variable | |
| 3. User home directory (~/.cache/huggingface) | |
| 4. User-specific fallback directory | |
| 5. Temporary directory (last resort) | |
| Returns: | |
| str: Path to writable cache directory | |
| """ | |
| # Priority order for cache directory | |
| # In Docker, ~ may resolve to / which causes permission issues | |
| # So we prefer /tmp over ~/.cache in containerized environments | |
| is_docker = os.path.exists("/.dockerenv") or os.path.exists("/tmp") | |
| cache_candidates = [ | |
| os.getenv("HF_HOME"), | |
| os.getenv("TRANSFORMERS_CACHE"), | |
| # In Docker, prefer /tmp over ~/.cache | |
| "/tmp/huggingface_cache" if is_docker else None, | |
| os.path.join(os.path.expanduser("~"), ".cache", "huggingface") if os.path.expanduser("~") and not is_docker else None, | |
| os.path.join(os.path.expanduser("~"), ".cache", "huggingface_fallback") if os.path.expanduser("~") and not is_docker else None, | |
| "/tmp/huggingface_cache" if not is_docker else None, | |
| "/tmp/huggingface" # Final fallback | |
| ] | |
| for cache_dir in cache_candidates: | |
| if not cache_dir: | |
| continue | |
| try: | |
| # Ensure directory exists | |
| cache_path = Path(cache_dir) | |
| cache_path.mkdir(parents=True, exist_ok=True) | |
| # Set secure permissions (rwxr-xr-x) | |
| try: | |
| os.chmod(cache_path, 0o755) | |
| except (OSError, PermissionError): | |
| # If we can't set permissions, continue if directory is writable | |
| pass | |
| # Test write access | |
| test_file = cache_path / ".write_test" | |
| try: | |
| test_file.write_text("test") | |
| test_file.unlink() | |
| logger.info(f"✓ Cache directory verified: {cache_dir}") | |
| return str(cache_path) | |
| except (PermissionError, OSError) as e: | |
| logger.debug(f"Write test failed for {cache_dir}: {e}") | |
| continue | |
| except (PermissionError, OSError) as e: | |
| logger.debug(f"Could not create/access {cache_dir}: {e}") | |
| continue | |
| # If all candidates failed, use emergency fallback | |
| fallback = "/tmp/huggingface_emergency" | |
| try: | |
| Path(fallback).mkdir(parents=True, exist_ok=True) | |
| logger.warning(f"Using emergency fallback cache: {fallback}") | |
| return fallback | |
| except Exception as e: | |
| logger.error(f"Emergency fallback also failed: {e}") | |
| # Return a default that will fail gracefully later | |
| return "/tmp/huggingface" | |
| class Settings(BaseSettings): | |
| """ | |
| Application settings with secure defaults and validation. | |
| Backward Compatibility: | |
| - All existing attributes are preserved | |
| - hf_token is accessible as string (via property) | |
| - hf_cache_dir is accessible as property (works like before) | |
| - All defaults match original implementation | |
| """ | |
| # ==================== HuggingFace Configuration ==================== | |
| # BACKWARD COMPAT: hf_token as regular field (backward compatible) | |
| hf_token: str = Field( | |
| default="", | |
| description="HuggingFace API token", | |
| env="HF_TOKEN" | |
| ) | |
| def validate_hf_token(cls, v): | |
| """Validate HF token (backward compatible)""" | |
| if v is None: | |
| return "" | |
| token = str(v) if v else "" | |
| if not token: | |
| logger.debug("HF_TOKEN not set") | |
| return token | |
| def hf_cache_dir(self) -> str: | |
| """ | |
| Get cache directory with automatic fallback and validation. | |
| BACKWARD COMPAT: Works like the original hf_cache_dir field. | |
| Returns: | |
| str: Path to writable cache directory | |
| """ | |
| if not hasattr(self, '_cached_cache_dir'): | |
| try: | |
| self._cached_cache_dir = CacheDirectoryManager.get_cache_directory() | |
| except Exception as e: | |
| logger.error(f"Cache directory setup failed: {e}") | |
| # Fallback to original default | |
| fallback = os.getenv("HF_HOME", "/tmp/huggingface") | |
| Path(fallback).mkdir(parents=True, exist_ok=True) | |
| self._cached_cache_dir = fallback | |
| return self._cached_cache_dir | |
| # ==================== Novita AI Configuration ==================== | |
| novita_api_key: str = Field( | |
| default="", | |
| description="Novita AI API key (required)", | |
| env="NOVITA_API_KEY" | |
| ) | |
| novita_base_url: str = Field( | |
| default="https://api.novita.ai/dedicated/v1/openai", | |
| description="Novita AI dedicated endpoint base URL", | |
| env="NOVITA_BASE_URL" | |
| ) | |
| novita_model: str = Field( | |
| default="deepseek-ai/DeepSeek-R1-Distill-Qwen-7B:de-1a706eeafbf3ebc2", | |
| description="Novita AI dedicated endpoint model ID", | |
| env="NOVITA_MODEL" | |
| ) | |
| # DeepSeek-R1 optimized settings | |
| deepseek_r1_temperature: float = Field( | |
| default=0.6, | |
| description="Temperature for DeepSeek-R1 models (0.5-0.7 range, 0.6 recommended)", | |
| env="DEEPSEEK_R1_TEMPERATURE" | |
| ) | |
| deepseek_r1_force_reasoning: bool = Field( | |
| default=True, | |
| description="Force DeepSeek-R1 to start with reasoning trigger", | |
| env="DEEPSEEK_R1_FORCE_REASONING" | |
| ) | |
| # Token Allocation Configuration | |
| user_input_max_tokens: int = Field( | |
| default=8000, | |
| description="Maximum tokens dedicated for user input (prioritized over context)", | |
| env="USER_INPUT_MAX_TOKENS" | |
| ) | |
| context_preparation_budget: int = Field( | |
| default=28000, | |
| description="Maximum tokens for context preparation (includes user input + context)", | |
| env="CONTEXT_PREPARATION_BUDGET" | |
| ) | |
| context_pruning_threshold: int = Field( | |
| default=28000, | |
| description="Context pruning threshold (should match context_preparation_budget)", | |
| env="CONTEXT_PRUNING_THRESHOLD" | |
| ) | |
| prioritize_user_input: bool = Field( | |
| default=True, | |
| description="Always prioritize user input over historical context", | |
| env="PRIORITIZE_USER_INPUT" | |
| ) | |
| def validate_novita_api_key(cls, v): | |
| """Validate and clean Novita API key""" | |
| if v is None: | |
| return "" | |
| return str(v).strip() | |
| def validate_deepseek_temperature(cls, v): | |
| """Validate DeepSeek-R1 temperature is in recommended range""" | |
| if isinstance(v, str): | |
| v = float(v) | |
| temp = float(v) if v else 0.6 | |
| return max(0.5, min(0.7, temp)) | |
| def validate_force_reasoning(cls, v): | |
| """Convert string to boolean for force_reasoning""" | |
| if isinstance(v, str): | |
| return v.lower() in ("true", "1", "yes", "on") | |
| return bool(v) | |
| def validate_user_input_tokens(cls, v): | |
| """Validate user input token limit""" | |
| val = int(v) if v else 8000 | |
| return max(1000, min(20000, val)) | |
| def validate_context_budget(cls, v): | |
| """Validate context preparation budget""" | |
| val = int(v) if v else 28000 | |
| return max(4000, min(120000, val)) | |
| # ==================== Model Configuration ==================== | |
| default_model: str = Field( | |
| default="meta-llama/Llama-3.1-8B-Instruct:cerebras", | |
| description="Primary model for reasoning tasks (Cerebras deployment with 4-bit quantization)" | |
| ) | |
| embedding_model: str = Field( | |
| default="intfloat/e5-large-v2", | |
| description="Model for embeddings (upgraded: 1024-dim embeddings)" | |
| ) | |
| classification_model: str = Field( | |
| default="meta-llama/Llama-3.1-8B-Instruct:cerebras", | |
| description="Model for classification tasks (Cerebras deployment)" | |
| ) | |
| # ==================== Performance Configuration ==================== | |
| max_workers: int = Field( | |
| default=4, | |
| description="Maximum worker threads for parallel processing", | |
| env="MAX_WORKERS" | |
| ) | |
| def validate_max_workers(cls, v): | |
| """Validate and convert max_workers (backward compatible)""" | |
| if v is None: | |
| return 4 | |
| if isinstance(v, str): | |
| try: | |
| v = int(v) | |
| except ValueError: | |
| logger.warning(f"Invalid MAX_WORKERS value: {v}, using default 4") | |
| return 4 | |
| try: | |
| val = int(v) | |
| return max(1, min(16, val)) # Clamp between 1 and 16 | |
| except (ValueError, TypeError): | |
| return 4 | |
| cache_ttl: int = Field( | |
| default=3600, | |
| description="Cache time-to-live in seconds", | |
| env="CACHE_TTL" | |
| ) | |
| def validate_cache_ttl(cls, v): | |
| """Validate cache TTL (backward compatible)""" | |
| if v is None: | |
| return 3600 | |
| if isinstance(v, str): | |
| try: | |
| v = int(v) | |
| except ValueError: | |
| return 3600 | |
| try: | |
| return max(0, int(v)) | |
| except (ValueError, TypeError): | |
| return 3600 | |
| # ==================== Database Configuration ==================== | |
| db_path: str = Field( | |
| default="sessions.db", | |
| description="Path to SQLite database file", | |
| env="DB_PATH" | |
| ) | |
| def validate_db_path(cls, v): | |
| """Validate db_path with Docker fallback (backward compatible)""" | |
| if v is None: | |
| # Check if we're in Docker (HF Spaces) - if so, use /tmp | |
| if os.path.exists("/.dockerenv") or os.path.exists("/tmp"): | |
| return "/tmp/sessions.db" | |
| return "sessions.db" | |
| return str(v) | |
| faiss_index_path: str = Field( | |
| default="embeddings.faiss", | |
| description="Path to FAISS index file", | |
| env="FAISS_INDEX_PATH" | |
| ) | |
| def validate_faiss_path(cls, v): | |
| """Validate faiss path with Docker fallback (backward compatible)""" | |
| if v is None: | |
| # Check if we're in Docker (HF Spaces) - if so, use /tmp | |
| if os.path.exists("/.dockerenv") or os.path.exists("/tmp"): | |
| return "/tmp/embeddings.faiss" | |
| return "embeddings.faiss" | |
| return str(v) | |
| # ==================== Session Configuration ==================== | |
| session_timeout: int = Field( | |
| default=3600, | |
| description="Session timeout in seconds", | |
| env="SESSION_TIMEOUT" | |
| ) | |
| def validate_session_timeout(cls, v): | |
| """Validate session timeout (backward compatible)""" | |
| if v is None: | |
| return 3600 | |
| if isinstance(v, str): | |
| try: | |
| v = int(v) | |
| except ValueError: | |
| return 3600 | |
| try: | |
| return max(60, int(v)) | |
| except (ValueError, TypeError): | |
| return 3600 | |
| max_session_size_mb: int = Field( | |
| default=10, | |
| description="Maximum session size in megabytes", | |
| env="MAX_SESSION_SIZE_MB" | |
| ) | |
| def validate_max_session_size(cls, v): | |
| """Validate max session size (backward compatible)""" | |
| if v is None: | |
| return 10 | |
| if isinstance(v, str): | |
| try: | |
| v = int(v) | |
| except ValueError: | |
| return 10 | |
| try: | |
| return max(1, min(100, int(v))) | |
| except (ValueError, TypeError): | |
| return 10 | |
| # ==================== Mobile Optimization ==================== | |
| mobile_max_tokens: int = Field( | |
| default=800, | |
| description="Maximum tokens for mobile responses", | |
| env="MOBILE_MAX_TOKENS" | |
| ) | |
| def validate_mobile_max_tokens(cls, v): | |
| """Validate mobile max tokens (backward compatible)""" | |
| if v is None: | |
| return 800 | |
| if isinstance(v, str): | |
| try: | |
| v = int(v) | |
| except ValueError: | |
| return 800 | |
| try: | |
| return max(100, min(2000, int(v))) | |
| except (ValueError, TypeError): | |
| return 800 | |
| mobile_timeout: int = Field( | |
| default=15000, | |
| description="Mobile request timeout in milliseconds", | |
| env="MOBILE_TIMEOUT" | |
| ) | |
| def validate_mobile_timeout(cls, v): | |
| """Validate mobile timeout (backward compatible)""" | |
| if v is None: | |
| return 15000 | |
| if isinstance(v, str): | |
| try: | |
| v = int(v) | |
| except ValueError: | |
| return 15000 | |
| try: | |
| return max(5000, min(60000, int(v))) | |
| except (ValueError, TypeError): | |
| return 15000 | |
| # ==================== API Configuration ==================== | |
| gradio_port: int = Field( | |
| default=7860, | |
| description="Gradio server port", | |
| env="GRADIO_PORT" | |
| ) | |
| def validate_gradio_port(cls, v): | |
| """Validate gradio port (backward compatible)""" | |
| if v is None: | |
| return 7860 | |
| if isinstance(v, str): | |
| try: | |
| v = int(v) | |
| except ValueError: | |
| return 7860 | |
| try: | |
| return max(1024, min(65535, int(v))) | |
| except (ValueError, TypeError): | |
| return 7860 | |
| gradio_host: str = Field( | |
| default="0.0.0.0", | |
| description="Gradio server host", | |
| env="GRADIO_HOST" | |
| ) | |
| # ==================== Logging Configuration ==================== | |
| log_level: str = Field( | |
| default="INFO", | |
| description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", | |
| env="LOG_LEVEL" | |
| ) | |
| def validate_log_level(cls, v): | |
| """Validate log level (backward compatible)""" | |
| if not v: | |
| return "INFO" | |
| valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] | |
| if v.upper() not in valid_levels: | |
| logger.warning(f"Invalid log level: {v}, using INFO") | |
| return "INFO" | |
| return v.upper() | |
| log_format: str = Field( | |
| default="json", | |
| description="Log format (json or text)", | |
| env="LOG_FORMAT" | |
| ) | |
| def validate_log_format(cls, v): | |
| """Validate log format (backward compatible)""" | |
| if not v: | |
| return "json" | |
| if v.lower() not in ["json", "text"]: | |
| logger.warning(f"Invalid log format: {v}, using json") | |
| return "json" | |
| return v.lower() | |
| # ==================== Pydantic Configuration ==================== | |
| class Config: | |
| """Pydantic configuration""" | |
| env_file = ".env" | |
| env_file_encoding = "utf-8" | |
| case_sensitive = False | |
| validate_assignment = True | |
| # Allow extra fields for backward compatibility | |
| extra = "ignore" | |
| # ==================== Utility Methods ==================== | |
| def validate_configuration(self) -> bool: | |
| """ | |
| Validate configuration and log status. | |
| Returns: | |
| bool: True if configuration is valid, False otherwise | |
| """ | |
| try: | |
| # Validate cache directory | |
| cache_dir = self.hf_cache_dir | |
| if logger.isEnabledFor(logging.INFO): | |
| logger.info("Configuration validated:") | |
| logger.info(f" - Cache directory: {cache_dir}") | |
| logger.info(f" - Max workers: {self.max_workers}") | |
| logger.info(f" - Log level: {self.log_level}") | |
| logger.info(f" - HF token: {'Set' if self.hf_token else 'Not set'}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Configuration validation failed: {e}") | |
| return False | |
| # ==================== Global Settings Instance ==================== | |
| def get_settings() -> Settings: | |
| """ | |
| Get or create global settings instance. | |
| Returns: | |
| Settings: Global settings instance | |
| Note: | |
| This function ensures settings are loaded once and cached. | |
| """ | |
| if not hasattr(get_settings, '_instance'): | |
| get_settings._instance = Settings() | |
| # Validate on first load (non-blocking) | |
| try: | |
| get_settings._instance.validate_configuration() | |
| except Exception as e: | |
| logger.warning(f"Configuration validation warning: {e}") | |
| return get_settings._instance | |
| # Create global settings instance (backward compatible) | |
| settings = get_settings() | |
| # Log configuration on import (at INFO level, non-blocking) | |
| if logger.isEnabledFor(logging.INFO): | |
| try: | |
| logger.info("=" * 60) | |
| logger.info("Configuration Loaded") | |
| logger.info("=" * 60) | |
| logger.info(f"Cache directory: {settings.hf_cache_dir}") | |
| logger.info(f"Max workers: {settings.max_workers}") | |
| logger.info(f"Log level: {settings.log_level}") | |
| logger.info("=" * 60) | |
| except Exception as e: | |
| logger.debug(f"Configuration logging skipped: {e}") | |