tiny-recursive-model / modeling_tiny_recursive.py
ainz's picture
Fix model architecture and generation compatibility
43ea3b4 verified
from transformers import PreTrainedModel, PretrainedConfig
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP
from transformers.generation import GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
import torch
import torch.nn as nn
class TRMConfig(PretrainedConfig):
model_type = "recursive_gpt"
def __init__(
self,
vocab_size=50257,
n_positions=1024,
n_embd=512,
n_physical_layers=3,
n_loops=8,
n_head=8,
activation_function="gelu_new",
resid_pdrop=0.1,
embd_pdrop=0.1,
attn_pdrop=0.1,
layer_norm_epsilon=1e-5,
scale_attn_weights=True,
scale_attn_by_inverse_layer_idx=False,
reorder_and_upcast_attn=False,
**kwargs,
):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.n_positions = n_positions
self.n_embd = n_embd
self.n_physical_layers = n_physical_layers
self.n_loops = n_loops
self.n_head = n_head
self.activation_function = activation_function
self.resid_pdrop = resid_pdrop
self.embd_pdrop = embd_pdrop
self.attn_pdrop = attn_pdrop
self.layer_norm_epsilon = layer_norm_epsilon
self.scale_attn_weights = scale_attn_weights
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx
self.reorder_and_upcast_attn = reorder_and_upcast_attn
# Required for transformers compatibility
self.hidden_size = n_embd
self.num_attention_heads = n_head
self.num_hidden_layers = n_physical_layers
self.n_inner = None
self.is_encoder_decoder = False
class TinyRecursiveModel(PreTrainedModel, GenerationMixin):
config_class = TRMConfig
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.config = config
# 1. Embeddings
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
self.wpe = nn.Embedding(config.n_positions, config.n_embd)
self.drop = nn.Dropout(config.embd_pdrop)
# 2. Physical blocks - matching your saved model structure
self.physical_blocks = nn.ModuleList([
nn.ModuleDict({
"ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
"attn": GPT2Attention(config, layer_idx=i),
"ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon),
"mlp": GPT2MLP(4 * config.n_embd, config)
}) for i in range(config.n_physical_layers)
])
# 3. Final layer norm
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
# 4. Language modeling head
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
# Initialize weights
self.post_init()
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
if input_ids is None:
return None
batch_size, seq_len = input_ids.shape
device = input_ids.device
# Get embeddings
token_embeds = self.wte(input_ids)
pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device)
pos_embeds = self.wpe(pos_ids)
hidden_states = self.drop(token_embeds + pos_embeds)
# Apply recursive loops through physical blocks
for loop in range(self.config.n_loops):
block_idx = loop % self.config.n_physical_layers
block = self.physical_blocks[block_idx]
# Attention
ln_output = block["ln_1"](hidden_states)
attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0]
hidden_states = hidden_states + attn_output
# MLP
ln_output = block["ln_2"](hidden_states)
mlp_output = block["mlp"](ln_output)
hidden_states = hidden_states + mlp_output
# Final layer norm and projection
hidden_states = self.ln_f(hidden_states)
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
hidden_states=hidden_states,
attentions=None,
cross_attentions=None
)
def prepare_inputs_for_generation(self, input_ids, **kwargs):
return {"input_ids": input_ids}
def _reorder_cache(self, past, beam_idx):
return past