|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from einops import rearrange, repeat |
|
|
from einops.layers.torch import EinMix |
|
|
from transformers import PreTrainedModel, PretrainedConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TRMConfig(PretrainedConfig): |
|
|
model_type = "trm" |
|
|
|
|
|
def __init__(self, |
|
|
vocab_size=32000, |
|
|
hidden_size=256, |
|
|
seq_len=128, |
|
|
depth_L=2, |
|
|
depth_H=2, |
|
|
act_threshold=0.9, |
|
|
act_epsilon=1e-2, |
|
|
**kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.seq_len = seq_len |
|
|
self.depth_L = depth_L |
|
|
self.depth_H = depth_H |
|
|
self.act_threshold = act_threshold |
|
|
self.act_epsilon = act_epsilon |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class HaltingBlock(nn.Module): |
|
|
def __init__(self, hidden_size, act_threshold, act_epsilon): |
|
|
super().__init__() |
|
|
self.proj = nn.Linear(hidden_size, hidden_size) |
|
|
self.act_proj = nn.Linear(hidden_size, 1) |
|
|
self.act_threshold = act_threshold |
|
|
self.act_epsilon = act_epsilon |
|
|
|
|
|
def forward(self, x): |
|
|
halting_probs = torch.sigmoid(self.act_proj(x)) |
|
|
remainders = torch.zeros_like(halting_probs) |
|
|
n_updates = torch.zeros_like(halting_probs) |
|
|
still_running = torch.ones_like(halting_probs, dtype=torch.bool) |
|
|
accumulated_output = torch.zeros_like(x) |
|
|
accumulated_prob = torch.zeros_like(halting_probs) |
|
|
|
|
|
while still_running.any(): |
|
|
p = torch.where(still_running, halting_probs, torch.zeros_like(halting_probs)) |
|
|
new_accum = accumulated_prob + p |
|
|
|
|
|
still_running = new_accum < self.act_threshold |
|
|
remainder = torch.where(still_running, torch.zeros_like(halting_probs), 1 - accumulated_prob) |
|
|
|
|
|
update_weights = torch.where(still_running, p, remainder) |
|
|
accumulated_output += update_weights * torch.tanh(self.proj(x)) |
|
|
accumulated_prob += update_weights |
|
|
n_updates += still_running.float() |
|
|
|
|
|
if (1 - accumulated_prob).mean() < self.act_epsilon: |
|
|
break |
|
|
|
|
|
return accumulated_output, accumulated_prob.mean() |
|
|
|
|
|
|
|
|
class TRMLayer(nn.Module): |
|
|
def __init__(self, hidden_size, depth_H, act_threshold, act_epsilon): |
|
|
super().__init__() |
|
|
self.blocks = nn.ModuleList([ |
|
|
HaltingBlock(hidden_size, act_threshold, act_epsilon) for _ in range(depth_H) |
|
|
]) |
|
|
self.norm = nn.LayerNorm(hidden_size) |
|
|
|
|
|
def forward(self, x): |
|
|
for block in self.blocks: |
|
|
x, _ = block(x) |
|
|
return self.norm(x) |
|
|
|
|
|
|
|
|
class TRM(PreTrainedModel): |
|
|
config_class = TRMConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.emb = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.pos_emb = nn.Parameter(torch.zeros(1, config.seq_len, config.hidden_size)) |
|
|
self.layers = nn.ModuleList([ |
|
|
TRMLayer(config.hidden_size, config.depth_H, config.act_threshold, config.act_epsilon) |
|
|
for _ in range(config.depth_L) |
|
|
]) |
|
|
self.norm = nn.LayerNorm(config.hidden_size) |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def forward(self, input_ids, labels=None): |
|
|
x = self.emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :] |
|
|
for layer in self.layers: |
|
|
x = layer(x) |
|
|
x = self.norm(x) |
|
|
logits = self.lm_head(x) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss_fct = nn.CrossEntropyLoss() |
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) |
|
|
|
|
|
return {"loss": loss, "logits": logits} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from transformers import AutoConfig, AutoModel |
|
|
|
|
|
AutoConfig.register("trm", TRMConfig) |
|
|
AutoModel.register(TRMConfig, TRM) |
|
|
|