Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) | |
| import argparse | |
| import torch | |
| import re | |
| import json | |
| import os | |
| import warnings | |
| import numpy as np | |
| from pathlib import Path | |
| from transformers import EsmTokenizer, EsmModel, BertModel, BertTokenizer, AutoModelForMaskedLM | |
| from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer, AutoModel | |
| from transformers import logging | |
| from peft import PeftModel | |
| # Import project modules | |
| from models.adapter_model import AdapterModel | |
| from models.lora_model import LoraModel | |
| from models.pooling import MeanPooling, Attention1dPoolingHead, LightAttentionPoolingHead | |
| # Ignore warning information | |
| logging.set_verbosity_error() | |
| warnings.filterwarnings("ignore") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Predict protein function for a single sequence") | |
| # Model parameters | |
| parser.add_argument('--eval_method', type=str, default="freeze", choices=["freeze", "plm-lora", "plm-qlora", "ses-adapter"], help="Evaluation method") | |
| parser.add_argument('--model_path', type=str, required=True, help="Path to the trained model") | |
| parser.add_argument('--plm_model', type=str, required=True, help="Pretrained language model name or path") | |
| parser.add_argument('--pooling_method', type=str, default="mean", choices=["mean", "attention1d", "light_attention"], help="Pooling method") | |
| parser.add_argument('--problem_type', type=str, default="single_label_classification", | |
| choices=["single_label_classification", "multi_label_classification", "regression"], | |
| help="Problem type") | |
| parser.add_argument('--num_labels', type=int, default=2, help="Number of labels") | |
| parser.add_argument('--hidden_size', type=int, default=None, help="Embedding hidden size of the model") | |
| parser.add_argument('--num_attention_head', type=int, default=8, help="Number of attention heads") | |
| parser.add_argument('--attention_probs_dropout', type=float, default=0, help="Attention probs dropout prob") | |
| parser.add_argument('--pooling_dropout', type=float, default=0.25, help="Pooling dropout") | |
| # Input sequence parameters | |
| parser.add_argument('--aa_seq', type=str, required=True, help="Amino acid sequence") | |
| parser.add_argument('--foldseek_seq', type=str, default="", help="Foldseek sequence (optional)") | |
| parser.add_argument('--ss8_seq', type=str, default="", help="Secondary structure sequence (optional)") | |
| parser.add_argument('--dataset', type=str, default="single", help="Dataset name (optional)") | |
| parser.add_argument('--use_foldseek', action='store_true', help="Use foldseek sequence") | |
| parser.add_argument('--use_ss8', action='store_true', help="Use secondary structure sequence") | |
| parser.add_argument('--structure_seq', type=str, default=None, help="Structure sequence types to use (comma-separated)") | |
| # Other parameters | |
| parser.add_argument('--max_seq_len', type=int, default=1024, help="Maximum sequence length") | |
| args = parser.parse_args() | |
| # Automatically determine whether to use structure sequences based on input | |
| args.use_foldseek = bool(args.foldseek_seq) | |
| args.use_ss8 = bool(args.ss8_seq) | |
| return args | |
| def load_model_and_tokenizer(args): | |
| print("---------- Loading Model and Tokenizer ----------") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Check if model file exists | |
| if not os.path.exists(args.model_path): | |
| raise FileNotFoundError(f"Model file not found: {args.model_path}") | |
| # Load model configuration if available | |
| config_path = os.path.join(os.path.dirname(args.model_path), "config.json") | |
| try: | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| print(f"Loaded configuration from {config_path}") | |
| # Update args with config values if they exist | |
| if "pooling_method" in config: | |
| args.pooling_method = config["pooling_method"] | |
| if "problem_type" in config: | |
| args.problem_type = config["problem_type"] | |
| if "num_labels" in config: | |
| args.num_labels = config["num_labels"] | |
| if "num_attention_head" in config: | |
| args.num_attention_head = config["num_attention_head"] | |
| if "attention_probs_dropout" in config: | |
| args.attention_probs_dropout = config["attention_probs_dropout"] | |
| if "pooling_dropout" in config: | |
| args.pooling_dropout = config["pooling_dropout"] | |
| except FileNotFoundError: | |
| print(f"Model config not found at {config_path}. Using command line arguments.") | |
| # Build tokenizer and protein language model | |
| if "esm" in args.plm_model: | |
| tokenizer = EsmTokenizer.from_pretrained(args.plm_model) | |
| plm_model = EsmModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "bert" in args.plm_model: | |
| tokenizer = BertTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = BertModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "prot_t5" in args.plm_model: | |
| tokenizer = T5Tokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = T5EncoderModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.d_model | |
| elif "ankh" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = T5EncoderModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.d_model | |
| elif "ProSST" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "Prime" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.hidden_size | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model) | |
| plm_model = AutoModel.from_pretrained(args.plm_model) | |
| args.hidden_size = plm_model.config.hidden_size | |
| args.vocab_size = plm_model.config.vocab_size | |
| # Determine structure sequence types | |
| if args.structure_seq is None: | |
| args.structure_seq = "" | |
| print("Warning: structure_seq was None, setting to empty string") | |
| # Auto-set structure sequence flags based on structure_seq parameter | |
| if 'foldseek_seq' in args.structure_seq: | |
| args.use_foldseek = True | |
| print("Enabled foldseek_seq based on structure_seq parameter") | |
| if 'ss8_seq' in args.structure_seq: | |
| args.use_ss8 = True | |
| print("Enabled ss8_seq based on structure_seq parameter") | |
| # If flags are set but structure_seq is not, update structure_seq | |
| structure_seq_list = [] | |
| if args.use_foldseek and 'foldseek_seq' not in args.structure_seq: | |
| structure_seq_list.append("foldseek_seq") | |
| if args.use_ss8 and 'ss8_seq' not in args.structure_seq: | |
| structure_seq_list.append("ss8_seq") | |
| if structure_seq_list and not args.structure_seq: | |
| args.structure_seq = ",".join(structure_seq_list) | |
| print(f"Training method: {args.eval_method}") # Default for prediction | |
| print(f"Structure sequence: {args.structure_seq}") | |
| print(f"Use foldseek: {args.use_foldseek}") | |
| print(f"Use ss8: {args.use_ss8}") | |
| print(f"Problem type: {args.problem_type}") | |
| print(f"Number of labels: {args.num_labels}") | |
| print(f"Number of attention heads: {args.num_attention_head}") | |
| # Create and load model | |
| try: | |
| if args.eval_method in ["full", "ses-adapter", "freeze"]: | |
| model = AdapterModel(args) | |
| # ! lora/ qlora | |
| elif args.eval_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
| model = LoraModel(args) | |
| if args.model_path is not None: | |
| model_path = args.model_path | |
| else: | |
| model_path = f"{args.output_root}/{args.output_dir}/{args.output_model_name}" | |
| model.load_state_dict(torch.load(model_path)) | |
| model.to(device).eval() | |
| # ! lora/ qlora | |
| if args.eval_method == 'plm-lora': | |
| lora_path = model_path.replace(".pt", "_lora") | |
| plm_model = PeftModel.from_pretrained(plm_model,lora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == 'plm-qlora': | |
| lora_path = model_path.replace(".pt", "_qlora") | |
| plm_model = PeftModel.from_pretrained(plm_model,lora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-dora": | |
| dora_path = model_path.replace(".pt", "_dora") | |
| plm_model = PeftModel.from_pretrained(plm_model, dora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-adalora": | |
| adalora_path = model_path.replace(".pt", "_adalora") | |
| plm_model = PeftModel.from_pretrained(plm_model, adalora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-ia3": | |
| ia3_path = model_path.replace(".pt", "_ia3") | |
| plm_model = PeftModel.from_pretrained(plm_model, ia3_path) | |
| plm_model = plm_model.merge_and_unload() | |
| plm_model.to(device).eval() | |
| plm_model.to(device).eval() | |
| return model, plm_model, tokenizer, device | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| raise | |
| def process_sequences(args, tokenizer, plm_model_name): | |
| """Process and prepare input sequences for prediction""" | |
| print("---------- Processing Input Sequences ----------") | |
| # Process amino acid sequence | |
| aa_seq = args.aa_seq.strip() | |
| if not aa_seq: | |
| raise ValueError("Amino acid sequence is empty") | |
| # Process structure sequences if needed | |
| foldseek_seq = args.foldseek_seq.strip() if args.foldseek_seq else "" | |
| ss8_seq = args.ss8_seq.strip() if args.ss8_seq else "" | |
| # Check if structure sequences are required but not provided | |
| if args.use_foldseek and not foldseek_seq: | |
| print("Warning: Foldseek sequence is required but not provided.") | |
| if args.use_ss8 and not ss8_seq: | |
| print("Warning: SS8 sequence is required but not provided.") | |
| # Format sequences based on model type | |
| if 'prot_bert' in plm_model_name or "prot_t5" in plm_model_name: | |
| aa_seq = " ".join(list(aa_seq)) | |
| aa_seq = re.sub(r"[UZOB]", "X", aa_seq) | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_seq = " ".join(list(foldseek_seq)) | |
| if args.use_ss8 and ss8_seq: | |
| ss8_seq = " ".join(list(ss8_seq)) | |
| elif 'ankh' in plm_model_name: | |
| aa_seq = list(aa_seq) | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_seq = list(foldseek_seq) | |
| if args.use_ss8 and ss8_seq: | |
| ss8_seq = list(ss8_seq) | |
| # Truncate sequences if needed | |
| if args.max_seq_len: | |
| aa_seq = aa_seq[:args.max_seq_len] | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_seq = foldseek_seq[:args.max_seq_len] | |
| if args.use_ss8 and ss8_seq: | |
| ss8_seq = ss8_seq[:args.max_seq_len] | |
| # Tokenize sequences | |
| if 'ankh' in plm_model_name: | |
| aa_inputs = tokenizer.batch_encode_plus([aa_seq], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt") | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_inputs = tokenizer.batch_encode_plus([foldseek_seq], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt") | |
| if args.use_ss8 and ss8_seq: | |
| ss8_inputs = tokenizer.batch_encode_plus([ss8_seq], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt") | |
| else: | |
| aa_inputs = tokenizer([aa_seq], return_tensors="pt", padding=True, truncation=True) | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_inputs = tokenizer([foldseek_seq], return_tensors="pt", padding=True, truncation=True) | |
| if args.use_ss8 and ss8_seq: | |
| ss8_inputs = tokenizer([ss8_seq], return_tensors="pt", padding=True, truncation=True) | |
| # Prepare data dictionary | |
| data_dict = { | |
| "aa_seq_input_ids": aa_inputs["input_ids"], | |
| "aa_seq_attention_mask": aa_inputs["attention_mask"], | |
| } | |
| # 只有 ProSST 模型需要结构标记 | |
| if "ProSST" in plm_model_name and hasattr(args, 'prosst_stru_token') and args.prosst_stru_token: | |
| try: | |
| # 处理 ProSST 结构标记 | |
| if isinstance(args.prosst_stru_token, str): | |
| seq_clean = args.prosst_stru_token.strip("[]").replace(" ","") | |
| tokens = list(map(int, seq_clean.split(','))) if seq_clean else [] | |
| elif isinstance(args.prosst_stru_token, (list, tuple)): | |
| tokens = [int(x) for x in args.prosst_stru_token] | |
| else: | |
| tokens = [] | |
| # 添加到数据字典 | |
| if tokens: | |
| stru_tokens = torch.tensor([tokens], dtype=torch.long) | |
| data_dict["aa_seq_stru_tokens"] = stru_tokens | |
| else: | |
| # 如果没有标记,则使用零填充 | |
| data_dict["aa_seq_stru_tokens"] = torch.zeros_like(aa_inputs["input_ids"], dtype=torch.long) | |
| except Exception as e: | |
| print(f"Warning: Failed to process ProSST structure tokens: {e}") | |
| # 使用零填充 | |
| data_dict["aa_seq_stru_tokens"] = torch.zeros_like(aa_inputs["input_ids"], dtype=torch.long) | |
| if args.use_foldseek and foldseek_seq: | |
| data_dict["foldseek_seq_input_ids"] = foldseek_inputs["input_ids"] | |
| if args.use_ss8 and ss8_seq: | |
| data_dict["ss8_seq_input_ids"] = ss8_inputs["input_ids"] | |
| print("Processed input sequences with keys:", data_dict.keys()) | |
| return data_dict | |
| def predict(model, data_dict, device, args, plm_model): | |
| """Run prediction on the processed input data""" | |
| print("---------- Running Prediction ----------") | |
| # Move data to device | |
| for k, v in data_dict.items(): | |
| data_dict[k] = v.to(device) | |
| # Run model inference | |
| with torch.no_grad(): | |
| outputs = model(plm_model, data_dict) # Pass the actual plm_model instead of None | |
| # Process outputs based on problem type | |
| if args.problem_type == "regression": | |
| predictions = outputs.squeeze().item() | |
| print(f"Prediction result: {predictions}") | |
| return {"prediction": predictions} | |
| elif args.problem_type == "single_label_classification": | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| predicted_class = torch.argmax(probabilities, dim=1).item() | |
| class_probs = probabilities.squeeze().tolist() | |
| # Ensure class_probs is a list | |
| if not isinstance(class_probs, list): | |
| class_probs = [class_probs] | |
| print(f"Predicted class: {predicted_class}") | |
| print(f"Class probabilities: {class_probs}") | |
| return { | |
| "predicted_class": predicted_class, | |
| "probabilities": class_probs | |
| } | |
| elif args.problem_type == "multi_label_classification": | |
| sigmoid_outputs = torch.sigmoid(outputs) | |
| predictions = (sigmoid_outputs > 0.5).int().squeeze().tolist() | |
| probabilities = sigmoid_outputs.squeeze().tolist() | |
| # Ensure predictions and probabilities are lists | |
| if not isinstance(predictions, list): | |
| predictions = [predictions] | |
| if not isinstance(probabilities, list): | |
| probabilities = [probabilities] | |
| print(f"Predicted labels: {predictions}") | |
| print(f"Label probabilities: {probabilities}") | |
| return { | |
| "predictions": predictions, | |
| "probabilities": probabilities | |
| } | |
| def main(): | |
| try: | |
| # Parse arguments | |
| args = parse_args() | |
| # Load model, tokenizer and get device | |
| model, plm_model, tokenizer, device = load_model_and_tokenizer(args) | |
| # Process input sequences | |
| data_dict = process_sequences(args, tokenizer, args.plm_model) | |
| # Run prediction | |
| results = predict(model, data_dict, device, args, plm_model) | |
| # Output results | |
| print("\n---------- Prediction Results ----------") | |
| print(json.dumps(results, indent=2)) | |
| return 0 | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return 1 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |