Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| import signal | |
| import threading | |
| import queue | |
| import time | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import re | |
| from datasets import load_dataset | |
| from web.utils.command import preview_eval_command | |
| def create_eval_tab(constant): | |
| plm_models = constant["plm_models"] | |
| dataset_configs = constant["dataset_configs"] | |
| is_evaluating = False | |
| current_process = None | |
| output_queue = queue.Queue() | |
| stop_thread = False | |
| process_aborted = False # 新增标志,表示进程是否被手动终止 | |
| plm_models = constant["plm_models"] | |
| def format_metrics(metrics_file): | |
| """Convert metrics to HTML table format for display""" | |
| try: | |
| df = pd.read_csv(metrics_file) | |
| metrics_dict = df.iloc[0].to_dict() | |
| # 定义指标优先级顺序 | |
| priority_metrics = ['loss', 'accuracy', 'f1', 'precision', 'recall', 'auroc', 'mcc'] | |
| # 构建优先级排序键 | |
| def get_priority(item): | |
| name = item[0] | |
| if name in priority_metrics: | |
| return priority_metrics.index(name) | |
| return len(priority_metrics) | |
| # 按优先级排序 | |
| sorted_metrics = sorted(metrics_dict.items(), key=get_priority) | |
| # 计算指标数量 | |
| metrics_count = len(sorted_metrics) | |
| html = f""" | |
| <div style="max-width: 800px; margin: 0 auto; font-family: Arial, sans-serif;"> | |
| <p style="text-align: center; margin-bottom: 15px; color: #666;">{metrics_count} metrics found</p> | |
| <table style="width: 100%; border-collapse: collapse; font-size: 14px; border: 1px solid #ddd; box-shadow: 0 2px 3px rgba(0,0,0,0.1);"> | |
| <thead> | |
| <tr style="background-color: #f0f0f0;"> | |
| <th style="padding: 12px; text-align: center; border: 1px solid #ddd; font-weight: bold; width: 50%;">Metric</th> | |
| <th style="padding: 12px; text-align: center; border: 1px solid #ddd; font-weight: bold; width: 50%;">Value</th> | |
| </tr> | |
| </thead> | |
| <tbody> | |
| """ | |
| # 添加每个指标行,使用交替行颜色 | |
| for i, (metric_name, metric_value) in enumerate(sorted_metrics): | |
| row_style = 'background-color: #f9f9f9;' if i % 2 == 0 else '' | |
| # 对优先级指标使用粗体 | |
| is_priority = metric_name in priority_metrics | |
| name_style = 'font-weight: bold;' if is_priority else '' | |
| # 转换指标名称:缩写用大写,非缩写首字母大写 | |
| display_name = metric_name | |
| if metric_name.lower() in ['f1', 'mcc', 'auroc']: | |
| display_name = metric_name.upper() | |
| else: | |
| display_name = metric_name.capitalize() | |
| html += f""" | |
| <tr style="{row_style}"> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #ddd; {name_style}">{display_name}</td> | |
| <td style="padding: 10px; text-align: center; border: 1px solid #ddd;">{metric_value:.4f}</td> | |
| </tr> | |
| """ | |
| html += """ | |
| </tbody> | |
| </table> | |
| <p style="text-align: center; margin-top: 10px; color: #888; font-size: 12px;">Test completed at: """ + time.strftime("%Y-%m-%d %H:%M:%S") + """</p> | |
| </div> | |
| """ | |
| return html | |
| except Exception as e: | |
| return f"Error formatting metrics: {str(e)}" | |
| def process_output(process, queue): | |
| nonlocal stop_thread | |
| while True: | |
| if stop_thread: | |
| break | |
| output = process.stdout.readline() | |
| if output == '' and process.poll() is not None: | |
| break | |
| if output: | |
| queue.put(output.strip()) | |
| process.stdout.close() | |
| def evaluate_model(plm_model, model_path, eval_method, is_custom_dataset, dataset_defined, dateset_custom, problem_type, num_labels, metrics, batch_mode, batch_size, batch_token, eval_structure_seq, pooling_method): | |
| nonlocal is_evaluating, current_process, stop_thread, process_aborted | |
| if is_evaluating: | |
| return "Evaluation is already in progress. Please wait...", gr.update(visible=False) | |
| # First reset all state variables to ensure clean start | |
| is_evaluating = True | |
| stop_thread = False | |
| process_aborted = False # Reset abort flag | |
| # Clear the output queue | |
| while not output_queue.empty(): | |
| try: | |
| output_queue.get_nowait() | |
| except queue.Empty: | |
| break | |
| # Initialize progress info and start time | |
| start_time = time.time() | |
| progress_info = { | |
| "stage": "Preparing", | |
| "progress": 0, | |
| "total_samples": 0, | |
| "current": 0, | |
| "total": 0, | |
| "elapsed_time": "00:00:00", | |
| "lines": [] | |
| } | |
| # Create initial progress bar with completely empty state | |
| initial_progress_html = generate_progress_bar(progress_info) | |
| yield initial_progress_html, gr.update(visible=False) | |
| try: | |
| # Validate inputs | |
| if not model_path or not os.path.exists(os.path.dirname(model_path)): | |
| is_evaluating = False | |
| yield """ | |
| <div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
| <p style="margin: 0; color: #c62828; font-weight: bold;">Error: Invalid model path</p> | |
| </div> | |
| """, gr.update(visible=False) | |
| return | |
| if is_custom_dataset == "Use Custom Dataset": | |
| dataset = dateset_custom | |
| test_file = dateset_custom | |
| else: | |
| dataset = dataset_defined | |
| if dataset not in dataset_configs: | |
| is_evaluating = False | |
| yield """ | |
| <div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
| <p style="margin: 0; color: #c62828; font-weight: bold;">Error: Invalid dataset selection</p> | |
| </div> | |
| """, gr.update(visible=False) | |
| return | |
| config_path = dataset_configs[dataset] | |
| with open(config_path, 'r') as f: | |
| dataset_config = json.load(f) | |
| test_file = dataset_config["dataset"] | |
| # Get dataset name | |
| dataset_display_name = dataset.split('/')[-1] | |
| test_result_name = f"test_results_{os.path.basename(model_path)}_{dataset_display_name}" | |
| test_result_dir = os.path.join(os.path.dirname(model_path), test_result_name) | |
| # Prepare command | |
| cmd = [sys.executable, "src/eval.py"] | |
| args_dict = { | |
| "eval_method": eval_method, | |
| "model_path": model_path, | |
| "test_file": test_file, | |
| "problem_type": problem_type, | |
| "num_labels": num_labels, | |
| "metrics": ",".join(metrics), | |
| "plm_model": plm_models[plm_model], | |
| "test_result_dir": test_result_dir, | |
| "dataset": dataset_display_name, | |
| "pooling_method": pooling_method, | |
| } | |
| if batch_mode == "Batch Size Mode": | |
| args_dict["batch_size"] = batch_size | |
| else: | |
| args_dict["batch_token"] = batch_token | |
| if eval_method == "ses-adapter": | |
| args_dict["structure_seq"] = ",".join(eval_structure_seq) if eval_structure_seq else None | |
| # Add flags for using foldseek and ss8 | |
| if "foldseek_seq" in eval_structure_seq: | |
| args_dict["use_foldseek"] = True | |
| if "ss8_seq" in eval_structure_seq: | |
| args_dict["use_ss8"] = True | |
| else: | |
| args_dict["structure_seq"] = "" | |
| for k, v in args_dict.items(): | |
| if v is True: | |
| cmd.append(f"--{k}") | |
| elif v is not False and v is not None: | |
| cmd.append(f"--{k}") | |
| cmd.append(str(v)) | |
| # Start evaluation process | |
| current_process = subprocess.Popen( | |
| cmd, | |
| stdout=subprocess.PIPE, | |
| stderr=subprocess.STDOUT, | |
| text=True, | |
| bufsize=1, | |
| universal_newlines=True, | |
| preexec_fn=os.setsid | |
| ) | |
| output_thread = threading.Thread(target=process_output, args=(current_process, output_queue)) | |
| output_thread.daemon = True | |
| output_thread.start() | |
| sample_pattern = r"Total samples: (\d+)" | |
| progress_pattern = r"(\d+)/(\d+)" | |
| last_update_time = time.time() | |
| while True: | |
| # Check if the process still exists and hasn't been aborted | |
| if process_aborted or current_process is None or current_process.poll() is not None: | |
| break | |
| try: | |
| new_lines = [] | |
| lines_processed = 0 | |
| while lines_processed < 10: | |
| try: | |
| line = output_queue.get_nowait() | |
| new_lines.append(line) | |
| progress_info["lines"].append(line) | |
| # print(line) | |
| # Parse total samples | |
| if "Total samples" in line: | |
| match = re.search(sample_pattern, line) | |
| if match: | |
| progress_info["total_samples"] = int(match.group(1)) | |
| progress_info["stage"] = "Evaluating" | |
| # Parse progress | |
| if "it/s" in line and "/" in line: | |
| match = re.search(progress_pattern, line) | |
| if match: | |
| progress_info["current"] = int(match.group(1)) | |
| progress_info["total"] = int(match.group(2)) | |
| progress_info["progress"] = (progress_info["current"] / progress_info["total"]) * 100 | |
| if "Evaluation completed" in line: | |
| progress_info["stage"] = "Completed" | |
| progress_info["progress"] = 100 | |
| lines_processed += 1 | |
| except queue.Empty: | |
| break | |
| # 无论是否有新行,都更新时间信息 | |
| elapsed = time.time() - start_time | |
| hours, remainder = divmod(int(elapsed), 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| progress_info["elapsed_time"] = f"{hours:02}:{minutes:02}:{seconds:02}" | |
| # 即使没有新行,也定期更新进度条(每0.5秒) | |
| current_time = time.time() | |
| if lines_processed > 0 or (current_time - last_update_time) >= 0.5: | |
| # Generate progress bar HTML | |
| progress_html = generate_progress_bar(progress_info) | |
| # Only yield updates if there's actual new information | |
| yield progress_html, gr.update(visible=False) | |
| last_update_time = current_time | |
| time.sleep(0.1) # 减少循环间隔,使更新更频繁 | |
| except Exception as e: | |
| yield f""" | |
| <div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
| <p style="margin: 0; color: #c62828;">Error reading output: {str(e)}</p> | |
| </div> | |
| """, gr.update(visible=False) | |
| if current_process.returncode == 0: | |
| # Load and format results | |
| result_file = os.path.join(test_result_dir, "evaluation_metrics.csv") | |
| if os.path.exists(result_file): | |
| metrics_html = format_metrics(result_file) | |
| # Calculate total evaluation time | |
| total_time = time.time() - start_time | |
| hours, remainder = divmod(int(total_time), 3600) | |
| minutes, seconds = divmod(remainder, 60) | |
| time_str = f"{hours:02}:{minutes:02}:{seconds:02}" | |
| summary_html = f""" | |
| <div style="padding: 15px; background-color: #e8f5e9; border-radius: 5px; margin-bottom: 15px;"> | |
| <h3 style="margin-top: 0; color: #2e7d32;">Evaluation completed successfully!</h3> | |
| <p><b>Total evaluation time:</b> {time_str}</p> | |
| <p><b>Evaluation dataset:</b> {dataset_display_name}</p> | |
| <p><b>Total samples:</b> {progress_info.get('total_samples', 0)}</p> | |
| </div> | |
| <div style="margin-top: 20px; font-weight: bold; font-size: 18px; text-align: center;">Evaluation Results</div> | |
| {metrics_html} | |
| """ | |
| # 设置下载按钮可见并指向结果文件 | |
| yield summary_html, gr.update(value=result_file, visible=True) | |
| else: | |
| error_output = "\n".join(progress_info.get("lines", [])) | |
| yield f""" | |
| <div style="padding: 10px; background-color: #fff8e1; border-radius: 5px; margin-bottom: 10px;"> | |
| <p style="margin: 0; color: #f57f17; font-weight: bold;">Evaluation completed, but metrics file not found at: {result_file}</p> | |
| </div> | |
| """, gr.update(visible=False) | |
| else: | |
| error_output = "\n".join(progress_info.get("lines", [])) | |
| if not error_output: | |
| error_output = "No output captured from the evaluation process" | |
| yield f""" | |
| <div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
| <p style="margin: 0; color: #c62828; font-weight: bold;">Evaluation failed:</p> | |
| <pre style="margin: 5px 0 0; white-space: pre-wrap; max-height: 300px; overflow-y: auto;">{error_output}</pre> | |
| </div> | |
| """, gr.update(visible=False) | |
| except Exception as e: | |
| yield f""" | |
| <div style="padding: 10px; background-color: #ffebee; border-radius: 5px; margin-bottom: 10px;"> | |
| <p style="margin: 0; color: #c62828; font-weight: bold;">Error during evaluation process:</p> | |
| <pre style="margin: 5px 0 0; white-space: pre-wrap;">{str(e)}</pre> | |
| </div> | |
| """, gr.update(visible=False) | |
| finally: | |
| if current_process: | |
| stop_thread = True | |
| is_evaluating = False | |
| current_process = None | |
| def generate_progress_bar(progress_info): | |
| """Generate HTML for evaluation progress bar""" | |
| stage = progress_info.get("stage", "Preparing") | |
| progress = progress_info.get("progress", 0) | |
| current = progress_info.get("current", 0) | |
| total = progress_info.get("total", 0) | |
| total_samples = progress_info.get("total_samples", 0) | |
| # 确保进度在0-100之间 | |
| progress = max(0, min(100, progress)) | |
| # 准备详细信息 | |
| details = [] | |
| if total_samples > 0: | |
| details.append(f"Total samples: {total_samples}") | |
| if current > 0 and total > 0: | |
| details.append(f"Current progress: {current}/{total}") | |
| # 计算评估时间(如果有) | |
| elapsed_time = progress_info.get("elapsed_time", "") | |
| if elapsed_time: | |
| details.append(f"Elapsed time: {elapsed_time}") | |
| details_text = ", ".join(details) | |
| # 创建更现代化的进度条 | |
| html = f""" | |
| <div style="background-color: #f8f9fa; border-radius: 10px; padding: 20px; margin-bottom: 15px; box-shadow: 0 2px 5px rgba(0,0,0,0.05);"> | |
| <div style="display: flex; justify-content: space-between; margin-bottom: 12px;"> | |
| <div> | |
| <span style="font-weight: 600; font-size: 16px;">Evaluation Status: </span> | |
| <span style="color: #1976d2; font-weight: 500; font-size: 16px;">{stage}</span> | |
| </div> | |
| <div> | |
| <span style="font-weight: 600; color: #333;">{progress:.1f}%</span> | |
| </div> | |
| </div> | |
| <div style="margin-bottom: 15px; background-color: #e9ecef; height: 10px; border-radius: 5px; overflow: hidden;"> | |
| <div style="background-color: #4285f4; width: {progress}%; height: 100%; border-radius: 5px; transition: width 0.3s ease;"></div> | |
| </div> | |
| <div style="display: flex; flex-wrap: wrap; gap: 10px; font-size: 14px; color: #555;"> | |
| {f'<div style="background-color: #e3f2fd; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Total samples:</span> {total_samples}</div>' if total_samples > 0 else ''} | |
| {f'<div style="background-color: #e8f5e9; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Progress:</span> {current}/{total}</div>' if current > 0 and total > 0 else ''} | |
| {f'<div style="background-color: #fff8e1; padding: 5px 10px; border-radius: 4px;"><span style="font-weight: 500;">Time:</span> {elapsed_time}</div>' if elapsed_time else ''} | |
| </div> | |
| </div> | |
| """ | |
| return html | |
| def handle_abort(): | |
| """Handle abortion of the evaluation process""" | |
| nonlocal is_evaluating, current_process, stop_thread, process_aborted | |
| if current_process is None: | |
| return """ | |
| <div style="padding: 10px; background-color: #f5f5f5; border-radius: 5px;"> | |
| <p style="margin: 0;">No evaluation in progress to terminate.</p> | |
| </div> | |
| """, gr.update(visible=False) | |
| try: | |
| # Set the abort flag before terminating the process | |
| process_aborted = True | |
| stop_thread = True | |
| # Using terminate instead of killpg for safety | |
| current_process.terminate() | |
| # Wait for process to terminate (with timeout) | |
| try: | |
| current_process.wait(timeout=5) | |
| except subprocess.TimeoutExpired: | |
| current_process.kill() | |
| # Reset state completely | |
| current_process = None | |
| is_evaluating = False | |
| # Reset output queue to clear any pending messages | |
| while not output_queue.empty(): | |
| try: | |
| output_queue.get_nowait() | |
| except queue.Empty: | |
| break | |
| return """ | |
| <div style="padding: 10px; background-color: #e8f5e9; border-radius: 5px;"> | |
| <p style="margin: 0; color: #2e7d32; font-weight: bold;">Evaluation successfully terminated!</p> | |
| <p style="margin: 5px 0 0; color: #388e3c;">All evaluation state has been reset.</p> | |
| </div> | |
| """, gr.update(visible=False) | |
| except Exception as e: | |
| # Still need to reset states even if there's an error | |
| current_process = None | |
| is_evaluating = False | |
| process_aborted = False | |
| # Reset output queue | |
| while not output_queue.empty(): | |
| try: | |
| output_queue.get_nowait() | |
| except queue.Empty: | |
| break | |
| return f""" | |
| <div style="padding: 10px; background-color: #ffebee; border-radius: 5px;"> | |
| <p style="margin: 0; color: #c62828; font-weight: bold;">Failed to terminate evaluation: {str(e)}</p> | |
| <p style="margin: 5px 0 0; color: #c62828;">Evaluation state has been reset.</p> | |
| </div> | |
| """, gr.update(visible=False) | |
| with gr.Tab("Evaluation"): | |
| gr.Markdown("### Model and Dataset Configuration") | |
| # Original evaluation interface components | |
| with gr.Group(): | |
| with gr.Row(): | |
| eval_model_path = gr.Textbox( | |
| label="Model Path", | |
| value="ckpt/demo/demo_provided.pt", | |
| placeholder="Path to the trained model" | |
| ) | |
| eval_plm_model = gr.Dropdown( | |
| choices=list(plm_models.keys()), | |
| label="Protein Language Model" | |
| ) | |
| with gr.Row(): | |
| eval_method = gr.Dropdown( | |
| choices=["full", "freeze", "ses-adapter", "plm-lora", "plm-qlora", "plm_adalora", "plm_dora", "plm_ia3"], | |
| label="Evaluation Method", | |
| value="freeze" | |
| ) | |
| eval_pooling_method = gr.Dropdown( | |
| choices=["mean", "attention1d", "light_attention"], | |
| label="Pooling Method", | |
| value="mean" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| with gr.Row(): | |
| is_custom_dataset = gr.Radio( | |
| choices=["Use Custom Dataset", "Use Pre-defined Dataset"], | |
| label="Dataset Selection", | |
| value="Use Pre-defined Dataset" | |
| ) | |
| eval_dataset_defined = gr.Dropdown( | |
| choices=list(dataset_configs.keys()), | |
| label="Evaluation Dataset", | |
| visible=True | |
| ) | |
| eval_dataset_custom = gr.Textbox( | |
| label="Custom Dataset Path", | |
| placeholder="Huggingface Dataset eg: user/dataset", | |
| visible=False | |
| ) | |
| with gr.Column(scale=1, min_width=120, elem_classes="preview-button-container"): | |
| # Add dataset preview functionality | |
| preview_button = gr.Button( | |
| "Preview Dataset", | |
| variant="primary", | |
| size="lg", | |
| elem_classes="preview-button" | |
| ) | |
| # 将数据统计和表格都放入折叠面板 | |
| with gr.Row(): | |
| with gr.Accordion("Dataset Preview", open=False) as preview_accordion: | |
| # 数据统计区域 | |
| with gr.Row(): | |
| dataset_stats_md = gr.HTML("", elem_classes=["dataset-stats"]) | |
| # 表格区域 | |
| with gr.Row(): | |
| preview_table = gr.Dataframe( | |
| headers=["Name", "Sequence", "Label"], | |
| value=[["No dataset selected", "-", "-"]], | |
| wrap=True, | |
| interactive=False, | |
| row_count=3, | |
| elem_classes=["preview-table"] | |
| ) | |
| # Add CSS styles | |
| gr.HTML(""" | |
| <style> | |
| /* 数据统计样式 */ | |
| .dataset-stats { | |
| margin: 0 0 15px 0; | |
| padding: 0; | |
| } | |
| .dataset-stats table { | |
| width: 100%; | |
| border-collapse: collapse; | |
| font-size: 0.9em; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
| border-radius: 8px; | |
| overflow: hidden; | |
| table-layout: fixed; | |
| } | |
| .dataset-stats th { | |
| background-color: #e0e0e0; | |
| font-weight: bold; | |
| padding: 6px 10px; | |
| text-align: center; | |
| border: 1px solid #ddd; | |
| font-size: 0.95em; | |
| white-space: nowrap; | |
| overflow: hidden; | |
| min-width: 120px; | |
| } | |
| .dataset-stats td { | |
| padding: 6px 10px; | |
| text-align: center; | |
| border: 1px solid #ddd; | |
| } | |
| .dataset-stats h2 { | |
| font-size: 1.1em; | |
| margin: 0 0 10px 0; | |
| text-align: center; | |
| } | |
| /* 表格样式 */ | |
| .preview-table table { | |
| background-color: white !important; | |
| font-size: 0.9em !important; | |
| width: 100%; | |
| table-layout: fixed !important; | |
| } | |
| .preview-table .gr-block.gr-box { | |
| background-color: transparent !important; | |
| } | |
| .preview-table .gr-input-label { | |
| background-color: transparent !important; | |
| } | |
| /* 表格外观增强 */ | |
| .preview-table table { | |
| margin-top: 0; | |
| border-radius: 8px; | |
| overflow: hidden; | |
| box-shadow: 0 2px 4px rgba(0,0,0,0.05); | |
| } | |
| /* 表头样式 */ | |
| .preview-table th { | |
| background-color: #e0e0e0 !important; | |
| font-weight: bold !important; | |
| padding: 6px !important; | |
| border-bottom: 1px solid #ccc !important; | |
| font-size: 0.95em !important; | |
| text-align: center !important; | |
| white-space: nowrap !important; | |
| min-width: 120px !important; | |
| } | |
| /* 单元格样式 */ | |
| .preview-table td { | |
| padding: 4px 6px !important; | |
| max-width: 300px !important; | |
| overflow: hidden; | |
| text-overflow: ellipsis; | |
| white-space: nowrap; | |
| text-align: left !important; | |
| } | |
| /* 悬停效果 */ | |
| .preview-table tr:hover { | |
| background-color: #f0f0f0 !important; | |
| } | |
| /* 折叠面板样式 */ | |
| .gr-accordion { | |
| border: 1px solid #e0e0e0; | |
| border-radius: 8px; | |
| overflow: hidden; | |
| margin-bottom: 15px; | |
| } | |
| /* 折叠面板标题样式 */ | |
| .gr-accordion .label-wrap { | |
| background-color: #f5f5f5; | |
| padding: 8px 15px; | |
| font-weight: bold; | |
| } | |
| .preview-button { | |
| height: 86px !important; | |
| } | |
| </style> | |
| """, visible=True) | |
| ### These are settings for custom dataset. ### | |
| with gr.Row(visible=True) as custom_dataset_settings: | |
| problem_type = gr.Dropdown( | |
| choices=["single_label_classification", "multi_label_classification", "regression"], | |
| label="Problem Type", | |
| value="single_label_classification", | |
| scale=23, | |
| interactive=False | |
| ) | |
| num_labels = gr.Number( | |
| value=2, | |
| label="Number of Labels", | |
| scale=11, | |
| interactive=False | |
| ) | |
| metrics = gr.Dropdown( | |
| choices=["accuracy", "recall", "precision", "f1", "mcc", "auroc", "f1_max", "spearman_corr", "mse"], | |
| label="Metrics", | |
| value=["accuracy", "mcc", "f1", "precision", "recall", "auroc"], | |
| scale=101, | |
| multiselect=True, | |
| interactive=False | |
| ) | |
| # Add dataset preview function | |
| def update_dataset_preview(dataset_type=None, defined_dataset=None, custom_dataset=None): | |
| """Update dataset preview content""" | |
| # Determine which dataset to use based on selection | |
| if dataset_type == "Use Custom Dataset" and custom_dataset: | |
| try: | |
| # Try to load custom dataset | |
| dataset = load_dataset(custom_dataset) | |
| stats_html = f""" | |
| <div style="text-align: center; margin: 20px 0;"> | |
| <table style="width: 100%; border-collapse: collapse; margin: 0 auto;"> | |
| <tr> | |
| <th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Dataset</th> | |
| <th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Train Samples</th> | |
| <th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Val Samples</th> | |
| <th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Test Samples</th> | |
| </tr> | |
| <tr> | |
| <td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{custom_dataset}</td> | |
| <td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["train"]) if "train" in dataset else 0}</td> | |
| <td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["validation"]) if "validation" in dataset else 0}</td> | |
| <td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["test"]) if "test" in dataset else 0}</td> | |
| </tr> | |
| </table> | |
| </div> | |
| """ | |
| # Get sample data points | |
| split = "train" if "train" in dataset else list(dataset.keys())[0] | |
| samples = dataset[split].select(range(min(3, len(dataset[split])))) | |
| if len(samples) == 0: | |
| return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
| # Get fields actually present in the dataset | |
| available_fields = list(samples[0].keys()) | |
| # Build sample data | |
| sample_data = [] | |
| for sample in samples: | |
| sample_dict = {} | |
| for field in available_fields: | |
| # Keep full sequence | |
| sample_dict[field] = str(sample[field]) | |
| sample_data.append(sample_dict) | |
| df = pd.DataFrame(sample_data) | |
| return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True) | |
| except Exception as e: | |
| error_html = f""" | |
| <div> | |
| <h2>Error loading dataset</h2> | |
| <p style="color: #c62828;">{str(e)}</p> | |
| </div> | |
| """ | |
| return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
| # Use predefined dataset | |
| elif dataset_type == "Use Pre-defined Dataset" and defined_dataset: | |
| try: | |
| config_path = dataset_configs[defined_dataset] | |
| with open(config_path, 'r') as f: | |
| config = json.load(f) | |
| # Load dataset statistics | |
| dataset = load_dataset(config["dataset"]) | |
| stats_html = f""" | |
| <div style="text-align: center; margin: 20px 0;"> | |
| <table style="width: 100%; border-collapse: collapse; margin: 0 auto;"> | |
| <tr> | |
| <th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Dataset</th> | |
| <th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Train Samples</th> | |
| <th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Val Samples</th> | |
| <th style="padding: 8px; font-size: 14px; border: 1px solid #ddd; background-color: #e0e0e0; font-weight: bold; border-bottom: 1px solid #ccc; text-align: center;">Test Samples</th> | |
| </tr> | |
| <tr> | |
| <td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{config["dataset"]}</td> | |
| <td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["train"]) if "train" in dataset else 0}</td> | |
| <td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["validation"]) if "validation" in dataset else 0}</td> | |
| <td style="padding: 15px; font-size: 14px; border: 1px solid #ddd; text-align: center;">{len(dataset["test"]) if "test" in dataset else 0}</td> | |
| </tr> | |
| </table> | |
| </div> | |
| """ | |
| # Get sample data points and available fields | |
| samples = dataset["train"].select(range(min(3, len(dataset["train"])))) | |
| if len(samples) == 0: | |
| return gr.update(value=stats_html), gr.update(value=[["No data available", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
| # Get fields actually present in the dataset | |
| available_fields = list(samples[0].keys()) | |
| # Build sample data | |
| sample_data = [] | |
| for sample in samples: | |
| sample_dict = {} | |
| for field in available_fields: | |
| # Keep full sequence | |
| sample_dict[field] = str(sample[field]) | |
| sample_data.append(sample_dict) | |
| df = pd.DataFrame(sample_data) | |
| return gr.update(value=stats_html), gr.update(value=df.values.tolist(), headers=df.columns.tolist()), gr.update(open=True) | |
| except Exception as e: | |
| error_html = f""" | |
| <div> | |
| <h2>Error loading dataset</h2> | |
| <p style="color: #c62828;">{str(e)}</p> | |
| </div> | |
| """ | |
| return gr.update(value=error_html), gr.update(value=[["Error", str(e), "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
| # If no valid dataset information provided | |
| return gr.update(value=""), gr.update(value=[["No dataset selected", "-", "-"]], headers=["Name", "Sequence", "Label"]), gr.update(open=True) | |
| # Preview button click event | |
| preview_button.click( | |
| fn=update_dataset_preview, | |
| inputs=[is_custom_dataset, eval_dataset_defined, eval_dataset_custom], | |
| outputs=[dataset_stats_md, preview_table, preview_accordion] | |
| ) | |
| def update_dataset_settings(choice, dataset_name=None): | |
| if choice == "Use Pre-defined Dataset": | |
| # Load configuration from dataset_config | |
| if dataset_name and dataset_name in dataset_configs: | |
| with open(dataset_configs[dataset_name], 'r') as f: | |
| config = json.load(f) | |
| # 处理metrics,将字符串转换为列表以适应多选组件 | |
| metrics_value = config.get("metrics", "accuracy,mcc,f1,precision,recall,auroc") | |
| if isinstance(metrics_value, str): | |
| metrics_value = metrics_value.split(",") | |
| return [ | |
| gr.update(visible=True), # eval_dataset_defined | |
| gr.update(visible=False), # eval_dataset_custom | |
| gr.update(value=config.get("problem_type", ""), interactive=False), | |
| gr.update(value=config.get("num_labels", 1), interactive=False), | |
| gr.update(value=metrics_value, interactive=False) | |
| ] | |
| else: | |
| # Custom dataset settings | |
| return [ | |
| gr.update(visible=False), # eval_dataset_defined | |
| gr.update(visible=True), # eval_dataset_custom | |
| gr.update(value="", interactive=True), | |
| gr.update(value=2, interactive=True), | |
| gr.update(value="", interactive=True) | |
| ] | |
| is_custom_dataset.change( | |
| fn=update_dataset_settings, | |
| inputs=[is_custom_dataset, eval_dataset_defined], | |
| outputs=[eval_dataset_defined, eval_dataset_custom, | |
| problem_type, num_labels, metrics] | |
| ) | |
| eval_dataset_defined.change( | |
| fn=lambda x: update_dataset_settings("Use Pre-defined Dataset", x), | |
| inputs=[eval_dataset_defined], | |
| outputs=[eval_dataset_defined, eval_dataset_custom, | |
| problem_type, num_labels, metrics] | |
| ) | |
| ### These are settings for different training methods. ### | |
| # for ses-adapter | |
| with gr.Row(visible=False) as structure_seq_row: | |
| eval_structure_seq = gr.CheckboxGroup( | |
| label="Structure Sequence", | |
| choices=["foldseek_seq", "ss8_seq"], | |
| value=["foldseek_seq", "ss8_seq"] | |
| ) | |
| def update_training_method(method): | |
| return { | |
| structure_seq_row: gr.update(visible=method == "ses-adapter") | |
| } | |
| eval_method.change( | |
| fn=update_training_method, | |
| inputs=[eval_method], | |
| outputs=[structure_seq_row] | |
| ) | |
| gr.Markdown("### Batch Processing Configuration") | |
| with gr.Group(): | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| batch_mode = gr.Radio( | |
| choices=["Batch Size Mode", "Batch Token Mode"], | |
| label="Batch Processing Mode", | |
| value="Batch Size Mode" | |
| ) | |
| with gr.Column(scale=2): | |
| batch_size = gr.Slider( | |
| minimum=1, | |
| maximum=128, | |
| value=16, | |
| step=1, | |
| label="Batch Size", | |
| visible=True | |
| ) | |
| batch_token = gr.Slider( | |
| minimum=1000, | |
| maximum=50000, | |
| value=10000, | |
| step=1000, | |
| label="Tokens per Batch", | |
| visible=False | |
| ) | |
| def update_batch_inputs(mode): | |
| return { | |
| batch_size: gr.update(visible=mode == "Batch Size Mode"), | |
| batch_token: gr.update(visible=mode == "Batch Token Mode") | |
| } | |
| # Update visibility when mode changes | |
| batch_mode.change( | |
| fn=update_batch_inputs, | |
| inputs=[batch_mode], | |
| outputs=[batch_size, batch_token] | |
| ) | |
| with gr.Row(): | |
| preview_button = gr.Button("Preview Command") | |
| abort_button = gr.Button("Abort", variant="stop") | |
| eval_button = gr.Button("Start Evaluation", variant="primary") | |
| with gr.Row(): | |
| command_preview = gr.Code( | |
| label="Command Preview", | |
| language="shell", | |
| interactive=False, | |
| visible=False | |
| ) | |
| def handle_preview(plm_model, model_path, eval_method, is_custom_dataset, dataset_defined, | |
| dataset_custom, problem_type, num_labels, metrics, batch_mode, | |
| batch_size, batch_token, eval_structure_seq, eval_pooling_method): | |
| """处理预览命令按钮点击事件""" | |
| if command_preview.visible: | |
| return gr.update(visible=False) | |
| # 构建参数字典 | |
| args = { | |
| "plm_model": plm_models[plm_model], | |
| "model_path": model_path, | |
| "eval_method": eval_method, | |
| "pooling_method": eval_pooling_method | |
| } | |
| # 处理数据集相关参数 | |
| if is_custom_dataset == "Use Custom Dataset": | |
| args["dataset"] = dataset_custom | |
| args["problem_type"] = problem_type | |
| args["num_labels"] = num_labels | |
| args["metrics"] = ",".join(metrics) | |
| else: | |
| with open(dataset_configs[dataset_defined], 'r') as f: | |
| config = json.load(f) | |
| args["dataset_config"] = dataset_configs[dataset_defined] | |
| # 处理批处理参数 | |
| if batch_mode == "Batch Size Mode": | |
| args["batch_size"] = batch_size | |
| else: | |
| args["batch_token"] = batch_token | |
| # 处理结构序列参数 | |
| if eval_method == "ses-adapter" and eval_structure_seq: | |
| args["structure_seq"] = ",".join(eval_structure_seq) | |
| if "foldseek_seq" in eval_structure_seq: | |
| args["use_foldseek"] = True | |
| if "ss8_seq" in eval_structure_seq: | |
| args["use_ss8"] = True | |
| # 生成预览命令 | |
| preview_text = preview_eval_command(args) | |
| return gr.update(value=preview_text, visible=True) | |
| # 绑定预览按钮事件 | |
| preview_button.click( | |
| fn=handle_preview, | |
| inputs=[ | |
| eval_plm_model, | |
| eval_model_path, | |
| eval_method, | |
| is_custom_dataset, | |
| eval_dataset_defined, | |
| eval_dataset_custom, | |
| problem_type, | |
| num_labels, | |
| metrics, | |
| batch_mode, | |
| batch_size, | |
| batch_token, | |
| eval_structure_seq, | |
| eval_pooling_method | |
| ], | |
| outputs=[command_preview] | |
| ) | |
| eval_output = gr.HTML( | |
| value="<div style='padding: 15px; background-color: #f5f5f5; border-radius: 5px;'><p style='margin: 0;'>Click the 「Start Evaluation」 button to begin model evaluation</p></div>", | |
| label="Evaluation Status & Results" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| pass | |
| with gr.Column(scale=1): | |
| download_csv_btn = gr.DownloadButton( | |
| "Download CSV", | |
| visible=False, | |
| size="lg" | |
| ) | |
| with gr.Column(scale=4): | |
| pass | |
| # Connect buttons to functions | |
| eval_button.click( | |
| fn=evaluate_model, | |
| inputs=[ | |
| eval_plm_model, | |
| eval_model_path, | |
| eval_method, | |
| is_custom_dataset, | |
| eval_dataset_defined, | |
| eval_dataset_custom, | |
| problem_type, | |
| num_labels, | |
| metrics, | |
| batch_mode, | |
| batch_size, | |
| batch_token, | |
| eval_structure_seq, | |
| eval_pooling_method | |
| ], | |
| outputs=[eval_output, download_csv_btn] | |
| ) | |
| abort_button.click( | |
| fn=handle_abort, | |
| inputs=[], | |
| outputs=[eval_output, download_csv_btn] | |
| ) | |
| return { | |
| "eval_button": eval_button, | |
| "eval_output": eval_output | |
| } |