Spaces:
Runtime error
Runtime error
| import os | |
| import time | |
| import json | |
| import math | |
| import copy | |
| import collections | |
| from typing import Optional, List, Dict, Tuple, Callable, Any, Union, NewType | |
| import numpy as np | |
| from tqdm import tqdm | |
| import datasets | |
| from transformers import AutoTokenizer | |
| from transformers.tokenization_utils_fast import PreTrainedTokenizerFast | |
| from transformers.utils import logging | |
| from transformers.trainer_utils import EvalLoopOutput, EvalPrediction | |
| from .args import ( | |
| HfArgumentParser, | |
| RetroArguments, | |
| TrainingArguments, | |
| ) | |
| from .base import BaseReader | |
| from . import constants as C | |
| from .preprocess import ( | |
| get_sketch_features, | |
| get_intensive_features | |
| ) | |
| from .metrics import ( | |
| compute_classification_metric, | |
| compute_squad_v2 | |
| ) | |
| DataClassType = NewType("DataClassType", Any) | |
| logger = logging.get_logger(__name__) | |
| class SketchReader(BaseReader): | |
| name: str = "sketch" | |
| def postprocess( | |
| self, | |
| output: Union[np.ndarray, EvalLoopOutput], | |
| eval_examples: datasets.Dataset, | |
| eval_dataset: datasets.Dataset, | |
| mode: str = "evaluate", | |
| ) -> Union[EvalPrediction, Dict[str, float]]: | |
| # External Front Verification (E-FV) | |
| if isinstance(output, EvalLoopOutput): | |
| logits = output.predictions | |
| else: | |
| logits = output | |
| example_id_to_index = {k: i for i, k in enumerate(eval_examples[C.ID_COLUMN_NAME])} | |
| features_per_example = collections.defaultdict(list) | |
| for i, feature in enumerate(eval_dataset): | |
| features_per_example[example_id_to_index[feature["example_id"]]].append(i) | |
| count_map = {k: len(v) for k, v in features_per_example.items()} | |
| logits_ans = np.zeros(len(count_map)) | |
| logits_na = np.zeros(len(count_map)) | |
| for example_index, example in enumerate(tqdm(eval_examples)): | |
| feature_indices = features_per_example[example_index] | |
| n_strides = count_map[example_index] | |
| logits_ans[example_index] += logits[example_index, 0] / n_strides | |
| logits_na[example_index] += logits[example_index, 1] / n_strides | |
| # Calculate E-FV score | |
| score_ext = logits_ans - logits_na | |
| # Save external front verification score | |
| final_map = dict(zip(eval_examples[C.ID_COLUMN_NAME], score_ext.tolist())) | |
| with open(os.path.join(self.args.output_dir, C.SCORE_EXT_FILE_NAME), "w") as writer: | |
| writer.write(json.dumps(final_map, indent=4) + "\n") | |
| if mode == "evaluate": | |
| return EvalPrediction( | |
| predictions=logits, label_ids=output.label_ids, | |
| ) | |
| else: | |
| return final_map | |
| class IntensiveReader(BaseReader): | |
| name: str = "intensive" | |
| def postprocess( | |
| self, | |
| output: EvalLoopOutput, | |
| eval_examples: datasets.Dataset, | |
| eval_dataset: datasets.Dataset, | |
| log_level: int = logging.WARNING, | |
| mode: str = "evaluate", | |
| ) -> Union[List[Dict[str, Any]], EvalPrediction]: | |
| # Internal Front Verification (I-FV) | |
| # Verification is already done inside the model | |
| # Post-processing: we match the start logits and end logits to answers in the original context. | |
| predictions, nbest_json, scores_diff_json = self.compute_predictions( | |
| eval_examples, | |
| eval_dataset, | |
| output.predictions, | |
| version_2_with_negative=self.data_args.version_2_with_negative, | |
| n_best_size=self.data_args.n_best_size, | |
| max_answer_length=self.data_args.max_answer_length, | |
| null_score_diff_threshold=self.data_args.null_score_diff_threshold, | |
| output_dir=self.args.output_dir, | |
| log_level=log_level, | |
| n_tops=(self.data_args.start_n_top, self.data_args.end_n_top), | |
| ) | |
| if mode == "retro_inference": | |
| return nbest_json, scores_diff_json | |
| # Format the result to the format the metric expects. | |
| if self.data_args.version_2_with_negative: | |
| formatted_predictions = [ | |
| {"id": k, "prediction_text": v, "no_answer_probability": scores_diff_json[k]} | |
| for k, v in predictions.items() | |
| ] | |
| else: | |
| formatted_predictions = [ | |
| {"id": k, "prediction_text": v} | |
| for k, v in predictions.items() | |
| ] | |
| if mode == "predict": | |
| return formatted_predictions | |
| else: | |
| references = [ | |
| {"id": ex[C.ID_COLUMN_NAME], "answers": ex[C.ANSWER_COLUMN_NAME]} | |
| for ex in eval_examples | |
| ] | |
| return EvalPrediction( | |
| predictions=formatted_predictions, label_ids=references | |
| ) | |
| def compute_predictions( | |
| self, | |
| examples: datasets.Dataset, | |
| features: datasets.Dataset, | |
| predictions: Tuple[np.ndarray, np.ndarray], | |
| version_2_with_negative: bool = False, | |
| n_best_size: int = 20, | |
| max_answer_length: int = 30, | |
| null_score_diff_threshold: float = 0.0, | |
| output_dir: Optional[str] = None, | |
| log_level: Optional[int] = logging.WARNING, | |
| n_tops: Tuple[int, int] = (-1, -1), | |
| use_choice_logits: bool = False, | |
| ): | |
| # Threshold-based Answerable Verification (TAV) | |
| if len(predictions) not in [2, 3]: | |
| raise ValueError("`predictions` should be a tuple with two or three elements " | |
| "(start_logits, end_logits, choice_logits).") | |
| all_start_logits, all_end_logits = predictions[:2] | |
| all_choice_logits = None | |
| if len(predictions) == 3: | |
| all_choice_logits = predictions[-1] | |
| # Build a map example to its corresponding features. | |
| example_id_to_index = {k: i for i, k in enumerate(examples[C.ID_COLUMN_NAME])} | |
| features_per_example = collections.defaultdict(list) | |
| for i, feature in enumerate(features): | |
| features_per_example[example_id_to_index[feature["example_id"]]].append(i) | |
| all_predictions = collections.OrderedDict() | |
| all_nbest_json = collections.OrderedDict() | |
| scores_diff_json = collections.OrderedDict() if version_2_with_negative else None | |
| # Logging. | |
| logger.setLevel(log_level) | |
| logger.info(f"Post-processing {len(examples)} example predictions split into {len(features)} features.") | |
| # Let's loop over all the examples! | |
| for example_index, example in enumerate(tqdm(examples)): | |
| # Those are the indices of the features associated to the current example. | |
| feature_indices = features_per_example[example_index] | |
| min_null_prediction = None | |
| prelim_predictions = [] | |
| # Looping through all the features associated to the current example. | |
| for feature_index in feature_indices: | |
| # We grab the predictions of the model for this feature. | |
| start_logits = all_start_logits[feature_index] | |
| end_logits = all_end_logits[feature_index] | |
| # score_null = s1 + e1 | |
| feature_null_score = start_logits[0] + end_logits[0] | |
| if all_choice_logits is not None: | |
| choice_logits = all_choice_logits[feature_index] | |
| if use_choice_logits: | |
| feature_null_score = choice_logits[1] | |
| # This is what will allow us to map some the positions | |
| # in our logits to span of texts in the original context. | |
| offset_mapping = features[feature_index]["offset_mapping"] | |
| # Optional `token_is_max_context`, | |
| # if provided we will remove answers that do not have the maximum context | |
| # available in the current feature. | |
| token_is_max_context = features[feature_index].get("token_is_max_context", None) | |
| # Update minimum null prediction. | |
| if ( | |
| min_null_prediction is None or | |
| min_null_prediction["score"] > feature_null_score | |
| ): | |
| min_null_prediction = { | |
| "offsets": (0, 0), | |
| "score": feature_null_score, | |
| "start_logit": start_logits[0], | |
| "end_logit": end_logits[0], | |
| } | |
| # Go through all possibilities for the {top k} greater start and end logits | |
| start_indexes = np.argsort(start_logits)[-1 : -n_best_size - 1 : -1].tolist() | |
| end_indexes = np.argsort(end_logits)[-1 : -n_best_size - 1 : -1].tolist() | |
| for start_index in start_indexes: | |
| for end_index in end_indexes: | |
| # Don't consider out-of-scope answers! | |
| # either because the indices are out of bounds | |
| # or correspond to part of the input_ids that are note in the context. | |
| if ( | |
| start_index >= len(offset_mapping) | |
| or end_index >= len(offset_mapping) | |
| or not offset_mapping[start_index] | |
| or not offset_mapping[end_index] | |
| ): | |
| continue | |
| # Don't consider answers with a length negative or > max_answer_length. | |
| if end_index < start_index or end_index - start_index + 1 > max_answer_length: | |
| continue | |
| # Don't consider answer that don't have the maximum context available | |
| # (if such information is provided). | |
| if token_is_max_context is not None and not token_is_max_context.get(str(start_index), False): | |
| continue | |
| prelim_predictions.append( | |
| { | |
| "offsets": (offset_mapping[start_index][0], offset_mapping[end_index][1]), | |
| "score": start_logits[start_index] + end_logits[end_index], | |
| "start_logit": start_logits[start_index], | |
| "end_logit": end_logits[end_index], | |
| } | |
| ) | |
| if version_2_with_negative: | |
| # Add the minimum null prediction | |
| prelim_predictions.append(min_null_prediction) | |
| null_score = min_null_prediction["score"] | |
| # Only keep the best `n_best_size` predictions | |
| predictions = sorted(prelim_predictions, key=lambda x: x["score"], reverse=True)[:n_best_size] | |
| # Add back the minimum null prediction if it was removed because of its low score. | |
| if version_2_with_negative and not any(p["offsets"] == (0, 0) for p in predictions): | |
| predictions.append(min_null_prediction) | |
| # Use the offsets to gather the answer text in the original context. | |
| context = example["context"] | |
| for pred in predictions: | |
| offsets = pred.pop("offsets") | |
| pred["text"] = context[offsets[0] : offsets[1]] | |
| # In the very rare edge case we have not a single non-null prediction, | |
| # we create a fake prediction to avoid failure. | |
| if len(predictions) == 0 or (len(predictions) == 1 and predictions[0]["text"] == ""): | |
| predictions.insert(0, {"text": "", "start_logit": 0.0, "end_logit": 0.0, "score": 0.0,}) | |
| # Compute the softmax of all scores | |
| # (we do it with numpy to stay independent from torch/tf) in this file, | |
| # using the LogSum trick). | |
| scores = np.array([pred.pop("score") for pred in predictions]) | |
| exp_scores = np.exp(scores - np.max(scores)) | |
| probs = exp_scores / exp_scores.sum() | |
| # Include the probabilities in our predictions. | |
| for prob, pred in zip(probs, predictions): | |
| pred["probability"] = prob | |
| # Pick the best prediction. If the null answer is not possible, this is easy. | |
| if not version_2_with_negative: | |
| all_predictions[example[C.ID_COLUMN_NAME]] = predictions[0]["text"] | |
| else: | |
| # Otherwise we first need to find the best non-empty prediction. | |
| i = 0 | |
| try: | |
| while predictions[i]["text"] == "": | |
| i += 1 | |
| except: | |
| i = 0 | |
| best_non_null_pred = predictions[i] | |
| # Then we compare to the null prediction using the threshold. | |
| score_diff = null_score - best_non_null_pred["start_logit"] - best_non_null_pred["end_logit"] | |
| scores_diff_json[example[C.ID_COLUMN_NAME]] = float(score_diff) # To be JSON-serializable. | |
| if score_diff > null_score_diff_threshold: | |
| all_predictions[example[C.ID_COLUMN_NAME]] = "" | |
| else: | |
| all_predictions[example[C.ID_COLUMN_NAME]] = best_non_null_pred["text"] | |
| # Make `predictions` JSON-serializable by casting np.float back to float. | |
| all_nbest_json[example[C.ID_COLUMN_NAME]] = [ | |
| {k: (float(v) if isinstance(v, (np.float16, np.float32, np.float64)) else v) for k, v in pred.items()} | |
| for pred in predictions | |
| ] | |
| # If we have an output_dir, let's save all those dicts. | |
| if output_dir is not None: | |
| if not os.path.isdir(output_dir): | |
| raise EnvironmentError(f"{output_dir} is not a directory.") | |
| prediction_file = os.path.join(output_dir, C.INTENSIVE_PRED_FILE_NAME) | |
| nbest_file = os.path.join(output_dir, C.NBEST_PRED_FILE_NAME) | |
| if version_2_with_negative: | |
| null_odds_file = os.path.join(output_dir, C.SCORE_DIFF_FILE_NAME) | |
| logger.info(f"Saving predictions to {prediction_file}.") | |
| with open(prediction_file, "w") as writer: | |
| writer.write(json.dumps(all_predictions, indent=4) + "\n") | |
| logger.info(f"Saving nbest_preds to {nbest_file}.") | |
| with open(nbest_file, "w") as writer: | |
| writer.write(json.dumps(all_nbest_json, indent=4) + "\n") | |
| if version_2_with_negative: | |
| logger.info(f"Saving null_odds to {null_odds_file}.") | |
| with open(null_odds_file, "w") as writer: | |
| writer.write(json.dumps(scores_diff_json, indent=4) + "\n") | |
| return all_predictions, all_nbest_json, scores_diff_json | |
| class RearVerifier: | |
| def __init__( | |
| self, | |
| beta1: int = 1, | |
| beta2: int = 1, | |
| best_cof: int = 1, | |
| thresh: float = 0.0, | |
| ): | |
| self.beta1 = beta1 | |
| self.beta2 = beta2 | |
| self.best_cof = best_cof | |
| self.thresh = thresh | |
| def __call__( | |
| self, | |
| score_ext: Dict[str, float], | |
| score_diff: Dict[str, float], | |
| nbest_preds: Dict[str, Dict[int, Dict[str, float]]] | |
| ): | |
| all_scores = collections.OrderedDict() | |
| assert score_ext.keys() == score_diff.keys() | |
| for key in score_ext.keys(): | |
| if key not in all_scores: | |
| all_scores[key] = [] | |
| all_scores[key].extend( | |
| [self.beta1 * score_ext[key], | |
| self.beta2 * score_diff[key]] | |
| ) | |
| output_scores = {} | |
| for key, scores in all_scores.items(): | |
| mean_score = sum(scores) / float(len(scores)) | |
| output_scores[key] = mean_score | |
| all_nbest = collections.OrderedDict() | |
| for key, entries in nbest_preds.items(): | |
| if key not in all_nbest: | |
| all_nbest[key] = collections.defaultdict(float) | |
| for entry in entries: | |
| prob = self.best_cof * entry["probability"] | |
| all_nbest[key][entry["text"]] += prob | |
| output_predictions = {} | |
| for key, entry_map in all_nbest.items(): | |
| sorted_texts = sorted( | |
| entry_map.keys(), key=lambda x: entry_map[x], reverse=True | |
| ) | |
| best_text = sorted_texts[0] | |
| output_predictions[key] = best_text | |
| for qid in output_predictions.keys(): | |
| if output_scores[qid] > self.thresh: | |
| output_predictions[qid] = "" | |
| return output_predictions, output_scores | |
| class RetroReader: | |
| def __init__( | |
| self, | |
| args, | |
| sketch_reader: SketchReader, | |
| intensive_reader: IntensiveReader, | |
| rear_verifier: RearVerifier, | |
| prep_fn: Tuple[Callable, Callable], | |
| ): | |
| self.args = args | |
| # Set submodules | |
| self.sketch_reader = sketch_reader | |
| self.intensive_reader = intensive_reader | |
| self.rear_verifier = rear_verifier | |
| # Set prep function for inference | |
| self.sketch_prep_fn, self.intensive_prep_fn = prep_fn | |
| def load( | |
| cls, | |
| train_examples=None, | |
| sketch_train_dataset=None, | |
| intensive_train_dataset=None, | |
| eval_examples=None, | |
| sketch_eval_dataset=None, | |
| intensive_eval_dataset=None, | |
| config_file: str = C.DEFAULT_CONFIG_FILE, | |
| ): | |
| # Get arguments from yaml files | |
| parser = HfArgumentParser([RetroArguments, TrainingArguments]) | |
| retro_args, training_args = parser.parse_yaml_file(yaml_file=config_file) | |
| if training_args.run_name is not None and "," in training_args.run_name: | |
| sketch_run_name, intensive_run_name = training_args.run_name.split(",") | |
| else: | |
| sketch_run_name, intensive_run_name = None, None | |
| if training_args.metric_for_best_model is not None and "," in training_args.metric_for_best_model: | |
| sketch_best_metric, intensive_best_metric = training_args.metric_for_best_model.split(",") | |
| else: | |
| sketch_best_metric, intensive_best_metric = None, None | |
| sketch_training_args = copy.deepcopy(training_args) | |
| intensive_training_args = training_args | |
| sketch_tokenizer = AutoTokenizer.from_pretrained( | |
| pretrained_model_name_or_path=retro_args.sketch_tokenizer_name, | |
| use_auth_token=retro_args.use_auth_token, | |
| revision=retro_args.sketch_revision, | |
| ) | |
| # If `train_examples` is feeded, perform preprocessing | |
| if train_examples is not None and sketch_train_dataset is None: | |
| sketch_prep_fn, is_batched = get_sketch_features(sketch_tokenizer, "train", retro_args) | |
| sketch_train_dataset = train_examples.map( | |
| sketch_prep_fn, | |
| batched=is_batched, | |
| remove_columns=train_examples.column_names, | |
| num_proc=retro_args.preprocessing_num_workers, | |
| load_from_cache_file=not retro_args.overwrite_cache, | |
| ) | |
| # If `eval_examples` is feeded, perform preprocessing | |
| if eval_examples is not None and sketch_eval_dataset is None: | |
| sketch_prep_fn, is_batched = get_sketch_features(sketch_tokenizer, "eval", retro_args) | |
| sketch_eval_dataset = eval_examples.map( | |
| sketch_prep_fn, | |
| batched=is_batched, | |
| remove_columns=eval_examples.column_names, | |
| num_proc=retro_args.preprocessing_num_workers, | |
| load_from_cache_file=not retro_args.overwrite_cache, | |
| ) | |
| # Get preprocessing function for inference | |
| sketch_prep_fn, _ = get_sketch_features(sketch_tokenizer, "test", retro_args) | |
| # Get model for sketch reader | |
| sketch_model_cls = retro_args.sketch_model_cls | |
| sketch_model = sketch_model_cls.from_pretrained( | |
| pretrained_model_name_or_path=retro_args.sketch_model_name, | |
| use_auth_token=retro_args.use_auth_token, | |
| revision=retro_args.sketch_revision, | |
| ) | |
| # Get sketch reader | |
| sketch_training_args.run_name = sketch_run_name | |
| sketch_training_args.output_dir += "/sketch" | |
| sketch_training_args.metric_for_best_model = sketch_best_metric | |
| sketch_reader = SketchReader( | |
| model=sketch_model, | |
| args=sketch_training_args, | |
| train_dataset=sketch_train_dataset, | |
| eval_dataset=sketch_eval_dataset, | |
| eval_examples=eval_examples, | |
| data_args=retro_args, | |
| tokenizer=sketch_tokenizer, | |
| compute_metrics=compute_classification_metric, | |
| ) | |
| intensive_tokenizer = AutoTokenizer.from_pretrained( | |
| pretrained_model_name_or_path=retro_args.intensive_tokenizer_name, | |
| use_auth_token=retro_args.use_auth_token, | |
| revision=retro_args.intensive_revision, | |
| ) | |
| # If `train_examples` is feeded, perform preprocessing | |
| if train_examples is not None and intensive_train_dataset is None: | |
| intensive_prep_fn, is_batched = get_intensive_features(intensive_tokenizer, "train", retro_args) | |
| intensive_train_dataset = train_examples.map( | |
| intensive_prep_fn, | |
| batched=is_batched, | |
| remove_columns=train_examples.column_names, | |
| num_proc=retro_args.preprocessing_num_workers, | |
| load_from_cache_file=not retro_args.overwrite_cache, | |
| ) | |
| # If `eval_examples` is feeded, perform preprocessing | |
| if eval_examples is not None and intensive_eval_dataset is None: | |
| intensive_prep_fn, is_batched = get_intensive_features(intensive_tokenizer, "eval", retro_args) | |
| intensive_eval_dataset = eval_examples.map( | |
| intensive_prep_fn, | |
| batched=is_batched, | |
| remove_columns=eval_examples.column_names, | |
| num_proc=retro_args.preprocessing_num_workers, | |
| load_from_cache_file=not retro_args.overwrite_cache, | |
| ) | |
| # Get preprocessing function for inference | |
| intensive_prep_fn, _ = get_intensive_features(intensive_tokenizer, "test", retro_args) | |
| # Get model for intensive reader | |
| intensive_model_cls = retro_args.intensive_model_cls | |
| intensive_model = intensive_model_cls.from_pretrained( | |
| pretrained_model_name_or_path=retro_args.intensive_model_name, | |
| use_auth_token=retro_args.use_auth_token, | |
| revision=retro_args.intensive_revision, | |
| ) | |
| # Get intensive reader | |
| intensive_training_args.run_name = intensive_run_name | |
| intensive_training_args.output_dir += "/intensive" | |
| intensive_training_args.metric_for_best_model = intensive_best_metric | |
| intensive_reader = IntensiveReader( | |
| model=intensive_model, | |
| args=intensive_training_args, | |
| train_dataset=intensive_train_dataset, | |
| eval_dataset=intensive_eval_dataset, | |
| eval_examples=eval_examples, | |
| data_args=retro_args, | |
| tokenizer=intensive_tokenizer, | |
| compute_metrics=compute_squad_v2, | |
| ) | |
| # Get rear verifier | |
| rear_verifier = RearVerifier( | |
| beta1=retro_args.beta1, | |
| beta2=retro_args.beta2, | |
| best_cof=retro_args.best_cof, | |
| thresh=retro_args.rear_threshold, | |
| ) | |
| return cls( | |
| args=retro_args, | |
| sketch_reader=sketch_reader, | |
| intensive_reader=intensive_reader, | |
| rear_verifier=rear_verifier, | |
| prep_fn=(sketch_prep_fn, intensive_prep_fn), | |
| ) | |
| def __call__( | |
| self, | |
| query: str, | |
| context: Union[str, List[str]], | |
| return_submodule_outputs: bool = False, | |
| ) -> Tuple[Any]: | |
| if isinstance(context, list): | |
| context = " ".join(context) | |
| predict_examples = datasets.Dataset.from_dict({ | |
| "example_id": ["0"], | |
| C.ID_COLUMN_NAME: ["id-01"], | |
| C.QUESTION_COLUMN_NAME: [query], | |
| C.CONTEXT_COLUMN_NAME: [context] | |
| }) | |
| return self.inference(predict_examples) | |
| def train(self, module: str = "all"): | |
| def wandb_finish(module): | |
| for callback in module.callback_handler.callbacks: | |
| if "wandb" in str(type(callback)).lower(): | |
| callback._wandb.finish() | |
| callback._initialized = False | |
| # Train sketch reader | |
| if module.lower() in ["all", "sketch"]: | |
| self.sketch_reader.train() | |
| self.sketch_reader.save_model() | |
| self.sketch_reader.save_state() | |
| self.sketch_reader.free_memory() | |
| wandb_finish(self.sketch_reader) | |
| # Train intensive reader | |
| if module.lower() in ["all", "intensive"]: | |
| self.intensive_reader.train() | |
| self.intensive_reader.save_model() | |
| self.intensive_reader.save_state() | |
| self.intensive_reader.free_memory() | |
| wandb_finish(self.intensive_reader) | |
| def inference(self, predict_examples: datasets.Dataset) -> Tuple[Any]: | |
| if "example_id" not in predict_examples.column_names: | |
| test_dataset = predict_examples.map( | |
| lambda _, i: {"example_id": str(i)}, | |
| with_indices=True, | |
| ) | |
| sketch_features = predict_examples.map( | |
| self.sketch_prep_fn, | |
| batched=True, | |
| remove_columns=predict_examples.column_names, | |
| ) | |
| intensive_features = predict_examples.map( | |
| self.intensive_prep_fn, | |
| batched=True, | |
| remove_columns=predict_examples.column_names, | |
| ) | |
| # self.sketch_reader.to(self.sketch_reader.args.device) | |
| score_ext = self.sketch_reader.predict(sketch_features, predict_examples) | |
| # self.sketch_reader.to("cpu") | |
| # self.intensive_reader.to(self.intensive_reader.args.device) | |
| nbest_preds, score_diff = self.intensive_reader.predict( | |
| intensive_features, predict_examples, mode="retro_inference") | |
| # self.intensive_reader.to("cpu") | |
| predictions, scores = self.rear_verifier(score_ext, score_diff, nbest_preds) | |
| outputs = (predictions, scores) | |
| # if self.return_submodule_outputs: | |
| # outputs += (score_ext, nbest_preds, score_diff) | |
| return outputs | |
| def null_score_diff_threshold(self): | |
| return self.args.null_score_diff_threshold | |
| def null_score_diff_threshold(self, val): | |
| self.args.null_score_diff_threshold = val | |
| def n_best_size(self): | |
| return self.args.n_best_size | |
| def n_best_size(self, val): | |
| self.args.n_best_size = val | |
| def beta1(self): | |
| return self.rear_verifier.beta1 | |
| def beta1(self, val): | |
| self.rear_verifier.beta1 = val | |
| def beta2(self): | |
| return self.rear_verifier.beta2 | |
| def beta2(self, val): | |
| self.rear_verifier.beta2 = val | |
| def best_cof(self): | |
| return self.rear_verifier.best_cof | |
| def best_cof(self, val): | |
| self.rear_verifier.best_cof = val | |
| def rear_threshold(self): | |
| return self.rear_verifier.thresh | |
| def rear_threshold(self, val): | |
| self.rear_verifier.thresh = val | |