|
|
import torch |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
from huggingface_hub import hf_hub_download |
|
|
from safetensors.torch import load_file |
|
|
import os |
|
|
|
|
|
def load_assets(model_config): |
|
|
""" |
|
|
Loads all models, tokenizer, and optimizer states. |
|
|
|
|
|
Args: |
|
|
model_config (dict): The model configuration dictionary from the YAML file. |
|
|
|
|
|
Returns: |
|
|
tuple: (pretrained_model, finetuned_model, optimizer_v_state, tokenizer) |
|
|
""" |
|
|
device = model_config.get("device", "cuda" if torch.cuda.is_available() else "cpu") |
|
|
dtype_str = model_config.get("dtype", "bfloat16") |
|
|
|
|
|
if dtype_str == "bfloat16": |
|
|
dtype = torch.bfloat16 |
|
|
elif dtype_str == "float16": |
|
|
dtype = torch.float16 |
|
|
elif dtype_str == "float32": |
|
|
dtype = torch.float32 |
|
|
else: |
|
|
raise ValueError(f"Unsupported dtype: {dtype_str}") |
|
|
|
|
|
print(f"Using device: {device} and dtype: {dtype_str}") |
|
|
|
|
|
|
|
|
print(f"Loading base model: {model_config['base_model_id']}") |
|
|
pretrained_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_config['base_model_id'], |
|
|
torch_dtype=dtype, |
|
|
device_map=device, |
|
|
trust_remote_code=True |
|
|
) |
|
|
print("β Base model loaded.") |
|
|
|
|
|
|
|
|
print(f"Loading fine-tuned model: {model_config['finetuned_model_id']}") |
|
|
finetuned_model = AutoModelForCausalLM.from_pretrained( |
|
|
model_config['finetuned_model_id'], |
|
|
torch_dtype=dtype, |
|
|
device_map=device, |
|
|
trust_remote_code=True |
|
|
) |
|
|
print("β Fine-tuned model loaded.") |
|
|
|
|
|
|
|
|
print(f"Loading tokenizer from: {model_config['finetuned_model_id']}") |
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_config['finetuned_model_id'], |
|
|
trust_remote_code=True |
|
|
) |
|
|
print("β Tokenizer loaded.") |
|
|
|
|
|
|
|
|
optimizer_v_state = None |
|
|
if model_config.get('optimizer_states_file'): |
|
|
print(f"Loading optimizer states from HF: {model_config['optimizer_states_file']}") |
|
|
repo_id, filename = model_config['optimizer_states_file'].split(":") |
|
|
try: |
|
|
cached_file = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
optimizer_v_state = load_file(cached_file) |
|
|
print(f"β Loaded {len(optimizer_v_state)} optimizer state tensors.") |
|
|
except Exception as e: |
|
|
print(f"Could not download optimizer states from HF Hub: {e}") |
|
|
raise |
|
|
elif model_config.get('local_optimizer_states_path'): |
|
|
path = model_config['local_optimizer_states_path'] |
|
|
print(f"Loading optimizer states from local path: {path}") |
|
|
if os.path.exists(path): |
|
|
optimizer_v_state = load_file(path) |
|
|
print(f"β Loaded {len(optimizer_v_state)} optimizer state tensors.") |
|
|
else: |
|
|
raise FileNotFoundError(f"Optimizer states file not found at: {path}") |
|
|
|
|
|
return pretrained_model, finetuned_model, optimizer_v_state, tokenizer |
|
|
|