from transformers import GPT2LMHeadModel, GPT2Config import torch import copy import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, List class MultiheadSelfAttention(nn.Module): def __init__(self, d_model: int, n_heads: int): super().__init__() assert d_model % n_heads == 0, "d_model must be divisible by n_heads" self.d_model = d_model self.n_heads = n_heads self.d_head = d_model // n_heads # Standard projections self.q_proj = nn.Linear(d_model, d_model) self.k_proj = nn.Linear(d_model, d_model) self.v_proj = nn.Linear(d_model, d_model) self.out_proj = nn.Linear(d_model, d_model) def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): B, T, C = x.shape H = self.n_heads D = self.d_head q = self.q_proj(x).view(B, T, H, D).transpose(1, 2) # (B, H, T, D) k = self.k_proj(x).view(B, T, H, D).transpose(1, 2) v = self.v_proj(x).view(B, T, H, D).transpose(1, 2) att = (q @ k.transpose(-2, -1)) / math.sqrt(D) # (B, H, T, T) if attn_mask is not None: att = att + attn_mask # mask should be broadcastable; use -inf on masked positions att = F.softmax(att, dim=-1) y = att @ v # (B, H, T, D) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.out_proj(y) return y class MLP(nn.Module): # Fixed: Now inherits from nn.Module def __init__(self, d_model: int, d_ff: int): super().__init__() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.activation = nn.ReLU() def forward(self, x: torch.Tensor): return self.fc2(self.activation(self.fc1(x))) class TransformerLayer(nn.Module): def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.ln2 = nn.LayerNorm(d_model) self.dropout = nn.Dropout(dropout) self.self_attn = MultiheadSelfAttention(d_model, n_heads) self.mlp = MLP(d_model, d_ff) def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): y = self.self_attn(self.ln1(x), attn_mask) x = x + self.dropout(y) y = self.mlp(self.ln2(x)) return x + self.dropout(y) class Transformer(nn.Module): def __init__(self, n_layers: int, d_model: int, n_heads: int, d_ff: int, vocab_size: int, dropout: float = 0.1): super().__init__() self.d_model = d_model self.n_heads = n_heads self.n_layers = n_layers self.d_ff = d_ff self.tok_emb = nn.Embedding(vocab_size, d_model) self.pos_emb = nn.Embedding(2048, d_model) # simple fixed max length self.layers = nn.ModuleList([ TransformerLayer(d_model, n_heads, d_ff, dropout) for _ in range(n_layers) ]) self.ln_f = nn.LayerNorm(d_model) # Added missing final LayerNorm self.lm_head = nn.Linear(d_model, vocab_size, bias=False) self.lm_head.weight = self.tok_emb.weight # weight tying def forward(self, idx: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): B, T = idx.shape pos = torch.arange(0, T, device=idx.device).unsqueeze(0) x = self.tok_emb(idx) + self.pos_emb(pos) for layer in self.layers: x = layer(x, attn_mask) x = self.ln_f(x) return self.lm_head(x) # ---- LoRA ---- class LoRAAdapter(nn.Module): def __init__(self, in_features: int, out_features: int, rank: int, alpha: float = 1.0, weight: Optional[torch.Tensor] = None): super().__init__() self.rank = rank self.alpha = alpha if rank > 0: self.A = nn.Parameter(torch.zeros((rank, in_features))) self.B = nn.Parameter(torch.zeros((out_features, rank))) # Initialize with SVD if base weight is provided if weight is not None: U, S, Vh = torch.linalg.svd(weight, full_matrices=False) U = U[:, :rank] S = S[:rank] Vh = Vh[:rank, :] self.A.data = Vh # (rank, in_features) self.B.data = U @ torch.diag(S) # (out_features, rank) else: nn.init.normal_(self.A, std=1/rank) nn.init.zeros_(self.B) else: self.register_parameter('A', None) self.register_parameter('B', None) def delta(self) -> Optional[torch.Tensor]: if self.rank == 0 or self.A is None or self.B is None: return None return (self.B @ self.A) * (self.alpha / self.rank) # (out, in) def lora_parameters(self): if self.A is not None: yield self.A if self.B is not None: yield self.B class LoRALinear(nn.Module): def __init__(self, linear: nn.Linear, rank: int, alpha: float = 1.0, num_repeats: int = 1): super().__init__() self.linear = linear # base frozen linear self.rank = rank self.num_repeats = num_repeats if rank > 0: self.loras = nn.ModuleList([ LoRAAdapter(linear.in_features, linear.out_features, rank, alpha) for _ in range(num_repeats) ]) else: self.loras = nn.ModuleList([]) def forward(self, x, repeat_idx: int = 0): out = self.linear(x) # [batch, ..., out_features] if self.rank == 0: return out delta = self.loras[repeat_idx].delta() # (out, in) if delta is not None: delta_t = delta # nn.Linear expects (out, in) return out + F.linear(x, delta_t) return out def lora_parameters(self): for lora in self.loras: yield from lora.lora_parameters() class LoRAConv1D(nn.Module): """GPT-2 style Conv1D with LoRA support.""" def __init__(self, conv1d, rank: int, alpha: float = 1.0, num_repeats: int = 1): super().__init__() self.conv1d = conv1d # base GPT-2 Conv1D self.rank = rank self.num_repeats = num_repeats in_features, out_features = conv1d.weight.shape # GPT-2 Conv1D: [in, out] # Special handling for c_attn layer which has 3x output features self.is_c_attn = (out_features % 3 == 0) and ("c_attn" in str(conv1d)) self.split_size = out_features // 3 if self.is_c_attn else out_features if rank > 0: if self.is_c_attn: # Create separate LoRA adapters for Q, K, V projections self.loras = nn.ModuleList([ nn.ModuleList([ LoRAAdapter(in_features, self.split_size, rank, alpha) for _ in range(3) # Q, K, V ]) for _ in range(num_repeats) ]) else: self.loras = nn.ModuleList([ LoRAAdapter(in_features, out_features, rank, alpha) for _ in range(num_repeats) ]) else: self.loras = nn.ModuleList([]) def forward(self, x, repeat_idx: int = 0): """ x: [batch, seq_len, in_features] returns: [batch, seq_len, out_features] """ out = self.conv1d(x) if self.rank == 0 or len(self.loras) == 0: return out if self.is_c_attn: # Handle Q, K, V projections separately deltas = [] for i in range(3): delta = self.loras[repeat_idx][i].delta() # (split_size, in) if delta is not None: delta_t = delta.T # (in, split_size) deltas.append(torch.matmul(x, delta_t)) if deltas: return out + torch.cat(deltas, dim=-1) return out else: delta = self.loras[repeat_idx].delta() # (out, in) if delta is not None: delta_t = delta.T # (in, out) return out + torch.matmul(x, delta_t) return out def lora_parameters(self): if self.is_c_attn: for lora_group in self.loras: for lora in lora_group: yield from lora.lora_parameters() else: for lora in self.loras: yield from lora.lora_parameters() class SharedAttention(nn.Module): def __init__(self, base_attn, num_repeats: int, lora_rank: int, lora_alpha: float): super().__init__() self.n_heads = base_attn.n_heads self.d_head = base_attn.d_head self.d_model = base_attn.d_model self.q_proj = LoRALinear(base_attn.q_proj, lora_rank, lora_alpha, num_repeats) self.k_proj = LoRALinear(base_attn.k_proj, lora_rank, lora_alpha, num_repeats) self.v_proj = LoRALinear(base_attn.v_proj, lora_rank, lora_alpha, num_repeats) self.out_proj = LoRALinear(base_attn.out_proj, lora_rank, lora_alpha, num_repeats) def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None): B, T, C = x.shape H, D = self.n_heads, self.d_head q = self.q_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2) k = self.k_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2) v = self.v_proj(x, repeat_idx).view(B, T, H, D).transpose(1,2) att = (q @ k.transpose(-2, -1)) / math.sqrt(D) if attn_mask is not None: att = att + attn_mask att = F.softmax(att, dim=-1) y = att @ v y = y.transpose(1,2).contiguous().view(B, T, C) return self.out_proj(y, repeat_idx) class SharedMLP(nn.Module): def __init__(self, base_mlp, num_repeats: int, lora_rank: int, lora_alpha: float): super().__init__() self.fc1 = LoRALinear(base_mlp.fc1, lora_rank, lora_alpha, num_repeats) self.fc2 = LoRALinear(base_mlp.fc2, lora_rank, lora_alpha, num_repeats) self.act = base_mlp.act def forward(self, x, repeat_idx: int): return self.fc2(self.act(self.fc1(x, repeat_idx)), repeat_idx) class SharedTransformerLayer(nn.Module): def __init__(self, base_layer, num_repeats: int, lora_rank: int, lora_alpha: float): super().__init__() self.ln1 = base_layer.ln1 self.ln2 = base_layer.ln2 self.dropout1 = base_layer.dropout1 self.dropout2 = base_layer.dropout2 self.attn = SharedAttention(base_layer.attn, num_repeats, lora_rank, lora_alpha) self.mlp = SharedMLP(base_layer.mlp, num_repeats, lora_rank, lora_alpha) def forward(self, x, repeat_idx: int, attn_mask: Optional[torch.Tensor] = None): y = self.attn(self.ln1(x), repeat_idx, attn_mask) x = x + self.dropout1(y) y = self.mlp(self.ln2(x), repeat_idx) x = x + self.dropout2(y) return x # ---- Conversion Utilities ---- def average_weights(layers, attr): weights = [getattr(layer, attr).weight.data for layer in layers] return torch.stack(weights, dim=0).mean(dim=0) def initialize_lora_with_svd(lora_layer, original_weights, repeat_indices, rank): """ original_weights: list of original weights for each repeat index repeat_indices: which repeat indices these weights correspond to """ shared_weight = lora_layer.base_layer.weight.data.clone() for idx, orig_weight in zip(repeat_indices, original_weights): residual = orig_weight - shared_weight U, S, Vh = torch.linalg.svd(residual, full_matrices=False) # Truncate to rank U = U[:, :rank] S = S[:rank] Vh = Vh[:rank, :] # Initialize LoRA weights lora_layer.lora_A[idx].weight.data = Vh # A = Vᵣᵀ lora_layer.lora_B[idx].weight.data = U @ torch.diag(S) # B = UᵣΣᵣ def convert_to_recursive(model, K=2, rank=8, lora_alpha=1.0): n_layers = len(model.transformer.h) new_blocks = [] for b in range(n_layers // K): block_layers = model.transformer.h[b*K:(b+1)*K] base_layer = copy.deepcopy(block_layers[0]) # Average weights across the block for shared parameters with torch.no_grad(): if hasattr(base_layer.attn, 'c_attn'): shared_weight = average_weights([l.attn for l in block_layers], 'c_attn') base_layer.attn.c_attn.weight.data = shared_weight if hasattr(base_layer.attn, 'c_proj'): shared_weight = average_weights([l.attn for l in block_layers], 'c_proj') base_layer.attn.c_proj.weight.data = shared_weight if hasattr(base_layer.mlp, 'c_fc'): shared_weight = average_weights([l.mlp for l in block_layers], 'c_fc') base_layer.mlp.c_fc.weight.data = shared_weight if hasattr(base_layer.mlp, 'c_proj'): shared_weight = average_weights([l.mlp for l in block_layers], 'c_proj') base_layer.mlp.c_proj.weight.data = shared_weight # Convert to LoRA if hasattr(base_layer.attn, 'c_attn'): base_layer.attn.c_attn = LoRAConv1D( base_layer.attn.c_attn, rank, lora_alpha, K ) if hasattr(base_layer.attn, 'c_proj'): base_layer.attn.c_proj = LoRAConv1D( base_layer.attn.c_proj, rank, lora_alpha, K ) if hasattr(base_layer.mlp, 'c_fc'): base_layer.mlp.c_fc = LoRAConv1D( base_layer.mlp.c_fc, rank, lora_alpha, K ) if hasattr(base_layer.mlp, 'c_proj'): base_layer.mlp.c_proj = LoRAConv1D( base_layer.mlp.c_proj, rank, lora_alpha, K ) new_blocks.append(base_layer) model.transformer.h = nn.ModuleList(new_blocks) return model class RecursiveGPT2Config(GPT2Config): model_type = "recursive_gpt2" def __init__(self, K=2, rank=8, **kwargs): super().__init__(**kwargs) self.K = K self.rank = rank class RecursiveGPT2LMHeadModel(GPT2LMHeadModel): config_class = RecursiveGPT2Config def __init__(self, config): # Initialize as regular GPT2 first super().__init__(config) # Apply recursive modifications convert_to_recursive(self, K=config.K, rank=config.rank) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # This ensures the recursive modifications are applied when loading model = super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) return model