fix sliding window merging
Browse files- modeling_phi4flash.py +7 -7
modeling_phi4flash.py
CHANGED
|
@@ -573,7 +573,7 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
| 573 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 574 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 575 |
|
| 576 |
-
use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx]
|
| 577 |
|
| 578 |
if past_key_value is not None:
|
| 579 |
|
|
@@ -710,8 +710,8 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
| 710 |
softmax_scale=softmax_scale,
|
| 711 |
causal=causal,
|
| 712 |
window_size=(
|
| 713 |
-
self.config.
|
| 714 |
-
self.config.
|
| 715 |
),
|
| 716 |
)
|
| 717 |
|
|
@@ -735,8 +735,8 @@ class SambaYFlashAttention2(SambaYAttention):
|
|
| 735 |
softmax_scale=softmax_scale,
|
| 736 |
causal=causal,
|
| 737 |
window_size=(
|
| 738 |
-
self.config.
|
| 739 |
-
self.config.
|
| 740 |
),
|
| 741 |
)
|
| 742 |
|
|
@@ -1085,9 +1085,9 @@ class SambaYDecoderLayer(nn.Module):
|
|
| 1085 |
residual = residual.to(torch.float32)
|
| 1086 |
self_attn_weights = None
|
| 1087 |
else:
|
| 1088 |
-
if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx]
|
| 1089 |
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
| 1090 |
-
attention_mask = attention_mask[:, -self.config.
|
| 1091 |
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
| 1092 |
# Self Attention
|
| 1093 |
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|
|
|
|
| 573 |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 574 |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 575 |
|
| 576 |
+
use_sliding_windows = self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] == "sliding_attention"
|
| 577 |
|
| 578 |
if past_key_value is not None:
|
| 579 |
|
|
|
|
| 710 |
softmax_scale=softmax_scale,
|
| 711 |
causal=causal,
|
| 712 |
window_size=(
|
| 713 |
+
self.config.sliding_window -1,
|
| 714 |
+
self.config.sliding_window -1,
|
| 715 |
),
|
| 716 |
)
|
| 717 |
|
|
|
|
| 735 |
softmax_scale=softmax_scale,
|
| 736 |
causal=causal,
|
| 737 |
window_size=(
|
| 738 |
+
self.config.sliding_window -1,
|
| 739 |
+
self.config.sliding_window -1,
|
| 740 |
),
|
| 741 |
)
|
| 742 |
|
|
|
|
| 1085 |
residual = residual.to(torch.float32)
|
| 1086 |
self_attn_weights = None
|
| 1087 |
else:
|
| 1088 |
+
if self.config.sliding_window is not None and self.config.layer_types[self.layer_idx] == "sliding_attention" and attention_mask is not None: # efficient SDPA and no padding
|
| 1089 |
if past_key_value is not None and cache_position[0] > 0: # when decoding
|
| 1090 |
+
attention_mask = attention_mask[:, -self.config.sliding_window:]
|
| 1091 |
#hidden_states = self.input_layernorm2(hidden_states.to(dtype=self.input_layernorm2.weight.dtype))
|
| 1092 |
# Self Attention
|
| 1093 |
attn_outputs, self_attn_weights, yoco_key_values = self.attn(
|