Spaces:
Runtime error
Runtime error
| import torch | |
| from transformers import AutoTokenizer, EsmModel, T5Tokenizer, T5EncoderModel, BertModel, AutoModelForMaskedLM | |
| from transformers import BertTokenizer, EsmTokenizer, T5Tokenizer | |
| from peft import LoraConfig, get_peft_model, PeftModel, PeftConfig | |
| from typing import List, Dict, Any, Tuple | |
| from transformers import PreTrainedModel | |
| def prepare_for_lora_model( | |
| based_model, | |
| lora_r: int = 8, | |
| lora_alpha: int = 32, | |
| lora_dropout: float = 0.1, | |
| target_modules: List[str,] = ["key", "query", "value"], | |
| ): | |
| if not isinstance(based_model, PreTrainedModel): | |
| raise TypeError("based_model must be a PreTrainedModel instance") | |
| # validate target_modules exist in model | |
| available_modules = [name for name, _ in based_model.named_modules()] | |
| for module in target_modules: | |
| if not any(module in name for name in available_modules): | |
| raise ValueError(f"Target module {module} not found in model") | |
| # get lora config | |
| lora_config = LoraConfig( | |
| r=lora_r, | |
| lora_alpha=lora_alpha, | |
| lora_dropout=lora_dropout, | |
| target_modules=target_modules, | |
| ) | |
| # get lora model | |
| model = get_peft_model(based_model, lora_config) | |
| print("Lora model is ready! num of trainable_parameters: ") | |
| model.print_trainable_parameters() | |
| return model | |
| def load_lora_model(base_model, lora_ckpt_path): | |
| model = PeftModel.from_pretrained(base_model, lora_ckpt_path) | |
| return model | |
| def load_eval_base_model(plm_model): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if "esm" in plm_model: | |
| base_model = EsmModel.from_pretrained(plm_model).to(device) | |
| elif "bert" in plm_model: | |
| base_model = BertModel.from_pretrained(plm_model).to(device) | |
| elif "prot_t5" in plm_model: | |
| base_model = T5EncoderModel.from_pretrained(plm_model).to(device) | |
| elif "ankh" in plm_model: | |
| base_model = T5EncoderModel.from_pretrained(plm_model).to(device) | |
| elif "ProSST" in plm_model: | |
| base_model = AutoModelForMaskedLM.from_pretrained(plm_model).to(device) | |
| return base_model | |
| def check_lora_params(model): | |
| lora_params = [ | |
| (name, param) for name, param in model.named_parameters() if "lora_" in name | |
| ] | |
| print(f"\n num of lora params: {len(lora_params)}") | |
| if len(lora_params) == 0: | |
| print("warning: no lora params found!") | |
| else: | |
| print("\n first lora param:") | |
| name, param = lora_params[0] | |
| print(f"name: {name}") | |
| print(f"param.shape: {param.shape}") | |
| print(f"param.dtype: {param.dtype}") | |
| print(f"param.device: {param.device}") | |
| # print(f"param_value:\n{param.data.cpu().numpy()}") | |
| print(f"requires_grad: {param.requires_grad}") | |