Spaces:
Runtime error
Runtime error
| import os | |
| import random | |
| import time | |
| import pickle | |
| import math | |
| from argparse import ArgumentParser | |
| from collections import defaultdict | |
| import string | |
| import csv | |
| from tqdm import tqdm | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from transformers import AutoTokenizer, AutoModelWithLMHead, AutoModelForSequenceClassification | |
| from data import Dataset | |
| from model import Model | |
| from util import save_checkpoint, ProgressMeter, AverageMeter, num_params, pad_mask | |
| from predict import predict | |
| from constants import * | |
| def tw_topic_eval(sentences, category, tw_dir, cap=None): | |
| # num matches of distinct words | |
| words = [] | |
| with open(os.path.join(tw_dir, category + '.txt'), 'r') as rf: | |
| for line in rf: | |
| words.append(line.strip().lower()) | |
| num_match = 0 | |
| for sent in sentences: | |
| sent_match = 0 | |
| sent = sent.strip().lower().split() | |
| sent = [tok.strip(string.punctuation) for tok in sent] | |
| for word in words: | |
| if word in sent: | |
| sent_match += 1 | |
| if cap is None: | |
| num_match += sent_match | |
| else: | |
| num_match += min(cap, sent_match) | |
| return num_match | |
| def perplexity(sentences, tokenizer, model, device='cuda'): | |
| # calculate perplexity | |
| with torch.no_grad(): | |
| ppl = [] | |
| sos_token = tokenizer.decode([0]) | |
| for sentence in tqdm(sentences, total=len(sentences)): | |
| full_tensor_input = tokenizer.encode(sos_token + sentence.replace(EOT_TOKEN, ' ').strip(), return_tensors='pt').to(device) | |
| full_loss = model(full_tensor_input, labels=full_tensor_input)[0].mean() | |
| ppl.append(torch.exp(full_loss).flatten().cpu().item()) | |
| return np.mean(ppl), np.std(ppl) | |
| def grammaticality(sentences, tokenizer, model, device='cuda'): | |
| with torch.no_grad(): | |
| total_good = 0 | |
| for sent in tqdm(sentences, total=len(sentences)): | |
| good_prob = F.softmax(model(tokenizer.encode(sent, return_tensors='pt').to(device))[0].flatten(), dim=0)[1] | |
| total_good += good_prob | |
| return total_good / len(sentences) # avg probability of grammaticality according to model | |
| def distinctness(results): | |
| d1, d2, d3 = defaultdict(lambda: set()), defaultdict(lambda: set()), defaultdict(lambda: set()) | |
| total_words = defaultdict(lambda: 0) | |
| for cw, outputs in results.items(): | |
| for o in outputs: | |
| o = o.replace(EOT_TOKEN, ' ').strip().split(' ') | |
| o = [str(x) for x in o] | |
| total_words[cw] += len(o) | |
| d1[cw].update(o) | |
| for i in range(len(o) - 1): | |
| d2[cw].add(o[i] + ' ' + o[i+1]) | |
| for i in range(len(o) - 2): | |
| d3[cw].add(o[i] + ' ' + o[i+1] + ' ' + o[i+2]) | |
| return_info = [] | |
| avg_d1, avg_d2, avg_d3 = 0, 0, 0 | |
| for cw in total_words.keys(): | |
| return_info.append((cw, 'DISTINCTNESS', len(d1[cw]) / total_words[cw], len(d2[cw]) / total_words[cw], len(d3[cw]) / total_words[cw])) | |
| avg_d1 += len(d1[cw]) / total_words[cw] | |
| avg_d2 += len(d2[cw]) / total_words[cw] | |
| avg_d3 += len(d3[cw]) / total_words[cw] | |
| avg_d1, avg_d2, avg_d3 = avg_d1 / len(total_words.keys()), avg_d2 / len(total_words.keys()), avg_d3 / len(total_words.keys()) | |
| return return_info, (avg_d1, avg_d2, avg_d3) | |
| if __name__=='__main__': | |
| parser = ArgumentParser() | |
| parser.add_argument('--log_file', type=str, required=True, help='where to load results from') | |
| parser.add_argument('--tw_dir', type=str, default='test_wordlists', help='test wordlists') | |
| parser.add_argument('--batch_size', type=int, default=8, help='max samples at a time') | |
| parser.add_argument('--cap_per_example', type=int, default=None, help='max matches to count per sentence') | |
| parser.add_argument('--device', type=str, default='cuda', choices=['cpu', 'cuda']) | |
| args = parser.parse_args() | |
| tw_topic_match_c_total = 0 | |
| category_totals_c = defaultdict(lambda:0) | |
| results = defaultdict(lambda: []) | |
| with open(args.log_file, 'r') as rf: | |
| data = list(csv.DictReader(rf)) | |
| for line in data: | |
| results[line['category']].append(line['generation']) | |
| all_c_sents = [] | |
| for category, condition_results in results.items(): | |
| tw_topic_match_c = tw_topic_eval(condition_results, category, args.tw_dir, cap=args.cap_per_example) | |
| tw_topic_match_c_total += tw_topic_match_c | |
| category_totals_c[category] += tw_topic_match_c | |
| all_c_sents += condition_results | |
| print('Test wordlist matches (divide by num outputs to get the Success metric):', tw_topic_match_c_total) | |
| print('per category:', category_totals_c) | |
| dist_info_by_category, dist_overall = distinctness(results) | |
| print('Overall avg distinctness:', dist_overall) | |
| print('per category:', dist_info_by_category) | |
| grammar_tokenizer = AutoTokenizer.from_pretrained('textattack/roberta-base-CoLA') | |
| grammar_model = AutoModelForSequenceClassification.from_pretrained('textattack/roberta-base-CoLA').to(args.device) | |
| grammar_model.eval() | |
| print('grammaticality:', grammaticality(all_c_sents, grammar_tokenizer, grammar_model, device=args.device)) | |
| eval_tokenizer = AutoTokenizer.from_pretrained('openai-gpt') | |
| eval_model = AutoModelWithLMHead.from_pretrained('openai-gpt').to(args.device) | |
| eval_model.eval() | |
| print('GPT perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model)) | |
| eval_tokenizer = AutoTokenizer.from_pretrained('transfo-xl-wt103') | |
| eval_model = AutoModelWithLMHead.from_pretrained('transfo-xl-wt103').to(args.device) | |
| eval_model.eval() | |
| print('TFXL perplexity:', perplexity(all_c_sents, eval_tokenizer, eval_model)) | |