HonestAI / src /local_model_loader.py
JatsTheAIGen's picture
Fix: DynamicCache compatibility, dependencies, and Docker configuration
ea87e33
# local_model_loader.py
# Local GPU-based model loading for NVIDIA T4 Medium (16GB VRAM)
# Optimized with 4-bit quantization to fit larger models
import logging
import os
import torch
from typing import Optional, Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
# Import GatedRepoError for handling gated repositories
try:
from huggingface_hub.exceptions import GatedRepoError
from huggingface_hub import login as hf_login
except ImportError:
# Fallback if huggingface_hub is not available
GatedRepoError = Exception
hf_login = None
# Import settings for cache directory and HF token
try:
from .config import settings
except ImportError:
try:
from config import settings
except ImportError:
settings = None
logger = logging.getLogger(__name__)
class LocalModelLoader:
"""
Loads and manages models locally on GPU for faster inference.
Optimized for NVIDIA T4 Medium with 16GB VRAM using 4-bit quantization.
"""
def __init__(self, device: Optional[str] = None):
"""Initialize the model loader with GPU device detection."""
# Detect device
if device is None:
if torch.cuda.is_available():
self.device = "cuda"
self.device_name = torch.cuda.get_device_name(0)
logger.info(f"GPU detected: {self.device_name}")
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
self.device = "cpu"
self.device_name = "CPU"
logger.warning("No GPU detected, using CPU")
else:
self.device = device
self.device_name = device
# Get cache directory from settings
if settings:
self.cache_dir = settings.hf_cache_dir
self.hf_token = settings.hf_token
else:
# Fallback to environment variables
self.cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or "/tmp/huggingface"
self.hf_token = os.getenv("HF_TOKEN", "")
# Ensure cache directory exists and is writable
os.makedirs(self.cache_dir, exist_ok=True)
# Set environment variables for transformers/huggingface_hub
if not os.getenv("HF_HOME"):
os.environ["HF_HOME"] = self.cache_dir
if not os.getenv("TRANSFORMERS_CACHE"):
os.environ["TRANSFORMERS_CACHE"] = self.cache_dir
logger.info(f"Cache directory: {self.cache_dir}")
# Login to Hugging Face if token is provided (needed for gated repositories)
if self.hf_token and hf_login:
try:
hf_login(token=self.hf_token, add_to_git_credential=False)
logger.info("✓ HF_TOKEN authenticated for gated model access")
except Exception as e:
logger.warning(f"HF_TOKEN login failed (may not be needed): {e}")
# Model cache
self.loaded_models: Dict[str, Any] = {}
self.loaded_tokenizers: Dict[str, Any] = {}
self.loaded_embedding_models: Dict[str, Any] = {}
def load_chat_model(self, model_id: str, load_in_8bit: bool = False, load_in_4bit: bool = False) -> tuple:
"""
Load a chat model and tokenizer on GPU.
Args:
model_id: HuggingFace model identifier
load_in_8bit: Use 8-bit quantization (saves memory)
load_in_4bit: Use 4-bit quantization (saves more memory)
Returns:
Tuple of (model, tokenizer)
"""
if model_id in self.loaded_models:
logger.info(f"Model {model_id} already loaded, reusing")
return self.loaded_models[model_id], self.loaded_tokenizers[model_id]
try:
logger.info(f"Loading model {model_id} on {self.device}...")
# Strip API-specific suffixes (e.g., :cerebras, :novita) for local loading
# These suffixes are typically used for API endpoints, not local model identifiers
base_model_id = model_id.split(':')[0] if ':' in model_id else model_id
if base_model_id != model_id:
logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
# Load tokenizer with cache directory
# This will fail with actual GatedRepoError if model is gated
try:
tokenizer = AutoTokenizer.from_pretrained(
base_model_id,
cache_dir=self.cache_dir,
token=self.hf_token if self.hf_token else None,
trust_remote_code=True
)
except Exception as e:
# Check if this is actually a gated repo error
error_str = str(e).lower()
if "gated" in error_str or "authorized" in error_str or "access" in error_str:
# This might be a gated repo error
try:
from huggingface_hub.exceptions import GatedRepoError as RealGatedRepoError
if isinstance(e, RealGatedRepoError):
logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
logger.error(f" Error details: {e}")
raise RealGatedRepoError(
f"Cannot access gated repository {base_model_id}. "
f"Visit https://huggingface.co/{base_model_id} to request access."
) from e
except ImportError:
pass
# If it's not a gated repo error, re-raise as-is
raise
# Determine quantization config
if load_in_4bit and self.device == "cuda":
try:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
logger.info("Using 4-bit quantization")
except ImportError:
logger.warning("bitsandbytes not available, loading without quantization")
quantization_config = None
elif load_in_8bit and self.device == "cuda":
try:
quantization_config = {"load_in_8bit": True}
logger.info("Using 8-bit quantization")
except:
quantization_config = None
else:
quantization_config = None
# Load model with GPU optimization and cache directory
# Try with quantization first, fallback to no quantization if bitsandbytes fails
load_kwargs = {
"cache_dir": self.cache_dir,
"token": self.hf_token if self.hf_token else None,
"trust_remote_code": True
}
if self.device == "cuda":
# Use explicit device placement to avoid meta device issues
# device_map="auto" works well with quantization, but can cause issues without it
load_kwargs.update({
"torch_dtype": torch.float16, # Use FP16 for memory efficiency
})
# Only use device_map="auto" with quantization, otherwise use explicit placement
# This prevents "Tensor on device meta" errors
# Try loading with quantization first
model = None
quantization_failed = False
if quantization_config and self.device == "cuda":
try:
if isinstance(quantization_config, dict):
load_kwargs.update(quantization_config)
else:
load_kwargs["quantization_config"] = quantization_config
# With quantization, device_map="auto" works correctly
load_kwargs["device_map"] = "auto"
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
**load_kwargs
)
logger.info("✓ Model loaded with quantization")
except (RuntimeError, ModuleNotFoundError, ImportError) as e:
error_str = str(e).lower()
# Check if error is related to bitsandbytes
if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str or "validate_bnb_backend" in error_str:
logger.warning(f"⚠ BitsAndBytes error detected: {e}")
logger.warning("⚠ Falling back to loading without quantization")
quantization_failed = True
# Remove quantization config and retry
load_kwargs.pop("quantization_config", None)
load_kwargs.pop("load_in_8bit", None)
load_kwargs.pop("load_in_4bit", None)
else:
# Re-raise if it's not a bitsandbytes error
raise
# If quantization failed or not using quantization, load without it
if model is None:
try:
if self.device == "cuda":
# Without quantization, use explicit device placement to avoid meta device issues
# Don't use device_map="auto" here - it can cause tensor placement errors
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
**load_kwargs
)
# Explicitly move to GPU after loading
model = model.to(self.device)
logger.info(f"✓ Model loaded without quantization on {self.device}")
else:
load_kwargs.update({
"torch_dtype": torch.float32,
})
model = AutoModelForCausalLM.from_pretrained(
base_model_id,
**load_kwargs
)
model = model.to(self.device)
except Exception as e:
# Check if this is a gated repo error (not bitsandbytes)
error_str = str(e).lower()
if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str:
# BitsAndBytes error - should have been caught earlier
logger.error(f"❌ Unexpected BitsAndBytes error: {e}")
raise RuntimeError(f"BitsAndBytes compatibility issue: {e}") from e
# Check for actual gated repo error
try:
from huggingface_hub.exceptions import GatedRepoError as RealGatedRepoError
if isinstance(e, RealGatedRepoError) or "gated" in error_str or "authorized" in error_str:
logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
logger.error(f" Error details: {e}")
raise RealGatedRepoError(
f"Cannot access gated repository {base_model_id}. "
f"Visit https://huggingface.co/{base_model_id} to request access."
) from e
except ImportError:
pass
# Re-raise other errors as-is
raise
# Ensure padding token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Cache models (use original model_id for cache key to maintain API compatibility)
self.loaded_models[model_id] = model
self.loaded_tokenizers[model_id] = tokenizer
# Log memory usage
if self.device == "cuda":
allocated = torch.cuda.memory_allocated(0) / 1024**3
reserved = torch.cuda.memory_reserved(0) / 1024**3
logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
logger.info(f"✓ Model {model_id} (base: {base_model_id}) loaded successfully on {self.device}")
return model, tokenizer
except GatedRepoError:
# Re-raise GatedRepoError to be handled by caller
raise
except Exception as e:
logger.error(f"Error loading model {model_id}: {e}", exc_info=True)
raise
def load_embedding_model(self, model_id: str) -> SentenceTransformer:
"""
Load a sentence transformer model for embeddings.
Args:
model_id: HuggingFace model identifier
Returns:
SentenceTransformer model
"""
if model_id in self.loaded_embedding_models:
logger.info(f"Embedding model {model_id} already loaded, reusing")
return self.loaded_embedding_models[model_id]
try:
logger.info(f"Loading embedding model {model_id}...")
# Strip API-specific suffixes for local loading
base_model_id = model_id.split(':')[0] if ':' in model_id else model_id
if base_model_id != model_id:
logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}")
# SentenceTransformer automatically handles GPU
# Note: SentenceTransformer uses cache_dir from environment or default location
# We can't directly pass cache_dir, but we've set HF_HOME and TRANSFORMERS_CACHE
try:
model = SentenceTransformer(
base_model_id,
device=self.device
)
except GatedRepoError as e:
logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}")
logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.")
logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.")
logger.error(f" Error details: {e}")
raise GatedRepoError(
f"Cannot access gated repository {base_model_id}. "
f"Visit https://huggingface.co/{base_model_id} to request access."
) from e
# Cache model (use original model_id for cache key)
self.loaded_embedding_models[model_id] = model
logger.info(f"✓ Embedding model {model_id} (base: {base_model_id}) loaded successfully on {self.device}")
return model
except GatedRepoError:
# Re-raise GatedRepoError to be handled by caller
raise
except Exception as e:
logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True)
raise
def generate_text(
self,
model_id: str,
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Generate text using a loaded chat model.
Args:
model_id: Model identifier
prompt: Input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Generated text
"""
if model_id not in self.loaded_models:
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
model = self.loaded_models[model_id]
tokenizer = self.loaded_tokenizers[model_id]
try:
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
# Prepare generation kwargs
generation_kwargs = {
"max_new_tokens": max_tokens,
"temperature": temperature,
"do_sample": True,
"pad_token_id": tokenizer.pad_token_id,
"eos_token_id": tokenizer.eos_token_id,
}
# Add compatibility fix for Phi-3 DynamicCache issues
# Phi-3 models may use DynamicCache which doesn't have seen_tokens in some versions
if "phi" in model_id.lower() or "phi3" in model_id.lower() or "phi-3" in model_id.lower():
# Use cache=False as workaround for DynamicCache.seen_tokens AttributeError
generation_kwargs["use_cache"] = False
logger.debug(f"Using use_cache=False for Phi-3 model to avoid DynamicCache compatibility issues")
# Merge additional kwargs (may override above settings)
generation_kwargs.update(kwargs)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
**generation_kwargs
)
# Decode
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove prompt from output if present
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
return generated_text
except AttributeError as e:
# Handle DynamicCache.seen_tokens AttributeError specifically
if "seen_tokens" in str(e) or "DynamicCache" in str(e):
logger.warning(f"DynamicCache compatibility issue detected ({e}), retrying without cache")
try:
# Retry without cache to avoid DynamicCache issues
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
use_cache=False, # Disable cache to avoid DynamicCache issues
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
**{k: v for k, v in kwargs.items() if k != "use_cache"} # Remove use_cache from kwargs
)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
logger.info("✓ Generation successful after DynamicCache workaround")
return generated_text
except Exception as retry_error:
logger.error(f"Retry without cache also failed: {retry_error}", exc_info=True)
raise RuntimeError(f"Generation failed even with cache disabled: {retry_error}") from retry_error
# Re-raise if it's a different AttributeError
raise
except Exception as e:
logger.error(f"Error generating text: {e}", exc_info=True)
raise
def generate_chat_completion(
self,
model_id: str,
messages: list,
max_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Generate chat completion using a loaded model.
Args:
model_id: Model identifier
messages: List of message dicts with 'role' and 'content'
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Generated response
"""
if model_id not in self.loaded_models:
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
model = self.loaded_models[model_id]
tokenizer = self.loaded_tokenizers[model_id]
try:
# Format messages as prompt
if hasattr(tokenizer, 'apply_chat_template'):
# Use chat template if available
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# Fallback: simple formatting
prompt = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in messages
]) + "\nassistant: "
# Generate
return self.generate_text(
model_id=model_id,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
**kwargs
)
except Exception as e:
logger.error(f"Error generating chat completion: {e}", exc_info=True)
raise
def get_embedding(self, model_id: str, text: str) -> list:
"""
Get embedding vector for text.
Args:
model_id: Embedding model identifier
text: Input text
Returns:
Embedding vector
"""
if model_id not in self.loaded_embedding_models:
raise ValueError(f"Embedding model {model_id} not loaded. Call load_embedding_model() first.")
model = self.loaded_embedding_models[model_id]
try:
embedding = model.encode(text, convert_to_numpy=True)
return embedding.tolist()
except Exception as e:
logger.error(f"Error getting embedding: {e}", exc_info=True)
raise
def clear_cache(self):
"""Clear all loaded models from memory."""
logger.info("Clearing model cache...")
# Clear models
for model_id in list(self.loaded_models.keys()):
del self.loaded_models[model_id]
for model_id in list(self.loaded_tokenizers.keys()):
del self.loaded_tokenizers[model_id]
for model_id in list(self.loaded_embedding_models.keys()):
del self.loaded_embedding_models[model_id]
# Clear GPU cache
if self.device == "cuda":
torch.cuda.empty_cache()
logger.info("✓ Model cache cleared")
def get_memory_usage(self) -> Dict[str, float]:
"""Get current GPU memory usage in GB."""
if self.device != "cuda":
return {"device": "cpu", "gpu_available": False}
return {
"device": self.device_name,
"gpu_available": True,
"allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
"reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
}