Skier8402's picture
Update app.py
1e56b47 verified
"""
This script provides an interactive Gradio web application for visualizing token-level attributions in language model predictions using Integrated Gradients. It loads a small LLaMA model, computes how each input token contributes to the probability of a specified target token, and generates a color-coded visualization to explain model reasoning.
Features:
- Loads a causal language model and tokenizer (LLaMA).
- Computes Integrated Gradients attributions for a prompt and target token.
- Visualizes token contributions with a grid of colored boxes (green = positive, red = negative).
- Interactive Gradio UI for custom prompts and target tokens.
- Includes a Feynman-style explanation for interpretability concepts.
How to run:
1. Ensure Python dependencies are installed: torch, transformers, captum, matplotlib, gradio.
2. Place this file in your project directory.
3. Run the script from the command line:
python app.py
4. The app will launch locally (default port 7860). Open the provided URL in your browser.
5. Enter a prompt and target token to see the visualization and interpret model predictions.
Notes:
- The script saves the visualization as 'token_attributions.png'.
- For long prompts (>50 tokens), a warning is shown to prevent performance issues.
- Example prompts are provided for quick testing.
"""
import os
import logging
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from captum.attr import IntegratedGradients
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import gradio as gr # Added for interactive UI
device = "cuda" if torch.cuda.is_available() else "cpu"
# Basic logger for helpful messages when loading gated models
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# ---------------- Load model (gated models handled safely) ----------------
# Default attempts to load LLaMA-3.2-1B, but that model is gated on HF. We try to use
# HUGGINGFACE_HUB_TOKEN if available, otherwise fall back to a small public model for demo.
requested_model = "meta-llama/Llama-3.2-1B"
fallback_model = "distilgpt2"
hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
model_name = requested_model
try:
load_kwargs = {}
if hf_token:
load_kwargs["use_auth_token"] = hf_token
tokenizer = AutoTokenizer.from_pretrained(model_name, **load_kwargs)
model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs).to(device)
model.eval()
logger.info(f"Loaded gated model: {model_name}")
except Exception as e:
logger.warning(f"Could not load requested model '{requested_model}': {e}")
logger.info(f"Falling back to public model: {fallback_model} for demo purposes.")
model_name = fallback_model
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
model.eval()
# ---------------- Modularized Functions ----------------
def compute_attributions(prompt, target_token):
"""
Compute Integrated Gradients attributions for a given prompt and target token.
Appeals to devs/ML: Shows model interpretability; business: Builds trust by explaining AI decisions.
"""
inputs = tokenizer(prompt, return_tensors="pt").to(device)
target_id = tokenizer(target_token, add_special_tokens=False)["input_ids"][0]
def forward_func(embeds):
outputs = model(inputs_embeds=embeds)
logits = outputs.logits[:, -1, :]
probs = torch.softmax(logits, dim=-1)
return probs[:, target_id]
embeddings = model.get_input_embeddings()(inputs["input_ids"])
embeddings.requires_grad_(True)
ig = IntegratedGradients(forward_func)
attributions, delta = ig.attribute(
embeddings, n_steps=30, return_convergence_delta=True
)
token_attr = attributions.sum(-1).squeeze().detach().cpu()
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze())
# Normalize safely
token_attr_np = token_attr.numpy()
norm_denom = (abs(token_attr_np).max() + 1e-8)
token_attr_np = token_attr_np / norm_denom
return tokens, token_attr_np
def create_visualization(tokens, token_attr_np, prompt, target_token):
"""
Generate an appealing visualization: Grid of colored token boxes.
Enhanced for mixed audience: Clean design, simple explanations, professional look.
"""
num_tokens = max(1, len(tokens))
cols = min(max(3, int(num_tokens**0.5)), 8)
rows = (num_tokens + cols - 1) // cols
box_w = 1.0 / cols
box_h = 0.18
fig_h = max(4, rows * 0.7 + 2.0) # Increased height for more spacing
fig = plt.figure(figsize=(12, fig_h))
# Add title for context
fig.suptitle(f"Token Contributions to Predicting '{target_token}' in: '{prompt}'",
fontsize=14, y=0.95, ha='center')
ax = fig.add_axes([0, 0.30, 1, 0.60]) # Shift grid higher for more bottom space
ax.set_xlim(0, cols)
ax.set_ylim(0, rows)
ax.axis('off')
# Normalize for colormap (0-1 range)
minv, maxv = token_attr_np.min(), token_attr_np.max()
norm = (token_attr_np - minv) / (maxv - minv + 1e-8)
cmap = plt.get_cmap('RdYlGn') # Green positive, red negative
from matplotlib.patches import FancyBboxPatch
for idx, (tok, score_norm) in enumerate(zip(tokens, norm)):
r = idx // cols
c = idx % cols
x = c
y = rows - 1 - r
color = cmap(score_norm)
pad = 0.08
rect = FancyBboxPatch((x + pad*0.15, y + pad*0.15), 1 - pad, box_h - pad*0.3,
boxstyle='round,pad=0.02', linewidth=0.8,
facecolor=color, edgecolor='gray', alpha=0.95) # Softer edges
ax.add_patch(rect)
# Improved text: Larger font, wrap long tokens
display_tok = tok.replace('Ġ', ' ') if isinstance(tok, str) else str(tok) # Space for subwords
ax.text(x + 0.5, y + box_h/2, display_tok, ha='center', va='center',
fontsize=10, fontweight='bold') # Bold for readability
# Enhanced colorbar - lowered position
sm = plt.cm.ScalarMappable(cmap=cmap)
sm.set_array([0, 1])
cax = fig.add_axes([0.1, 0.22, 0.8, 0.04]) # Lowered from 0.18
cb = fig.colorbar(sm, cax=cax, orientation='horizontal')
cb.set_label('Contribution Strength', fontsize=11, fontweight='bold')
# Markers for audience-friendly explanation - lowered
fig.text(0.05, 0.16, 'Green Positive (helps prediction)', fontsize=10, ha='left')
fig.text(0.75, 0.16, 'Red Negative (hinders prediction)', fontsize=10, ha='right')
# Engaging caption for mixed audience - shortened and lowered with wrap
caption = (
"How input tokens influence the model's target prediction: Green supports (builds AI trust), "
"red opposes. For debugging (devs), reasoning insights (ML), reliable decisions (business). Normalized."
)
fig.text(0.5, 0.08, caption, fontsize=9, ha='center', va='top', wrap=True) # Smaller font, lower pos
# Save with higher quality
out_path = 'token_attributions.png'
fig.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white')
plt.close(fig) # Clean up
return out_path
# ---------------- Gradio Interface for Interactivity ----------------
def generate_attribution(prompt, target_token):
"""
Gradio wrapper: Compute and visualize for custom inputs.
Default example: France capital for quick demo.
"""
if not prompt.strip():
prompt = "The capital of France is"
if not target_token.strip():
target_token = " Paris"
# Add check for long prompts to prevent overload
if len(prompt.split()) > 50:
return "Warning: Prompt too long (>50 tokens). Shorten for better performance."
try:
tokens, token_attr_np = compute_attributions(prompt, target_token)
img_path = create_visualization(tokens, token_attr_np, prompt, target_token)
return img_path
except Exception as e:
return f"Error: {str(e)}"
# Launch interactive app
iface = gr.Interface(
fn=generate_attribution,
inputs=[
gr.Textbox(label="Prompt", value="The capital of France is", placeholder="Enter your prompt..."),
gr.Textbox(label="Target Token", value=" Paris", placeholder="Enter target token (e.g., ' Paris')")
],
outputs=gr.Image(label="Token Attribution Visualization"),
title="AI Interpretability Explorer: See How Tokens Influence Predictions",
description="Input a prompt and target token to visualize token contributions using [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients) on LLaMA. "
"Explore model reasoning interactively.",
# Insert a collapsible Feynman-style explanation and quick cheat-sheet actions using HTML so Gradio shows it above the app.
# We use safe escaping for the cheat text when embedding into HTML/JS.
# The small JS below enables a copy-to-clipboard action and a downloadable .txt file via data URI.
article="""
### How it works — Feynman-style
This tool explains which input tokens most influence the model's next-token prediction using Integrated Gradients https://captum.ai/docs/extension/integrated_gradients.
- What it does: Interpolates from a baseline to the actual input in embedding space, accumulates gradients along the path, and attributes importance to each input token.
- Why it helps: Highlights which tokens push the model toward (green) or away from (red) the chosen target token. Useful for debugging, bias detection, and model transparency.
- How to read results: Higher positive values (green) mean the token increases the probability of the target; negative values (red) mean the token reduces it. Values are normalized per example.
- Watch-outs: IG depends on the baseline choice and number of interpolation steps. Subword tokens (e.g., Ġ) are shown with spaces; long prompts may be noisy.
"""
,
examples=[
["The capital of France is", " Paris"],
["I love this product because", " it's amazing"],
["The weather today is", " sunny"]
]
)
if __name__ == "__main__":
# Run the original example for backward compatibility, then launch Gradio
print("Generating default example...")
default_img = generate_attribution("", "")
print(f"Default plot saved to: token_attributions.png")
print("\nLaunching interactive Gradio app... Open in browser for custom examples.")
iface.launch(share=True, server_name="0.0.0.0", server_port=7860)