|
|
import torch |
|
|
import os |
|
|
import json |
|
|
import io |
|
|
import matplotlib.pyplot as plt |
|
|
import matplotlib.colors as mcolors |
|
|
import matplotlib.patches as mpatches |
|
|
from matplotlib_venn import venn2, venn3 |
|
|
import seaborn as sns |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
from tqdm import tqdm |
|
|
from typing import List, Dict, Any, Optional |
|
|
from PIL import Image |
|
|
from safetensors import safe_open |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
|
|
|
def _set_publication_fonts(scale_factor=1.0): |
|
|
""" |
|
|
Sets matplotlib to use publication-ready fonts matching NeurIPS/LaTeX style. |
|
|
|
|
|
Args: |
|
|
scale_factor: Factor to scale all font sizes. Use >1.0 when creating subplots |
|
|
or smaller figures where fonts need to be larger for readability. |
|
|
Recommended: 1.0 for full-page plots, 1.5-2.0 for subplots. |
|
|
""" |
|
|
|
|
|
base_sizes = { |
|
|
'font.size': 14, |
|
|
'axes.labelsize': 16, |
|
|
'axes.titlesize': 18, |
|
|
'xtick.labelsize': 14, |
|
|
'ytick.labelsize': 14, |
|
|
'legend.fontsize': 14, |
|
|
} |
|
|
|
|
|
|
|
|
plt.rcParams['font.family'] = 'serif' |
|
|
plt.rcParams['font.serif'] = ['Computer Modern Roman', 'DejaVu Serif', 'Times New Roman'] |
|
|
for key, size in base_sizes.items(): |
|
|
plt.rcParams[key] = size * scale_factor |
|
|
|
|
|
plt.rcParams['mathtext.fontset'] = 'cm' |
|
|
|
|
|
|
|
|
def _get_scaled_fontsize(base_size, scale_factor=1.5): |
|
|
""" |
|
|
Returns a scaled font size for specific plot elements. |
|
|
Default scale_factor of 1.5 ensures readability in subplots. |
|
|
""" |
|
|
return int(base_size * scale_factor) |
|
|
|
|
|
|
|
|
def _optimize_png_for_heatmap(png_path: str, num_colors: int = 256, resize_factor: float = 1.0) -> None: |
|
|
""" |
|
|
Aggressively optimize a PNG file for minimal size while maintaining acceptable quality. |
|
|
|
|
|
Args: |
|
|
png_path: Path to the PNG file to optimize |
|
|
num_colors: Maximum number of colors in the palette (default 256) |
|
|
resize_factor: Factor to resize image (1.0 = no resize, 0.5 = half size) |
|
|
""" |
|
|
try: |
|
|
from PIL import Image |
|
|
import subprocess |
|
|
import shutil |
|
|
|
|
|
|
|
|
img = Image.open(png_path) |
|
|
|
|
|
|
|
|
if img.mode == 'RGBA': |
|
|
background = Image.new('RGB', img.size, (255, 255, 255)) |
|
|
background.paste(img, mask=img.split()[3]) |
|
|
img = background |
|
|
elif img.mode != 'RGB': |
|
|
img = img.convert('RGB') |
|
|
|
|
|
|
|
|
if resize_factor < 1.0: |
|
|
new_size = (int(img.width * resize_factor), int(img.height * resize_factor)) |
|
|
img = img.resize(new_size, Image.Resampling.NEAREST) |
|
|
|
|
|
|
|
|
|
|
|
actual_colors = min(num_colors, 16) |
|
|
img_indexed = img.quantize(colors=actual_colors, method=2, dither=0) |
|
|
|
|
|
|
|
|
img_indexed.save(png_path, 'PNG', optimize=True, compress_level=9) |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
if shutil.which('pngquant'): |
|
|
subprocess.run([ |
|
|
'pngquant', |
|
|
'--force', |
|
|
'--skip-if-larger', |
|
|
'--quality=50-90', |
|
|
'--speed=1', |
|
|
str(actual_colors), |
|
|
png_path |
|
|
], capture_output=True, check=False) |
|
|
|
|
|
|
|
|
elif shutil.which('optipng'): |
|
|
subprocess.run([ |
|
|
'optipng', |
|
|
'-o7', |
|
|
'-quiet', |
|
|
png_path |
|
|
], capture_output=True, check=False) |
|
|
|
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
except Exception as e: |
|
|
print(f" Warning: PNG optimization failed: {e}") |
|
|
|
|
|
def _calculate_optimal_dpi(data_shape: tuple, target_pixels: int = 200000, is_per_model: bool = False) -> int: |
|
|
""" |
|
|
Calculate optimal DPI based on data dimensions to minimize file size. |
|
|
More aggressive settings since quality is confirmed to be good. |
|
|
|
|
|
Args: |
|
|
data_shape: Shape of the heatmap data (height, width) |
|
|
target_pixels: Target number of pixels in the output image |
|
|
is_per_model: Whether this is for per-model heatmaps (use more aggressive compression) |
|
|
|
|
|
Returns: |
|
|
Optimal DPI value |
|
|
""" |
|
|
|
|
|
if is_per_model: |
|
|
target_pixels = 100000 |
|
|
|
|
|
|
|
|
if data_shape[0] < 50 and data_shape[1] < 50: |
|
|
return 120 if is_per_model else 150 |
|
|
|
|
|
|
|
|
figure_width_inches = 8 |
|
|
data_pixels = data_shape[0] * data_shape[1] |
|
|
|
|
|
if data_pixels > 5000: |
|
|
|
|
|
scale_factor = np.sqrt(target_pixels / data_pixels) |
|
|
optimal_dpi = int(80 * scale_factor) if is_per_model else int(100 * scale_factor) |
|
|
return max(60 if is_per_model else 72, min(120 if is_per_model else 150, optimal_dpi)) |
|
|
|
|
|
return 100 if is_per_model else 120 |
|
|
|
|
|
def _save_heatmap_pdf(fig, output_path: str, data_shape: tuple) -> str: |
|
|
""" |
|
|
Save a heatmap figure to PDF. Due to inherent PDF rendering issues with |
|
|
pixel-perfect data, we recommend using PNG format for heatmaps instead. |
|
|
|
|
|
Args: |
|
|
fig: The matplotlib figure |
|
|
output_path: Path to save the PDF (can be .png or .pdf extension) |
|
|
data_shape: Shape of the heatmap data (height, width) |
|
|
|
|
|
Returns: |
|
|
str: Path to the saved PDF file |
|
|
""" |
|
|
pdf_output_path = os.path.splitext(output_path)[0] + '.pdf' |
|
|
|
|
|
|
|
|
import matplotlib as mpl |
|
|
|
|
|
|
|
|
old_interpolation = mpl.rcParams.get('image.interpolation', 'antialiased') |
|
|
old_interpolation_stage = mpl.rcParams.get('image.interpolation_stage', 'data') |
|
|
|
|
|
try: |
|
|
|
|
|
mpl.rcParams['image.interpolation'] = 'none' |
|
|
mpl.rcParams['image.interpolation_stage'] = 'rgba' |
|
|
|
|
|
|
|
|
plt.savefig(pdf_output_path, |
|
|
format='pdf', |
|
|
dpi=1200, |
|
|
bbox_inches='tight', |
|
|
facecolor='white', |
|
|
edgecolor='none', |
|
|
pad_inches=0.1, |
|
|
|
|
|
transparent=False, |
|
|
|
|
|
metadata={'Creator': None, 'Producer': None, 'CreationDate': None}) |
|
|
|
|
|
finally: |
|
|
|
|
|
mpl.rcParams['image.interpolation'] = old_interpolation |
|
|
mpl.rcParams['image.interpolation_stage'] = old_interpolation_stage |
|
|
|
|
|
|
|
|
print(f" ⚠️ Note: PDF format may show artifacts with pixel-based heatmaps.") |
|
|
print(f" For publication-quality heatmaps, consider using the PNG versions.") |
|
|
|
|
|
return pdf_output_path |
|
|
|
|
|
def _shorten_name(name: str) -> str: |
|
|
"""Shortens run names for legends, e.g., 'sft_if_magnitude' -> 'if'.""" |
|
|
parts = name.split('_') |
|
|
|
|
|
if len(parts) > 1: |
|
|
|
|
|
return parts[1] |
|
|
return name |
|
|
|
|
|
|
|
|
def _extract_parameter_info(layer_name: str) -> str: |
|
|
""" |
|
|
Extracts parameter type and layer number from layer name for display. |
|
|
E.g., 'model.layers.15.self_attn.q_proj.weight' -> 'Layer 15 - q_proj' |
|
|
""" |
|
|
import re |
|
|
|
|
|
pattern = re.compile(r"model\.layers\.(\d+)\..*\.([^.]+)\.weight") |
|
|
match = pattern.match(layer_name) |
|
|
if match: |
|
|
layer_num = match.group(1) |
|
|
param_type = match.group(2) |
|
|
return f"Layer {layer_num} - {param_type}" |
|
|
|
|
|
return layer_name |
|
|
|
|
|
def load_masks_from_run(run_dir: str) -> Dict[str, Any]: |
|
|
""" |
|
|
Loads the masks.pt file from an experiment output directory. |
|
|
|
|
|
Args: |
|
|
run_dir (str): Path to the experiment output directory. |
|
|
|
|
|
Returns: |
|
|
dict: The dictionary of masks. |
|
|
""" |
|
|
masks_path = os.path.join(run_dir, "masks.pt") |
|
|
if not os.path.exists(masks_path): |
|
|
raise FileNotFoundError(f"Mask file not found at {masks_path}") |
|
|
|
|
|
print(f"Loading masks from {masks_path}...") |
|
|
masks_dict = torch.load(masks_path, map_location='cpu') |
|
|
print(f"✓ Loaded {len(masks_dict)} masks.") |
|
|
return masks_dict |
|
|
|
|
|
def calculate_mask_overlap(masks1_dict, masks2_dict): |
|
|
""" |
|
|
Calculates the overlap (Jaccard Index) between two sets of masks. |
|
|
|
|
|
Args: |
|
|
masks1_dict (dict): The first dictionary of masks. |
|
|
masks2_dict (dict): The second dictionary of masks. |
|
|
|
|
|
Returns: |
|
|
dict: A dictionary with overlap statistics. |
|
|
""" |
|
|
print("Calculating mask overlap...") |
|
|
|
|
|
intersection_size = 0 |
|
|
union_size = 0 |
|
|
|
|
|
|
|
|
common_params = set(masks1_dict.keys()) & set(masks2_dict.keys()) |
|
|
print(f"Found {len(common_params)} common parameters between the two mask sets.") |
|
|
|
|
|
for name in common_params: |
|
|
mask1 = masks1_dict[name] |
|
|
mask2 = masks2_dict[name] |
|
|
|
|
|
|
|
|
mask1 = mask1.bool() |
|
|
mask2 = mask2.bool() |
|
|
|
|
|
intersection = (mask1 & mask2).sum().item() |
|
|
union = (mask1 | mask2).sum().item() |
|
|
|
|
|
intersection_size += intersection |
|
|
union_size += union |
|
|
|
|
|
if union_size == 0: |
|
|
jaccard_index = 0.0 |
|
|
else: |
|
|
jaccard_index = intersection_size / union_size |
|
|
|
|
|
stats = { |
|
|
'jaccard_index': jaccard_index, |
|
|
'intersection_size': intersection_size, |
|
|
'union_size': union_size, |
|
|
'total_common_params': len(common_params) |
|
|
} |
|
|
|
|
|
print("✓ Overlap calculation complete.") |
|
|
return stats |
|
|
|
|
|
def _visualize_grafting_analysis(pretrained_model, finetuned_model, optimizer_v_state, |
|
|
selected_layers, sparsity_ratio, global_threshold, grafting_method): |
|
|
""" |
|
|
Internal function to compute stats for visualization. |
|
|
Adapted from the notebook. |
|
|
""" |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
pretrained_state = pretrained_model.state_dict() |
|
|
finetuned_state = finetuned_model.state_dict() |
|
|
|
|
|
layer_stats = {} |
|
|
print(f"🔍 Computing scores for {len(selected_layers)} layers for visualization...") |
|
|
for layer_name in tqdm(selected_layers, desc="Analyzing layers"): |
|
|
if layer_name in pretrained_state: |
|
|
w_t = finetuned_state[layer_name].to(device).to(torch.float32) |
|
|
w_0 = pretrained_state[layer_name].to(device).to(torch.float32) |
|
|
|
|
|
if grafting_method in ('fast_fisher', 'ffg'): |
|
|
if layer_name not in optimizer_v_state: |
|
|
continue |
|
|
v_t = optimizer_v_state[layer_name].to(device).to(torch.float32) |
|
|
scores = (w_t - w_0)**2 * v_t |
|
|
elif grafting_method in ('magnitude', 'mag'): |
|
|
scores = torch.abs(w_t - w_0) |
|
|
elif grafting_method in ('fish_mask', 'fmg'): |
|
|
if layer_name not in optimizer_v_state: |
|
|
continue |
|
|
v_t = optimizer_v_state[layer_name].to(device).to(torch.float32) |
|
|
scores = v_t |
|
|
else: |
|
|
raise ValueError(f"Unsupported grafting method: {grafting_method}") |
|
|
|
|
|
flat_scores = scores.flatten() |
|
|
mask = (scores >= global_threshold).reshape(w_t.shape) |
|
|
kept_params = mask.sum().item() |
|
|
total_params_layer = mask.numel() |
|
|
sparsity_layer = kept_params / total_params_layer |
|
|
|
|
|
layer_stats[layer_name] = { |
|
|
'scores': scores.cpu(), |
|
|
'flat_scores': flat_scores.cpu(), |
|
|
'shape': w_t.shape, |
|
|
'mask': mask.cpu(), |
|
|
'kept_params': kept_params, |
|
|
'sparsity': sparsity_layer, |
|
|
'mean_score': float(flat_scores.mean()), |
|
|
} |
|
|
return layer_stats |
|
|
|
|
|
def _create_grafting_visualizations(layer_stats, global_threshold, sparsity_ratio, grafting_method, save_path): |
|
|
""" |
|
|
Internal function to create and save the visualization plot. |
|
|
Adapted from the notebook. |
|
|
""" |
|
|
print("🎨 Creating grafting visualizations...") |
|
|
plt.style.use('seaborn-v0_8-whitegrid') |
|
|
|
|
|
sns.set_palette("husl") |
|
|
|
|
|
fig, axes = plt.subplots(2, 3, figsize=(20, 12)) |
|
|
fig.suptitle(f'Grafting Analysis ({grafting_method.replace("_", " ").title()})', y=1.02) |
|
|
|
|
|
layer_names = list(layer_stats.keys()) |
|
|
|
|
|
|
|
|
for i, layer_name in enumerate(layer_names[:3]): |
|
|
ax = axes[0, i] |
|
|
stats = layer_stats[layer_name] |
|
|
sns.histplot(stats['flat_scores'].numpy(), ax=ax, bins=50, log_scale=True, kde=True) |
|
|
ax.axvline(global_threshold, color='r', linestyle='--', label=f'Global Thr: {global_threshold:.2e}') |
|
|
ax.set_title(f'{layer_name}\nSparsity: {stats["sparsity"]:.2%}') |
|
|
ax.set_xlabel("Importance Score") |
|
|
ax.legend() |
|
|
|
|
|
|
|
|
for i, layer_name in enumerate(layer_names[3:]): |
|
|
ax = axes[1, i] |
|
|
stats = layer_stats[layer_name] |
|
|
mask = stats['mask'].numpy() |
|
|
|
|
|
if len(mask.shape) == 2 and (mask.shape[0] > 100 or mask.shape[1] > 100): |
|
|
center_i, center_j = mask.shape[0] // 2, mask.shape[1] // 2 |
|
|
mask_sample = mask[center_i-50:center_i+50, center_j-50:center_j+50] |
|
|
title = f'{layer_name}\nSparsity: {stats["sparsity"]:.2%}\n(100x100 center crop)' |
|
|
else: |
|
|
mask_sample = mask |
|
|
title = f'{layer_name}\nSparsity: {stats["sparsity"]:.2%}\n(Full matrix)' |
|
|
|
|
|
hm = sns.heatmap(mask_sample, ax=ax, cbar=False, cmap="viridis") |
|
|
for c in hm.collections: |
|
|
c.set_rasterized(True) |
|
|
ax.set_title(title) |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
|
|
|
plt.tight_layout(rect=[0, 0, 1, 0.98]) |
|
|
plt.savefig(save_path, dpi=150) |
|
|
print(f"✓ Visualization saved to {save_path}") |
|
|
plt.close() |
|
|
|
|
|
def generate_single_run_visualizations(run_dir): |
|
|
""" |
|
|
Loads artifacts from a single experiment run and generates visualizations. |
|
|
""" |
|
|
print(f"--- Generating visualizations for run: {run_dir} ---") |
|
|
|
|
|
|
|
|
config_path = os.path.join(run_dir, "config.yml") |
|
|
stats_path = os.path.join(run_dir, "statistics.json") |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
import yaml |
|
|
config = yaml.safe_load(f) |
|
|
with open(stats_path, 'r') as f: |
|
|
stats = json.load(f) |
|
|
|
|
|
grafting_method = config['grafting_config']['method'] |
|
|
global_threshold = stats['threshold'] |
|
|
sparsity_ratio = config['grafting_config']['sparsity_ratio'] |
|
|
|
|
|
|
|
|
from .models import load_assets |
|
|
pretrained_model, finetuned_model, optimizer_v_state, _ = load_assets(config['model_config']) |
|
|
|
|
|
if grafting_method == 'fast_fisher' and optimizer_v_state is None: |
|
|
raise ValueError("Fast Fisher method requires optimizer states, which were not found.") |
|
|
|
|
|
|
|
|
selected_layers = [ |
|
|
'model.layers.0.self_attn.q_proj.weight', |
|
|
'model.layers.15.self_attn.q_proj.weight', |
|
|
'model.layers.31.self_attn.q_proj.weight', |
|
|
'model.layers.0.mlp.gate_proj.weight', |
|
|
'model.layers.15.mlp.gate_proj.weight', |
|
|
'model.layers.31.mlp.gate_proj.weight', |
|
|
] |
|
|
|
|
|
|
|
|
layer_stats = _visualize_grafting_analysis( |
|
|
pretrained_model, finetuned_model, optimizer_v_state, |
|
|
selected_layers, sparsity_ratio, global_threshold, grafting_method |
|
|
) |
|
|
|
|
|
|
|
|
save_path = os.path.join(run_dir, "grafting_analysis.png") |
|
|
_create_grafting_visualizations( |
|
|
layer_stats, global_threshold, sparsity_ratio, grafting_method, save_path |
|
|
) |
|
|
|
|
|
|
|
|
def _calculate_layerwise_jaccard(masks1_dict, masks2_dict): |
|
|
""" |
|
|
Calculates layer-wise Jaccard Index. |
|
|
""" |
|
|
layer_jaccard_scores = {} |
|
|
common_params = set(masks1_dict.keys()) & set(masks2_dict.keys()) |
|
|
|
|
|
for name in common_params: |
|
|
mask1 = masks1_dict[name].bool() |
|
|
mask2 = masks2_dict[name].bool() |
|
|
|
|
|
intersection = (mask1 & mask2).sum().item() |
|
|
union = (mask1 | mask2).sum().item() |
|
|
|
|
|
jaccard = intersection / union if union > 0 else 0 |
|
|
layer_jaccard_scores[name] = jaccard |
|
|
|
|
|
return layer_jaccard_scores |
|
|
|
|
|
def _create_jaccard_barchart(layer_jaccard_scores, output_path, names, font_scale=1.0): |
|
|
""" |
|
|
Creates and saves a bar chart of layer-wise Jaccard scores. |
|
|
""" |
|
|
|
|
|
sorted_layers = sorted(layer_jaccard_scores.items(), key=lambda item: item[1], reverse=True) |
|
|
|
|
|
display_data = sorted_layers |
|
|
title = f"Layer-wise Jaccard Scores ({names[0]} vs {names[1]})" |
|
|
|
|
|
layer_names = [item[0] for item in display_data] |
|
|
jaccard_values = [item[1] for item in display_data] |
|
|
|
|
|
plt.style.use('seaborn-v0_8-whitegrid') |
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(12, max(12, len(jaccard_values) * 0.2))) |
|
|
|
|
|
bars = ax.barh(layer_names, jaccard_values, color=sns.color_palette("viridis", len(jaccard_values))) |
|
|
|
|
|
ax.set_xlabel("Jaccard Index (Overlap)") |
|
|
ax.set_title(title) |
|
|
ax.set_xlim(0, 1) |
|
|
ax.invert_yaxis() |
|
|
ax.tick_params(axis='both', which='major', labelsize=plt.rcParams['xtick.labelsize']) |
|
|
|
|
|
|
|
|
for bar in bars: |
|
|
width = bar.get_width() |
|
|
ax.text(width + 0.01, bar.get_y() + bar.get_height()/2, f'{width:.2f}', ha='left', va='center', fontsize=plt.rcParams['font.size']) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(output_path, dpi=300) |
|
|
|
|
|
pdf_output_path = os.path.splitext(output_path)[0] + '.pdf' |
|
|
plt.savefig(pdf_output_path, format='pdf') |
|
|
|
|
|
print(f"✓ Layer-wise overlap chart saved to {output_path} and {pdf_output_path}") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _create_comparison_heatmap(masks1_dict: Dict[str, Any], masks2_dict: Dict[str, Any], layer_name: str, output_path: str, names: List[str], font_scale: float = 1.0): |
|
|
""" |
|
|
Creates and saves a comparison heatmap for a specific layer. |
|
|
""" |
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
mask1 = masks1_dict[layer_name].bool() |
|
|
|
|
|
|
|
|
if len(mask1.shape) != 2: |
|
|
return |
|
|
|
|
|
mask2 = masks2_dict[layer_name].bool() |
|
|
total_params = mask1.numel() |
|
|
|
|
|
|
|
|
kept_1_only = (mask1 & ~mask2).sum().item() |
|
|
kept_2_only = (~mask1 & mask2).sum().item() |
|
|
intersection = (mask1 & mask2).sum().item() |
|
|
pruned_both = (~mask1 & ~mask2).sum().item() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
comparison_map = torch.zeros_like(mask1, dtype=torch.int8) |
|
|
comparison_map[mask1 & ~mask2] = 1 |
|
|
comparison_map[~mask1 & mask2] = 2 |
|
|
comparison_map[mask1 & mask2] = 3 |
|
|
|
|
|
comparison_map = comparison_map.numpy() |
|
|
|
|
|
|
|
|
if comparison_map.shape[0] > 256 or comparison_map.shape[1] > 256: |
|
|
|
|
|
center_i, center_j = comparison_map.shape[0] // 2, comparison_map.shape[1] // 2 |
|
|
map_sample = comparison_map[center_i-128:center_i+128, center_j-128:center_j+128] |
|
|
|
|
|
else: |
|
|
map_sample = comparison_map |
|
|
|
|
|
|
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 8)) |
|
|
|
|
|
|
|
|
cmap = mcolors.ListedColormap(['#e0e0e0', '#6495ED', '#DC143C', '#9932CC']) |
|
|
bounds = [-0.5, 0.5, 1.5, 2.5, 3.5] |
|
|
norm = mcolors.BoundaryNorm(bounds, cmap.N) |
|
|
|
|
|
|
|
|
|
|
|
fig_width_inches = fig.get_figwidth() |
|
|
data_width_pixels = map_sample.shape[1] |
|
|
dpi_for_data = max(300, (data_width_pixels * 3) / fig_width_inches) |
|
|
|
|
|
cax = ax.imshow(map_sample, cmap=cmap, norm=norm, interpolation='nearest', |
|
|
aspect='auto', rasterized=True) |
|
|
|
|
|
|
|
|
short_names = [_shorten_name(n) for n in names] |
|
|
patches = [ |
|
|
mpatches.Patch(color='#e0e0e0', label=f'Pruned in Both ({pruned_both/total_params:.2%})'), |
|
|
mpatches.Patch(color='#6495ED', label=f'Kept in {short_names[0]} Only ({kept_1_only/total_params:.2%})'), |
|
|
mpatches.Patch(color='#DC143C', label=f'Kept in {short_names[1]} Only ({kept_2_only/total_params:.2%})'), |
|
|
mpatches.Patch(color='#9932CC', label=f'Kept in Both (Intersection) ({intersection/total_params:.2%})') |
|
|
] |
|
|
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., fontsize=plt.rcParams['legend.fontsize']) |
|
|
|
|
|
|
|
|
param_info = _extract_parameter_info(layer_name) |
|
|
ax.set_title(param_info, pad=20) |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
|
optimal_dpi = _calculate_optimal_dpi(map_sample.shape) |
|
|
|
|
|
|
|
|
plt.savefig(output_path, dpi=optimal_dpi, bbox_inches='tight') |
|
|
|
|
|
|
|
|
resize_factor = 1.0 |
|
|
if map_sample.shape[0] > 512 or map_sample.shape[1] > 512: |
|
|
resize_factor = 0.75 |
|
|
|
|
|
|
|
|
|
|
|
_optimize_png_for_heatmap(output_path, num_colors=8, resize_factor=resize_factor) |
|
|
|
|
|
|
|
|
import os |
|
|
file_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
|
|
|
|
|
print(f"✓ Comparison heatmap for {layer_name} saved to {output_path} ({file_size_mb:.2f} MB)") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _create_rgb_heatmap(masks: List[Dict[str, Any]], layer_name: str, output_path: str, names: List[str], font_scale: float = 1.0): |
|
|
""" |
|
|
Creates and saves a 3-way RGB heatmap for a specific layer. |
|
|
""" |
|
|
mask1 = masks[0][layer_name].bool() |
|
|
|
|
|
|
|
|
if len(mask1.shape) != 2: |
|
|
return |
|
|
|
|
|
mask2, mask3 = masks[1][layer_name].bool(), masks[2][layer_name].bool() |
|
|
total_params = mask1.numel() |
|
|
|
|
|
|
|
|
intersect_1_only = (mask1 & ~mask2 & ~mask3).sum().item() |
|
|
intersect_2_only = (~mask1 & mask2 & ~mask3).sum().item() |
|
|
intersect_3_only = (~mask1 & ~mask2 & mask3).sum().item() |
|
|
intersect_1_2 = (mask1 & mask2 & ~mask3).sum().item() |
|
|
intersect_1_3 = (mask1 & ~mask2 & mask3).sum().item() |
|
|
intersect_2_3 = (~mask1 & mask2 & mask3).sum().item() |
|
|
intersect_1_2_3 = (mask1 & mask2 & mask3).sum().item() |
|
|
pruned_all = (~mask1 & ~mask2 & ~mask3).sum().item() |
|
|
|
|
|
|
|
|
rgb_image = torch.stack([mask1, mask2, mask3], dim=-1).numpy().astype(float) |
|
|
|
|
|
if rgb_image.shape[0] > 256 or rgb_image.shape[1] > 256: |
|
|
center_i, center_j = rgb_image.shape[0] // 2, rgb_image.shape[1] // 2 |
|
|
map_sample = rgb_image[center_i-128:center_i+128, center_j-128:center_j+128, :] |
|
|
|
|
|
else: |
|
|
map_sample = rgb_image |
|
|
|
|
|
|
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 8)) |
|
|
|
|
|
|
|
|
ax.imshow(map_sample, interpolation='nearest', aspect='auto', rasterized=True) |
|
|
|
|
|
|
|
|
param_info = _extract_parameter_info(layer_name) |
|
|
ax.set_title(param_info, pad=20) |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
|
short_names = [_shorten_name(n) for n in names] |
|
|
patches = [ |
|
|
mpatches.Patch(color='red', label=f'{short_names[0]} Only ({intersect_1_only/total_params:.2%})'), |
|
|
mpatches.Patch(color='green', label=f'{short_names[1]} Only ({intersect_2_only/total_params:.2%})'), |
|
|
mpatches.Patch(color='blue', label=f'{short_names[2]} Only ({intersect_3_only/total_params:.2%})'), |
|
|
mpatches.Patch(color='yellow', label=f'({short_names[0]})+({short_names[1]}) ({intersect_1_2/total_params:.2%})'), |
|
|
mpatches.Patch(color='cyan', label=f'({short_names[1]})+({short_names[2]}) ({intersect_2_3/total_params:.2%})'), |
|
|
mpatches.Patch(color='magenta', label=f'({short_names[0]})+({short_names[2]}) ({intersect_1_3/total_params:.2%})'), |
|
|
mpatches.Patch(color='white', label=f'All Three ({intersect_1_2_3/total_params:.2%})'), |
|
|
mpatches.Patch(color='black', label=f'Pruned in All ({pruned_all/total_params:.2%})') |
|
|
] |
|
|
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., fontsize=plt.rcParams['legend.fontsize']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimal_dpi = _calculate_optimal_dpi(map_sample.shape) |
|
|
|
|
|
|
|
|
plt.savefig(output_path, dpi=optimal_dpi, bbox_inches='tight') |
|
|
|
|
|
|
|
|
|
|
|
_optimize_png_for_heatmap(output_path, num_colors=32) |
|
|
|
|
|
|
|
|
import os |
|
|
file_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
|
|
|
|
|
print(f"✓ 3-way RGB heatmap for {layer_name} saved to {output_path} ({file_size_mb:.2f} MB)") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _create_sparsity_distribution_plot(mask: torch.Tensor, layer_name: str, output_path: str, font_scale: float = 1.0): |
|
|
""" |
|
|
Generates and saves a visualization of row and column sparsity distributions for a given layer mask. |
|
|
""" |
|
|
if not isinstance(mask, torch.Tensor): |
|
|
print(f"Skipping sparsity distribution for {layer_name}: mask is not a tensor.") |
|
|
return |
|
|
|
|
|
if mask.dim() != 2: |
|
|
return |
|
|
|
|
|
|
|
|
mask = mask.cpu().float() |
|
|
|
|
|
|
|
|
|
|
|
row_sparsity = 1.0 - mask.mean(dim=1) |
|
|
col_sparsity = 1.0 - mask.mean(dim=0) |
|
|
|
|
|
|
|
|
if row_sparsity.numel() <= 1 or col_sparsity.numel() <= 1: |
|
|
return |
|
|
|
|
|
|
|
|
plt.style.use('seaborn-v0_8-whitegrid') |
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
fig, axes = plt.subplots(2, 1, figsize=(12, 10), sharex=True) |
|
|
|
|
|
|
|
|
param_info = _extract_parameter_info(layer_name) |
|
|
fig.suptitle(f'Structural Sparsity Distribution: {param_info}', y=0.99) |
|
|
|
|
|
|
|
|
sns.histplot(row_sparsity.numpy(), ax=axes[0], bins=50, kde=True) |
|
|
axes[0].set_title(f'Row-wise Sparsity (Avg: {row_sparsity.mean():.2%})') |
|
|
axes[0].set_ylabel('Number of Rows') |
|
|
axes[0].tick_params(axis='both', which='major', labelsize=plt.rcParams['xtick.labelsize']) |
|
|
axes[0].grid(True, which='both', linestyle='--', linewidth=0.5) |
|
|
|
|
|
|
|
|
sns.histplot(col_sparsity.numpy(), ax=axes[1], bins=50, kde=True) |
|
|
axes[1].set_title(f'Column-wise Sparsity (Avg: {col_sparsity.mean():.2%})') |
|
|
|
|
|
axes[1].set_ylabel('Number of Columns') |
|
|
axes[1].tick_params(axis='both', which='major', labelsize=plt.rcParams['xtick.labelsize']) |
|
|
axes[1].grid(True, which='both', linestyle='--', linewidth=0.5) |
|
|
|
|
|
plt.tight_layout(rect=[0, 0, 1, 0.96]) |
|
|
|
|
|
|
|
|
plt.savefig(output_path, dpi=120, bbox_inches='tight') |
|
|
|
|
|
|
|
|
|
|
|
_optimize_png_for_heatmap(output_path, num_colors=16) |
|
|
|
|
|
|
|
|
pdf_output_path = os.path.splitext(output_path)[0] + '.pdf' |
|
|
plt.savefig(pdf_output_path, format='pdf') |
|
|
|
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _create_n_way_count_heatmap(masks_list: List[Dict[str, Any]], layer_name: str, output_path: str, names: List[str], font_scale: float = 1.0) -> None: |
|
|
""" |
|
|
Creates and saves an N-way (N>=4) count heatmap for a specific layer. |
|
|
Each pixel value indicates how many runs (0..N) kept that parameter (mask==True). |
|
|
""" |
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
num_models = len(masks_list) |
|
|
if num_models < 4: |
|
|
return |
|
|
|
|
|
|
|
|
mask0 = masks_list[0][layer_name].bool() |
|
|
if len(mask0.shape) != 2: |
|
|
return |
|
|
|
|
|
|
|
|
stacked = torch.stack([m[layer_name].bool() for m in masks_list], dim=0) |
|
|
keep_counts = stacked.sum(dim=0).to(torch.int16) |
|
|
|
|
|
total_params = keep_counts.numel() |
|
|
|
|
|
|
|
|
keep_counts_np = keep_counts.numpy() |
|
|
if keep_counts_np.shape[0] > 256 or keep_counts_np.shape[1] > 256: |
|
|
ci, cj = keep_counts_np.shape[0] // 2, keep_counts_np.shape[1] // 2 |
|
|
map_sample = keep_counts_np[ci-128:ci+128, cj-128:cj+128] |
|
|
else: |
|
|
map_sample = keep_counts_np |
|
|
|
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 8)) |
|
|
|
|
|
|
|
|
discrete_colors = plt.cm.viridis(np.linspace(0.05, 0.95, num_models + 1)) |
|
|
cmap = mcolors.ListedColormap(discrete_colors) |
|
|
bounds = np.arange(-0.5, num_models + 1.5, 1) |
|
|
norm = mcolors.BoundaryNorm(bounds, cmap.N) |
|
|
|
|
|
|
|
|
ax.imshow(map_sample, cmap=cmap, norm=norm, interpolation='nearest', zorder=1, rasterized=True) |
|
|
|
|
|
|
|
|
param_info = _extract_parameter_info(layer_name) |
|
|
ax.set_title(param_info, pad=20) |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
|
counts, _ = np.histogram(keep_counts_np, bins=np.arange(-0.5, num_models + 1.5, 1)) |
|
|
short_names = [_shorten_name(n) for n in names] |
|
|
summary_patches = [] |
|
|
for k in range(num_models + 1): |
|
|
frac = counts[k] / total_params if total_params > 0 else 0.0 |
|
|
label = 'Pruned in All' if k == 0 else f'Kept in {k} of {num_models}' |
|
|
summary_patches.append(mpatches.Patch(color=cmap(k), label=f'{label} ({frac:.2%})')) |
|
|
|
|
|
ax.legend(handles=summary_patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., fontsize=plt.rcParams['legend.fontsize']) |
|
|
|
|
|
|
|
|
optimal_dpi = _calculate_optimal_dpi(map_sample.shape) |
|
|
|
|
|
|
|
|
plt.savefig(output_path, dpi=optimal_dpi, bbox_inches='tight') |
|
|
|
|
|
|
|
|
|
|
|
_optimize_png_for_heatmap(output_path, num_colors=min(8, num_models + 1)) |
|
|
|
|
|
|
|
|
import os |
|
|
file_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
|
|
|
|
|
print(f"✓ {num_models}-way count heatmap for {layer_name} saved to {output_path} ({file_size_mb:.2f} MB)") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _create_n_way_subset_heatmap(masks_list: List[Dict[str, Any]], layer_name: str, output_path: str, names: List[str], legend_style: Optional[str] = 'auto', legend_max_rows: Optional[int] = None, font_scale: float = 1.0) -> None: |
|
|
""" |
|
|
Creates and saves an N-way (N>=4) subset-categorical heatmap for a specific layer. |
|
|
Each pixel is assigned to one of 2^N categories (bitmask across experts). |
|
|
Legend behavior is controlled via legend_style: |
|
|
- 'auto' : UpSet-style legend for N>=4, regular list otherwise |
|
|
- 'upset' : UpSet-style legend (dot-matrix + proportion bars) |
|
|
- 'list' : Original 2^N textual legend entries |
|
|
- 'none' : No legend |
|
|
""" |
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
num_models = len(masks_list) |
|
|
if num_models < 4: |
|
|
return |
|
|
|
|
|
|
|
|
mask0 = masks_list[0][layer_name].bool() |
|
|
if len(mask0.shape) != 2: |
|
|
return |
|
|
|
|
|
|
|
|
bitmask = torch.zeros_like(mask0, dtype=torch.int32) |
|
|
for i in range(num_models): |
|
|
m_i = masks_list[i][layer_name].bool() |
|
|
bitmask |= (m_i.to(torch.int32) << i) |
|
|
|
|
|
|
|
|
bitmask_np = bitmask.numpy() |
|
|
if bitmask_np.shape[0] > 256 or bitmask_np.shape[1] > 256: |
|
|
ci, cj = bitmask_np.shape[0] // 2, bitmask_np.shape[1] // 2 |
|
|
map_sample = bitmask_np[ci-128:ci+128, cj-128:cj+128] |
|
|
else: |
|
|
map_sample = bitmask_np |
|
|
|
|
|
|
|
|
|
|
|
base_colors_hex = ['#FF0000', '#00AA00', '#0000FF', '#FF8C00', '#800080', '#00CED1', '#FFD700', '#8B4513'] |
|
|
if num_models > len(base_colors_hex): |
|
|
extra = num_models - len(base_colors_hex) |
|
|
for k in range(extra): |
|
|
hue = (k + 1) / (extra + 1) |
|
|
col = plt.cm.hsv(hue) |
|
|
base_colors_hex.append(mcolors.to_hex(col)) |
|
|
base_rgbs = [np.array(mcolors.to_rgb(h)) for h in base_colors_hex[:num_models]] |
|
|
|
|
|
num_categories = 1 << num_models |
|
|
colors = [] |
|
|
for cat in range(num_categories): |
|
|
if cat == 0: |
|
|
colors.append('#000000') |
|
|
continue |
|
|
|
|
|
|
|
|
if cat == (num_categories - 1) and num_models > 1: |
|
|
colors.append('#FFFFFF') |
|
|
continue |
|
|
|
|
|
indices = [i for i in range(num_models) if (cat >> i) & 1] |
|
|
mix = np.mean([base_rgbs[i] for i in indices], axis=0) |
|
|
mix = np.clip(mix ** 0.9, 0, 1) |
|
|
colors.append(mcolors.to_hex(mix)) |
|
|
|
|
|
cmap = mcolors.ListedColormap(colors) |
|
|
|
|
|
|
|
|
full_counts, _ = np.histogram(bitmask_np, bins=np.arange(-0.5, num_categories + 0.5, 1)) |
|
|
total_params = bitmask_np.size if bitmask_np.size > 0 else 1 |
|
|
short_names = [_shorten_name(n) for n in names] |
|
|
|
|
|
|
|
|
style = (legend_style or 'auto').lower() |
|
|
if style == 'auto': |
|
|
style = 'upset' |
|
|
|
|
|
|
|
|
if style == 'upset': |
|
|
|
|
|
fig = plt.figure(figsize=(12, 8)) |
|
|
from matplotlib import gridspec as _gs |
|
|
gs = _gs.GridSpec(1, 2, width_ratios=[1.0, 1.25], wspace=0.3) |
|
|
ax = fig.add_subplot(gs[0]) |
|
|
ax_leg = fig.add_subplot(gs[1]) |
|
|
else: |
|
|
fig, ax = plt.subplots(figsize=(8, 8)) |
|
|
ax_leg = None |
|
|
|
|
|
|
|
|
|
|
|
ax.imshow(map_sample.astype(float), cmap=cmap, vmin=0, vmax=num_categories - 1, |
|
|
interpolation='nearest', zorder=1, rasterized=True) |
|
|
|
|
|
|
|
|
param_info = _extract_parameter_info(layer_name) |
|
|
ax.set_title(param_info, pad=20) |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
|
if style == 'list': |
|
|
patches = [] |
|
|
for cat in range(num_categories): |
|
|
frac = full_counts[cat] / total_params |
|
|
|
|
|
if cat == 0: |
|
|
label = f'Pruned in All ({frac:.2%})' |
|
|
patches.append(mpatches.Patch(color=colors[cat], label=label)) |
|
|
|
|
|
elif cat == num_categories - 1 and num_models > 1: |
|
|
label = f'Kept in All ({frac:.2%})' |
|
|
|
|
|
patches.append(mpatches.Patch(color=colors[cat], label=label, edgecolor='black', linewidth=0.75)) |
|
|
|
|
|
else: |
|
|
included = [short_names[i] for i in range(num_models) if (cat >> i) & 1] |
|
|
label = "+".join(included) + f" ({frac:.2%})" |
|
|
patches.append(mpatches.Patch(color=colors[cat], label=label)) |
|
|
|
|
|
ax.legend(handles=patches, bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0., fontsize=plt.rcParams['legend.fontsize']) |
|
|
|
|
|
elif style == 'upset' and ax_leg is not None: |
|
|
|
|
|
cats = [cat for cat in range(1, num_categories) if full_counts[cat] > 0] |
|
|
|
|
|
cats.sort(key=lambda c: full_counts[c], reverse=True) |
|
|
if legend_max_rows is not None and legend_max_rows > 0: |
|
|
cats = cats[:legend_max_rows] |
|
|
|
|
|
num_rows = len(cats) |
|
|
y_positions = np.arange(num_rows)[::-1] |
|
|
|
|
|
|
|
|
for r, cat in enumerate(cats): |
|
|
y = y_positions[r] |
|
|
for i in range(num_models): |
|
|
on = ((cat >> i) & 1) == 1 |
|
|
ax_leg.scatter(i, y, s=36, c='k' if on else 'white', edgecolors='k', linewidths=0.75, zorder=3) |
|
|
|
|
|
|
|
|
x_bar0 = num_models + 0.8 |
|
|
max_bar_width = 2.2 |
|
|
for r, cat in enumerate(cats): |
|
|
y = y_positions[r] |
|
|
frac = full_counts[cat] / total_params |
|
|
w = max_bar_width * frac |
|
|
rect = mpatches.Rectangle((x_bar0, y - 0.3), w, 0.6, color=colors[cat], zorder=2) |
|
|
ax_leg.add_patch(rect) |
|
|
ax_leg.text(x_bar0 + w + 0.05, y, f"{frac:.2%}", va='center', fontsize=plt.rcParams['font.size']) |
|
|
|
|
|
|
|
|
ax_leg.set_ylim(-0.5, num_rows - 0.5) |
|
|
ax_leg.set_xlim(-0.5, x_bar0 + max_bar_width + 1.1) |
|
|
ax_leg.set_yticks([]) |
|
|
ax_leg.set_xticks(list(range(num_models)) + [x_bar0]) |
|
|
ax_leg.set_xticklabels(short_names + [' ']) |
|
|
ax_leg.tick_params(axis='x', labelrotation=45) |
|
|
ax_leg.axvline(x=x_bar0 - 0.4, color='gray', linewidth=1) |
|
|
ax_leg.set_title('Intersections (UpSet-style)', pad=10) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optimal_dpi = _calculate_optimal_dpi(map_sample.shape) |
|
|
|
|
|
|
|
|
plt.savefig(output_path, dpi=optimal_dpi, bbox_inches='tight') |
|
|
|
|
|
|
|
|
|
|
|
max_colors = min(32, 1 << num_models) |
|
|
_optimize_png_for_heatmap(output_path, num_colors=max_colors) |
|
|
|
|
|
|
|
|
import os |
|
|
file_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
|
|
|
|
|
print(f"✓ {num_models}-way subset heatmap for {layer_name} saved to {output_path} ({file_size_mb:.2f} MB)") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _load_preconditioner_map(file_path: str) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Loads a safetensors file, attempting to download from HF hub if not found locally. |
|
|
""" |
|
|
if not os.path.exists(file_path): |
|
|
try: |
|
|
|
|
|
parts = file_path.split('/') |
|
|
if len(parts) < 3: |
|
|
raise ValueError(f"Invalid Hugging Face path format: '{file_path}'") |
|
|
|
|
|
repo_id = f"{parts[0]}/{parts[1]}" |
|
|
filename = "/".join(parts[2:]) |
|
|
|
|
|
print(f" -> Preconditioner '{file_path}' not found locally.") |
|
|
print(f" Attempting download from repo='{repo_id}', filename='{filename}'...") |
|
|
|
|
|
resolved_path = hf_hub_download(repo_id=repo_id, filename=filename) |
|
|
file_path = resolved_path |
|
|
print(f" Successfully downloaded to: {file_path}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f" -> ERROR: Failed to download from Hugging Face Hub: {e}") |
|
|
raise FileNotFoundError(f"Could not find or download preconditioner file: {file_path}") from e |
|
|
|
|
|
tensors: Dict[str, torch.Tensor] = {} |
|
|
with safe_open(file_path, framework="pt", device="cpu") as f: |
|
|
for key in f.keys(): |
|
|
tensors[key] = f.get_tensor(key) |
|
|
return tensors |
|
|
|
|
|
|
|
|
def _map_layer_to_precond_key(layer_name: str, precond_map: Dict[str, torch.Tensor]) -> Optional[str]: |
|
|
""" |
|
|
Try mapping a mask layer name like '...weight' to a preconditioner key like '...exp_avg_sq'. |
|
|
Handles presence/absence of 'model.' prefix. |
|
|
""" |
|
|
candidates: List[str] = [layer_name] |
|
|
|
|
|
if layer_name.endswith('.weight'): |
|
|
candidates.append(layer_name[:-len('.weight')] + '.exp_avg_sq') |
|
|
else: |
|
|
candidates.append(layer_name + '.exp_avg_sq') |
|
|
|
|
|
|
|
|
more: List[str] = [] |
|
|
for c in candidates: |
|
|
if c.startswith('model.'): |
|
|
more.append(c[len('model.'):]) |
|
|
else: |
|
|
more.append('model.' + c) |
|
|
candidates.extend(more) |
|
|
|
|
|
for key in candidates: |
|
|
if key in precond_map: |
|
|
return key |
|
|
return None |
|
|
|
|
|
|
|
|
def _create_n_way_winner_tiebreak_heatmap( |
|
|
masks_list: List[Dict[str, Any]], |
|
|
preconds_list: List[Dict[str, torch.Tensor]], |
|
|
layer_name: str, |
|
|
output_path: str, |
|
|
names: List[str], |
|
|
threshold: float, |
|
|
font_scale: float = 1.0, |
|
|
) -> None: |
|
|
""" |
|
|
For each parameter element, choose a single winner among N experts using second moments |
|
|
to break ties when multiple masks keep the element: |
|
|
- If exactly one mask keeps the element → assign to that expert |
|
|
- If >=2 keep it → compute (max/min) of exp_avg_sq over kept experts; if >= threshold → assign to argmax expert; else assign to a fallback category |
|
|
- If none keep it → assign to a fallback category |
|
|
""" |
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
num_models = len(masks_list) |
|
|
if num_models != len(preconds_list) or num_models < 2: |
|
|
return |
|
|
|
|
|
|
|
|
mask0 = masks_list[0][layer_name].bool() |
|
|
if mask0.dim() != 2: |
|
|
return |
|
|
|
|
|
H, W = mask0.shape |
|
|
masks_stack = torch.stack([masks_list[i][layer_name].bool() for i in range(num_models)], dim=0) |
|
|
|
|
|
|
|
|
pre_stack_list: List[torch.Tensor] = [] |
|
|
for i in range(num_models): |
|
|
key = _map_layer_to_precond_key(layer_name, preconds_list[i]) |
|
|
if key is None: |
|
|
return |
|
|
t = preconds_list[i][key] |
|
|
if t.dim() != 2 or t.shape != (H, W): |
|
|
return |
|
|
pre_stack_list.append(t.to(torch.float32)) |
|
|
|
|
|
pre_stack = torch.stack(pre_stack_list, dim=0) |
|
|
|
|
|
|
|
|
candidate_counts = masks_stack.sum(dim=0) |
|
|
|
|
|
|
|
|
neg_inf = torch.tensor(float('-inf'), dtype=pre_stack.dtype) |
|
|
pos_inf = torch.tensor(float('inf'), dtype=pre_stack.dtype) |
|
|
pre_for_max = torch.where(masks_stack, pre_stack, neg_inf) |
|
|
pre_for_min = torch.where(masks_stack, pre_stack, pos_inf) |
|
|
|
|
|
max_vals, max_idx = torch.max(pre_for_max, dim=0) |
|
|
min_vals, _ = torch.min(pre_for_min, dim=0) |
|
|
|
|
|
|
|
|
pruned_by_all_idx = num_models |
|
|
tie_idx = num_models + 1 |
|
|
|
|
|
|
|
|
winner = torch.full((H, W), -1, dtype=torch.int64) |
|
|
|
|
|
|
|
|
pruned_mask = (candidate_counts == 0) |
|
|
winner[pruned_mask] = pruned_by_all_idx |
|
|
|
|
|
|
|
|
single_mask = (candidate_counts == 1) |
|
|
winner[single_mask] = max_idx[single_mask] |
|
|
|
|
|
|
|
|
multi_mask = (candidate_counts >= 2) |
|
|
|
|
|
eps = torch.tensor(1e-28, dtype=pre_stack.dtype) |
|
|
ratio = max_vals / (min_vals + eps) |
|
|
|
|
|
|
|
|
strong_dom_mask = (ratio >= threshold) & multi_mask |
|
|
winner[strong_dom_mask] = max_idx[strong_dom_mask] |
|
|
|
|
|
|
|
|
tie_mask = (ratio < threshold) & multi_mask |
|
|
winner[tie_mask] = tie_idx |
|
|
|
|
|
|
|
|
if (winner == -1).any(): |
|
|
print(f"Warning: some pixels in layer {layer_name} were not assigned a category.") |
|
|
|
|
|
|
|
|
display = winner |
|
|
if H > 256 or W > 256: |
|
|
ci, cj = H // 2, W // 2 |
|
|
display = display[ci-128:ci+128, cj-128:cj+128] |
|
|
|
|
|
|
|
|
if num_models <= 10: |
|
|
model_colors = plt.cm.get_cmap('tab10', num_models).colors |
|
|
elif num_models <= 12: |
|
|
model_colors = plt.cm.get_cmap('Paired', num_models).colors |
|
|
elif num_models <= 20: |
|
|
model_colors = plt.cm.get_cmap('tab20', num_models).colors |
|
|
else: |
|
|
model_colors = plt.cm.get_cmap('viridis', num_models).colors |
|
|
|
|
|
colors = [mcolors.to_hex(c) for c in model_colors] |
|
|
colors.append('#000000') |
|
|
colors.append('#808080') |
|
|
|
|
|
cmap = mcolors.ListedColormap(colors) |
|
|
bounds = [i - 0.5 for i in range(num_models + 3)] |
|
|
norm = mcolors.BoundaryNorm(bounds, cmap.N) |
|
|
|
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(8, 8)) |
|
|
im = ax.imshow(display.numpy(), cmap=cmap, norm=norm, interpolation='nearest', zorder=1, rasterized=True) |
|
|
|
|
|
|
|
|
param_info = _extract_parameter_info(layer_name) |
|
|
ax.set_title(param_info, pad=20) |
|
|
ax.set_xticks([]) |
|
|
ax.set_yticks([]) |
|
|
|
|
|
|
|
|
ticks = list(range(num_models + 2)) |
|
|
cbar = plt.colorbar(im, ticks=ticks, spacing='proportional') |
|
|
labels = names[:] + ["Pruned by All", "Tie"] |
|
|
cbar.set_ticklabels(labels) |
|
|
cbar.ax.tick_params(labelsize=plt.rcParams['legend.fontsize']) |
|
|
|
|
|
|
|
|
optimal_dpi = _calculate_optimal_dpi(display.shape) |
|
|
|
|
|
|
|
|
plt.savefig(output_path, dpi=optimal_dpi, bbox_inches='tight') |
|
|
|
|
|
|
|
|
|
|
|
_optimize_png_for_heatmap(output_path, num_colors=num_models + 2) |
|
|
|
|
|
|
|
|
import os |
|
|
file_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
|
|
|
|
|
print(f"✓ Winner tie-break heatmap for {layer_name} saved to {output_path} ({file_size_mb:.2f} MB)") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def generate_comparison_visualizations( |
|
|
dirs: List[str], |
|
|
names: List[str], |
|
|
output_dir: str, |
|
|
precond_paths: Optional[List[str]] = None, |
|
|
winner_tie_break_threshold: Optional[float] = None, |
|
|
legend_style: str = 'auto', |
|
|
legend_max_rows: int = 16, |
|
|
font_scaling: Optional[Dict[str, float]] = None, |
|
|
plots_to_generate: Optional[Dict[str, bool]] = None, |
|
|
): |
|
|
""" |
|
|
Generates and saves visualizations comparing masks from two or three runs. |
|
|
""" |
|
|
num_dirs = len(dirs) |
|
|
print(f"--- Generating {num_dirs}-way comparison visualizations for: {', '.join(names)} ---") |
|
|
|
|
|
|
|
|
if font_scaling is None: |
|
|
font_scaling = {} |
|
|
default_scale = font_scaling.get('default', 1.0) |
|
|
|
|
|
if plots_to_generate is None: |
|
|
plots_to_generate = {} |
|
|
|
|
|
|
|
|
masks = [load_masks_from_run(d) for d in dirs] |
|
|
|
|
|
|
|
|
if not masks: |
|
|
print("No masks loaded, skipping comparison.") |
|
|
return |
|
|
common_layers = list(set.intersection(*(set(m.keys()) for m in masks))) |
|
|
|
|
|
|
|
|
if num_dirs == 2 and plots_to_generate.get('n_way_comparison_plots', True): |
|
|
print("\n📊 Calculating layer-wise Jaccard overlap...") |
|
|
layer_jaccard = _calculate_layerwise_jaccard(masks[0], masks[1]) |
|
|
|
|
|
print("\n🎨 Generating comparison bar chart...") |
|
|
barchart_path = os.path.join(output_dir, "layerwise_jaccard_comparison.png") |
|
|
jaccard_scale = font_scaling.get('jaccard_barchart', default_scale) |
|
|
_create_jaccard_barchart(layer_jaccard, barchart_path, names, font_scale=jaccard_scale) |
|
|
|
|
|
|
|
|
if plots_to_generate.get('n_way_comparison_plots', True): |
|
|
print(f"\n🎨 Generating heatmaps for all {len(common_layers)} common layers...") |
|
|
|
|
|
if num_dirs == 2: |
|
|
|
|
|
heatmap_dir = os.path.join(output_dir, "heatmaps_2way") |
|
|
os.makedirs(heatmap_dir, exist_ok=True) |
|
|
print(f"Saving 2-way heatmaps to: {heatmap_dir}") |
|
|
|
|
|
comp_scale = font_scaling.get('comparison_heatmap', default_scale) |
|
|
for layer_name in tqdm(common_layers, desc="Generating 2-way heatmaps"): |
|
|
heatmap_path = os.path.join(heatmap_dir, f"comparison_heatmap_{layer_name}.png") |
|
|
_create_comparison_heatmap(masks[0], masks[1], layer_name, heatmap_path, names, font_scale=comp_scale) |
|
|
|
|
|
elif num_dirs == 3: |
|
|
|
|
|
heatmap_dir = os.path.join(output_dir, "heatmaps_3way_rgb") |
|
|
os.makedirs(heatmap_dir, exist_ok=True) |
|
|
print(f"Saving 3-way RGB heatmaps to: {heatmap_dir}") |
|
|
|
|
|
rgb_scale = font_scaling.get('rgb_heatmap', default_scale) |
|
|
for layer_name in tqdm(common_layers, desc="Generating 3-way heatmaps"): |
|
|
heatmap_path = os.path.join(heatmap_dir, f"rgb_heatmap_{layer_name}.png") |
|
|
_create_rgb_heatmap(masks, layer_name, heatmap_path, names, font_scale=rgb_scale) |
|
|
|
|
|
else: |
|
|
|
|
|
heatmap_dir = os.path.join(output_dir, f"heatmaps_{num_dirs}way_subsets") |
|
|
os.makedirs(heatmap_dir, exist_ok=True) |
|
|
print(f"Saving {num_dirs}-way subset heatmaps to: {heatmap_dir}") |
|
|
|
|
|
subset_scale = font_scaling.get('subset_heatmap', default_scale) |
|
|
for layer_name in tqdm(common_layers, desc=f"Generating {num_dirs}-way subset heatmaps"): |
|
|
heatmap_path = os.path.join(heatmap_dir, f"subsets_heatmap_{layer_name}.png") |
|
|
_create_n_way_subset_heatmap(masks, layer_name, heatmap_path, names, legend_style=legend_style, legend_max_rows=legend_max_rows, font_scale=subset_scale) |
|
|
|
|
|
|
|
|
if precond_paths is not None and len(precond_paths) == len(dirs) and (winner_tie_break_threshold is not None): |
|
|
if plots_to_generate.get('winner_tiebreak_heatmap', True): |
|
|
try: |
|
|
precond_maps = [_load_preconditioner_map(p) for p in precond_paths] |
|
|
except Exception as e: |
|
|
print(f"Warning: failed to load preconditioners: {e}. Skipping winner tie-break heatmaps.") |
|
|
precond_maps = None |
|
|
|
|
|
if precond_maps is not None: |
|
|
winner_dir = os.path.join(output_dir, f"heatmaps_{len(dirs)}way_winner_tiebreak") |
|
|
os.makedirs(winner_dir, exist_ok=True) |
|
|
print(f"Saving winner tie-break heatmaps to: {winner_dir}") |
|
|
winner_scale = font_scaling.get('winner_tiebreak_heatmap', default_scale) |
|
|
for layer_name in tqdm(common_layers, desc="Generating tie-break winner heatmaps"): |
|
|
out_path = os.path.join(winner_dir, f"winner_tiebreak_{layer_name}.png") |
|
|
try: |
|
|
_create_n_way_winner_tiebreak_heatmap(masks, precond_maps, layer_name, out_path, names, threshold=float(winner_tie_break_threshold), font_scale=winner_scale) |
|
|
except Exception as e: |
|
|
print(f"Skipping winner heatmap for {layer_name}: {e}") |
|
|
|
|
|
print("\nComparison visualizations finished for this set.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _save_precond_heatmap_optimized(fig, base_path: str, data_shape: tuple, plot_format: str = "png", |
|
|
compression_level: int = 9, is_per_model: bool = False) -> str: |
|
|
""" |
|
|
Save preconditioner heatmap with optimal compression strategy. |
|
|
|
|
|
Args: |
|
|
fig: Matplotlib figure |
|
|
base_path: Base path without extension |
|
|
data_shape: Shape of the data being visualized |
|
|
plot_format: Desired format (png, jpg, pdf) |
|
|
compression_level: PNG compression level (0-9, 9 is max compression) |
|
|
is_per_model: Whether this is a per-model heatmap (uses more aggressive compression) |
|
|
|
|
|
Returns: |
|
|
Path to saved file |
|
|
""" |
|
|
|
|
|
optimal_dpi = _calculate_optimal_dpi(data_shape, is_per_model=is_per_model) |
|
|
|
|
|
|
|
|
estimated_pixels = (data_shape[0] * data_shape[1] * optimal_dpi**2) / (100**2) |
|
|
|
|
|
|
|
|
if plot_format == "auto": |
|
|
if estimated_pixels > 5_000_000: |
|
|
plot_format = "jpg" |
|
|
else: |
|
|
plot_format = "png" |
|
|
|
|
|
output_path = f"{base_path}.{plot_format}" |
|
|
|
|
|
if plot_format == "png": |
|
|
|
|
|
try: |
|
|
|
|
|
import io |
|
|
from PIL import Image |
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format='png', dpi=optimal_dpi, bbox_inches='tight', |
|
|
pad_inches=0.05, facecolor='white') |
|
|
buf.seek(0) |
|
|
|
|
|
|
|
|
img = Image.open(buf) |
|
|
img.save(output_path, 'PNG', optimize=True, compress_level=compression_level) |
|
|
buf.close() |
|
|
except ImportError: |
|
|
|
|
|
plt.savefig(output_path, dpi=optimal_dpi, bbox_inches='tight', |
|
|
pad_inches=0.05, format='png') |
|
|
|
|
|
elif plot_format == "jpg" or plot_format == "jpeg": |
|
|
|
|
|
quality = 85 if estimated_pixels > 10_000_000 else 90 |
|
|
plt.savefig(output_path, dpi=optimal_dpi, bbox_inches='tight', |
|
|
pad_inches=0.05, format='jpeg', quality=quality) |
|
|
|
|
|
elif plot_format == "pdf": |
|
|
|
|
|
output_path = _save_heatmap_pdf(fig, base_path, data_shape) |
|
|
|
|
|
else: |
|
|
|
|
|
plt.savefig(output_path, dpi=optimal_dpi, bbox_inches='tight', |
|
|
pad_inches=0.05, format=plot_format) |
|
|
|
|
|
|
|
|
if os.path.exists(output_path): |
|
|
file_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
|
|
if file_size_mb > 10: |
|
|
print(f" ⚠️ Large file: {output_path} ({file_size_mb:.1f} MB)") |
|
|
|
|
|
return output_path |
|
|
|
|
|
|
|
|
def _adaptive_downsample_precond(data: torch.Tensor, max_side: int = 256, |
|
|
preserve_patterns: bool = True) -> torch.Tensor: |
|
|
""" |
|
|
Adaptively downsample preconditioner data while preserving important patterns. |
|
|
|
|
|
Args: |
|
|
data: 2D tensor to downsample |
|
|
max_side: Maximum dimension for output |
|
|
preserve_patterns: Whether to use max pooling to preserve high-value regions |
|
|
|
|
|
Returns: |
|
|
Downsampled tensor |
|
|
""" |
|
|
if data.shape[0] <= max_side and data.shape[1] <= max_side: |
|
|
return data |
|
|
|
|
|
|
|
|
factor_h = max(1, data.shape[0] // max_side) |
|
|
factor_w = max(1, data.shape[1] // max_side) |
|
|
|
|
|
if preserve_patterns and factor_h > 1 and factor_w > 1: |
|
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
data_4d = data.unsqueeze(0).unsqueeze(0) |
|
|
|
|
|
|
|
|
pooled = F.max_pool2d(data_4d, kernel_size=(factor_h, factor_w), |
|
|
stride=(factor_h, factor_w)) |
|
|
|
|
|
|
|
|
result = pooled.squeeze(0).squeeze(0) |
|
|
|
|
|
|
|
|
if result.shape[0] > max_side or result.shape[1] > max_side: |
|
|
step_h = max(1, result.shape[0] // max_side) |
|
|
step_w = max(1, result.shape[1] // max_side) |
|
|
result = result[::step_h, ::step_w] |
|
|
|
|
|
return result |
|
|
else: |
|
|
|
|
|
step_h = max(1, int(torch.ceil(torch.tensor(data.shape[0] / max_side)).item())) |
|
|
step_w = max(1, int(torch.ceil(torch.tensor(data.shape[1] / max_side)).item())) |
|
|
return data[::step_h, ::step_w] |
|
|
|
|
|
|
|
|
def _plot_precond_histogram(data_tensor: torch.Tensor, title_prefix: str, base_filename: str, out_dir: str, |
|
|
use_log_x_scale_heuristic: bool = False, force_linear_x_scale: bool = False, |
|
|
plot_format: str = "png", font_scale: float = 1.0) -> None: |
|
|
""" |
|
|
Creates histogram for preconditioner data with publication-ready styling. |
|
|
""" |
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
numpy_data = data_tensor.detach().cpu().flatten().numpy() |
|
|
current_xlabel = title_prefix |
|
|
positive_data_for_log = numpy_data[numpy_data > 0] |
|
|
|
|
|
if force_linear_x_scale: |
|
|
plt.hist(numpy_data, bins=100, edgecolor='black', linewidth=0.5) |
|
|
elif use_log_x_scale_heuristic and len(positive_data_for_log) > 0 and positive_data_for_log.max() > 1000: |
|
|
if (numpy_data == 0).any(): |
|
|
min_log_val = np.log10(max(1e-30, positive_data_for_log.min())) |
|
|
max_log_val = np.log10(positive_data_for_log.max()) |
|
|
if max_log_val > min_log_val: |
|
|
bins = np.logspace(min_log_val, max_log_val, 50) |
|
|
plt.hist(positive_data_for_log, bins=bins, label=f'>0 values (max {positive_data_for_log.max():.2e})', |
|
|
edgecolor='black', linewidth=0.5) |
|
|
else: |
|
|
plt.hist(positive_data_for_log, bins=50, label=f'>0 values (max {positive_data_for_log.max():.2e})', |
|
|
edgecolor='black', linewidth=0.5) |
|
|
plt.legend(fontsize=plt.rcParams['legend.fontsize']) |
|
|
plt.xscale('log') |
|
|
else: |
|
|
min_log_val = np.log10(max(1e-30, positive_data_for_log.min())) |
|
|
max_log_val = np.log10(positive_data_for_log.max()) |
|
|
if max_log_val > min_log_val: |
|
|
bins = np.logspace(min_log_val, max_log_val, 50) |
|
|
plt.hist(positive_data_for_log, bins=bins, edgecolor='black', linewidth=0.5) |
|
|
else: |
|
|
plt.hist(positive_data_for_log, bins=50, edgecolor='black', linewidth=0.5) |
|
|
plt.xscale('log') |
|
|
current_xlabel = f"{title_prefix} (Log Scale for x > 0)" |
|
|
else: |
|
|
plt.hist(numpy_data, bins=100, edgecolor='black', linewidth=0.5) |
|
|
|
|
|
|
|
|
param_info = _extract_parameter_info(base_filename) |
|
|
plt.title(f"Histogram of {title_prefix}\n{param_info}") |
|
|
plt.xlabel(current_xlabel) |
|
|
plt.ylabel("Frequency") |
|
|
plt.grid(True, linestyle='--', alpha=0.7) |
|
|
|
|
|
clean_title_prefix = title_prefix.lower().replace(' ', '_').replace('/', '_').replace('(', '').replace(')', '').replace('>', 'gt') |
|
|
histograms_dir = os.path.join(out_dir, "histograms") |
|
|
os.makedirs(histograms_dir, exist_ok=True) |
|
|
histogram_path = os.path.join(histograms_dir, f"{base_filename}_{clean_title_prefix}_histogram.{plot_format}") |
|
|
plt.tight_layout() |
|
|
plt.savefig(histogram_path, bbox_inches='tight', pad_inches=0.05, dpi=300) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _plot_precond_heatmap(data_tensor: torch.Tensor, title_prefix: str, base_filename: str, out_dir: str, |
|
|
force_linear_scale: bool = False, plot_format: str = "png", font_scale: float = 1.0, |
|
|
max_side: int = 256) -> None: |
|
|
""" |
|
|
Creates heatmap for preconditioner data with publication-ready styling. |
|
|
""" |
|
|
if data_tensor.ndim != 2: |
|
|
return |
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
|
|
|
data = data_tensor.detach().cpu() |
|
|
if data.shape[0] > max_side or data.shape[1] > max_side: |
|
|
|
|
|
data = _adaptive_downsample_precond(data, max_side, preserve_patterns=True) |
|
|
|
|
|
plt.figure(figsize=(12, 10)) |
|
|
numpy_tensor = data.numpy() |
|
|
|
|
|
|
|
|
if not np.isfinite(numpy_tensor).all(): |
|
|
numpy_tensor = np.nan_to_num(numpy_tensor, nan=0.0, |
|
|
posinf=np.finfo(numpy_tensor.dtype if np.issubdtype(numpy_tensor.dtype, np.floating) else np.float32).max, |
|
|
neginf=0.0) |
|
|
|
|
|
norm = None |
|
|
scale_type = "linear" |
|
|
imshow_vmin = None |
|
|
imshow_vmax = None |
|
|
|
|
|
if force_linear_scale: |
|
|
imshow_vmin = 0 |
|
|
data_max = np.max(numpy_tensor) if numpy_tensor.size > 0 else 1.0 |
|
|
data_min = np.min(numpy_tensor) if numpy_tensor.size > 0 else 0.0 |
|
|
imshow_vmax = data_max |
|
|
if not np.isfinite(imshow_vmax): |
|
|
imshow_vmax = 1.0 |
|
|
if imshow_vmax <= imshow_vmin: |
|
|
imshow_vmax = imshow_vmin + 1.0 |
|
|
else: |
|
|
positive_values = numpy_tensor[np.isfinite(numpy_tensor) & (numpy_tensor > 1e-30)] |
|
|
if positive_values.size > 0: |
|
|
min_positive_val_for_norm = np.min(positive_values) |
|
|
max_val_for_norm = np.max(positive_values) |
|
|
else: |
|
|
min_positive_val_for_norm = 1e-30 |
|
|
max_val_for_norm = 1e-30 * 10 |
|
|
|
|
|
if positive_values.size > 0 and max_val_for_norm > min_positive_val_for_norm * 100 and np.isfinite(min_positive_val_for_norm) and np.isfinite(max_val_for_norm): |
|
|
vmin_candidate = max(min_positive_val_for_norm, 1e-30) |
|
|
vmax_candidate = max_val_for_norm |
|
|
if vmax_candidate <= vmin_candidate or np.isclose(vmax_candidate, vmin_candidate, rtol=1e-5, atol=1e-30): |
|
|
vmax_candidate = vmin_candidate * 10.0 |
|
|
norm = mcolors.LogNorm(vmin=vmin_candidate, vmax=vmax_candidate) |
|
|
scale_type = "logscale" |
|
|
else: |
|
|
finite_vals = numpy_tensor[np.isfinite(numpy_tensor)] |
|
|
if finite_vals.size > 0: |
|
|
imshow_vmin = np.min(finite_vals) |
|
|
imshow_vmax = np.max(finite_vals) |
|
|
if imshow_vmax <= imshow_vmin: |
|
|
imshow_vmax = imshow_vmin + 1.0 |
|
|
else: |
|
|
imshow_vmin = 0.0 |
|
|
imshow_vmax = 1.0 |
|
|
|
|
|
aspect_ratio = numpy_tensor.shape[1] / numpy_tensor.shape[0] |
|
|
aspect = 'auto' if aspect_ratio > 10 or aspect_ratio < 0.1 else 'equal' |
|
|
display_tensor = numpy_tensor |
|
|
if scale_type == "logscale" and norm is not None: |
|
|
display_tensor = np.maximum(display_tensor, norm.vmin) |
|
|
|
|
|
im = plt.imshow(display_tensor, aspect=aspect, cmap='viridis', norm=norm, vmin=imshow_vmin, vmax=imshow_vmax) |
|
|
cbar = plt.colorbar(im) |
|
|
cbar.ax.tick_params(labelsize=plt.rcParams['ytick.labelsize']) |
|
|
|
|
|
|
|
|
param_info = _extract_parameter_info(base_filename) |
|
|
plt.title(f"Heatmap of {title_prefix}\n{param_info}", pad=14) |
|
|
plt.xlabel("Dimension 1") |
|
|
plt.ylabel("Dimension 0") |
|
|
|
|
|
clean_title_prefix = title_prefix.lower().replace(' ', '_').replace('/', '_').replace('(', '').replace(')', '').replace('>', 'gt') |
|
|
heatmaps_dir = os.path.join(out_dir, "heatmaps") |
|
|
os.makedirs(heatmaps_dir, exist_ok=True) |
|
|
base_path = os.path.join(heatmaps_dir, f"{base_filename}_{clean_title_prefix}_heatmap_{scale_type}") |
|
|
|
|
|
|
|
|
fig = plt.gcf() |
|
|
output_path = _save_precond_heatmap_optimized(fig, base_path, data.shape, plot_format) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
if os.path.exists(output_path): |
|
|
file_size_mb = os.path.getsize(output_path) / (1024 * 1024) |
|
|
print(f" Saved heatmap: {os.path.basename(output_path)} ({file_size_mb:.2f} MB, shape={data.shape})") |
|
|
|
|
|
|
|
|
def _plot_single_model_precond_heatmap(tensor: torch.Tensor, model_idx: int, base_filename: str, out_dir: str, |
|
|
model_names: Optional[List[str]] = None, max_side: int = 256, |
|
|
plot_format: str = "png", threshold: Optional[float] = None, |
|
|
zero_ratio: Optional[float] = None, heatmap_floor_log_offset: Optional[float] = None, |
|
|
font_scale: float = 1.0, compression_level: int = 9) -> None: |
|
|
""" |
|
|
Creates a heatmap for a single model's preconditioner values with publication-ready styling. |
|
|
""" |
|
|
if tensor.numel() == 0 or tensor.ndim != 2: |
|
|
return |
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
data = tensor.detach().abs().cpu() |
|
|
if data.shape[0] > max_side or data.shape[1] > max_side: |
|
|
|
|
|
data = _adaptive_downsample_precond(data, max_side, preserve_patterns=True) |
|
|
|
|
|
if zero_ratio is not None and 0 < zero_ratio < 1: |
|
|
flat_data = data.flatten() |
|
|
if flat_data.numel() > 0: |
|
|
threshold_val = torch.quantile(flat_data, zero_ratio) |
|
|
values_to_keep = flat_data[flat_data > threshold_val] |
|
|
if values_to_keep.numel() > 0: |
|
|
min_val_to_keep = torch.min(values_to_keep) |
|
|
new_floor = min_val_to_keep |
|
|
if heatmap_floor_log_offset is not None and heatmap_floor_log_offset > 0 and min_val_to_keep > 0: |
|
|
new_floor = min_val_to_keep / (10**heatmap_floor_log_offset) |
|
|
data[data <= threshold_val] = new_floor |
|
|
else: |
|
|
data[data <= threshold_val] = 0.0 |
|
|
|
|
|
arr = data.numpy() |
|
|
if threshold is not None: |
|
|
arr[arr < threshold] = 0.0 |
|
|
|
|
|
eps = 1e-30 |
|
|
plt.figure(figsize=(6, 5)) |
|
|
img = plt.imshow(np.log10(np.maximum(arr, eps)), cmap='viridis', aspect='auto') |
|
|
|
|
|
chosen_name = model_names[model_idx] if model_names and model_idx < len(model_names) else f'Model {model_idx}' |
|
|
param_info = _extract_parameter_info(base_filename) |
|
|
plot_title = f"{chosen_name}\n{param_info}" |
|
|
plt.title(plot_title, fontsize=plt.rcParams['axes.titlesize'], pad=12) |
|
|
plt.axis("off") |
|
|
|
|
|
cbar = plt.colorbar(img, fraction=0.046, pad=0.04) |
|
|
cbar.set_label("log10(exp_avg_sq)", fontsize=plt.rcParams['axes.labelsize']) |
|
|
cbar.ax.tick_params(labelsize=plt.rcParams['ytick.labelsize']) |
|
|
|
|
|
plt.tight_layout() |
|
|
safe_model_name = chosen_name.replace("/", "-").replace("\\", "-") |
|
|
threshold_str = f"_thresh{threshold:.0e}" if threshold is not None else "" |
|
|
zero_ratio_str = f"_zero{zero_ratio:.2f}" if zero_ratio is not None and 0 < zero_ratio < 1 else "" |
|
|
heatmap_filename = f"{base_filename}_model_{model_idx}_{safe_model_name}_weights_heatmap{threshold_str}{zero_ratio_str}" |
|
|
per_model_dir = os.path.join(out_dir, "per_model_weight_heatmaps") |
|
|
os.makedirs(per_model_dir, exist_ok=True) |
|
|
base_path = os.path.join(per_model_dir, heatmap_filename) |
|
|
|
|
|
|
|
|
fig = plt.gcf() |
|
|
|
|
|
output_path = _save_precond_heatmap_optimized(fig, base_path, data.shape, plot_format, |
|
|
compression_level=compression_level, is_per_model=True) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _plot_dominant_model_precond_heatmap(display_tensor: torch.Tensor, num_models: int, title_prefix: str, |
|
|
base_filename: str, out_dir: str, threshold_value: float, |
|
|
model_names: Optional[List[str]] = None, plot_format: str = "png", |
|
|
font_scale: float = 1.0, max_side: int = 256) -> None: |
|
|
""" |
|
|
Creates a heatmap showing which model has dominant preconditioner values with publication-ready styling. |
|
|
""" |
|
|
if display_tensor.ndim != 2: |
|
|
return |
|
|
|
|
|
_set_publication_fonts(scale_factor=font_scale) |
|
|
|
|
|
|
|
|
data = display_tensor.detach().cpu() |
|
|
if data.shape[0] > max_side or data.shape[1] > max_side: |
|
|
|
|
|
|
|
|
step_h = max(1, int(torch.ceil(torch.tensor(data.shape[0] / max_side)).item())) |
|
|
step_w = max(1, int(torch.ceil(torch.tensor(data.shape[1] / max_side)).item())) |
|
|
data = data[::step_h, ::step_w] |
|
|
|
|
|
plt.figure(figsize=(12, 10)) |
|
|
numpy_display_tensor = data.numpy() |
|
|
|
|
|
|
|
|
if num_models <= 10: |
|
|
model_colors = plt.cm.get_cmap('tab10', num_models).colors |
|
|
elif num_models <= 12: |
|
|
model_colors = plt.cm.get_cmap('Paired', num_models).colors |
|
|
elif num_models <= 20: |
|
|
model_colors = plt.cm.get_cmap('tab20', num_models).colors |
|
|
else: |
|
|
model_colors = plt.cm.get_cmap('viridis', num_models).colors |
|
|
|
|
|
colors = ['black'] + [mcolors.to_hex(c) for c in model_colors] |
|
|
cmap = mcolors.ListedColormap(colors) |
|
|
bounds = [-1.5] + [i - 0.5 for i in range(num_models + 1)] |
|
|
norm = mcolors.BoundaryNorm(bounds, cmap.N) |
|
|
|
|
|
aspect_ratio = numpy_display_tensor.shape[1] / numpy_display_tensor.shape[0] |
|
|
aspect = 'auto' if aspect_ratio > 10 or aspect_ratio < 0.1 else 'equal' |
|
|
im = plt.imshow(numpy_display_tensor, aspect=aspect, cmap=cmap, norm=norm) |
|
|
|
|
|
ticks = list(range(-1, num_models)) |
|
|
cbar = plt.colorbar(im, ticks=ticks, spacing='proportional') |
|
|
base_tick_labels = [f'< {threshold_value:.1f}'] |
|
|
for i in range(num_models): |
|
|
name_i = model_names[i] if model_names and i < len(model_names) else f'Model {i}' |
|
|
base_tick_labels.append(name_i) |
|
|
cbar.set_ticklabels(base_tick_labels) |
|
|
cbar.ax.tick_params(labelsize=plt.rcParams['ytick.labelsize']) |
|
|
|
|
|
param_info = _extract_parameter_info(base_filename) |
|
|
plt.title(f"{title_prefix}\n{param_info}", fontsize=plt.rcParams['axes.titlesize'], pad=18) |
|
|
plt.xlabel("Dimension 1") |
|
|
plt.ylabel("Dimension 0") |
|
|
|
|
|
clean_title_prefix = title_prefix.lower().replace(' ', '_').replace('/', '_').replace('(', '').replace(')', '').replace('>', 'gt') |
|
|
dom_dir = os.path.join(out_dir, "dominant_model_heatmaps") |
|
|
os.makedirs(dom_dir, exist_ok=True) |
|
|
base_path = os.path.join(dom_dir, f"{base_filename}_{clean_title_prefix}_dominant_model_heatmap_thresh{threshold_value}") |
|
|
|
|
|
|
|
|
fig = plt.gcf() |
|
|
output_path = _save_precond_heatmap_optimized(fig, base_path, data.shape, plot_format) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def _looks_like_hf_repo_id(s: str) -> bool: |
|
|
"""Check if string looks like a HuggingFace repo ID.""" |
|
|
import re |
|
|
return bool(re.match(r'^[^/\s]+/[^/\s]+$', s)) |
|
|
|
|
|
|
|
|
def _split_repo_and_file(path: str) -> Optional[tuple]: |
|
|
"""Split HF repo path into repo ID and file path.""" |
|
|
import re |
|
|
m = re.match(r'^([^/\s]+/[^/\s]+)/(.*)$', path) |
|
|
if m: |
|
|
return m.group(1), m.group(2) |
|
|
return None |
|
|
|
|
|
|
|
|
def _resolve_preconditioner_file(model_id: str, precond_spec: Optional[str]) -> tuple: |
|
|
""" |
|
|
Resolves preconditioner file path from model ID and optional spec. |
|
|
Returns (display_name, local_file_path). |
|
|
""" |
|
|
display_name = model_id |
|
|
if precond_spec is None: |
|
|
|
|
|
rel = "export/exp_avg_sq.safetensors" |
|
|
if _looks_like_hf_repo_id(model_id): |
|
|
try: |
|
|
local_path = hf_hub_download(model_id, rel) |
|
|
return display_name, local_path |
|
|
except: |
|
|
pass |
|
|
|
|
|
from pathlib import Path |
|
|
local_candidate = Path(model_id) / rel |
|
|
if not local_candidate.exists(): |
|
|
raise FileNotFoundError(f"Preconditioner not found: {local_candidate}") |
|
|
return display_name, str(local_candidate) |
|
|
|
|
|
|
|
|
split = _split_repo_and_file(precond_spec) |
|
|
if split: |
|
|
repo_id, file_path = split |
|
|
try: |
|
|
local_path = hf_hub_download(repo_id, file_path) |
|
|
return display_name, local_path |
|
|
except: |
|
|
raise ImportError("huggingface_hub is required to resolve HF paths in preconditioner_path") |
|
|
|
|
|
|
|
|
if _looks_like_hf_repo_id(model_id): |
|
|
try: |
|
|
local_path = hf_hub_download(model_id, precond_spec) |
|
|
return display_name, local_path |
|
|
except: |
|
|
pass |
|
|
|
|
|
from pathlib import Path |
|
|
local_candidate = Path(model_id) / precond_spec |
|
|
if not local_candidate.exists(): |
|
|
raise FileNotFoundError(f"Preconditioner not found: {local_candidate}") |
|
|
return display_name, str(local_candidate) |
|
|
|
|
|
|
|
|
def compare_preconditioners(model_entries: List[Dict[str, Any]], output_dir: str, |
|
|
threshold: float = 2.0, only_layers_containing: Optional[str] = None, |
|
|
max_heatmap_side: int = 256, no_per_model_heatmaps: bool = False, |
|
|
param_limit: Optional[int] = None, plot_format: str = "png", |
|
|
single_model_heatmap_threshold: Optional[float] = None, |
|
|
single_model_heatmap_zero_ratio: Optional[float] = None, |
|
|
heatmap_floor_log_offset: Optional[float] = None, |
|
|
compression_level: int = 9, adaptive_format: bool = True, |
|
|
preserve_patterns: bool = True, |
|
|
font_scaling: Optional[Dict[str, float]] = None, |
|
|
plots_to_generate: Optional[Dict[str, bool]] = None) -> None: |
|
|
""" |
|
|
Compare preconditioners across multiple models with professional visualization. |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
if font_scaling is None: |
|
|
font_scaling = {} |
|
|
default_scale = font_scaling.get('default', 1.0) |
|
|
|
|
|
if plots_to_generate is None: |
|
|
plots_to_generate = {} |
|
|
|
|
|
|
|
|
if adaptive_format and plot_format == 'auto': |
|
|
actual_plot_format = 'auto' |
|
|
else: |
|
|
actual_plot_format = plot_format |
|
|
|
|
|
|
|
|
resolved_files: List[str] = [] |
|
|
display_names: List[str] = [] |
|
|
|
|
|
for entry in model_entries: |
|
|
if isinstance(entry, str): |
|
|
model_id = entry |
|
|
precond = None |
|
|
friendly_name = None |
|
|
elif isinstance(entry, dict): |
|
|
model_id = entry.get('model') |
|
|
precond = entry.get('preconditioner_path') |
|
|
params = entry.get('parameters') or {} |
|
|
if precond is None and isinstance(params, dict): |
|
|
precond = params.get('preconditioner_path') |
|
|
friendly_name = entry.get('name') |
|
|
else: |
|
|
continue |
|
|
|
|
|
if not model_id and not precond: |
|
|
continue |
|
|
|
|
|
disp, local = _resolve_preconditioner_file(model_id or "", precond) |
|
|
used_name = friendly_name if friendly_name else (disp if disp else (model_id or "")) |
|
|
display_names.append(used_name) |
|
|
resolved_files.append(local) |
|
|
|
|
|
if len(resolved_files) < 1: |
|
|
raise ValueError("Need at least one model to visualize preconditioners") |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Comparing preconditioners for {len(resolved_files)} models") |
|
|
print(f"Models: {', '.join(display_names)}") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
|
|
|
manifest_path = os.path.join(output_dir, "model_manifest.json") |
|
|
with open(manifest_path, 'w') as f: |
|
|
json.dump(display_names, f, indent=2) |
|
|
|
|
|
|
|
|
per_model_keys: List[set] = [] |
|
|
for fp in resolved_files: |
|
|
keys = set() |
|
|
with safe_open(fp, framework="pt", device="cpu") as f: |
|
|
for k in f.keys(): |
|
|
keys.add(k) |
|
|
per_model_keys.append(keys) |
|
|
|
|
|
common_keys = set.intersection(*per_model_keys) if per_model_keys else set() |
|
|
if only_layers_containing: |
|
|
common_keys = {k for k in common_keys if only_layers_containing in k} |
|
|
|
|
|
|
|
|
sorted_keys = sorted(common_keys) |
|
|
if param_limit is not None: |
|
|
sorted_keys = sorted_keys[:param_limit] |
|
|
|
|
|
if not sorted_keys: |
|
|
print("No common parameter keys found across models.") |
|
|
return |
|
|
|
|
|
print(f"Found {len(sorted_keys)} common parameters to compare") |
|
|
|
|
|
|
|
|
for key_idx, k in enumerate(sorted_keys): |
|
|
print(f"\n[{key_idx+1}/{len(sorted_keys)}] Processing: {k}") |
|
|
|
|
|
|
|
|
display_key = k.replace('.weight', '.exp_avg_sq') |
|
|
if display_key.startswith('model.'): |
|
|
display_key = display_key[len('model.'):] |
|
|
|
|
|
base_filename = display_key.replace('.', '_').replace('/', '_') |
|
|
if len(base_filename) > 180: |
|
|
base_filename = base_filename[:180] |
|
|
|
|
|
|
|
|
tensors = [] |
|
|
shapes = [] |
|
|
for fp in resolved_files: |
|
|
with safe_open(fp, framework="pt", device="cpu") as f: |
|
|
t = f.get_tensor(k) |
|
|
if t.ndim != 2: |
|
|
print(f" Skipping non-2D tensor with shape {t.shape}") |
|
|
break |
|
|
tensors.append(t) |
|
|
shapes.append(t.shape) |
|
|
|
|
|
if len(tensors) != len(resolved_files): |
|
|
continue |
|
|
|
|
|
|
|
|
if len(set(shapes)) != 1: |
|
|
print(f" Skipping - shapes don't match: {shapes}") |
|
|
continue |
|
|
|
|
|
|
|
|
weights_stack = torch.stack(tensors, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if plots_to_generate.get('precond_stddev', True): |
|
|
stddev_tensor = weights_stack.std(dim=0) |
|
|
hist_scale = font_scaling.get('precond_histogram', default_scale) |
|
|
heatmap_scale = font_scaling.get('precond_heatmap', default_scale) |
|
|
_plot_precond_histogram(stddev_tensor, "Element-wise StdDev", base_filename, output_dir, font_scale=hist_scale) |
|
|
_plot_precond_heatmap(stddev_tensor, "Element-wise StdDev", base_filename, output_dir, |
|
|
font_scale=heatmap_scale, max_side=max_heatmap_side) |
|
|
|
|
|
|
|
|
if not no_per_model_heatmaps and plots_to_generate.get('precond_per_model', True): |
|
|
per_model_scale = font_scaling.get('precond_per_model', default_scale) |
|
|
|
|
|
|
|
|
total_heatmaps = len(sorted_keys) * len(resolved_files) |
|
|
if total_heatmaps > 100 and actual_plot_format == 'auto': |
|
|
per_model_format = 'jpg' |
|
|
print(f" Note: Using JPEG format for {total_heatmaps} per-model heatmaps to save space") |
|
|
else: |
|
|
per_model_format = actual_plot_format |
|
|
|
|
|
for i in range(weights_stack.shape[0]): |
|
|
_plot_single_model_precond_heatmap( |
|
|
weights_stack[i, :, :], i, base_filename, output_dir, |
|
|
model_names=display_names, max_side=max_heatmap_side, |
|
|
plot_format=per_model_format, threshold=single_model_heatmap_threshold, |
|
|
zero_ratio=single_model_heatmap_zero_ratio, |
|
|
heatmap_floor_log_offset=heatmap_floor_log_offset, |
|
|
font_scale=per_model_scale, |
|
|
compression_level=compression_level |
|
|
) |
|
|
|
|
|
|
|
|
if plots_to_generate.get('precond_max_min_ratio', True): |
|
|
max_weights = torch.max(weights_stack, dim=0).values |
|
|
min_weights = torch.min(weights_stack, dim=0).values |
|
|
max_min_ratio = max_weights / (min_weights + 1e-28) |
|
|
max_min_ratio = torch.clamp(max_min_ratio, max=1e12) |
|
|
max_min_ratio = torch.nan_to_num(max_min_ratio, nan=0.0) |
|
|
|
|
|
hist_scale = font_scaling.get('precond_histogram', default_scale) |
|
|
heatmap_scale = font_scaling.get('precond_heatmap', default_scale) |
|
|
_plot_precond_histogram(max_min_ratio, "Max-Min Weight Ratio", base_filename, output_dir, |
|
|
use_log_x_scale_heuristic=True, font_scale=hist_scale) |
|
|
_plot_precond_heatmap(max_min_ratio, "Max-Min Weight Ratio", base_filename, output_dir, |
|
|
font_scale=heatmap_scale, max_side=max_heatmap_side) |
|
|
|
|
|
|
|
|
if plots_to_generate.get('precond_dominant_model', True): |
|
|
max_weights_op = torch.max(weights_stack, dim=0) |
|
|
max_weights = max_weights_op.values |
|
|
mean_weights = torch.mean(weights_stack, dim=0) |
|
|
max_mean_ratio = max_weights / (mean_weights + 1e-28) |
|
|
max_indices = max_weights_op.indices |
|
|
|
|
|
dominant_model_display = torch.full_like(max_indices, -1, dtype=torch.long) |
|
|
above_threshold_mask = max_mean_ratio >= threshold |
|
|
dominant_model_display[above_threshold_mask] = max_indices[above_threshold_mask] |
|
|
|
|
|
dom_scale = font_scaling.get('precond_dominant', default_scale) |
|
|
_plot_dominant_model_precond_heatmap( |
|
|
dominant_model_display, weights_stack.shape[0], |
|
|
f"Dominant Model (Max-Mean Ratio > {threshold})", |
|
|
base_filename, output_dir, threshold, model_names=display_names, |
|
|
plot_format=actual_plot_format, font_scale=dom_scale, max_side=max_heatmap_side |
|
|
) |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print("Preconditioner comparison completed!") |
|
|
print(f"Results saved to: {output_dir}") |
|
|
print(f"{'='*60}\n") |
|
|
|