# local_model_loader.py # Local GPU-based model loading for NVIDIA T4 Medium (16GB VRAM) # Optimized with 4-bit quantization to fit larger models import logging import os import torch from typing import Optional, Dict, Any from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel from sentence_transformers import SentenceTransformer # Import GatedRepoError for handling gated repositories try: from huggingface_hub.exceptions import GatedRepoError from huggingface_hub import login as hf_login except ImportError: # Fallback if huggingface_hub is not available GatedRepoError = Exception hf_login = None # Import settings for cache directory and HF token try: from .config import settings except ImportError: try: from config import settings except ImportError: settings = None logger = logging.getLogger(__name__) class LocalModelLoader: """ Loads and manages models locally on GPU for faster inference. Optimized for NVIDIA T4 Medium with 16GB VRAM using 4-bit quantization. """ def __init__(self, device: Optional[str] = None): """Initialize the model loader with GPU device detection.""" # Detect device if device is None: if torch.cuda.is_available(): self.device = "cuda" self.device_name = torch.cuda.get_device_name(0) logger.info(f"GPU detected: {self.device_name}") logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") else: self.device = "cpu" self.device_name = "CPU" logger.warning("No GPU detected, using CPU") else: self.device = device self.device_name = device # Get cache directory from settings if settings: self.cache_dir = settings.hf_cache_dir self.hf_token = settings.hf_token else: # Fallback to environment variables self.cache_dir = os.getenv("HF_HOME") or os.getenv("TRANSFORMERS_CACHE") or "/tmp/huggingface" self.hf_token = os.getenv("HF_TOKEN", "") # Ensure cache directory exists and is writable os.makedirs(self.cache_dir, exist_ok=True) # Set environment variables for transformers/huggingface_hub if not os.getenv("HF_HOME"): os.environ["HF_HOME"] = self.cache_dir if not os.getenv("TRANSFORMERS_CACHE"): os.environ["TRANSFORMERS_CACHE"] = self.cache_dir logger.info(f"Cache directory: {self.cache_dir}") # Login to Hugging Face if token is provided (needed for gated repositories) if self.hf_token and hf_login: try: hf_login(token=self.hf_token, add_to_git_credential=False) logger.info("✓ HF_TOKEN authenticated for gated model access") except Exception as e: logger.warning(f"HF_TOKEN login failed (may not be needed): {e}") # Model cache self.loaded_models: Dict[str, Any] = {} self.loaded_tokenizers: Dict[str, Any] = {} self.loaded_embedding_models: Dict[str, Any] = {} def load_chat_model(self, model_id: str, load_in_8bit: bool = False, load_in_4bit: bool = False) -> tuple: """ Load a chat model and tokenizer on GPU. Args: model_id: HuggingFace model identifier load_in_8bit: Use 8-bit quantization (saves memory) load_in_4bit: Use 4-bit quantization (saves more memory) Returns: Tuple of (model, tokenizer) """ if model_id in self.loaded_models: logger.info(f"Model {model_id} already loaded, reusing") return self.loaded_models[model_id], self.loaded_tokenizers[model_id] try: logger.info(f"Loading model {model_id} on {self.device}...") # Strip API-specific suffixes (e.g., :cerebras, :novita) for local loading # These suffixes are typically used for API endpoints, not local model identifiers base_model_id = model_id.split(':')[0] if ':' in model_id else model_id if base_model_id != model_id: logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}") # Load tokenizer with cache directory # This will fail with actual GatedRepoError if model is gated try: tokenizer = AutoTokenizer.from_pretrained( base_model_id, cache_dir=self.cache_dir, token=self.hf_token if self.hf_token else None, trust_remote_code=True ) except Exception as e: # Check if this is actually a gated repo error error_str = str(e).lower() if "gated" in error_str or "authorized" in error_str or "access" in error_str: # This might be a gated repo error try: from huggingface_hub.exceptions import GatedRepoError as RealGatedRepoError if isinstance(e, RealGatedRepoError): logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}") logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.") logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.") logger.error(f" Error details: {e}") raise RealGatedRepoError( f"Cannot access gated repository {base_model_id}. " f"Visit https://huggingface.co/{base_model_id} to request access." ) from e except ImportError: pass # If it's not a gated repo error, re-raise as-is raise # Determine quantization config if load_in_4bit and self.device == "cuda": try: from transformers import BitsAndBytesConfig quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) logger.info("Using 4-bit quantization") except ImportError: logger.warning("bitsandbytes not available, loading without quantization") quantization_config = None elif load_in_8bit and self.device == "cuda": try: quantization_config = {"load_in_8bit": True} logger.info("Using 8-bit quantization") except: quantization_config = None else: quantization_config = None # Load model with GPU optimization and cache directory # Try with quantization first, fallback to no quantization if bitsandbytes fails load_kwargs = { "cache_dir": self.cache_dir, "token": self.hf_token if self.hf_token else None, "trust_remote_code": True } if self.device == "cuda": # Use explicit device placement to avoid meta device issues # device_map="auto" works well with quantization, but can cause issues without it load_kwargs.update({ "torch_dtype": torch.float16, # Use FP16 for memory efficiency }) # Only use device_map="auto" with quantization, otherwise use explicit placement # This prevents "Tensor on device meta" errors # Try loading with quantization first model = None quantization_failed = False if quantization_config and self.device == "cuda": try: if isinstance(quantization_config, dict): load_kwargs.update(quantization_config) else: load_kwargs["quantization_config"] = quantization_config # With quantization, device_map="auto" works correctly load_kwargs["device_map"] = "auto" model = AutoModelForCausalLM.from_pretrained( base_model_id, **load_kwargs ) logger.info("✓ Model loaded with quantization") except (RuntimeError, ModuleNotFoundError, ImportError) as e: error_str = str(e).lower() # Check if error is related to bitsandbytes if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str or "validate_bnb_backend" in error_str: logger.warning(f"⚠ BitsAndBytes error detected: {e}") logger.warning("⚠ Falling back to loading without quantization") quantization_failed = True # Remove quantization config and retry load_kwargs.pop("quantization_config", None) load_kwargs.pop("load_in_8bit", None) load_kwargs.pop("load_in_4bit", None) else: # Re-raise if it's not a bitsandbytes error raise # If quantization failed or not using quantization, load without it if model is None: try: if self.device == "cuda": # Without quantization, use explicit device placement to avoid meta device issues # Don't use device_map="auto" here - it can cause tensor placement errors model = AutoModelForCausalLM.from_pretrained( base_model_id, **load_kwargs ) # Explicitly move to GPU after loading model = model.to(self.device) logger.info(f"✓ Model loaded without quantization on {self.device}") else: load_kwargs.update({ "torch_dtype": torch.float32, }) model = AutoModelForCausalLM.from_pretrained( base_model_id, **load_kwargs ) model = model.to(self.device) except Exception as e: # Check if this is a gated repo error (not bitsandbytes) error_str = str(e).lower() if "bitsandbytes" in error_str or "int8_mm_dequant" in error_str: # BitsAndBytes error - should have been caught earlier logger.error(f"❌ Unexpected BitsAndBytes error: {e}") raise RuntimeError(f"BitsAndBytes compatibility issue: {e}") from e # Check for actual gated repo error try: from huggingface_hub.exceptions import GatedRepoError as RealGatedRepoError if isinstance(e, RealGatedRepoError) or "gated" in error_str or "authorized" in error_str: logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}") logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.") logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.") logger.error(f" Error details: {e}") raise RealGatedRepoError( f"Cannot access gated repository {base_model_id}. " f"Visit https://huggingface.co/{base_model_id} to request access." ) from e except ImportError: pass # Re-raise other errors as-is raise # Ensure padding token is set if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Cache models (use original model_id for cache key to maintain API compatibility) self.loaded_models[model_id] = model self.loaded_tokenizers[model_id] = tokenizer # Log memory usage if self.device == "cuda": allocated = torch.cuda.memory_allocated(0) / 1024**3 reserved = torch.cuda.memory_reserved(0) / 1024**3 logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB") logger.info(f"✓ Model {model_id} (base: {base_model_id}) loaded successfully on {self.device}") return model, tokenizer except GatedRepoError: # Re-raise GatedRepoError to be handled by caller raise except Exception as e: logger.error(f"Error loading model {model_id}: {e}", exc_info=True) raise def load_embedding_model(self, model_id: str) -> SentenceTransformer: """ Load a sentence transformer model for embeddings. Args: model_id: HuggingFace model identifier Returns: SentenceTransformer model """ if model_id in self.loaded_embedding_models: logger.info(f"Embedding model {model_id} already loaded, reusing") return self.loaded_embedding_models[model_id] try: logger.info(f"Loading embedding model {model_id}...") # Strip API-specific suffixes for local loading base_model_id = model_id.split(':')[0] if ':' in model_id else model_id if base_model_id != model_id: logger.info(f"Stripping API suffix from {model_id}, using base model: {base_model_id}") # SentenceTransformer automatically handles GPU # Note: SentenceTransformer uses cache_dir from environment or default location # We can't directly pass cache_dir, but we've set HF_HOME and TRANSFORMERS_CACHE try: model = SentenceTransformer( base_model_id, device=self.device ) except GatedRepoError as e: logger.error(f"❌ Gated Repository Error: Cannot access gated repo {base_model_id}") logger.error(f" Access to model {base_model_id} is restricted and you are not in the authorized list.") logger.error(f" Visit https://huggingface.co/{base_model_id} to request access.") logger.error(f" Error details: {e}") raise GatedRepoError( f"Cannot access gated repository {base_model_id}. " f"Visit https://huggingface.co/{base_model_id} to request access." ) from e # Cache model (use original model_id for cache key) self.loaded_embedding_models[model_id] = model logger.info(f"✓ Embedding model {model_id} (base: {base_model_id}) loaded successfully on {self.device}") return model except GatedRepoError: # Re-raise GatedRepoError to be handled by caller raise except Exception as e: logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True) raise def generate_text( self, model_id: str, prompt: str, max_tokens: int = 512, temperature: float = 0.7, **kwargs ) -> str: """ Generate text using a loaded chat model. Args: model_id: Model identifier prompt: Input prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature Returns: Generated text """ if model_id not in self.loaded_models: raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.") model = self.loaded_models[model_id] tokenizer = self.loaded_tokenizers[model_id] try: # Tokenize input inputs = tokenizer(prompt, return_tensors="pt").to(self.device) # Prepare generation kwargs generation_kwargs = { "max_new_tokens": max_tokens, "temperature": temperature, "do_sample": True, "pad_token_id": tokenizer.pad_token_id, "eos_token_id": tokenizer.eos_token_id, } # Add compatibility fix for Phi-3 DynamicCache issues # Phi-3 models may use DynamicCache which doesn't have seen_tokens in some versions if "phi" in model_id.lower() or "phi3" in model_id.lower() or "phi-3" in model_id.lower(): # Use cache=False as workaround for DynamicCache.seen_tokens AttributeError generation_kwargs["use_cache"] = False logger.debug(f"Using use_cache=False for Phi-3 model to avoid DynamicCache compatibility issues") # Merge additional kwargs (may override above settings) generation_kwargs.update(kwargs) # Generate with torch.no_grad(): outputs = model.generate( **inputs, **generation_kwargs ) # Decode generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) # Remove prompt from output if present if generated_text.startswith(prompt): generated_text = generated_text[len(prompt):].strip() return generated_text except AttributeError as e: # Handle DynamicCache.seen_tokens AttributeError specifically if "seen_tokens" in str(e) or "DynamicCache" in str(e): logger.warning(f"DynamicCache compatibility issue detected ({e}), retrying without cache") try: # Retry without cache to avoid DynamicCache issues with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_tokens, temperature=temperature, do_sample=True, use_cache=False, # Disable cache to avoid DynamicCache issues pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, **{k: v for k, v in kwargs.items() if k != "use_cache"} # Remove use_cache from kwargs ) generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) if generated_text.startswith(prompt): generated_text = generated_text[len(prompt):].strip() logger.info("✓ Generation successful after DynamicCache workaround") return generated_text except Exception as retry_error: logger.error(f"Retry without cache also failed: {retry_error}", exc_info=True) raise RuntimeError(f"Generation failed even with cache disabled: {retry_error}") from retry_error # Re-raise if it's a different AttributeError raise except Exception as e: logger.error(f"Error generating text: {e}", exc_info=True) raise def generate_chat_completion( self, model_id: str, messages: list, max_tokens: int = 512, temperature: float = 0.7, **kwargs ) -> str: """ Generate chat completion using a loaded model. Args: model_id: Model identifier messages: List of message dicts with 'role' and 'content' max_tokens: Maximum tokens to generate temperature: Sampling temperature Returns: Generated response """ if model_id not in self.loaded_models: raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.") model = self.loaded_models[model_id] tokenizer = self.loaded_tokenizers[model_id] try: # Format messages as prompt if hasattr(tokenizer, 'apply_chat_template'): # Use chat template if available prompt = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) else: # Fallback: simple formatting prompt = "\n".join([ f"{msg['role']}: {msg['content']}" for msg in messages ]) + "\nassistant: " # Generate return self.generate_text( model_id=model_id, prompt=prompt, max_tokens=max_tokens, temperature=temperature, **kwargs ) except Exception as e: logger.error(f"Error generating chat completion: {e}", exc_info=True) raise def get_embedding(self, model_id: str, text: str) -> list: """ Get embedding vector for text. Args: model_id: Embedding model identifier text: Input text Returns: Embedding vector """ if model_id not in self.loaded_embedding_models: raise ValueError(f"Embedding model {model_id} not loaded. Call load_embedding_model() first.") model = self.loaded_embedding_models[model_id] try: embedding = model.encode(text, convert_to_numpy=True) return embedding.tolist() except Exception as e: logger.error(f"Error getting embedding: {e}", exc_info=True) raise def clear_cache(self): """Clear all loaded models from memory.""" logger.info("Clearing model cache...") # Clear models for model_id in list(self.loaded_models.keys()): del self.loaded_models[model_id] for model_id in list(self.loaded_tokenizers.keys()): del self.loaded_tokenizers[model_id] for model_id in list(self.loaded_embedding_models.keys()): del self.loaded_embedding_models[model_id] # Clear GPU cache if self.device == "cuda": torch.cuda.empty_cache() logger.info("✓ Model cache cleared") def get_memory_usage(self) -> Dict[str, float]: """Get current GPU memory usage in GB.""" if self.device != "cuda": return {"device": "cpu", "gpu_available": False} return { "device": self.device_name, "gpu_available": True, "allocated_gb": torch.cuda.memory_allocated(0) / 1024**3, "reserved_gb": torch.cuda.memory_reserved(0) / 1024**3, "total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3 }