Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import csv | |
| from collections import defaultdict | |
| from dataclasses import dataclass, field | |
| from typing import List, Optional | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from matplotlib.ticker import ScalarFormatter | |
| from transformers import HfArgumentParser | |
| def list_field(default=None, metadata=None): | |
| return field(default_factory=lambda: default, metadata=metadata) | |
| class PlotArguments: | |
| """ | |
| Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | |
| """ | |
| csv_file: str = field( | |
| metadata={"help": "The csv file to plot."}, | |
| ) | |
| plot_along_batch: bool = field( | |
| default=False, | |
| metadata={"help": "Whether to plot along batch size or sequence length. Defaults to sequence length."}, | |
| ) | |
| is_time: bool = field( | |
| default=False, | |
| metadata={"help": "Whether the csv file has time results or memory results. Defaults to memory results."}, | |
| ) | |
| no_log_scale: bool = field( | |
| default=False, | |
| metadata={"help": "Disable logarithmic scale when plotting"}, | |
| ) | |
| is_train: bool = field( | |
| default=False, | |
| metadata={ | |
| "help": "Whether the csv file has training results or inference results. Defaults to inference results." | |
| }, | |
| ) | |
| figure_png_file: Optional[str] = field( | |
| default=None, | |
| metadata={"help": "Filename under which the plot will be saved. If unused no plot is saved."}, | |
| ) | |
| short_model_names: Optional[List[str]] = list_field( | |
| default=None, metadata={"help": "List of model names that are used instead of the ones in the csv file."} | |
| ) | |
| def can_convert_to_int(string): | |
| try: | |
| int(string) | |
| return True | |
| except ValueError: | |
| return False | |
| def can_convert_to_float(string): | |
| try: | |
| float(string) | |
| return True | |
| except ValueError: | |
| return False | |
| class Plot: | |
| def __init__(self, args): | |
| self.args = args | |
| self.result_dict = defaultdict(lambda: {"bsz": [], "seq_len": [], "result": {}}) | |
| with open(self.args.csv_file, newline="") as csv_file: | |
| reader = csv.DictReader(csv_file) | |
| for row in reader: | |
| model_name = row["model"] | |
| self.result_dict[model_name]["bsz"].append(int(row["batch_size"])) | |
| self.result_dict[model_name]["seq_len"].append(int(row["sequence_length"])) | |
| if can_convert_to_int(row["result"]): | |
| # value is not None | |
| self.result_dict[model_name]["result"][ | |
| (int(row["batch_size"]), int(row["sequence_length"])) | |
| ] = int(row["result"]) | |
| elif can_convert_to_float(row["result"]): | |
| # value is not None | |
| self.result_dict[model_name]["result"][ | |
| (int(row["batch_size"]), int(row["sequence_length"])) | |
| ] = float(row["result"]) | |
| def plot(self): | |
| fig, ax = plt.subplots() | |
| title_str = "Time usage" if self.args.is_time else "Memory usage" | |
| title_str = title_str + " for training" if self.args.is_train else title_str + " for inference" | |
| if not self.args.no_log_scale: | |
| # set logarithm scales | |
| ax.set_xscale("log") | |
| ax.set_yscale("log") | |
| for axis in [ax.xaxis, ax.yaxis]: | |
| axis.set_major_formatter(ScalarFormatter()) | |
| for model_name_idx, model_name in enumerate(self.result_dict.keys()): | |
| batch_sizes = sorted(set(self.result_dict[model_name]["bsz"])) | |
| sequence_lengths = sorted(set(self.result_dict[model_name]["seq_len"])) | |
| results = self.result_dict[model_name]["result"] | |
| (x_axis_array, inner_loop_array) = ( | |
| (batch_sizes, sequence_lengths) if self.args.plot_along_batch else (sequence_lengths, batch_sizes) | |
| ) | |
| label_model_name = ( | |
| model_name if self.args.short_model_names is None else self.args.short_model_names[model_name_idx] | |
| ) | |
| for inner_loop_value in inner_loop_array: | |
| if self.args.plot_along_batch: | |
| y_axis_array = np.asarray( | |
| [results[(x, inner_loop_value)] for x in x_axis_array if (x, inner_loop_value) in results], | |
| dtype=int, | |
| ) | |
| else: | |
| y_axis_array = np.asarray( | |
| [results[(inner_loop_value, x)] for x in x_axis_array if (inner_loop_value, x) in results], | |
| dtype=np.float32, | |
| ) | |
| (x_axis_label, inner_loop_label) = ( | |
| ("batch_size", "len") if self.args.plot_along_batch else ("in #tokens", "bsz") | |
| ) | |
| x_axis_array = np.asarray(x_axis_array, int)[: len(y_axis_array)] | |
| plt.scatter( | |
| x_axis_array, y_axis_array, label=f"{label_model_name} - {inner_loop_label}: {inner_loop_value}" | |
| ) | |
| plt.plot(x_axis_array, y_axis_array, "--") | |
| title_str += f" {label_model_name} vs." | |
| title_str = title_str[:-4] | |
| y_axis_label = "Time in s" if self.args.is_time else "Memory in MB" | |
| # plot | |
| plt.title(title_str) | |
| plt.xlabel(x_axis_label) | |
| plt.ylabel(y_axis_label) | |
| plt.legend() | |
| if self.args.figure_png_file is not None: | |
| plt.savefig(self.args.figure_png_file) | |
| else: | |
| plt.show() | |
| def main(): | |
| parser = HfArgumentParser(PlotArguments) | |
| plot_args = parser.parse_args_into_dataclasses()[0] | |
| plot = Plot(args=plot_args) | |
| plot.plot() | |
| if __name__ == "__main__": | |
| main() | |