pmahdavi's picture
Deploy FFG Mask Explorer initial version
48a55a5
import torch
import copy
from tqdm import tqdm
import gc
def _get_grafting_threshold(scores_iterator, k):
"""
Finds the exact threshold by collecting all scores on the CPU and using a single,
efficient `kthvalue` operation.
"""
print("Collecting all scores on CPU to determine threshold...")
# This will be memory-intensive on the CPU, but is much faster than iterative updates.
all_scores = torch.cat([scores.cpu().flatten() for scores in scores_iterator])
if k == 0:
return torch.finfo(torch.float32).max
if k >= len(all_scores):
return all_scores.min()
# To find the k-th LARGEST value, we find the (n-k)-th SMALLEST value.
# We use all_scores.numel() - k since kthvalue is 1-indexed.
threshold = torch.kthvalue(all_scores, all_scores.numel() - k + 1).values
del all_scores
gc.collect()
return threshold
def fast_fisher_graft(pretrained_model, finetuned_model, optimizer_v_state, sparsity_ratio):
print("Performing Fast Fisher Grafting...")
device = next(finetuned_model.parameters()).device
pretrained_state = pretrained_model.state_dict()
finetuned_state = finetuned_model.state_dict()
param_names = [name for name in finetuned_state.keys() if name in pretrained_state and name in optimizer_v_state]
total_params = sum(finetuned_state[name].numel() for name in param_names)
k = int(total_params * sparsity_ratio)
print(f"Total graftable parameters: {total_params}")
print(f"Sparsity: {sparsity_ratio:.2%}, keeping top {k} parameters.")
def score_generator():
for name in tqdm(param_names, desc="Calculating scores on GPU"):
with torch.no_grad():
w_t = finetuned_state[name].to(device, dtype=torch.float32)
w_0 = pretrained_state[name].to(device, dtype=torch.float32)
v_t = optimizer_v_state[name].to(device, dtype=torch.float32)
yield (w_t - w_0)**2 * v_t
threshold = _get_grafting_threshold(score_generator(), k)
print(f"Calculated sensitivity threshold: {threshold}")
grafted_model = copy.deepcopy(pretrained_model)
grafted_state = grafted_model.state_dict()
masks_dict = {}
total_kept = 0
for name in tqdm(param_names, desc="Applying graft"):
with torch.no_grad():
w_t = finetuned_state[name].to(device, dtype=torch.float32)
w_0 = pretrained_state[name].to(device, dtype=torch.float32)
v_t = optimizer_v_state[name].to(device, dtype=torch.float32)
score = (w_t - w_0)**2 * v_t
mask = (score >= threshold).to(w_t.dtype)
grafted_param = w_0 + mask * (w_t - w_0)
grafted_state[name].copy_(grafted_param.to(grafted_state[name].dtype))
masks_dict[name] = mask.cpu().bool()
total_kept += mask.sum().item()
stats_dict = {
'kept_params': total_kept,
'total_params': total_params,
'final_sparsity': total_kept / total_params if total_params > 0 else 0,
'threshold': threshold.item() if isinstance(threshold, torch.Tensor) else threshold
}
return grafted_model, stats_dict, masks_dict
def magnitude_graft(pretrained_model, finetuned_model, sparsity_ratio):
print("Performing Magnitude-Based Grafting...")
device = next(finetuned_model.parameters()).device
pretrained_state = pretrained_model.state_dict()
finetuned_state = finetuned_model.state_dict()
param_names = [name for name in finetuned_state.keys() if name in pretrained_state]
total_params = sum(finetuned_state[name].numel() for name in param_names)
k = int(total_params * sparsity_ratio)
print(f"Total graftable parameters: {total_params}")
print(f"Sparsity: {sparsity_ratio:.2%}, keeping top {k} parameters.")
def score_generator():
for name in tqdm(param_names, desc="Calculating scores on GPU"):
with torch.no_grad():
w_t = finetuned_state[name].to(device, dtype=torch.float32)
w_0 = pretrained_state[name].to(device, dtype=torch.float32)
yield torch.abs(w_t - w_0)
threshold = _get_grafting_threshold(score_generator(), k)
print(f"Calculated magnitude threshold: {threshold}")
grafted_model = copy.deepcopy(pretrained_model)
grafted_state = grafted_model.state_dict()
masks_dict = {}
total_kept = 0
for name in tqdm(param_names, desc="Applying graft"):
with torch.no_grad():
w_t = finetuned_state[name].to(device, dtype=torch.float32)
w_0 = pretrained_state[name].to(device, dtype=torch.float32)
score = torch.abs(w_t - w_0)
mask = (score >= threshold).to(w_t.dtype)
grafted_param = w_0 + mask * (w_t - w_0)
grafted_state[name].copy_(grafted_param.to(grafted_state[name].dtype))
masks_dict[name] = mask.cpu().bool()
total_kept += mask.sum().item()
stats_dict = {
'kept_params': total_kept,
'total_params': total_params,
'final_sparsity': total_kept / total_params if total_params > 0 else 0,
'threshold': threshold.item() if isinstance(threshold, torch.Tensor) else threshold
}
return grafted_model, stats_dict, masks_dict
def fish_mask_graft(pretrained_model, finetuned_model, optimizer_v_state, sparsity_ratio):
print("Performing Fish-Mask Grafting (v_t-only)...")
device = next(finetuned_model.parameters()).device
pretrained_state = pretrained_model.state_dict()
finetuned_state = finetuned_model.state_dict()
param_names = [name for name in finetuned_state.keys() if name in pretrained_state and name in optimizer_v_state]
total_params = sum(finetuned_state[name].numel() for name in param_names)
k = int(total_params * sparsity_ratio)
print(f"Total graftable parameters: {total_params}")
print(f"Sparsity: {sparsity_ratio:.2%}, keeping top {k} parameters.")
def score_generator():
for name in tqdm(param_names, desc="Calculating scores on GPU"):
with torch.no_grad():
v_t = optimizer_v_state[name].to(device, dtype=torch.float32)
yield v_t
threshold = _get_grafting_threshold(score_generator(), k)
print(f"Calculated fish-mask threshold: {threshold}")
grafted_model = copy.deepcopy(pretrained_model)
grafted_state = grafted_model.state_dict()
masks_dict = {}
total_kept = 0
for name in tqdm(param_names, desc="Applying graft"):
with torch.no_grad():
w_t = finetuned_state[name].to(device, dtype=torch.float32)
w_0 = pretrained_state[name].to(device, dtype=torch.float32)
v_t = optimizer_v_state[name].to(device, dtype=torch.float32)
mask = (v_t >= threshold).to(w_t.dtype)
grafted_param = w_0 + mask * (w_t - w_0)
grafted_state[name].copy_(grafted_param.to(grafted_state[name].dtype))
masks_dict[name] = mask.cpu().bool()
total_kept += mask.sum().item()
stats_dict = {
'kept_params': total_kept,
'total_params': total_params,
'final_sparsity': total_kept / total_params if total_params > 0 else 0,
'threshold': threshold.item() if isinstance(threshold, torch.Tensor) else threshold
}
return grafted_model, stats_dict, masks_dict