import math import torch import torch.nn as nn from einops import rearrange, repeat from einops.layers.torch import EinMix from transformers import PreTrainedModel, PretrainedConfig # --------------------------- # Configuration Class # --------------------------- class TRMConfig(PretrainedConfig): model_type = "trm" def __init__(self, vocab_size=32000, hidden_size=256, seq_len=128, depth_L=2, depth_H=2, act_threshold=0.9, act_epsilon=1e-2, **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.seq_len = seq_len self.depth_L = depth_L self.depth_H = depth_H self.act_threshold = act_threshold self.act_epsilon = act_epsilon # --------------------------- # Model Architecture # --------------------------- class HaltingBlock(nn.Module): def __init__(self, hidden_size, act_threshold, act_epsilon): super().__init__() self.proj = nn.Linear(hidden_size, hidden_size) self.act_proj = nn.Linear(hidden_size, 1) self.act_threshold = act_threshold self.act_epsilon = act_epsilon def forward(self, x): halting_probs = torch.sigmoid(self.act_proj(x)) remainders = torch.zeros_like(halting_probs) n_updates = torch.zeros_like(halting_probs) still_running = torch.ones_like(halting_probs, dtype=torch.bool) accumulated_output = torch.zeros_like(x) accumulated_prob = torch.zeros_like(halting_probs) while still_running.any(): p = torch.where(still_running, halting_probs, torch.zeros_like(halting_probs)) new_accum = accumulated_prob + p still_running = new_accum < self.act_threshold remainder = torch.where(still_running, torch.zeros_like(halting_probs), 1 - accumulated_prob) update_weights = torch.where(still_running, p, remainder) accumulated_output += update_weights * torch.tanh(self.proj(x)) accumulated_prob += update_weights n_updates += still_running.float() if (1 - accumulated_prob).mean() < self.act_epsilon: break return accumulated_output, accumulated_prob.mean() class TRMLayer(nn.Module): def __init__(self, hidden_size, depth_H, act_threshold, act_epsilon): super().__init__() self.blocks = nn.ModuleList([ HaltingBlock(hidden_size, act_threshold, act_epsilon) for _ in range(depth_H) ]) self.norm = nn.LayerNorm(hidden_size) def forward(self, x): for block in self.blocks: x, _ = block(x) return self.norm(x) class TRM(PreTrainedModel): config_class = TRMConfig def __init__(self, config): super().__init__(config) self.emb = nn.Embedding(config.vocab_size, config.hidden_size) self.pos_emb = nn.Parameter(torch.zeros(1, config.seq_len, config.hidden_size)) self.layers = nn.ModuleList([ TRMLayer(config.hidden_size, config.depth_H, config.act_threshold, config.act_epsilon) for _ in range(config.depth_L) ]) self.norm = nn.LayerNorm(config.hidden_size) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def forward(self, input_ids, labels=None): x = self.emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :] for layer in self.layers: x = layer(x) x = self.norm(x) logits = self.lm_head(x) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return {"loss": loss, "logits": logits} # --------------------------- # Utility: Register to AutoClasses # --------------------------- from transformers import AutoConfig, AutoModel AutoConfig.register("trm", TRMConfig) AutoModel.register(TRMConfig, TRM)