Spaces:
Runtime error
Runtime error
| import sys | |
| import os | |
| os.environ["HF_ENDPOINT"]="https://hf-mirror.com" | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) | |
| import argparse | |
| import torch | |
| import re | |
| import json | |
| import os | |
| import warnings | |
| import pandas as pd | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| from torchmetrics.classification import Accuracy, Recall, Precision, MatthewsCorrCoef, AUROC, F1Score, MatthewsCorrCoef | |
| from torchmetrics.classification import BinaryAccuracy, BinaryRecall, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryMatthewsCorrCoef, BinaryF1Score | |
| from torchmetrics.regression import SpearmanCorrCoef | |
| from transformers import EsmTokenizer, EsmModel, BertModel, BertTokenizer | |
| from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer, AutoModelForMaskedLM, AutoModel | |
| from transformers import logging | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| # from utils.data_utils import BatchSampler | |
| # from utils.metrics import MultilabelF1Max | |
| # from models.adapter_mdoel import AdapterModel | |
| from data.batch_sampler import BatchSampler | |
| from training.metrics import MultilabelF1Max | |
| from models.adapter_model import AdapterModel | |
| from models.lora_model import LoraModel | |
| from peft import PeftModel | |
| from typing import Dict, Any, Union, Tuple | |
| from data.dataloader import prepare_dataloaders | |
| from datetime import datetime | |
| # ignore warning information | |
| logging.set_verbosity_error() | |
| warnings.filterwarnings("ignore") | |
| def evaluate(model, plm_model, metrics, dataloader, loss_function, device=None): | |
| total_loss = 0 | |
| total_samples = len(dataloader.dataset) | |
| print(f"Total samples: {total_samples}") | |
| epoch_iterator = tqdm(dataloader) | |
| pred_labels = [] | |
| for i, batch in enumerate(epoch_iterator, 1): | |
| for k, v in batch.items(): | |
| batch[k] = v.to(device) | |
| label = batch["label"] | |
| logits = model(plm_model, batch) | |
| pred_labels.extend(logits.argmax(dim=1).cpu().numpy()) | |
| for metric_name, metric in metrics_dict.items(): | |
| if args.problem_type == 'regression' and args.num_labels == 1: | |
| loss = loss_function(logits.squeeze(), label.squeeze()) | |
| metric(logits.squeeze(), label.squeeze()) | |
| elif args.problem_type == 'multi_label_classification': | |
| loss = loss_function(logits, label.float()) | |
| metric(logits, label) | |
| else: | |
| loss = loss_function(logits, label) | |
| metric(torch.argmax(logits, 1), label) | |
| total_loss += loss.item() * len(label) | |
| epoch_iterator.set_postfix(eval_loss=loss.item()) | |
| epoch_loss = total_loss / len(dataloader.dataset) | |
| for k, v in metrics.items(): | |
| metrics[k] = [v.compute().item()] | |
| print(f"{k}: {metrics[k][0]}") | |
| metrics['loss'] = [epoch_loss] | |
| return metrics, pred_labels | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| # model params | |
| parser.add_argument('--eval_method', type=str, default=None, help='evaluation method') | |
| parser.add_argument('--hidden_size', type=int, default=None, help='embedding hidden size of the model') | |
| parser.add_argument('--num_attention_head', type=int, default=8, help='number of attention heads') | |
| parser.add_argument('--attention_probs_dropout', type=float, default=0, help='attention probs dropout prob') | |
| parser.add_argument('--plm_model', type=str, default='facebook/esm2_t33_650M_UR50D', help='esm model name') | |
| parser.add_argument('--num_labels', type=int, default=2, help='number of labels') | |
| parser.add_argument('--pooling_method', type=str, default='mean', help='pooling method') | |
| parser.add_argument('--pooling_dropout', type=float, default=0.25, help='pooling dropout') | |
| # dataset | |
| parser.add_argument('--dataset', type=str, default=None, help='dataset name') | |
| parser.add_argument('--problem_type', type=str, default=None, help='problem type') | |
| parser.add_argument('--test_file', type=str, default=None, help='test file') | |
| parser.add_argument('--split', type=str, default=None, help='split name in Huggingface') | |
| parser.add_argument('--test_result_dir', type=str, default=None, help='test result directory') | |
| parser.add_argument('--metrics', type=str, default=None, help='computation metrics') | |
| parser.add_argument('--num_workers', type=int, default=4, help='number of workers') | |
| parser.add_argument('--max_seq_len', type=int, default=None, help='max sequence length') | |
| parser.add_argument('--batch_size', type=int, default=None, help='batch size for fixed batch size') | |
| parser.add_argument('--batch_token', type=int, default=10000, help='max number of token per batch') | |
| parser.add_argument('--use_foldseek', action='store_true', help='use foldseek') | |
| parser.add_argument('--use_ss8', action='store_true', help='use ss8') | |
| # model path | |
| parser.add_argument('--output_model_name', type=str, default=None, help='model name') | |
| parser.add_argument('--output_root', default="result", help='root directory to save trained models') | |
| parser.add_argument('--output_dir', default=None, help='directory to save trained models') | |
| parser.add_argument('--model_path', default=None, help='model path directly') | |
| parser.add_argument('--structure_seq', type=str, default="", help='structure sequence') | |
| parser.add_argument('--training_method', type=str, default="freeze", help='training method') | |
| args = parser.parse_args() | |
| if 'foldseek_seq' in args.structure_seq: | |
| args.use_foldseek = True | |
| print("Enabled foldseek_seq based on structure_seq parameter") | |
| if 'ss8_seq' in args.structure_seq: | |
| args.use_ss8 = True | |
| print("Enabled ss8_seq based on structure_seq parameter") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| os.makedirs(args.test_result_dir, exist_ok=True) | |
| # build tokenizer and protein language model | |
| if "esm" in args.plm_model: | |
| tokenizer = EsmTokenizer.from_pretrained(args.plm_model) | |
| plm_model = EsmModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "bert" in args.plm_model: | |
| tokenizer = BertTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = BertModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "prot_t5" in args.plm_model: | |
| tokenizer = T5Tokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = T5EncoderModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.d_model | |
| elif "ankh" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = T5EncoderModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.d_model | |
| elif "ProSST" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "Prime" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.hidden_size | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model) | |
| plm_model = AutoModel.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.hidden_size | |
| args.vocab_size = plm_model.config.vocab_size | |
| # Define metric configurations | |
| metric_configs = { | |
| 'accuracy': { | |
| 'binary': BinaryAccuracy, | |
| 'multi': lambda: Accuracy(task="multiclass", num_classes=args.num_labels) | |
| }, | |
| 'recall': { | |
| 'binary': BinaryRecall, | |
| 'multi': lambda: Recall(task="multiclass", num_classes=args.num_labels) | |
| }, | |
| 'precision': { | |
| 'binary': BinaryPrecision, | |
| 'multi': lambda: Precision(task="multiclass", num_classes=args.num_labels) | |
| }, | |
| 'f1': { | |
| 'binary': BinaryF1Score, | |
| 'multi': lambda: F1Score(task="multiclass", num_classes=args.num_labels) | |
| }, | |
| 'mcc': { | |
| 'binary': BinaryMatthewsCorrCoef, | |
| 'multi': lambda: MatthewsCorrCoef(task="multiclass", num_classes=args.num_labels) | |
| }, | |
| 'auroc': { | |
| 'binary': BinaryAUROC, | |
| 'multi': lambda: AUROC(task="multiclass", num_classes=args.num_labels) | |
| }, | |
| 'f1_max': { | |
| 'any': lambda: MultilabelF1Max(num_labels=args.num_labels) | |
| }, | |
| 'spearman_corr': { | |
| 'any': SpearmanCorrCoef | |
| } | |
| } | |
| # Initialize metrics dictionary | |
| metrics_dict = {} | |
| args.metrics = args.metrics.split(',') | |
| # Create metrics based on configurations | |
| for metric_name in args.metrics: | |
| if metric_name not in metric_configs: | |
| raise ValueError(f"Invalid metric: {metric_name}") | |
| config = metric_configs[metric_name] | |
| if 'any' in config: | |
| metrics_dict[metric_name] = config['any']() | |
| else: | |
| metrics_dict[metric_name] = (config['binary']() if args.num_labels == 2 | |
| else config['multi']()) | |
| # Move metric to device | |
| metrics_dict[metric_name].to(device) | |
| # load adapter model | |
| print("---------- Load Model ----------") | |
| # model = AdapterModel(args) | |
| # if args.model_path is not None: | |
| # model_path = args.model_path | |
| # else: | |
| # model_path = f"{args.output_root}/{args.output_dir}/{args.output_model_name}" | |
| if args.eval_method in ["full", "ses-adapter", "freeze"]: | |
| model = AdapterModel(args) | |
| elif args.eval_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
| model = LoraModel(args) | |
| if args.model_path is not None: | |
| model_path = args.model_path | |
| else: | |
| model_path = f"{args.output_root}/{args.output_dir}/{args.output_model_name}" | |
| if args.eval_method == "full": | |
| model_weights = torch.load(model_path) | |
| model.load_state_dict(model_weights['model_state_dict']) | |
| plm_model.load_state_dict(model_weights['plm_state_dict']) | |
| else: | |
| model.load_state_dict(torch.load(model_path)) | |
| model.to(device).eval() | |
| if args.eval_method == 'plm-lora': | |
| lora_path = model_path.replace(".pt", "_lora") | |
| plm_model = PeftModel.from_pretrained(plm_model,lora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == 'plm-qlora': | |
| lora_path = model_path.replace(".pt", "_qlora") | |
| plm_model = PeftModel.from_pretrained(plm_model,lora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-dora": | |
| dora_path = model_path.replace(".pt", "_dora") | |
| plm_model = PeftModel.from_pretrained(plm_model, dora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-adalora": | |
| adalora_path = model_path.replace(".pt", "_adalora") | |
| plm_model = PeftModel.from_pretrained(plm_model, adalora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-ia3": | |
| ia3_path = model_path.replace(".pt", "_ia3") | |
| plm_model = PeftModel.from_pretrained(plm_model, ia3_path) | |
| plm_model = plm_model.merge_and_unload() | |
| plm_model.to(device).eval() | |
| def param_num(model): | |
| total = sum([param.numel() for param in model.parameters() if param.requires_grad]) | |
| num_M = total/1e6 | |
| if num_M >= 1000: | |
| return "Number of parameter: %.2fB" % (num_M/1e3) | |
| else: | |
| return "Number of parameter: %.2fM" % (num_M) | |
| print(param_num(model)) | |
| def collate_fn(examples): | |
| aa_seqs, labels = [], [] | |
| if args.use_foldseek: | |
| foldseek_seqs = [] | |
| if args.use_ss8: | |
| ss8_seqs = [] | |
| prosst_stru_tokens = [] if "ProSST" in args.plm_model else None | |
| for e in examples: | |
| aa_seq = e["aa_seq"] | |
| if args.use_foldseek: | |
| foldseek_seq = e["foldseek_seq"] | |
| if args.use_ss8: | |
| ss8_seq = e["ss8_seq"] | |
| if "ProSST" in args.plm_model and "prosst_stru_token" in e: | |
| stru_token = e["prosst_stru_token"] | |
| if isinstance(stru_token, str): | |
| seq_clean = stru_token.strip("[]").replace(" ","") | |
| tokens = list(map(int, seq_clean.split(','))) if seq_clean else [] | |
| elif isinstance(stru_token, (list, tuple)): | |
| tokens = [int(x) for x in stru_token] | |
| else: | |
| tokens = [] | |
| prosst_stru_tokens.append(torch.tensor(tokens)) | |
| if 'prot_bert' in args.plm_model or "prot_t5" in args.plm_model: | |
| aa_seq = " ".join(list(aa_seq)) | |
| aa_seq = re.sub(r"[UZOB]", "X", aa_seq) | |
| if args.use_foldseek: | |
| foldseek_seq = " ".join(list(foldseek_seq)) | |
| if args.use_ss8: | |
| ss8_seq = " ".join(list(ss8_seq)) | |
| elif 'ankh' in args.plm_model: | |
| aa_seq = list(aa_seq) | |
| if args.use_foldseek: | |
| foldseek_seq = list(foldseek_seq) | |
| if args.use_ss8: | |
| ss8_seq = list(ss8_seq) | |
| aa_seqs.append(aa_seq) | |
| if args.use_foldseek: | |
| foldseek_seqs.append(foldseek_seq) | |
| if args.use_ss8: | |
| ss8_seqs.append(ss8_seq) | |
| labels.append(e["label"]) | |
| if 'ankh' in args.plm_model: | |
| aa_inputs = tokenizer.batch_encode_plus(aa_seqs, add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt") | |
| if args.use_foldseek: | |
| foldseek_input_ids = tokenizer.batch_encode_plus(foldseek_seqs, add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt")["input_ids"] | |
| if args.use_ss8: | |
| ss8_input_ids = tokenizer.batch_encode_plus(ss8_seqs, add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt")["input_ids"] | |
| else: | |
| aa_inputs = tokenizer(aa_seqs, return_tensors="pt", padding=True, truncation=True) | |
| if args.use_foldseek: | |
| foldseek_input_ids = tokenizer(foldseek_seqs, return_tensors="pt", padding=True, truncation=True)["input_ids"] | |
| if args.use_ss8: | |
| ss8_input_ids = tokenizer(ss8_seqs, return_tensors="pt", padding=True, truncation=True)["input_ids"] | |
| aa_input_ids = aa_inputs["input_ids"] | |
| attention_mask = aa_inputs["attention_mask"] | |
| if args.problem_type == 'regression': | |
| labels = torch.as_tensor(labels, dtype=torch.float) | |
| else: | |
| labels = torch.as_tensor(labels, dtype=torch.long) | |
| data_dict = { | |
| "aa_seq_input_ids": aa_input_ids, | |
| "aa_seq_attention_mask": attention_mask, | |
| "label": labels | |
| } | |
| if "ProSST" in args.plm_model and prosst_stru_tokens: | |
| aa_max_length = len(aa_input_ids[0]) | |
| padded_tokens = [] | |
| for tokens in prosst_stru_tokens: | |
| if tokens is None or len(tokens) == 0: | |
| padded_tokens.append([0] * aa_max_length) | |
| else: | |
| struct_sequence = tokens.tolist() | |
| padded_tokens.append(struct_sequence + [0] * (aa_max_length - len(struct_sequence))) | |
| data_dict["aa_seq_stru_tokens"] = torch.tensor(padded_tokens, dtype=torch.long) | |
| if args.use_foldseek: | |
| data_dict["foldseek_seq_input_ids"] = foldseek_input_ids | |
| if args.use_ss8: | |
| data_dict["ss8_seq_input_ids"] = ss8_input_ids | |
| return data_dict | |
| loss_function = nn.CrossEntropyLoss() | |
| def process_data_line(data): | |
| if args.problem_type == 'multi_label_classification': | |
| label_list = data['label'].split(',') | |
| data['label'] = [int(l) for l in label_list] | |
| binary_list = [0] * args.num_labels | |
| for index in data['label']: | |
| binary_list[index] = 1 | |
| data['label'] = binary_list | |
| if args.max_seq_len is not None: | |
| data["aa_seq"] = data["aa_seq"][:args.max_seq_len] | |
| if args.use_foldseek: | |
| data["foldseek_seq"] = data["foldseek_seq"][:args.max_seq_len] | |
| if args.use_ss8: | |
| data["ss8_seq"] = data["ss8_seq"][:args.max_seq_len] | |
| # 如果是 ProSST 模型且有结构标记,也需要截断 | |
| if "ProSST" in args.plm_model and "prosst_stru_token" in data: | |
| # 结构标记可能是字符串或列表形式 | |
| if isinstance(data["prosst_stru_token"], str): | |
| pass | |
| elif isinstance(data["prosst_stru_token"], (list, tuple)): | |
| data["prosst_stru_token"] = data["prosst_stru_token"][:args.max_seq_len] | |
| token_num = min(len(data["aa_seq"]), args.max_seq_len) | |
| else: | |
| token_num = len(data["aa_seq"]) | |
| return data, token_num | |
| # process dataset from json file | |
| def process_dataset_from_json(file): | |
| dataset, token_nums = [], [] | |
| for l in open(file): | |
| data = json.loads(l) | |
| data, token_num = process_data_line(data) | |
| dataset.append(data) | |
| token_nums.append(token_num) | |
| return dataset, token_nums | |
| # process dataset from list | |
| def process_dataset_from_list(data_list): | |
| dataset, token_nums = [], [] | |
| for l in data_list: | |
| data, token_num = process_data_line(l) | |
| dataset.append(data) | |
| token_nums.append(token_num) | |
| return dataset, token_nums | |
| if args.test_file.endswith('json'): | |
| test_dataset, test_token_num = process_dataset_from_json(args.test_file) | |
| elif args.test_file.endswith('csv'): | |
| test_dataset, test_token_num = process_dataset_from_list(load_dataset("csv", data_files=args.test_file)['train']) | |
| if args.test_result_dir: | |
| test_result_df = pd.read_csv(args.test_file) | |
| elif '/' in args.test_file: # Huggingface dataset (only csv now) | |
| raw_dataset = load_dataset(args.test_file) | |
| # Using the chosen split first. | |
| if args.split and args.split in raw_dataset: | |
| split = args.split | |
| elif 'test' in raw_dataset: | |
| split = 'test' | |
| elif 'validation' in raw_dataset: | |
| split = 'validation' | |
| elif 'train' in raw_dataset: | |
| split = 'train' | |
| else: | |
| split = list(raw_dataset.keys())[0] | |
| test_dataset, test_token_num = process_dataset_from_list(raw_dataset[split]) | |
| if args.test_result_dir: | |
| test_result_df = pd.DataFrame(raw_dataset[split]) | |
| else: | |
| raise ValueError("Invalid file format") | |
| if args.batch_size is None: | |
| if args.batch_token is None: | |
| raise ValueError("batch_size or batch_token must be specified") | |
| test_loader = DataLoader( | |
| test_dataset, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_fn, | |
| batch_sampler=BatchSampler(test_token_num, args.batch_token, False) | |
| ) | |
| else: | |
| test_loader = DataLoader( | |
| test_dataset, | |
| batch_size=args.batch_size, | |
| num_workers=args.num_workers, | |
| collate_fn=collate_fn, | |
| shuffle=False | |
| ) | |
| print("---------- Start Eval ----------") | |
| with torch.no_grad(): | |
| metric, pred_labels = evaluate(model, plm_model, metrics_dict, test_loader, loss_function, device) | |
| if args.test_result_dir: | |
| pd.DataFrame(metric).to_csv(f"{args.test_result_dir}/evaluation_metrics.csv", index=False) | |
| test_result_df["pred_label"] = pred_labels | |
| test_result_df.to_csv(f"{args.test_result_dir}/evaluation_result.csv", index=False) | |