File size: 20,358 Bytes
79ea999 8f4d405 79ea999 8f4d405 79ea999 9959ea9 79ea999 9959ea9 79ea999 8f4d405 79ea999 0747201 927854c 0747201 927854c 0747201 927854c 0747201 927854c 5d37f3d 927854c 5d37f3d 927854c 5d37f3d 927854c a9135e0 0747201 a9135e0 0747201 927854c 0747201 927854c 0747201 927854c 0747201 927854c 5d37f3d 927854c 5d37f3d 927854c 0747201 a9135e0 0747201 5d37f3d a9135e0 79ea999 b3aba24 79ea999 b3aba24 79ea999 8f4d405 79ea999 8f4d405 79ea999 8f4d405 79ea999 8f4d405 79ea999 8f4d405 79ea999 8f4d405 79ea999 8f4d405 79ea999 8f4d405 79ea999 8f4d405 79ea999 8f4d405 79ea999 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 |
"""
Configuration Management Module
This module provides secure, robust configuration management with:
- Environment variable handling with secure defaults
- Cache directory management with automatic fallbacks
- Comprehensive logging and error handling
- Security best practices for sensitive data
- Backward compatibility with existing code
Environment Variables:
HF_TOKEN: HuggingFace API token (required for API access)
HF_HOME: Primary cache directory for HuggingFace models
TRANSFORMERS_CACHE: Alternative cache directory path
MAX_WORKERS: Maximum worker threads (default: 4)
CACHE_TTL: Cache time-to-live in seconds (default: 3600)
DB_PATH: Database file path (default: sessions.db)
LOG_LEVEL: Logging level (default: INFO)
LOG_FORMAT: Log format (default: json)
Security Notes:
- Never commit .env files to version control
- Use environment variables for all sensitive data
- Cache directories are automatically secured with proper permissions
"""
import os
import logging
from pathlib import Path
from typing import Optional
from pydantic_settings import BaseSettings
from pydantic import Field, validator
# Configure logging
logger = logging.getLogger(__name__)
class CacheDirectoryManager:
"""
Manages cache directory with secure fallback mechanism.
Implements:
- Multi-level fallback strategy
- Permission validation
- Automatic directory creation
- Security best practices
"""
@staticmethod
def get_cache_directory() -> str:
"""
Get cache directory with secure fallback chain.
Priority order:
1. HF_HOME environment variable
2. TRANSFORMERS_CACHE environment variable
3. User home directory (~/.cache/huggingface)
4. User-specific fallback directory
5. Temporary directory (last resort)
Returns:
str: Path to writable cache directory
"""
# Priority order for cache directory
# In Docker, ~ may resolve to / which causes permission issues
# So we prefer /tmp over ~/.cache in containerized environments
is_docker = os.path.exists("/.dockerenv") or os.path.exists("/tmp")
cache_candidates = [
os.getenv("HF_HOME"),
os.getenv("TRANSFORMERS_CACHE"),
# In Docker, prefer /tmp over ~/.cache
"/tmp/huggingface_cache" if is_docker else None,
os.path.join(os.path.expanduser("~"), ".cache", "huggingface") if os.path.expanduser("~") and not is_docker else None,
os.path.join(os.path.expanduser("~"), ".cache", "huggingface_fallback") if os.path.expanduser("~") and not is_docker else None,
"/tmp/huggingface_cache" if not is_docker else None,
"/tmp/huggingface" # Final fallback
]
for cache_dir in cache_candidates:
if not cache_dir:
continue
try:
# Ensure directory exists
cache_path = Path(cache_dir)
cache_path.mkdir(parents=True, exist_ok=True)
# Set secure permissions (rwxr-xr-x)
try:
os.chmod(cache_path, 0o755)
except (OSError, PermissionError):
# If we can't set permissions, continue if directory is writable
pass
# Test write access
test_file = cache_path / ".write_test"
try:
test_file.write_text("test")
test_file.unlink()
logger.info(f"✓ Cache directory verified: {cache_dir}")
return str(cache_path)
except (PermissionError, OSError) as e:
logger.debug(f"Write test failed for {cache_dir}: {e}")
continue
except (PermissionError, OSError) as e:
logger.debug(f"Could not create/access {cache_dir}: {e}")
continue
# If all candidates failed, use emergency fallback
fallback = "/tmp/huggingface_emergency"
try:
Path(fallback).mkdir(parents=True, exist_ok=True)
logger.warning(f"Using emergency fallback cache: {fallback}")
return fallback
except Exception as e:
logger.error(f"Emergency fallback also failed: {e}")
# Return a default that will fail gracefully later
return "/tmp/huggingface"
class Settings(BaseSettings):
"""
Application settings with secure defaults and validation.
Backward Compatibility:
- All existing attributes are preserved
- hf_token is accessible as string (via property)
- hf_cache_dir is accessible as property (works like before)
- All defaults match original implementation
"""
# ==================== HuggingFace Configuration ====================
# BACKWARD COMPAT: hf_token as regular field (backward compatible)
hf_token: str = Field(
default="",
description="HuggingFace API token",
env="HF_TOKEN"
)
@validator("hf_token", pre=True)
def validate_hf_token(cls, v):
"""Validate HF token (backward compatible)"""
if v is None:
return ""
token = str(v) if v else ""
if not token:
logger.debug("HF_TOKEN not set")
return token
@property
def hf_cache_dir(self) -> str:
"""
Get cache directory with automatic fallback and validation.
BACKWARD COMPAT: Works like the original hf_cache_dir field.
Returns:
str: Path to writable cache directory
"""
if not hasattr(self, '_cached_cache_dir'):
try:
self._cached_cache_dir = CacheDirectoryManager.get_cache_directory()
except Exception as e:
logger.error(f"Cache directory setup failed: {e}")
# Fallback to original default
fallback = os.getenv("HF_HOME", "/tmp/huggingface")
Path(fallback).mkdir(parents=True, exist_ok=True)
self._cached_cache_dir = fallback
return self._cached_cache_dir
# ==================== ZeroGPU Chat API Configuration ====================
zerogpu_base_url: str = Field(
default="http://your-pod-ip:8000",
description="ZeroGPU Chat API base URL (RunPod endpoint)",
env="ZEROGPU_BASE_URL"
)
zerogpu_email: str = Field(
default="",
description="ZeroGPU Chat API email for authentication (required)",
env="ZEROGPU_EMAIL"
)
zerogpu_password: str = Field(
default="",
description="ZeroGPU Chat API password for authentication (required)",
env="ZEROGPU_PASSWORD"
)
# Token Allocation Configuration
user_input_max_tokens: int = Field(
default=32000,
description="Maximum tokens dedicated for user input (prioritized over context)",
env="USER_INPUT_MAX_TOKENS"
)
context_preparation_budget: int = Field(
default=115000,
description="Maximum tokens for context preparation (includes user input + context)",
env="CONTEXT_PREPARATION_BUDGET"
)
context_pruning_threshold: int = Field(
default=115000,
description="Context pruning threshold (should match context_preparation_budget)",
env="CONTEXT_PRUNING_THRESHOLD"
)
prioritize_user_input: bool = Field(
default=True,
description="Always prioritize user input over historical context",
env="PRIORITIZE_USER_INPUT"
)
# Model Context Window Configuration
zerogpu_model_context_window: int = Field(
default=8192,
description="Maximum context window for ZeroGPU Chat API model (input + output tokens). Adjust based on your deployed model.",
env="ZEROGPU_MODEL_CONTEXT_WINDOW"
)
@validator("zerogpu_base_url", pre=True)
def validate_zerogpu_base_url(cls, v):
"""Validate ZeroGPU base URL"""
if v is None:
return "http://your-pod-ip:8000"
url = str(v).strip()
# Remove trailing slash
if url.endswith('/'):
url = url[:-1]
return url
@validator("zerogpu_email", pre=True)
def validate_zerogpu_email(cls, v):
"""Validate ZeroGPU email"""
if v is None:
return ""
email = str(v).strip()
if email and '@' not in email:
logger.warning("ZEROGPU_EMAIL may not be a valid email address")
return email
@validator("zerogpu_password", pre=True)
def validate_zerogpu_password(cls, v):
"""Validate ZeroGPU password"""
if v is None:
return ""
return str(v).strip()
@validator("user_input_max_tokens", pre=True)
def validate_user_input_tokens(cls, v):
"""Validate user input token limit"""
val = int(v) if v else 32000
return max(1000, min(50000, val)) # Allow up to 50K for large inputs
@validator("context_preparation_budget", pre=True)
def validate_context_budget(cls, v):
"""Validate context preparation budget"""
val = int(v) if v else 115000
return max(4000, min(125000, val)) # Allow up to 125K for 128K context window
@validator("context_pruning_threshold", pre=True)
def validate_pruning_threshold(cls, v):
"""Validate context pruning threshold"""
val = int(v) if v else 115000
return max(4000, min(125000, val)) # Match context_preparation_budget limits
@validator("zerogpu_model_context_window", pre=True)
def validate_context_window(cls, v):
"""Validate context window size"""
val = int(v) if v else 8192
return max(1000, min(200000, val)) # Support up to 200K for future models
# ==================== Model Configuration ====================
default_model: str = Field(
default="meta-llama/Llama-3.1-8B-Instruct:cerebras",
description="Primary model for reasoning tasks (Cerebras deployment with 4-bit quantization)"
)
embedding_model: str = Field(
default="intfloat/e5-large-v2",
description="Model for embeddings (upgraded: 1024-dim embeddings)"
)
classification_model: str = Field(
default="meta-llama/Llama-3.1-8B-Instruct:cerebras",
description="Model for classification tasks (Cerebras deployment)"
)
# ==================== Performance Configuration ====================
max_workers: int = Field(
default=4,
description="Maximum worker threads for parallel processing",
env="MAX_WORKERS"
)
@validator("max_workers", pre=True)
def validate_max_workers(cls, v):
"""Validate and convert max_workers (backward compatible)"""
if v is None:
return 4
if isinstance(v, str):
try:
v = int(v)
except ValueError:
logger.warning(f"Invalid MAX_WORKERS value: {v}, using default 4")
return 4
try:
val = int(v)
return max(1, min(16, val)) # Clamp between 1 and 16
except (ValueError, TypeError):
return 4
cache_ttl: int = Field(
default=3600,
description="Cache time-to-live in seconds",
env="CACHE_TTL"
)
@validator("cache_ttl", pre=True)
def validate_cache_ttl(cls, v):
"""Validate cache TTL (backward compatible)"""
if v is None:
return 3600
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 3600
try:
return max(0, int(v))
except (ValueError, TypeError):
return 3600
# ==================== Database Configuration ====================
db_path: str = Field(
default="sessions.db",
description="Path to SQLite database file",
env="DB_PATH"
)
@validator("db_path", pre=True)
def validate_db_path(cls, v):
"""Validate db_path with Docker fallback (backward compatible)"""
if v is None:
# Check if we're in Docker (HF Spaces) - if so, use /tmp
if os.path.exists("/.dockerenv") or os.path.exists("/tmp"):
return "/tmp/sessions.db"
return "sessions.db"
return str(v)
faiss_index_path: str = Field(
default="embeddings.faiss",
description="Path to FAISS index file",
env="FAISS_INDEX_PATH"
)
@validator("faiss_index_path", pre=True)
def validate_faiss_path(cls, v):
"""Validate faiss path with Docker fallback (backward compatible)"""
if v is None:
# Check if we're in Docker (HF Spaces) - if so, use /tmp
if os.path.exists("/.dockerenv") or os.path.exists("/tmp"):
return "/tmp/embeddings.faiss"
return "embeddings.faiss"
return str(v)
# ==================== Session Configuration ====================
session_timeout: int = Field(
default=3600,
description="Session timeout in seconds",
env="SESSION_TIMEOUT"
)
@validator("session_timeout", pre=True)
def validate_session_timeout(cls, v):
"""Validate session timeout (backward compatible)"""
if v is None:
return 3600
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 3600
try:
return max(60, int(v))
except (ValueError, TypeError):
return 3600
max_session_size_mb: int = Field(
default=10,
description="Maximum session size in megabytes",
env="MAX_SESSION_SIZE_MB"
)
@validator("max_session_size_mb", pre=True)
def validate_max_session_size(cls, v):
"""Validate max session size (backward compatible)"""
if v is None:
return 10
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 10
try:
return max(1, min(100, int(v)))
except (ValueError, TypeError):
return 10
# ==================== Mobile Optimization ====================
mobile_max_tokens: int = Field(
default=800,
description="Maximum tokens for mobile responses",
env="MOBILE_MAX_TOKENS"
)
@validator("mobile_max_tokens", pre=True)
def validate_mobile_max_tokens(cls, v):
"""Validate mobile max tokens (backward compatible)"""
if v is None:
return 800
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 800
try:
return max(100, min(2000, int(v)))
except (ValueError, TypeError):
return 800
mobile_timeout: int = Field(
default=15000,
description="Mobile request timeout in milliseconds",
env="MOBILE_TIMEOUT"
)
@validator("mobile_timeout", pre=True)
def validate_mobile_timeout(cls, v):
"""Validate mobile timeout (backward compatible)"""
if v is None:
return 15000
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 15000
try:
return max(5000, min(60000, int(v)))
except (ValueError, TypeError):
return 15000
# ==================== API Configuration ====================
gradio_port: int = Field(
default=7860,
description="Gradio server port",
env="GRADIO_PORT"
)
@validator("gradio_port", pre=True)
def validate_gradio_port(cls, v):
"""Validate gradio port (backward compatible)"""
if v is None:
return 7860
if isinstance(v, str):
try:
v = int(v)
except ValueError:
return 7860
try:
return max(1024, min(65535, int(v)))
except (ValueError, TypeError):
return 7860
gradio_host: str = Field(
default="0.0.0.0",
description="Gradio server host",
env="GRADIO_HOST"
)
# ==================== Logging Configuration ====================
log_level: str = Field(
default="INFO",
description="Logging level (DEBUG, INFO, WARNING, ERROR, CRITICAL)",
env="LOG_LEVEL"
)
@validator("log_level")
def validate_log_level(cls, v):
"""Validate log level (backward compatible)"""
if not v:
return "INFO"
valid_levels = ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
if v.upper() not in valid_levels:
logger.warning(f"Invalid log level: {v}, using INFO")
return "INFO"
return v.upper()
log_format: str = Field(
default="json",
description="Log format (json or text)",
env="LOG_FORMAT"
)
@validator("log_format")
def validate_log_format(cls, v):
"""Validate log format (backward compatible)"""
if not v:
return "json"
if v.lower() not in ["json", "text"]:
logger.warning(f"Invalid log format: {v}, using json")
return "json"
return v.lower()
# ==================== Pydantic Configuration ====================
class Config:
"""Pydantic configuration"""
env_file = ".env"
env_file_encoding = "utf-8"
case_sensitive = False
validate_assignment = True
# Allow extra fields for backward compatibility
extra = "ignore"
# ==================== Utility Methods ====================
def validate_configuration(self) -> bool:
"""
Validate configuration and log status.
Returns:
bool: True if configuration is valid, False otherwise
"""
try:
# Validate cache directory
cache_dir = self.hf_cache_dir
if logger.isEnabledFor(logging.INFO):
logger.info("Configuration validated:")
logger.info(f" - Cache directory: {cache_dir}")
logger.info(f" - Max workers: {self.max_workers}")
logger.info(f" - Log level: {self.log_level}")
logger.info(f" - HF token: {'Set' if self.hf_token else 'Not set'}")
return True
except Exception as e:
logger.error(f"Configuration validation failed: {e}")
return False
# ==================== Global Settings Instance ====================
def get_settings() -> Settings:
"""
Get or create global settings instance.
Returns:
Settings: Global settings instance
Note:
This function ensures settings are loaded once and cached.
"""
if not hasattr(get_settings, '_instance'):
get_settings._instance = Settings()
# Validate on first load (non-blocking)
try:
get_settings._instance.validate_configuration()
except Exception as e:
logger.warning(f"Configuration validation warning: {e}")
return get_settings._instance
# Create global settings instance (backward compatible)
settings = get_settings()
# Log configuration on import (at INFO level, non-blocking)
if logger.isEnabledFor(logging.INFO):
try:
logger.info("=" * 60)
logger.info("Configuration Loaded")
logger.info("=" * 60)
logger.info(f"Cache directory: {settings.hf_cache_dir}")
logger.info(f"Max workers: {settings.max_workers}")
logger.info(f"Log level: {settings.log_level}")
logger.info("=" * 60)
except Exception as e:
logger.debug(f"Configuration logging skipped: {e}")
|