File size: 6,267 Bytes
22c93a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import regex
from copy import deepcopy
from eval.eval_utils import math_equal
from eval.ocwcourses_eval_utils import (
normalize_numeric,
numeric_equality,
normalize_symbolic_equation,
SymbolicMathMixin,
)
def is_correct(item, pred_key="prediction", prec=1e-3):
pred = item[pred_key]
ans = item["answer"]
if isinstance(pred, list) and isinstance(ans, list):
pred_matched = set()
ans_matched = set()
for i in range(len(pred)):
for j in range(len(ans)):
item_cpy = deepcopy(item)
item_cpy.update({pred_key: pred[i], "answer": ans[j]})
if is_correct(item_cpy, pred_key=pred_key, prec=prec):
pred_matched.add(i)
ans_matched.add(j)
if item_cpy[pred_key] == "2,3,4":
print(item, flush=True)
print("wtf", flush=True)
return len(pred_matched) == len(pred) and len(ans_matched) == len(ans)
elif isinstance(pred, str) and isinstance(ans, str):
if "\\cup" in pred and "\\cup" in ans:
item = deepcopy(item)
item.update(
{
pred_key: pred.split("\\cup"),
"answer": ans.split("\\cup"),
}
)
return is_correct(item, pred_key=pred_key, prec=prec)
else:
label = False
try:
label = (
abs(
float(regex.sub(r",", "", str(pred)))
- float(regex.sub(r",", "", str(ans)))
)
< prec
)
except:
pass
label = label or (ans and pred == ans) or math_equal(pred, ans)
return label
else:
print(item, flush=True)
raise NotImplementedError()
def eval_math(item, pred_key="prediction", prec=1e-3):
pred = item[pred_key]
if pred_key == "program_output" and isinstance(pred, str):
pred = [pred]
ans = item["answer"]
if isinstance(pred, list) and isinstance(ans, list):
# for some questions in MATH, `reference` repeats answers
_ans = []
for a in ans:
if a not in _ans:
_ans.append(a)
ans = _ans
# some predictions for MATH questions also repeats answers
_pred = []
for a in pred:
if a not in _pred:
_pred.append(a)
# some predictions mistakenly box non-answer strings
pred = _pred[-len(ans) :]
item.update({pred_key: pred, "answer": ans})
return is_correct(item, pred_key=pred_key, prec=prec)
def eval_last_single_answer(item, pred_key="prediction", prec=1e-3):
for key in [pred_key, "answer"]:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
return is_correct(item, pred_key=pred_key, prec=prec)
def eval_agieval_gaokao_math_cloze(item, pred_key="prediction", prec=1e-3):
if pred_key == "program_output" and isinstance(item[pred_key], str):
item[pred_key] = [item[pred_key]]
for key in [pred_key, "answer"]:
assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list"
pred = item[pred_key]
ans = item["answer"]
_pred = []
for p in pred:
p = p + ";"
while p:
left_brackets = 0
for i in range(len(p)):
if p[i] == ";" or (p[i] == "," and left_brackets == 0):
_p, p = p[:i].strip(), p[i + 1 :].strip()
if _p not in _pred:
_pred.append(_p)
break
elif p[i] in "([{":
left_brackets += 1
elif p[i] in ")]}":
left_brackets -= 1
pred = _pred[-len(ans) :]
if len(pred) == len(ans):
for p, a in zip(pred, ans):
item.update(
{
pred_key: p,
"answer": a,
}
)
if not is_correct(item, pred_key=pred_key, prec=prec):
return False
return True
else:
return False
def eval_agieval_gaokao_mathqa(item, pred_key="prediction", prec=1e-3):
if pred_key == "program_output" and isinstance(item[pred_key], str):
item[pred_key] = [item[pred_key]]
pred_str = " ".join(item[pred_key])
ans = item["answer"]
tag = None
idx = -1
for t in "ABCD":
if t in pred_str and pred_str.index(t) > idx:
tag = t
idx = pred_str.index(t)
return tag == ans
def eval_math_sat(item, pred_key="prediction", prec=1e-3):
for key in [pred_key, "answer"]:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
return item[pred_key].lower() == item["answer"].lower()
def eval_mmlu_stem(item, pred_key="prediction", prec=1e-3):
return eval_math_sat(item, pred_key=pred_key, prec=prec)
def eval_ocwcourses(item, pred_key="prediction", prec=1e-3):
INVALID_ANSWER = "[invalidanswer]"
for key in [pred_key, "answer"]:
assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str"
pred = item[pred_key]
ans = item["answer"]
try:
float(ans)
normalize_fn = normalize_numeric
is_equiv = numeric_equality
answer_type = "numeric"
except ValueError:
if "=" in ans:
normalize_fn = normalize_symbolic_equation
is_equiv = lambda x, y: x == y
answer_type = "equation"
else:
normalize_fn = SymbolicMathMixin().normalize_tex
is_equiv = SymbolicMathMixin().is_tex_equiv
answer_type = "expression"
correct_answer = normalize_fn(ans)
unnormalized_answer = pred if pred else INVALID_ANSWER
model_answer = normalize_fn(unnormalized_answer)
if unnormalized_answer == INVALID_ANSWER:
acc = 0
elif model_answer == INVALID_ANSWER:
acc = 0
elif is_equiv(model_answer, correct_answer):
acc = 1
else:
acc = 0
return acc
def eval_minif2f_isabelle(item, pred_key="prediction", prec=1e-3):
return True
|