Spaces:
Sleeping
Sleeping
Add text lowercasing and refactor with AI.
Browse files- test_of_time_accuracy.py +302 -69
- tests/test_arithmetic_scoring.py +4 -2
- tests/test_arithmetic_type_casting.py +12 -12
test_of_time_accuracy.py
CHANGED
|
@@ -15,7 +15,7 @@
|
|
| 15 |
|
| 16 |
import ast
|
| 17 |
import json
|
| 18 |
-
from typing import Literal
|
| 19 |
|
| 20 |
import datasets
|
| 21 |
import evaluate
|
|
@@ -65,7 +65,8 @@ class TestOfTimeAccuracy(evaluate.Metric):
|
|
| 65 |
|
| 66 |
__test__ = False
|
| 67 |
|
| 68 |
-
def _info(self):
|
|
|
|
| 69 |
return evaluate.MetricInfo(
|
| 70 |
module_type="metric",
|
| 71 |
description=_DESCRIPTION,
|
|
@@ -86,52 +87,124 @@ class TestOfTimeAccuracy(evaluate.Metric):
|
|
| 86 |
)
|
| 87 |
|
| 88 |
@staticmethod
|
| 89 |
-
def _extract_first_json_object(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
decoder = json.JSONDecoder()
|
| 91 |
-
idx, end = 0, len(
|
|
|
|
| 92 |
while idx < end:
|
| 93 |
try:
|
| 94 |
-
obj, next_idx = decoder.raw_decode(
|
| 95 |
-
idx = next_idx
|
| 96 |
if isinstance(obj, dict):
|
| 97 |
return obj
|
|
|
|
| 98 |
except ValueError:
|
| 99 |
idx += 1
|
| 100 |
return None
|
| 101 |
|
| 102 |
@staticmethod
|
| 103 |
-
def
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
return d
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
if isinstance(d, dict):
|
| 111 |
-
return d.get("answer", None)
|
| 112 |
-
return d
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
|
|
|
|
|
|
|
|
|
| 117 |
try:
|
| 118 |
-
|
| 119 |
-
return ast.literal_eval(s)
|
| 120 |
except (ValueError, SyntaxError):
|
| 121 |
return None
|
| 122 |
|
| 123 |
@staticmethod
|
| 124 |
-
def
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
@staticmethod
|
| 130 |
-
def
|
|
|
|
|
|
|
| 131 |
"""
|
| 132 |
-
Casts
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
"""
|
|
|
|
|
|
|
|
|
|
| 135 |
casted_prediction = {}
|
| 136 |
|
| 137 |
try:
|
|
@@ -142,12 +215,12 @@ class TestOfTimeAccuracy(evaluate.Metric):
|
|
| 142 |
reference_type = type(ref_value)
|
| 143 |
pred_value = prediction[ref_key]
|
| 144 |
|
| 145 |
-
#
|
| 146 |
-
# We don't want to turn strings into character lists
|
| 147 |
-
if reference_type
|
| 148 |
return None
|
| 149 |
|
| 150 |
-
#
|
| 151 |
casted_prediction[ref_key] = reference_type(pred_value)
|
| 152 |
|
| 153 |
return casted_prediction
|
|
@@ -156,47 +229,207 @@ class TestOfTimeAccuracy(evaluate.Metric):
|
|
| 156 |
return None
|
| 157 |
|
| 158 |
@staticmethod
|
| 159 |
-
def
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
def _compute(
|
| 167 |
self,
|
| 168 |
-
predictions,
|
| 169 |
-
references,
|
| 170 |
subset: Literal["arithmetic", "semantic"],
|
| 171 |
return_average: bool = True,
|
| 172 |
-
):
|
| 173 |
-
"""
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
if return_average:
|
| 201 |
-
return {"accuracy": sum(
|
| 202 |
-
return {"accuracy":
|
|
|
|
| 15 |
|
| 16 |
import ast
|
| 17 |
import json
|
| 18 |
+
from typing import Any, Literal
|
| 19 |
|
| 20 |
import datasets
|
| 21 |
import evaluate
|
|
|
|
| 65 |
|
| 66 |
__test__ = False
|
| 67 |
|
| 68 |
+
def _info(self) -> evaluate.MetricInfo:
|
| 69 |
+
"""Returns metadata about this metric."""
|
| 70 |
return evaluate.MetricInfo(
|
| 71 |
module_type="metric",
|
| 72 |
description=_DESCRIPTION,
|
|
|
|
| 87 |
)
|
| 88 |
|
| 89 |
@staticmethod
|
| 90 |
+
def _extract_first_json_object(text: str) -> dict | None:
|
| 91 |
+
"""
|
| 92 |
+
Extracts the first valid JSON object from a string.
|
| 93 |
+
|
| 94 |
+
Scans through the text and returns the first valid JSON dictionary found.
|
| 95 |
+
This is useful for parsing LLM outputs that may contain JSON mixed with
|
| 96 |
+
other text or markdown formatting.
|
| 97 |
+
|
| 98 |
+
Args:
|
| 99 |
+
text: String that may contain JSON objects
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
The first JSON dictionary found, or None if no valid JSON dict exists
|
| 103 |
+
"""
|
| 104 |
decoder = json.JSONDecoder()
|
| 105 |
+
idx, end = 0, len(text)
|
| 106 |
+
|
| 107 |
while idx < end:
|
| 108 |
try:
|
| 109 |
+
obj, next_idx = decoder.raw_decode(text, idx)
|
|
|
|
| 110 |
if isinstance(obj, dict):
|
| 111 |
return obj
|
| 112 |
+
idx = next_idx
|
| 113 |
except ValueError:
|
| 114 |
idx += 1
|
| 115 |
return None
|
| 116 |
|
| 117 |
@staticmethod
|
| 118 |
+
def _parse_reference_label(label_str: str) -> dict | None:
|
| 119 |
+
"""
|
| 120 |
+
Parses a reference label string into a dictionary.
|
|
|
|
| 121 |
|
| 122 |
+
Handles Python dict strings (e.g., "{'key': 'value'}") by
|
| 123 |
+
evaluating them as literals.
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
Args:
|
| 126 |
+
label_str: String representation of a dictionary
|
| 127 |
+
|
| 128 |
+
Returns:
|
| 129 |
+
Parsed dictionary, or None if parsing fails
|
| 130 |
+
"""
|
| 131 |
try:
|
| 132 |
+
return ast.literal_eval(label_str)
|
|
|
|
| 133 |
except (ValueError, SyntaxError):
|
| 134 |
return None
|
| 135 |
|
| 136 |
@staticmethod
|
| 137 |
+
def _remove_explanation_field(data: Any) -> Any:
|
| 138 |
+
"""
|
| 139 |
+
Removes the 'explanation' field from a dictionary.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
data: Dictionary or other data type
|
| 143 |
+
|
| 144 |
+
Returns:
|
| 145 |
+
The data with explanation field removed (if it was a dict),
|
| 146 |
+
or the original data unchanged
|
| 147 |
+
"""
|
| 148 |
+
if isinstance(data, dict):
|
| 149 |
+
data.pop("explanation", None)
|
| 150 |
+
return data
|
| 151 |
+
|
| 152 |
+
@staticmethod
|
| 153 |
+
def _extract_answer_field(data: Any) -> Any:
|
| 154 |
+
"""
|
| 155 |
+
Extracts the 'answer' field from a dictionary.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
data: Dictionary or other data type
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
The value of the 'answer' field if data is a dict,
|
| 162 |
+
otherwise returns the data unchanged
|
| 163 |
+
"""
|
| 164 |
+
if isinstance(data, dict):
|
| 165 |
+
return data.get("answer", None)
|
| 166 |
+
return data
|
| 167 |
+
|
| 168 |
+
@staticmethod
|
| 169 |
+
def _sort_unordered_list_field(data: Any) -> Any:
|
| 170 |
+
"""
|
| 171 |
+
Sorts the 'unordered_list' field in a dictionary.
|
| 172 |
+
|
| 173 |
+
This enables comparison of unordered lists by converting them to
|
| 174 |
+
a canonical sorted form.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
data: Dictionary potentially containing an 'unordered_list' field
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Sorted list if data is a dict with 'unordered_list',
|
| 181 |
+
otherwise returns data unchanged
|
| 182 |
+
"""
|
| 183 |
+
if isinstance(data, dict) and "unordered_list" in data:
|
| 184 |
+
return sorted(data["unordered_list"])
|
| 185 |
+
return data
|
| 186 |
|
| 187 |
@staticmethod
|
| 188 |
+
def _cast_prediction_to_reference_types(
|
| 189 |
+
reference: dict, prediction: dict
|
| 190 |
+
) -> dict | None:
|
| 191 |
"""
|
| 192 |
+
Casts prediction values to match reference types.
|
| 193 |
+
|
| 194 |
+
Ensures that predictions can be compared with references even when
|
| 195 |
+
the types differ (e.g., string "123" vs int 123, int 5 vs float 5.0).
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
reference: Reference dictionary with expected types
|
| 199 |
+
prediction: Prediction dictionary to cast
|
| 200 |
+
|
| 201 |
+
Returns:
|
| 202 |
+
Dictionary with casted values, or None if casting fails or
|
| 203 |
+
prediction is missing required keys
|
| 204 |
"""
|
| 205 |
+
if not isinstance(prediction, dict) or not isinstance(reference, dict):
|
| 206 |
+
return None
|
| 207 |
+
|
| 208 |
casted_prediction = {}
|
| 209 |
|
| 210 |
try:
|
|
|
|
| 215 |
reference_type = type(ref_value)
|
| 216 |
pred_value = prediction[ref_key]
|
| 217 |
|
| 218 |
+
# Safeguard: Python allows list("abc") -> ['a', 'b', 'c']
|
| 219 |
+
# We don't want to turn strings into character lists
|
| 220 |
+
if reference_type is list and not isinstance(pred_value, list):
|
| 221 |
return None
|
| 222 |
|
| 223 |
+
# Cast to reference type: int("123") -> 123, float(12) -> 12.0, etc.
|
| 224 |
casted_prediction[ref_key] = reference_type(pred_value)
|
| 225 |
|
| 226 |
return casted_prediction
|
|
|
|
| 229 |
return None
|
| 230 |
|
| 231 |
@staticmethod
|
| 232 |
+
def _normalise_list_field_casing(data: dict | None) -> dict | None:
|
| 233 |
+
"""
|
| 234 |
+
Converts all list items to lowercase for case-insensitive comparison.
|
| 235 |
+
|
| 236 |
+
Applied to 'ordered_list' and 'unordered_list' fields to handle variations
|
| 237 |
+
in capitalization (e.g., "Skating" vs "skating").
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
data: Dictionary potentially containing list fields
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
Dictionary with lowercased list items, or None if data is None
|
| 244 |
+
"""
|
| 245 |
+
if data is None or not isinstance(data, dict):
|
| 246 |
+
return data
|
| 247 |
+
|
| 248 |
+
# Process the first key if it's a list field
|
| 249 |
+
if data:
|
| 250 |
+
first_key = next(iter(data.keys()))
|
| 251 |
+
if first_key in ["ordered_list", "unordered_list"]:
|
| 252 |
+
data[first_key] = [item.lower() for item in data[first_key]]
|
| 253 |
+
|
| 254 |
+
return data
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def _fix_age_field_conflict(prediction: dict | None) -> dict | None:
|
| 258 |
+
"""
|
| 259 |
+
Fixes a known conflict in the dataset regarding the 'age' field.
|
| 260 |
+
|
| 261 |
+
In some dataset samples, the instruction asks for an 'age' field but
|
| 262 |
+
the reference uses 'answer'. This method normalises the prediction
|
| 263 |
+
to match the expected format.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
prediction: Prediction dictionary potentially with 'age' field
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
Dictionary with 'age' converted to 'answer', or unchanged if
|
| 270 |
+
'age' field not present
|
| 271 |
+
"""
|
| 272 |
+
if prediction is not None and isinstance(prediction, dict):
|
| 273 |
+
if "age" in prediction:
|
| 274 |
+
prediction = {"answer": prediction["age"]}
|
| 275 |
+
return prediction
|
| 276 |
+
|
| 277 |
+
def _process_arithmetic_prediction(
|
| 278 |
+
self, prediction: dict | None, reference: dict | None
|
| 279 |
+
) -> tuple[Any, Any]:
|
| 280 |
+
"""
|
| 281 |
+
Processes a prediction-reference pair for the arithmetic subset.
|
| 282 |
+
|
| 283 |
+
Applies arithmetic-specific transformations:
|
| 284 |
+
1. Fixes age field conflicts
|
| 285 |
+
2. normalises list casing
|
| 286 |
+
3. Casts prediction types to match reference
|
| 287 |
+
4. Sorts unordered lists for comparison
|
| 288 |
+
|
| 289 |
+
Args:
|
| 290 |
+
prediction: Raw prediction dictionary
|
| 291 |
+
reference: Raw reference dictionary
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Tuple of (processed_prediction, processed_reference)
|
| 295 |
+
"""
|
| 296 |
+
prediction = self._fix_age_field_conflict(prediction)
|
| 297 |
+
prediction = self._normalise_list_field_casing(prediction)
|
| 298 |
+
reference = self._normalise_list_field_casing(reference)
|
| 299 |
+
prediction = self._cast_prediction_to_reference_types(reference, prediction)
|
| 300 |
+
|
| 301 |
+
# Sort unordered lists for order-independent comparison
|
| 302 |
+
if reference and "unordered_list" in reference:
|
| 303 |
+
prediction = self._sort_unordered_list_field(prediction)
|
| 304 |
+
reference = self._sort_unordered_list_field(reference)
|
| 305 |
+
|
| 306 |
+
return prediction, reference
|
| 307 |
+
|
| 308 |
+
def _process_semantic_prediction(
|
| 309 |
+
self, prediction: Any, reference: Any
|
| 310 |
+
) -> tuple[str, str]:
|
| 311 |
+
"""
|
| 312 |
+
Processes a prediction-reference pair for the semantic subset.
|
| 313 |
+
|
| 314 |
+
Converts both to strings for comparison since semantic answers
|
| 315 |
+
may have type mismatches (e.g., int in JSON vs string in reference).
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
prediction: Raw prediction value
|
| 319 |
+
reference: Raw reference value
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
Tuple of (str(prediction), str(reference))
|
| 323 |
+
"""
|
| 324 |
+
return str(prediction), str(reference)
|
| 325 |
+
|
| 326 |
+
def _extract_predictions(
|
| 327 |
+
self, raw_predictions: list[str], subset: str
|
| 328 |
+
) -> list[Any]:
|
| 329 |
+
"""
|
| 330 |
+
Extracts and preprocesses predictions based on subset type.
|
| 331 |
+
|
| 332 |
+
Args:
|
| 333 |
+
raw_predictions: List of raw prediction strings (e.g., from LLM output)
|
| 334 |
+
subset: Either 'arithmetic' or 'semantic'
|
| 335 |
+
|
| 336 |
+
Returns:
|
| 337 |
+
List of extracted prediction values
|
| 338 |
+
"""
|
| 339 |
+
predictions = [self._extract_first_json_object(p) for p in raw_predictions]
|
| 340 |
+
|
| 341 |
+
if subset == "semantic":
|
| 342 |
+
# Since labels are not dicts, we need to extract the value from the LLM's answer field.
|
| 343 |
+
predictions = [self._extract_answer_field(p) for p in predictions]
|
| 344 |
+
elif subset == "arithmetic":
|
| 345 |
+
# Labels and LLMs differ only by the explanation field. Thus, remove.
|
| 346 |
+
predictions = [self._remove_explanation_field(p) for p in predictions]
|
| 347 |
+
|
| 348 |
+
return predictions
|
| 349 |
+
|
| 350 |
+
def _extract_references(self, raw_references: list[str], subset: str) -> list[Any]:
|
| 351 |
+
"""
|
| 352 |
+
Extracts and preprocesses references based on subset type.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
raw_references: List of raw reference strings
|
| 356 |
+
subset: Either 'arithmetic' or 'semantic'
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
List of extracted reference values
|
| 360 |
+
"""
|
| 361 |
+
if subset == "arithmetic":
|
| 362 |
+
# Arithmetic references are Python dict strings that need parsing
|
| 363 |
+
return [self._parse_reference_label(r) for r in raw_references]
|
| 364 |
+
else:
|
| 365 |
+
# Semantic references are used as-is
|
| 366 |
+
return raw_references
|
| 367 |
+
|
| 368 |
+
def _compare_pair(self, prediction: Any, reference: Any, subset: str) -> bool:
|
| 369 |
+
"""
|
| 370 |
+
Compares a single prediction-reference pair.
|
| 371 |
+
|
| 372 |
+
Args:
|
| 373 |
+
prediction: Processed prediction value
|
| 374 |
+
reference: Processed reference value
|
| 375 |
+
subset: Either 'arithmetic' or 'semantic'
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
True if prediction matches reference, False otherwise
|
| 379 |
+
"""
|
| 380 |
+
if subset == "arithmetic":
|
| 381 |
+
prediction, reference = self._process_arithmetic_prediction(
|
| 382 |
+
prediction, reference
|
| 383 |
+
)
|
| 384 |
+
elif subset == "semantic":
|
| 385 |
+
prediction, reference = self._process_semantic_prediction(
|
| 386 |
+
prediction, reference
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
return prediction == reference
|
| 390 |
|
| 391 |
def _compute(
|
| 392 |
self,
|
| 393 |
+
predictions: list[str],
|
| 394 |
+
references: list[str],
|
| 395 |
subset: Literal["arithmetic", "semantic"],
|
| 396 |
return_average: bool = True,
|
| 397 |
+
) -> dict[str, float | list[bool]]:
|
| 398 |
+
"""
|
| 399 |
+
Computes accuracy scores for the Test of Time benchmark.
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
predictions: List of prediction strings (LLM outputs)
|
| 403 |
+
references: List of reference answer strings
|
| 404 |
+
subset: Benchmark subset - either 'arithmetic' or 'semantic'
|
| 405 |
+
return_average: If True, returns average accuracy; if False,
|
| 406 |
+
returns per-sample correctness
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
Dictionary with 'accuracy' key containing either:
|
| 410 |
+
- float: average accuracy (if return_average=True)
|
| 411 |
+
- list[bool]: per-sample correctness (if return_average=False)
|
| 412 |
+
|
| 413 |
+
Raises:
|
| 414 |
+
ValueError: If subset is not 'arithmetic' or 'semantic'
|
| 415 |
+
"""
|
| 416 |
+
# Validate subset
|
| 417 |
+
if subset not in ["arithmetic", "semantic"]:
|
| 418 |
+
raise ValueError(
|
| 419 |
+
f"Invalid subset: {subset}. Must be 'arithmetic' or 'semantic'."
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
# Extract and preprocess predictions and references
|
| 423 |
+
predictions = self._extract_predictions(predictions, subset)
|
| 424 |
+
references = self._extract_references(references, subset)
|
| 425 |
+
|
| 426 |
+
# Compare each prediction-reference pair
|
| 427 |
+
accuracy_scores = [
|
| 428 |
+
self._compare_pair(pred, ref, subset)
|
| 429 |
+
for pred, ref in zip(predictions, references)
|
| 430 |
+
]
|
| 431 |
+
|
| 432 |
+
# Return average or per-sample scores
|
| 433 |
if return_average:
|
| 434 |
+
return {"accuracy": sum(accuracy_scores) / len(accuracy_scores)}
|
| 435 |
+
return {"accuracy": accuracy_scores}
|
tests/test_arithmetic_scoring.py
CHANGED
|
@@ -6,6 +6,7 @@ arithmetic_test_cases = {
|
|
| 6 |
"predictions": [
|
| 7 |
'JSON = {"explanation": "The war began in 360 BC. Since BC years count backwards, adding 8 years to 360 BC means subtracting 8 from 360, resulting in 352 BC.", "answer": "352 BC"}',
|
| 8 |
'```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "unordered_list": ["Berlin","London"]\n}\n```',
|
|
|
|
| 9 |
'```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "malformed_unordered_list": ["Berlin","London"]\n}\n```',
|
| 10 |
' "To find the date of the second most important game, we need to subtract 7 days from the date of the most important game. We can do this by counting back 7 days from April 14, 2005. April 14 - 7 days = April 7, 2005", "answer": "2005-04-07"}',
|
| 11 |
'\n```json\n{\n "explanation": "Step 1: Determine the time it takes the robot to carry a single box. The robot takes 4 hours, 34 minutes, and 30 seconds to carry 2 boxes. We divide this time by 2 to find the time per box.\\n- Hours: 4 / 2 = 2 hours\\n- Minutes: 34 / 2 = 17 minutes\\n- Seconds: 30 / 2 = 15 seconds\\nSo, it takes the robot 2 hours, 17 minutes, and 15 seconds to carry one box.\\n\\nStep 2: Calculate the total time to carry 25 boxes. We multiply the time per box by the total number of boxes (25).\\n- Total Hours: 2 hours/box * 25 boxes = 50 hours\\n- Total Minutes: 17 minutes/box * 25 boxes = 425 minutes\\n- Total Seconds: 15 seconds/box * 25 boxes = 375 seconds\\n\\nStep 3: Convert the calculated time into the standard H:M:S format by carrying over excess seconds and minutes.\\n- Convert seconds to minutes: 375 seconds is equal to 6 minutes and 15 seconds (since 375 / 60 = 6 with a remainder of 15). We add the 6 minutes to our minutes total.\\n- New total: 50 hours, (425 + 6) minutes, 15 seconds -> 50 hours, 431 minutes, 15 seconds.\\n- Convert minutes to hours: 431 minutes is equal to 7 hours and 11 minutes (since 431 / 60 = 7 with a remainder of 11). We add the 7 hours to our hours total.\\n- New total: (50 + 7) hours, 11 minutes, 15 seconds -> 57 hours, 11 minutes, 15 seconds.\\n\\nThe final time is 57 hours, 11 minutes, and 15 seconds.",\n "H": 57,\n "M": 11,\n "S": 15\n}\n```',
|
|
@@ -15,12 +16,13 @@ arithmetic_test_cases = {
|
|
| 15 |
'{"answer": "352 BC"}',
|
| 16 |
'{"unordered_list": ["London", "Berlin"]}',
|
| 17 |
'{"unordered_list": ["London", "Berlin"]}',
|
|
|
|
| 18 |
'{"answer": "2005-04-07"}',
|
| 19 |
'{"H": 57.0, "M": 11.0, "S": 15.0}',
|
| 20 |
'{"answer": 3319}',
|
| 21 |
],
|
| 22 |
-
"result": {"accuracy":
|
| 23 |
-
"per_item_accuracy": [True, True, False, False, True, True],
|
| 24 |
}
|
| 25 |
|
| 26 |
|
|
|
|
| 6 |
"predictions": [
|
| 7 |
'JSON = {"explanation": "The war began in 360 BC. Since BC years count backwards, adding 8 years to 360 BC means subtracting 8 from 360, resulting in 352 BC.", "answer": "352 BC"}',
|
| 8 |
'```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "unordered_list": ["Berlin","London"]\n}\n```',
|
| 9 |
+
'```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "unordered_list": ["berlin","london"]\n}\n```',
|
| 10 |
'```json\n{\n "explanation": "The dates provided are March 2012, September 2011, June 2017, September 2019, and June 2015. These correspond to visits to Miami, Sydney, Tokyo, London, and Nairobi respectively. The latest date among these is September 2019, which is associated with London. Therefore, London is the last city visited.",\n "malformed_unordered_list": ["Berlin","London"]\n}\n```',
|
| 11 |
' "To find the date of the second most important game, we need to subtract 7 days from the date of the most important game. We can do this by counting back 7 days from April 14, 2005. April 14 - 7 days = April 7, 2005", "answer": "2005-04-07"}',
|
| 12 |
'\n```json\n{\n "explanation": "Step 1: Determine the time it takes the robot to carry a single box. The robot takes 4 hours, 34 minutes, and 30 seconds to carry 2 boxes. We divide this time by 2 to find the time per box.\\n- Hours: 4 / 2 = 2 hours\\n- Minutes: 34 / 2 = 17 minutes\\n- Seconds: 30 / 2 = 15 seconds\\nSo, it takes the robot 2 hours, 17 minutes, and 15 seconds to carry one box.\\n\\nStep 2: Calculate the total time to carry 25 boxes. We multiply the time per box by the total number of boxes (25).\\n- Total Hours: 2 hours/box * 25 boxes = 50 hours\\n- Total Minutes: 17 minutes/box * 25 boxes = 425 minutes\\n- Total Seconds: 15 seconds/box * 25 boxes = 375 seconds\\n\\nStep 3: Convert the calculated time into the standard H:M:S format by carrying over excess seconds and minutes.\\n- Convert seconds to minutes: 375 seconds is equal to 6 minutes and 15 seconds (since 375 / 60 = 6 with a remainder of 15). We add the 6 minutes to our minutes total.\\n- New total: 50 hours, (425 + 6) minutes, 15 seconds -> 50 hours, 431 minutes, 15 seconds.\\n- Convert minutes to hours: 431 minutes is equal to 7 hours and 11 minutes (since 431 / 60 = 7 with a remainder of 11). We add the 7 hours to our hours total.\\n- New total: (50 + 7) hours, 11 minutes, 15 seconds -> 57 hours, 11 minutes, 15 seconds.\\n\\nThe final time is 57 hours, 11 minutes, and 15 seconds.",\n "H": 57,\n "M": 11,\n "S": 15\n}\n```',
|
|
|
|
| 16 |
'{"answer": "352 BC"}',
|
| 17 |
'{"unordered_list": ["London", "Berlin"]}',
|
| 18 |
'{"unordered_list": ["London", "Berlin"]}',
|
| 19 |
+
'{"unordered_list": ["London", "Berlin"]}',
|
| 20 |
'{"answer": "2005-04-07"}',
|
| 21 |
'{"H": 57.0, "M": 11.0, "S": 15.0}',
|
| 22 |
'{"answer": 3319}',
|
| 23 |
],
|
| 24 |
+
"result": {"accuracy": 5 / 7},
|
| 25 |
+
"per_item_accuracy": [True, True, True,False, False, True, True],
|
| 26 |
}
|
| 27 |
|
| 28 |
|
tests/test_arithmetic_type_casting.py
CHANGED
|
@@ -16,7 +16,7 @@
|
|
| 16 |
# days, hours, minutes, seconds <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'> 12
|
| 17 |
##################################################################################################
|
| 18 |
|
| 19 |
-
from
|
| 20 |
|
| 21 |
|
| 22 |
def test_answer_type_casting():
|
|
@@ -29,7 +29,7 @@ def test_answer_type_casting():
|
|
| 29 |
{"answer": "1032"},
|
| 30 |
]
|
| 31 |
for ref, pred in zip(references_answer_key, predictions_answer_key):
|
| 32 |
-
pred_cast = TestOfTimeAccuracy.
|
| 33 |
assert ref == pred_cast
|
| 34 |
|
| 35 |
|
|
@@ -37,7 +37,7 @@ def test_unordered_list_type_casting():
|
|
| 37 |
references_unordered_list_key = [{"unordered_list": ["Kyle", "Jason", "Joe"]}]
|
| 38 |
predictions_unordered_list_key = [{"unordered_list": ["Kyle", "Jason", "Joe"]}]
|
| 39 |
for ref, pred in zip(references_unordered_list_key, predictions_unordered_list_key):
|
| 40 |
-
pred_cast = TestOfTimeAccuracy.
|
| 41 |
assert ref == pred_cast
|
| 42 |
|
| 43 |
|
|
@@ -45,7 +45,7 @@ def test_date_type_casting():
|
|
| 45 |
references_date_key = [{"date": "12/11/2011"}]
|
| 46 |
predictions_date_key = [{"date": "12/11/2011"}]
|
| 47 |
for ref, pred in zip(references_date_key, predictions_date_key):
|
| 48 |
-
pred_cast = TestOfTimeAccuracy.
|
| 49 |
assert ref == pred_cast
|
| 50 |
|
| 51 |
|
|
@@ -53,7 +53,7 @@ def test_day_time_type_casting():
|
|
| 53 |
references_day_time_keys = [{"day": "+2", "time": "21:44:10"}]
|
| 54 |
predictions_day_time_keys = [{"day": "+2", "time": "21:44:10"}]
|
| 55 |
for ref, pred in zip(references_day_time_keys, predictions_day_time_keys):
|
| 56 |
-
pred_cast = TestOfTimeAccuracy.
|
| 57 |
assert ref == pred_cast
|
| 58 |
|
| 59 |
|
|
@@ -71,7 +71,7 @@ def test_hms_type_casting():
|
|
| 71 |
{"H": "2.0", "M": "13.0", "S": "30.0"},
|
| 72 |
]
|
| 73 |
for ref, pred in zip(references_hms_keys, predictions_hms_keys):
|
| 74 |
-
pred_cast = TestOfTimeAccuracy.
|
| 75 |
assert ref == pred_cast
|
| 76 |
|
| 77 |
|
|
@@ -83,7 +83,7 @@ def test_ordered_list_type_casting():
|
|
| 83 |
{"ordered_list": ["Joe", "Jenny", "Jason", "Dan", "Kyle"]},
|
| 84 |
]
|
| 85 |
for ref, pred in zip(references_ordered_list_key, predictions_ordered_list_key):
|
| 86 |
-
pred_cast = TestOfTimeAccuracy.
|
| 87 |
assert ref == pred_cast
|
| 88 |
|
| 89 |
# TODO: Check if I should treat float strings differently, e.g., int(float("18.0"))
|
|
@@ -101,7 +101,7 @@ def test_abc_type_casting():
|
|
| 101 |
# {"A": "80.0", "B": "22.0", "C": "20.0"},
|
| 102 |
]
|
| 103 |
for ref, pred in zip(references_abc_keys, predictions_abc_keys):
|
| 104 |
-
pred_cast = TestOfTimeAccuracy.
|
| 105 |
assert ref == pred_cast
|
| 106 |
|
| 107 |
|
|
@@ -123,7 +123,7 @@ def test_xyz_type_casting():
|
|
| 123 |
{"X": 2.0, "Y": 4.0, "Z": 36.0},
|
| 124 |
]
|
| 125 |
for ref, pred in zip(references_xyz_keys, predictions_xyz_keys):
|
| 126 |
-
pred_cast = TestOfTimeAccuracy.
|
| 127 |
assert ref == pred_cast
|
| 128 |
|
| 129 |
|
|
@@ -143,7 +143,7 @@ def test_hours_minutes_seconds_type_casting():
|
|
| 143 |
for ref, pred in zip(
|
| 144 |
references_hours_minutes_seconds_keys, predictions_hours_minutes_seconds_keys
|
| 145 |
):
|
| 146 |
-
pred_cast = TestOfTimeAccuracy.
|
| 147 |
assert ref == pred_cast
|
| 148 |
|
| 149 |
|
|
@@ -161,7 +161,7 @@ def test_hours_minutes_type_casting():
|
|
| 161 |
# {"hours": "5.0", "minutes": "0.0"},
|
| 162 |
]
|
| 163 |
for ref, pred in zip(references_hours_minutes_keys, predictions_hours_minutes_keys):
|
| 164 |
-
pred_cast = TestOfTimeAccuracy.
|
| 165 |
assert ref == pred_cast
|
| 166 |
|
| 167 |
|
|
@@ -182,5 +182,5 @@ def test_days_hours_minutes_seconds_type_casting():
|
|
| 182 |
references_days_hours_minutes_seconds_keys,
|
| 183 |
predictions_days_hours_minutes_seconds_keys,
|
| 184 |
):
|
| 185 |
-
pred_cast = TestOfTimeAccuracy.
|
| 186 |
assert ref == pred_cast
|
|
|
|
| 16 |
# days, hours, minutes, seconds <class 'int'>, <class 'int'>, <class 'int'>, <class 'int'> 12
|
| 17 |
##################################################################################################
|
| 18 |
|
| 19 |
+
from test_of_time_accuracy import TestOfTimeAccuracy
|
| 20 |
|
| 21 |
|
| 22 |
def test_answer_type_casting():
|
|
|
|
| 29 |
{"answer": "1032"},
|
| 30 |
]
|
| 31 |
for ref, pred in zip(references_answer_key, predictions_answer_key):
|
| 32 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 33 |
assert ref == pred_cast
|
| 34 |
|
| 35 |
|
|
|
|
| 37 |
references_unordered_list_key = [{"unordered_list": ["Kyle", "Jason", "Joe"]}]
|
| 38 |
predictions_unordered_list_key = [{"unordered_list": ["Kyle", "Jason", "Joe"]}]
|
| 39 |
for ref, pred in zip(references_unordered_list_key, predictions_unordered_list_key):
|
| 40 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 41 |
assert ref == pred_cast
|
| 42 |
|
| 43 |
|
|
|
|
| 45 |
references_date_key = [{"date": "12/11/2011"}]
|
| 46 |
predictions_date_key = [{"date": "12/11/2011"}]
|
| 47 |
for ref, pred in zip(references_date_key, predictions_date_key):
|
| 48 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 49 |
assert ref == pred_cast
|
| 50 |
|
| 51 |
|
|
|
|
| 53 |
references_day_time_keys = [{"day": "+2", "time": "21:44:10"}]
|
| 54 |
predictions_day_time_keys = [{"day": "+2", "time": "21:44:10"}]
|
| 55 |
for ref, pred in zip(references_day_time_keys, predictions_day_time_keys):
|
| 56 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 57 |
assert ref == pred_cast
|
| 58 |
|
| 59 |
|
|
|
|
| 71 |
{"H": "2.0", "M": "13.0", "S": "30.0"},
|
| 72 |
]
|
| 73 |
for ref, pred in zip(references_hms_keys, predictions_hms_keys):
|
| 74 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 75 |
assert ref == pred_cast
|
| 76 |
|
| 77 |
|
|
|
|
| 83 |
{"ordered_list": ["Joe", "Jenny", "Jason", "Dan", "Kyle"]},
|
| 84 |
]
|
| 85 |
for ref, pred in zip(references_ordered_list_key, predictions_ordered_list_key):
|
| 86 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 87 |
assert ref == pred_cast
|
| 88 |
|
| 89 |
# TODO: Check if I should treat float strings differently, e.g., int(float("18.0"))
|
|
|
|
| 101 |
# {"A": "80.0", "B": "22.0", "C": "20.0"},
|
| 102 |
]
|
| 103 |
for ref, pred in zip(references_abc_keys, predictions_abc_keys):
|
| 104 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 105 |
assert ref == pred_cast
|
| 106 |
|
| 107 |
|
|
|
|
| 123 |
{"X": 2.0, "Y": 4.0, "Z": 36.0},
|
| 124 |
]
|
| 125 |
for ref, pred in zip(references_xyz_keys, predictions_xyz_keys):
|
| 126 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 127 |
assert ref == pred_cast
|
| 128 |
|
| 129 |
|
|
|
|
| 143 |
for ref, pred in zip(
|
| 144 |
references_hours_minutes_seconds_keys, predictions_hours_minutes_seconds_keys
|
| 145 |
):
|
| 146 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 147 |
assert ref == pred_cast
|
| 148 |
|
| 149 |
|
|
|
|
| 161 |
# {"hours": "5.0", "minutes": "0.0"},
|
| 162 |
]
|
| 163 |
for ref, pred in zip(references_hours_minutes_keys, predictions_hours_minutes_keys):
|
| 164 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 165 |
assert ref == pred_cast
|
| 166 |
|
| 167 |
|
|
|
|
| 182 |
references_days_hours_minutes_seconds_keys,
|
| 183 |
predictions_days_hours_minutes_seconds_keys,
|
| 184 |
):
|
| 185 |
+
pred_cast = TestOfTimeAccuracy._cast_prediction_to_reference_types(ref, pred)
|
| 186 |
assert ref == pred_cast
|