brianling16 commited on
Commit
fe7385b
·
verified ·
1 Parent(s): f8f3daa

Upload RecursiveGPT2Model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. RecursiveGPT2Model.py +353 -1
RecursiveGPT2Model.py CHANGED
@@ -1,6 +1,358 @@
1
  from transformers import GPT2LMHeadModel, GPT2Config
2
- from shared_attention import convert_to_recursive
3
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  class RecursiveGPT2Config(GPT2Config):
6
  model_type = "recursive_gpt2"
 
1
  from transformers import GPT2LMHeadModel, GPT2Config
 
2
  import torch
3
+ import copy
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ from typing import Optional, List
8
+
9
+ class MultiheadSelfAttention(nn.Module):
10
+ def __init__(self, d_model: int, n_heads: int):
11
+ super().__init__()
12
+ assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
13
+ self.d_model = d_model
14
+ self.n_heads = n_heads
15
+ self.d_head = d_model // n_heads
16
+
17
+ # Standard projections
18
+ self.q_proj = nn.Linear(d_model, d_model)
19
+ self.k_proj = nn.Linear(d_model, d_model)
20
+ self.v_proj = nn.Linear(d_model, d_model)
21
+ self.out_proj = nn.Linear(d_model, d_model)
22
+
23
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
24
+ B, T, C = x.shape
25
+ H = self.n_heads
26
+ D = self.d_head
27
+
28
+ q = self.q_proj(x).view(B, T, H, D).transpose(1, 2) # (B, H, T, D)
29
+ k = self.k_proj(x).view(B, T, H, D).transpose(1, 2)
30
+ v = self.v_proj(x).view(B, T, H, D).transpose(1, 2)
31
+
32
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(D) # (B, H, T, T)
33
+ if attn_mask is not None:
34
+ att = att + attn_mask # mask should be broadcastable; use -inf on masked positions
35
+ att = F.softmax(att, dim=-1)
36
+ y = att @ v # (B, H, T, D)
37
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
38
+ y = self.out_proj(y)
39
+ return y
40
+
41
+ class MLP(nn.Module): # Fixed: Now inherits from nn.Module
42
+ def __init__(self, d_model: int, d_ff: int):
43
+ super().__init__()
44
+ self.fc1 = nn.Linear(d_model, d_ff)
45
+ self.fc2 = nn.Linear(d_ff, d_model)
46
+ self.activation = nn.ReLU()
47
+
48
+ def forward(self, x: torch.Tensor):
49
+ return self.fc2(self.activation(self.fc1(x)))
50
+
51
+ class TransformerLayer(nn.Module):
52
+ def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
53
+ super().__init__()
54
+ self.ln1 = nn.LayerNorm(d_model)
55
+ self.ln2 = nn.LayerNorm(d_model)
56
+ self.dropout = nn.Dropout(dropout)
57
+ self.self_attn = MultiheadSelfAttention(d_model, n_heads)
58
+ self.mlp = MLP(d_model, d_ff)
59
+
60
+ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
61
+ y = self.self_attn(self.ln1(x), attn_mask)
62
+ x = x + self.dropout(y)
63
+ y = self.mlp(self.ln2(x))
64
+ return x + self.dropout(y)
65
+
66
+ class Transformer(nn.Module):
67
+ def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int, vocab_size: int, dropout: float = 0.1):
68
+ super().__init__()
69
+ self.d_model = d_model
70
+ self.n_heads = n_heads
71
+ self.n_layers = n_layers
72
+ self.d_ff = d_ff
73
+ self.tok_emb = nn.Embedding(vocab_size, d_model)
74
+ self.pos_emb = nn.Embedding(2048, d_model) # simple fixed max length
75
+ self.layers = nn.ModuleList([
76
+ TransformerLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
77
+ ])
78
+ self.ln_f = nn.LayerNorm(d_model) # Added missing final LayerNorm
79
+ self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
80
+ self.lm_head.weight = self.tok_emb.weight # weight tying
81
+
82
+ def forward(self, idx: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
83
+ B, T = idx.shape
84
+ pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
85
+ x = self.tok_emb(idx) + self.pos_emb(pos)
86
+ for layer in self.layers:
87
+ x = layer(x, attn_mask)
88
+ x = self.ln_f(x)
89
+ return self.lm_head(x)
90
+
91
+ # ---- LoRA ----
92
+ class LoRAAdapter(nn.Module):
93
+ def __init__(self, in_features: int, out_features: int, rank: int, alpha: float = 1.0,
94
+ weight: Optional[torch.Tensor] = None):
95
+ super().__init__()
96
+ self.rank = rank
97
+ self.alpha = alpha
98
+ if rank > 0:
99
+ self.A = nn.Parameter(torch.zeros((rank, in_features)))
100
+ self.B = nn.Parameter(torch.zeros((out_features, rank)))
101
+
102
+ # Initialize with SVD if base weight is provided
103
+ if weight is not None:
104
+ U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
105
+ U = U[:, :rank]
106
+ S = S[:rank]
107
+ Vh = Vh[:rank, :]
108
+ self.A.data = Vh # (rank, in_features)
109
+ self.B.data = U @ torch.diag(S) # (out_features, rank)
110
+ else:
111
+ nn.init.normal_(self.A, std=1/rank)
112
+ nn.init.zeros_(self.B)
113
+ else:
114
+ self.register_parameter('A', None)
115
+ self.register_parameter('B', None)
116
+
117
+ def delta(self) -> Optional[torch.Tensor]:
118
+ if self.rank == 0 or self.A is None or self.B is None:
119
+ return None
120
+ return (self.B @ self.A) * (self.alpha / self.rank) # (out, in)
121
+
122
+ def lora_parameters(self):
123
+ if self.A is not None:
124
+ yield self.A
125
+ if self.B is not None:
126
+ yield self.B
127
+
128
+ class LoRALinear(nn.Module):
129
+ def __init__(self, linear: nn.Linear, rank: int, alpha: float = 1.0, num_repeats: int = 1):
130
+ super().__init__()
131
+ self.linear = linear # base frozen linear
132
+ self.rank = rank
133
+ self.num_repeats = num_repeats
134
+
135
+ if rank > 0:
136
+ self.loras = nn.ModuleList([
137
+ LoRAAdapter(linear.in_features, linear.out_features, rank, alpha)
138
+ for _ in range(num_repeats)
139
+ ])
140
+ else:
141
+ self.loras = nn.ModuleList([])
142
+
143
+ def forward(self, x, repeat_idx: int = 0):
144
+ out = self.linear(x) # [batch, ..., out_features]
145
+ if self.rank == 0:
146
+ return out
147
+ delta = self.loras[repeat_idx].delta() # (out, in)
148
+ if delta is not None:
149
+ delta_t = delta # nn.Linear expects (out, in)
150
+ return out + F.linear(x, delta_t)
151
+ return out
152
+
153
+ def lora_parameters(self):
154
+ for lora in self.loras:
155
+ yield from lora.lora_parameters()
156
+
157
+
158
+ class LoRAConv1D(nn.Module):
159
+ """GPT-2 style Conv1D with LoRA support."""
160
+ def __init__(self, conv1d, rank: int, alpha: float = 1.0, num_repeats: int = 1):
161
+ super().__init__()
162
+ self.conv1d = conv1d # base GPT-2 Conv1D
163
+ self.rank = rank
164
+ self.num_repeats = num_repeats
165
+ in_features, out_features = conv1d.weight.shape # GPT-2 Conv1D: [in, out]
166
+
167
+ # Special handling for c_attn layer which has 3x output features
168
+ self.is_c_attn = (out_features % 3 == 0) and ("c_attn" in str(conv1d))
169
+ self.split_size = out_features // 3 if self.is_c_attn else out_features
170
+
171
+ if rank > 0:
172
+ if self.is_c_attn:
173
+ # Create separate LoRA adapters for Q, K, V projections
174
+ self.loras = nn.ModuleList([
175
+ nn.ModuleList([
176
+ LoRAAdapter(in_features, self.split_size, rank, alpha)
177
+ for _ in range(3) # Q, K, V
178
+ ]) for _ in range(num_repeats)
179
+ ])
180
+ else:
181
+ self.loras = nn.ModuleList([
182
+ LoRAAdapter(in_features, out_features, rank, alpha)
183
+ for _ in range(num_repeats)
184
+ ])
185
+ else:
186
+ self.loras = nn.ModuleList([])
187
+
188
+ def forward(self, x, repeat_idx: int = 0):
189
+ """
190
+ x: [batch, seq_len, in_features]
191
+ returns: [batch, seq_len, out_features]
192
+ """
193
+ out = self.conv1d(x)
194
+ if self.rank == 0 or len(self.loras) == 0:
195
+ return out
196
+
197
+ if self.is_c_attn:
198
+ # Handle Q, K, V projections separately
199
+ deltas = []
200
+ for i in range(3):
201
+ delta = self.loras[repeat_idx][i].delta() # (split_size, in)
202
+ if delta is not None:
203
+ delta_t = delta.T # (in, split_size)
204
+ deltas.append(torch.matmul(x, delta_t))
205
+ if deltas:
206
+ return out + torch.cat(deltas, dim=-1)
207
+ return out
208
+ else:
209
+ delta = self.loras[repeat_idx].delta() # (out, in)
210
+ if delta is not None:
211
+ delta_t = delta.T # (in, out)
212
+ return out + torch.matmul(x, delta_t)
213
+ return out
214
+
215
+ def lora_parameters(self):
216
+ if self.is_c_attn:
217
+ for lora_group in self.loras:
218
+ for lora in lora_group:
219
+ yield from lora.lora_parameters()
220
+ else:
221
+ for lora in self.loras:
222
+ yield from lora.lora_parameters()
223
+
224
+ class SharedAttention(nn.Module):
225
+ def __init__(self, base_attn, num_repeats: int, lora_rank: int, lora_alpha: float):
226
+ super().__init__()
227
+ self.n_heads = base_attn.n_heads
228
+ self.d_head = base_attn.d_head
229
+ self.d_model = base_attn.d_model
230
+
231
+ self.q_proj = LoRALinear(base_attn.q_proj, lora_rank, lora_alpha, num_repeats)
232
+ self.k_proj = LoRALinear(base_attn.k_proj, lora_rank, lora_alpha, num_repeats)
233
+ self.v_proj = LoRALinear(base_attn.v_proj, lora_rank, lora_alpha, num_repeats)
234
+ self.out_proj = LoRALinear(base_attn.out_proj, lora_rank, lora_alpha, num_repeats)
235
+
236
+ def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None):
237
+ B, T, C = x.shape
238
+ H, D = self.n_heads, self.d_head
239
+
240
+ q = self.q_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
241
+ k = self.k_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
242
+ v = self.v_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
243
+
244
+ att = (q @ k.transpose(-2, -1)) / math.sqrt(D)
245
+ if attn_mask is not None:
246
+ att = att + attn_mask
247
+ att = F.softmax(att, dim=-1)
248
+ y = att @ v
249
+ y = y.transpose(1,2).contiguous().view(B, T, C)
250
+ return self.out_proj(y, repeat_idx)
251
+
252
+ class SharedMLP(nn.Module):
253
+ def __init__(self, base_mlp, num_repeats: int, lora_rank: int, lora_alpha: float):
254
+ super().__init__()
255
+ self.fc1 = LoRALinear(base_mlp.fc1, lora_rank, lora_alpha, num_repeats)
256
+ self.fc2 = LoRALinear(base_mlp.fc2, lora_rank, lora_alpha, num_repeats)
257
+ self.act = base_mlp.act
258
+
259
+ def forward(self, x, repeat_idx: int):
260
+ return self.fc2(self.act(self.fc1(x, repeat_idx)), repeat_idx)
261
+
262
+ class SharedTransformerLayer(nn.Module):
263
+ def __init__(self, base_layer, num_repeats: int, lora_rank: int, lora_alpha: float):
264
+ super().__init__()
265
+ self.ln1 = base_layer.ln1
266
+ self.ln2 = base_layer.ln2
267
+ self.dropout1 = base_layer.dropout1
268
+ self.dropout2 = base_layer.dropout2
269
+ self.attn = SharedAttention(base_layer.attn, num_repeats, lora_rank, lora_alpha)
270
+ self.mlp = SharedMLP(base_layer.mlp, num_repeats, lora_rank, lora_alpha)
271
+
272
+ def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None):
273
+ y = self.attn(self.ln1(x), repeat_idx, attn_mask)
274
+ x = x + self.dropout1(y)
275
+ y = self.mlp(self.ln2(x), repeat_idx)
276
+ x = x + self.dropout2(y)
277
+ return x
278
+
279
+ # ---- Conversion Utilities ----
280
+ def average_weights(layers, attr):
281
+ weights = [getattr(layer, attr).weight.data for layer in layers]
282
+ return torch.stack(weights, dim=0).mean(dim=0)
283
+
284
+
285
+ def initialize_lora_with_svd(lora_layer, original_weights, repeat_indices, rank):
286
+ """
287
+ original_weights: list of original weights for each repeat index
288
+ repeat_indices: which repeat indices these weights correspond to
289
+ """
290
+ shared_weight = lora_layer.base_layer.weight.data.clone()
291
+
292
+ for idx, orig_weight in zip(repeat_indices, original_weights):
293
+ residual = orig_weight - shared_weight
294
+ U, S, Vh = torch.linalg.svd(residual, full_matrices=False)
295
+
296
+ # Truncate to rank
297
+ U = U[:, :rank]
298
+ S = S[:rank]
299
+ Vh = Vh[:rank, :]
300
+
301
+ # Initialize LoRA weights
302
+ lora_layer.lora_A[idx].weight.data = Vh # A = Vᵣᵀ
303
+ lora_layer.lora_B[idx].weight.data = U @ torch.diag(S) # B = UᵣΣᵣ
304
+
305
+ def convert_to_recursive(model, K=2, rank=8, lora_alpha=1.0):
306
+ n_layers = len(model.transformer.h)
307
+ new_blocks = []
308
+
309
+ for b in range(n_layers // K):
310
+ block_layers = model.transformer.h[b*K:(b+1)*K]
311
+ base_layer = copy.deepcopy(block_layers[0])
312
+
313
+ # Average weights across the block for shared parameters
314
+ with torch.no_grad():
315
+ if hasattr(base_layer.attn, 'c_attn'):
316
+ shared_weight = average_weights([l.attn for l in block_layers], 'c_attn')
317
+ base_layer.attn.c_attn.weight.data = shared_weight
318
+
319
+ if hasattr(base_layer.attn, 'c_proj'):
320
+ shared_weight = average_weights([l.attn for l in block_layers], 'c_proj')
321
+ base_layer.attn.c_proj.weight.data = shared_weight
322
+
323
+ if hasattr(base_layer.mlp, 'c_fc'):
324
+ shared_weight = average_weights([l.mlp for l in block_layers], 'c_fc')
325
+ base_layer.mlp.c_fc.weight.data = shared_weight
326
+
327
+ if hasattr(base_layer.mlp, 'c_proj'):
328
+ shared_weight = average_weights([l.mlp for l in block_layers], 'c_proj')
329
+ base_layer.mlp.c_proj.weight.data = shared_weight
330
+
331
+ # Convert to LoRA
332
+ if hasattr(base_layer.attn, 'c_attn'):
333
+ base_layer.attn.c_attn = LoRAConv1D(
334
+ base_layer.attn.c_attn, rank, lora_alpha, K
335
+ )
336
+
337
+ if hasattr(base_layer.attn, 'c_proj'):
338
+ base_layer.attn.c_proj = LoRAConv1D(
339
+ base_layer.attn.c_proj, rank, lora_alpha, K
340
+ )
341
+
342
+ if hasattr(base_layer.mlp, 'c_fc'):
343
+ base_layer.mlp.c_fc = LoRAConv1D(
344
+ base_layer.mlp.c_fc, rank, lora_alpha, K
345
+ )
346
+
347
+ if hasattr(base_layer.mlp, 'c_proj'):
348
+ base_layer.mlp.c_proj = LoRAConv1D(
349
+ base_layer.mlp.c_proj, rank, lora_alpha, K
350
+ )
351
+
352
+ new_blocks.append(base_layer)
353
+
354
+ model.transformer.h = nn.ModuleList(new_blocks)
355
+ return model
356
 
357
  class RecursiveGPT2Config(GPT2Config):
358
  model_type = "recursive_gpt2"