Spaces:
Sleeping
Sleeping
Martijn van Beers
commited on
Commit
·
9f74b46
1
Parent(s):
8821877
Initial implementation
Browse files- app.py +154 -0
- description.md +4 -0
- examples.csv +3 -0
- notice.md +2 -0
- requirements.txt +7 -0
app.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import pandas
|
| 3 |
+
import seaborn
|
| 4 |
+
import gradio
|
| 5 |
+
import pathlib
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import matplotlib
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import numpy
|
| 11 |
+
from sklearn.metrics.pairwise import cosine_distances
|
| 12 |
+
|
| 13 |
+
from transformers import (
|
| 14 |
+
AutoConfig,
|
| 15 |
+
AutoTokenizer,
|
| 16 |
+
AutoModelForSequenceClassification, AutoModelForMaskedLM
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
## Rollout Helper Function
|
| 20 |
+
def compute_joint_attention(att_mat, res=True):
|
| 21 |
+
if res:
|
| 22 |
+
residual_att = numpy.eye(att_mat.shape[1])[None,...]
|
| 23 |
+
att_mat = att_mat + residual_att
|
| 24 |
+
att_mat = att_mat / att_mat.sum(axis=-1)[...,None]
|
| 25 |
+
|
| 26 |
+
joint_attentions = numpy.zeros(att_mat.shape)
|
| 27 |
+
layers = joint_attentions.shape[0]
|
| 28 |
+
joint_attentions[0] = att_mat[0]
|
| 29 |
+
for i in numpy.arange(1,layers):
|
| 30 |
+
joint_attentions[i] = att_mat[i].dot(joint_attentions[i-1])
|
| 31 |
+
|
| 32 |
+
return joint_attentions
|
| 33 |
+
|
| 34 |
+
def create_plot(all_tokens, score_data):
|
| 35 |
+
LAYERS = list(range(12))
|
| 36 |
+
fig, axs = plt.subplots(6, 2, figsize=(8, 24))
|
| 37 |
+
plt.subplots_adjust(top=0.98, bottom=0.05, hspace=0.5, wspace=0.5)
|
| 38 |
+
for layer in LAYERS:
|
| 39 |
+
a = (layer)//2
|
| 40 |
+
b = layer%2
|
| 41 |
+
seaborn.heatmap(
|
| 42 |
+
ax=axs[a, b],
|
| 43 |
+
data=pandas.DataFrame(score_data[layer], index= all_tokens, columns=all_tokens),
|
| 44 |
+
cmap="Blues",
|
| 45 |
+
annot=False,
|
| 46 |
+
cbar=False
|
| 47 |
+
)
|
| 48 |
+
axs[a, b].set_title(f"Layer: {layer+1}")
|
| 49 |
+
return fig
|
| 50 |
+
|
| 51 |
+
matplotlib.use('agg')
|
| 52 |
+
|
| 53 |
+
DISTANCE_FUNC = {
|
| 54 |
+
'cosine': cosine_distances
|
| 55 |
+
}
|
| 56 |
+
MODEL_PATH = {
|
| 57 |
+
'bert': 'bert-base-uncased',
|
| 58 |
+
'roberta': 'roberta-base',
|
| 59 |
+
}
|
| 60 |
+
|
| 61 |
+
MODEL_NAME = 'bert'
|
| 62 |
+
#MODEL_NAME = 'roberta'
|
| 63 |
+
METRIC = 'cosine'
|
| 64 |
+
|
| 65 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 66 |
+
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
|
| 67 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
|
| 68 |
+
model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def run(mname, sent):
|
| 72 |
+
global MODEL_NAME, config, model, tokenizer
|
| 73 |
+
if mname != MODEL_NAME:
|
| 74 |
+
MODEL_NAME = mname
|
| 75 |
+
config = AutoConfig.from_pretrained(MODEL_PATH[MODEL_NAME])
|
| 76 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH[MODEL_NAME])
|
| 77 |
+
model = AutoModelForMaskedLM.from_pretrained(MODEL_PATH[MODEL_NAME], config=config).to(device)
|
| 78 |
+
sent = re.sub(r".MASK.", tokenizer.mask_token, sent)
|
| 79 |
+
inputs = tokenizer(sent, return_token_type_ids=True, return_tensors="pt")
|
| 80 |
+
|
| 81 |
+
## Cpmpute: layerwise value zeroing
|
| 82 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 83 |
+
with torch.no_grad():
|
| 84 |
+
outputs = model(inputs['input_ids'],
|
| 85 |
+
attention_mask=inputs['attention_mask'],
|
| 86 |
+
token_type_ids=inputs['token_type_ids'],
|
| 87 |
+
output_hidden_states=True, output_attentions=False)
|
| 88 |
+
|
| 89 |
+
org_hidden_states = torch.stack(outputs['hidden_states']).squeeze(1)
|
| 90 |
+
input_shape = inputs['input_ids'].size()
|
| 91 |
+
batch_size, seq_length = input_shape
|
| 92 |
+
|
| 93 |
+
score_matrix = numpy.zeros((config.num_hidden_layers, seq_length, seq_length))
|
| 94 |
+
for l, layer_module in enumerate(getattr(model, MODEL_NAME).encoder.layer):
|
| 95 |
+
for t in range(seq_length):
|
| 96 |
+
extended_blanking_attention_mask: torch.Tensor = getattr(model, MODEL_NAME).get_extended_attention_mask(inputs['attention_mask'], input_shape, device)
|
| 97 |
+
with torch.no_grad():
|
| 98 |
+
layer_outputs = layer_module(org_hidden_states[l].unsqueeze(0), # previous layer's original output
|
| 99 |
+
attention_mask=extended_blanking_attention_mask,
|
| 100 |
+
output_attentions=False,
|
| 101 |
+
zero_value_index=t,
|
| 102 |
+
)
|
| 103 |
+
hidden_states = layer_outputs[0].squeeze().detach().cpu().numpy()
|
| 104 |
+
# compute similarity between original and new outputs
|
| 105 |
+
# cosine
|
| 106 |
+
x = hidden_states
|
| 107 |
+
y = org_hidden_states[l+1].detach().cpu().numpy()
|
| 108 |
+
|
| 109 |
+
distances = DISTANCE_FUNC[METRIC](x, y).diagonal()
|
| 110 |
+
score_matrix[l, :, t] = distances
|
| 111 |
+
|
| 112 |
+
valuezeroing_scores = score_matrix / numpy.sum(score_matrix, axis=-1, keepdims=True)
|
| 113 |
+
rollout_valuezeroing_scores = compute_joint_attention(valuezeroing_scores, res=False)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# Plot:
|
| 117 |
+
cmap = "Blues"
|
| 118 |
+
all_tokens = [tokenizer.convert_ids_to_tokens(t) for t in inputs['input_ids']]
|
| 119 |
+
rollout_fig = create_plot(all_tokens, rollout_valuezeroing_scores)
|
| 120 |
+
value_fig = create_plot(all_tokens, valuezeroing_scores)
|
| 121 |
+
|
| 122 |
+
return rollout_fig, value_fig
|
| 123 |
+
|
| 124 |
+
examples = pandas.read_csv("examples.csv").to_numpy().tolist()
|
| 125 |
+
|
| 126 |
+
with gradio.Blocks(
|
| 127 |
+
title="Differences with/without zero-valuing",
|
| 128 |
+
css= ".output-image > img {height: 2000px !important; max-height: none !important;} "
|
| 129 |
+
) as iface:
|
| 130 |
+
gradio.Markdown(pathlib.Path("description.md").read_text)
|
| 131 |
+
with gradio.Row(equal_height=True):
|
| 132 |
+
with gradio.Column(scale=4):
|
| 133 |
+
sent = gradio.Textbox(label="Input sentence")
|
| 134 |
+
with gradio.Column(scale=1):
|
| 135 |
+
model_choice = gradio.Dropdown(choices=['bert', 'roberta'], value="bert")
|
| 136 |
+
but = gradio.Button("Submit")
|
| 137 |
+
gradio.Examples(examples, [sent])
|
| 138 |
+
with gradio.Row(equal_height=True):
|
| 139 |
+
with gradio.Column():
|
| 140 |
+
gradio.Markdown("### With Rollout")
|
| 141 |
+
rollout_result = gradio.Plot()
|
| 142 |
+
with gradio.Column():
|
| 143 |
+
gradio.Markdown("### Without Rollout")
|
| 144 |
+
value_result = gradio.Plot()
|
| 145 |
+
with gradio.Accordion("Some more details"):
|
| 146 |
+
gradio.Markdown(pathlib.Path("notice.md").read_text)
|
| 147 |
+
|
| 148 |
+
but.click(run,
|
| 149 |
+
inputs=[model_choice, sent],
|
| 150 |
+
outputs=[rollout_result, value_result]
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
iface.launch()
|
description.md
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Value Zeroing
|
| 2 |
+
|
| 3 |
+
Demo of the effect of value-zeroing (Hosein, 2022) both with Attention Rollout (Abnar & Zuidema, 2020)
|
| 4 |
+
and without.
|
examples.csv
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sentence
|
| 2 |
+
"You either win the game or you [MASK] the game."
|
| 3 |
+
"The author talked to Sarah about [MASK] book."
|
notice.md
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
* Shown on the left are the results after applying attention rollout, as defined by Abnar & Zuidema (2020)
|
| 2 |
+
* On the left the results before.
|
requirements.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
git+https://github.com/martijnvanbeers/transformers@feature/transformer-explainability
|
| 3 |
+
pandas
|
| 4 |
+
seaborn
|
| 5 |
+
matplotlib
|
| 6 |
+
numpy
|
| 7 |
+
scikit-learn
|