Lorenzob commited on
Commit
216cb97
·
verified ·
1 Parent(s): d44efe6

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,8 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
1
+
2
+ # Trained TRM Model
3
+
4
+ This is a TRM model trained using the provided datasets.
5
+
6
+ ## How to use
7
+
8
+ [More detailed usage instructions can be added here]
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "H_cycles": 1,
3
+ "H_layers": 8,
4
+ "L_cycles": 1,
5
+ "L_layers": 2,
6
+ "act_epsilon": 0.01,
7
+ "act_threshold": 0.9,
8
+ "architectures": [
9
+ "TRM"
10
+ ],
11
+ "depth_H": 2,
12
+ "depth_L": 2,
13
+ "dropout": 0.1,
14
+ "dtype": "float32",
15
+ "expansion": 4,
16
+ "halt_epsilon": 0.01,
17
+ "halt_max_steps": 4,
18
+ "hidden_size": 32,
19
+ "model_type": "trm",
20
+ "num_heads": 4,
21
+ "pad_token_id": 0,
22
+ "seq_len": 4096,
23
+ "transformers_version": "4.57.0",
24
+ "vocab_size": 1183855
25
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89b827be0651807c55d7d4d3fcd1236efd8d9bcc1ff5ac64cd516718cede1383
3
+ size 303611768
modelling_trm.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import EinMix
6
+ from transformers import PreTrainedModel, PretrainedConfig
7
+
8
+ # ---------------------------
9
+ # Configuration Class
10
+ # ---------------------------
11
+ class TRMConfig(PretrainedConfig):
12
+ model_type = "trm"
13
+
14
+ def __init__(self,
15
+ vocab_size=32000,
16
+ hidden_size=256,
17
+ seq_len=128,
18
+ depth_L=2,
19
+ depth_H=2,
20
+ act_threshold=0.9,
21
+ act_epsilon=1e-2,
22
+ **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.vocab_size = vocab_size
25
+ self.hidden_size = hidden_size
26
+ self.seq_len = seq_len
27
+ self.depth_L = depth_L
28
+ self.depth_H = depth_H
29
+ self.act_threshold = act_threshold
30
+ self.act_epsilon = act_epsilon
31
+
32
+
33
+ # ---------------------------
34
+ # Model Architecture
35
+ # ---------------------------
36
+ class HaltingBlock(nn.Module):
37
+ def __init__(self, hidden_size, act_threshold, act_epsilon):
38
+ super().__init__()
39
+ self.proj = nn.Linear(hidden_size, hidden_size)
40
+ self.act_proj = nn.Linear(hidden_size, 1)
41
+ self.act_threshold = act_threshold
42
+ self.act_epsilon = act_epsilon
43
+
44
+ def forward(self, x):
45
+ halting_probs = torch.sigmoid(self.act_proj(x))
46
+ remainders = torch.zeros_like(halting_probs)
47
+ n_updates = torch.zeros_like(halting_probs)
48
+ still_running = torch.ones_like(halting_probs, dtype=torch.bool)
49
+ accumulated_output = torch.zeros_like(x)
50
+ accumulated_prob = torch.zeros_like(halting_probs)
51
+
52
+ while still_running.any():
53
+ p = torch.where(still_running, halting_probs, torch.zeros_like(halting_probs))
54
+ new_accum = accumulated_prob + p
55
+
56
+ still_running = new_accum < self.act_threshold
57
+ remainder = torch.where(still_running, torch.zeros_like(halting_probs), 1 - accumulated_prob)
58
+
59
+ update_weights = torch.where(still_running, p, remainder)
60
+ accumulated_output += update_weights * torch.tanh(self.proj(x))
61
+ accumulated_prob += update_weights
62
+ n_updates += still_running.float()
63
+
64
+ if (1 - accumulated_prob).mean() < self.act_epsilon:
65
+ break
66
+
67
+ return accumulated_output, accumulated_prob.mean()
68
+
69
+
70
+ class TRMLayer(nn.Module):
71
+ def __init__(self, hidden_size, depth_H, act_threshold, act_epsilon):
72
+ super().__init__()
73
+ self.blocks = nn.ModuleList([
74
+ HaltingBlock(hidden_size, act_threshold, act_epsilon) for _ in range(depth_H)
75
+ ])
76
+ self.norm = nn.LayerNorm(hidden_size)
77
+
78
+ def forward(self, x):
79
+ for block in self.blocks:
80
+ x, _ = block(x)
81
+ return self.norm(x)
82
+
83
+
84
+ class TRM(PreTrainedModel):
85
+ config_class = TRMConfig
86
+
87
+ def __init__(self, config):
88
+ super().__init__(config)
89
+ self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
90
+ self.pos_emb = nn.Parameter(torch.zeros(1, config.seq_len, config.hidden_size))
91
+ self.layers = nn.ModuleList([
92
+ TRMLayer(config.hidden_size, config.depth_H, config.act_threshold, config.act_epsilon)
93
+ for _ in range(config.depth_L)
94
+ ])
95
+ self.norm = nn.LayerNorm(config.hidden_size)
96
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
97
+
98
+ self.post_init()
99
+
100
+ def forward(self, input_ids, labels=None):
101
+ x = self.emb(input_ids) + self.pos_emb[:, :input_ids.size(1), :]
102
+ for layer in self.layers:
103
+ x = layer(x)
104
+ x = self.norm(x)
105
+ logits = self.lm_head(x)
106
+
107
+ loss = None
108
+ if labels is not None:
109
+ shift_logits = logits[..., :-1, :].contiguous()
110
+ shift_labels = labels[..., 1:].contiguous()
111
+ loss_fct = nn.CrossEntropyLoss()
112
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
113
+
114
+ return {"loss": loss, "logits": logits}
115
+
116
+
117
+ # ---------------------------
118
+ # Utility: Register to AutoClasses
119
+ # ---------------------------
120
+ from transformers import AutoConfig, AutoModel
121
+
122
+ AutoConfig.register("trm", TRMConfig)
123
+ AutoModel.register(TRMConfig, TRM)
special_tokens_map.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": "[CLS]",
3
+ "mask_token": "[MASK]",
4
+ "pad_token": "[PAD]",
5
+ "sep_token": "[SEP]",
6
+ "unk_token": "[UNK]"
7
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "extra_special_tokens": {},
49
+ "mask_token": "[MASK]",
50
+ "model_max_length": 512,
51
+ "never_split": null,
52
+ "pad_token": "[PAD]",
53
+ "sep_token": "[SEP]",
54
+ "strip_accents": null,
55
+ "tokenize_chinese_chars": true,
56
+ "tokenizer_class": "BertTokenizer",
57
+ "unk_token": "[UNK]"
58
+ }
vocab.txt ADDED
The diff for this file is too large to render. See raw diff