File size: 810 Bytes
65834d7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
from transformers.models.llama.configuration_llama import LlamaConfig
class LlamaMLAConfig(LlamaConfig):
model_type = "llamamla"
def __init__(
self,
*args,
kv_lora_rank=512,
q_lora_rank=None,
qk_rope_head_dim=64,
qk_nope_head_dim=128,
v_head_dim=128,
query_pre_attn_scalar=128,
softcap=None,
**kwargs
):
super().__init__(*args, **kwargs)
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim
self.v_head_dim = v_head_dim
self.query_pre_attn_scalar = query_pre_attn_scalar
self.softcap = softcap |