|
|
""" |
|
|
Configuration Management Module |
|
|
|
|
|
This module provides secure, robust configuration management with: |
|
|
- Environment variable handling with secure defaults |
|
|
- Cache directory management with automatic fallbacks |
|
|
- Comprehensive logging and error handling |
|
|
- Security best practices for sensitive data |
|
|
- Backward compatibility with existing code |
|
|
|
|
|
Environment Variables: |
|
|
HF_TOKEN: HuggingFace API token (required for API access) |
|
|
HF_HOME: Primary cache directory for HuggingFace models |
|
|
TRANSFORMERS_CACHE: Alternative cache directory path |
|
|
MAX_WORKERS: Maximum worker threads (default: 4) |
|
|
CACHE_TTL: Cache time-to-live in seconds (default: 3600) |
|
|
DB_PATH: Database file path (default: sessions.db) |
|
|
LOG_LEVEL: Logging level (default: INFO) |
|
|
LOG_FORMAT: Log format (default: json) |
|
|
|
|
|
Security Notes: |
|
|
- Never commit .env files to version control |
|
|
- Use environment variables for all sensitive data |
|
|
- Cache directories are automatically secured with proper permissions |
|
|
""" |
|
|
|
|
|
import os |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
from pydantic_settings import BaseSettings |
|
|
from pydantic import Field, validator |
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class CacheDirectoryManager: |
|
|
""" |
|
|
Manages cache directory with secure fallback mechanism. |
|
|
|
|
|
Implements: |
|
|
- Multi-level fallback strategy |
|
|
- Permission validation |
|
|
- Automatic directory creation |
|
|
- Security best practices |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def get_cache_directory() -> str: |
|
|
""" |
|
|
Get cache directory with secure fallback chain. |
|
|
|
|
|
Priority order: |
|
|
1. HF_HOME environment variable |
|
|
2. TRANSFORMERS_CACHE environment variable |
|
|
3. User home directory (~/.cache/huggingface) |
|
|
4. User-specific fallback directory |
|
|
5. Temporary directory (last resort) |
|
|
|
|
|
Returns: |
|
|
str: Path to writable cache directory |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
is_docker = os.path.exists("/.dockerenv") or os.path.exists("/tmp") |
|
|
|
|
|
cache_candidates = [ |
|
|
os.getenv("HF_HOME"), |
|
|
os.getenv("TRANSFORMERS_CACHE"), |
|
|
|
|
|
"/tmp/huggingface_cache" if is_docker else None, |
|
|
os.path.join(os.path.expanduser("~"), ".cache", "huggingface") if os.path.expanduser("~") and not is_docker else None, |
|
|
os.path.join(os.path.expanduser("~"), ".cache", "huggingface_fallback") if os.path.expanduser("~") and not is_docker else None, |
|
|
"/tmp/huggingface_cache" if not is_docker else None, |
|
|
"/tmp/huggingface" |
|
|
] |
|
|
|
|
|
for cache_dir in cache_candidates: |
|
|
if not cache_dir: |
|
|
continue |
|
|
|
|
|
try: |
|
|
|
|
|
cache_path = Path(cache_dir) |
|
|
cache_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
try: |
|
|
os.chmod(cache_path, 0o755) |
|
|
except (OSError, PermissionError): |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
test_file = cache_path / ".write_test" |
|
|
try: |
|
|
test_file.write_text("test") |
|
|
test_file.unlink() |
|
|
|
|
|
logger.info(f"✓ Cache directory verified: {cache_dir}") |
|
|
return str(cache_path) |
|
|
|
|
|
except (PermissionError, OSError) as e: |
|
|
logger.debug(f"Write test failed for {cache_dir}: {e}") |
|
|
continue |
|
|
|
|
|
except (PermissionError, OSError) as e: |
|
|
logger.debug(f"Could not create/access {cache_dir}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
fallback = "/tmp/huggingface_emergency" |
|
|
try: |
|
|
Path(fallback).mkdir(parents=True, exist_ok=True) |
|
|
logger.warning(f"Using emergency fallback cache: {fallback}") |
|
|
return fallback |
|
|
except Exception as e: |
|
|
logger.error(f"Emergency fallback also failed: {e}") |
|
|
|
|
|
return "/tmp/huggingface" |
|
|
|
|
|
|
|
|
class Settings(BaseSettings): |
|
|
""" |
|
|
Application settings with secure defaults and validation. |
|
|
|
|
|
Backward Compatibility: |
|
|
- All existing attributes are preserved |
|
|
- hf_token is accessible as string (via property) |
|
|
- hf_cache_dir is accessible as property (works like before) |
|
|
- All defaults match original implementation |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_token: str = Field( |
|
|
default="", |
|
|
description="HuggingFace API token", |
|
|
env="HF_TOKEN" |
|
|
) |
|
|
|
|
|
@validator("hf_token", pre=True) |
|
|
def validate_hf_token(cls, v): |
|
|
"""Validate HF token (backward compatible)""" |
|
|
if v is None: |
|
|
return "" |
|
|
token = str(v) if v else "" |
|
|
if not token: |
|
|
logger.debug("HF_TOKEN not set") |
|
|
return token |
|
|
|
|
|
@property |
|
|
def hf_cache_dir(self) -> str: |
|
|
""" |
|
|
Get cache directory with automatic fallback and validation. |
|
|
|
|
|
BACKWARD COMPAT: Works like the original hf_cache_dir field. |
|
|
|
|
|
Returns: |
|
|
str: Path to writable cache directory |
|
|
""" |
|
|
if not hasattr(self, '_cached_cache_dir'): |
|
|
try: |
|
|
self._cached_cache_dir = CacheDirectoryManager.get_cache_directory() |
|
|
except Exception as e: |
|
|
logger.error(f"Cache directory setup failed: {e}") |
|
|
|
|
|
fallback = os.getenv("HF_HOME", "/tmp/huggingface") |
|
|
Path(fallback).mkdir(parents=True, exist_ok=True) |
|
|
self._cached_cache_dir = fallback |
|
|
|
|
|
return self._cached_cache_dir |
|
|
|
|
|
|
|
|
|
|
|
zerogpu_base_url: str = Field( |
|
|
default="http://your-pod-ip:8000", |
|
|
description="ZeroGPU Chat API base URL (RunPod endpoint)", |
|
|
env="ZEROGPU_BASE_URL" |
|
|
) |
|
|
|
|
|
zerogpu_email: str = Field( |
|
|
default="", |
|
|
description="ZeroGPU Chat API email for authentication (required)", |
|
|
env="ZEROGPU_EMAIL" |
|
|
) |
|
|
|
|
|
zerogpu_password: str = Field( |
|
|
default="", |
|
|
description="ZeroGPU Chat API password for authentication (required)", |
|
|
env="ZEROGPU_PASSWORD" |
|
|
) |
|
|
|
|
|
|
|
|
user_input_max_tokens: int = Field( |
|
|
default=32000, |
|
|
description="Maximum tokens dedicated for user input (prioritized over context)", |
|
|
env="USER_INPUT_MAX_TOKENS" |
|
|
) |
|
|
|
|
|
context_preparation_budget: int = Field( |
|
|
default=115000, |
|
|
description="Maximum tokens for context preparation (includes user input + context)", |
|
|
env="CONTEXT_PREPARATION_BUDGET" |
|
|
) |
|
|
|
|
|
context_pruning_threshold: int = Field( |
|
|
default=115000, |
|
|
description="Context pruning threshold (should match context_preparation_budget)", |
|
|
env="CONTEXT_PRUNING_THRESHOLD" |
|
|
) |
|
|
|
|
|
prioritize_user_input: bool = Field( |
|
|
default=True, |
|
|
description="Always prioritize user input over historical context", |
|
|
env="PRIORITIZE_USER_INPUT" |
|
|
) |
|
|
|
|
|
|
|
|
zerogpu_model_context_window: int = Field( |
|
|
default=8192, |
|
|
description="Maximum context window for ZeroGPU Chat API model (input + output tokens). Adjust based on your deployed model.", |
|
|
env="ZEROGPU_MODEL_CONTEXT_WINDOW" |
|
|
) |
|
|
|
|
|
@validator("zerogpu_base_url", pre=True) |
|
|
def validate_zerogpu_base_url(cls, v): |
|
|
"""Validate ZeroGPU base URL""" |
|
|
if v is None: |
|
|
return "http://your-pod-ip:8000" |
|
|
url = str(v).strip() |
|
|
|
|
|
if url.endswith('/'): |
|
|
url = url[:-1] |
|
|
return url |
|
|
|
|
|
@validator("zerogpu_email", pre=True) |
|
|
def validate_zerogpu_email(cls, v): |
|
|
"""Validate ZeroGPU email""" |
|
|
if v is None: |
|
|
return "" |
|
|
email = str(v).strip() |
|
|
if email and '@' not in email: |
|
|
logger.warning("ZEROGPU_EMAIL may not be a valid email address") |
|
|
return email |
|
|
|
|
|
@validator("zerogpu_password", pre=True) |
|
|
def validate_zerogpu_password(cls, v): |
|
|
"""Validate ZeroGPU password""" |
|
|
if v is None: |
|
|
return "" |
|
|
return str(v).strip() |
|
|
|
|
|
@validator("user_input_max_tokens", pre=True) |
|
|
def validate_user_input_tokens(cls, v): |
|
|
"""Validate user input token limit""" |
|
|
val = int(v) if v else 32000 |
|
|
return max(1000, min(50000, val)) |
|
|
|
|
|
@validator("context_preparation_budget", pre=True) |
|
|
def validate_context_budget(cls, v): |
|
|
"""Validate context preparation budget""" |
|
|
val = int(v) if v else 115000 |
|
|
return max(4000, min(125000, val)) |
|
|
|
|
|
@validator("context_pruning_threshold", pre=True) |
|
|
def validate_pruning_threshold(cls, v): |
|
|
"""Validate context pruning threshold""" |
|
|
val = int(v) if v else 115000 |
|
|
return max(4000, min(125000, val)) |
|
|
|
|
|
@validator("zerogpu_model_context_window", pre=True) |
|
|
def validate_context_window(cls, v): |
|
|
"""Validate context window size""" |
|
|
val = int(v) if v else 8192 |
|
|
return max(1000, min(200000, val)) |
|
|
|
|
|
|
|
|
|
|
|
default_model: str = Field( |
|
|
default="meta-llama/Llama-3.1-8B-Instruct:cerebras", |
|
|
description="Primary model for reasoning tasks (Cerebras deployment with 4-bit quantization)" |
|
|
) |
|
|
|
|
|
embedding_model: str = Field( |
|
|
default="intfloat/e5-large-v2", |
|
|
description="Model for embeddings (upgraded: 1024-dim embeddings)" |
|
|
) |
|
|
|
|
|
classification_model: str = Field( |
|
|
default="meta-llama/Llama-3.1-8B-Instruct:cerebras", |
|
|
description="Model for classification tasks (Cerebras deployment)" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
max_workers: int = Field( |
|
|
default=4, |
|
|
description="Maximum worker threads for parallel processing", |
|
|
env="MAX_WORKERS" |
|
|
) |
|
|
|
|
|
@validator("max_workers", pre=True) |
|
|
def validate_max_workers(cls, v): |
|
|
"""Validate and convert max_workers (backward compatible)""" |
|
|
if v is None: |
|
|
return 4 |
|
|
if isinstance(v, str): |
|
|
try: |
|
|
v = int(v) |
|
|
except ValueError: |
|
|
logger.warning(f"Invalid MAX_WORKERS value: {v}, using default 4") |
|
|
return 4 |
|
|
try: |
|
|
val = int(v) |
|
|
return max(1, min(16, val)) |
|
|
except (ValueError, TypeError): |
|
|
return 4 |
|
|
|
|
|
cache_ttl: int = Field( |
|
|
default=3600, |
|
|
description="Cache time-to-live in seconds", |
|
|
env="CACHE_TTL" |
|
|
) |
|
|
|
|
|
@validator("cache_ttl", pre=True) |
|
|
def validate_cache_ttl(cls, v): |
|
|
"""Validate cache TTL (backward compatible)""" |
|
|
if v is None: |
|
|
return 3600 |
|
|
if isinstance(v, str): |
|
|
try: |
|
|
v = int(v) |
|
|
except ValueError: |
|
|
return 3600 |
|
|
try: |
|
|
return max(0, int(v)) |
|
|
except (ValueError, TypeError): |
|
|
return 3600 |
|
|
|
|
|
|
|
|
|
|
|
db_path: str = Field( |
|
|
default="sessions.db", |
|
|
description="Path to SQLite database file", |
|
|
env="DB_PATH" |
|
|
) |
|
|
|
|
|
@validator("db_path", pre=True) |
|
|
def validate_db_path(cls, v): |
|
|
"""Validate db_path with Docker fallback (backward compatible)""" |
|
|
if v is None: |
|
|
|
|
|
if os.path.exists("/.dockerenv") or os.path.exists("/tmp"): |
|
|
return "/tmp/sessions.db" |
|
|
return "sessions.db" |
|
|
return str(v) |
|
|
|
|
|
faiss_index_path: str = Field( |
|
|
default="embeddings.faiss", |
|
|
description="Path to FAISS index file", |
|
|
env="FAISS_INDEX_PATH" |
|
|
) |
|
|
|
|
|
@validator("faiss_index_path", pre=True) |
|
|
def validate_faiss_path(cls, v): |
|
|
"""Validate faiss path with Docker fallback (backward compatible)""" |
|
|
if v is None: |
|
|
|
|
|
if os.path.exists("/.dockerenv") or os.path.exists("/tmp"): |
|
|
return "/tmp/embeddings.faiss" |
|
|
return "embeddings.faiss" |
|
|
return str(v) |
|
|
|
|
|
|
|
|
|
|
|
session_timeout: int = Field( |
|
|
default=3600, |
|
|
description="Session timeout in seconds", |
|
|
env="SESSION_TIMEOUT" |
|
|
) |
|
|
|
|
|
@validator("session_timeout", pre=True) |
|
|
def validate_session_timeout(cls, v): |
|
|
"""Validate session timeout (backward compatible)""" |
|
|
if v is None: |
|
|
return 3600 |
|
|
if isinstance(v, str): |
|
|
try: |
|
|
v = int(v) |
|
|
except ValueError: |
|
|
return 3600 |
|
|
try: |
|
|
return max(60, int(v)) |
|
|
except (ValueError, TypeError): |
|
|
return 3600 |
|
|
|
|
|
max_session_size_mb: int = Field( |
|
|
default=10, |
|
|
description="Maximum session size in megabytes", |
|
|
env="MAX_SESSION_SIZE_MB" |
|
|
) |
|
|
|
|
|
@validator("max_session_size_mb", pre=True) |
|
|
def validate_max_session_size(cls, v): |
|
|
"""Validate max session size (backward compatible)""" |
|
|
if v is None: |
|
|
return 10 |
|
|
if isinstance(v, str): |
|
|
try: |
|
|
v = int(v) |
|
|
except ValueError: |
|
|
return 10 |
|
|
try: |
|
|
return max(1, min(100, int(v))) |
|
|
except (ValueError, TypeError): |
|
|
return 10 |
|
|
|
|
|
|
|
|
|
|
|
mobile_max_tokens: int = Field( |
|
|
default=800, |
|
|
description="Maximum tokens for mobile responses", |
|
|
env="MOBILE_MAX_TOKENS" |
|
|
) |
|
|
|
|
|
@validator("mobile_max_tokens", pre=True) |
|
|
def validate_mobile_max_tokens(cls, v): |
|
|
"""Validate mobile max tokens (backward compatible)""" |
|
|
if v is None: |
|
|
return 800 |
|
|
if isinstance(v, str): |
|
|
try: |
|
|
v = int(v) |
|
|
except ValueError: |
|
|
return 800 |
|
|
try: |
|
|
return max(100, min(2000, int(v))) |
|
|
except (ValueError, TypeError): |
|
|
return 800 |
|
|
|
|
|
mobile_timeout: int = Field( |
|
|
default=15000, |
|
|
description="Mobile request timeout in milliseconds", |
|
|
env="MOBILE_TIMEOUT" |
|
|
) |
|
|
|
|
|
@validator("mobile_timeout", pre=True) |
|
|
def validate_mobile_timeout(cls, v): |
|
|
"""Validate mobile timeout (backward compatible)""" |
|
|
if v is None: |
|
|
return 15000 |
|
|
if isinstance(v, str): |
|
|
try: |
|
|
v = int(v) |
|
|
except ValueError: |
|
|
return 15000 |
|
|
try: |
|
|
return max(5000, min(60000, int(v))) |
|
|
except (ValueError, TypeError): |
|
|
return 15000 |
|
|
|
|
|
|
|
|
|
|
|
gradio_port: int = Field( |
|
|
default=7860, |
|
|
description="Gradio server port", |
|
|
env="GRADIO_PORT" |
|
|
) |
|
|
|
|
|
@validator("gradio_port", pre=True) |
|
|
def validate_gradio_port(cls, v): |
|
|
"""Validate gradio port (backward compatible)""" |
|
|
if v is None: |
|
|
return 7860 |
|
|
if isinstance(v, str): |
|
|
try: |
|
|
v = int(v) |
|
|
except ValueError: |
|
|
return 7860 |
|
|
try: |
|
|
return max(1024, min(65535, int(v))) |
|
|
except (ValueError, TypeError): |
|
|
return 7860 |
|
|
|
|
|
gradio_host: str = Field( |
|
|
default="0.0.0.0", |
|
|
description="Gradio server host", |
|
|
env="GRADIO_HOST" |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
log_level: str = Field( |
|
|
default="INFO", |
|
|
description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)", |
|
|
env="LOG_LEVEL" |
|
|
) |
|
|
|
|
|
@validator("log_level") |
|
|
def validate_log_level(cls, v): |
|
|
"""Validate log level (backward compatible)""" |
|
|
if not v: |
|
|
return "INFO" |
|
|
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] |
|
|
if v.upper() not in valid_levels: |
|
|
logger.warning(f"Invalid log level: {v}, using INFO") |
|
|
return "INFO" |
|
|
return v.upper() |
|
|
|
|
|
log_format: str = Field( |
|
|
default="json", |
|
|
description="Log format (json or text)", |
|
|
env="LOG_FORMAT" |
|
|
) |
|
|
|
|
|
@validator("log_format") |
|
|
def validate_log_format(cls, v): |
|
|
"""Validate log format (backward compatible)""" |
|
|
if not v: |
|
|
return "json" |
|
|
if v.lower() not in ["json", "text"]: |
|
|
logger.warning(f"Invalid log format: {v}, using json") |
|
|
return "json" |
|
|
return v.lower() |
|
|
|
|
|
|
|
|
|
|
|
class Config: |
|
|
"""Pydantic configuration""" |
|
|
env_file = ".env" |
|
|
env_file_encoding = "utf-8" |
|
|
case_sensitive = False |
|
|
validate_assignment = True |
|
|
|
|
|
extra = "ignore" |
|
|
|
|
|
|
|
|
|
|
|
def validate_configuration(self) -> bool: |
|
|
""" |
|
|
Validate configuration and log status. |
|
|
|
|
|
Returns: |
|
|
bool: True if configuration is valid, False otherwise |
|
|
""" |
|
|
try: |
|
|
|
|
|
cache_dir = self.hf_cache_dir |
|
|
if logger.isEnabledFor(logging.INFO): |
|
|
logger.info("Configuration validated:") |
|
|
logger.info(f" - Cache directory: {cache_dir}") |
|
|
logger.info(f" - Max workers: {self.max_workers}") |
|
|
logger.info(f" - Log level: {self.log_level}") |
|
|
logger.info(f" - HF token: {'Set' if self.hf_token else 'Not set'}") |
|
|
|
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Configuration validation failed: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_settings() -> Settings: |
|
|
""" |
|
|
Get or create global settings instance. |
|
|
|
|
|
Returns: |
|
|
Settings: Global settings instance |
|
|
|
|
|
Note: |
|
|
This function ensures settings are loaded once and cached. |
|
|
""" |
|
|
if not hasattr(get_settings, '_instance'): |
|
|
get_settings._instance = Settings() |
|
|
|
|
|
try: |
|
|
get_settings._instance.validate_configuration() |
|
|
except Exception as e: |
|
|
logger.warning(f"Configuration validation warning: {e}") |
|
|
return get_settings._instance |
|
|
|
|
|
|
|
|
|
|
|
settings = get_settings() |
|
|
|
|
|
|
|
|
if logger.isEnabledFor(logging.INFO): |
|
|
try: |
|
|
logger.info("=" * 60) |
|
|
logger.info("Configuration Loaded") |
|
|
logger.info("=" * 60) |
|
|
logger.info(f"Cache directory: {settings.hf_cache_dir}") |
|
|
logger.info(f"Max workers: {settings.max_workers}") |
|
|
logger.info(f"Log level: {settings.log_level}") |
|
|
logger.info("=" * 60) |
|
|
except Exception as e: |
|
|
logger.debug(f"Configuration logging skipped: {e}") |
|
|
|