luminar-nano / modelling_trm.py
Lorenzob's picture
Upload folder using huggingface_hub
216cb97 verified
raw
history blame
4.26 kB
import math
import torch
import torch.nn as nn
from einops import rearrange, repeat
from einops.layers.torch import EinMix
from transformers import PreTrainedModel, PretrainedConfig
# ---------------------------
# Configuration Class
# ---------------------------
class TRMConfig(PretrainedConfig):
model_type = "trm"
def __init__(self,
vocab_size=32000,
hidden_size=256,
seq_len=128,
depth_L=2,
depth_H=2,
act_threshold=0.9,
act_epsilon=1e-2,
**kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.seq_len = seq_len
self.depth_L = depth_L
self.depth_H = depth_H
self.act_threshold = act_threshold
self.act_epsilon = act_epsilon
# ---------------------------
# Model Architecture
# ---------------------------
class HaltingBlock(nn.Module):
def __init__(self, hidden_size, act_threshold, act_epsilon):
super().__init__()
self.proj = nn.Linear(hidden_size, hidden_size)
self.act_proj = nn.Linear(hidden_size, 1)
self.act_threshold = act_threshold
self.act_epsilon = act_epsilon
def forward(self, x):
halting_probs = torch.sigmoid(self.act_proj(x))
remainders = torch.zeros_like(halting_probs)
n_updates = torch.zeros_like(halting_probs)
still_running = torch.ones_like(halting_probs, dtype=torch.bool)
accumulated_output = torch.zeros_like(x)
accumulated_prob = torch.zeros_like(halting_probs)
while still_running.any():
p = torch.where(still_running, halting_probs, torch.zeros_like(halting_probs))
new_accum = accumulated_prob + p
still_running = new_accum < self.act_threshold
remainder = torch.where(still_running, torch.zeros_like(halting_probs), 1 - accumulated_prob)
update_weights = torch.where(still_running, p, remainder)
accumulated_output += update_weights * torch.tanh(self.proj(x))
accumulated_prob += update_weights
n_updates += still_running.float()
if (1 - accumulated_prob).mean() < self.act_epsilon:
break
return accumulated_output, accumulated_prob.mean()
class TRMLayer(nn.Module):
def __init__(self, hidden_size, depth_H, act_threshold, act_epsilon):
super().__init__()
self.blocks = nn.ModuleList([
HaltingBlock(hidden_size, act_threshold, act_epsilon) for _ in range(depth_H)
])
self.norm = nn.LayerNorm(hidden_size)
def forward(self, x):
for block in self.blocks:
x, _ = block(x)
return self.norm(x)
class TRM(PreTrainedModel):
config_class = TRMConfig
def __init__(self, config):
super().__init__(config)
self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.pos_emb = nn.Parameter(torch.zeros(1, config.seq_len, config.hidden_size))
self.layers = nn.ModuleList([
TRMLayer(config.hidden_size, config.depth_H, config.act_threshold, config.act_epsilon)
for _ in range(config.depth_L)
])
self.norm = nn.LayerNorm(config.hidden_size)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def forward(self, input_ids, labels=None):
x = self.emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
for layer in self.layers:
x = layer(x)
x = self.norm(x)
logits = self.lm_head(x)
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return {"loss": loss, "logits": logits}
# ---------------------------
# Utility: Register to AutoClasses
# ---------------------------
from transformers import AutoConfig, AutoModel
AutoConfig.register("trm", TRMConfig)
AutoModel.register(TRMConfig, TRM)