Skier8402 commited on
Commit
280c08c
·
verified ·
1 Parent(s): 75b68ee

Upload 2 files

Browse files
Files changed (2) hide show
  1. nlp_gradio_llm.py +180 -0
  2. requirements.txt +14 -0
nlp_gradio_llm.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+ from captum.attr import IntegratedGradients
4
+ import matplotlib
5
+ matplotlib.use('Agg')
6
+ import matplotlib.pyplot as plt
7
+ import gradio as gr # Added for interactive UI
8
+
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # ---------------- Load Smaller LLaMA ----------------
12
+ model_name = "meta-llama/Llama-3.2-1B"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
15
+ model.eval()
16
+
17
+ # ---------------- Modularized Functions ----------------
18
+ def compute_attributions(prompt, target_token):
19
+ """
20
+ Compute Integrated Gradients attributions for a given prompt and target token.
21
+ Appeals to devs/ML: Shows model interpretability; business: Builds trust by explaining AI decisions.
22
+ """
23
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
24
+ target_id = tokenizer(target_token, add_special_tokens=False)["input_ids"][0]
25
+
26
+ def forward_func(embeds):
27
+ outputs = model(inputs_embeds=embeds)
28
+ logits = outputs.logits[:, -1, :]
29
+ probs = torch.softmax(logits, dim=-1)
30
+ return probs[:, target_id]
31
+
32
+ embeddings = model.get_input_embeddings()(inputs["input_ids"])
33
+ embeddings.requires_grad_(True)
34
+
35
+ ig = IntegratedGradients(forward_func)
36
+ attributions, delta = ig.attribute(
37
+ embeddings, n_steps=30, return_convergence_delta=True
38
+ )
39
+
40
+ token_attr = attributions.sum(-1).squeeze().detach().cpu()
41
+ tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
42
+
43
+ # Normalize safely
44
+ token_attr_np = token_attr.numpy()
45
+ norm_denom = (abs(token_attr_np).max() + 1e-8)
46
+ token_attr_np = token_attr_np / norm_denom
47
+
48
+ return tokens, token_attr_np
49
+
50
+ def create_visualization(tokens, token_attr_np, prompt, target_token):
51
+ """
52
+ Generate an appealing visualization: Grid of colored token boxes.
53
+ Enhanced for mixed audience: Clean design, simple explanations, professional look.
54
+ """
55
+ num_tokens = max(1, len(tokens))
56
+ cols = min(max(3, int(num_tokens**0.5)), 8)
57
+ rows = (num_tokens + cols - 1) // cols
58
+
59
+ box_w = 1.0 / cols
60
+ box_h = 0.18
61
+
62
+ fig_h = max(4, rows * 0.7 + 2.0) # Increased height for more spacing
63
+ fig = plt.figure(figsize=(12, fig_h))
64
+
65
+ # Add title for context
66
+ fig.suptitle(f"Token Contributions to Predicting '{target_token}' in: '{prompt}'",
67
+ fontsize=14, y=0.95, ha='center')
68
+
69
+ ax = fig.add_axes([0, 0.30, 1, 0.60]) # Shift grid higher for more bottom space
70
+ ax.set_xlim(0, cols)
71
+ ax.set_ylim(0, rows)
72
+ ax.axis('off')
73
+
74
+ # Normalize for colormap (0-1 range)
75
+ minv, maxv = token_attr_np.min(), token_attr_np.max()
76
+ norm = (token_attr_np - minv) / (maxv - minv + 1e-8)
77
+ cmap = plt.get_cmap('RdYlGn') # Green positive, red negative
78
+
79
+ from matplotlib.patches import FancyBboxPatch
80
+ for idx, (tok, score_norm) in enumerate(zip(tokens, norm)):
81
+ r = idx // cols
82
+ c = idx % cols
83
+ x = c
84
+ y = rows - 1 - r
85
+ color = cmap(score_norm)
86
+ pad = 0.08
87
+ rect = FancyBboxPatch((x + pad*0.15, y + pad*0.15), 1 - pad, box_h - pad*0.3,
88
+ boxstyle='round,pad=0.02', linewidth=0.8,
89
+ facecolor=color, edgecolor='gray', alpha=0.95) # Softer edges
90
+ ax.add_patch(rect)
91
+ # Improved text: Larger font, wrap long tokens
92
+ display_tok = tok.replace('Ġ', ' ') if isinstance(tok, str) else str(tok) # Space for subwords
93
+ ax.text(x + 0.5, y + box_h/2, display_tok, ha='center', va='center',
94
+ fontsize=10, fontweight='bold') # Bold for readability
95
+
96
+ # Enhanced colorbar - lowered position
97
+ sm = plt.cm.ScalarMappable(cmap=cmap)
98
+ sm.set_array([0, 1])
99
+ cax = fig.add_axes([0.1, 0.22, 0.8, 0.04]) # Lowered from 0.18
100
+ cb = fig.colorbar(sm, cax=cax, orientation='horizontal')
101
+ cb.set_label('Contribution Strength', fontsize=11, fontweight='bold')
102
+
103
+ # Markers for audience-friendly explanation - lowered
104
+ fig.text(0.05, 0.16, 'Green Positive (helps prediction)', fontsize=10, ha='left')
105
+ fig.text(0.75, 0.16, 'Red Negative (hinders prediction)', fontsize=10, ha='right')
106
+
107
+ # Engaging caption for mixed audience - shortened and lowered with wrap
108
+ caption = (
109
+ "How input tokens influence the model's target prediction: Green supports (builds AI trust), "
110
+ "red opposes. For debugging (devs), reasoning insights (ML), reliable decisions (business). Normalized."
111
+ )
112
+ fig.text(0.5, 0.08, caption, fontsize=9, ha='center', va='top', wrap=True) # Smaller font, lower pos
113
+
114
+ # Save with higher quality
115
+ out_path = 'token_attributions.png'
116
+ fig.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white')
117
+ plt.close(fig) # Clean up
118
+ return out_path
119
+
120
+ # ---------------- Gradio Interface for Interactivity ----------------
121
+ def generate_attribution(prompt, target_token):
122
+ """
123
+ Gradio wrapper: Compute and visualize for custom inputs.
124
+ Default example: France capital for quick demo.
125
+ """
126
+ if not prompt.strip():
127
+ prompt = "The capital of France is"
128
+ if not target_token.strip():
129
+ target_token = " Paris"
130
+
131
+ # Add check for long prompts to prevent overload
132
+ if len(prompt.split()) > 50:
133
+ return "Warning: Prompt too long (>50 tokens). Shorten for better performance."
134
+
135
+ try:
136
+ tokens, token_attr_np = compute_attributions(prompt, target_token)
137
+ img_path = create_visualization(tokens, token_attr_np, prompt, target_token)
138
+ return img_path
139
+ except Exception as e:
140
+ return f"Error: {str(e)}"
141
+
142
+ # Launch interactive app
143
+ iface = gr.Interface(
144
+ fn=generate_attribution,
145
+ inputs=[
146
+ gr.Textbox(label="Prompt", value="The capital of France is", placeholder="Enter your prompt..."),
147
+ gr.Textbox(label="Target Token", value=" Paris", placeholder="Enter target token (e.g., ' Paris')")
148
+ ],
149
+ outputs=gr.Image(label="Token Attribution Visualization"),
150
+ title="AI Interpretability Explorer: See How Tokens Influence Predictions",
151
+ description="Input a prompt and target token to visualize token contributions using Integrated Gradients on LLaMA. "
152
+ "Explore model reasoning interactively.",
153
+ # Insert a collapsible Feynman-style explanation and quick cheat-sheet actions using HTML so Gradio shows it above the app.
154
+ # We use safe escaping for the cheat text when embedding into HTML/JS.
155
+ # The small JS below enables a copy-to-clipboard action and a downloadable .txt file via data URI.
156
+ article="""
157
+ ### How it works — Feynman-style
158
+
159
+ This tool explains which input tokens most influence the model's next-token prediction using Integrated Gradients.
160
+
161
+ - What it does: Interpolates from a baseline to the actual input in embedding space, accumulates gradients along the path, and attributes importance to each input token.
162
+ - Why it helps: Highlights which tokens push the model toward (green) or away from (red) the chosen target token. Useful for debugging, bias detection, and model transparency.
163
+ - How to read results: Higher positive values (green) mean the token increases the probability of the target; negative values (red) mean the token reduces it. Values are normalized per example.
164
+ - Watch-outs: IG depends on the baseline choice and number of interpolation steps. Subword tokens (e.g., Ġ) are shown with spaces; long prompts may be noisy.
165
+ """
166
+ ,
167
+ examples=[
168
+ ["The capital of France is", " Paris"],
169
+ ["I love this product because", " it's amazing"],
170
+ ["The weather today is", " sunny"]
171
+ ]
172
+ )
173
+
174
+ if __name__ == "__main__":
175
+ # Run the original example for backward compatibility, then launch Gradio
176
+ print("Generating default example...")
177
+ default_img = generate_attribution("", "")
178
+ print(f"Default plot saved to: token_attributions.png")
179
+ print("\nLaunching interactive Gradio app... Open in browser for custom examples.")
180
+ iface.launch(share=True, server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ matplotlib
4
+ captum
5
+ ipython
6
+ transformers
7
+ pillow
8
+ lime
9
+ numpy
10
+ scikit-image
11
+ timm
12
+ streamlit
13
+ gradio
14
+ accelerate