Spaces:
Running
Running
| """ | |
| This script provides an interactive Gradio web application for visualizing token-level attributions in language model predictions using Integrated Gradients. It loads a small LLaMA model, computes how each input token contributes to the probability of a specified target token, and generates a color-coded visualization to explain model reasoning. | |
| Features: | |
| - Loads a causal language model and tokenizer (LLaMA). | |
| - Computes Integrated Gradients attributions for a prompt and target token. | |
| - Visualizes token contributions with a grid of colored boxes (green = positive, red = negative). | |
| - Interactive Gradio UI for custom prompts and target tokens. | |
| - Includes a Feynman-style explanation for interpretability concepts. | |
| How to run: | |
| 1. Ensure Python dependencies are installed: torch, transformers, captum, matplotlib, gradio. | |
| 2. Place this file in your project directory. | |
| 3. Run the script from the command line: | |
| python app.py | |
| 4. The app will launch locally (default port 7860). Open the provided URL in your browser. | |
| 5. Enter a prompt and target token to see the visualization and interpret model predictions. | |
| Notes: | |
| - The script saves the visualization as 'token_attributions.png'. | |
| - For long prompts (>50 tokens), a warning is shown to prevent performance issues. | |
| - Example prompts are provided for quick testing. | |
| """ | |
| import os | |
| import logging | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from captum.attr import IntegratedGradients | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import gradio as gr # Added for interactive UI | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Basic logger for helpful messages when loading gated models | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # ---------------- Load model (gated models handled safely) ---------------- | |
| # Default attempts to load LLaMA-3.2-1B, but that model is gated on HF. We try to use | |
| # HUGGINGFACE_HUB_TOKEN if available, otherwise fall back to a small public model for demo. | |
| requested_model = "meta-llama/Llama-3.2-1B" | |
| fallback_model = "distilgpt2" | |
| hf_token = os.environ.get("HUGGINGFACE_HUB_TOKEN") | |
| model_name = requested_model | |
| try: | |
| load_kwargs = {} | |
| if hf_token: | |
| load_kwargs["use_auth_token"] = hf_token | |
| tokenizer = AutoTokenizer.from_pretrained(model_name, **load_kwargs) | |
| model = AutoModelForCausalLM.from_pretrained(model_name, **load_kwargs).to(device) | |
| model.eval() | |
| logger.info(f"Loaded gated model: {model_name}") | |
| except Exception as e: | |
| logger.warning(f"Could not load requested model '{requested_model}': {e}") | |
| logger.info(f"Falling back to public model: {fallback_model} for demo purposes.") | |
| model_name = fallback_model | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained(model_name).to(device) | |
| model.eval() | |
| # ---------------- Modularized Functions ---------------- | |
| def compute_attributions(prompt, target_token): | |
| """ | |
| Compute Integrated Gradients attributions for a given prompt and target token. | |
| Appeals to devs/ML: Shows model interpretability; business: Builds trust by explaining AI decisions. | |
| """ | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| target_id = tokenizer(target_token, add_special_tokens=False)["input_ids"][0] | |
| def forward_func(embeds): | |
| outputs = model(inputs_embeds=embeds) | |
| logits = outputs.logits[:, -1, :] | |
| probs = torch.softmax(logits, dim=-1) | |
| return probs[:, target_id] | |
| embeddings = model.get_input_embeddings()(inputs["input_ids"]) | |
| embeddings.requires_grad_(True) | |
| ig = IntegratedGradients(forward_func) | |
| attributions, delta = ig.attribute( | |
| embeddings, n_steps=30, return_convergence_delta=True | |
| ) | |
| token_attr = attributions.sum(-1).squeeze().detach().cpu() | |
| tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"].squeeze()) | |
| # Normalize safely | |
| token_attr_np = token_attr.numpy() | |
| norm_denom = (abs(token_attr_np).max() + 1e-8) | |
| token_attr_np = token_attr_np / norm_denom | |
| return tokens, token_attr_np | |
| def create_visualization(tokens, token_attr_np, prompt, target_token): | |
| """ | |
| Generate an appealing visualization: Grid of colored token boxes. | |
| Enhanced for mixed audience: Clean design, simple explanations, professional look. | |
| """ | |
| num_tokens = max(1, len(tokens)) | |
| cols = min(max(3, int(num_tokens**0.5)), 8) | |
| rows = (num_tokens + cols - 1) // cols | |
| box_w = 1.0 / cols | |
| box_h = 0.18 | |
| fig_h = max(4, rows * 0.7 + 2.0) # Increased height for more spacing | |
| fig = plt.figure(figsize=(12, fig_h)) | |
| # Add title for context | |
| fig.suptitle(f"Token Contributions to Predicting '{target_token}' in: '{prompt}'", | |
| fontsize=14, y=0.95, ha='center') | |
| ax = fig.add_axes([0, 0.30, 1, 0.60]) # Shift grid higher for more bottom space | |
| ax.set_xlim(0, cols) | |
| ax.set_ylim(0, rows) | |
| ax.axis('off') | |
| # Normalize for colormap (0-1 range) | |
| minv, maxv = token_attr_np.min(), token_attr_np.max() | |
| norm = (token_attr_np - minv) / (maxv - minv + 1e-8) | |
| cmap = plt.get_cmap('RdYlGn') # Green positive, red negative | |
| from matplotlib.patches import FancyBboxPatch | |
| for idx, (tok, score_norm) in enumerate(zip(tokens, norm)): | |
| r = idx // cols | |
| c = idx % cols | |
| x = c | |
| y = rows - 1 - r | |
| color = cmap(score_norm) | |
| pad = 0.08 | |
| rect = FancyBboxPatch((x + pad*0.15, y + pad*0.15), 1 - pad, box_h - pad*0.3, | |
| boxstyle='round,pad=0.02', linewidth=0.8, | |
| facecolor=color, edgecolor='gray', alpha=0.95) # Softer edges | |
| ax.add_patch(rect) | |
| # Improved text: Larger font, wrap long tokens | |
| display_tok = tok.replace('Ġ', ' ') if isinstance(tok, str) else str(tok) # Space for subwords | |
| ax.text(x + 0.5, y + box_h/2, display_tok, ha='center', va='center', | |
| fontsize=10, fontweight='bold') # Bold for readability | |
| # Enhanced colorbar - lowered position | |
| sm = plt.cm.ScalarMappable(cmap=cmap) | |
| sm.set_array([0, 1]) | |
| cax = fig.add_axes([0.1, 0.22, 0.8, 0.04]) # Lowered from 0.18 | |
| cb = fig.colorbar(sm, cax=cax, orientation='horizontal') | |
| cb.set_label('Contribution Strength', fontsize=11, fontweight='bold') | |
| # Markers for audience-friendly explanation - lowered | |
| fig.text(0.05, 0.16, 'Green Positive (helps prediction)', fontsize=10, ha='left') | |
| fig.text(0.75, 0.16, 'Red Negative (hinders prediction)', fontsize=10, ha='right') | |
| # Engaging caption for mixed audience - shortened and lowered with wrap | |
| caption = ( | |
| "How input tokens influence the model's target prediction: Green supports (builds AI trust), " | |
| "red opposes. For debugging (devs), reasoning insights (ML), reliable decisions (business). Normalized." | |
| ) | |
| fig.text(0.5, 0.08, caption, fontsize=9, ha='center', va='top', wrap=True) # Smaller font, lower pos | |
| # Save with higher quality | |
| out_path = 'token_attributions.png' | |
| fig.savefig(out_path, dpi=300, bbox_inches='tight', facecolor='white') | |
| plt.close(fig) # Clean up | |
| return out_path | |
| # ---------------- Gradio Interface for Interactivity ---------------- | |
| def generate_attribution(prompt, target_token): | |
| """ | |
| Gradio wrapper: Compute and visualize for custom inputs. | |
| Default example: France capital for quick demo. | |
| """ | |
| if not prompt.strip(): | |
| prompt = "The capital of France is" | |
| if not target_token.strip(): | |
| target_token = " Paris" | |
| # Add check for long prompts to prevent overload | |
| if len(prompt.split()) > 50: | |
| return "Warning: Prompt too long (>50 tokens). Shorten for better performance." | |
| try: | |
| tokens, token_attr_np = compute_attributions(prompt, target_token) | |
| img_path = create_visualization(tokens, token_attr_np, prompt, target_token) | |
| return img_path | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Launch interactive app | |
| iface = gr.Interface( | |
| fn=generate_attribution, | |
| inputs=[ | |
| gr.Textbox(label="Prompt", value="The capital of France is", placeholder="Enter your prompt..."), | |
| gr.Textbox(label="Target Token", value=" Paris", placeholder="Enter target token (e.g., ' Paris')") | |
| ], | |
| outputs=gr.Image(label="Token Attribution Visualization"), | |
| title="AI Interpretability Explorer: See How Tokens Influence Predictions", | |
| description="Input a prompt and target token to visualize token contributions using [Integrated Gradients](https://captum.ai/docs/extension/integrated_gradients) on LLaMA. " | |
| "Explore model reasoning interactively.", | |
| # Insert a collapsible Feynman-style explanation and quick cheat-sheet actions using HTML so Gradio shows it above the app. | |
| # We use safe escaping for the cheat text when embedding into HTML/JS. | |
| # The small JS below enables a copy-to-clipboard action and a downloadable .txt file via data URI. | |
| article=""" | |
| ### How it works — Feynman-style | |
| This tool explains which input tokens most influence the model's next-token prediction using Integrated Gradients https://captum.ai/docs/extension/integrated_gradients. | |
| - What it does: Interpolates from a baseline to the actual input in embedding space, accumulates gradients along the path, and attributes importance to each input token. | |
| - Why it helps: Highlights which tokens push the model toward (green) or away from (red) the chosen target token. Useful for debugging, bias detection, and model transparency. | |
| - How to read results: Higher positive values (green) mean the token increases the probability of the target; negative values (red) mean the token reduces it. Values are normalized per example. | |
| - Watch-outs: IG depends on the baseline choice and number of interpolation steps. Subword tokens (e.g., Ġ) are shown with spaces; long prompts may be noisy. | |
| """ | |
| , | |
| examples=[ | |
| ["The capital of France is", " Paris"], | |
| ["I love this product because", " it's amazing"], | |
| ["The weather today is", " sunny"] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| # Run the original example for backward compatibility, then launch Gradio | |
| print("Generating default example...") | |
| default_img = generate_attribution("", "") | |
| print(f"Default plot saved to: token_attributions.png") | |
| print("\nLaunching interactive Gradio app... Open in browser for custom examples.") | |
| iface.launch(share=True, server_name="0.0.0.0", server_port=7860) | |