File size: 34,453 Bytes
a8592c4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
import hashlib
import json
import pickle
from datetime import datetime
from pathlib import Path

import gradio as gr
import pandas as pd
import plotly.graph_objects as go
from datasets import load_dataset
from tqdm import tqdm

# Cache configuration
global CACHE_DIR
global TASKS_INDEX_FILE
global TASK_DATA_DIR
global DATASET_DATA_DIR
global METRICS_INDEX_FILE

CACHE_DIR = Path("./pwc_cache")
CACHE_DIR.mkdir(exist_ok=True)

# Directory structure for disk-based storage
TASKS_INDEX_FILE = CACHE_DIR / "tasks_index.json"  # Small JSON file with task list
TASK_DATA_DIR = CACHE_DIR / "task_data"  # Directory for individual task files
DATASET_DATA_DIR = CACHE_DIR / "dataset_data"  # Directory for individual dataset files
METRICS_INDEX_FILE = CACHE_DIR / "metrics_index.json"  # Metrics metadata

# Create directories
TASK_DATA_DIR.mkdir(exist_ok=True)
DATASET_DATA_DIR.mkdir(exist_ok=True)


def sanitize_filename(name):
    """Convert a string to a safe filename."""
    # Replace problematic characters with underscores
    safe_name = name.replace('/', '_').replace('\\', '_').replace(':', '_')
    safe_name = safe_name.replace('*', '_').replace('?', '_').replace('"', '_')
    safe_name = safe_name.replace('<', '_').replace('>', '_').replace('|', '_')
    safe_name = safe_name.replace(' ', '_').replace('.', '_')
    # Remove multiple underscores and trim
    safe_name = '_'.join(filter(None, safe_name.split('_')))
    # Limit length to avoid filesystem issues
    if len(safe_name) > 200:
        # If too long, use first 150 chars + hash of full name
        safe_name = safe_name[:150] + '_' + hashlib.md5(name.encode()).hexdigest()[:8]
    return safe_name


def get_task_filename(task):
    """Generate a safe filename for a task."""
    safe_name = sanitize_filename(task)
    return TASK_DATA_DIR / f"task_{safe_name}.pkl"


def get_dataset_filename(task, dataset_name):
    """Generate a safe filename for a dataset."""
    safe_task = sanitize_filename(task)
    safe_dataset = sanitize_filename(dataset_name)
    # Include both task and dataset in filename for clarity
    filename = f"data_{safe_task}_{safe_dataset}.pkl"
    # If combined name is too long, shorten it
    if len(filename) > 255:
        # Use shorter version with hash
        filename = f"data_{safe_task[:50]}_{safe_dataset[:50]}_{hashlib.md5(f'{task}||{dataset_name}'.encode()).hexdigest()[:8]}.pkl"
    return DATASET_DATA_DIR / filename


def cache_exists():
    """Check if cache structure exists."""
    print(f"{TASKS_INDEX_FILE =}")
    print(f"{METRICS_INDEX_FILE =}")
    print(f"{TASKS_INDEX_FILE.exists() =}")
    print(f"{METRICS_INDEX_FILE.exists() =}")

    return TASKS_INDEX_FILE.exists() and METRICS_INDEX_FILE.exists()


def build_disk_based_cache():
    """Build cache with minimal memory usage - process dataset in streaming fashion."""

    import os
    print("Michael test", os.path.isdir("./pwc_cache"))
    print("=" * 60)

    
    print("=" * 60)
    print("Building disk-based cache (one-time operation)...")
    print("=" * 60)

    # Initialize tracking structures (kept small)
    tasks_set = set()
    metrics_index = {}

    print("\n[1/4] Streaming dataset and building cache...")

    # Load dataset in streaming mode to save memory
    ds = load_dataset("pwc-archive/evaluation-tables", split="train", streaming=False)
    total_items = len(ds)

    processed_count = 0
    dataset_count = 0

    for idx, item in tqdm(enumerate(ds), total=total_items):
        # Progress indicator

        task = item['task']
        if not task:
            continue

        tasks_set.add(task)

        # Load existing task data from disk or create new
        task_file = get_task_filename(task)
        if task_file.exists():
            with open(task_file, 'rb') as f:
                task_data = pickle.load(f)
        else:
            task_data = {
                'categories': set(),
                'datasets': set(),
                'date_range': {'min': None, 'max': None}
            }

        # Update task data
        if item['categories']:
            task_data['categories'].update(item['categories'])

        # Process datasets
        if item['datasets']:
            for dataset in item['datasets']:
                if not isinstance(dataset, dict) or 'dataset' not in dataset:
                    continue

                dataset_name = dataset['dataset']
                dataset_file = get_dataset_filename(task, dataset_name)

                # Skip if already processed
                if dataset_file.exists():
                    task_data['datasets'].add(dataset_name)
                    continue

                task_data['datasets'].add(dataset_name)

                # Process SOTA data
                if 'sota' not in dataset or 'rows' not in dataset['sota']:
                    continue

                models_data = []
                for row in dataset['sota']['rows']:
                    if not isinstance(row, dict):
                        continue

                    model_name = row.get('model_name', 'Unknown Model')

                    # Extract metrics
                    metrics = {}
                    if 'metrics' in row and isinstance(row['metrics'], dict):
                        for metric_name, metric_value in row['metrics'].items():
                            if metric_value is not None:
                                metrics[metric_name] = metric_value
                                # Track metric metadata
                                if metric_name not in metrics_index:
                                    metrics_index[metric_name] = {
                                        'count': 0,
                                        'is_lower_better': any(kw in metric_name.lower()
                                                               for kw in ['error', 'loss', 'time', 'cost'])
                                    }
                                metrics_index[metric_name]['count'] += 1

                    # Parse date
                    paper_date = row.get('paper_date')
                    try:
                        if paper_date and isinstance(paper_date, str):
                            release_date = pd.to_datetime(paper_date)
                        else:
                            release_date = pd.to_datetime('2020-01-01')
                    except:
                        release_date = pd.to_datetime('2020-01-01')

                    # Update date range
                    if task_data['date_range']['min'] is None or release_date < task_data['date_range']['min']:
                        task_data['date_range']['min'] = release_date
                    if task_data['date_range']['max'] is None or release_date > task_data['date_range']['max']:
                        task_data['date_range']['max'] = release_date

                    # Build model entry
                    model_entry = {
                        'model_name': model_name,
                        'release_date': release_date,
                        'paper_date': row.get('paper_date', ''),  # Store raw paper_date for dynamic parsing
                        'paper_url': row.get('paper_url', ''),
                        'paper_title': row.get('paper_title', ''),
                        'code_url': row.get('code_links', [''])[0] if row.get('code_links') else '',
                        **metrics
                    }

                    models_data.append(model_entry)

                if models_data:
                    df = pd.DataFrame(models_data)
                    df = df.sort_values('release_date')

                    # Save dataset to its own file
                    with open(dataset_file, 'wb') as f:
                        pickle.dump(df, f, protocol=pickle.HIGHEST_PROTOCOL)

                    dataset_count += 1

                    # Clear DataFrame from memory
                    del df
                    del models_data

        # Save updated task data back to disk
        with open(task_file, 'wb') as f:
            # Convert sets to lists for serialization
            task_data_to_save = {
                'categories': sorted(list(task_data['categories'])),
                'datasets': sorted(list(task_data['datasets'])),
                'date_range': task_data['date_range']
            }
            pickle.dump(task_data_to_save, f, protocol=pickle.HIGHEST_PROTOCOL)

        # Clear task data from memory
        del task_data
        processed_count += 1

    print(f"\nβœ“ Processed {len(tasks_set)} tasks and {dataset_count} datasets")

    print("\n[2/4] Saving index files...")

    # Save tasks index (small file)
    tasks_list = sorted(list(tasks_set))
    with open(TASKS_INDEX_FILE, 'w') as f:
        json.dump(tasks_list, f)
    print(f"  βœ“ Saved tasks index ({len(tasks_list)} tasks)")

    # Save metrics index
    with open(METRICS_INDEX_FILE, 'w') as f:
        json.dump(metrics_index, f, indent=2)
    print(f"  βœ“ Saved metrics index ({len(metrics_index)} metrics)")

    print("\n[3/4] Calculating cache statistics...")

    # Calculate total cache size
    total_size = 0
    for file in TASK_DATA_DIR.glob("*.pkl"):
        total_size += file.stat().st_size
    for file in DATASET_DATA_DIR.glob("*.pkl"):
        total_size += file.stat().st_size

    print(f"  βœ“ Total cache size: {total_size / 1024 / 1024:.1f} MB")
    print(f"  βœ“ Task files: {len(list(TASK_DATA_DIR.glob('*.pkl')))}")
    print(f"  βœ“ Dataset files: {len(list(DATASET_DATA_DIR.glob('*.pkl')))}")

    print("\n[4/4] Cache building complete!")
    print("=" * 60)

    return tasks_list


def load_tasks_index():
    """Load just the task list from disk."""
    with open(TASKS_INDEX_FILE, 'r') as f:
        return json.load(f)


def load_task_data(task):
    """Load data for a specific task from disk."""
    task_file = get_task_filename(task)
    if task_file.exists():
        with open(task_file, 'rb') as f:
            return pickle.load(f)
    return None


def load_dataset_data(task, dataset_name):
    """Load a specific dataset from disk."""
    dataset_file = get_dataset_filename(task, dataset_name)
    if dataset_file.exists():
        with open(dataset_file, 'rb') as f:
            return pickle.load(f)
    return pd.DataFrame()


def load_metrics_index():
    """Load metrics index from disk."""
    if METRICS_INDEX_FILE.exists():
        with open(METRICS_INDEX_FILE, 'r') as f:
            return json.load(f)
    return {}

# Initialize - build cache if doesn't exist
if cache_exists():
    print("Loading task index from disk...")
    TASKS = load_tasks_index()
    print(f"βœ“ Loaded {len(TASKS)} tasks")
else:
    TASKS = build_disk_based_cache()

# Load metrics index once (it's small)
METRICS_INDEX = load_metrics_index()


# Memory-efficient accessor functions
def get_tasks():
    """Get all tasks from index."""
    return TASKS


def get_task_data(task):
    """Load task data from disk on-demand."""
    return load_task_data(task)


def get_categories(task):
    """Get categories for a task (loads from disk)."""
    task_data = get_task_data(task)
    return task_data['categories'] if task_data else []


def get_datasets_for_task(task):
    """Get datasets for a task (loads from disk)."""
    task_data = get_task_data(task)
    return task_data['datasets'] if task_data else []


def get_cached_model_data(task, dataset_name):
    """Load dataset from disk on-demand."""
    return load_dataset_data(task, dataset_name)


def parse_paper_date(paper_date, paper_title="", paper_url=""):
    """Parse paper date with improved fallback strategies."""
    import re

    # Try to parse the raw paper_date if available
    if paper_date and isinstance(paper_date, str) and paper_date.strip():
        try:
            # Try common date formats
            date_formats = [
                '%Y-%m-%d',
                '%Y/%m/%d',
                '%d-%m-%Y',
                '%d/%m/%Y',
                '%Y-%m',
                '%Y/%m',
                '%Y'
            ]

            for fmt in date_formats:
                try:
                    return pd.to_datetime(paper_date.strip(), format=fmt)
                except:
                    continue

            # Try pandas automatic parsing
            return pd.to_datetime(paper_date.strip())
        except:
            pass

    # Fallback: try to extract year from paper title or URL
    year_pattern = r'\b(19[5-9]\d|20[0-9]\d)\b'  # Match 1950-2099

    # Look for year in paper title
    if paper_title:
        years = re.findall(year_pattern, str(paper_title))
        if years:
            try:
                year = max(years)  # Use the latest year found
                return pd.to_datetime(f'{year}-01-01')
            except:
                pass

    # Look for year in paper URL
    if paper_url:
        years = re.findall(year_pattern, str(paper_url))
        if years:
            try:
                year = max(years)  # Use the latest year found
                return pd.to_datetime(f'{year}-01-01')
            except:
                pass

    # Final fallback: return None instead of a default year
    return None


def get_task_statistics(task):
    """Get statistics about a task."""
    return {}


def create_sota_plot(df, metric):
    """Create a plot showing model performance evolution over time.

    Args:
        df: DataFrame with model data
        metric: Metric name to plot on y-axis
    """
    if df.empty or metric not in df.columns:
        fig = go.Figure()
        fig.add_annotation(
            text="No data available for this metric",
            xref="paper",
            yref="paper",
            x=0.5,
            y=0.5,
            showarrow=False,
            font=dict(size=20)
        )
        fig.update_layout(
            title="No Data Available",
            height=600,
            plot_bgcolor='white',
            paper_bgcolor='white'
        )
        return fig

    # Remove rows where the metric is NaN
    df_clean = df.dropna(subset=[metric]).copy()

    if df_clean.empty:
        fig = go.Figure()
        fig.add_annotation(
            text="No valid data points for this metric",
            xref="paper",
            yref="paper",
            x=0.5,
            y=0.5,
            showarrow=False,
            font=dict(size=20)
        )
        fig.update_layout(
            title="No Data Available",
            height=600,
            plot_bgcolor='white',
            paper_bgcolor='white'
        )
        return fig

    # Convert metric column to numeric, handling any string values
    try:
        df_clean[metric] = pd.to_numeric(
            df_clean[metric].apply(lambda x: x.strip()[:-1] if isinstance(x, str) and x.strip().endswith("%") else x),
            errors='coerce')
        # Remove any rows that couldn't be converted to numeric
        df_clean = df_clean.dropna(subset=[metric])

        if df_clean.empty:
            fig = go.Figure()
            fig.add_annotation(
                text=f"No numeric data available for metric: {metric}",
                xref="paper",
                yref="paper",
                x=0.5,
                y=0.5,
                showarrow=False,
                font=dict(size=20)
            )
            fig.update_layout(
                title="No Numeric Data Available",
                height=600,
                plot_bgcolor='white',
                paper_bgcolor='white'
            )
            return fig

    except Exception as e:
        fig = go.Figure()
        fig.add_annotation(
            text=f"Error processing metric data: {str(e)}",
            xref="paper",
            yref="paper",
            x=0.5,
            y=0.5,
            showarrow=False,
            font=dict(size=16)
        )
        fig.update_layout(
            title="Data Processing Error",
            height=600,
            plot_bgcolor='white',
            paper_bgcolor='white'
        )
        return fig

    # Recalculate release dates dynamically from raw paper_date if available
    df_processed = df_clean.copy()
    if 'paper_date' in df_processed.columns:
        # Parse dates dynamically using improved logic
        df_processed['dynamic_release_date'] = df_processed.apply(
            lambda row: parse_paper_date(
                row.get('paper_date', ''),
                row.get('paper_title', ''),
                row.get('paper_url', '')
            ), axis=1
        )
        # Use dynamic dates if available, otherwise fallback to original release_date
        df_processed['final_release_date'] = df_processed['dynamic_release_date'].fillna(df_processed['release_date'])
    else:
        # If no paper_date column, use existing release_date
        df_processed['final_release_date'] = df_processed['release_date']

    # Filter out rows with no valid date
    df_with_dates = df_processed[df_processed['final_release_date'].notna()].copy()

    if df_with_dates.empty:
        # If no valid dates, return empty plot
        fig = go.Figure()
        fig.add_annotation(
            text="No valid dates available for this dataset",
            xref="paper",
            yref="paper",
            x=0.5,
            y=0.5,
            showarrow=False,
            font=dict(size=20)
        )
        fig.update_layout(
            title="No Date Data Available",
            height=600,
            plot_bgcolor='white',
            paper_bgcolor='white'
        )
        return fig

    # Sort by final release date
    df_sorted = df_with_dates.sort_values('final_release_date').copy()

    # Check if metric is lower-better
    is_lower_better = False
    if metric in METRICS_INDEX:
        is_lower_better = METRICS_INDEX[metric].get('is_lower_better', False)
    else:
        is_lower_better = any(keyword in metric.lower() for keyword in ['error', 'loss', 'time', 'cost'])

    if is_lower_better:
        df_sorted['cumulative_best'] = df_sorted[metric].cummin()
        df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best']
    else:
        df_sorted['cumulative_best'] = df_sorted[metric].cummax()
        df_sorted['is_sota'] = df_sorted[metric] == df_sorted['cumulative_best']

    # Get SOTA models
    sota_df = df_sorted[df_sorted['is_sota']].copy()

    # Use the dynamically calculated dates for x-axis
    x_values = df_sorted['final_release_date']
    x_axis_title = 'Release Date'

    # Create the plot
    fig = go.Figure()

    # Add all models as scatter points
    fig.add_trace(go.Scatter(
        x=x_values,
        y=df_sorted[metric],
        mode='markers',
        name='All models',
        marker=dict(
            color=['#00CED1' if is_sota else 'lightgray'
                   for is_sota in df_sorted['is_sota']],
            size=8,
            opacity=0.7
        ),
        text=df_sorted['model_name'],
        customdata=df_sorted[['paper_title', 'paper_url', 'code_url']],
        hovertemplate='<b>%{text}</b><br>' +
                      f'{metric}: %{{y:.4f}}<br>' +
                      'Date: %{x}<br>' +
                      'Paper: %{customdata[0]}<br>' +
                      '<extra></extra>'
    ))

    # Add SOTA line
    fig.add_trace(go.Scatter(
        x=x_values,
        y=df_sorted['cumulative_best'],
        mode='lines',
        name=f'SOTA (cumulative {"min" if is_lower_better else "max"})',
        line=dict(color='#00CED1', width=2, dash='solid'),
        hovertemplate=f'SOTA {metric}: %{{y:.4f}}<br>{x_axis_title}: %{{x}}<extra></extra>'
    ))

    # Add labels for SOTA models
    if not sota_df.empty:
        # Calculate dynamic offset based on data range
        y_range = df_sorted[metric].max() - df_sorted[metric].min()

        # Use a percentage of the range for offset, with minimum and maximum bounds
        if y_range > 0:
            base_offset = y_range * 0.03  # 3% of the data range
            # Ensure minimum offset for readability and maximum to prevent excessive spacing
            label_offset = max(y_range * 0.01, min(base_offset, y_range * 0.08))
        else:
            # Fallback for when all values are the same
            label_offset = 1

        # Track label positions to prevent overlaps
        previous_labels = []
        # For date-based x-axis, use date separation
        try:
            date_range = (df_sorted['final_release_date'].max() - df_sorted['final_release_date'].min()).days
            min_separation = max(30, date_range * 0.05)  # Minimum 30 days or 5% of range
        except (TypeError, AttributeError):
            # Fallback if date calculation fails
            min_separation = 30

        for i, (_, row) in enumerate(sota_df.iterrows()):
            # Determine base label position based on metric type
            if is_lower_better:
                # For lower-better metrics, place label above the point (negative ay)
                base_ay_offset = -label_offset
                base_yshift = -8
                alternate_multiplier = -1
            else:
                # For higher-better metrics, place label below the point (positive ay)
                base_ay_offset = label_offset
                base_yshift = 8
                alternate_multiplier = 1

            # Check for collision with previous labels
            current_x = row['final_release_date']
            collision_detected = False

            for prev_x, prev_ay in previous_labels:
                try:
                    x_diff = abs((current_x - prev_x).days)
                    if x_diff < min_separation:
                        collision_detected = True
                        break
                except (TypeError, AttributeError):
                    # Skip collision detection if calculation fails
                    continue

            # Adjust position if collision detected
            if collision_detected:
                # Alternate the label position (above/below) to avoid overlap
                ay_offset = base_ay_offset + (alternate_multiplier * label_offset * 0.7 * (i % 2))
                yshift = base_yshift + (alternate_multiplier * 12 * (i % 2))
            else:
                ay_offset = base_ay_offset
                yshift = base_yshift

            # Add the annotation
            fig.add_annotation(
                x=current_x,
                y=row[metric],
                text=row['model_name'][:25] + '...' if len(row['model_name']) > 25 else row['model_name'],
                showarrow=True,
                arrowhead=2,
                arrowsize=1,
                arrowwidth=1,
                arrowcolor='#00CED1',  # Match the SOTA line color
                ax=0,
                ay=ay_offset,  # Dynamic offset based on data range and collision detection
                yshift=yshift,  # Fine-tune positioning
                font=dict(size=8, color='#333333'),
                bgcolor='rgba(255, 255, 255, 0.9)',  # Semi-transparent background
                borderwidth=0  # Remove border
            )

            # Track this label position
            previous_labels.append((current_x, ay_offset))

    # Update layout
    fig.update_layout(
        title=f'SOTA Evolution: {metric}',
        xaxis_title=x_axis_title,
        yaxis_title=metric,
        xaxis=dict(showgrid=True, gridcolor='lightgray'),
        yaxis=dict(showgrid=True, gridcolor='lightgray'),
        plot_bgcolor='white',
        paper_bgcolor='white',
        height=600,
        legend=dict(yanchor="top", y=0.99, xanchor="left", x=0.01),
        hovermode='closest'
    )

    # Clear the DataFrame from memory after plotting
    del df_clean
    del df_sorted
    del sota_df

    return fig


# Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸ“Š Papers with Code - SOTA Evolution Visualizer")
    gr.Markdown(
        "Navigate through ML tasks and datasets to visualize the evolution of state-of-the-art models over time.")
    gr.Markdown("*Optimized for low memory usage - data is loaded on-demand from disk*")

    # Status
    with gr.Row():
        gr.Markdown(f"""
        <div style="background-color: #f0f9ff; border-left: 4px solid #00CED1; padding: 10px; margin: 10px 0;">
        <b>πŸ’Ύ Disk-Based Storage Active</b><br>
        β€’ <b>{len(TASKS)}</b> tasks indexed<br>
        β€’ <b>{len(METRICS_INDEX)}</b> unique metrics tracked<br>
        β€’ Data loaded on-demand to minimize RAM usage
        </div>
        """)

    # State variables
    current_df = gr.State(pd.DataFrame())
    current_task = gr.State(None)

    # Navigation dropdowns
    with gr.Row():
        task_dropdown = gr.Dropdown(
            choices=get_tasks(),
            label="Select Task",
            interactive=True
        )
        category_dropdown = gr.Dropdown(
            choices=[],
            label="Categories (info only)",
            interactive=False
        )

    with gr.Row():
        dataset_dropdown = gr.Dropdown(
            choices=[],
            label="Select Dataset",
            interactive=True
        )
        metric_dropdown = gr.Dropdown(
            choices=[],
            label="Select Metric",
            interactive=True
        )

    # Info display
    info_text = gr.Markdown("πŸ‘† Please select a task to begin")

    # Plot
    plot = gr.Plot(label="SOTA Evolution")

    # Data display
    with gr.Row():
        show_data_btn = gr.Button("πŸ“‹ Show/Hide Model Data")
        export_btn = gr.Button("πŸ’Ύ Export Current Data (CSV)")
        clear_memory_btn = gr.Button("🧹 Clear Memory", variant="secondary")

    df_display = gr.Dataframe(
        label="Model Data",
        visible=False
    )


    # Update functions
    def update_task_selection(task):
        """Update dropdowns when task is selected."""
        if not task:
            return [], [], [], "πŸ‘† Please select a task to begin", pd.DataFrame(), None, None

        # Load task data from disk
        categories = get_categories(task)
        datasets = get_datasets_for_task(task)

        info = f"### πŸ“‚ **Task:** {task}\n"
        if categories:
            info += f"- **Categories:** {', '.join(categories[:3])}{'...' if len(categories) > 3 else ''} ({len(categories)} total)\n"

        return (
            gr.Dropdown(choices=categories, value=categories[0] if categories else None),
            gr.Dropdown(choices=datasets, value=None),
            gr.Dropdown(choices=[], value=None),
            info,
            pd.DataFrame(),
            None,
            task  # Store current task
        )


    def update_dataset_selection(task, dataset_name):
        """Update when dataset is selected - loads from disk."""
        if not task or not dataset_name:
            return [], "", pd.DataFrame(), None

        # Load dataset from disk
        df = get_cached_model_data(task, dataset_name)

        if df.empty:
            return [], f"⚠️ No models found for dataset: {dataset_name}", df, None

        # Get metric columns
        exclude_cols = ['model_name', 'release_date', 'paper_date', 'paper_url', 'paper_title', 'code_url']
        metric_cols = [col for col in df.columns if col not in exclude_cols]

        info = f"### πŸ“Š **Dataset:** {dataset_name}\n"
        info += f"- **Models:** {len(df)} models\n"
        info += f"- **Metrics:** {len(metric_cols)} metrics available\n"
        if not df.empty:
            info += f"- **Date Range:** {df['release_date'].min().strftime('%Y-%m-%d')} to {df['release_date'].max().strftime('%Y-%m-%d')}\n"

        if metric_cols:
            info += f"- **Available Metrics:** {', '.join(metric_cols[:5])}{'...' if len(metric_cols) > 5 else ''}"

        return (
            gr.Dropdown(choices=metric_cols, value=metric_cols[0] if metric_cols else None),
            info,
            df,
            None
        )


    def update_plot(df, metric):
        """Update plot when metric is selected."""
        if df.empty or not metric:
            return None
        plot_result = create_sota_plot(df, metric)
        return plot_result


    def toggle_dataframe(df):
        """Toggle dataframe visibility."""
        if df.empty:
            return gr.Dataframe(value=pd.DataFrame(), visible=False)
        # Show relevant columns
        display_cols = ['model_name', 'release_date'] + [col for col in df.columns
                                                         if col not in ['model_name', 'release_date', 'paper_date',
                                                                        'paper_url',
                                                                        'paper_title', 'code_url']]
        display_df = df[display_cols].copy()
        display_df['release_date'] = display_df['release_date'].dt.strftime('%Y-%m-%d')
        return gr.Dataframe(value=display_df, visible=True)


    def export_data(df):
        """Export current dataframe to CSV."""
        if df.empty:
            return "⚠️ No data to export"

        filename = f"sota_export_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv"
        df.to_csv(filename, index=False)
        return f"βœ… Data exported to {filename} ({len(df)} models)"


    def clear_memory():
        """Clear memory by forcing garbage collection."""
        import gc
        gc.collect()
        return "βœ… Memory cleared"


    # Event handlers
    task_dropdown.change(
        fn=update_task_selection,
        inputs=task_dropdown,
        outputs=[category_dropdown, dataset_dropdown,
                 metric_dropdown, info_text, current_df, plot, current_task]
    )

    dataset_dropdown.change(
        fn=update_dataset_selection,
        inputs=[task_dropdown, dataset_dropdown],
        outputs=[metric_dropdown, info_text, current_df, plot]
    )

    metric_dropdown.change(
        fn=update_plot,
        inputs=[current_df, metric_dropdown],
        outputs=plot
    )

    show_data_btn.click(
        fn=toggle_dataframe,
        inputs=current_df,
        outputs=df_display
    )

    export_btn.click(
        fn=export_data,
        inputs=current_df,
        outputs=info_text
    )

    clear_memory_btn.click(
        fn=clear_memory,
        inputs=[],
        outputs=info_text
    )

    gr.Markdown("""
    ---
    ### πŸ“– How to Use
    1. **Select a Task** from the first dropdown
    2. **Select a Dataset** to analyze
    3. **Select a Metric** to visualize
    4. The plot shows SOTA model evolution over time with dynamically calculated dates

    ### πŸ’Ύ Memory Optimization
    - Data is stored on disk and loaded on-demand
    - Only the current task and dataset are kept in memory
    - Use "Clear Memory" button if needed
    - Infinite disk space is utilized for permanent caching

    ### 🎨 Plot Features
    - **πŸ”΅ Cyan dots**: SOTA models when released
    - **βšͺ Gray dots**: Other models
    - **πŸ“ˆ Cyan line**: SOTA progression
    - **πŸ” Hover**: View model details
    - **🏷️ Smart Labels**: SOTA model labels positioned close to the line with intelligent collision detection
    """)


def test_sota_label_positioning():
    """Test function to validate SOTA label positioning improvements."""
    print("πŸ§ͺ Testing SOTA label positioning...")

    # Create sample data for testing
    import pandas as pd
    from datetime import datetime

    # Test data with different metric types (including all required columns)
    test_data = {
        'model_name': ['Model A', 'Model B', 'Model C', 'Model D'],
        'release_date': [
            datetime(2020, 1, 1),
            datetime(2020, 6, 1),
            datetime(2021, 1, 1),
            datetime(2021, 6, 1)
        ],
        'paper_title': ['Paper A', 'Paper B', 'Paper C', 'Paper D'],
        'paper_url': ['http://example.com/a', 'http://example.com/b', 'http://example.com/c', 'http://example.com/d'],
        'code_url': ['http://github.com/a', 'http://github.com/b', 'http://github.com/c', 'http://github.com/d'],
        'accuracy': [0.85, 0.87, 0.90, 0.92],  # Higher-better metric
        'error_rate': [0.15, 0.13, 0.10, 0.08]  # Lower-better metric
    }

    df_test = pd.DataFrame(test_data)

    # Test with higher-better metric (accuracy)
    print("  Testing with higher-better metric (accuracy)...")
    try:
        fig1 = create_sota_plot(df_test, 'accuracy')
        print("  βœ… Higher-better metric test passed")
    except Exception as e:
        print(f"  ❌ Higher-better metric test failed: {e}")

    # Test with lower-better metric (error_rate)
    print("  Testing with lower-better metric (error_rate)...")
    try:
        fig2 = create_sota_plot(df_test, 'error_rate')
        print("  βœ… Lower-better metric test passed")
    except Exception as e:
        print(f"  ❌ Lower-better metric test failed: {e}")

    # Test with empty data
    print("  Testing with empty dataframe...")
    try:
        fig3 = create_sota_plot(pd.DataFrame(), 'test_metric')
        print("  βœ… Empty data test passed")
    except Exception as e:
        print(f"  ❌ Empty data test failed: {e}")

    # Test with string metric data (should handle gracefully)
    print("  Testing with string metric data...")
    try:
        df_test_string = df_test.copy()
        df_test_string['string_metric'] = ['low', 'medium', 'high', 'very_high']
        fig4 = create_sota_plot(df_test_string, 'string_metric')
        print("  βœ… String metric test passed (handled gracefully)")
    except Exception as e:
        print(f"  ❌ String metric test failed: {e}")

    # Test with mixed numeric/string data
    print("  Testing with mixed data types...")
    try:
        df_test_mixed = df_test.copy()
        df_test_mixed['mixed_metric'] = [0.85, 'N/A', 0.90, 0.92]
        fig5 = create_sota_plot(df_test_mixed, 'mixed_metric')
        print("  βœ… Mixed data test passed")
    except Exception as e:
        print(f"  ❌ Mixed data test failed: {e}")

    # Test with paper_date parsing
    print("  Testing with paper_date column...")
    try:
        df_test_dates = df_test.copy()
        df_test_dates['paper_date'] = ['2015-03-15', '2018-invalid', '2021-12-01', '2022']
        fig6 = create_sota_plot(df_test_dates, 'accuracy')
        print("  βœ… Paper date parsing test passed")
    except Exception as e:
        print(f"  ❌ Paper date parsing test failed: {e}")

    print("πŸŽ‰ SOTA label positioning tests completed!")
    return True

demo.launch()