|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers.cache_utils import Cache |
|
|
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs |
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel |
|
|
from transformers.processing_utils import Unpack |
|
|
from transformers.utils import LossKwargs |
|
|
|
|
|
from transformers.models.gemma2.modeling_gemma2 import ( |
|
|
repeat_kv, |
|
|
apply_rotary_pos_emb, |
|
|
eager_attention_forward |
|
|
) |
|
|
|
|
|
|
|
|
class MLAAttention(nn.Module): |
|
|
""" |
|
|
Modified from `transformers.models.llama.modeling_deepseek_v3.DeepseekV3Attention` |
|
|
add support for attention bias and softcapping |
|
|
""" |
|
|
def __init__(self, config, layer_idx: int): |
|
|
|
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads |
|
|
self.attention_dropout = config.attention_dropout |
|
|
self.num_heads = config.num_attention_heads |
|
|
self.rope_theta = config.rope_theta |
|
|
self.q_lora_rank = config.q_lora_rank |
|
|
self.kv_lora_rank = config.kv_lora_rank |
|
|
self.qk_rope_head_dim = config.qk_rope_head_dim |
|
|
self.qk_nope_head_dim = config.qk_nope_head_dim |
|
|
self.v_head_dim = config.v_head_dim |
|
|
self.qk_head_dim = config.qk_head_dim |
|
|
self.softcap = config.softcap |
|
|
|
|
|
self.is_causal = True |
|
|
if self.q_lora_rank is None: |
|
|
self.q_proj = nn.Linear(config.hidden_size, self.num_heads * self.qk_head_dim, bias=config.attention_bias) |
|
|
else: |
|
|
self.q_a_proj = nn.Linear(config.hidden_size, config.q_lora_rank, bias=False) |
|
|
self.q_b_proj = nn.Linear(config.q_lora_rank, self.num_heads * self.qk_head_dim, bias=config.attention_bias) |
|
|
|
|
|
self.kv_a_proj_with_mqa = nn.Linear( |
|
|
config.hidden_size, |
|
|
self.kv_lora_rank + self.qk_rope_head_dim, |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
self.kv_b_proj = nn.Linear( |
|
|
self.kv_lora_rank, |
|
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), |
|
|
bias=config.attention_bias, |
|
|
) |
|
|
|
|
|
self.o_proj = nn.Linear( |
|
|
self.num_heads * self.v_head_dim, |
|
|
config.hidden_size, |
|
|
bias=False, |
|
|
) |
|
|
|
|
|
self.scaling = self.config.query_pre_attn_scalar ** (-0.5) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor], |
|
|
attention_mask: Optional[torch.Tensor], |
|
|
past_key_value: Optional[Cache] = None, |
|
|
cache_position: Optional[torch.LongTensor] = None, |
|
|
**kwargs: Unpack[FlashAttentionKwargs], |
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
batch_size, seq_length = hidden_states.shape[:-1] |
|
|
query_shape = (batch_size, seq_length, -1, self.qk_head_dim) |
|
|
key_shape = (batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) |
|
|
if self.q_lora_rank is None: |
|
|
q_states = self.q_proj(hidden_states) |
|
|
else: |
|
|
q_states = self.q_b_proj(self.q_a_proj(hidden_states)) |
|
|
q_states = q_states.view(query_shape).transpose(1, 2) |
|
|
q_pass, q_rot = torch.split(q_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) |
|
|
compressed_kv = self.kv_a_proj_with_mqa(hidden_states) |
|
|
k_pass, k_rot = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) |
|
|
|
|
|
k_pass = self.kv_b_proj(k_pass).view(key_shape).transpose(1, 2) |
|
|
k_pass, value_states = torch.split(k_pass, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) |
|
|
|
|
|
k_rot = k_rot.view(batch_size, 1, seq_length, self.qk_rope_head_dim) |
|
|
|
|
|
cos, sin = position_embeddings |
|
|
q_rot, k_rot = apply_rotary_pos_emb(q_rot, k_rot, cos, sin) |
|
|
k_rot = k_rot.expand(*k_pass.shape[:-1], -1) |
|
|
|
|
|
query_states = torch.cat((q_pass, q_rot), dim=-1) |
|
|
key_states = torch.cat((k_pass, k_rot), dim=-1) |
|
|
|
|
|
if past_key_value is not None: |
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
|
|
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: |
|
|
value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) |
|
|
|
|
|
attention_interface = eager_attention_forward |
|
|
if self.config._attn_implementation != "eager": |
|
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False): |
|
|
logger.warning_once( |
|
|
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to " |
|
|
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' |
|
|
) |
|
|
else: |
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] |
|
|
|
|
|
attn_output, attn_weights = attention_interface( |
|
|
self, |
|
|
query_states, |
|
|
key_states, |
|
|
value_states, |
|
|
attention_mask, |
|
|
dropout=0.0 if not self.training else self.attention_dropout, |
|
|
scaling=self.scaling, |
|
|
softcap=self.softcap, |
|
|
**kwargs, |
|
|
) |
|
|
if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: |
|
|
attn_output = attn_output[:, :, :, : self.v_head_dim] |
|
|
attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() |
|
|
attn_output = self.o_proj(attn_output) |
|
|
return attn_output, attn_weights |
|
|
|