Spaces:
Running
Running
| from pathlib import Path | |
| import re | |
| from datasets import load_dataset | |
| import json | |
| import gradio as gr | |
| from matplotlib import pyplot as plt | |
| import pandas as pd | |
| HEAD_HTML = """ | |
| <link href='https://fonts.googleapis.com/css?family=PT Mono' rel='stylesheet'> | |
| """ | |
| def normalize_spaces(text): | |
| return '\n'.join(re.sub(r" {2,}", " ", line) for line in text.split('\n')) | |
| def load_json(file_path): | |
| with open(file_path, "r") as file: | |
| return json.load(file) | |
| def on_select(evt: gr.SelectData, current_split): | |
| item_id = evt.row_value[0] | |
| filename = evt.row_value[1] | |
| output_methods = [] | |
| for method in METHOD_LIST: | |
| output_methods.extend( | |
| [ | |
| item_by_id_dict[current_split][filename][method], | |
| evaluation_dict[current_split][method][filename]["pred"], | |
| evaluation_dict[current_split][method][filename]["score"] == 1, | |
| ] | |
| ) | |
| return output_methods + [ | |
| item_by_id_dict[current_split][filename]["image"], | |
| input_dataframe[current_split]["questions"][item_id], | |
| input_dataframe[current_split]["answers"][item_id], | |
| ] | |
| def on_dataset_change(current_split): | |
| # update dataframe, plot based on the selected dataset | |
| plot = generate_plot( | |
| providers=METHOD_LIST, | |
| scores=[ | |
| method_scores[current_split][method] | |
| for method in METHOD_LIST | |
| ], | |
| ) | |
| dataframe = pd.DataFrame(input_dataframe[current_split]) | |
| return plot, dataframe | |
| def generate_plot(providers, scores): | |
| fig, ax = plt.subplots(figsize=(4, 3)) | |
| bars = ax.barh(providers[::-1], scores[::-1]) | |
| min_score = min(scores) | |
| max_score = max(scores) | |
| # Customize plot | |
| ax.set_title("Methods Average Scores") | |
| ax.set_ylabel("Methods") | |
| ax.set_xlabel("Scores") | |
| ax.set_xlim(min_score - 0.1, min(max_score + 0.1, 1.0)) | |
| for bar in bars: | |
| width = bar.get_width() | |
| ax.text( | |
| width, | |
| bar.get_y() + bar.get_height() / 2.0, | |
| f"{width:.3f}", | |
| ha="left", | |
| va="center", | |
| ) | |
| plt.tight_layout() | |
| return fig | |
| evaluation_json_dir = Path("eval_output") | |
| dataset = load_dataset(path="terryoo/TableVQA-Bench") | |
| SPLIT_NAMES = ["fintabnetqa", "vwtq_syn"] | |
| DEFAULT_SPLIT_NAME = "fintabnetqa" | |
| METHOD_LIST = ["text_2d", "text_1d", "html"] | |
| item_by_id_dict = {} | |
| input_dataframe = {} | |
| evaluation_dict = {} | |
| method_scores = {} | |
| for split_name in SPLIT_NAMES: | |
| input_text_path = Path( | |
| f"dataset_tablevqa_{split_name}_2d_text" | |
| ) | |
| item_by_id_dict[split_name] = {} | |
| input_dataframe[split_name] = { | |
| "ids": [], | |
| "filenames": [], | |
| "questions": [], | |
| "answers": [], | |
| } | |
| evaluation_dict[split_name] = {} | |
| method_scores[split_name] = {} | |
| for idx, sample in enumerate(dataset[split_name]): | |
| sample_id = sample["qa_id"] | |
| text_path = input_text_path / f"{sample_id}.txt" | |
| with open(text_path, "r") as f: | |
| text_2d = f.read() | |
| item_by_id_dict[split_name][sample_id] = { | |
| "text_2d": text_2d, | |
| "text_1d": normalize_spaces(text_2d), | |
| "image": sample["image"], | |
| "html": sample["text_html_table"], | |
| } | |
| input_dataframe[split_name]["ids"].append(idx) | |
| input_dataframe[split_name]["filenames"].append(sample_id) | |
| input_dataframe[split_name]["questions"].append(sample["question"]) | |
| input_dataframe[split_name]["answers"].append(sample["gt"]) | |
| for method in METHOD_LIST: | |
| evaluation_json_path = evaluation_json_dir / f"{split_name}_{method}.json" | |
| evaluation_data = load_json(evaluation_json_path) | |
| evaluation_dict[split_name][method] = { | |
| item["qa_id"]: { | |
| "pred": item["pred"], | |
| "score": item["scores"]["a"], | |
| } | |
| for item in evaluation_data["instances"] | |
| } | |
| method_scores[split_name][method] = round( | |
| evaluation_data["evaluation_meta"]["average_scores"][0] / 100, | |
| 2, | |
| ) | |
| with gr.Blocks( | |
| theme=gr.themes.Ocean( | |
| font_mono="PT Mono", | |
| ), | |
| head=HEAD_HTML, | |
| ) as demo: | |
| gr.Markdown( | |
| "# 2D Layout-Preserving Text Benchmark\n" | |
| "Dataset: [TableVQA-Bench](https://huggingface.co/datasets/terryoo/TableVQA-Bench)\n" | |
| ) | |
| dataset_name = gr.Dropdown( | |
| label="Dataset split", | |
| value=DEFAULT_SPLIT_NAME, | |
| choices=["fintabnetqa", "vwtq_syn"], | |
| ) | |
| gr.Markdown("### File List") | |
| plot_avg = gr.Plot( | |
| label="Average scores", | |
| value=generate_plot( | |
| providers=METHOD_LIST, | |
| scores=[ | |
| method_scores[DEFAULT_SPLIT_NAME][method] | |
| for method in METHOD_LIST | |
| ], | |
| ), | |
| container=False, | |
| ) | |
| file_list = gr.Dataframe( | |
| value=pd.DataFrame(input_dataframe[DEFAULT_SPLIT_NAME]), | |
| max_height=300, | |
| show_row_numbers=False, | |
| show_search=True, | |
| column_widths=["10%", "30%", "30%", "30%"], | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| demo_image = gr.Image( | |
| label="Input Image", | |
| interactive=False, | |
| height=400, | |
| width=600, | |
| ) | |
| with gr.Column(): | |
| question = gr.Textbox( | |
| label="Question", | |
| interactive=False, | |
| ) | |
| answer_gt = gr.Textbox( | |
| label="GT Answer", | |
| interactive=False, | |
| ) | |
| output_elements = [] | |
| with gr.Tabs(): | |
| for method in METHOD_LIST: | |
| with gr.Tab(method): | |
| if "html" in method: | |
| output = gr.HTML( | |
| container=False, | |
| show_label=False, | |
| ) | |
| else: | |
| output = gr.Code( | |
| container=False, | |
| language="markdown", | |
| show_line_numbers=False, | |
| ) | |
| pred = gr.Textbox( | |
| label="Predicted Answer", | |
| interactive=False, | |
| ) | |
| score = gr.Textbox( | |
| label="Score", | |
| interactive=False, | |
| ) | |
| output_elements.extend([output, pred, score]) | |
| file_list.select( | |
| fn=on_select, | |
| inputs=[dataset_name], | |
| outputs=output_elements + | |
| [ | |
| demo_image, | |
| question, | |
| answer_gt | |
| ], | |
| ) | |
| dataset_name.change( | |
| fn=on_dataset_change, | |
| inputs=dataset_name, | |
| outputs=[plot_avg, file_list], | |
| ) | |
| demo.launch() | |