|
|
|
|
|
|
|
|
import logging |
|
|
import torch |
|
|
from typing import Optional, Dict, Any |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModel |
|
|
from sentence_transformers import SentenceTransformer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class LocalModelLoader: |
|
|
""" |
|
|
Loads and manages models locally on GPU for faster inference. |
|
|
Optimized for NVIDIA T4 Medium with 24GB vRAM. |
|
|
""" |
|
|
|
|
|
def __init__(self, device: Optional[str] = None): |
|
|
"""Initialize the model loader with GPU device detection.""" |
|
|
|
|
|
if device is None: |
|
|
if torch.cuda.is_available(): |
|
|
self.device = "cuda" |
|
|
self.device_name = torch.cuda.get_device_name(0) |
|
|
logger.info(f"GPU detected: {self.device_name}") |
|
|
logger.info(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB") |
|
|
else: |
|
|
self.device = "cpu" |
|
|
self.device_name = "CPU" |
|
|
logger.warning("No GPU detected, using CPU") |
|
|
else: |
|
|
self.device = device |
|
|
self.device_name = device |
|
|
|
|
|
|
|
|
self.loaded_models: Dict[str, Any] = {} |
|
|
self.loaded_tokenizers: Dict[str, Any] = {} |
|
|
self.loaded_embedding_models: Dict[str, Any] = {} |
|
|
|
|
|
def load_chat_model(self, model_id: str, load_in_8bit: bool = False, load_in_4bit: bool = False) -> tuple: |
|
|
""" |
|
|
Load a chat model and tokenizer on GPU. |
|
|
|
|
|
Args: |
|
|
model_id: HuggingFace model identifier |
|
|
load_in_8bit: Use 8-bit quantization (saves memory) |
|
|
load_in_4bit: Use 4-bit quantization (saves more memory) |
|
|
|
|
|
Returns: |
|
|
Tuple of (model, tokenizer) |
|
|
""" |
|
|
if model_id in self.loaded_models: |
|
|
logger.info(f"Model {model_id} already loaded, reusing") |
|
|
return self.loaded_models[model_id], self.loaded_tokenizers[model_id] |
|
|
|
|
|
try: |
|
|
logger.info(f"Loading model {model_id} on {self.device}...") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_id, |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
|
|
|
if load_in_4bit and self.device == "cuda": |
|
|
try: |
|
|
from transformers import BitsAndBytesConfig |
|
|
quantization_config = BitsAndBytesConfig( |
|
|
load_in_4bit=True, |
|
|
bnb_4bit_compute_dtype=torch.float16, |
|
|
bnb_4bit_use_double_quant=True, |
|
|
bnb_4bit_quant_type="nf4" |
|
|
) |
|
|
logger.info("Using 4-bit quantization") |
|
|
except ImportError: |
|
|
logger.warning("bitsandbytes not available, loading without quantization") |
|
|
quantization_config = None |
|
|
elif load_in_8bit and self.device == "cuda": |
|
|
try: |
|
|
quantization_config = {"load_in_8bit": True} |
|
|
logger.info("Using 8-bit quantization") |
|
|
except: |
|
|
quantization_config = None |
|
|
else: |
|
|
quantization_config = None |
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.float16, |
|
|
trust_remote_code=True, |
|
|
**(quantization_config if isinstance(quantization_config, dict) else {}), |
|
|
**({"quantization_config": quantization_config} if quantization_config and not isinstance(quantization_config, dict) else {}) |
|
|
) |
|
|
else: |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float32, |
|
|
trust_remote_code=True |
|
|
) |
|
|
model = model.to(self.device) |
|
|
|
|
|
|
|
|
if tokenizer.pad_token is None: |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
self.loaded_models[model_id] = model |
|
|
self.loaded_tokenizers[model_id] = tokenizer |
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
allocated = torch.cuda.memory_allocated(0) / 1024**3 |
|
|
reserved = torch.cuda.memory_reserved(0) / 1024**3 |
|
|
logger.info(f"GPU Memory - Allocated: {allocated:.2f} GB, Reserved: {reserved:.2f} GB") |
|
|
|
|
|
logger.info(f"✓ Model {model_id} loaded successfully on {self.device}") |
|
|
return model, tokenizer |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading model {model_id}: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def load_embedding_model(self, model_id: str) -> SentenceTransformer: |
|
|
""" |
|
|
Load a sentence transformer model for embeddings. |
|
|
|
|
|
Args: |
|
|
model_id: HuggingFace model identifier |
|
|
|
|
|
Returns: |
|
|
SentenceTransformer model |
|
|
""" |
|
|
if model_id in self.loaded_embedding_models: |
|
|
logger.info(f"Embedding model {model_id} already loaded, reusing") |
|
|
return self.loaded_embedding_models[model_id] |
|
|
|
|
|
try: |
|
|
logger.info(f"Loading embedding model {model_id}...") |
|
|
|
|
|
|
|
|
model = SentenceTransformer( |
|
|
model_id, |
|
|
device=self.device |
|
|
) |
|
|
|
|
|
|
|
|
self.loaded_embedding_models[model_id] = model |
|
|
|
|
|
logger.info(f"✓ Embedding model {model_id} loaded successfully on {self.device}") |
|
|
return model |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading embedding model {model_id}: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def generate_text( |
|
|
self, |
|
|
model_id: str, |
|
|
prompt: str, |
|
|
max_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Generate text using a loaded chat model. |
|
|
|
|
|
Args: |
|
|
model_id: Model identifier |
|
|
prompt: Input prompt |
|
|
max_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
|
|
|
Returns: |
|
|
Generated text |
|
|
""" |
|
|
if model_id not in self.loaded_models: |
|
|
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.") |
|
|
|
|
|
model = self.loaded_models[model_id] |
|
|
tokenizer = self.loaded_tokenizers[model_id] |
|
|
|
|
|
try: |
|
|
|
|
|
inputs = tokenizer(prompt, return_tensors="pt").to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.generate( |
|
|
**inputs, |
|
|
max_new_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
do_sample=True, |
|
|
pad_token_id=tokenizer.pad_token_id, |
|
|
eos_token_id=tokenizer.eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
|
|
|
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if generated_text.startswith(prompt): |
|
|
generated_text = generated_text[len(prompt):].strip() |
|
|
|
|
|
return generated_text |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating text: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def generate_chat_completion( |
|
|
self, |
|
|
model_id: str, |
|
|
messages: list, |
|
|
max_tokens: int = 512, |
|
|
temperature: float = 0.7, |
|
|
**kwargs |
|
|
) -> str: |
|
|
""" |
|
|
Generate chat completion using a loaded model. |
|
|
|
|
|
Args: |
|
|
model_id: Model identifier |
|
|
messages: List of message dicts with 'role' and 'content' |
|
|
max_tokens: Maximum tokens to generate |
|
|
temperature: Sampling temperature |
|
|
|
|
|
Returns: |
|
|
Generated response |
|
|
""" |
|
|
if model_id not in self.loaded_models: |
|
|
raise ValueError(f"Model {model_id} not loaded. Call load_chat_model() first.") |
|
|
|
|
|
model = self.loaded_models[model_id] |
|
|
tokenizer = self.loaded_tokenizers[model_id] |
|
|
|
|
|
try: |
|
|
|
|
|
if hasattr(tokenizer, 'apply_chat_template'): |
|
|
|
|
|
prompt = tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
else: |
|
|
|
|
|
prompt = "\n".join([ |
|
|
f"{msg['role']}: {msg['content']}" |
|
|
for msg in messages |
|
|
]) + "\nassistant: " |
|
|
|
|
|
|
|
|
return self.generate_text( |
|
|
model_id=model_id, |
|
|
prompt=prompt, |
|
|
max_tokens=max_tokens, |
|
|
temperature=temperature, |
|
|
**kwargs |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error generating chat completion: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def get_embedding(self, model_id: str, text: str) -> list: |
|
|
""" |
|
|
Get embedding vector for text. |
|
|
|
|
|
Args: |
|
|
model_id: Embedding model identifier |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
Embedding vector |
|
|
""" |
|
|
if model_id not in self.loaded_embedding_models: |
|
|
raise ValueError(f"Embedding model {model_id} not loaded. Call load_embedding_model() first.") |
|
|
|
|
|
model = self.loaded_embedding_models[model_id] |
|
|
|
|
|
try: |
|
|
embedding = model.encode(text, convert_to_numpy=True) |
|
|
return embedding.tolist() |
|
|
except Exception as e: |
|
|
logger.error(f"Error getting embedding: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
def clear_cache(self): |
|
|
"""Clear all loaded models from memory.""" |
|
|
logger.info("Clearing model cache...") |
|
|
|
|
|
|
|
|
for model_id in list(self.loaded_models.keys()): |
|
|
del self.loaded_models[model_id] |
|
|
for model_id in list(self.loaded_tokenizers.keys()): |
|
|
del self.loaded_tokenizers[model_id] |
|
|
for model_id in list(self.loaded_embedding_models.keys()): |
|
|
del self.loaded_embedding_models[model_id] |
|
|
|
|
|
|
|
|
if self.device == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
logger.info("✓ Model cache cleared") |
|
|
|
|
|
def get_memory_usage(self) -> Dict[str, float]: |
|
|
"""Get current GPU memory usage in GB.""" |
|
|
if self.device != "cuda": |
|
|
return {"device": "cpu", "gpu_available": False} |
|
|
|
|
|
return { |
|
|
"device": self.device_name, |
|
|
"gpu_available": True, |
|
|
"allocated_gb": torch.cuda.memory_allocated(0) / 1024**3, |
|
|
"reserved_gb": torch.cuda.memory_reserved(0) / 1024**3, |
|
|
"total_gb": torch.cuda.get_device_properties(0).total_memory / 1024**3 |
|
|
} |
|
|
|
|
|
|