File size: 6,267 Bytes
22c93a7 |
|
import regex
from copy import deepcopy
from eval.eval_utils import math_equal
from eval.ocwcourses_eval_utils import (
normalize_numeric,
numeric_equality,
normalize_symbolic_equation,
SymbolicMathMixin,
)
def is_correct(item, pred_key="prediction", prec=1e-3):
pred = item[pred_key]
ans = item["answer"]
if isinstance(pred, list) and isinstance(ans, list):
pred_matched = set()
ans_matched = set()
for i in range(len(pred)):
for j in range(len(ans)):
item_cpy = deepcopy(item)
item_cpy.update({pred_key: pred[i], "answer": ans[j]})
if is_correct(item_cpy, pred_key=pred_key, prec=prec):
pred_matched.add(i)
ans_matched.add(j)
if item_cpy[pred_key] == "2,3,4":
print(item, flush=True)
print("wtf", flush=True)
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
elif isinstance(pred, str) and isinstance(ans, str):
if "\\cup" in pred and "\\cup" in ans:
item = deepcopy(item)
item.update(
{
pred_key: pred.split("\\cup"),
"answer": ans.split("\\cup"),
}
)
return is_correct(item, pred_key=pred_key, prec=prec)
else:
label = False
try:
label = (
abs(
float(regex.sub(r",", "", str(pred)))
- float(regex.sub(r",", "", str(ans)))
)
< prec
)
except:
pass
label = label or (ans and pred == ans) or math_equal(pred, ans)
return label
else:
print(item, flush=True)
raise NotImplementedError()
def eval_math(item, pred_key="prediction", prec=1e-3):
pred = item[pred_key]
if pred_key == "program_output" and isinstance(pred, str):
pred = [pred]
ans = item["answer"]
if isinstance(pred, list) and isinstance(ans, list):
# for some questions in MATH, `reference` repeats answers
_ans = []
for a in ans:
if a not in _ans:
_ans.append(a)
ans = _ans
# some predictions for MATH questions also repeats answers
_pred = []
for a in pred:
if a not in _pred:
_pred.append(a)
# some predictions mistakenly box non-answer strings
pred = _pred[-len(ans) :]
item.update({pred_key: pred, "answer": ans})
return is_correct(item, pred_key=pred_key, prec=prec)
def eval_last_single_answer(item, pred_key="prediction", prec=1e-3):
for key in [pred_key, "answer"]:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
return is_correct(item, pred_key=pred_key, prec=prec)
def eval_agieval_gaokao_math_cloze(item, pred_key="prediction", prec=1e-3):
if pred_key == "program_output" and isinstance(item[pred_key], str):
item[pred_key] = [item[pred_key]]
for key in [pred_key, "answer"]:
assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
pred = item[pred_key]
ans = item["answer"]
_pred = []
for p in pred:
p = p + ";"
while p:
left_brackets = 0
for i in range(len(p)):
if p[i] == ";" or (p[i] == "," and left_brackets == 0):
_p, p = p[:i].strip(), p[i + 1 :].strip()
if _p not in _pred:
_pred.append(_p)
break
elif p[i] in "([{":
left_brackets += 1
elif p[i] in ")]}":
left_brackets -= 1
pred = _pred[-len(ans) :]
if len(pred) == len(ans):
for p, a in zip(pred, ans):
item.update(
{
pred_key: p,
"answer": a,
}
)
if not is_correct(item, pred_key=pred_key, prec=prec):
return False
return True
else:
return False
def eval_agieval_gaokao_mathqa(item, pred_key="prediction", prec=1e-3):
if pred_key == "program_output" and isinstance(item[pred_key], str):
item[pred_key] = [item[pred_key]]
pred_str = " ".join(item[pred_key])
ans = item["answer"]
tag = None
idx = -1
for t in "ABCD":
if t in pred_str and pred_str.index(t) > idx:
tag = t
idx = pred_str.index(t)
return tag == ans
def eval_math_sat(item, pred_key="prediction", prec=1e-3):
for key in [pred_key, "answer"]:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
return item[pred_key].lower() == item["answer"].lower()
def eval_mmlu_stem(item, pred_key="prediction", prec=1e-3):
return eval_math_sat(item, pred_key=pred_key, prec=prec)
def eval_ocwcourses(item, pred_key="prediction", prec=1e-3):
INVALID_ANSWER = "[invalidanswer]"
for key in [pred_key, "answer"]:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
pred = item[pred_key]
ans = item["answer"]
try:
float(ans)
normalize_fn = normalize_numeric
is_equiv = numeric_equality
answer_type = "numeric"
except ValueError:
if "=" in ans:
normalize_fn = normalize_symbolic_equation
is_equiv = lambda x, y: x == y
answer_type = "equation"
else:
normalize_fn = SymbolicMathMixin().normalize_tex
is_equiv = SymbolicMathMixin().is_tex_equiv
answer_type = "expression"
correct_answer = normalize_fn(ans)
unnormalized_answer = pred if pred else INVALID_ANSWER
model_answer = normalize_fn(unnormalized_answer)
if unnormalized_answer == INVALID_ANSWER:
acc = 0
elif model_answer == INVALID_ANSWER:
acc = 0
elif is_equiv(model_answer, correct_answer):
acc = 1
else:
acc = 0
return acc
def eval_minif2f_isabelle(item, pred_key="prediction", prec=1e-3):
return True
|