GPT4-dev-177M-1511 / modeling_gpt4dev.py
k050506koch's picture
Uploaded weights and code
9b5b518 verified
raw
history blame
12.8 kB
import math, torch, torch.nn as nn, torch.nn.functional as F
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
from typing import Optional, Tuple, List
class GPT4DevConfig(PretrainedConfig):
model_type = "gpt4dev"
def __init__(
self,
vocab_size=50257,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
max_position_embeddings=1024,
rope_theta=10000.0,
qkv_bias=True,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
multi_query=True,
architectures=None,
tie_word_embeddings=False,
compat_prefill_tokens: int = 0,
**kwargs,
):
super().__init__(
vocab_size=vocab_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
max_position_embeddings=max_position_embeddings,
rope_theta=rope_theta,
qkv_bias=qkv_bias,
layer_norm_epsilon=layer_norm_epsilon,
initializer_range=initializer_range,
multi_query=multi_query,
architectures=architectures,
tie_word_embeddings=tie_word_embeddings,
compat_prefill_tokens=compat_prefill_tokens,
**kwargs,
)
def rope_cache(seq_len, dim, theta, device, dtype=torch.float32):
# Note: kept float32 to match training-time math used in early checkpoints
inv = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim))
t = torch.arange(seq_len, device=device, dtype=torch.float32)
freqs = torch.outer(t, inv)
return torch.polar(torch.ones_like(freqs), freqs).to(dtype)
def apply_rope(x, rope):
# x: (..., D) with D even; rope: (T, D/2). In legacy math this can be float (cos-only)
xc = torch.view_as_complex(x.to(torch.float32).reshape(*x.shape[:-1], -1, 2))
yc = xc * rope.to(xc.dtype)
y = torch.view_as_real(yc).reshape(*x.shape[:-1], -1)
return y.to(x.dtype)
class MQA(nn.Module):
def __init__(self, config: GPT4DevConfig):
super().__init__()
h, d = config.num_attention_heads, config.hidden_size // config.num_attention_heads
self.h, self.d = h, d
self.qkv = nn.Linear(config.hidden_size, h * d + 2 * d, bias=config.qkv_bias)
self.out = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
def forward(
self,
x: torch.Tensor,
rope: torch.Tensor,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
B, T, _ = x.shape
qkv = self.qkv(x)
q, kv = qkv.split(self.h * self.d, dim=-1)
k_new, v_new = kv.split(self.d, dim=-1) # (B, T, d)
# queries to head dim; apply RoPE
q = q.view(B, T, self.h, self.d).transpose(1, 2) # (B, h, T, d)
q = apply_rope(q, rope)
# rotate new k
k_new = apply_rope(k_new.unsqueeze(1), rope).squeeze(1) # (B, T, d)
# concat cache
if past_kv is not None and past_kv[0] is not None:
k_cat = torch.cat([past_kv[0], k_new], dim=1)
v_cat = torch.cat([past_kv[1], v_new], dim=1)
else:
k_cat, v_cat = k_new, v_new
# expand KV
k_exp = k_cat.unsqueeze(1).expand(-1, self.h, -1, -1) # (B, h, S, d)
v_exp = v_cat.unsqueeze(1).expand(-1, self.h, -1, -1) # (B, h, S, d)
B, h, T, d = q.shape
S = k_exp.size(2)
past_len = S - T
attn = torch.matmul(q, k_exp.transpose(-2, -1)) / math.sqrt(d)
# Offset-aware causal mask
idx_t = torch.arange(T, device=q.device)[:, None]
idx_s = torch.arange(S, device=q.device)[None, :]
mask = idx_s > idx_t + past_len
attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
attn = F.softmax(attn, dim=-1)
y = torch.matmul(attn, v_exp)
y = y.transpose(1, 2).reshape(B, T, -1)
return self.out(y), (k_cat, v_cat)
def forward_compat(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor:
B, T, _ = x.shape
qkv = self.qkv(x)
q, kv = qkv.split(self.h * self.d, dim=-1)
k, v = kv.split(self.d, dim=-1)
q = q.view(B, T, self.h, self.d).transpose(1, 2) # (B,h,T,d)
k = k.unsqueeze(1).expand(-1, self.h, -1, -1) # (B,h,T,d)
v = v.unsqueeze(1).expand(-1, self.h, -1, -1) # (B,h,T,d)
q = apply_rope(q, rope)
k = apply_rope(k, rope)
y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
return self.out(y.transpose(1, 2).reshape(B, T, -1))
class SwiGLU(nn.Module):
def __init__(self, hidden_dim, intermediate_dim):
super().__init__()
self.w1 = nn.Linear(hidden_dim, intermediate_dim * 2, bias=True)
self.w2 = nn.Linear(intermediate_dim, hidden_dim, bias=False)
def forward(self, x):
x_g, x_v = self.w1(x).chunk(2, dim=-1)
return self.w2(F.silu(x_g) * x_v)
class Block(nn.Module):
def __init__(self, config):
super().__init__()
self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.attn = MQA(config) if config.multi_query else nn.MultiheadAttention(
config.hidden_size, config.num_attention_heads, bias=config.qkv_bias, batch_first=True)
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.mlp = SwiGLU(config.hidden_size, config.intermediate_size)
self.gradient_checkpointing = False
def forward(
self,
x: torch.Tensor,
rope: torch.Tensor,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_checkpoint: bool = False,
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
def custom_forward(x_, rope_):
a, new_kv = self.attn(self.ln1(x_), rope_, past_kv)
x_ = x_ + a
x_ = x_ + self.mlp(self.ln2(x_))
return x_, new_kv
if use_checkpoint and self.training:
y, new_kv = torch.utils.checkpoint.checkpoint(custom_forward, x, rope, use_reentrant=False)
return y, new_kv
else:
return custom_forward(x, rope)
def forward_compat(self, x: torch.Tensor, rope: torch.Tensor, use_checkpoint: bool = False) -> torch.Tensor:
def custom_forward(x_, rope_):
a = self.attn.forward_compat(self.ln1(x_), rope_)
x_ = x_ + a
x_ = x_ + self.mlp(self.ln2(x_))
return x_
if use_checkpoint and self.training:
return torch.utils.checkpoint.checkpoint(custom_forward, x, rope, use_reentrant=False)
else:
return custom_forward(x, rope)
class GPT4DevPreTrained(PreTrainedModel):
config_class = GPT4DevConfig
base_model_prefix = "transformer"
supports_gradient_checkpointing = True
_no_split_modules = ["Block"]
def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Embedding)):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if isinstance(module, nn.Linear) and module.bias is not None:
nn.init.zeros_(module.bias)
class GPT4DevForCausalLM(GPT4DevPreTrained, GenerationMixin):
def __init__(self, config):
super().__init__(config)
self.embed = nn.Embedding(config.vocab_size, config.hidden_size)
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)])
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.rope_cache = None
self.post_init()
# embeddings tie helpers
def get_input_embeddings(self):
return self.embed
def set_input_embeddings(self, new_embeddings):
self.embed = new_embeddings
if getattr(self.config, "tie_word_embeddings", True) and self.get_output_embeddings() is not None:
with torch.no_grad():
self.get_output_embeddings().weight = self.embed.weight
def get_output_embeddings(self):
return self.head
def set_output_embeddings(self, new_lm_head):
self.head = new_lm_head
def tie_weights(self):
if getattr(self.config, "tie_word_embeddings", True):
self.head.weight = self.embed.weight
# generation helpers (legacy tuple KV-cache)
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, past_key_values=None, **kwargs):
# Until compat_prefill_tokens, avoid slicing and ignore cache to mirror legacy behavior
cutoff = int(getattr(self.config, "compat_prefill_tokens", 0) or 0)
if past_key_values is not None and input_ids is not None and input_ids.size(1) < cutoff:
past_key_values = None # drop cache, process full prefix
elif past_key_values is not None:
# normal cached decode path
input_ids = input_ids[:, -1:]
if attention_mask is not None and attention_mask.dim() == 2 and torch.all(attention_mask == 1):
attention_mask = None
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": True}
def _reorder_cache(self, past_key_values, beam_idx):
if isinstance(past_key_values, (tuple, list)):
reordered = []
for k, v in past_key_values:
if k is None or v is None:
reordered.append((k, v))
else:
reordered.append((k.index_select(0, beam_idx), v.index_select(0, beam_idx)))
return tuple(reordered)
return past_key_values
# RoPE utilities (kept float32 behavior to mirror training)
def _rope_slice(self, past_len: int, T: int, device, dtype):
if self.rope_cache is None or self.rope_cache.device != device:
self.rope_cache = rope_cache(
self.config.max_position_embeddings,
self.config.hidden_size // self.config.num_attention_heads,
self.config.rope_theta, device, dtype=torch.float32
)
need = past_len + T
if need > self.rope_cache.size(0):
self.rope_cache = rope_cache(
self.config.max_position_embeddings,
self.config.hidden_size // self.config.num_attention_heads,
self.config.rope_theta, device, dtype=torch.float32
)
return self.rope_cache[past_len: past_len + T]
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, Block):
module.gradient_checkpointing = value
def forward(
self,
input_ids,
labels=None,
attention_mask=None,
past_key_values=None,
use_cache=None,
**kwargs,
):
B, T = input_ids.shape
x = self.embed(input_ids)
past = past_key_values
use_cache = True if (use_cache is None) else use_cache
new_past: List[Tuple[torch.Tensor, torch.Tensor]] = [] if use_cache else None
past_len = 0
if past is not None and isinstance(past, (tuple, list)) and past and past[0] is not None:
past_len = past[0][0].size(1)
rope = self._rope_slice(past_len, T, x.device, x.dtype)
for i, blk in enumerate(self.blocks):
pkv = None if past is None else (past[i] if i < len(past) else None)
x, new_kv = blk(x, rope, past_kv=pkv, use_checkpoint=(self.is_gradient_checkpointing and self.training))
if use_cache and new_past is not None:
new_past.append(new_kv)
logits = self.head(self.ln_f(x))
loss = None
if labels is not None:
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = F.cross_entropy(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1),
ignore_index=-100,
)
return CausalLMOutputWithCrossAttentions(
loss=loss,
logits=logits,
past_key_values=tuple(new_past) if use_cache else None,
)
GPT4DevConfig.auto_map = {
"AutoConfig": "modeling_gpt4dev.GPT4DevConfig",
"AutoModel": "modeling_gpt4dev.GPT4DevForCausalLM",
"AutoModelForCausalLM": "modeling_gpt4dev.GPT4DevForCausalLM",
}