File size: 16,964 Bytes
b965a6e
 
 
 
 
 
 
 
 
 
 
 
 
29a0e42
b965a6e
29a0e42
 
948d61e
b965a6e
29a0e42
 
b965a6e
 
 
29a0e42
 
 
b965a6e
 
 
 
29a0e42
b965a6e
 
 
 
29a0e42
b965a6e
b807233
 
 
 
b965a6e
b807233
b965a6e
b807233
 
 
 
 
 
 
 
 
 
 
b965a6e
b807233
b965a6e
 
 
 
ddf1ba7
29a0e42
 
 
b965a6e
948d61e
 
b965a6e
 
 
 
 
 
29a0e42
 
 
 
 
 
b965a6e
29a0e42
b965a6e
29a0e42
 
b965a6e
 
29a0e42
948d61e
 
3fa3bf9
948d61e
3fa3bf9
 
948d61e
 
 
 
 
3fa3bf9
948d61e
3fa3bf9
 
 
29a0e42
3fa3bf9
 
 
5ec6ad7
3fa3bf9
5ec6ad7
 
3fa3bf9
5ec6ad7
3fa3bf9
29a0e42
 
3fa3bf9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29a0e42
948d61e
 
 
29a0e42
948d61e
 
29a0e42
948d61e
 
 
 
 
 
29a0e42
948d61e
d84d51c
 
b965a6e
3d8a7dc
948d61e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6975d0
948d61e
a1332fa
3d8a7dc
948d61e
 
 
3d8a7dc
948d61e
 
 
 
 
 
 
 
 
 
 
 
3d8a7dc
948d61e
 
 
3d8a7dc
 
 
 
 
 
 
 
 
 
948d61e
 
 
3d8a7dc
 
948d61e
3d8a7dc
 
 
 
 
 
 
3aece72
948d61e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6975d0
 
 
 
948d61e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3aece72
29a0e42
 
948d61e
 
29a0e42
 
948d61e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29a0e42
948d61e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Accuracy metric for the Test of Time benchmark by Bahar et al. (2025)."""

import ast
import json
from typing import Any, Literal

import datasets
import evaluate

_CITATION = """\
@InProceedings{huggingface:module,
title = {Test of Time Accuracy},
authors={Auss Abbood},
year={2025}
}
"""

_DESCRIPTION = """\
The Test of Time (ToT) benchmarks expects models format their answers as a JSON with an explanation field and an answer field that follows a predefined format. The metrics extracts JSONs objects from the model's output, retains only the first JSON, drops the explanation field and compares it with the reference answer.
"""


_KWARGS_DESCRIPTION = """
Compares the extracted answer from the model's output with the reference answer.
Args:
    predictions: list of predictions to score. Each prediction should be a string that contains a JSON object (e.g., generated by an LLM).
    references: list of reference answers.
    subset: The subset of the benchmark being evaluated. Must be one of "arithmetic" or "semantic".
    return_average: If True, returns the average accuracy. If False, returns a list of boolean scores (correct/incorrect) for each sample. Defaults to True.
Returns:
    accuracy: The accuracy score (0.0 to 1.0) if return_average=True, or a list of booleans indicating correctness per sample if return_average=False.
Examples:
    >>> import evaluate
    >>> metric = evaluate.load("aauss/test_of_time_accuracy")
    >>> predictions = [
    ...     '{"explanation": "Some explanation...", "unordered_list": ["London"]}',
    ...     ' "Response without opening curly brackets...", "answer": "2005-04-07"}',
    ... ]
    >>> references = [
    ...     '{"unordered_list": ["London"]}',
    ...     "{'answer': '2005-04-07'}",
    ... ]
    >>> results = metric.compute(predictions=predictions, references=references, subset="arithmetic")
    >>> print(results)
    {'accuracy': 0.5}
"""


@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class TestOfTimeAccuracy(evaluate.Metric):
    """Accuracy metric for the Test of Time benchmark by Bahar et al. (2025)."""

    __test__ = False

    def _info(self) -> evaluate.MetricInfo:
        """Returns metadata about this metric."""
        return evaluate.MetricInfo(
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features(
                {
                    "predictions": datasets.Value("string"),
                    "references": datasets.Value("string"),
                }
            ),
            # Homepage of the module for documentation
            # homepage="http://module.homepage",
            # Additional links to the codebase or references
            # codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            # reference_urls=["http://path.to.reference.url/new_module"],
        )

    @staticmethod
    def _extract_first_json_object(text: str) -> dict | None:
        """
        Extract the first valid JSON object from text.

        Handles common LLM output issues like unescaped newlines in string
        values (LLMs produce human-readable output, not strict JSON).

        Args:
            text: String that may contain JSON objects

        Returns:
            The first JSON dictionary found, or None if no valid JSON exists
        """
        # Fix unescaped control chars in strings (common LLM issue)
        text = TestOfTimeAccuracy._escape_control_chars_in_strings(text)

        decoder = json.JSONDecoder()
        idx = 0
        while idx < len(text):
            if text[idx] == '{':
                try:
                    obj, _ = decoder.raw_decode(text, idx)
                    if isinstance(obj, dict):
                        return obj
                except json.JSONDecodeError:
                    pass
            idx += 1
        return None

    @staticmethod
    def _escape_control_chars_in_strings(text: str) -> str:
        """
        Escape literal control characters inside JSON string values.

        LLMs produce newlines/tabs for readability, but JSON requires them
        to be escaped within strings.
        """
        result = []
        in_string = False
        i = 0
        while i < len(text):
            char = text[i]
            if char == '\\' and in_string and i + 1 < len(text):
                # Preserve existing escape sequences
                result.append(char)
                result.append(text[i + 1])
                i += 2
                continue
            if char == '"':
                in_string = not in_string
            if in_string and char == '\n':
                result.append('\\n')
            elif in_string and char == '\r':
                result.append('\\r')
            elif in_string and char == '\t':
                result.append('\\t')
            else:
                result.append(char)
            i += 1
        return ''.join(result)

    @staticmethod
    def _parse_reference_label(label_str: str) -> dict | None:
        """
        Parses a reference label string into a dictionary.

        Handles Python dict strings (e.g., "{'key': 'value'}") by
        evaluating them as literals.

        Args:
            label_str: String representation of a dictionary

        Returns:
            Parsed dictionary, or None if parsing fails
        """
        try:
            return ast.literal_eval(label_str)
        except (ValueError, SyntaxError):
            return None

    @staticmethod
    def _remove_explanation_field(data: Any) -> Any:
        """
        Removes the 'explanation' field from a dictionary.

        Args:
            data: Dictionary or other data type

        Returns:
            The data with explanation field removed (if it was a dict),
            or the original data unchanged
        """
        if isinstance(data, dict):
            data.pop("explanation", None)
        return data

    @staticmethod
    def _extract_answer_field(data: Any) -> Any:
        """
        Extracts the 'answer' field from a dictionary.

        Args:
            data: Dictionary or other data type

        Returns:
            The value of the 'answer' field if data is a dict,
            otherwise returns the data unchanged
        """
        if isinstance(data, dict):
            return data.get("answer", None)
        return data

    @staticmethod
    def _sort_unordered_list_field(data: Any) -> Any:
        """
        Sorts the 'unordered_list' field in a dictionary.

        This enables comparison of unordered lists by converting them to
        a canonical sorted form.

        Args:
            data: Dictionary potentially containing an 'unordered_list' field

        Returns:
            Sorted list if data is a dict with 'unordered_list',
            otherwise returns data unchanged
        """
        if isinstance(data, dict) and "unordered_list" in data:
            return sorted([item for item in data["unordered_list"] if isinstance(item, str)])
        return data

    @staticmethod
    def _cast_prediction_to_reference_types(
        reference: dict, prediction: dict
    ) -> dict | None:
        """
        Casts prediction values to match reference types.

        Ensures that predictions can be compared with references even when
        the types differ (e.g., string "123" vs int 123, int 5 vs float 5.0).

        Args:
            reference: Reference dictionary with expected types
            prediction: Prediction dictionary to cast

        Returns:
            Dictionary with casted values, or None if casting fails or
            prediction is missing required keys
        """
        if not isinstance(prediction, dict) or not isinstance(reference, dict):
            return None

        casted_prediction = {}

        try:
            for ref_key, ref_value in reference.items():
                if ref_key not in prediction:
                    return None

                reference_type = type(ref_value)
                pred_value = prediction[ref_key]

                # Safeguard: Python allows list("abc") -> ['a', 'b', 'c']
                # We don't want to turn strings into character lists
                if reference_type is list and not isinstance(pred_value, list):
                    return None

                # Cast to reference type: int("123") -> 123, float(12) -> 12.0, etc.
                casted_prediction[ref_key] = reference_type(pred_value)

            return casted_prediction

        except (ValueError, TypeError):
            return None

    @staticmethod
    def _normalise_list_field_casing(data: dict | None) -> dict | None:
        """
        Converts all list items to lowercase for case-insensitive comparison.

        Applied to 'ordered_list' and 'unordered_list' fields to handle variations
        in capitalization (e.g., "Skating" vs "skating").

        Args:
            data: Dictionary potentially containing list fields

        Returns:
            Dictionary with lowercased list items, or None if data is None
        """
        if data is None or not isinstance(data, dict):
            return data

        # Process list fields regardless of key order
        for key in ["ordered_list", "unordered_list"]:
            if key in data and isinstance(data[key], list):
                data[key] = [item.lower() for item in data[key] if isinstance(item, str)]

        return data

    @staticmethod
    def _fix_age_field_conflict(prediction: dict | None) -> dict | None:
        """
        Fixes a known conflict in the dataset regarding the 'age' field.

        In some dataset samples, the instruction asks for an 'age' field but
        the reference uses 'answer'. This method normalises the prediction
        to match the expected format.

        Args:
            prediction: Prediction dictionary potentially with 'age' field

        Returns:
            Dictionary with 'age' converted to 'answer', or unchanged if
            'age' field not present
        """
        if prediction is not None and isinstance(prediction, dict):
            if "age" in prediction:
                prediction = {"answer": prediction["age"]}
        return prediction

    def _process_arithmetic_prediction(
        self, prediction: dict | None, reference: dict | None
    ) -> tuple[Any, Any]:
        """
        Processes a prediction-reference pair for the arithmetic subset.

        Applies arithmetic-specific transformations:
        1. Fixes age field conflicts
        2. normalises list casing
        3. Casts prediction types to match reference
        4. Sorts unordered lists for comparison

        Args:
            prediction: Raw prediction dictionary
            reference: Raw reference dictionary

        Returns:
            Tuple of (processed_prediction, processed_reference)
        """
        prediction = self._fix_age_field_conflict(prediction)
        prediction = self._normalise_list_field_casing(prediction)
        reference = self._normalise_list_field_casing(reference)
        prediction = self._cast_prediction_to_reference_types(reference, prediction)

        # Sort unordered lists for order-independent comparison
        if reference and "unordered_list" in reference:
            prediction = self._sort_unordered_list_field(prediction)
            reference = self._sort_unordered_list_field(reference)

        return prediction, reference

    def _process_semantic_prediction(
        self, prediction: Any, reference: Any
    ) -> tuple[str, str]:
        """
        Processes a prediction-reference pair for the semantic subset.

        Converts both to strings for comparison since semantic answers
        may have type mismatches (e.g., int in JSON vs string in reference).

        Args:
            prediction: Raw prediction value
            reference: Raw reference value

        Returns:
            Tuple of (str(prediction), str(reference))
        """
        return str(prediction), str(reference)

    def _extract_predictions(
        self, raw_predictions: list[str], subset: str
    ) -> list[Any]:
        """
        Extracts and preprocesses predictions based on subset type.

        Args:
            raw_predictions: List of raw prediction strings (e.g., from LLM output)
            subset: Either 'arithmetic' or 'semantic'

        Returns:
            List of extracted prediction values
        """
        predictions = [self._extract_first_json_object(p) for p in raw_predictions]

        if subset == "semantic":
            # Since labels are not dicts, we need to extract the value from the LLM's answer field.
            predictions = [self._extract_answer_field(p) for p in predictions]
        elif subset == "arithmetic":
            # Labels and LLMs differ only by the explanation field. Thus, remove.
            predictions = [self._remove_explanation_field(p) for p in predictions]

        return predictions

    def _extract_references(self, raw_references: list[str], subset: str) -> list[Any]:
        """
        Extracts and preprocesses references based on subset type.

        Args:
            raw_references: List of raw reference strings
            subset: Either 'arithmetic' or 'semantic'

        Returns:
            List of extracted reference values
        """
        if subset == "arithmetic":
            # Arithmetic references are Python dict strings that need parsing
            return [self._parse_reference_label(r) for r in raw_references]
        else:
            # Semantic references are used as-is
            return raw_references

    def _compare_pair(self, prediction: Any, reference: Any, subset: str) -> bool:
        """
        Compares a single prediction-reference pair.

        Args:
            prediction: Processed prediction value
            reference: Processed reference value
            subset: Either 'arithmetic' or 'semantic'

        Returns:
            True if prediction matches reference, False otherwise
        """
        if subset == "arithmetic":
            prediction, reference = self._process_arithmetic_prediction(
                prediction, reference
            )
        elif subset == "semantic":
            prediction, reference = self._process_semantic_prediction(
                prediction, reference
            )

        return prediction == reference

    def _compute(
        self,
        predictions: list[str],
        references: list[str],
        subset: Literal["arithmetic", "semantic"],
        return_average: bool = True,
    ) -> dict[str, float | list[bool]]:
        """
        Computes accuracy scores for the Test of Time benchmark.

        Args:
            predictions: List of prediction strings (LLM outputs)
            references: List of reference answer strings
            subset: Benchmark subset - either 'arithmetic' or 'semantic'
            return_average: If True, returns average accuracy; if False,
                          returns per-sample correctness

        Returns:
            Dictionary with 'accuracy' key containing either:
            - float: average accuracy (if return_average=True)
            - list[bool]: per-sample correctness (if return_average=False)

        Raises:
            ValueError: If subset is not 'arithmetic' or 'semantic'
        """
        # Validate subset
        if subset not in ["arithmetic", "semantic"]:
            raise ValueError(
                f"Invalid subset: {subset}. Must be 'arithmetic' or 'semantic'."
            )

        # Extract and preprocess predictions and references
        predictions = self._extract_predictions(predictions, subset)
        references = self._extract_references(references, subset)

        # Compare each prediction-reference pair
        accuracy_scores = [
            self._compare_pair(pred, ref, subset)
            for pred, ref in zip(predictions, references)
        ]

        # Return average or per-sample scores
        if return_average:
            return {"accuracy": sum(accuracy_scores) / len(accuracy_scores)}
        return {"accuracy": accuracy_scores}