Spaces:
Running
Running
| import json | |
| import math | |
| from statistics import mean | |
| from datetime import datetime | |
| import pandas as pd | |
| import gradio as gr | |
| import plotly.graph_objects as go | |
| from gradio_leaderboard import Leaderboard | |
| from src.utils import( | |
| _safe_numeric, | |
| calculate_cumulative_average, | |
| create_dataframe, | |
| get_aggregated_columns, | |
| load_data, | |
| load_model_metadata, | |
| load_raw_model_data, | |
| build_year_column_mapping, | |
| ) | |
| from content import LLMLAGBENCH_INTRO, LEADERBOARD_INTRO, MODEL_COMPARISON_INTRO, AUTHORS, CIT_BTN_TEXT, CIT_BTN_LABEL, EXEMPLARY_QUESTIONS_INTRO, EXEMPLARY_QUESTIONS_DATA | |
| ### CONFIGURATION | |
| cfg = { | |
| "data_path": "data/leaderboard_graph_data.json", | |
| "metadata_path": "data/model_metadata.json", | |
| "years": ["2021", "2022", "2023", "2024", "2025"], | |
| "months": [f"{i:02d}" for i in range(1, 13)] | |
| } | |
| ### CALLBACKS | |
| # updated update_dash to create interactive plot | |
| def update_dashboard(graph_years, graph_model_filter): | |
| """ | |
| graph_years: list like ["2024","2025"] for graph from graph_year_selector | |
| graph_model_filter: list of models for the line plot or None | |
| """ | |
| # Table always shows all years and all models | |
| table_years = cfg.get("years") | |
| table_model_filter = None | |
| # Default: show all years if none selected for graph | |
| if not graph_years: | |
| graph_years = cfg.get("years") | |
| # keep some necessary metadata columns in the specified order | |
| metadata_cols = ["Model", "Overall Average", "1st Detected cutoff", "2nd Detected cutoff", "Provider cutoff", "Provider", "Release date", "Self-declared cutoff", "trend_changepoints", "Parameters", "Evaluation period"] | |
| cols = metadata_cols.copy() | |
| yearly_df = df.copy() | |
| monthly_df = df.copy() | |
| graph_df = df.copy() | |
| # TODO if >1 year - aggregate the values to be per year, not per month | |
| if len(table_years) > 1: | |
| lb_cols = ["Model", "Overall Average", "1st Detected cutoff", "2nd Detected cutoff", "Provider cutoff", "Provider", "Release date", "Self-declared cutoff", "Parameters", "Evaluation period"] + [y for y in cfg.get("aggregated_cols_year") if y in table_years] | |
| yearly_df = yearly_df[lb_cols] | |
| # Expand years into their YYYY_MM columns (for table) | |
| chosen_months = [] | |
| for y in table_years: | |
| chosen_months.extend(year_to_columns.get(y, [])) | |
| # Sort chronologically using the global aggregated_cols order | |
| # Only include months that actually exist in the dataframe | |
| chosen_months_with_years = table_years + [c for c in cfg.get("aggregated_cols_month") if c in chosen_months and c in monthly_df.columns] | |
| cols.extend(chosen_months_with_years) | |
| # Filter by models for table if requested | |
| if table_model_filter: | |
| yearly_df = yearly_df[yearly_df["Model"].isin(table_model_filter)] | |
| monthly_df = monthly_df[monthly_df["Model"].isin(table_model_filter)] | |
| # Sort by Overall Average in descending order | |
| yearly_df = yearly_df.sort_values(by="Overall Average", ascending=False) | |
| monthly_df = monthly_df.sort_values(by="Overall Average", ascending=False) | |
| # Reduce columns | |
| monthly_df = monthly_df[cols] | |
| # Filter by models for graph if requested (use separate dataframe) | |
| # Build graph columns based on graph_years | |
| graph_months = [] | |
| for y in graph_years: | |
| graph_months.extend(year_to_columns.get(y, [])) | |
| graph_months_with_years = graph_years + [c for c in cfg.get("aggregated_cols_month") if c in graph_months and c in graph_df.columns] | |
| graph_cols = metadata_cols + graph_months_with_years | |
| graph_df = graph_df[graph_cols] | |
| if graph_model_filter: | |
| graph_df = graph_df[graph_df["Model"].isin(graph_model_filter)] | |
| # Build tidy dataframe for gr.LinePlot with columns x, y, Model | |
| records = [] | |
| # Exclude all metadata columns and yearly aggregates from x_labels - only keep monthly columns | |
| excluded_cols = {"Model", "Overall Average", "Parameters", "1st Detected cutoff", "2nd Detected cutoff", "Provider", "Provider cutoff", "Release date", "Self-declared cutoff", "trend_changepoints", "Evaluation period"} | |
| x_labels = [c for c in graph_cols if c not in excluded_cols and c not in graph_years] # only months for the plot | |
| for _, row in graph_df.iterrows(): | |
| for col in x_labels: | |
| y_val = _safe_numeric(row.get(col)) | |
| records.append({"x": col, "y": y_val, "Model": row["Model"]}) | |
| lineplot_df = pd.DataFrame(records) | |
| # Ensure chronological order using global sorted list - double sorting? TODO verify | |
| chronological_order = [c for c in cfg.get("aggregated_cols_month") if c in lineplot_df["x"].unique()] | |
| lineplot_df["x"] = pd.Categorical(lineplot_df["x"], categories=chronological_order, ordered=True) | |
| lineplot_df = lineplot_df.sort_values(by="x") | |
| # Build Plotly figure | |
| fig = go.Figure() | |
| for _, row in graph_df.iterrows(): | |
| model = row["Model"] | |
| color = GLOBAL_MODEL_COLORS[model] | |
| model_data = lineplot_df[lineplot_df["Model"] == model] | |
| fig.add_trace(go.Scatter( | |
| x=model_data["x"], | |
| y=model_data["y"], | |
| mode="lines", | |
| name=model, | |
| line=dict(width=2, color=color), | |
| hovertemplate="Model: %{text}<br>x=%{x}<br>y=%{y}", | |
| text=[model] * len(model_data), | |
| showlegend=True, | |
| line_shape='spline' | |
| )) | |
| # Highlight changepoints (can be multiple) | |
| changepoints = row.get("trend_changepoints", []) | |
| if isinstance(changepoints, list): | |
| for idx, bp in enumerate(changepoints): | |
| if bp in model_data["x"].values: | |
| cp_row = model_data[model_data["x"] == bp] | |
| # Make first changepoint smaller if there are multiple | |
| marker_size = 12 if idx == 0 else 6 | |
| fig.add_trace(go.Scatter( | |
| x=cp_row["x"], | |
| y=cp_row["y"], | |
| mode="markers", | |
| marker=dict( | |
| size=marker_size, | |
| color=color, | |
| symbol="circle-open", | |
| line=dict(width=3, color="white") | |
| ), | |
| hovertemplate=f"<b>Trend Changepoint</b><br>Model: {model}<br>x=%{{x}}<br>y=%{{y}}", | |
| showlegend=False | |
| )) | |
| # Style the figure & Lock axis order | |
| fig.update_layout( | |
| xaxis=dict( | |
| categoryorder="array", | |
| categoryarray=chronological_order, | |
| title="Year_Month", | |
| color="#e5e7eb", | |
| gridcolor="#374151", | |
| nticks=30 # Limit number of x-axis ticks displayed | |
| ), | |
| yaxis=dict(title="Average Faithfulness (0-2 scale)", color="#e5e7eb", gridcolor="#374151"), | |
| paper_bgcolor="#1f2937", | |
| plot_bgcolor="#1f2937", | |
| font=dict(family="IBM Plex Sans", size=12, color="#e5e7eb"), | |
| hoverlabel=dict(bgcolor="#374151", font=dict(color="#e5e7eb"), bordercolor="#4b5563"), | |
| margin=dict(l=40, r=20, t=60, b=40), | |
| # title=dict(text="Model Comparison with Trend Changepoints", x=0.5, font=dict(color="#e5e7eb")), | |
| showlegend=True, | |
| yaxis_range=[-0.1, 2.1], | |
| xaxis_tickangle=-45 | |
| ) | |
| if len(table_years) > 1: | |
| return yearly_df, fig | |
| else: | |
| return monthly_df, fig | |
| def create_faithfulness_plot(model_name): | |
| """ | |
| Create a Plotly figure showing faithfulness scores with segments and cumulative refusals. | |
| Args: | |
| model_name: Name of the model to plot | |
| Returns: | |
| Plotly Figure object or None if model not found | |
| """ | |
| if not model_name: | |
| return go.Figure() | |
| # Load raw model data | |
| model_data = load_raw_model_data(cfg.get("data_path"), model_name) | |
| if not model_data: | |
| return go.Figure() | |
| # Extract data | |
| dates = model_data.get('dates', []) | |
| faithfulness = model_data.get('faithfulness', []) | |
| cumulative_refusals = model_data.get('cumulative_refusals', []) | |
| segments = model_data.get('segments', []) | |
| changepoint_dates = model_data.get('changepoint_dates', []) | |
| total_obs = model_data.get('total_observations', max(cumulative_refusals) if cumulative_refusals else 1) | |
| # Calculate cumulative average faithfulness | |
| cumulative_avg_faithfulness = calculate_cumulative_average(faithfulness) if faithfulness else [] | |
| # Create figure with secondary y-axis | |
| fig = go.Figure() | |
| # Add faithfulness scatter points | |
| fig.add_trace(go.Scatter( | |
| x=dates, | |
| y=faithfulness, | |
| mode='markers', | |
| name='Faithfulness', | |
| marker=dict(size=4, color='steelblue', opacity=0.6), | |
| hovertemplate='Date: %{x}<br>Faithfulness: %{y}<extra></extra>', | |
| yaxis='y' | |
| )) | |
| # Add cumulative average faithfulness line (green curve) | |
| if cumulative_avg_faithfulness: | |
| fig.add_trace(go.Scatter( | |
| x=dates, | |
| y=cumulative_avg_faithfulness, | |
| mode='lines', | |
| name='Cumulative Average', | |
| line=dict(color='#22c55e', width=2.5), | |
| hovertemplate='Date: %{x}<br>Cumulative Avg: %{y:.3f}<extra></extra>', | |
| yaxis='y' | |
| )) | |
| # Add segment mean lines (horizontal lines for each segment) | |
| for seg in segments: | |
| fig.add_trace(go.Scatter( | |
| x=[seg['start_date'], seg['end_date']], | |
| y=[seg['mean_faithfulness'], seg['mean_faithfulness']], | |
| mode='lines', | |
| name=f"Segment Mean ({seg['mean_faithfulness']:.2f})", | |
| line=dict(color='red', width=2), | |
| hovertemplate=f"Mean: {seg['mean_faithfulness']:.2f}<br>Refusal Rate: {seg['refusal_rate_percent']:.1f}%<extra></extra>", | |
| yaxis='y', | |
| showlegend=False | |
| )) | |
| # Add changepoint vertical lines | |
| for cp_date in changepoint_dates: | |
| fig.add_vline( | |
| x=cp_date, | |
| line=dict(color='darkred', dash='dash', width=1.5), | |
| opacity=0.7 | |
| ) | |
| # Add cumulative refusals line (on secondary y-axis) | |
| fig.add_trace(go.Scatter( | |
| x=dates, | |
| y=cumulative_refusals, | |
| mode='lines', | |
| name='Cumulative Refusals', | |
| line=dict(color='darkorange', width=2), | |
| hovertemplate='Date: %{x}<br>Cumulative Refusals: %{y}<extra></extra>', | |
| yaxis='y2' | |
| )) | |
| # Add refusal rate annotations for each segment | |
| for seg in segments: | |
| # Calculate midpoint date for annotation | |
| start = datetime.strptime(seg['start_date'], '%Y-%m-%d') | |
| end = datetime.strptime(seg['end_date'], '%Y-%m-%d') | |
| mid_date = start + (end - start) / 2 | |
| fig.add_annotation( | |
| x=mid_date.strftime('%Y-%m-%d'), | |
| y=1.85, | |
| text=f"{seg['refusal_rate_percent']:.1f}%", | |
| showarrow=False, | |
| font=dict(size=10, color='#fbbf24', family='IBM Plex Sans'), | |
| bgcolor='rgba(55, 65, 81, 0.9)', | |
| bordercolor='#fbbf24', | |
| borderwidth=1, | |
| yref='y' | |
| ) | |
| # Update layout with dual y-axes | |
| fig.update_layout( | |
| title=dict( | |
| text=f"{model_name}: Faithfulness with PELT Changepoints", | |
| x=0.5, | |
| font=dict(color='#e5e7eb', size=14, family='IBM Plex Sans') | |
| ), | |
| xaxis=dict( | |
| title='Date', | |
| color='#e5e7eb', | |
| gridcolor='#374151', | |
| tickangle=-45 | |
| ), | |
| yaxis=dict( | |
| title='Faithfulness Score', | |
| color='#e5e7eb', | |
| gridcolor='#374151', | |
| range=[-0.05, 2.05], | |
| side='left' | |
| ), | |
| yaxis2=dict( | |
| title='Cumulative Refusals', | |
| color='#fbbf24', | |
| gridcolor='#374151', | |
| range=[0, total_obs], | |
| overlaying='y', | |
| side='right' | |
| ), | |
| paper_bgcolor='#1f2937', | |
| plot_bgcolor='#1f2937', | |
| font=dict(family='IBM Plex Sans', size=12, color='#e5e7eb'), | |
| hoverlabel=dict(bgcolor='#374151', font=dict(color='#e5e7eb'), bordercolor='#4b5563'), | |
| margin=dict(l=60, r=60, t=80, b=80), | |
| showlegend=False, | |
| legend=dict( | |
| x=0.02, | |
| y=0.98, | |
| bgcolor='rgba(55, 65, 81, 0.9)', | |
| bordercolor='#4b5563', | |
| borderwidth=1 | |
| ) | |
| ) | |
| return fig | |
| def update_model_comparison(split_enabled, model1, model2): | |
| """ | |
| Update model comparison plots based on split checkbox and model selections. | |
| Args: | |
| split_enabled: Boolean indicating if 2 graphs should be shown | |
| model1: First model name | |
| model2: Second model name (only used if split_enabled) | |
| Returns: | |
| Tuple of (plot1, plot2, visibility_dict) | |
| """ | |
| if split_enabled: | |
| # Show 2 graphs side by side | |
| plot1 = create_faithfulness_plot(model1) if model1 else go.Figure() | |
| plot2 = create_faithfulness_plot(model2) if model2 else go.Figure() | |
| return plot1, plot2, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True) | |
| else: | |
| # Show only 1 graph | |
| plot1 = create_faithfulness_plot(model1) if model1 else go.Figure() | |
| return plot1, go.Figure(), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) | |
| def initialize_model_comparison(): | |
| """ | |
| Initialize model comparison section with random model selections on page load. | |
| Returns: | |
| Tuple of (model1_value, model2_value, plot1, plot2, col_plot_1_visible, col_model_2_visible, col_plot_2_visible) | |
| """ | |
| import random | |
| # Select random models for initial display | |
| if len(all_models) >= 2: | |
| random_models = random.sample(all_models, 2) | |
| model1 = random_models[0] | |
| model2 = random_models[1] | |
| elif len(all_models) == 1: | |
| model1 = all_models[0] | |
| model2 = None | |
| else: | |
| model1 = None | |
| model2 = None | |
| # Generate initial plot for model1 only (split_enabled=False by default) | |
| plot1 = create_faithfulness_plot(model1) if model1 else go.Figure() | |
| return ( | |
| gr.update(value=model1), # model_dropdown_1 | |
| gr.update(value=model2), # model_dropdown_2 | |
| plot1, # comparison_plot_1 | |
| go.Figure(), # comparison_plot_2 (empty since split is disabled) | |
| gr.update(visible=True), # col_plot_1 | |
| gr.update(visible=False), # col_model_2 | |
| gr.update(visible=False) # col_plot_2 | |
| ) | |
| def initialize_main_dashboard(graph_year_selector_value): | |
| """ | |
| Initialize main dashboard with random model selections on page load. | |
| Args: | |
| graph_year_selector_value: Selected years from the graph year selector | |
| Returns: | |
| Tuple of (graph_model_filter_value, leaderboard, line_plot) | |
| """ | |
| import random | |
| # Select random models for initial display (5 models for graph) | |
| num_models = min(5, len(all_models)) | |
| if num_models > 0: | |
| random_graph_models = random.sample(all_models, num_models) | |
| else: | |
| random_graph_models = [] | |
| # Generate dashboard with random models | |
| leaderboard, line_plot = update_dashboard(graph_year_selector_value, random_graph_models) | |
| return ( | |
| gr.update(value=random_graph_models), # graph_model_filter | |
| leaderboard, | |
| line_plot | |
| ) | |
| def initialize_all_components(graph_year_selector_value): | |
| """ | |
| Initialize all components on page load: main dashboard and model comparison. | |
| Combining into a single load function to prevent double-rendering issues in HF Spaces. | |
| Args: | |
| graph_year_selector_value: Selected years from the graph year selector | |
| Returns: | |
| Tuple of all outputs for both dashboard and comparison sections | |
| """ | |
| import random | |
| # Initialize main dashboard | |
| num_models = min(5, len(all_models)) | |
| if num_models > 0: | |
| random_graph_models = random.sample(all_models, num_models) | |
| else: | |
| random_graph_models = [] | |
| leaderboard, line_plot = update_dashboard(graph_year_selector_value, random_graph_models) | |
| # Initialize model comparison with 2 graphs side by side (split_enabled=True) | |
| if len(all_models) >= 2: | |
| random_models = random.sample(all_models, 2) | |
| model1 = random_models[0] | |
| model2 = random_models[1] | |
| elif len(all_models) == 1: | |
| model1 = all_models[0] | |
| model2 = None | |
| else: | |
| model1 = None | |
| model2 = None | |
| # Generate both plots for split view | |
| plot1 = create_faithfulness_plot(model1) if model1 else go.Figure() | |
| plot2 = create_faithfulness_plot(model2) if model2 else go.Figure() | |
| return ( | |
| gr.update(value=random_graph_models), # graph_model_filter | |
| leaderboard, | |
| line_plot, | |
| gr.update(value=model1), # model_dropdown_1 | |
| gr.update(value=model2), # model_dropdown_2 | |
| plot1, # comparison_plot_1 | |
| plot2, # comparison_plot_2 (now showing model2) | |
| gr.update(visible=True), # col_plot_1 | |
| gr.update(visible=True), # col_model_2 (now visible) | |
| gr.update(visible=True) # col_plot_2 (now visible) | |
| ) | |
| ### HELPER FUNCTIONS | |
| def generate_distinct_colors(n): | |
| """ | |
| Generate n distinct colors using HSL color space. | |
| Args: | |
| n: Number of distinct colors needed | |
| Returns: | |
| List of hex color strings | |
| """ | |
| colors = [] | |
| for i in range(n): | |
| hue = (i * 360 / n) % 360 | |
| saturation = 70 + (i % 3) * 10 # Vary saturation slightly for more distinction | |
| lightness = 55 + (i % 2) * 10 # Vary lightness slightly | |
| # Convert HSL to RGB | |
| h = hue / 360 | |
| s = saturation / 100 | |
| l = lightness / 100 | |
| if s == 0: | |
| r = g = b = l | |
| else: | |
| def hue_to_rgb(p, q, t): | |
| if t < 0: t += 1 | |
| if t > 1: t -= 1 | |
| if t < 1/6: return p + (q - p) * 6 * t | |
| if t < 1/2: return q | |
| if t < 2/3: return p + (q - p) * (2/3 - t) * 6 | |
| return p | |
| q = l * (1 + s) if l < 0.5 else l + s - l * s | |
| p = 2 * l - q | |
| r = hue_to_rgb(p, q, h + 1/3) | |
| g = hue_to_rgb(p, q, h) | |
| b = hue_to_rgb(p, q, h - 1/3) | |
| # Convert to hex | |
| hex_color = f"#{int(r*255):02x}{int(g*255):02x}{int(b*255):02x}" | |
| colors.append(hex_color) | |
| return colors | |
| ### DATA PREP | |
| # Load data | |
| data = load_data(cfg.get("data_path")) | |
| # Load model metadata | |
| model_metadata = load_model_metadata(cfg.get("metadata_path")) | |
| # Build year to columns mapping | |
| year_to_columns = build_year_column_mapping(cfg.get("years"), cfg.get("months")) | |
| # Create DataFrame (new format doesn't need models_map or metrics) | |
| df = create_dataframe(cfg, data, model_metadata=model_metadata) | |
| # Get aggregated column lists | |
| aggregated_cols_year, aggregated_cols_month = get_aggregated_columns( | |
| cfg.get("years"), year_to_columns | |
| ) | |
| cfg["aggregated_cols_year"] = aggregated_cols_year | |
| cfg["aggregated_cols_month"] = aggregated_cols_month | |
| # Generate consistent color mapping for all models (do this once globally) | |
| all_models = sorted(df["Model"].unique().tolist()) # Sort for consistency | |
| colors = generate_distinct_colors(len(all_models)) | |
| GLOBAL_MODEL_COLORS = {model: colors[i] for i, model in enumerate(all_models)} | |
| ### BUILD UI | |
| theme = gr.themes.Base( | |
| primary_hue="green", | |
| secondary_hue="green", | |
| radius_size="lg", | |
| text_size="sm", | |
| ) | |
| # Custom CSS for scrollable dropdown and table styling | |
| custom_css = """ | |
| /* Limit the height of selected items in multiselect dropdown */ | |
| .scrollable-dropdown .wrap-inner { | |
| max-height: 100px !important; | |
| overflow-y: auto !important; | |
| } | |
| /* Alternative selector for the selected items container */ | |
| .scrollable-dropdown div[data-testid="block-label-inner"] ~ div { | |
| max-height: 100px !important; | |
| overflow-y: auto !important; | |
| } | |
| /* Style the leaderboard table background */ | |
| .gradio-container .gr-table-wrap, | |
| .gradio-container .gr-dataframe, | |
| .gradio-leaderboard { | |
| background-color: #fafafa !important; | |
| } | |
| .gradio-container table { | |
| background-color: #fafafa !important; | |
| } | |
| /* Header row - gray background */ | |
| .gradio-container table thead tr, | |
| .gradio-container table thead th { | |
| background-color: #f3f4f6 !important; | |
| } | |
| /* First column (td:first-child) - gray background for all rows */ | |
| .gradio-container table tbody tr td:first-child { | |
| background-color: #f3f4f6 !important; | |
| } | |
| /* Odd rows - very light background (excluding first column) */ | |
| .gradio-container table tbody tr:nth-child(odd) td { | |
| background-color: #fafafa !important; | |
| } | |
| /* Even rows - white background (excluding first column) */ | |
| .gradio-container table tbody tr:nth-child(even) td { | |
| background-color: white !important; | |
| } | |
| /* Keep first column gray for both odd and even rows */ | |
| .gradio-container table tbody tr:nth-child(odd) td:first-child, | |
| .gradio-container table tbody tr:nth-child(even) td:first-child { | |
| background-color: #f3f4f6 !important; | |
| } | |
| /* Hover effect for all rows */ | |
| .gradio-container table tbody tr:hover td { | |
| background-color: #f3f4f6 !important; | |
| } | |
| /* Keep first column darker gray on hover */ | |
| .gradio-container table tbody tr:hover td:first-child { | |
| background-color: #e5e7eb !important; | |
| } | |
| """ | |
| # JavaScript to force light mode | |
| js_func = """ | |
| function refresh() { | |
| const url = new URL(window.location); | |
| if (url.searchParams.get('__theme') !== 'light') { | |
| url.searchParams.set('__theme', 'light'); | |
| window.location.href = url.href; | |
| } | |
| } | |
| """ | |
| with gr.Blocks(theme=theme, css=custom_css, js=js_func) as demo: | |
| gr.Markdown( | |
| """ | |
| <div style='text-align: center;'> | |
| <h1>📶 LLMLagBench - All LLMs lag behind</h1> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown(AUTHORS) | |
| gr.Markdown("<hr>") | |
| gr.Markdown(LLMLAGBENCH_INTRO) | |
| gr.Markdown("<hr>") | |
| with gr.Row(): | |
| # Year selector for graph | |
| with gr.Column(scale=1): | |
| graph_year_selector = gr.CheckboxGroup(choices=cfg.get("years"), value=["2021", "2022", "2023", "2024", "2025"], label="Select Years for Graph") | |
| with gr.Column(scale=1): | |
| graph_model_filter = gr.Dropdown( | |
| choices=df["Model"].unique().tolist(), | |
| multiselect=True, | |
| filterable=True, | |
| value=None, # Will be set randomly on page load | |
| label="Select Models for Graph", | |
| elem_classes="scrollable-dropdown" | |
| ) | |
| gr.Markdown("## Model Comparison with Trend Changepoints") | |
| line_plot = gr.Plot(label="Model Trends") | |
| gr.Markdown('<hr>') | |
| gr.Markdown('<br>') | |
| gr.Markdown(LEADERBOARD_INTRO) | |
| leaderboard = Leaderboard( | |
| value=df, | |
| search_columns=["Model"], | |
| interactive=False, | |
| ) | |
| # Wire events — graph inputs update the leaderboard + plot | |
| for comp in (graph_year_selector, graph_model_filter): | |
| comp.change( | |
| fn=update_dashboard, | |
| inputs=[graph_year_selector, graph_model_filter], | |
| outputs=[leaderboard, line_plot], | |
| ) | |
| gr.Markdown('<hr>') | |
| gr.Markdown('<br>') | |
| # Model comparison section - wrap everything in a container to prevent duplication | |
| with gr.Column(): | |
| gr.Markdown("## Model Comparison: Faithfulness to Ideal Answer with PELT Changepoints") | |
| gr.Markdown(MODEL_COMPARISON_INTRO) | |
| with gr.Row(): | |
| split_checkbox = gr.Checkbox( | |
| label="Split into 2 segments", | |
| value=True, | |
| info="Enable to compare two models side by side" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_dropdown_1 = gr.Dropdown( | |
| choices=all_models, | |
| value=None, # Will be set randomly on page load | |
| label="Select Model 1", | |
| filterable=True, | |
| elem_classes="scrollable-dropdown" | |
| ) | |
| with gr.Column(scale=1, visible=False) as col_model_2: | |
| model_dropdown_2 = gr.Dropdown( | |
| choices=all_models, | |
| value=None, # Will be set randomly on page load | |
| label="Select Model 2", | |
| filterable=True, | |
| elem_classes="scrollable-dropdown" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1) as col_plot_1: | |
| comparison_plot_1 = gr.Plot(label="Model Faithfulness Analysis") | |
| with gr.Column(scale=1, visible=False) as col_plot_2: | |
| comparison_plot_2 = gr.Plot(label="Model Faithfulness Analysis") | |
| # Wire model comparison events | |
| split_checkbox.change( | |
| fn=update_model_comparison, | |
| inputs=[split_checkbox, model_dropdown_1, model_dropdown_2], | |
| outputs=[comparison_plot_1, comparison_plot_2, col_plot_1, col_model_2, col_plot_2] | |
| ) | |
| model_dropdown_1.change( | |
| fn=update_model_comparison, | |
| inputs=[split_checkbox, model_dropdown_1, model_dropdown_2], | |
| outputs=[comparison_plot_1, comparison_plot_2, col_plot_1, col_model_2, col_plot_2] | |
| ) | |
| model_dropdown_2.change( | |
| fn=update_model_comparison, | |
| inputs=[split_checkbox, model_dropdown_1, model_dropdown_2], | |
| outputs=[comparison_plot_1, comparison_plot_2, col_plot_1, col_model_2, col_plot_2] | |
| ) | |
| gr.Markdown('<hr>') | |
| gr.Markdown('<br>') | |
| # Exemplary Questions section | |
| gr.Markdown(EXEMPLARY_QUESTIONS_INTRO) | |
| exemplary_questions_df = pd.DataFrame( | |
| EXEMPLARY_QUESTIONS_DATA, | |
| columns=["Date", "Question", "Gold Answer", "Possible decision"] | |
| ) | |
| gr.Dataframe( | |
| value=exemplary_questions_df, | |
| interactive=False, | |
| wrap=True | |
| ) | |
| # Citation | |
| gr.Markdown('<hr>') | |
| gr.Markdown('<br>') | |
| with gr.Row(): | |
| with gr.Accordion("📙 Citation", open=False): | |
| citation_button = gr.Textbox( | |
| value=CIT_BTN_TEXT, | |
| label=CIT_BTN_LABEL, | |
| lines=20, | |
| elem_id="citation-button", | |
| show_copy_button=True, | |
| ) | |
| # Initialize all components on load with a single load call (prevents double-rendering in HF Spaces) | |
| demo.load( | |
| fn=initialize_all_components, | |
| inputs=[graph_year_selector], | |
| outputs=[graph_model_filter, leaderboard, line_plot, model_dropdown_1, model_dropdown_2, comparison_plot_1, comparison_plot_2, col_plot_1, col_model_2, col_plot_2] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |