from typing import Optional, Tuple, Union import torch from torch import nn from transformers.cache_utils import Cache from transformers.modeling_flash_attention_utils import FlashAttentionKwargs from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from transformers.processing_utils import Unpack from transformers.utils import LossKwargs from transformers.models.llama.modeling_llama import ( LlamaModel, LlamaDecoderLayer, LlamaPreTrainedModel, LlamaForCausalLM ) from .configuration_llamamla import LlamaMLAConfig from .mla import MLAAttention, eager_attention_forward class LlamaMLADecoderLayer(LlamaDecoderLayer): def __init__(self, config: LlamaMLAConfig, layer_idx: int): super().__init__(config, layer_idx) self.self_attn = MLAAttention(config, layer_idx) class LlamaMLAPreTrainedModel(LlamaPreTrainedModel): config_class = LlamaMLAConfig _no_split_modules = ["LlamaMLADecoderLayer"] class LlamaMLAModel(LlamaMLAPreTrainedModel, LlamaModel): def __init__(self, config: LlamaMLAConfig): super().__init__(config) self.layers = nn.ModuleList( [LlamaMLADecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ) class LlamaMLAForCausalLM(LlamaMLAPreTrainedModel, LlamaForCausalLM): def __init__(self, config): super().__init__(config) self.model = LlamaMLAModel(config) __all__ = [ "LlamaMLAForCausalLM", "LlamaMLAModel", "LlamaMLAPreTrainedModel", ]