pmahdavi's picture
Deploy FFG Mask Explorer initial version
48a55a5
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import os
def load_assets(model_config):
"""
Loads all models, tokenizer, and optimizer states.
Args:
model_config (dict): The model configuration dictionary from the YAML file.
Returns:
tuple: (pretrained_model, finetuned_model, optimizer_v_state, tokenizer)
"""
device = model_config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
dtype_str = model_config.get("dtype", "bfloat16")
if dtype_str == "bfloat16":
dtype = torch.bfloat16
elif dtype_str == "float16":
dtype = torch.float16
elif dtype_str == "float32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {dtype_str}")
print(f"Using device: {device} and dtype: {dtype_str}")
# Load base model (w_0)
print(f"Loading base model: {model_config['base_model_id']}")
pretrained_model = AutoModelForCausalLM.from_pretrained(
model_config['base_model_id'],
torch_dtype=dtype,
device_map=device,
trust_remote_code=True
)
print("βœ“ Base model loaded.")
# Load fine-tuned model (w_T)
print(f"Loading fine-tuned model: {model_config['finetuned_model_id']}")
finetuned_model = AutoModelForCausalLM.from_pretrained(
model_config['finetuned_model_id'],
torch_dtype=dtype,
device_map=device,
trust_remote_code=True
)
print("βœ“ Fine-tuned model loaded.")
# Load tokenizer
print(f"Loading tokenizer from: {model_config['finetuned_model_id']}")
tokenizer = AutoTokenizer.from_pretrained(
model_config['finetuned_model_id'],
trust_remote_code=True
)
print("βœ“ Tokenizer loaded.")
# Load optimizer states (v_T)
optimizer_v_state = None
if model_config.get('optimizer_states_file'):
print(f"Loading optimizer states from HF: {model_config['optimizer_states_file']}")
repo_id, filename = model_config['optimizer_states_file'].split(":")
try:
cached_file = hf_hub_download(repo_id=repo_id, filename=filename)
optimizer_v_state = load_file(cached_file)
print(f"βœ“ Loaded {len(optimizer_v_state)} optimizer state tensors.")
except Exception as e:
print(f"Could not download optimizer states from HF Hub: {e}")
raise
elif model_config.get('local_optimizer_states_path'):
path = model_config['local_optimizer_states_path']
print(f"Loading optimizer states from local path: {path}")
if os.path.exists(path):
optimizer_v_state = load_file(path)
print(f"βœ“ Loaded {len(optimizer_v_state)} optimizer state tensors.")
else:
raise FileNotFoundError(f"Optimizer states file not found at: {path}")
return pretrained_model, finetuned_model, optimizer_v_state, tokenizer