|
|
import math, torch, torch.nn as nn, torch.nn.functional as F |
|
|
from transformers import PretrainedConfig, PreTrainedModel, GenerationMixin |
|
|
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions |
|
|
from typing import Optional, Tuple, List |
|
|
|
|
|
|
|
|
class GPT4DevConfig(PretrainedConfig): |
|
|
model_type = "gpt4dev" |
|
|
def __init__( |
|
|
self, |
|
|
vocab_size=50257, |
|
|
hidden_size=768, |
|
|
num_hidden_layers=12, |
|
|
num_attention_heads=12, |
|
|
intermediate_size=3072, |
|
|
max_position_embeddings=1024, |
|
|
rope_theta=10000.0, |
|
|
qkv_bias=True, |
|
|
layer_norm_epsilon=1e-5, |
|
|
initializer_range=0.02, |
|
|
multi_query=True, |
|
|
architectures=None, |
|
|
tie_word_embeddings=False, |
|
|
compat_prefill_tokens: int = 0, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__( |
|
|
vocab_size=vocab_size, |
|
|
hidden_size=hidden_size, |
|
|
num_hidden_layers=num_hidden_layers, |
|
|
num_attention_heads=num_attention_heads, |
|
|
intermediate_size=intermediate_size, |
|
|
max_position_embeddings=max_position_embeddings, |
|
|
rope_theta=rope_theta, |
|
|
qkv_bias=qkv_bias, |
|
|
layer_norm_epsilon=layer_norm_epsilon, |
|
|
initializer_range=initializer_range, |
|
|
multi_query=multi_query, |
|
|
architectures=architectures, |
|
|
tie_word_embeddings=tie_word_embeddings, |
|
|
compat_prefill_tokens=compat_prefill_tokens, |
|
|
**kwargs, |
|
|
) |
|
|
|
|
|
|
|
|
def rope_cache(seq_len, dim, theta, device, dtype=torch.float32): |
|
|
|
|
|
inv = 1.0 / (theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float32) / dim)) |
|
|
t = torch.arange(seq_len, device=device, dtype=torch.float32) |
|
|
freqs = torch.outer(t, inv) |
|
|
return torch.polar(torch.ones_like(freqs), freqs).to(dtype) |
|
|
|
|
|
|
|
|
def apply_rope(x, rope): |
|
|
|
|
|
xc = torch.view_as_complex(x.to(torch.float32).reshape(*x.shape[:-1], -1, 2)) |
|
|
yc = xc * rope.to(xc.dtype) |
|
|
y = torch.view_as_real(yc).reshape(*x.shape[:-1], -1) |
|
|
return y.to(x.dtype) |
|
|
|
|
|
|
|
|
class MQA(nn.Module): |
|
|
def __init__(self, config: GPT4DevConfig): |
|
|
super().__init__() |
|
|
h, d = config.num_attention_heads, config.hidden_size // config.num_attention_heads |
|
|
self.h, self.d = h, d |
|
|
self.qkv = nn.Linear(config.hidden_size, h * d + 2 * d, bias=config.qkv_bias) |
|
|
self.out = nn.Linear(config.hidden_size, config.hidden_size, bias=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
rope: torch.Tensor, |
|
|
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
B, T, _ = x.shape |
|
|
qkv = self.qkv(x) |
|
|
q, kv = qkv.split(self.h * self.d, dim=-1) |
|
|
k_new, v_new = kv.split(self.d, dim=-1) |
|
|
|
|
|
|
|
|
q = q.view(B, T, self.h, self.d).transpose(1, 2) |
|
|
q = apply_rope(q, rope) |
|
|
|
|
|
|
|
|
k_new = apply_rope(k_new.unsqueeze(1), rope).squeeze(1) |
|
|
|
|
|
|
|
|
if past_kv is not None and past_kv[0] is not None: |
|
|
k_cat = torch.cat([past_kv[0], k_new], dim=1) |
|
|
v_cat = torch.cat([past_kv[1], v_new], dim=1) |
|
|
else: |
|
|
k_cat, v_cat = k_new, v_new |
|
|
|
|
|
|
|
|
k_exp = k_cat.unsqueeze(1).expand(-1, self.h, -1, -1) |
|
|
v_exp = v_cat.unsqueeze(1).expand(-1, self.h, -1, -1) |
|
|
|
|
|
B, h, T, d = q.shape |
|
|
S = k_exp.size(2) |
|
|
past_len = S - T |
|
|
attn = torch.matmul(q, k_exp.transpose(-2, -1)) / math.sqrt(d) |
|
|
|
|
|
|
|
|
idx_t = torch.arange(T, device=q.device)[:, None] |
|
|
idx_s = torch.arange(S, device=q.device)[None, :] |
|
|
mask = idx_s > idx_t + past_len |
|
|
attn = attn.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf')) |
|
|
|
|
|
attn = F.softmax(attn, dim=-1) |
|
|
y = torch.matmul(attn, v_exp) |
|
|
y = y.transpose(1, 2).reshape(B, T, -1) |
|
|
return self.out(y), (k_cat, v_cat) |
|
|
|
|
|
def forward_compat(self, x: torch.Tensor, rope: torch.Tensor) -> torch.Tensor: |
|
|
B, T, _ = x.shape |
|
|
qkv = self.qkv(x) |
|
|
q, kv = qkv.split(self.h * self.d, dim=-1) |
|
|
k, v = kv.split(self.d, dim=-1) |
|
|
q = q.view(B, T, self.h, self.d).transpose(1, 2) |
|
|
k = k.unsqueeze(1).expand(-1, self.h, -1, -1) |
|
|
v = v.unsqueeze(1).expand(-1, self.h, -1, -1) |
|
|
q = apply_rope(q, rope) |
|
|
k = apply_rope(k, rope) |
|
|
y = F.scaled_dot_product_attention(q, k, v, is_causal=True) |
|
|
return self.out(y.transpose(1, 2).reshape(B, T, -1)) |
|
|
|
|
|
|
|
|
class SwiGLU(nn.Module): |
|
|
def __init__(self, hidden_dim, intermediate_dim): |
|
|
super().__init__() |
|
|
self.w1 = nn.Linear(hidden_dim, intermediate_dim * 2, bias=True) |
|
|
self.w2 = nn.Linear(intermediate_dim, hidden_dim, bias=False) |
|
|
def forward(self, x): |
|
|
x_g, x_v = self.w1(x).chunk(2, dim=-1) |
|
|
return self.w2(F.silu(x_g) * x_v) |
|
|
|
|
|
|
|
|
class Block(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.ln1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
|
|
self.attn = MQA(config) if config.multi_query else nn.MultiheadAttention( |
|
|
config.hidden_size, config.num_attention_heads, bias=config.qkv_bias, batch_first=True) |
|
|
self.ln2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
|
|
self.mlp = SwiGLU(config.hidden_size, config.intermediate_size) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
rope: torch.Tensor, |
|
|
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
|
use_checkpoint: bool = False, |
|
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
def custom_forward(x_, rope_): |
|
|
a, new_kv = self.attn(self.ln1(x_), rope_, past_kv) |
|
|
x_ = x_ + a |
|
|
x_ = x_ + self.mlp(self.ln2(x_)) |
|
|
return x_, new_kv |
|
|
if use_checkpoint and self.training: |
|
|
y, new_kv = torch.utils.checkpoint.checkpoint(custom_forward, x, rope, use_reentrant=False) |
|
|
return y, new_kv |
|
|
else: |
|
|
return custom_forward(x, rope) |
|
|
|
|
|
def forward_compat(self, x: torch.Tensor, rope: torch.Tensor, use_checkpoint: bool = False) -> torch.Tensor: |
|
|
def custom_forward(x_, rope_): |
|
|
a = self.attn.forward_compat(self.ln1(x_), rope_) |
|
|
x_ = x_ + a |
|
|
x_ = x_ + self.mlp(self.ln2(x_)) |
|
|
return x_ |
|
|
if use_checkpoint and self.training: |
|
|
return torch.utils.checkpoint.checkpoint(custom_forward, x, rope, use_reentrant=False) |
|
|
else: |
|
|
return custom_forward(x, rope) |
|
|
|
|
|
|
|
|
class GPT4DevPreTrained(PreTrainedModel): |
|
|
config_class = GPT4DevConfig |
|
|
base_model_prefix = "transformer" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["Block"] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
|
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
|
|
|
class GPT4DevForCausalLM(GPT4DevPreTrained, GenerationMixin): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.embed = nn.Embedding(config.vocab_size, config.hidden_size) |
|
|
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_hidden_layers)]) |
|
|
self.ln_f = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) |
|
|
self.head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
self.rope_cache = None |
|
|
self.post_init() |
|
|
|
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embed |
|
|
def set_input_embeddings(self, new_embeddings): |
|
|
self.embed = new_embeddings |
|
|
if getattr(self.config, "tie_word_embeddings", True) and self.get_output_embeddings() is not None: |
|
|
with torch.no_grad(): |
|
|
self.get_output_embeddings().weight = self.embed.weight |
|
|
def get_output_embeddings(self): |
|
|
return self.head |
|
|
def set_output_embeddings(self, new_lm_head): |
|
|
self.head = new_lm_head |
|
|
def tie_weights(self): |
|
|
if getattr(self.config, "tie_word_embeddings", True): |
|
|
self.head.weight = self.embed.weight |
|
|
|
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, past_key_values=None, **kwargs): |
|
|
|
|
|
cutoff = int(getattr(self.config, "compat_prefill_tokens", 0) or 0) |
|
|
if past_key_values is not None and input_ids is not None and input_ids.size(1) < cutoff: |
|
|
past_key_values = None |
|
|
elif past_key_values is not None: |
|
|
|
|
|
input_ids = input_ids[:, -1:] |
|
|
if attention_mask is not None and attention_mask.dim() == 2 and torch.all(attention_mask == 1): |
|
|
attention_mask = None |
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values, "use_cache": True} |
|
|
|
|
|
def _reorder_cache(self, past_key_values, beam_idx): |
|
|
if isinstance(past_key_values, (tuple, list)): |
|
|
reordered = [] |
|
|
for k, v in past_key_values: |
|
|
if k is None or v is None: |
|
|
reordered.append((k, v)) |
|
|
else: |
|
|
reordered.append((k.index_select(0, beam_idx), v.index_select(0, beam_idx))) |
|
|
return tuple(reordered) |
|
|
return past_key_values |
|
|
|
|
|
|
|
|
def _rope_slice(self, past_len: int, T: int, device, dtype): |
|
|
if self.rope_cache is None or self.rope_cache.device != device: |
|
|
self.rope_cache = rope_cache( |
|
|
self.config.max_position_embeddings, |
|
|
self.config.hidden_size // self.config.num_attention_heads, |
|
|
self.config.rope_theta, device, dtype=torch.float32 |
|
|
) |
|
|
need = past_len + T |
|
|
if need > self.rope_cache.size(0): |
|
|
self.rope_cache = rope_cache( |
|
|
self.config.max_position_embeddings, |
|
|
self.config.hidden_size // self.config.num_attention_heads, |
|
|
self.config.rope_theta, device, dtype=torch.float32 |
|
|
) |
|
|
return self.rope_cache[past_len: past_len + T] |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if isinstance(module, Block): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids, |
|
|
labels=None, |
|
|
attention_mask=None, |
|
|
past_key_values=None, |
|
|
use_cache=None, |
|
|
**kwargs, |
|
|
): |
|
|
B, T = input_ids.shape |
|
|
x = self.embed(input_ids) |
|
|
|
|
|
past = past_key_values |
|
|
use_cache = True if (use_cache is None) else use_cache |
|
|
new_past: List[Tuple[torch.Tensor, torch.Tensor]] = [] if use_cache else None |
|
|
|
|
|
past_len = 0 |
|
|
if past is not None and isinstance(past, (tuple, list)) and past and past[0] is not None: |
|
|
past_len = past[0][0].size(1) |
|
|
|
|
|
rope = self._rope_slice(past_len, T, x.device, x.dtype) |
|
|
for i, blk in enumerate(self.blocks): |
|
|
pkv = None if past is None else (past[i] if i < len(past) else None) |
|
|
x, new_kv = blk(x, rope, past_kv=pkv, use_checkpoint=(self.is_gradient_checkpointing and self.training)) |
|
|
if use_cache and new_past is not None: |
|
|
new_past.append(new_kv) |
|
|
|
|
|
logits = self.head(self.ln_f(x)) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
shift_logits = logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
loss = F.cross_entropy( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), |
|
|
shift_labels.view(-1), |
|
|
ignore_index=-100, |
|
|
) |
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=tuple(new_past) if use_cache else None, |
|
|
) |
|
|
|
|
|
|
|
|
GPT4DevConfig.auto_map = { |
|
|
"AutoConfig": "modeling_gpt4dev.GPT4DevConfig", |
|
|
"AutoModel": "modeling_gpt4dev.GPT4DevForCausalLM", |
|
|
"AutoModelForCausalLM": "modeling_gpt4dev.GPT4DevForCausalLM", |
|
|
} |
|
|
|
|
|
|