relaxed-recursive-transformer / RecursiveGPT2Model.py
brianling16's picture
Upload RecursiveGPT2Model.py with huggingface_hub
fe7385b verified
raw
history blame
14.9 kB
from transformers import GPT2LMHeadModel, GPT2Config
import torch
import copy
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, List
class MultiheadSelfAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int):
super().__init__()
assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
self.d_model = d_model
self.n_heads = n_heads
self.d_head = d_model // n_heads
# Standard projections
self.q_proj = nn.Linear(d_model, d_model)
self.k_proj = nn.Linear(d_model, d_model)
self.v_proj = nn.Linear(d_model, d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
B, T, C = x.shape
H = self.n_heads
D = self.d_head
q = self.q_proj(x).view(B, T, H, D).transpose(1, 2) # (B, H, T, D)
k = self.k_proj(x).view(B, T, H, D).transpose(1, 2)
v = self.v_proj(x).view(B, T, H, D).transpose(1, 2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(D) # (B, H, T, T)
if attn_mask is not None:
att = att + attn_mask # mask should be broadcastable; use -inf on masked positions
att = F.softmax(att, dim=-1)
y = att @ v # (B, H, T, D)
y = y.transpose(1, 2).contiguous().view(B, T, C)
y = self.out_proj(y)
return y
class MLP(nn.Module): # Fixed: Now inherits from nn.Module
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.activation = nn.ReLU()
def forward(self, x: torch.Tensor):
return self.fc2(self.activation(self.fc1(x)))
class TransformerLayer(nn.Module):
def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
self.self_attn = MultiheadSelfAttention(d_model, n_heads)
self.mlp = MLP(d_model, d_ff)
def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
y = self.self_attn(self.ln1(x), attn_mask)
x = x + self.dropout(y)
y = self.mlp(self.ln2(x))
return x + self.dropout(y)
class Transformer(nn.Module):
def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int, vocab_size: int, dropout: float = 0.1):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
self.n_layers = n_layers
self.d_ff = d_ff
self.tok_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(2048, d_model) # simple fixed max length
self.layers = nn.ModuleList([
TransformerLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model) # Added missing final LayerNorm
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.tok_emb.weight # weight tying
def forward(self, idx: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
B, T = idx.shape
pos = torch.arange(0, T, device=idx.device).unsqueeze(0)
x = self.tok_emb(idx) + self.pos_emb(pos)
for layer in self.layers:
x = layer(x, attn_mask)
x = self.ln_f(x)
return self.lm_head(x)
# ---- LoRA ----
class LoRAAdapter(nn.Module):
def __init__(self, in_features: int, out_features: int, rank: int, alpha: float = 1.0,
weight: Optional[torch.Tensor] = None):
super().__init__()
self.rank = rank
self.alpha = alpha
if rank > 0:
self.A = nn.Parameter(torch.zeros((rank, in_features)))
self.B = nn.Parameter(torch.zeros((out_features, rank)))
# Initialize with SVD if base weight is provided
if weight is not None:
U, S, Vh = torch.linalg.svd(weight, full_matrices=False)
U = U[:, :rank]
S = S[:rank]
Vh = Vh[:rank, :]
self.A.data = Vh # (rank, in_features)
self.B.data = U @ torch.diag(S) # (out_features, rank)
else:
nn.init.normal_(self.A, std=1/rank)
nn.init.zeros_(self.B)
else:
self.register_parameter('A', None)
self.register_parameter('B', None)
def delta(self) -> Optional[torch.Tensor]:
if self.rank == 0 or self.A is None or self.B is None:
return None
return (self.B @ self.A) * (self.alpha / self.rank) # (out, in)
def lora_parameters(self):
if self.A is not None:
yield self.A
if self.B is not None:
yield self.B
class LoRALinear(nn.Module):
def __init__(self, linear: nn.Linear, rank: int, alpha: float = 1.0, num_repeats: int = 1):
super().__init__()
self.linear = linear # base frozen linear
self.rank = rank
self.num_repeats = num_repeats
if rank > 0:
self.loras = nn.ModuleList([
LoRAAdapter(linear.in_features, linear.out_features, rank, alpha)
for _ in range(num_repeats)
])
else:
self.loras = nn.ModuleList([])
def forward(self, x, repeat_idx: int = 0):
out = self.linear(x) # [batch, ..., out_features]
if self.rank == 0:
return out
delta = self.loras[repeat_idx].delta() # (out, in)
if delta is not None:
delta_t = delta # nn.Linear expects (out, in)
return out + F.linear(x, delta_t)
return out
def lora_parameters(self):
for lora in self.loras:
yield from lora.lora_parameters()
class LoRAConv1D(nn.Module):
"""GPT-2 style Conv1D with LoRA support."""
def __init__(self, conv1d, rank: int, alpha: float = 1.0, num_repeats: int = 1):
super().__init__()
self.conv1d = conv1d # base GPT-2 Conv1D
self.rank = rank
self.num_repeats = num_repeats
in_features, out_features = conv1d.weight.shape # GPT-2 Conv1D: [in, out]
# Special handling for c_attn layer which has 3x output features
self.is_c_attn = (out_features % 3 == 0) and ("c_attn" in str(conv1d))
self.split_size = out_features // 3 if self.is_c_attn else out_features
if rank > 0:
if self.is_c_attn:
# Create separate LoRA adapters for Q, K, V projections
self.loras = nn.ModuleList([
nn.ModuleList([
LoRAAdapter(in_features, self.split_size, rank, alpha)
for _ in range(3) # Q, K, V
]) for _ in range(num_repeats)
])
else:
self.loras = nn.ModuleList([
LoRAAdapter(in_features, out_features, rank, alpha)
for _ in range(num_repeats)
])
else:
self.loras = nn.ModuleList([])
def forward(self, x, repeat_idx: int = 0):
"""
x: [batch, seq_len, in_features]
returns: [batch, seq_len, out_features]
"""
out = self.conv1d(x)
if self.rank == 0 or len(self.loras) == 0:
return out
if self.is_c_attn:
# Handle Q, K, V projections separately
deltas = []
for i in range(3):
delta = self.loras[repeat_idx][i].delta() # (split_size, in)
if delta is not None:
delta_t = delta.T # (in, split_size)
deltas.append(torch.matmul(x, delta_t))
if deltas:
return out + torch.cat(deltas, dim=-1)
return out
else:
delta = self.loras[repeat_idx].delta() # (out, in)
if delta is not None:
delta_t = delta.T # (in, out)
return out + torch.matmul(x, delta_t)
return out
def lora_parameters(self):
if self.is_c_attn:
for lora_group in self.loras:
for lora in lora_group:
yield from lora.lora_parameters()
else:
for lora in self.loras:
yield from lora.lora_parameters()
class SharedAttention(nn.Module):
def __init__(self, base_attn, num_repeats: int, lora_rank: int, lora_alpha: float):
super().__init__()
self.n_heads = base_attn.n_heads
self.d_head = base_attn.d_head
self.d_model = base_attn.d_model
self.q_proj = LoRALinear(base_attn.q_proj, lora_rank, lora_alpha, num_repeats)
self.k_proj = LoRALinear(base_attn.k_proj, lora_rank, lora_alpha, num_repeats)
self.v_proj = LoRALinear(base_attn.v_proj, lora_rank, lora_alpha, num_repeats)
self.out_proj = LoRALinear(base_attn.out_proj, lora_rank, lora_alpha, num_repeats)
def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None):
B, T, C = x.shape
H, D = self.n_heads, self.d_head
q = self.q_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
k = self.k_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
v = self.v_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2)
att = (q @ k.transpose(-2, -1)) / math.sqrt(D)
if attn_mask is not None:
att = att + attn_mask
att = F.softmax(att, dim=-1)
y = att @ v
y = y.transpose(1,2).contiguous().view(B, T, C)
return self.out_proj(y, repeat_idx)
class SharedMLP(nn.Module):
def __init__(self, base_mlp, num_repeats: int, lora_rank: int, lora_alpha: float):
super().__init__()
self.fc1 = LoRALinear(base_mlp.fc1, lora_rank, lora_alpha, num_repeats)
self.fc2 = LoRALinear(base_mlp.fc2, lora_rank, lora_alpha, num_repeats)
self.act = base_mlp.act
def forward(self, x, repeat_idx: int):
return self.fc2(self.act(self.fc1(x, repeat_idx)), repeat_idx)
class SharedTransformerLayer(nn.Module):
def __init__(self, base_layer, num_repeats: int, lora_rank: int, lora_alpha: float):
super().__init__()
self.ln1 = base_layer.ln1
self.ln2 = base_layer.ln2
self.dropout1 = base_layer.dropout1
self.dropout2 = base_layer.dropout2
self.attn = SharedAttention(base_layer.attn, num_repeats, lora_rank, lora_alpha)
self.mlp = SharedMLP(base_layer.mlp, num_repeats, lora_rank, lora_alpha)
def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None):
y = self.attn(self.ln1(x), repeat_idx, attn_mask)
x = x + self.dropout1(y)
y = self.mlp(self.ln2(x), repeat_idx)
x = x + self.dropout2(y)
return x
# ---- Conversion Utilities ----
def average_weights(layers, attr):
weights = [getattr(layer, attr).weight.data for layer in layers]
return torch.stack(weights, dim=0).mean(dim=0)
def initialize_lora_with_svd(lora_layer, original_weights, repeat_indices, rank):
"""
original_weights: list of original weights for each repeat index
repeat_indices: which repeat indices these weights correspond to
"""
shared_weight = lora_layer.base_layer.weight.data.clone()
for idx, orig_weight in zip(repeat_indices, original_weights):
residual = orig_weight - shared_weight
U, S, Vh = torch.linalg.svd(residual, full_matrices=False)
# Truncate to rank
U = U[:, :rank]
S = S[:rank]
Vh = Vh[:rank, :]
# Initialize LoRA weights
lora_layer.lora_A[idx].weight.data = Vh # A = Vᵣᵀ
lora_layer.lora_B[idx].weight.data = U @ torch.diag(S) # B = UᵣΣᵣ
def convert_to_recursive(model, K=2, rank=8, lora_alpha=1.0):
n_layers = len(model.transformer.h)
new_blocks = []
for b in range(n_layers // K):
block_layers = model.transformer.h[b*K:(b+1)*K]
base_layer = copy.deepcopy(block_layers[0])
# Average weights across the block for shared parameters
with torch.no_grad():
if hasattr(base_layer.attn, 'c_attn'):
shared_weight = average_weights([l.attn for l in block_layers], 'c_attn')
base_layer.attn.c_attn.weight.data = shared_weight
if hasattr(base_layer.attn, 'c_proj'):
shared_weight = average_weights([l.attn for l in block_layers], 'c_proj')
base_layer.attn.c_proj.weight.data = shared_weight
if hasattr(base_layer.mlp, 'c_fc'):
shared_weight = average_weights([l.mlp for l in block_layers], 'c_fc')
base_layer.mlp.c_fc.weight.data = shared_weight
if hasattr(base_layer.mlp, 'c_proj'):
shared_weight = average_weights([l.mlp for l in block_layers], 'c_proj')
base_layer.mlp.c_proj.weight.data = shared_weight
# Convert to LoRA
if hasattr(base_layer.attn, 'c_attn'):
base_layer.attn.c_attn = LoRAConv1D(
base_layer.attn.c_attn, rank, lora_alpha, K
)
if hasattr(base_layer.attn, 'c_proj'):
base_layer.attn.c_proj = LoRAConv1D(
base_layer.attn.c_proj, rank, lora_alpha, K
)
if hasattr(base_layer.mlp, 'c_fc'):
base_layer.mlp.c_fc = LoRAConv1D(
base_layer.mlp.c_fc, rank, lora_alpha, K
)
if hasattr(base_layer.mlp, 'c_proj'):
base_layer.mlp.c_proj = LoRAConv1D(
base_layer.mlp.c_proj, rank, lora_alpha, K
)
new_blocks.append(base_layer)
model.transformer.h = nn.ModuleList(new_blocks)
return model
class RecursiveGPT2Config(GPT2Config):
model_type = "recursive_gpt2"
def __init__(self, K=2, rank=8, **kwargs):
super().__init__(**kwargs)
self.K = K
self.rank = rank
class RecursiveGPT2LMHeadModel(GPT2LMHeadModel):
config_class = RecursiveGPT2Config
def __init__(self, config):
# Initialize as regular GPT2 first
super().__init__(config)
# Apply recursive modifications
convert_to_recursive(self, K=config.K, rank=config.rank)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# This ensures the recursive modifications are applied when loading
model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
return model