|
|
|
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2MLP |
|
|
from transformers.generation import GenerationMixin |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
class TRMConfig(PretrainedConfig): |
|
|
model_type = "recursive_gpt" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=50257, |
|
|
n_positions=1024, |
|
|
n_embd=512, |
|
|
n_physical_layers=3, |
|
|
n_loops=8, |
|
|
n_head=8, |
|
|
activation_function="gelu_new", |
|
|
resid_pdrop=0.1, |
|
|
embd_pdrop=0.1, |
|
|
attn_pdrop=0.1, |
|
|
layer_norm_epsilon=1e-5, |
|
|
scale_attn_weights=True, |
|
|
scale_attn_by_inverse_layer_idx=False, |
|
|
reorder_and_upcast_attn=False, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.n_positions = n_positions |
|
|
self.n_embd = n_embd |
|
|
self.n_physical_layers = n_physical_layers |
|
|
self.n_loops = n_loops |
|
|
self.n_head = n_head |
|
|
self.activation_function = activation_function |
|
|
self.resid_pdrop = resid_pdrop |
|
|
self.embd_pdrop = embd_pdrop |
|
|
self.attn_pdrop = attn_pdrop |
|
|
self.layer_norm_epsilon = layer_norm_epsilon |
|
|
self.scale_attn_weights = scale_attn_weights |
|
|
self.scale_attn_by_inverse_layer_idx = scale_attn_by_inverse_layer_idx |
|
|
self.reorder_and_upcast_attn = reorder_and_upcast_attn |
|
|
|
|
|
|
|
|
self.hidden_size = n_embd |
|
|
self.num_attention_heads = n_head |
|
|
self.num_hidden_layers = n_physical_layers |
|
|
self.n_inner = None |
|
|
self.is_encoder_decoder = False |
|
|
|
|
|
class TinyRecursiveModel(PreTrainedModel, GenerationMixin): |
|
|
config_class = TRMConfig |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
|
|
|
self.wte = nn.Embedding(config.vocab_size, config.n_embd) |
|
|
self.wpe = nn.Embedding(config.n_positions, config.n_embd) |
|
|
self.drop = nn.Dropout(config.embd_pdrop) |
|
|
|
|
|
|
|
|
self.physical_blocks = nn.ModuleList([ |
|
|
nn.ModuleDict({ |
|
|
"ln_1": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), |
|
|
"attn": GPT2Attention(config, layer_idx=i), |
|
|
"ln_2": nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon), |
|
|
"mlp": GPT2MLP(4 * config.n_embd, config) |
|
|
}) for i in range(config.n_physical_layers) |
|
|
]) |
|
|
|
|
|
|
|
|
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) |
|
|
|
|
|
|
|
|
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
|
|
if input_ids is None: |
|
|
return None |
|
|
|
|
|
batch_size, seq_len = input_ids.shape |
|
|
device = input_ids.device |
|
|
|
|
|
|
|
|
token_embeds = self.wte(input_ids) |
|
|
pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=device) |
|
|
pos_embeds = self.wpe(pos_ids) |
|
|
hidden_states = self.drop(token_embeds + pos_embeds) |
|
|
|
|
|
|
|
|
for loop in range(self.config.n_loops): |
|
|
block_idx = loop % self.config.n_physical_layers |
|
|
block = self.physical_blocks[block_idx] |
|
|
|
|
|
|
|
|
ln_output = block["ln_1"](hidden_states) |
|
|
attn_output = block["attn"](ln_output, attention_mask=attention_mask)[0] |
|
|
hidden_states = hidden_states + attn_output |
|
|
|
|
|
|
|
|
ln_output = block["ln_2"](hidden_states) |
|
|
mlp_output = block["mlp"](ln_output) |
|
|
hidden_states = hidden_states + mlp_output |
|
|
|
|
|
|
|
|
hidden_states = self.ln_f(hidden_states) |
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=hidden_states, |
|
|
attentions=None, |
|
|
cross_attentions=None |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, **kwargs): |
|
|
return {"input_ids": input_ids} |
|
|
|
|
|
def _reorder_cache(self, past, beam_idx): |
|
|
return past |
|
|
|