Spaces:
Runtime error
Runtime error
| import dataclasses | |
| import re | |
| import copy | |
| import yaml | |
| import argparse | |
| from pathlib import Path | |
| from dataclasses import dataclass, field | |
| from typing import Any, Iterable, List, NewType, Optional, Tuple, Union, Dict | |
| from transformers.hf_argparser import HfArgumentParser as ArgumentParser | |
| DataClass = NewType("DataClass", Any) | |
| DataClassType = NewType("DataClassType", Any) | |
| def lambda_field(default, **kwargs): | |
| return field(default_factory=lambda: copy.copy(default)) | |
| class HfArgumentParser(ArgumentParser): | |
| def parse_yaml_file(self, yaml_file: str) -> Tuple[DataClass, ...]: | |
| """ | |
| Alternative helper method that does not use `argparse` at all, instead loading a yaml file and populating the | |
| dataclass types. | |
| """ | |
| # https://stackoverflow.com/questions/30458977/yaml-loads-5e-6-as-string-and-not-a-number | |
| loader = yaml.SafeLoader | |
| loader.add_implicit_resolver( | |
| u'tag:yaml.org,2002:float', | |
| re.compile(u'''^(?: | |
| [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? | |
| |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) | |
| |\\.[0-9_]+(?:[eE][-+][0-9]+)? | |
| |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* | |
| |[-+]?\\.(?:inf|Inf|INF) | |
| |\\.(?:nan|NaN|NAN))$''', re.X), | |
| list(u'-+0123456789.')) | |
| data = yaml.load(Path(yaml_file).read_text(), Loader=loader) | |
| outputs = [] | |
| for dtype in self.dataclass_types: | |
| keys = {f.name for f in dataclasses.fields(dtype) if f.init} | |
| arg_name = dtype.__mro__[-2].__name__ | |
| inputs = {k: v for k, v in data[arg_name].items() if k in keys} | |
| obj = dtype(**inputs) | |
| outputs.append(obj) | |
| return (*outputs,) |