|
|
from tools import * |
|
|
from torch import nn |
|
|
import torch |
|
|
class GroupedQueryAttention(nn.Module): |
|
|
def __init__( |
|
|
self, d_in, d_out, num_heads, |
|
|
num_kv_groups, |
|
|
dtype=None |
|
|
): |
|
|
super().__init__() |
|
|
assert d_out % num_heads == 0, "d_out must be divisible by num_heads" |
|
|
assert num_heads % num_kv_groups == 0, "num_heads must be divisible by num_kv_groups" |
|
|
|
|
|
self.d_out = d_out |
|
|
self.num_heads = num_heads |
|
|
self.head_dim = d_out // num_heads |
|
|
|
|
|
self.W_key = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) |
|
|
self.W_value = nn.Linear(d_in, num_kv_groups * self.head_dim, bias=False, dtype=dtype) |
|
|
self.num_kv_groups = num_kv_groups |
|
|
self.group_size = num_heads // num_kv_groups |
|
|
|
|
|
self.W_query = nn.Linear(d_in, d_out, bias=False, dtype=dtype) |
|
|
self.out_proj = nn.Linear(d_out, d_out, bias=False, dtype=dtype) |
|
|
|
|
|
def forward(self, x, mask, cos, sin): |
|
|
b, num_tokens, d_in = x.shape |
|
|
|
|
|
queries = self.W_query(x) |
|
|
keys = self.W_key(x) |
|
|
values = self.W_value(x) |
|
|
|
|
|
|
|
|
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim) |
|
|
keys = keys.view(b, num_tokens, self.num_kv_groups, self.head_dim) |
|
|
values = values.view(b, num_tokens, self.num_kv_groups, self.head_dim) |
|
|
|
|
|
|
|
|
keys = keys.transpose(1, 2) |
|
|
values = values.transpose(1, 2) |
|
|
queries = queries.transpose(1, 2) |
|
|
|
|
|
|
|
|
keys = apply_rope(keys, cos, sin) |
|
|
queries = apply_rope(queries, cos, sin) |
|
|
|
|
|
|
|
|
|
|
|
keys = keys.repeat_interleave(self.group_size, dim=1) |
|
|
values = values.repeat_interleave(self.group_size, dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
attn_scores = queries @ keys.transpose(2, 3) |
|
|
|
|
|
|
|
|
attn_scores = attn_scores.masked_fill(mask, -torch.inf) |
|
|
|
|
|
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1) |
|
|
assert keys.shape[-1] == self.head_dim |
|
|
|
|
|
|
|
|
context_vec = (attn_weights @ values).transpose(1, 2) |
|
|
|
|
|
|
|
|
context_vec = context_vec.reshape(b, num_tokens, self.d_out) |
|
|
context_vec = self.out_proj(context_vec) |
|
|
|
|
|
return context_vec |