Spaces:
Sleeping
Sleeping
| import hashlib | |
| import json | |
| import pickle | |
| from datetime import datetime | |
| from pathlib import Path | |
| import gradio as gr | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| from datasets import load_dataset | |
| from tqdm import tqdm | |
| # Cache configuration | |
| global CACHE_DIR | |
| global TASKS_INDEX_FILE | |
| global TASK_DATA_DIR | |
| global DATASET_DATA_DIR | |
| global METRICS_INDEX_FILE | |
| CACHE_DIR = Path("./pwc_cache") | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| # Directory structure for disk-based storage | |
| TASKS_INDEX_FILE = CACHE_DIR / "tasks_index.json" # Small JSON file with task list | |
| TASK_DATA_DIR = CACHE_DIR / "task_data" # Directory for individual task files | |
| DATASET_DATA_DIR = CACHE_DIR / "dataset_data" # Directory for individual dataset files | |
| METRICS_INDEX_FILE = CACHE_DIR / "metrics_index.json" # Metrics metadata | |
| # Create directories | |
| TASK_DATA_DIR.mkdir(exist_ok=True) | |
| DATASET_DATA_DIR.mkdir(exist_ok=True) | |
| def sanitize_filename(name): | |
| """Convert a string to a safe filename.""" | |
| # Replace problematic characters with underscores | |
| safe_name = name.replace('/', '_').replace('\\', '_').replace(':', '_') | |
| safe_name = safe_name.replace('*', '_').replace('?', '_').replace('"', '_') | |
| safe_name = safe_name.replace('<', '_').replace('>', '_').replace('|', '_') | |
| safe_name = safe_name.replace(' ', '_').replace('.', '_') | |
| # Remove multiple underscores and trim | |
| safe_name = '_'.join(filter(None, safe_name.split('_'))) | |
| # Limit length to avoid filesystem issues | |
| if len(safe_name) > 200: | |
| # If too long, use first 150 chars + hash of full name | |
| safe_name = safe_name[:150] + '_' + hashlib.md5(name.encode()).hexdigest()[:8] | |
| return safe_name | |
| def get_task_filename(task): | |
| """Generate a safe filename for a task.""" | |
| safe_name = sanitize_filename(task) | |
| return TASK_DATA_DIR / f"task_{safe_name}.pkl" | |
| def get_dataset_filename(task, dataset_name): | |
| """Generate a safe filename for a dataset.""" | |
| safe_task = sanitize_filename(task) | |
| safe_dataset = sanitize_filename(dataset_name) | |
| # Include both task and dataset in filename for clarity | |
| filename = f"data_{safe_task}_{safe_dataset}.pkl" | |
| # If combined name is too long, shorten it | |
| if len(filename) > 255: | |
| # Use shorter version with hash | |
| filename = f"data_{safe_task[:50]}_{safe_dataset[:50]}_{hashlib.md5(f'{task}||{dataset_name}'.encode()).hexdigest()[:8]}.pkl" | |
| return DATASET_DATA_DIR / filename | |
| def cache_exists(): | |
| """Check if cache structure exists.""" | |
| print(f"{TASKS_INDEX_FILE =}") | |
| print(f"{METRICS_INDEX_FILE =}") | |
| print(f"{TASKS_INDEX_FILE.exists() =}") | |
| print(f"{METRICS_INDEX_FILE.exists() =}") | |
| return TASKS_INDEX_FILE.exists() and METRICS_INDEX_FILE.exists() | |
| def build_disk_based_cache(): | |
| """Build cache with minimal memory usage - process dataset in streaming fashion.""" | |
| import os | |
| print("Michael test", os.path.isdir("./pwc_cache")) | |
| print("=" * 60) | |
| print("=" * 60) | |
| print("Building disk-based cache (one-time operation)...") | |
| print("=" * 60) | |
| # Initialize tracking structures (kept small) | |
| tasks_set = set() | |
| metrics_index = {} | |
| print("\n[1/4] Streaming dataset and building cache...") | |
| # Load dataset in streaming mode to save memory | |
| ds = load_dataset("pwc-archive/evaluation-tables", split="train", streaming=False) | |
| total_items = len(ds) | |
| processed_count = 0 | |
| dataset_count = 0 | |
| for idx, item in tqdm(enumerate(ds), total=total_items): | |
| # Progress indicator | |
| task = item['task'] | |
| if not task: | |
| continue | |
| tasks_set.add(task) | |
| # Load existing task data from disk or create new | |
| task_file = get_task_filename(task) | |
| if task_file.exists(): | |
| with open(task_file, 'rb') as f: | |
| task_data = pickle.load(f) | |
| else: | |
| task_data = { | |
| 'categories': set(), | |
| 'datasets': set(), | |
| 'date_range': {'min': None, 'max': None} | |
| } | |
| # Update task data | |
| if item['categories']: | |
| task_data['categories'].update(item['categories']) | |
| # Process datasets | |
| if item['datasets']: | |
| for dataset in item['datasets']: | |
| if not isinstance(dataset, dict) or 'dataset' not in dataset: | |
| continue | |
| dataset_name = dataset['dataset'] | |
| dataset_file = get_dataset_filename(task, dataset_name) | |
| # Skip if already processed | |
| if dataset_file.exists(): | |
| task_data['datasets'].add(dataset_name) | |
| continue | |
| task_data['datasets'].add(dataset_name) | |
| # Process SOTA data | |
| if 'sota' not in dataset or 'rows' not in dataset['sota']: | |
| continue | |
| models_data = [] | |
| for row in dataset['sota']['rows']: | |
| if not isinstance(row, dict): | |
| continue | |
| model_name = row.get('model_name', 'Unknown Model') | |
| # Extract metrics | |
| metrics = {} | |
| if 'metrics' in row and isinstance(row['metrics'], dict): | |
| for metric_name, metric_value in row['metrics'].items(): | |
| if metric_value is not None: | |
| metrics[metric_name] = metric_value | |
| # Track metric metadata | |
| if metric_name not in metrics_index: | |
| metrics_index[metric_name] = { | |
| 'count': 0, | |
| 'is_lower_better': any(kw in metric_name.lower() | |
| for kw in ['error', 'loss', 'time', 'cost']) | |
| } | |
| metrics_index[metric_name]['count'] += 1 | |
| # Parse date | |
| paper_date = row.get('paper_date') | |
| try: | |
| if paper_date and isinstance(paper_date, str): | |
| release_date = pd.to_datetime(paper_date) | |
| else: | |
| release_date = pd.to_datetime('2020-01-01') | |
| except: | |
| release_date = pd.to_datetime('2020-01-01') | |
| # Update date range | |
| if task_data['date_range']['min'] is None or release_date < task_data['date_range']['min']: | |
| task_data['date_range']['min'] = release_date | |
| if task_data['date_range']['max'] is None or release_date > task_data['date_range']['max']: | |
| task_data['date_range']['max'] = release_date | |
| # Build model entry | |
| model_entry = { | |
| 'model_name': model_name, | |
| 'release_date': release_date, | |
| 'paper_date': row.get('paper_date', ''), # Store raw paper_date for dynamic parsing | |
| 'paper_url': row.get('paper_url', ''), | |
| 'paper_title': row.get('paper_title', ''), | |
| 'code_url': row.get('code_links', [''])[0] if row.get('code_links') else '', | |
| **metrics | |
| } | |
| models_data.append(model_entry) | |
| if models_data: | |
| df = pd.DataFrame(models_data) | |
| df = df.sort_values('release_date') | |
| # Save dataset to its own file | |
| with open(dataset_file, 'wb') as f: | |
| pickle.dump(df, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| dataset_count += 1 | |
| # Clear DataFrame from memory | |
| del df | |
| del models_data | |
| # Save updated task data back to disk | |
| with open(task_file, 'wb') as f: | |
| # Convert sets to lists for serialization | |
| task_data_to_save = { | |
| 'categories': sorted(list(task_data['categories'])), | |
| 'datasets': sorted(list(task_data['datasets'])), | |
| 'date_range': task_data['date_range'] | |
| } | |
| pickle.dump(task_data_to_save, f, protocol=pickle.HIGHEST_PROTOCOL) | |
| # Clear task data from memory | |
| del task_data | |
| processed_count += 1 | |
| print(f"\nβ Processed {len(tasks_set)} tasks and {dataset_count} datasets") | |
| print("\n[2/4] Saving index files...") | |
| # Save tasks index (small file) | |
| tasks_list = sorted(list(tasks_set)) | |
| with open(TASKS_INDEX_FILE, 'w') as f: | |
| json.dump(tasks_list, f) | |
| print(f" β Saved tasks index ({len(tasks_list)} tasks)") | |
| # Save metrics index | |
| with open(METRICS_INDEX_FILE, 'w') as f: | |
| json.dump(metrics_index, f, indent=2) | |
| print(f" β Saved metrics index ({len(metrics_index)} metrics)") | |
| print("\n[3/4] Calculating cache statistics...") | |
| # Calculate total cache size | |
| total_size = 0 | |
| for file in TASK_DATA_DIR.glob("*.pkl"): | |
| total_size += file.stat().st_size | |
| for file in DATASET_DATA_DIR.glob("*.pkl"): | |
| total_size += file.stat().st_size | |
| print(f" β Total cache size: {total_size / 1024 / 1024:.1f} MB") | |
| print(f" β Task files: {len(list(TASK_DATA_DIR.glob('*.pkl')))}") | |
| print(f" β Dataset files: {len(list(DATASET_DATA_DIR.glob('*.pkl')))}") | |
| print("\n[4/4] Cache building complete!") | |
| print("=" * 60) | |
| return tasks_list | |
| def load_tasks_index(): | |
| """Load just the task list from disk.""" | |
| with open(TASKS_INDEX_FILE, 'r') as f: | |
| return json.load(f) | |
| def load_task_data(task): | |
| """Load data for a specific task from disk.""" | |
| task_file = get_task_filename(task) | |
| if task_file.exists(): | |
| with open(task_file, 'rb') as f: | |
| return pickle.load(f) | |
| return None | |
| def load_dataset_data(task, dataset_name): | |
| """Load a specific dataset from disk.""" | |
| dataset_file = get_dataset_filename(task, dataset_name) | |
| if dataset_file.exists(): | |
| with open(dataset_file, 'rb') as f: | |
| return pickle.load(f) | |
| return pd.DataFrame() | |
| def load_metrics_index(): | |
| """Load metrics index from disk.""" | |
| if METRICS_INDEX_FILE.exists(): | |
| with open(METRICS_INDEX_FILE, 'r') as f: | |
| return json.load(f) | |
| return {} | |
| # Initialize - build cache if doesn't exist | |
| if cache_exists(): | |
| print("Loading task index from disk...") | |
| TASKS = load_tasks_index() | |
| print(f"β Loaded {len(TASKS)} tasks") | |
| else: | |
| TASKS = build_disk_based_cache() | |
| # Load metrics index once (it's small) | |
| METRICS_INDEX = load_metrics_index() | |
| # Memory-efficient accessor functions | |
| def get_tasks(): | |
| """Get all tasks from index.""" | |
| return TASKS | |
| def get_task_data(task): | |
| """Load task data from disk on-demand.""" | |
| return load_task_data(task) | |
| def get_categories(task): | |
| """Get categories for a task (loads from disk).""" | |
| task_data = get_task_data(task) | |
| return task_data['categories'] if task_data else [] | |
| def get_datasets_for_task(task): | |
| """Get datasets for a task (loads from disk).""" | |
| task_data = get_task_data(task) | |
| return task_data['datasets'] if task_data else [] | |
| def get_cached_model_data(task, dataset_name): | |
| """Load dataset from disk on-demand.""" | |
| return load_dataset_data(task, dataset_name) | |
| def parse_paper_date(paper_date, paper_title="", paper_url=""): | |
| """Parse paper date with improved fallback strategies.""" | |
| import re | |
| # Try to parse the raw paper_date if available | |
| if paper_date and isinstance(paper_date, str) and paper_date.strip(): | |
| try: | |
| # Try common date formats | |
| date_formats = [ | |
| '%Y-%m-%d', | |
| '%Y/%m/%d', | |
| '%d-%m-%Y', | |
| '%d/%m/%Y', | |
| '%Y-%m', | |
| '%Y/%m', | |
| '%Y' | |
| ] | |
| for fmt in date_formats: | |
| try: | |
| return pd.to_datetime(paper_date.strip(), format=fmt) | |
| except: | |
| continue | |
| # Try pandas automatic parsing | |
| return pd.to_datetime(paper_date.strip()) | |
| except: | |
| pass | |
| # Fallback: try to extract year from paper title or URL | |
| year_pattern = r'\b(19[5-9]\d|20[0-9]\d)\b' # Match 1950-2099 | |
| # Look for year in paper title | |
| if paper_title: | |
| years = re.findall(year_pattern, str(paper_title)) | |
| if years: | |
| try: | |
| year = max(years) # Use the latest year found | |
| return pd.to_datetime(f'{year}-01-01') | |
| except: | |
| pass | |
| # Look for year in paper URL | |
| if paper_url: | |
| years = re.findall(year_pattern, str(paper_url)) | |
| if years: | |
| try: | |
| year = max(years) # Use the latest year found | |
| return pd.to_datetime(f'{year}-01-01') | |
| except: | |
| pass | |
| # Final fallback: return None instead of a default year | |
| return None | |
| def get_task_statistics(task): | |
| """Get statistics about a task.""" | |
| return {} | |
| def create_sota_plot(df, metric): | |
| """Create a plot showing model performance evolution over time. | |
| Args: | |
| df: DataFrame with model data | |
| metric: Metric name to plot on y-axis | |
| """ | |
| if df.empty or metric not in df.columns: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="No data available for this metric", | |
| xref="paper", | |
| yref="paper", | |
| x=0.5, | |
| y=0.5, | |
| showarrow=False, | |
| font=dict(size=20) | |
| ) | |
| fig.update_layout( | |
| title="No Data Available", | |
| height=600, | |
| plot_bgcolor='white', | |
| paper_bgcolor='white' | |
| ) | |
| return fig | |
| # Remove rows where the metric is NaN | |
| df_clean = df.dropna(subset=[metric]).copy() | |
| if df_clean.empty: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="No valid data points for this metric", | |
| xref="paper", | |
| yref="paper", | |
| x=0.5, | |
| y=0.5, | |
| showarrow=False, | |
| font=dict(size=20) | |
| ) | |
| fig.update_layout( | |
| title="No Data Available", | |
| height=600, | |
| plot_bgcolor='white', | |
| paper_bgcolor='white' | |
| ) | |
| return fig | |
| # Convert metric column to numeric, handling any string values | |
| try: | |
| df_clean[metric] = pd.to_numeric( | |
| df_clean[metric].apply(lambda x: x.strip()[:-1] if isinstance(x, str) and x.strip().endswith("%") else x), | |
| errors='coerce') | |
| # Remove any rows that couldn't be converted to numeric | |
| df_clean = df_clean.dropna(subset=[metric]) | |
| if df_clean.empty: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=f"No numeric data available for metric: {metric}", | |
| xref="paper", | |
| yref="paper", | |
| x=0.5, | |
| y=0.5, | |
| showarrow=False, | |
| font=dict(size=20) | |
| ) | |
| fig.update_layout( | |
| title="No Numeric Data Available", | |
| height=600, | |
| plot_bgcolor='white', | |
| paper_bgcolor='white' | |
| ) | |
| return fig | |
| except Exception as e: | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text=f"Error processing metric data: {str(e)}", | |
| xref="paper", | |
| yref="paper", | |
| x=0.5, | |
| y=0.5, | |
| showarrow=False, | |
| font=dict(size=16) | |
| ) | |
| fig.update_layout( | |
| title="Data Processing Error", | |
| height=600, | |
| plot_bgcolor='white', | |
| paper_bgcolor='white' | |
| ) | |
| return fig | |
| # Recalculate release dates dynamically from raw paper_date if available | |
| df_processed = df_clean.copy() | |
| if 'paper_date' in df_processed.columns: | |
| # Parse dates dynamically using improved logic | |
| df_processed['dynamic_release_date'] = df_processed.apply( | |
| lambda row: parse_paper_date( | |
| row.get('paper_date', ''), | |
| row.get('paper_title', ''), | |
| row.get('paper_url', '') | |
| ), axis=1 | |
| ) | |
| # Use dynamic dates if available, otherwise fallback to original release_date | |
| df_processed['final_release_date'] = df_processed['dynamic_release_date'].fillna(df_processed['release_date']) | |
| else: | |
| # If no paper_date column, use existing release_date | |
| df_processed['final_release_date'] = df_processed['release_date'] | |
| # Filter out rows with no valid date | |
| df_with_dates = df_processed[df_processed['final_release_date'].notna()].copy() | |
| if df_with_dates.empty: | |
| # If no valid dates, return empty plot | |
| fig = go.Figure() | |
| fig.add_annotation( | |
| text="No valid dates available for this dataset", | |
| xref="paper", | |
| yref="paper", | |
| x=0.5, | |
| y=0.5, | |
| showarrow=False, | |
| font=dict(size=20) | |
| ) | |
| fig.update_layout( | |
| title="No Date Data Available", | |
| height=600, | |
| plot_bgcolor='white', | |
| paper_bgcolor='white' | |
| ) | |
| return fig | |
| # Sort by final release date | |
| df_sorted = df_with_dates.sort_values('final_release_date').copy() | |
| # Check if metric is lower-better | |
| is_lower_better = False | |
| if metric in METRICS_INDEX: | |
| is_lower_better = METRICS_INDEX[metric].get('is_lower_better', False) | |
| else: | |
| is_lower_better = any(keyword in metric.lower() for keyword in ['error', 'loss', 'time', 'cost']) | |
| if is_lower_better: | |
| df_sorted['cumulative_best'] = df_sorted[metric].cummin() | |
| df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best'] | |
| else: | |
| df_sorted['cumulative_best'] = df_sorted[metric].cummax() | |
| df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best'] | |
| # Get SOTA models | |
| sota_df = df_sorted[df_sorted['is_sota']].copy() | |
| # Use the dynamically calculated dates for x-axis | |
| x_values = df_sorted['final_release_date'] | |
| x_axis_title = 'Release Date' | |
| # Create the plot | |
| fig = go.Figure() | |
| # Add all models as scatter points | |
| fig.add_trace(go.Scatter( | |
| x=x_values, | |
| y=df_sorted[metric], | |
| mode='markers', | |
| name='All models', | |
| marker=dict( | |
| color=['#00CED1' if is_sota else 'lightgray' | |
| for is_sota in df_sorted['is_sota']], | |
| size=8, | |
| opacity=0.7 | |
| ), | |
| text=df_sorted['model_name'], | |
| customdata=df_sorted[['paper_title', 'paper_url', 'code_url']], | |
| hovertemplate='<b>%{text}</b><br>' + | |
| f'{metric}: %{{y:.4f}}<br>' + | |
| 'Date: %{x}<br>' + | |
| 'Paper: %{customdata[0]}<br>' + | |
| '<extra></extra>' | |
| )) | |
| # Add SOTA line | |
| fig.add_trace(go.Scatter( | |
| x=x_values, | |
| y=df_sorted['cumulative_best'], | |
| mode='lines', | |
| name=f'SOTA (cumulative {"min" if is_lower_better else "max"})', | |
| line=dict(color='#00CED1', width=2, dash='solid'), | |
| hovertemplate=f'SOTA {metric}: %{{y:.4f}}<br>{x_axis_title}: %{{x}}<extra></extra>' | |
| )) | |
| # Add labels for SOTA models | |
| if not sota_df.empty: | |
| # Calculate dynamic offset based on data range | |
| y_range = df_sorted[metric].max() - df_sorted[metric].min() | |
| # Use a percentage of the range for offset, with minimum and maximum bounds | |
| if y_range > 0: | |
| base_offset = y_range * 0.03 # 3% of the data range | |
| # Ensure minimum offset for readability and maximum to prevent excessive spacing | |
| label_offset = max(y_range * 0.01, min(base_offset, y_range * 0.08)) | |
| else: | |
| # Fallback for when all values are the same | |
| label_offset = 1 | |
| # Track label positions to prevent overlaps | |
| previous_labels = [] | |
| # For date-based x-axis, use date separation | |
| try: | |
| date_range = (df_sorted['final_release_date'].max() - df_sorted['final_release_date'].min()).days | |
| min_separation = max(30, date_range * 0.05) # Minimum 30 days or 5% of range | |
| except (TypeError, AttributeError): | |
| # Fallback if date calculation fails | |
| min_separation = 30 | |
| for i, (_, row) in enumerate(sota_df.iterrows()): | |
| # Determine base label position based on metric type | |
| if is_lower_better: | |
| # For lower-better metrics, place label above the point (negative ay) | |
| base_ay_offset = -label_offset | |
| base_yshift = -8 | |
| alternate_multiplier = -1 | |
| else: | |
| # For higher-better metrics, place label below the point (positive ay) | |
| base_ay_offset = label_offset | |
| base_yshift = 8 | |
| alternate_multiplier = 1 | |
| # Check for collision with previous labels | |
| current_x = row['final_release_date'] | |
| collision_detected = False | |
| for prev_x, prev_ay in previous_labels: | |
| try: | |
| x_diff = abs((current_x - prev_x).days) | |
| if x_diff < min_separation: | |
| collision_detected = True | |
| break | |
| except (TypeError, AttributeError): | |
| # Skip collision detection if calculation fails | |
| continue | |
| # Adjust position if collision detected | |
| if collision_detected: | |
| # Alternate the label position (above/below) to avoid overlap | |
| ay_offset = base_ay_offset + (alternate_multiplier * label_offset * 0.7 * (i % 2)) | |
| yshift = base_yshift + (alternate_multiplier * 12 * (i % 2)) | |
| else: | |
| ay_offset = base_ay_offset | |
| yshift = base_yshift | |
| # Add the annotation | |
| fig.add_annotation( | |
| x=current_x, | |
| y=row[metric], | |
| text=row['model_name'][:25] + '...' if len(row['model_name']) > 25 else row['model_name'], | |
| showarrow=True, | |
| arrowhead=2, | |
| arrowsize=1, | |
| arrowwidth=1, | |
| arrowcolor='#00CED1', # Match the SOTA line color | |
| ax=0, | |
| ay=ay_offset, # Dynamic offset based on data range and collision detection | |
| yshift=yshift, # Fine-tune positioning | |
| font=dict(size=8, color='#333333'), | |
| bgcolor='rgba(255, 255, 255, 0.9)', # Semi-transparent background | |
| borderwidth=0 # Remove border | |
| ) | |
| # Track this label position | |
| previous_labels.append((current_x, ay_offset)) | |
| # Update layout | |
| fig.update_layout( | |
| title=f'SOTA Evolution: {metric}', | |
| xaxis_title=x_axis_title, | |
| yaxis_title=metric, | |
| xaxis=dict(showgrid=True, gridcolor='lightgray'), | |
| yaxis=dict(showgrid=True, gridcolor='lightgray'), | |
| plot_bgcolor='white', | |
| paper_bgcolor='white', | |
| height=600, | |
| legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01), | |
| hovermode='closest' | |
| ) | |
| # Clear the DataFrame from memory after plotting | |
| del df_clean | |
| del df_sorted | |
| del sota_df | |
| return fig | |
| # Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# π Papers with Code - SOTA Evolution Visualizer") | |
| gr.Markdown( | |
| "Navigate through ML tasks and datasets to visualize the evolution of state-of-the-art models over time.") | |
| gr.Markdown("*Optimized for low memory usage - data is loaded on-demand from disk*") | |
| # Status | |
| with gr.Row(): | |
| gr.Markdown(f""" | |
| <div style="background-color: #f0f9ff; border-left: 4px solid #00CED1; padding: 10px; margin: 10px 0;"> | |
| <b>πΎ Disk-Based Storage Active</b><br> | |
| β’ <b>{len(TASKS)}</b> tasks indexed<br> | |
| β’ <b>{len(METRICS_INDEX)}</b> unique metrics tracked<br> | |
| β’ Data loaded on-demand to minimize RAM usage | |
| </div> | |
| """) | |
| # State variables | |
| current_df = gr.State(pd.DataFrame()) | |
| current_task = gr.State(None) | |
| # Navigation dropdowns | |
| with gr.Row(): | |
| task_dropdown = gr.Dropdown( | |
| choices=get_tasks(), | |
| label="Select Task", | |
| interactive=True | |
| ) | |
| category_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Categories (info only)", | |
| interactive=False | |
| ) | |
| with gr.Row(): | |
| dataset_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Select Dataset", | |
| interactive=True | |
| ) | |
| metric_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Select Metric", | |
| interactive=True | |
| ) | |
| # Info display | |
| info_text = gr.Markdown("π Please select a task to begin") | |
| # Plot | |
| plot = gr.Plot(label="SOTA Evolution") | |
| # Data display | |
| with gr.Row(): | |
| show_data_btn = gr.Button("π Show/Hide Model Data") | |
| export_btn = gr.Button("πΎ Export Current Data (CSV)") | |
| clear_memory_btn = gr.Button("π§Ή Clear Memory", variant="secondary") | |
| df_display = gr.Dataframe( | |
| label="Model Data", | |
| visible=False | |
| ) | |
| # Update functions | |
| def update_task_selection(task): | |
| """Update dropdowns when task is selected.""" | |
| if not task: | |
| return [], [], [], "π Please select a task to begin", pd.DataFrame(), None, None | |
| # Load task data from disk | |
| categories = get_categories(task) | |
| datasets = get_datasets_for_task(task) | |
| info = f"### π **Task:** {task}\n" | |
| if categories: | |
| info += f"- **Categories:** {', '.join(categories[:3])}{'...' if len(categories) > 3 else ''} ({len(categories)} total)\n" | |
| return ( | |
| gr.Dropdown(choices=categories, value=categories[0] if categories else None), | |
| gr.Dropdown(choices=datasets, value=None), | |
| gr.Dropdown(choices=[], value=None), | |
| info, | |
| pd.DataFrame(), | |
| None, | |
| task # Store current task | |
| ) | |
| def update_dataset_selection(task, dataset_name): | |
| """Update when dataset is selected - loads from disk.""" | |
| if not task or not dataset_name: | |
| return [], "", pd.DataFrame(), None | |
| # Load dataset from disk | |
| df = get_cached_model_data(task, dataset_name) | |
| if df.empty: | |
| return [], f"β οΈ No models found for dataset: {dataset_name}", df, None | |
| # Get metric columns | |
| exclude_cols = ['model_name', 'release_date', 'paper_date', 'paper_url', 'paper_title', 'code_url'] | |
| metric_cols = [col for col in df.columns if col not in exclude_cols] | |
| info = f"### π **Dataset:** {dataset_name}\n" | |
| info += f"- **Models:** {len(df)} models\n" | |
| info += f"- **Metrics:** {len(metric_cols)} metrics available\n" | |
| if not df.empty: | |
| info += f"- **Date Range:** {df['release_date'].min().strftime('%Y-%m-%d')} to {df['release_date'].max().strftime('%Y-%m-%d')}\n" | |
| if metric_cols: | |
| info += f"- **Available Metrics:** {', '.join(metric_cols[:5])}{'...' if len(metric_cols) > 5 else ''}" | |
| return ( | |
| gr.Dropdown(choices=metric_cols, value=metric_cols[0] if metric_cols else None), | |
| info, | |
| df, | |
| None | |
| ) | |
| def update_plot(df, metric): | |
| """Update plot when metric is selected.""" | |
| if df.empty or not metric: | |
| return None | |
| plot_result = create_sota_plot(df, metric) | |
| return plot_result | |
| def toggle_dataframe(df): | |
| """Toggle dataframe visibility.""" | |
| if df.empty: | |
| return gr.Dataframe(value=pd.DataFrame(), visible=False) | |
| # Show relevant columns | |
| display_cols = ['model_name', 'release_date'] + [col for col in df.columns | |
| if col not in ['model_name', 'release_date', 'paper_date', | |
| 'paper_url', | |
| 'paper_title', 'code_url']] | |
| display_df = df[display_cols].copy() | |
| display_df['release_date'] = display_df['release_date'].dt.strftime('%Y-%m-%d') | |
| return gr.Dataframe(value=display_df, visible=True) | |
| def export_data(df): | |
| """Export current dataframe to CSV.""" | |
| if df.empty: | |
| return "β οΈ No data to export" | |
| filename = f"sota_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv" | |
| df.to_csv(filename, index=False) | |
| return f"β Data exported to {filename} ({len(df)} models)" | |
| def clear_memory(): | |
| """Clear memory by forcing garbage collection.""" | |
| import gc | |
| gc.collect() | |
| return "β Memory cleared" | |
| # Event handlers | |
| task_dropdown.change( | |
| fn=update_task_selection, | |
| inputs=task_dropdown, | |
| outputs=[category_dropdown, dataset_dropdown, | |
| metric_dropdown, info_text, current_df, plot, current_task] | |
| ) | |
| dataset_dropdown.change( | |
| fn=update_dataset_selection, | |
| inputs=[task_dropdown, dataset_dropdown], | |
| outputs=[metric_dropdown, info_text, current_df, plot] | |
| ) | |
| metric_dropdown.change( | |
| fn=update_plot, | |
| inputs=[current_df, metric_dropdown], | |
| outputs=plot | |
| ) | |
| show_data_btn.click( | |
| fn=toggle_dataframe, | |
| inputs=current_df, | |
| outputs=df_display | |
| ) | |
| export_btn.click( | |
| fn=export_data, | |
| inputs=current_df, | |
| outputs=info_text | |
| ) | |
| clear_memory_btn.click( | |
| fn=clear_memory, | |
| inputs=[], | |
| outputs=info_text | |
| ) | |
| gr.Markdown(""" | |
| --- | |
| ### π How to Use | |
| 1. **Select a Task** from the first dropdown | |
| 2. **Select a Dataset** to analyze | |
| 3. **Select a Metric** to visualize | |
| 4. The plot shows SOTA model evolution over time with dynamically calculated dates | |
| ### πΎ Memory Optimization | |
| - Data is stored on disk and loaded on-demand | |
| - Only the current task and dataset are kept in memory | |
| - Use "Clear Memory" button if needed | |
| - Infinite disk space is utilized for permanent caching | |
| ### π¨ Plot Features | |
| - **π΅ Cyan dots**: SOTA models when released | |
| - **βͺ Gray dots**: Other models | |
| - **π Cyan line**: SOTA progression | |
| - **π Hover**: View model details | |
| - **π·οΈ Smart Labels**: SOTA model labels positioned close to the line with intelligent collision detection | |
| """) | |
| def test_sota_label_positioning(): | |
| """Test function to validate SOTA label positioning improvements.""" | |
| print("π§ͺ Testing SOTA label positioning...") | |
| # Create sample data for testing | |
| import pandas as pd | |
| from datetime import datetime | |
| # Test data with different metric types (including all required columns) | |
| test_data = { | |
| 'model_name': ['Model A', 'Model B', 'Model C', 'Model D'], | |
| 'release_date': [ | |
| datetime(2020, 1, 1), | |
| datetime(2020, 6, 1), | |
| datetime(2021, 1, 1), | |
| datetime(2021, 6, 1) | |
| ], | |
| 'paper_title': ['Paper A', 'Paper B', 'Paper C', 'Paper D'], | |
| 'paper_url': ['http://example.com/a', 'http://example.com/b', 'http://example.com/c', 'http://example.com/d'], | |
| 'code_url': ['http://github.com/a', 'http://github.com/b', 'http://github.com/c', 'http://github.com/d'], | |
| 'accuracy': [0.85, 0.87, 0.90, 0.92], # Higher-better metric | |
| 'error_rate': [0.15, 0.13, 0.10, 0.08] # Lower-better metric | |
| } | |
| df_test = pd.DataFrame(test_data) | |
| # Test with higher-better metric (accuracy) | |
| print(" Testing with higher-better metric (accuracy)...") | |
| try: | |
| fig1 = create_sota_plot(df_test, 'accuracy') | |
| print(" β Higher-better metric test passed") | |
| except Exception as e: | |
| print(f" β Higher-better metric test failed: {e}") | |
| # Test with lower-better metric (error_rate) | |
| print(" Testing with lower-better metric (error_rate)...") | |
| try: | |
| fig2 = create_sota_plot(df_test, 'error_rate') | |
| print(" β Lower-better metric test passed") | |
| except Exception as e: | |
| print(f" β Lower-better metric test failed: {e}") | |
| # Test with empty data | |
| print(" Testing with empty dataframe...") | |
| try: | |
| fig3 = create_sota_plot(pd.DataFrame(), 'test_metric') | |
| print(" β Empty data test passed") | |
| except Exception as e: | |
| print(f" β Empty data test failed: {e}") | |
| # Test with string metric data (should handle gracefully) | |
| print(" Testing with string metric data...") | |
| try: | |
| df_test_string = df_test.copy() | |
| df_test_string['string_metric'] = ['low', 'medium', 'high', 'very_high'] | |
| fig4 = create_sota_plot(df_test_string, 'string_metric') | |
| print(" β String metric test passed (handled gracefully)") | |
| except Exception as e: | |
| print(f" β String metric test failed: {e}") | |
| # Test with mixed numeric/string data | |
| print(" Testing with mixed data types...") | |
| try: | |
| df_test_mixed = df_test.copy() | |
| df_test_mixed['mixed_metric'] = [0.85, 'N/A', 0.90, 0.92] | |
| fig5 = create_sota_plot(df_test_mixed, 'mixed_metric') | |
| print(" β Mixed data test passed") | |
| except Exception as e: | |
| print(f" β Mixed data test failed: {e}") | |
| # Test with paper_date parsing | |
| print(" Testing with paper_date column...") | |
| try: | |
| df_test_dates = df_test.copy() | |
| df_test_dates['paper_date'] = ['2015-03-15', '2018-invalid', '2021-12-01', '2022'] | |
| fig6 = create_sota_plot(df_test_dates, 'accuracy') | |
| print(" β Paper date parsing test passed") | |
| except Exception as e: | |
| print(f" β Paper date parsing test failed: {e}") | |
| print("π SOTA label positioning tests completed!") | |
| return True | |
| demo.launch() |