kt-test-account commited on
Commit
e6d7498
·
1 Parent(s): 3774694

moving parts to fragments

Browse files
Files changed (1) hide show
  1. app.py +60 -38
app.py CHANGED
@@ -119,13 +119,19 @@ def get_unique_teams(teams):
119
 
120
  @st.cache_data
121
  def filter_teams(temp, selected_team):
122
- return temp.query(f"team=='{selected_team}'")
 
123
 
124
  def make_roc_curves(task, submission_ids):
125
 
126
  rocs = load_roc_file(task, submission_ids)
127
 
128
- roc_chart = alt.Chart(rocs).mark_line().encode(x="fpr", y="tpr", color="team:N", detail="submission_id:N")
 
 
 
 
 
129
 
130
  return roc_chart
131
 
@@ -145,6 +151,7 @@ st.set_page_config(
145
  with st.sidebar:
146
 
147
  hf_token = os.getenv("HF_TOKEN")
 
148
  password = st.text_input("Admin login:", type="password")
149
 
150
  if password == hf_token:
@@ -189,8 +196,14 @@ with st.sidebar:
189
  else:
190
  split = "public"
191
 
 
 
192
 
193
- def show_leaderboard(results, task):
 
 
 
 
194
  source_split_map = {}
195
  if split == "private":
196
  _sol_df = pd.read_csv(COMP_CACHE / task / "solution.csv")
@@ -410,7 +423,7 @@ def make_acc(results):
410
  alt.Chart(results)
411
  .mark_circle(size=200)
412
  .encode(
413
- x=alt.X("total_time:Q", title="🕒 Inference Time", scale=alt.Scale(type = "log")),
414
  y=alt.Y(
415
  "balanced_accuracy:Q",
416
  title="Balanced Accuracy",
@@ -421,7 +434,7 @@ def make_acc(results):
421
  .properties(width=400, height=400, title="Inference Time vs Balanced Accuracy")
422
  )
423
  diag_line = (
424
- alt.Chart(pd.DataFrame(dict(t=[100, results["total_time"].max()], y=[0.5, 0.5])))
425
  .mark_line(color="lightgray", strokeDash=[8, 4])
426
  .encode(x="t", y="y")
427
  )
@@ -440,52 +453,61 @@ def get_heatmaps(temp):
440
  st.altair_chart(h3, use_container_width=True)
441
 
442
 
443
- def make_plots_for_task(task, split, best_only):
444
-
445
-
446
- results = load_results(task, best_only=best_only)
 
447
  temp = results[f"{split}_score"].reset_index()
448
  teams = get_unique_teams(temp["team"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
449
 
 
 
 
450
 
451
- t1, t2 = st.tabs(["Tables", "Charts"])
452
- with t1:
453
- show_leaderboard(results, task)
454
 
455
- with t2:
456
- # st.write(temp)
457
- if split == "private":
458
- best_only = st.toggle("Best Only", value=True, key = f"best only {task}")
459
- if not best_only:
460
- results = load_results(task, best_only=best_only)
461
- temp = results[f"{split}_score"].reset_index()
462
- selected_team = st.pills("Team",["ALL"] + teams, key = f"teams {task}", default="ALL")
463
- if not selected_team:
464
- selected_team = "ALL"
465
- if selected_team != "ALL":
466
- temp = filter_teams(temp, selected_team)
467
 
468
- # with st.spinner("making plots...", show_time=True):
469
- roc_scatter = make_roc(temp)
470
- acc_vs_time = make_acc(temp)
471
 
472
- if split == "private" and hf_token is not None:
473
- full_curves = st.toggle("Full curve", value=True, key=f"all curves {task}")
474
 
475
- if full_curves:
476
- roc_scatter = make_roc_curves(task, temp["submission_id"].values.tolist()) + roc_scatter
477
 
478
- st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
479
- else:
480
- st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
481
 
482
- st.info(f"loading {temp['submission_id'].nunique()} submissions")
 
 
 
 
 
 
483
 
484
 
485
 
486
  updated = get_updated_time()
487
  st.markdown(updated)
488
- best_only = True
489
 
490
 
491
  tp, t1, volume_tab, all_submission_tab = st.tabs(
@@ -493,10 +515,10 @@ tp, t1, volume_tab, all_submission_tab = st.tabs(
493
  )
494
  with tp:
495
  "*Detection of Synthetic Video Content. Video files are unmodified from the original output from the models or the real sources.*"
496
- make_plots_for_task(TASKS[0], split, best_only)
497
  with t1:
498
  "*Detection of Synthetic Video Content. Video files are unmodified from the original output from the models or the real sources.*"
499
- make_plots_for_task(TASKS[1], split, best_only)
500
 
501
  with volume_tab:
502
  subs = get_volume()
 
119
 
120
  @st.cache_data
121
  def filter_teams(temp, selected_team):
122
+ mask = temp.loc[:,"team"].isin(selected_team)
123
+ return temp.loc[mask]
124
 
125
  def make_roc_curves(task, submission_ids):
126
 
127
  rocs = load_roc_file(task, submission_ids)
128
 
129
+ # if rocs["team"].nunique() > 1:
130
+ color_field = "team:N"
131
+ # else:
132
+ # color_field = "submission_id:N"
133
+
134
+ roc_chart = alt.Chart(rocs).mark_line().encode(x="fpr", y="tpr", color=color_field, detail="submission_id:N")
135
 
136
  return roc_chart
137
 
 
151
  with st.sidebar:
152
 
153
  hf_token = os.getenv("HF_TOKEN")
154
+ st.session_state["hf_token"] = hf_token
155
  password = st.text_input("Admin login:", type="password")
156
 
157
  if password == hf_token:
 
196
  else:
197
  split = "public"
198
 
199
+ st.session_state["split"] = split
200
+
201
 
202
+ @st.fragment
203
+ def show_leaderboard(task):
204
+ split = st.session_state.get("split","public")
205
+ results = load_results(task, best_only=True)
206
+ temp = results[f"{split}_score"].reset_index()
207
  source_split_map = {}
208
  if split == "private":
209
  _sol_df = pd.read_csv(COMP_CACHE / task / "solution.csv")
 
423
  alt.Chart(results)
424
  .mark_circle(size=200)
425
  .encode(
426
+ x=alt.X("total_time:Q", title="🕒 Inference Time (sec)", scale=alt.Scale(type = "log",domain = [100,100000])),
427
  y=alt.Y(
428
  "balanced_accuracy:Q",
429
  title="Balanced Accuracy",
 
434
  .properties(width=400, height=400, title="Inference Time vs Balanced Accuracy")
435
  )
436
  diag_line = (
437
+ alt.Chart(pd.DataFrame(dict(t=[100, 100000], y=[0.5, 0.5])))
438
  .mark_line(color="lightgray", strokeDash=[8, 4])
439
  .encode(x="t", y="y")
440
  )
 
453
  st.altair_chart(h3, use_container_width=True)
454
 
455
 
456
+ @st.fragment
457
+ def show_charts(task):
458
+ split = st.session_state.get("split","public")
459
+ hf_token = st.session_state.get("hf_token",None)
460
+ results = load_results(task, best_only=True)
461
  temp = results[f"{split}_score"].reset_index()
462
  teams = get_unique_teams(temp["team"])
463
+
464
+ if split == "private":
465
+ best_only = st.toggle("Best Only", value=True, key = f"best only {task}")
466
+ if not best_only:
467
+ results = load_results(task, best_only=best_only)
468
+ temp = results[f"{split}_score"].reset_index()
469
+ selected_team = st.pills("Team",["ALL"] + teams, key = f"teams {task}", default=["ALL"],selection_mode="multi")
470
+
471
+ if selected_team is None or len(selected_team) == 0:
472
+ return
473
+
474
+ if "ALL" in selected_team:
475
+ selected_team = ["ALL"]
476
+
477
+ if "ALL" not in selected_team:
478
+ temp = filter_teams(temp, selected_team)
479
 
480
+ # with st.spinner("making plots...", show_time=True):
481
+ roc_scatter = make_roc(temp)
482
+ acc_vs_time = make_acc(temp)
483
 
484
+ if split == "private" and hf_token is not None:
485
+ full_curves = st.toggle("Full curve", value=True, key=f"all curves {task}")
 
486
 
487
+ if full_curves:
488
+ roc_scatter = make_roc_curves(task, temp["submission_id"].values.tolist()) + roc_scatter
 
 
 
 
 
 
 
 
 
 
489
 
490
+ st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
491
+ else:
492
+ st.altair_chart(roc_scatter | acc_vs_time, use_container_width=False)
493
 
494
+ st.info(f"loading {temp['submission_id'].nunique()} submissions")
 
495
 
 
 
496
 
497
+ def make_plots_for_task(task):
 
 
498
 
499
+ t1, t2 = st.tabs(["Tables", "Charts"])
500
+ with t1:
501
+ show_leaderboard(task)
502
+
503
+ with t2:
504
+ show_charts(task)
505
+
506
 
507
 
508
 
509
  updated = get_updated_time()
510
  st.markdown(updated)
 
511
 
512
 
513
  tp, t1, volume_tab, all_submission_tab = st.tabs(
 
515
  )
516
  with tp:
517
  "*Detection of Synthetic Video Content. Video files are unmodified from the original output from the models or the real sources.*"
518
+ make_plots_for_task(TASKS[0])
519
  with t1:
520
  "*Detection of Synthetic Video Content. Video files are unmodified from the original output from the models or the real sources.*"
521
+ make_plots_for_task(TASKS[1])
522
 
523
  with volume_tab:
524
  subs = get_volume()