|
|
import torch |
|
|
import copy |
|
|
from tqdm import tqdm |
|
|
import gc |
|
|
|
|
|
def _get_grafting_threshold(scores_iterator, k): |
|
|
""" |
|
|
Finds the exact threshold by collecting all scores on the CPU and using a single, |
|
|
efficient `kthvalue` operation. |
|
|
""" |
|
|
print("Collecting all scores on CPU to determine threshold...") |
|
|
|
|
|
all_scores = torch.cat([scores.cpu().flatten() for scores in scores_iterator]) |
|
|
|
|
|
if k == 0: |
|
|
return torch.finfo(torch.float32).max |
|
|
|
|
|
if k >= len(all_scores): |
|
|
return all_scores.min() |
|
|
|
|
|
|
|
|
|
|
|
threshold = torch.kthvalue(all_scores, all_scores.numel() - k + 1).values |
|
|
|
|
|
del all_scores |
|
|
gc.collect() |
|
|
|
|
|
return threshold |
|
|
|
|
|
def fast_fisher_graft(pretrained_model, finetuned_model, optimizer_v_state, sparsity_ratio): |
|
|
print("Performing Fast Fisher Grafting...") |
|
|
device = next(finetuned_model.parameters()).device |
|
|
|
|
|
pretrained_state = pretrained_model.state_dict() |
|
|
finetuned_state = finetuned_model.state_dict() |
|
|
|
|
|
param_names = [name for name in finetuned_state.keys() if name in pretrained_state and name in optimizer_v_state] |
|
|
total_params = sum(finetuned_state[name].numel() for name in param_names) |
|
|
k = int(total_params * sparsity_ratio) |
|
|
|
|
|
print(f"Total graftable parameters: {total_params}") |
|
|
print(f"Sparsity: {sparsity_ratio:.2%}, keeping top {k} parameters.") |
|
|
|
|
|
def score_generator(): |
|
|
for name in tqdm(param_names, desc="Calculating scores on GPU"): |
|
|
with torch.no_grad(): |
|
|
w_t = finetuned_state[name].to(device, dtype=torch.float32) |
|
|
w_0 = pretrained_state[name].to(device, dtype=torch.float32) |
|
|
v_t = optimizer_v_state[name].to(device, dtype=torch.float32) |
|
|
yield (w_t - w_0)**2 * v_t |
|
|
|
|
|
threshold = _get_grafting_threshold(score_generator(), k) |
|
|
print(f"Calculated sensitivity threshold: {threshold}") |
|
|
|
|
|
grafted_model = copy.deepcopy(pretrained_model) |
|
|
grafted_state = grafted_model.state_dict() |
|
|
masks_dict = {} |
|
|
total_kept = 0 |
|
|
|
|
|
for name in tqdm(param_names, desc="Applying graft"): |
|
|
with torch.no_grad(): |
|
|
w_t = finetuned_state[name].to(device, dtype=torch.float32) |
|
|
w_0 = pretrained_state[name].to(device, dtype=torch.float32) |
|
|
v_t = optimizer_v_state[name].to(device, dtype=torch.float32) |
|
|
score = (w_t - w_0)**2 * v_t |
|
|
mask = (score >= threshold).to(w_t.dtype) |
|
|
grafted_param = w_0 + mask * (w_t - w_0) |
|
|
grafted_state[name].copy_(grafted_param.to(grafted_state[name].dtype)) |
|
|
masks_dict[name] = mask.cpu().bool() |
|
|
total_kept += mask.sum().item() |
|
|
|
|
|
stats_dict = { |
|
|
'kept_params': total_kept, |
|
|
'total_params': total_params, |
|
|
'final_sparsity': total_kept / total_params if total_params > 0 else 0, |
|
|
'threshold': threshold.item() if isinstance(threshold, torch.Tensor) else threshold |
|
|
} |
|
|
|
|
|
return grafted_model, stats_dict, masks_dict |
|
|
|
|
|
def magnitude_graft(pretrained_model, finetuned_model, sparsity_ratio): |
|
|
print("Performing Magnitude-Based Grafting...") |
|
|
device = next(finetuned_model.parameters()).device |
|
|
|
|
|
pretrained_state = pretrained_model.state_dict() |
|
|
finetuned_state = finetuned_model.state_dict() |
|
|
|
|
|
param_names = [name for name in finetuned_state.keys() if name in pretrained_state] |
|
|
total_params = sum(finetuned_state[name].numel() for name in param_names) |
|
|
k = int(total_params * sparsity_ratio) |
|
|
|
|
|
print(f"Total graftable parameters: {total_params}") |
|
|
print(f"Sparsity: {sparsity_ratio:.2%}, keeping top {k} parameters.") |
|
|
|
|
|
def score_generator(): |
|
|
for name in tqdm(param_names, desc="Calculating scores on GPU"): |
|
|
with torch.no_grad(): |
|
|
w_t = finetuned_state[name].to(device, dtype=torch.float32) |
|
|
w_0 = pretrained_state[name].to(device, dtype=torch.float32) |
|
|
yield torch.abs(w_t - w_0) |
|
|
|
|
|
threshold = _get_grafting_threshold(score_generator(), k) |
|
|
print(f"Calculated magnitude threshold: {threshold}") |
|
|
|
|
|
grafted_model = copy.deepcopy(pretrained_model) |
|
|
grafted_state = grafted_model.state_dict() |
|
|
masks_dict = {} |
|
|
total_kept = 0 |
|
|
|
|
|
for name in tqdm(param_names, desc="Applying graft"): |
|
|
with torch.no_grad(): |
|
|
w_t = finetuned_state[name].to(device, dtype=torch.float32) |
|
|
w_0 = pretrained_state[name].to(device, dtype=torch.float32) |
|
|
score = torch.abs(w_t - w_0) |
|
|
mask = (score >= threshold).to(w_t.dtype) |
|
|
grafted_param = w_0 + mask * (w_t - w_0) |
|
|
grafted_state[name].copy_(grafted_param.to(grafted_state[name].dtype)) |
|
|
masks_dict[name] = mask.cpu().bool() |
|
|
total_kept += mask.sum().item() |
|
|
|
|
|
stats_dict = { |
|
|
'kept_params': total_kept, |
|
|
'total_params': total_params, |
|
|
'final_sparsity': total_kept / total_params if total_params > 0 else 0, |
|
|
'threshold': threshold.item() if isinstance(threshold, torch.Tensor) else threshold |
|
|
} |
|
|
|
|
|
return grafted_model, stats_dict, masks_dict |
|
|
|
|
|
def fish_mask_graft(pretrained_model, finetuned_model, optimizer_v_state, sparsity_ratio): |
|
|
print("Performing Fish-Mask Grafting (v_t-only)...") |
|
|
device = next(finetuned_model.parameters()).device |
|
|
|
|
|
pretrained_state = pretrained_model.state_dict() |
|
|
finetuned_state = finetuned_model.state_dict() |
|
|
|
|
|
param_names = [name for name in finetuned_state.keys() if name in pretrained_state and name in optimizer_v_state] |
|
|
total_params = sum(finetuned_state[name].numel() for name in param_names) |
|
|
k = int(total_params * sparsity_ratio) |
|
|
|
|
|
print(f"Total graftable parameters: {total_params}") |
|
|
print(f"Sparsity: {sparsity_ratio:.2%}, keeping top {k} parameters.") |
|
|
|
|
|
def score_generator(): |
|
|
for name in tqdm(param_names, desc="Calculating scores on GPU"): |
|
|
with torch.no_grad(): |
|
|
v_t = optimizer_v_state[name].to(device, dtype=torch.float32) |
|
|
yield v_t |
|
|
|
|
|
threshold = _get_grafting_threshold(score_generator(), k) |
|
|
print(f"Calculated fish-mask threshold: {threshold}") |
|
|
|
|
|
grafted_model = copy.deepcopy(pretrained_model) |
|
|
grafted_state = grafted_model.state_dict() |
|
|
masks_dict = {} |
|
|
total_kept = 0 |
|
|
|
|
|
for name in tqdm(param_names, desc="Applying graft"): |
|
|
with torch.no_grad(): |
|
|
w_t = finetuned_state[name].to(device, dtype=torch.float32) |
|
|
w_0 = pretrained_state[name].to(device, dtype=torch.float32) |
|
|
v_t = optimizer_v_state[name].to(device, dtype=torch.float32) |
|
|
mask = (v_t >= threshold).to(w_t.dtype) |
|
|
grafted_param = w_0 + mask * (w_t - w_0) |
|
|
grafted_state[name].copy_(grafted_param.to(grafted_state[name].dtype)) |
|
|
masks_dict[name] = mask.cpu().bool() |
|
|
total_kept += mask.sum().item() |
|
|
|
|
|
stats_dict = { |
|
|
'kept_params': total_kept, |
|
|
'total_params': total_params, |
|
|
'final_sparsity': total_kept / total_params if total_params > 0 else 0, |
|
|
'threshold': threshold.item() if isinstance(threshold, torch.Tensor) else threshold |
|
|
} |
|
|
|
|
|
return grafted_model, stats_dict, masks_dict |
|
|
|