import math, torch, torch.nn as nn, torch.nn.functional as F from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions from typing import Optional, Tuple, List class GPT4DevConfig(PretrainedConfig): model_type = "gpt4dev" def __init__( self, vocab_size=50257, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, max_position_embeddings=1024, rope_theta=10000.0, qkv_bias=True, layer_norm_epsilon=1e-5, initializer_range=0.02, multi_query=True, architectures=None, tie_word_embeddings=False, compat_prefill_tokens: int = 0, **kwargs, ): super().__init__( vocab_size=vocab_size, hidden_size=hidden_size, num_hidden_layers=num_hidden_layers, num_attention_heads=num_attention_heads, intermediate_size=intermediate_size, max_position_embeddings=max_position_embeddings, rope_theta=rope_theta, qkv_bias=qkv_bias, layer_norm_epsilon=layer_norm_epsilon, initializer_range=initializer_range, multi_query=multi_query, architectures=architectures, tie_word_embeddings=tie_word_embeddings, compat_prefill_tokens=compat_prefill_tokens, **kwargs, ) def rope_cache(seq_len, dim, theta, device, dtype=torch.float32): # Note: kept float32 to match training-time math used in early checkpoints inv = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) t = torch.arange(seq_len, device=device, dtype=torch.float32) freqs = torch.outer(t, inv) return torch.polar(torch.ones_like(freqs), freqs).to(dtype) def apply_rope(x, rope): # x: (..., D) with D even; rope: (T, D/2). In legacy math this can be float (cos-only) xc = torch.view_as_complex(x.to(torch.float32).reshape(*x.shape[:-1], -1, 2)) yc = xc * rope.to(xc.dtype) y = torch.view_as_real(yc).reshape(*x.shape[:-1], -1) return y.to(x.dtype) class MQA(nn.Module): def __init__(self, config: GPT4DevConfig): super().__init__() h, d = config.num_attention_heads, config.hidden_size // config.num_attention_heads self.h, self.d = h, d self.qkv = nn.Linear(config.hidden_size, h * d + 2 * d, bias=config.qkv_bias) self.out = nn.Linear(config.hidden_size, config.hidden_size, bias=False) def forward( self, x: torch.Tensor, rope: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: B, T, _ = x.shape qkv = self.qkv(x) q, kv = qkv.split(self.h * self.d, dim=-1) k_new, v_new = kv.split(self.d, dim=-1) # (B, T, d) # queries to head dim; apply RoPE q = q.view(B, T, self.h, self.d).transpose(1, 2) # (B, h, T, d) q = apply_rope(q, rope) # rotate new k k_new = apply_rope(k_new.unsqueeze(1), rope).squeeze(1) # (B, T, d) # concat cache if past_kv is not None and past_kv[0] is not None: k_cat = torch.cat([past_kv[0], k_new], dim=1) v_cat = torch.cat([past_kv[1], v_new], dim=1) else: k_cat, v_cat = k_new, v_new # expand KV k_exp = k_cat.unsqueeze(1).expand(-1, self.h, -1, -1) # (B, h, S, d) v_exp = v_cat.unsqueeze(1).expand(-1, self.h, -1, -1) # (B, h, S, d) B, h, T, d = q.shape S = k_exp.size(2) past_len = S - T attn = torch.matmul(q, k_exp.transpose(-2, -1)) / math.sqrt(d) # Offset-aware causal mask idx_t = torch.arange(T, device=q.device)[:, None] idx_s = torch.arange(S, device=q.device)[None, :] mask = idx_s > idx_t + past_len attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf')) attn = F.softmax(attn, dim=-1) y = torch.matmul(attn, v_exp) y = y.transpose(1, 2).reshape(B, T, -1) return self.out(y), (k_cat, v_cat) def forward_compat(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor: B, T, _ = x.shape qkv = self.qkv(x) q, kv = qkv.split(self.h * self.d, dim=-1) k, v = kv.split(self.d, dim=-1) q = q.view(B, T, self.h, self.d).transpose(1, 2) # (B,h,T,d) k = k.unsqueeze(1).expand(-1, self.h, -1, -1) # (B,h,T,d) v = v.unsqueeze(1).expand(-1, self.h, -1, -1) # (B,h,T,d) q = apply_rope(q, rope) k = apply_rope(k, rope) y = F.scaled_dot_product_attention(q, k, v, is_causal=True) return self.out(y.transpose(1, 2).reshape(B, T, -1)) class SwiGLU(nn.Module): def __init__(self, hidden_dim, intermediate_dim): super().__init__() self.w1 = nn.Linear(hidden_dim, intermediate_dim * 2, bias=True) self.w2 = nn.Linear(intermediate_dim, hidden_dim, bias=False) def forward(self, x): x_g, x_v = self.w1(x).chunk(2, dim=-1) return self.w2(F.silu(x_g) * x_v) class Block(nn.Module): def __init__(self, config): super().__init__() self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = MQA(config) if config.multi_query else nn.MultiheadAttention( config.hidden_size, config.num_attention_heads, bias=config.qkv_bias, batch_first=True) self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mlp = SwiGLU(config.hidden_size, config.intermediate_size) self.gradient_checkpointing = False def forward( self, x: torch.Tensor, rope: torch.Tensor, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, use_checkpoint: bool = False, ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: def custom_forward(x_, rope_): a, new_kv = self.attn(self.ln1(x_), rope_, past_kv) x_ = x_ + a x_ = x_ + self.mlp(self.ln2(x_)) return x_, new_kv if use_checkpoint and self.training: y, new_kv = torch.utils.checkpoint.checkpoint(custom_forward, x, rope, use_reentrant=False) return y, new_kv else: return custom_forward(x, rope) def forward_compat(self, x: torch.Tensor, rope: torch.Tensor, use_checkpoint: bool = False) -> torch.Tensor: def custom_forward(x_, rope_): a = self.attn.forward_compat(self.ln1(x_), rope_) x_ = x_ + a x_ = x_ + self.mlp(self.ln2(x_)) return x_ if use_checkpoint and self.training: return torch.utils.checkpoint.checkpoint(custom_forward, x, rope, use_reentrant=False) else: return custom_forward(x, rope) class GPT4DevPreTrained(PreTrainedModel): config_class = GPT4DevConfig base_model_prefix = "transformer" supports_gradient_checkpointing = True _no_split_modules = ["Block"] def _init_weights(self, module): if isinstance(module, (nn.Linear, nn.Embedding)): nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) if isinstance(module, nn.Linear) and module.bias is not None: nn.init.zeros_(module.bias) class GPT4DevForCausalLM(GPT4DevPreTrained, GenerationMixin): def __init__(self, config): super().__init__(config) self.embed = nn.Embedding(config.vocab_size, config.hidden_size) self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)]) self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.rope_cache = None self.post_init() # embeddings tie helpers def get_input_embeddings(self): return self.embed def set_input_embeddings(self, new_embeddings): self.embed = new_embeddings if getattr(self.config, "tie_word_embeddings", True) and self.get_output_embeddings() is not None: with torch.no_grad(): self.get_output_embeddings().weight = self.embed.weight def get_output_embeddings(self): return self.head def set_output_embeddings(self, new_lm_head): self.head = new_lm_head def tie_weights(self): if getattr(self.config, "tie_word_embeddings", True): self.head.weight = self.embed.weight # generation helpers (legacy tuple KV-cache) def prepare_inputs_for_generation(self, input_ids, attention_mask=None, past_key_values=None, **kwargs): # Until compat_prefill_tokens, avoid slicing and ignore cache to mirror legacy behavior cutoff = int(getattr(self.config, "compat_prefill_tokens", 0) or 0) if past_key_values is not None and input_ids is not None and input_ids.size(1) < cutoff: past_key_values = None # drop cache, process full prefix elif past_key_values is not None: # normal cached decode path input_ids = input_ids[:, -1:] if attention_mask is not None and attention_mask.dim() == 2 and torch.all(attention_mask == 1): attention_mask = None return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": True} def _reorder_cache(self, past_key_values, beam_idx): if isinstance(past_key_values, (tuple, list)): reordered = [] for k, v in past_key_values: if k is None or v is None: reordered.append((k, v)) else: reordered.append((k.index_select(0, beam_idx), v.index_select(0, beam_idx))) return tuple(reordered) return past_key_values # RoPE utilities (kept float32 behavior to mirror training) def _rope_slice(self, past_len: int, T: int, device, dtype): if self.rope_cache is None or self.rope_cache.device != device: self.rope_cache = rope_cache( self.config.max_position_embeddings, self.config.hidden_size // self.config.num_attention_heads, self.config.rope_theta, device, dtype=torch.float32 ) need = past_len + T if need > self.rope_cache.size(0): self.rope_cache = rope_cache( self.config.max_position_embeddings, self.config.hidden_size // self.config.num_attention_heads, self.config.rope_theta, device, dtype=torch.float32 ) return self.rope_cache[past_len: past_len + T] def _set_gradient_checkpointing(self, module, value=False): if isinstance(module, Block): module.gradient_checkpointing = value def forward( self, input_ids, labels=None, attention_mask=None, past_key_values=None, use_cache=None, **kwargs, ): B, T = input_ids.shape x = self.embed(input_ids) past = past_key_values use_cache = True if (use_cache is None) else use_cache new_past: List[Tuple[torch.Tensor, torch.Tensor]] = [] if use_cache else None past_len = 0 if past is not None and isinstance(past, (tuple, list)) and past and past[0] is not None: past_len = past[0][0].size(1) rope = self._rope_slice(past_len, T, x.device, x.dtype) for i, blk in enumerate(self.blocks): pkv = None if past is None else (past[i] if i < len(past) else None) x, new_kv = blk(x, rope, past_kv=pkv, use_checkpoint=(self.is_gradient_checkpointing and self.training)) if use_cache and new_past is not None: new_past.append(new_kv) logits = self.head(self.ln_f(x)) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss = F.cross_entropy( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), ignore_index=-100, ) return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, past_key_values=tuple(new_past) if use_cache else None, ) GPT4DevConfig.auto_map = { "AutoConfig": "modeling_gpt4dev.GPT4DevConfig", "AutoModel": "modeling_gpt4dev.GPT4DevForCausalLM", "AutoModelForCausalLM": "modeling_gpt4dev.GPT4DevForCausalLM", }