File size: 3,020 Bytes
48a55a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
import os
def load_assets(model_config):
"""
Loads all models, tokenizer, and optimizer states.
Args:
model_config (dict): The model configuration dictionary from the YAML file.
Returns:
tuple: (pretrained_model, finetuned_model, optimizer_v_state, tokenizer)
"""
device = model_config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
dtype_str = model_config.get("dtype", "bfloat16")
if dtype_str == "bfloat16":
dtype = torch.bfloat16
elif dtype_str == "float16":
dtype = torch.float16
elif dtype_str == "float32":
dtype = torch.float32
else:
raise ValueError(f"Unsupported dtype: {dtype_str}")
print(f"Using device: {device} and dtype: {dtype_str}")
# Load base model (w_0)
print(f"Loading base model: {model_config['base_model_id']}")
pretrained_model = AutoModelForCausalLM.from_pretrained(
model_config['base_model_id'],
torch_dtype=dtype,
device_map=device,
trust_remote_code=True
)
print("β Base model loaded.")
# Load fine-tuned model (w_T)
print(f"Loading fine-tuned model: {model_config['finetuned_model_id']}")
finetuned_model = AutoModelForCausalLM.from_pretrained(
model_config['finetuned_model_id'],
torch_dtype=dtype,
device_map=device,
trust_remote_code=True
)
print("β Fine-tuned model loaded.")
# Load tokenizer
print(f"Loading tokenizer from: {model_config['finetuned_model_id']}")
tokenizer = AutoTokenizer.from_pretrained(
model_config['finetuned_model_id'],
trust_remote_code=True
)
print("β Tokenizer loaded.")
# Load optimizer states (v_T)
optimizer_v_state = None
if model_config.get('optimizer_states_file'):
print(f"Loading optimizer states from HF: {model_config['optimizer_states_file']}")
repo_id, filename = model_config['optimizer_states_file'].split(":")
try:
cached_file = hf_hub_download(repo_id=repo_id, filename=filename)
optimizer_v_state = load_file(cached_file)
print(f"β Loaded {len(optimizer_v_state)} optimizer state tensors.")
except Exception as e:
print(f"Could not download optimizer states from HF Hub: {e}")
raise
elif model_config.get('local_optimizer_states_path'):
path = model_config['local_optimizer_states_path']
print(f"Loading optimizer states from local path: {path}")
if os.path.exists(path):
optimizer_v_state = load_file(path)
print(f"β Loaded {len(optimizer_v_state)} optimizer state tensors.")
else:
raise FileNotFoundError(f"Optimizer states file not found at: {path}")
return pretrained_model, finetuned_model, optimizer_v_state, tokenizer
|