Update modeling_rwkv6qwen2.py
Browse files- modeling_rwkv6qwen2.py +19 -66
modeling_rwkv6qwen2.py
CHANGED
|
@@ -423,7 +423,7 @@ class RWKV6Attention(nn.Module):
|
|
| 423 |
|
| 424 |
# dealing with left-padding
|
| 425 |
if attention_mask is not None:
|
| 426 |
-
v = v * attention_mask[:,
|
| 427 |
|
| 428 |
r = r.view(B,T,-1,N).to(v.dtype)
|
| 429 |
k = k.view(B,T,-1,N).to(v.dtype)
|
|
@@ -436,9 +436,6 @@ class RWKV6Attention(nn.Module):
|
|
| 436 |
output_final_state = not self.training and use_cache and past_key_values is not None
|
| 437 |
attn_output, output_kv_state = fused_recurrent_gla(r, k, v, log_w, None, scale, input_kv_state, output_final_state)
|
| 438 |
|
| 439 |
-
if output_final_state:
|
| 440 |
-
past_key_values.update(output_kv_state, output_shift_state, T, self.layer_idx)
|
| 441 |
-
|
| 442 |
attn_output = attn_output.view(B, T, -1)
|
| 443 |
if self.config.groupnorm_att:
|
| 444 |
attn_output = self.ln_x(attn_output.view(B * T, -1)).view(B, T, -1)
|
|
@@ -446,6 +443,9 @@ class RWKV6Attention(nn.Module):
|
|
| 446 |
attn_output = attn_output * g
|
| 447 |
attn_output = self.o_proj(attn_output)
|
| 448 |
|
|
|
|
|
|
|
|
|
|
| 449 |
return attn_output, attn_weights
|
| 450 |
|
| 451 |
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
|
@@ -680,36 +680,23 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
| 680 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 681 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 682 |
|
| 683 |
-
if self.gradient_checkpointing and self.training:
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
|
| 687 |
-
|
| 688 |
-
use_cache = False
|
| 689 |
-
|
| 690 |
-
# kept for BC (non `Cache` `past_key_values` inputs)
|
| 691 |
-
#return_legacy_cache = False
|
| 692 |
-
if use_cache and not isinstance(past_key_values, RWKV6State):
|
| 693 |
-
#return_legacy_cache = True
|
| 694 |
-
past_key_values = RWKV6State()
|
| 695 |
-
# if past_key_values is None:
|
| 696 |
-
# past_key_values = DynamicCache()
|
| 697 |
-
# else:
|
| 698 |
-
# past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 699 |
-
# logger.warning_once(
|
| 700 |
-
# "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and "
|
| 701 |
-
# "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class "
|
| 702 |
-
# "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)"
|
| 703 |
-
# )
|
| 704 |
|
| 705 |
if inputs_embeds is None:
|
| 706 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 707 |
|
| 708 |
-
if
|
| 709 |
-
|
| 710 |
-
|
| 711 |
-
|
| 712 |
-
|
|
|
|
|
|
|
|
|
|
| 713 |
|
| 714 |
if position_ids is None:
|
| 715 |
position_ids = cache_position.unsqueeze(0)
|
|
@@ -723,9 +710,10 @@ class RWKV6Qwen2Model(RWKV6Qwen2PreTrainedModel):
|
|
| 723 |
hidden_states = inputs_embeds
|
| 724 |
|
| 725 |
# create position embeddings to be shared across the decoder layers
|
| 726 |
-
position_embeddings = None
|
| 727 |
if self.config.use_rope:
|
| 728 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
|
|
|
|
|
|
| 729 |
|
| 730 |
# decoder layers
|
| 731 |
all_hidden_states = () if output_hidden_states else None
|
|
@@ -902,41 +890,6 @@ class RWKV6Qwen2ForCausalLM(RWKV6Qwen2PreTrainedModel, GenerationMixin):
|
|
| 902 |
attentions=outputs.attentions,
|
| 903 |
)
|
| 904 |
|
| 905 |
-
def prepare_inputs_for_generation(
|
| 906 |
-
self,
|
| 907 |
-
input_ids: torch.LongTensor,
|
| 908 |
-
past_key_values: Optional[Cache] = None,
|
| 909 |
-
attention_mask: Optional[torch.LongTensor] = None,
|
| 910 |
-
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 911 |
-
cache_position: Optional[torch.LongTensor] = None,
|
| 912 |
-
**kwargs,
|
| 913 |
-
):
|
| 914 |
-
# only last token for `inputs_ids` if the `past_key_values` is not empty.
|
| 915 |
-
if past_key_values is not None and len(past_key_values) > 0:
|
| 916 |
-
input_ids = input_ids[:, -1:]
|
| 917 |
-
|
| 918 |
-
model_inputs = {
|
| 919 |
-
'past_key_values': past_key_values,
|
| 920 |
-
'attention_mask': attention_mask,
|
| 921 |
-
'cache_position': cache_position,
|
| 922 |
-
}
|
| 923 |
-
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 924 |
-
if inputs_embeds is not None and past_key_values is None:
|
| 925 |
-
model_inputs['inputs_embeds'] = inputs_embeds
|
| 926 |
-
else:
|
| 927 |
-
# The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
|
| 928 |
-
# recompiles graphs as the stride of the inputs is a guard.
|
| 929 |
-
# Ref: https://github.com/huggingface/transformers/pull/29114
|
| 930 |
-
# TODO: use `next_tokens` directly instead.
|
| 931 |
-
model_inputs['input_ids'] = input_ids.contiguous()
|
| 932 |
-
|
| 933 |
-
model_inputs.update(**kwargs)
|
| 934 |
-
|
| 935 |
-
# 8. Remove unexpected `generate` inputs (TODO @joao: fix trainer and examples)
|
| 936 |
-
model_inputs.pop("labels", None)
|
| 937 |
-
|
| 938 |
-
return model_inputs
|
| 939 |
-
|
| 940 |
@add_start_docstrings(
|
| 941 |
"""
|
| 942 |
The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer).
|
|
|
|
| 423 |
|
| 424 |
# dealing with left-padding
|
| 425 |
if attention_mask is not None:
|
| 426 |
+
v = v * attention_mask[:, -v.shape[-2]:, None]
|
| 427 |
|
| 428 |
r = r.view(B,T,-1,N).to(v.dtype)
|
| 429 |
k = k.view(B,T,-1,N).to(v.dtype)
|
|
|
|
| 436 |
output_final_state = not self.training and use_cache and past_key_values is not None
|
| 437 |
attn_output, output_kv_state = fused_recurrent_gla(r, k, v, log_w, None, scale, input_kv_state, output_final_state)
|
| 438 |
|
|
|
|
|
|
|
|
|
|
| 439 |
attn_output = attn_output.view(B, T, -1)
|
| 440 |
if self.config.groupnorm_att:
|
| 441 |
attn_output = self.ln_x(attn_output.view(B * T, -1)).view(B, T, -1)
|
|
|
|
| 443 |
attn_output = attn_output * g
|
| 444 |
attn_output = self.o_proj(attn_output)
|
| 445 |
|
| 446 |
+
if output_final_state:
|
| 447 |
+
past_key_values.update(output_kv_state, output_shift_state, self.layer_idx, T)
|
| 448 |
+
|
| 449 |
return attn_output, attn_weights
|
| 450 |
|
| 451 |
class RWKV6Qwen2DecoderLayer(Qwen2DecoderLayer):
|
|
|
|
| 680 |
if (input_ids is None) ^ (inputs_embeds is not None):
|
| 681 |
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
| 682 |
|
| 683 |
+
if self.gradient_checkpointing and self.training and use_cache:
|
| 684 |
+
logger.warning_once(
|
| 685 |
+
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
| 686 |
+
)
|
| 687 |
+
use_cache = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
|
| 689 |
if inputs_embeds is None:
|
| 690 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 691 |
|
| 692 |
+
if use_cache and not isinstance(past_key_values, RWKV6State):
|
| 693 |
+
past_key_values = RWKV6State()
|
| 694 |
+
|
| 695 |
+
#if cache_position is None:
|
| 696 |
+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| 697 |
+
cache_position = torch.arange(
|
| 698 |
+
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
| 699 |
+
)
|
| 700 |
|
| 701 |
if position_ids is None:
|
| 702 |
position_ids = cache_position.unsqueeze(0)
|
|
|
|
| 710 |
hidden_states = inputs_embeds
|
| 711 |
|
| 712 |
# create position embeddings to be shared across the decoder layers
|
|
|
|
| 713 |
if self.config.use_rope:
|
| 714 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 715 |
+
else:
|
| 716 |
+
position_embeddings = None
|
| 717 |
|
| 718 |
# decoder layers
|
| 719 |
all_hidden_states = () if output_hidden_states else None
|
|
|
|
| 890 |
attentions=outputs.attentions,
|
| 891 |
)
|
| 892 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 893 |
@add_start_docstrings(
|
| 894 |
"""
|
| 895 |
The RWKV6Qwen2 Model transformer with a sequence classification head on top (linear layer).
|