|
|
import importlib |
|
|
import os |
|
|
import tempfile |
|
|
from unittest import TestCase |
|
|
|
|
|
import pytest |
|
|
from datasets import DownloadConfig |
|
|
|
|
|
import evaluate |
|
|
from evaluate.loading import ( |
|
|
CachedEvaluationModuleFactory, |
|
|
HubEvaluationModuleFactory, |
|
|
LocalEvaluationModuleFactory, |
|
|
evaluation_module_factory, |
|
|
) |
|
|
|
|
|
from .utils import OfflineSimulationMode, offline |
|
|
|
|
|
|
|
|
SAMPLE_METRIC_IDENTIFIER = "lvwerra/test" |
|
|
|
|
|
METRIC_LOADING_SCRIPT_NAME = "__dummy_metric1__" |
|
|
|
|
|
METRIC_LOADING_SCRIPT_CODE = """ |
|
|
import evaluate |
|
|
from evaluate import EvaluationModuleInfo |
|
|
from datasets import Features, Value |
|
|
|
|
|
class __DummyMetric1__(evaluate.EvaluationModule): |
|
|
|
|
|
def _info(self): |
|
|
return EvaluationModuleInfo(features=Features({"predictions": Value("int"), "references": Value("int")})) |
|
|
|
|
|
def _compute(self, predictions, references): |
|
|
return {"__dummy_metric1__": sum(int(p == r) for p, r in zip(predictions, references))} |
|
|
""" |
|
|
|
|
|
|
|
|
@pytest.fixture |
|
|
def metric_loading_script_dir(tmp_path): |
|
|
script_name = METRIC_LOADING_SCRIPT_NAME |
|
|
script_dir = tmp_path / script_name |
|
|
script_dir.mkdir() |
|
|
script_path = script_dir / f"{script_name}.py" |
|
|
with open(script_path, "w") as f: |
|
|
f.write(METRIC_LOADING_SCRIPT_CODE) |
|
|
return str(script_dir) |
|
|
|
|
|
|
|
|
class ModuleFactoryTest(TestCase): |
|
|
@pytest.fixture(autouse=True) |
|
|
def inject_fixtures(self, metric_loading_script_dir): |
|
|
self._metric_loading_script_dir = metric_loading_script_dir |
|
|
|
|
|
def setUp(self): |
|
|
self.hf_modules_cache = tempfile.mkdtemp() |
|
|
self.cache_dir = tempfile.mkdtemp() |
|
|
self.download_config = DownloadConfig(cache_dir=self.cache_dir) |
|
|
self.dynamic_modules_path = evaluate.loading.init_dynamic_modules( |
|
|
name="test_datasets_modules_" + os.path.basename(self.hf_modules_cache), |
|
|
hf_modules_cache=self.hf_modules_cache, |
|
|
) |
|
|
|
|
|
def test_HubEvaluationModuleFactory_with_internal_import(self): |
|
|
|
|
|
factory = HubEvaluationModuleFactory( |
|
|
"evaluate-metric/squad_v2", |
|
|
module_type="metric", |
|
|
download_config=self.download_config, |
|
|
dynamic_modules_path=self.dynamic_modules_path, |
|
|
) |
|
|
module_factory_result = factory.get_module() |
|
|
assert importlib.import_module(module_factory_result.module_path) is not None |
|
|
|
|
|
def test_HubEvaluationModuleFactory_with_external_import(self): |
|
|
|
|
|
factory = HubEvaluationModuleFactory( |
|
|
"evaluate-metric/bleu", |
|
|
module_type="metric", |
|
|
download_config=self.download_config, |
|
|
dynamic_modules_path=self.dynamic_modules_path, |
|
|
) |
|
|
module_factory_result = factory.get_module() |
|
|
assert importlib.import_module(module_factory_result.module_path) is not None |
|
|
|
|
|
def test_HubEvaluationModuleFactoryWithScript(self): |
|
|
factory = HubEvaluationModuleFactory( |
|
|
SAMPLE_METRIC_IDENTIFIER, |
|
|
download_config=self.download_config, |
|
|
dynamic_modules_path=self.dynamic_modules_path, |
|
|
) |
|
|
module_factory_result = factory.get_module() |
|
|
assert importlib.import_module(module_factory_result.module_path) is not None |
|
|
|
|
|
def test_LocalMetricModuleFactory(self): |
|
|
path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py") |
|
|
factory = LocalEvaluationModuleFactory( |
|
|
path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path |
|
|
) |
|
|
module_factory_result = factory.get_module() |
|
|
assert importlib.import_module(module_factory_result.module_path) is not None |
|
|
|
|
|
def test_CachedMetricModuleFactory(self): |
|
|
path = os.path.join(self._metric_loading_script_dir, f"{METRIC_LOADING_SCRIPT_NAME}.py") |
|
|
factory = LocalEvaluationModuleFactory( |
|
|
path, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path |
|
|
) |
|
|
module_factory_result = factory.get_module() |
|
|
for offline_mode in OfflineSimulationMode: |
|
|
with offline(offline_mode): |
|
|
factory = CachedEvaluationModuleFactory( |
|
|
METRIC_LOADING_SCRIPT_NAME, |
|
|
dynamic_modules_path=self.dynamic_modules_path, |
|
|
) |
|
|
module_factory_result = factory.get_module() |
|
|
assert importlib.import_module(module_factory_result.module_path) is not None |
|
|
|
|
|
def test_cache_with_remote_canonical_module(self): |
|
|
metric = "accuracy" |
|
|
evaluation_module_factory( |
|
|
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path |
|
|
) |
|
|
|
|
|
for offline_mode in OfflineSimulationMode: |
|
|
with offline(offline_mode): |
|
|
evaluation_module_factory( |
|
|
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path |
|
|
) |
|
|
|
|
|
def test_cache_with_remote_community_module(self): |
|
|
metric = "lvwerra/test" |
|
|
evaluation_module_factory( |
|
|
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path |
|
|
) |
|
|
|
|
|
for offline_mode in OfflineSimulationMode: |
|
|
with offline(offline_mode): |
|
|
evaluation_module_factory( |
|
|
metric, download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path |
|
|
) |
|
|
|