Research_AI_Assistant / src /local_model_loader.py
JatsTheAIGen's picture
api migration v2
7632802
raw
history blame
11.9 kB
# local_model_loader.py
# Local GPU-based model loading for NVIDIA T4 Medium (24GB vRAM)
import logging
import torch
from typing import Optional, Dict, Any
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
class LocalModelLoader:
"""
Loads and manages models locally on GPU for faster inference.
Optimized for NVIDIA T4 Medium with 24GB vRAM.
"""
def __init__(self, device: Optional[str] = None):
"""Initialize the model loader with GPU device detection."""
# Detect device
if device is None:
if torch.cuda.is_available():
self.device = "cuda"
self.device_name = torch.cuda.get_device_name(0)
logger.info(f"GPU detected: {self.device_name}")
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")
else:
self.device = "cpu"
self.device_name = "CPU"
logger.warning("No GPU detected, using CPU")
else:
self.device = device
self.device_name = device
# Model cache
self.loaded_models: Dict[str, Any] = {}
self.loaded_tokenizers: Dict[str, Any] = {}
self.loaded_embedding_models: Dict[str, Any] = {}
def load_chat_model(self, model_id: str, load_in_8bit: bool = False, load_in_4bit: bool = False) -> tuple:
"""
Load a chat model and tokenizer on GPU.
Args:
model_id: HuggingFace model identifier
load_in_8bit: Use 8-bit quantization (saves memory)
load_in_4bit: Use 4-bit quantization (saves more memory)
Returns:
Tuple of (model, tokenizer)
"""
if model_id in self.loaded_models:
logger.info(f"Model {model_id} already loaded, reusing")
return self.loaded_models[model_id], self.loaded_tokenizers[model_id]
try:
logger.info(f"Loading model {model_id} on {self.device}...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
model_id,
trust_remote_code=True
)
# Determine quantization config
if load_in_4bit and self.device == "cuda":
try:
from transformers import BitsAndBytesConfig
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
logger.info("Using 4-bit quantization")
except ImportError:
logger.warning("bitsandbytes not available, loading without quantization")
quantization_config = None
elif load_in_8bit and self.device == "cuda":
try:
quantization_config = {"load_in_8bit": True}
logger.info("Using 8-bit quantization")
except:
quantization_config = None
else:
quantization_config = None
# Load model with GPU optimization
if self.device == "cuda":
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto", # Automatically uses GPU
torch_dtype=torch.float16, # Use FP16 for memory efficiency
trust_remote_code=True,
**(quantization_config if isinstance(quantization_config, dict) else {}),
**({"quantization_config": quantization_config} if quantization_config and not isinstance(quantization_config, dict) else {})
)
else:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float32,
trust_remote_code=True
)
model = model.to(self.device)
# Ensure padding token is set
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Cache models
self.loaded_models[model_id] = model
self.loaded_tokenizers[model_id] = tokenizer
# Log memory usage
if self.device == "cuda":
allocated = torch.cuda.memory_allocated(0) / 1024**3
reserved = torch.cuda.memory_reserved(0) / 1024**3
logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB")
logger.info(f"✓ Model {model_id} loaded successfully on {self.device}")
return model, tokenizer
except Exception as e:
logger.error(f"Error loading model {model_id}: {e}", exc_info=True)
raise
def load_embedding_model(self, model_id: str) -> SentenceTransformer:
"""
Load a sentence transformer model for embeddings.
Args:
model_id: HuggingFace model identifier
Returns:
SentenceTransformer model
"""
if model_id in self.loaded_embedding_models:
logger.info(f"Embedding model {model_id} already loaded, reusing")
return self.loaded_embedding_models[model_id]
try:
logger.info(f"Loading embedding model {model_id}...")
# SentenceTransformer automatically handles GPU
model = SentenceTransformer(
model_id,
device=self.device
)
# Cache model
self.loaded_embedding_models[model_id] = model
logger.info(f"✓ Embedding model {model_id} loaded successfully on {self.device}")
return model
except Exception as e:
logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True)
raise
def generate_text(
self,
model_id: str,
prompt: str,
max_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Generate text using a loaded chat model.
Args:
model_id: Model identifier
prompt: Input prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Generated text
"""
if model_id not in self.loaded_models:
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
model = self.loaded_models[model_id]
tokenizer = self.loaded_tokenizers[model_id]
try:
# Tokenize input
inputs = tokenizer(prompt, return_tensors="pt").to(self.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
**kwargs
)
# Decode
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove prompt from output if present
if generated_text.startswith(prompt):
generated_text = generated_text[len(prompt):].strip()
return generated_text
except Exception as e:
logger.error(f"Error generating text: {e}", exc_info=True)
raise
def generate_chat_completion(
self,
model_id: str,
messages: list,
max_tokens: int = 512,
temperature: float = 0.7,
**kwargs
) -> str:
"""
Generate chat completion using a loaded model.
Args:
model_id: Model identifier
messages: List of message dicts with 'role' and 'content'
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
Returns:
Generated response
"""
if model_id not in self.loaded_models:
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.")
model = self.loaded_models[model_id]
tokenizer = self.loaded_tokenizers[model_id]
try:
# Format messages as prompt
if hasattr(tokenizer, 'apply_chat_template'):
# Use chat template if available
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
else:
# Fallback: simple formatting
prompt = "\n".join([
f"{msg['role']}: {msg['content']}"
for msg in messages
]) + "\nassistant: "
# Generate
return self.generate_text(
model_id=model_id,
prompt=prompt,
max_tokens=max_tokens,
temperature=temperature,
**kwargs
)
except Exception as e:
logger.error(f"Error generating chat completion: {e}", exc_info=True)
raise
def get_embedding(self, model_id: str, text: str) -> list:
"""
Get embedding vector for text.
Args:
model_id: Embedding model identifier
text: Input text
Returns:
Embedding vector
"""
if model_id not in self.loaded_embedding_models:
raise ValueError(f"Embedding model {model_id} not loaded. Call load_embedding_model() first.")
model = self.loaded_embedding_models[model_id]
try:
embedding = model.encode(text, convert_to_numpy=True)
return embedding.tolist()
except Exception as e:
logger.error(f"Error getting embedding: {e}", exc_info=True)
raise
def clear_cache(self):
"""Clear all loaded models from memory."""
logger.info("Clearing model cache...")
# Clear models
for model_id in list(self.loaded_models.keys()):
del self.loaded_models[model_id]
for model_id in list(self.loaded_tokenizers.keys()):
del self.loaded_tokenizers[model_id]
for model_id in list(self.loaded_embedding_models.keys()):
del self.loaded_embedding_models[model_id]
# Clear GPU cache
if self.device == "cuda":
torch.cuda.empty_cache()
logger.info("✓ Model cache cleared")
def get_memory_usage(self) -> Dict[str, float]:
"""Get current GPU memory usage in GB."""
if self.device != "cuda":
return {"device": "cpu", "gpu_available": False}
return {
"device": self.device_name,
"gpu_available": True,
"allocated_gb": torch.cuda.memory_allocated(0) / 1024**3,
"reserved_gb": torch.cuda.memory_reserved(0) / 1024**3,
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3
}