Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- nlp_gradio_llm.py +180 -0
- requirements.txt +14 -0
nlp_gradio_llm.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3 |
+
from captum.attr import IntegratedGradients
|
| 4 |
+
import matplotlib
|
| 5 |
+
matplotlib.use('Agg')
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import gradio as gr # Added for interactive UI
|
| 8 |
+
|
| 9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
|
| 11 |
+
# ---------------- Load Smaller LLaMA ----------------
|
| 12 |
+
model_name = "meta-llama/Llama-3.2-1B"
|
| 13 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 14 |
+
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
|
| 15 |
+
model.eval()
|
| 16 |
+
|
| 17 |
+
# ---------------- Modularized Functions ----------------
|
| 18 |
+
def compute_attributions(prompt, target_token):
|
| 19 |
+
"""
|
| 20 |
+
Compute Integrated Gradients attributions for a given prompt and target token.
|
| 21 |
+
Appeals to devs/ML: Shows model interpretability; business: Builds trust by explaining AI decisions.
|
| 22 |
+
"""
|
| 23 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 24 |
+
target_id = tokenizer(target_token, add_special_tokens=False)["input_ids"][0]
|
| 25 |
+
|
| 26 |
+
def forward_func(embeds):
|
| 27 |
+
outputs = model(inputs_embeds=embeds)
|
| 28 |
+
logits = outputs.logits[:, -1, :]
|
| 29 |
+
probs = torch.softmax(logits, dim=-1)
|
| 30 |
+
return probs[:, target_id]
|
| 31 |
+
|
| 32 |
+
embeddings = model.get_input_embeddings()(inputs["input_ids"])
|
| 33 |
+
embeddings.requires_grad_(True)
|
| 34 |
+
|
| 35 |
+
ig = IntegratedGradients(forward_func)
|
| 36 |
+
attributions, delta = ig.attribute(
|
| 37 |
+
embeddings, n_steps=30, return_convergence_delta=True
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
token_attr = attributions.sum(-1).squeeze().detach().cpu()
|
| 41 |
+
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
|
| 42 |
+
|
| 43 |
+
# Normalize safely
|
| 44 |
+
token_attr_np = token_attr.numpy()
|
| 45 |
+
norm_denom = (abs(token_attr_np).max() + 1e-8)
|
| 46 |
+
token_attr_np = token_attr_np / norm_denom
|
| 47 |
+
|
| 48 |
+
return tokens, token_attr_np
|
| 49 |
+
|
| 50 |
+
def create_visualization(tokens, token_attr_np, prompt, target_token):
|
| 51 |
+
"""
|
| 52 |
+
Generate an appealing visualization: Grid of colored token boxes.
|
| 53 |
+
Enhanced for mixed audience: Clean design, simple explanations, professional look.
|
| 54 |
+
"""
|
| 55 |
+
num_tokens = max(1, len(tokens))
|
| 56 |
+
cols = min(max(3, int(num_tokens**0.5)), 8)
|
| 57 |
+
rows = (num_tokens + cols - 1) // cols
|
| 58 |
+
|
| 59 |
+
box_w = 1.0 / cols
|
| 60 |
+
box_h = 0.18
|
| 61 |
+
|
| 62 |
+
fig_h = max(4, rows * 0.7 + 2.0) # Increased height for more spacing
|
| 63 |
+
fig = plt.figure(figsize=(12, fig_h))
|
| 64 |
+
|
| 65 |
+
# Add title for context
|
| 66 |
+
fig.suptitle(f"Token Contributions to Predicting '{target_token}' in: '{prompt}'",
|
| 67 |
+
fontsize=14, y=0.95, ha='center')
|
| 68 |
+
|
| 69 |
+
ax = fig.add_axes([0, 0.30, 1, 0.60]) # Shift grid higher for more bottom space
|
| 70 |
+
ax.set_xlim(0, cols)
|
| 71 |
+
ax.set_ylim(0, rows)
|
| 72 |
+
ax.axis('off')
|
| 73 |
+
|
| 74 |
+
# Normalize for colormap (0-1 range)
|
| 75 |
+
minv, maxv = token_attr_np.min(), token_attr_np.max()
|
| 76 |
+
norm = (token_attr_np - minv) / (maxv - minv + 1e-8)
|
| 77 |
+
cmap = plt.get_cmap('RdYlGn') # Green positive, red negative
|
| 78 |
+
|
| 79 |
+
from matplotlib.patches import FancyBboxPatch
|
| 80 |
+
for idx, (tok, score_norm) in enumerate(zip(tokens, norm)):
|
| 81 |
+
r = idx // cols
|
| 82 |
+
c = idx % cols
|
| 83 |
+
x = c
|
| 84 |
+
y = rows - 1 - r
|
| 85 |
+
color = cmap(score_norm)
|
| 86 |
+
pad = 0.08
|
| 87 |
+
rect = FancyBboxPatch((x + pad*0.15, y + pad*0.15), 1 - pad, box_h - pad*0.3,
|
| 88 |
+
boxstyle='round,pad=0.02', linewidth=0.8,
|
| 89 |
+
facecolor=color, edgecolor='gray', alpha=0.95) # Softer edges
|
| 90 |
+
ax.add_patch(rect)
|
| 91 |
+
# Improved text: Larger font, wrap long tokens
|
| 92 |
+
display_tok = tok.replace('Ġ', ' ') if isinstance(tok, str) else str(tok) # Space for subwords
|
| 93 |
+
ax.text(x + 0.5, y + box_h/2, display_tok, ha='center', va='center',
|
| 94 |
+
fontsize=10, fontweight='bold') # Bold for readability
|
| 95 |
+
|
| 96 |
+
# Enhanced colorbar - lowered position
|
| 97 |
+
sm = plt.cm.ScalarMappable(cmap=cmap)
|
| 98 |
+
sm.set_array([0, 1])
|
| 99 |
+
cax = fig.add_axes([0.1, 0.22, 0.8, 0.04]) # Lowered from 0.18
|
| 100 |
+
cb = fig.colorbar(sm, cax=cax, orientation='horizontal')
|
| 101 |
+
cb.set_label('Contribution Strength', fontsize=11, fontweight='bold')
|
| 102 |
+
|
| 103 |
+
# Markers for audience-friendly explanation - lowered
|
| 104 |
+
fig.text(0.05, 0.16, 'Green Positive (helps prediction)', fontsize=10, ha='left')
|
| 105 |
+
fig.text(0.75, 0.16, 'Red Negative (hinders prediction)', fontsize=10, ha='right')
|
| 106 |
+
|
| 107 |
+
# Engaging caption for mixed audience - shortened and lowered with wrap
|
| 108 |
+
caption = (
|
| 109 |
+
"How input tokens influence the model's target prediction: Green supports (builds AI trust), "
|
| 110 |
+
"red opposes. For debugging (devs), reasoning insights (ML), reliable decisions (business). Normalized."
|
| 111 |
+
)
|
| 112 |
+
fig.text(0.5, 0.08, caption, fontsize=9, ha='center', va='top', wrap=True) # Smaller font, lower pos
|
| 113 |
+
|
| 114 |
+
# Save with higher quality
|
| 115 |
+
out_path = 'token_attributions.png'
|
| 116 |
+
fig.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white')
|
| 117 |
+
plt.close(fig) # Clean up
|
| 118 |
+
return out_path
|
| 119 |
+
|
| 120 |
+
# ---------------- Gradio Interface for Interactivity ----------------
|
| 121 |
+
def generate_attribution(prompt, target_token):
|
| 122 |
+
"""
|
| 123 |
+
Gradio wrapper: Compute and visualize for custom inputs.
|
| 124 |
+
Default example: France capital for quick demo.
|
| 125 |
+
"""
|
| 126 |
+
if not prompt.strip():
|
| 127 |
+
prompt = "The capital of France is"
|
| 128 |
+
if not target_token.strip():
|
| 129 |
+
target_token = " Paris"
|
| 130 |
+
|
| 131 |
+
# Add check for long prompts to prevent overload
|
| 132 |
+
if len(prompt.split()) > 50:
|
| 133 |
+
return "Warning: Prompt too long (>50 tokens). Shorten for better performance."
|
| 134 |
+
|
| 135 |
+
try:
|
| 136 |
+
tokens, token_attr_np = compute_attributions(prompt, target_token)
|
| 137 |
+
img_path = create_visualization(tokens, token_attr_np, prompt, target_token)
|
| 138 |
+
return img_path
|
| 139 |
+
except Exception as e:
|
| 140 |
+
return f"Error: {str(e)}"
|
| 141 |
+
|
| 142 |
+
# Launch interactive app
|
| 143 |
+
iface = gr.Interface(
|
| 144 |
+
fn=generate_attribution,
|
| 145 |
+
inputs=[
|
| 146 |
+
gr.Textbox(label="Prompt", value="The capital of France is", placeholder="Enter your prompt..."),
|
| 147 |
+
gr.Textbox(label="Target Token", value=" Paris", placeholder="Enter target token (e.g., ' Paris')")
|
| 148 |
+
],
|
| 149 |
+
outputs=gr.Image(label="Token Attribution Visualization"),
|
| 150 |
+
title="AI Interpretability Explorer: See How Tokens Influence Predictions",
|
| 151 |
+
description="Input a prompt and target token to visualize token contributions using Integrated Gradients on LLaMA. "
|
| 152 |
+
"Explore model reasoning interactively.",
|
| 153 |
+
# Insert a collapsible Feynman-style explanation and quick cheat-sheet actions using HTML so Gradio shows it above the app.
|
| 154 |
+
# We use safe escaping for the cheat text when embedding into HTML/JS.
|
| 155 |
+
# The small JS below enables a copy-to-clipboard action and a downloadable .txt file via data URI.
|
| 156 |
+
article="""
|
| 157 |
+
### How it works — Feynman-style
|
| 158 |
+
|
| 159 |
+
This tool explains which input tokens most influence the model's next-token prediction using Integrated Gradients.
|
| 160 |
+
|
| 161 |
+
- What it does: Interpolates from a baseline to the actual input in embedding space, accumulates gradients along the path, and attributes importance to each input token.
|
| 162 |
+
- Why it helps: Highlights which tokens push the model toward (green) or away from (red) the chosen target token. Useful for debugging, bias detection, and model transparency.
|
| 163 |
+
- How to read results: Higher positive values (green) mean the token increases the probability of the target; negative values (red) mean the token reduces it. Values are normalized per example.
|
| 164 |
+
- Watch-outs: IG depends on the baseline choice and number of interpolation steps. Subword tokens (e.g., Ġ) are shown with spaces; long prompts may be noisy.
|
| 165 |
+
"""
|
| 166 |
+
,
|
| 167 |
+
examples=[
|
| 168 |
+
["The capital of France is", " Paris"],
|
| 169 |
+
["I love this product because", " it's amazing"],
|
| 170 |
+
["The weather today is", " sunny"]
|
| 171 |
+
]
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
if __name__ == "__main__":
|
| 175 |
+
# Run the original example for backward compatibility, then launch Gradio
|
| 176 |
+
print("Generating default example...")
|
| 177 |
+
default_img = generate_attribution("", "")
|
| 178 |
+
print(f"Default plot saved to: token_attributions.png")
|
| 179 |
+
print("\nLaunching interactive Gradio app... Open in browser for custom examples.")
|
| 180 |
+
iface.launch(share=True, server_name="0.0.0.0", server_port=7860)
|
requirements.txt
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
matplotlib
|
| 4 |
+
captum
|
| 5 |
+
ipython
|
| 6 |
+
transformers
|
| 7 |
+
pillow
|
| 8 |
+
lime
|
| 9 |
+
numpy
|
| 10 |
+
scikit-image
|
| 11 |
+
timm
|
| 12 |
+
streamlit
|
| 13 |
+
gradio
|
| 14 |
+
accelerate
|