Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| import sys | |
| import os | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src'))) | |
| import argparse | |
| import torch | |
| import re | |
| import json | |
| import os | |
| import warnings | |
| import numpy as np | |
| import pandas as pd | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from transformers import EsmTokenizer, EsmModel, BertModel, BertTokenizer | |
| from transformers import T5Tokenizer, T5EncoderModel, AutoTokenizer, AutoModel, AutoModelForMaskedLM | |
| from transformers import logging | |
| from peft import PeftModel | |
| # Import project modules | |
| from models.adapter_model import AdapterModel | |
| from models.lora_model import LoraModel | |
| from models.pooling import MeanPooling, Attention1dPoolingHead, LightAttentionPoolingHead | |
| # Ignore warning information | |
| logging.set_verbosity_error() | |
| warnings.filterwarnings("ignore") | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Batch predict protein function for multiple sequences") | |
| # Model parameters | |
| parser.add_argument('--eval_method', type=str, default="freeze", choices=["full", "freeze", "plm-lora", "plm-qlora", "ses-adapter", 'plm-dora', 'plm-adalora', 'plm-ia3'], help="Evaluation method") | |
| parser.add_argument('--model_path', type=str, required=True, help="Path to the trained model") | |
| parser.add_argument('--plm_model', type=str, required=True, help="Pretrained language model name or path") | |
| parser.add_argument('--pooling_method', type=str, default="mean", choices=["mean", "attention1d", "light_attention"], help="Pooling method") | |
| parser.add_argument('--problem_type', type=str, default="single_label_classification", | |
| choices=["single_label_classification", "multi_label_classification", "regression"], | |
| help="Problem type") | |
| parser.add_argument('--num_labels', type=int, default=2, help="Number of labels") | |
| parser.add_argument('--hidden_size', type=int, default=None, help="Embedding hidden size of the model") | |
| parser.add_argument('--num_attention_head', type=int, default=8, help="Number of attention heads") | |
| parser.add_argument('--attention_probs_dropout', type=float, default=0, help="Attention probs dropout prob") | |
| parser.add_argument('--pooling_dropout', type=float, default=0.25, help="Pooling dropout") | |
| # Input and output parameters | |
| parser.add_argument('--input_file', type=str, required=True, help="Path to input CSV file with sequences") | |
| parser.add_argument('--output_dir', type=str, required=True, help="Path to output CSV file dir for predictions") | |
| parser.add_argument('--output_file', type=str, required=True, help="output CSV file name for predictions") | |
| parser.add_argument('--use_foldseek', action='store_true', help="Use foldseek sequence") | |
| parser.add_argument('--use_ss8', action='store_true', help="Use secondary structure sequence") | |
| parser.add_argument('--structure_seq', type=str, default=None, help="Structure sequence types to use (comma-separated)") | |
| # Other parameters | |
| parser.add_argument('--max_seq_len', type=int, default=1024, help="Maximum sequence length") | |
| parser.add_argument('--batch_size', type=int, default=1, help="Batch size for prediction") | |
| parser.add_argument('--dataset', type=str, default="Protein-wise", help="Dataset name") | |
| args = parser.parse_args() | |
| return args | |
| def load_model_and_tokenizer(args): | |
| print("---------- Loading Model and Tokenizer ----------") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Check if model file exists | |
| if not os.path.exists(args.model_path): | |
| raise FileNotFoundError(f"Model file not found: {args.model_path}") | |
| # Load model configuration if available | |
| config_path = os.path.join(os.path.dirname(args.model_path), "config.json") | |
| try: | |
| with open(config_path, "r") as f: | |
| config = json.load(f) | |
| print(f"Loaded configuration from {config_path}") | |
| # Update args with config values if they exist | |
| if "pooling_method" in config: | |
| args.pooling_method = config["pooling_method"] | |
| if "problem_type" in config: | |
| args.problem_type = config["problem_type"] | |
| if "num_labels" in config: | |
| args.num_labels = config["num_labels"] | |
| if "num_attention_head" in config: | |
| args.num_attention_head = config["num_attention_head"] | |
| if "attention_probs_dropout" in config: | |
| args.attention_probs_dropout = config["attention_probs_dropout"] | |
| if "pooling_dropout" in config: | |
| args.pooling_dropout = config["pooling_dropout"] | |
| except FileNotFoundError: | |
| print(f"Model config not found at {config_path}. Using command line arguments.") | |
| # Build tokenizer and protein language model | |
| if "esm" in args.plm_model: | |
| tokenizer = EsmTokenizer.from_pretrained(args.plm_model) | |
| plm_model = EsmModel.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "bert" in args.plm_model: | |
| tokenizer = BertTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = BertModel.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "prot_t5" in args.plm_model: | |
| tokenizer = T5Tokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = T5EncoderModel.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.d_model | |
| elif "ankh" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = T5EncoderModel.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.d_model | |
| elif "ProSST" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.hidden_size | |
| elif "Prime" in args.plm_model: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model, do_lower_case=False) | |
| plm_model = AutoModelForMaskedLM.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.hidden_size | |
| else: | |
| tokenizer = AutoTokenizer.from_pretrained(args.plm_model) | |
| plm_model = AutoModel.from_pretrained(args.plm_model).to(device).eval() | |
| args.hidden_size = plm_model.config.hidden_size | |
| args.vocab_size = plm_model.config.vocab_size | |
| # Determine structure sequence types | |
| if args.structure_seq is None: | |
| args.structure_seq = "" | |
| print("Warning: structure_seq was None, setting to empty string") | |
| # Auto-set structure sequence flags based on structure_seq parameter | |
| if 'foldseek_seq' in args.structure_seq: | |
| args.use_foldseek = True | |
| print("Enabled foldseek_seq based on structure_seq parameter") | |
| if 'ss8_seq' in args.structure_seq: | |
| args.use_ss8 = True | |
| print("Enabled ss8_seq based on structure_seq parameter") | |
| # If flags are set but structure_seq is not, update structure_seq | |
| structure_seq_list = [] | |
| if args.use_foldseek and 'foldseek_seq' not in args.structure_seq: | |
| structure_seq_list.append("foldseek_seq") | |
| if args.use_ss8 and 'ss8_seq' not in args.structure_seq: | |
| structure_seq_list.append("ss8_seq") | |
| if structure_seq_list and not args.structure_seq: | |
| args.structure_seq = ",".join(structure_seq_list) | |
| print(f"Training method: {args.eval_method}") # Default for prediction | |
| print(f"Structure sequence: {args.structure_seq}") | |
| print(f"Use foldseek: {args.use_foldseek}") | |
| print(f"Use ss8: {args.use_ss8}") | |
| print(f"Problem type: {args.problem_type}") | |
| print(f"Number of labels: {args.num_labels}") | |
| print(f"Number of attention heads: {args.num_attention_head}") | |
| # Create and load model | |
| try: | |
| if args.eval_method in ["full", "ses-adapter", "freeze"]: | |
| model = AdapterModel(args) | |
| # ! lora/ qlora | |
| elif args.eval_method in ['plm-lora', 'plm-qlora', 'plm-dora', 'plm-adalora', 'plm-ia3']: | |
| model = LoraModel(args) | |
| if args.model_path is not None: | |
| model_path = args.model_path | |
| else: | |
| model_path = f"{args.output_root}/{args.output_dir}/{args.output_model_name}" | |
| if args.eval_method == "full": | |
| model_weights = torch.load(model_path) | |
| model.load_state_dict(model_weights['model_state_dict']) | |
| plm_model.load_state_dict(model_weights['plm_state_dict']) | |
| else: | |
| model.load_state_dict(torch.load(model_path)) | |
| model.to(device).eval() | |
| # ! lora/ qlora | |
| if args.eval_method == 'plm-lora': | |
| lora_path = model_path.replace(".pt", "_lora") | |
| plm_model = PeftModel.from_pretrained(plm_model,lora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == 'plm-qlora': | |
| lora_path = model_path.replace(".pt", "_qlora") | |
| plm_model = PeftModel.from_pretrained(plm_model,lora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-dora": | |
| dora_path = model_path.replace(".pt", "_dora") | |
| plm_model = PeftModel.from_pretrained(plm_model, dora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-adalora": | |
| adalora_path = model_path.replace(".pt", "_adalora") | |
| plm_model = PeftModel.from_pretrained(plm_model, adalora_path) | |
| plm_model = plm_model.merge_and_unload() | |
| elif args.eval_method == "plm-ia3": | |
| ia3_path = model_path.replace(".pt", "_ia3") | |
| plm_model = PeftModel.from_pretrained(plm_model, ia3_path) | |
| plm_model = plm_model.merge_and_unload() | |
| plm_model.to(device).eval() | |
| return model, plm_model, tokenizer, device | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| raise | |
| def process_sequence(args, tokenizer, plm_model_name, aa_seq, foldseek_seq="", ss8_seq="", prosst_stru_token=None): | |
| """Process and prepare a single input sequence for prediction""" | |
| # Process amino acid sequence | |
| aa_seq = aa_seq.strip() | |
| if not aa_seq: | |
| raise ValueError("Amino acid sequence is empty") | |
| # Process structure sequences if needed | |
| foldseek_seq = foldseek_seq.strip() if foldseek_seq else "" | |
| ss8_seq = ss8_seq.strip() if ss8_seq else "" | |
| # Check if structure sequences are required but not provided | |
| if args.use_foldseek and not foldseek_seq: | |
| print(f"Warning: Foldseek sequence is required but not provided for sequence: {aa_seq[:20]}...") | |
| if args.use_ss8 and not ss8_seq: | |
| print(f"Warning: SS8 sequence is required but not provided for sequence: {aa_seq[:20]}...") | |
| # Format sequences based on model type | |
| if 'prot_bert' in plm_model_name or "prot_t5" in plm_model_name: | |
| aa_seq = " ".join(list(aa_seq)) | |
| aa_seq = re.sub(r"[UZOB]", "X", aa_seq) | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_seq = " ".join(list(foldseek_seq)) | |
| if args.use_ss8 and ss8_seq: | |
| ss8_seq = " ".join(list(ss8_seq)) | |
| elif 'ankh' in plm_model_name: | |
| aa_seq = list(aa_seq) | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_seq = list(foldseek_seq) | |
| if args.use_ss8 and ss8_seq: | |
| ss8_seq = list(ss8_seq) | |
| # Truncate sequences if needed | |
| if args.max_seq_len: | |
| aa_seq = aa_seq[:args.max_seq_len] | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_seq = foldseek_seq[:args.max_seq_len] | |
| if args.use_ss8 and ss8_seq: | |
| ss8_seq = ss8_seq[:args.max_seq_len] | |
| # Tokenize sequences | |
| if 'ankh' in plm_model_name: | |
| aa_inputs = tokenizer.batch_encode_plus([aa_seq], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt") | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_inputs = tokenizer.batch_encode_plus([foldseek_seq], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt") | |
| if args.use_ss8 and ss8_seq: | |
| ss8_inputs = tokenizer.batch_encode_plus([ss8_seq], add_special_tokens=True, padding=True, is_split_into_words=True, return_tensors="pt") | |
| else: | |
| aa_inputs = tokenizer([aa_seq], return_tensors="pt", padding=True, truncation=True) | |
| if args.use_foldseek and foldseek_seq: | |
| foldseek_inputs = tokenizer([foldseek_seq], return_tensors="pt", padding=True, truncation=True) | |
| if args.use_ss8 and ss8_seq: | |
| ss8_inputs = tokenizer([ss8_seq], return_tensors="pt", padding=True, truncation=True) | |
| # Prepare data dictionary | |
| data_dict = { | |
| "aa_seq_input_ids": aa_inputs["input_ids"], | |
| "aa_seq_attention_mask": aa_inputs["attention_mask"], | |
| } | |
| if "ProSST" in plm_model_name and prosst_stru_token is not None: | |
| try: | |
| if isinstance(prosst_stru_token, str): | |
| seq_clean = prosst_stru_token.strip("[]").replace(" ","") | |
| tokens = list(map(int, seq_clean.split(','))) if seq_clean else [] | |
| elif isinstance(prosst_stru_token, (list, tuple)): | |
| tokens = [int(x) for x in prosst_stru_token] | |
| else: | |
| tokens = [] | |
| if tokens: | |
| stru_tokens = torch.tensor([tokens], dtype=torch.long) | |
| data_dict["aa_seq_stru_tokens"] = stru_tokens | |
| else: | |
| data_dict["aa_seq_stru_tokens"] = torch.zeros_like(aa_inputs["input_ids"], dtype=torch.long) | |
| except Exception as e: | |
| print(f"Warning: Failed to process ProSST structure tokens: {e}") | |
| data_dict["aa_seq_stru_tokens"] = torch.zeros_like(aa_inputs["input_ids"], dtype=torch.long) | |
| if args.use_foldseek and foldseek_seq: | |
| data_dict["foldseek_seq_input_ids"] = foldseek_inputs["input_ids"] | |
| if args.use_ss8 and ss8_seq: | |
| data_dict["ss8_seq_input_ids"] = ss8_inputs["input_ids"] | |
| return data_dict | |
| def predict_batch(model, plm_model, data_dict, device, args): | |
| """Run prediction on a batch of processed input data""" | |
| # Move data to device | |
| for k, v in data_dict.items(): | |
| data_dict[k] = v.to(device) | |
| # Run model inference | |
| with torch.no_grad(): | |
| outputs = model(plm_model, data_dict) | |
| # Process outputs based on problem type | |
| if args.problem_type == "regression": | |
| predictions = outputs.squeeze().cpu().numpy() | |
| # 确保返回标量值 | |
| if np.isscalar(predictions): | |
| return {"predictions": predictions} | |
| else: | |
| # 如果是批处理,返回整个数组 | |
| return {"predictions": predictions.tolist() if isinstance(predictions, np.ndarray) else predictions} | |
| elif args.problem_type == "single_label_classification": | |
| probabilities = torch.nn.functional.softmax(outputs, dim=1) | |
| predicted_classes = torch.argmax(probabilities, dim=1).cpu().numpy() | |
| class_probs = probabilities.cpu().numpy() | |
| return { | |
| "predicted_classes": predicted_classes.tolist(), | |
| "probabilities": class_probs.tolist() | |
| } | |
| elif args.problem_type == "multi_label_classification": | |
| sigmoid_outputs = torch.sigmoid(outputs) | |
| predictions = (sigmoid_outputs > 0.5).int().cpu().numpy() | |
| probabilities = sigmoid_outputs.cpu().numpy() | |
| return { | |
| "predictions": predictions.tolist(), | |
| "probabilities": probabilities.tolist() | |
| } | |
| def main(): | |
| # Parse command line arguments | |
| args = parse_args() | |
| try: | |
| # Load model and tokenizer | |
| model, plm_model, tokenizer, device = load_model_and_tokenizer(args) | |
| # Read input CSV file | |
| print(f"---------- Reading input file: {args.input_file} ----------") | |
| try: | |
| df = pd.read_csv(args.input_file) | |
| print(f"Found {len(df)} sequences in input file") | |
| except Exception as e: | |
| print(f"Error reading input file: {str(e)}") | |
| sys.exit(1) | |
| # Check required columns | |
| required_columns = ["aa_seq"] | |
| if args.use_foldseek: | |
| required_columns.append("foldseek_seq") | |
| if args.use_ss8: | |
| required_columns.append("ss8_seq") | |
| missing_columns = [col for col in required_columns if col not in df.columns] | |
| if missing_columns: | |
| print(f"Error: Input file is missing required columns: {', '.join(missing_columns)}") | |
| sys.exit(1) | |
| # Initialize results dataframe | |
| results = [] | |
| # Process each sequence | |
| print("---------- Processing sequences ----------") | |
| for idx, row in tqdm(df.iterrows(), total=len(df), desc="Predicting"): | |
| try: | |
| # Get sequences from row | |
| aa_seq = row["aa_seq"] | |
| foldseek_seq = row["foldseek_seq"] if "foldseek_seq" in df.columns and args.use_foldseek else "" | |
| ss8_seq = row["ss8_seq"] if "ss8_seq" in df.columns and args.use_ss8 else "" | |
| # Process sequence | |
| data_dict = process_sequence(args, tokenizer, args.plm_model, aa_seq, foldseek_seq, ss8_seq) | |
| # Run prediction | |
| prediction_results = predict_batch(model, plm_model, data_dict, device, args) | |
| # Create result row | |
| result_row = {"aa_seq": aa_seq} | |
| # Add sequence ID if available | |
| if "id" in df.columns: | |
| result_row["id"] = row["id"] | |
| # Add prediction results based on problem type | |
| if args.problem_type == "regression": | |
| # result_row["prediction"] = prediction_results["predictions"][0] | |
| if isinstance(prediction_results["predictions"], (list, np.ndarray)): | |
| result_row["prediction"] = prediction_results["predictions"][0] | |
| else: | |
| result_row["prediction"] = prediction_results["predictions"] | |
| elif args.problem_type == "single_label_classification": | |
| result_row["predicted_class"] = prediction_results["predicted_classes"][0] | |
| # Add class probabilities | |
| for i, prob in enumerate(prediction_results["probabilities"][0]): | |
| result_row[f"class_{i}_prob"] = prob | |
| elif args.problem_type == "multi_label_classification": | |
| # Add binary predictions | |
| for i, pred in enumerate(prediction_results["predictions"][0]): | |
| result_row[f"label_{i}"] = pred | |
| # Add probabilities | |
| for i, prob in enumerate(prediction_results["probabilities"][0]): | |
| result_row[f"label_{i}_prob"] = prob | |
| results.append(result_row) | |
| except Exception as e: | |
| print(f"Error processing sequence at index {idx}: {str(e)}") | |
| # Add error row | |
| error_row = {"aa_seq": aa_seq, "error": str(e)} | |
| if "id" in df.columns: | |
| error_row["id"] = row["id"] | |
| results.append(error_row) | |
| # Create results dataframe | |
| results_df = pd.DataFrame(results) | |
| # Save results to output file | |
| if not os.path.exists(args.output_dir): | |
| os.makedirs(args.output_dir) | |
| output_file = os.path.join(args.output_dir, args.output_file) | |
| print(f"---------- Saving results to {output_file} ----------") | |
| results_df.to_csv(output_file, index=False) | |
| print(f"Saved {len(results_df)} prediction results") | |
| print("---------- Batch prediction completed successfully ----------") | |
| except Exception as e: | |
| print(f"Error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| if __name__ == "__main__": | |
| main() | |