Spaces:
Runtime error
Runtime error
| """ | |
| Model setup utilities for RosettaModel training/evaluation | |
| """ | |
| import torch | |
| from typing import Dict, Any, List | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from rosetta.model.wrapper import RosettaModel | |
| from rosetta.model.projector import create_projector | |
| """ | |
| Mapping strategies | |
| """ | |
| def k_nearest_sources(num_target_layers: int, num_source_layers: int, k: int) -> Dict[int, List[int]]: | |
| """ | |
| Compute a per-target mapping to K nearest source layers. | |
| Returns: Dict[target_idx, List[source_idx]] only for targets we map. | |
| Distances are computed by placing target and source layers uniformly in [0, 1] | |
| and sorting by absolute distance. | |
| """ | |
| if num_target_layers <= 1: | |
| target_positions = [0.0] | |
| else: | |
| target_positions = [i / (num_target_layers - 1) for i in range(num_target_layers)] | |
| if num_source_layers <= 1: | |
| source_positions = [0.0] | |
| else: | |
| source_positions = [j / (num_source_layers - 1) for j in range(num_source_layers)] | |
| mapping: Dict[int, List[int]] = {} | |
| for t_idx, t_pos in enumerate(target_positions): | |
| sorted_src = sorted(range(num_source_layers), key=lambda j: abs(source_positions[j] - t_pos)) | |
| chosen = sorted_src[:max(0, k)] | |
| if len(chosen) > 0: | |
| mapping[t_idx] = chosen | |
| return mapping | |
| def last_aligned_sources(num_target_layers: int, num_source_layers: int, k: int = 1) -> Dict[int, List[int]]: | |
| """ | |
| Return a per-target mapping that aligns the last target layer to the last | |
| source layer and walks toward the front. | |
| Returns: Dict[target_idx, List[source_idx]] only for targets we map. For each | |
| target t, we choose up to K sources anchored at the aligned index, preferring | |
| backward indices first then forward to satisfy K. | |
| Example (T=11, S=33): target 10 -> [32, 31, ...], target 9 -> [31, 30, ...] | |
| """ | |
| mapping: Dict[int, List[int]] = {} | |
| if num_target_layers <= 0 or num_source_layers <= 0: | |
| return mapping | |
| # Align ends; offset >= 0 means extra source layers at the front | |
| offset = num_source_layers - num_target_layers | |
| def take_k_from(s0: int) -> List[int]: | |
| result: List[int] = [] | |
| # Prefer moving backward from the anchor (last-to-front) | |
| for back in range(k): | |
| idx = s0 - back | |
| if 0 <= idx < num_source_layers: | |
| result.append(idx) | |
| # If not enough due to boundary, extend forward | |
| next_idx = s0 + 1 | |
| while len(result) < k and next_idx < num_source_layers: | |
| result.append(next_idx) | |
| next_idx += 1 | |
| return result | |
| for t in range(num_target_layers): | |
| s0 = offset + t | |
| # Clamp to valid range for edge cases (e.g., fewer source layers) | |
| if s0 < 0: | |
| s0 = 0 | |
| elif s0 > num_source_layers - 1: | |
| s0 = num_source_layers - 1 | |
| chosen = take_k_from(s0) | |
| if len(chosen) > 0: | |
| mapping[t] = chosen | |
| return mapping | |
| def setup_models(model_config: Dict[str, Any], device: str = "cuda", dtype: torch.dtype = torch.bfloat16): | |
| """Setup RosettaModel with base model, teacher model, and projectors""" | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_config["base_model"]) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Load models | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_config["base_model"], | |
| torch_dtype=dtype, | |
| device_map=device | |
| ) | |
| teacher_model = AutoModelForCausalLM.from_pretrained( | |
| model_config["teacher_model"], | |
| torch_dtype=dtype, | |
| device_map=device | |
| ) | |
| # Create projector | |
| projector_config = model_config["projector"] | |
| projector_params = projector_config["params"].copy() | |
| projector_params["dtype"] = dtype | |
| projector = create_projector( | |
| projector_config["type"], | |
| source_dim=teacher_model.config.head_dim, | |
| target_dim=base_model.config.head_dim, | |
| **projector_params | |
| ) | |
| # Setup RosettaModel | |
| rosetta_model = RosettaModel( | |
| model_list=[base_model, teacher_model], | |
| base_model_idx=0, | |
| projector_list=[projector] | |
| ).to(device) | |
| # Configure projector mappings | |
| num_layers_to_map = min( | |
| base_model.config.num_hidden_layers, | |
| teacher_model.config.num_hidden_layers | |
| ) | |
| for layer_idx in range(num_layers_to_map): | |
| rosetta_model.set_projector_config( | |
| source_model_idx=1, # Teacher | |
| source_model_layer_idx=layer_idx, | |
| target_model_idx=0, # Base | |
| target_model_layer_idx=layer_idx, | |
| projector_idx=0 | |
| ) | |
| return rosetta_model, tokenizer |