|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" BERTScore metric. """ |
|
|
|
|
|
import functools |
|
|
from contextlib import contextmanager |
|
|
|
|
|
import bert_score |
|
|
import datasets |
|
|
from packaging import version |
|
|
|
|
|
import evaluate |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def filter_logging_context(): |
|
|
def filter_log(record): |
|
|
return False if "This IS expected if you are initializing" in record.msg else True |
|
|
|
|
|
logger = datasets.utils.logging.get_logger("transformers.modeling_utils") |
|
|
logger.addFilter(filter_log) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
logger.removeFilter(filter_log) |
|
|
|
|
|
|
|
|
_CITATION = """\ |
|
|
@inproceedings{bert-score, |
|
|
title={BERTScore: Evaluating Text Generation with BERT}, |
|
|
author={Tianyi Zhang* and Varsha Kishore* and Felix Wu* and Kilian Q. Weinberger and Yoav Artzi}, |
|
|
booktitle={International Conference on Learning Representations}, |
|
|
year={2020}, |
|
|
url={https://openreview.net/forum?id=SkeHuCVFDr} |
|
|
} |
|
|
""" |
|
|
|
|
|
_DESCRIPTION = """\ |
|
|
BERTScore leverages the pre-trained contextual embeddings from BERT and matches words in candidate and reference |
|
|
sentences by cosine similarity. |
|
|
It has been shown to correlate with human judgment on sentence-level and system-level evaluation. |
|
|
Moreover, BERTScore computes precision, recall, and F1 measure, which can be useful for evaluating different language |
|
|
generation tasks. |
|
|
|
|
|
See the project's README at https://github.com/Tiiiger/bert_score#readme for more information. |
|
|
""" |
|
|
|
|
|
_KWARGS_DESCRIPTION = """ |
|
|
BERTScore Metrics with the hashcode from a source against one or more references. |
|
|
|
|
|
Args: |
|
|
predictions (list of str): Prediction/candidate sentences. |
|
|
references (list of str or list of list of str): Reference sentences. |
|
|
lang (str): Language of the sentences; required (e.g. 'en'). |
|
|
model_type (str): Bert specification, default using the suggested |
|
|
model for the target language; has to specify at least one of |
|
|
`model_type` or `lang`. |
|
|
num_layers (int): The layer of representation to use, |
|
|
default using the number of layers tuned on WMT16 correlation data. |
|
|
verbose (bool): Turn on intermediate status update. |
|
|
idf (bool or dict): Use idf weighting; can also be a precomputed idf_dict. |
|
|
device (str): On which the contextual embedding model will be allocated on. |
|
|
If this argument is None, the model lives on cuda:0 if cuda is available. |
|
|
nthreads (int): Number of threads. |
|
|
batch_size (int): Bert score processing batch size, |
|
|
at least one of `model_type` or `lang`. `lang` needs to be |
|
|
specified when `rescale_with_baseline` is True. |
|
|
rescale_with_baseline (bool): Rescale bertscore with pre-computed baseline. |
|
|
baseline_path (str): Customized baseline file. |
|
|
use_fast_tokenizer (bool): `use_fast` parameter passed to HF tokenizer. New in version 0.3.10. |
|
|
|
|
|
Returns: |
|
|
precision: Precision. |
|
|
recall: Recall. |
|
|
f1: F1 score. |
|
|
hashcode: Hashcode of the library. |
|
|
|
|
|
Examples: |
|
|
|
|
|
>>> predictions = ["hello there", "general kenobi"] |
|
|
>>> references = ["hello there", "general kenobi"] |
|
|
>>> bertscore = evaluate.load("bertscore") |
|
|
>>> results = bertscore.compute(predictions=predictions, references=references, lang="en") |
|
|
>>> print([round(v, 2) for v in results["f1"]]) |
|
|
[1.0, 1.0] |
|
|
""" |
|
|
|
|
|
|
|
|
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION) |
|
|
class BERTScore(evaluate.Metric): |
|
|
def _info(self): |
|
|
return evaluate.MetricInfo( |
|
|
description=_DESCRIPTION, |
|
|
citation=_CITATION, |
|
|
homepage="https://github.com/Tiiiger/bert_score", |
|
|
inputs_description=_KWARGS_DESCRIPTION, |
|
|
features=[ |
|
|
datasets.Features( |
|
|
{ |
|
|
"predictions": datasets.Value("string", id="sequence"), |
|
|
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"), |
|
|
} |
|
|
), |
|
|
datasets.Features( |
|
|
{ |
|
|
"predictions": datasets.Value("string", id="sequence"), |
|
|
"references": datasets.Value("string", id="sequence"), |
|
|
} |
|
|
), |
|
|
], |
|
|
codebase_urls=["https://github.com/Tiiiger/bert_score"], |
|
|
reference_urls=[ |
|
|
"https://github.com/Tiiiger/bert_score", |
|
|
"https://arxiv.org/abs/1904.09675", |
|
|
], |
|
|
) |
|
|
|
|
|
def _compute( |
|
|
self, |
|
|
predictions, |
|
|
references, |
|
|
lang=None, |
|
|
model_type=None, |
|
|
num_layers=None, |
|
|
verbose=False, |
|
|
idf=False, |
|
|
device=None, |
|
|
batch_size=64, |
|
|
nthreads=4, |
|
|
all_layers=False, |
|
|
rescale_with_baseline=False, |
|
|
baseline_path=None, |
|
|
use_fast_tokenizer=False, |
|
|
): |
|
|
|
|
|
if isinstance(references[0], str): |
|
|
references = [[ref] for ref in references] |
|
|
|
|
|
if idf: |
|
|
idf_sents = [r for ref in references for r in ref] |
|
|
else: |
|
|
idf_sents = None |
|
|
|
|
|
get_hash = bert_score.utils.get_hash |
|
|
scorer = bert_score.BERTScorer |
|
|
|
|
|
if version.parse(bert_score.__version__) >= version.parse("0.3.10"): |
|
|
get_hash = functools.partial(get_hash, use_fast_tokenizer=use_fast_tokenizer) |
|
|
scorer = functools.partial(scorer, use_fast_tokenizer=use_fast_tokenizer) |
|
|
elif use_fast_tokenizer: |
|
|
raise ImportWarning( |
|
|
"To use a fast tokenizer, the module `bert-score>=0.3.10` is required, and the current version of " |
|
|
"`bert-score` doesn't match this condition.\n" |
|
|
'You can install it with `pip install "bert-score>=0.3.10"`.' |
|
|
) |
|
|
|
|
|
if model_type is None: |
|
|
if lang is None: |
|
|
raise ValueError( |
|
|
"Either 'lang' (e.g. 'en') or 'model_type' (e.g. 'microsoft/deberta-xlarge-mnli')" |
|
|
" must be specified" |
|
|
) |
|
|
model_type = bert_score.utils.lang2model[lang.lower()] |
|
|
|
|
|
if num_layers is None: |
|
|
num_layers = bert_score.utils.model2layers[model_type] |
|
|
|
|
|
hashcode = get_hash( |
|
|
model=model_type, |
|
|
num_layers=num_layers, |
|
|
idf=idf, |
|
|
rescale_with_baseline=rescale_with_baseline, |
|
|
use_custom_baseline=baseline_path is not None, |
|
|
) |
|
|
|
|
|
with filter_logging_context(): |
|
|
if not hasattr(self, "cached_bertscorer") or self.cached_bertscorer.hash != hashcode: |
|
|
self.cached_bertscorer = scorer( |
|
|
model_type=model_type, |
|
|
num_layers=num_layers, |
|
|
batch_size=batch_size, |
|
|
nthreads=nthreads, |
|
|
all_layers=all_layers, |
|
|
idf=idf, |
|
|
idf_sents=idf_sents, |
|
|
device=device, |
|
|
lang=lang, |
|
|
rescale_with_baseline=rescale_with_baseline, |
|
|
baseline_path=baseline_path, |
|
|
) |
|
|
|
|
|
(P, R, F) = self.cached_bertscorer.score( |
|
|
cands=predictions, |
|
|
refs=references, |
|
|
verbose=verbose, |
|
|
batch_size=batch_size, |
|
|
) |
|
|
output_dict = { |
|
|
"precision": P.tolist(), |
|
|
"recall": R.tolist(), |
|
|
"f1": F.tolist(), |
|
|
"hashcode": hashcode, |
|
|
} |
|
|
return output_dict |
|
|
|