Spaces:
Runtime error
Runtime error
| from datetime import datetime | |
| import numpy as np | |
| import json | |
| import re | |
| import heapq | |
| from collections import defaultdict | |
| import tempfile | |
| from typing import Dict, Tuple, List, Literal | |
| import gradio as gr | |
| from datatrove.utils.stats import MetricStatsDict | |
| from src.logic.graph_settings import Grouping | |
| PARTITION_OPTIONS = Literal["Top", "Bottom", "Most frequent (n_docs)"] | |
| def prepare_for_non_grouped_plotting(metric: Dict[str, MetricStatsDict], normalization: bool, rounding: int) -> Dict[float, float]: | |
| keys = np.array([float(key) for key in metric.keys()]) | |
| values = np.array([value.total for value in metric.values()]) | |
| rounded_keys = np.round(keys, rounding) | |
| unique_keys, indices = np.unique(rounded_keys, return_inverse=True) | |
| metrics_rounded = np.zeros_like(unique_keys, dtype=float) | |
| np.add.at(metrics_rounded, indices, values) | |
| if normalization: | |
| normalizer = np.sum(metrics_rounded) | |
| metrics_rounded /= normalizer | |
| return dict(zip(unique_keys, metrics_rounded)) | |
| def prepare_for_group_plotting(metric: Dict[str, MetricStatsDict], top_k: int, direction: PARTITION_OPTIONS, regex: str | None, rounding: int) -> Tuple[List[str], List[float], List[float]]: | |
| regex_compiled = re.compile(regex) if regex else None | |
| filtered_metric = {key: value for key, value in metric.items() if not regex or regex_compiled.match(key)} | |
| keys = np.array(list(filtered_metric.keys())) | |
| means = np.array([float(value.mean) for value in filtered_metric.values()]) | |
| stds = np.array([value.standard_deviation for value in filtered_metric.values()]) | |
| rounded_means = np.round(means, rounding) | |
| if direction == "Top": | |
| top_indices = np.argsort(rounded_means)[-top_k:][::-1] | |
| elif direction == "Most frequent (n_docs)": | |
| totals = np.array([int(value.n) for value in filtered_metric.values()]) | |
| top_indices = np.argsort(totals)[-top_k:][::-1] | |
| else: | |
| top_indices = np.argsort(rounded_means)[:top_k] | |
| top_keys = keys[top_indices] | |
| top_means = rounded_means[top_indices] | |
| top_stds = stds[top_indices] | |
| return top_keys.tolist(), top_means.tolist(), top_stds.tolist() | |
| def export_data(exported_data: Dict[str, MetricStatsDict], metric_name: str, grouping: Grouping): | |
| if not exported_data: | |
| return None | |
| file_name = f"{metric_name}_{grouping}_{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json" | |
| with open(file_name, 'w') as f: | |
| json.dump({ | |
| name: sorted([{"value": key, **value} for key, value in dt.to_dict().items()], key=lambda x: x["value"]) | |
| for name, dt in exported_data.items() | |
| }, f, indent=2) | |
| return gr.File(value=file_name, visible=True) |