Update modeling_motif.py (#6)
Browse files- Update modeling_motif.py (c5875d85cfee6820be55525739a8aec458661613)
- modeling_motif.py +2 -2
modeling_motif.py
CHANGED
|
@@ -399,7 +399,7 @@ class MotifAttention(nn.Module):
|
|
| 399 |
"removed and `position_embeddings` will be mandatory.")
|
| 400 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 401 |
else:
|
| 402 |
-
cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_seq_length(
|
| 403 |
if use_cache else position_embeddings)
|
| 404 |
|
| 405 |
query_states, key_states = apply_rotary_pos_emb(query_states,
|
|
@@ -534,7 +534,7 @@ class MotifFlashAttention2(MotifAttention):
|
|
| 534 |
"removed and `position_embeddings` will be mandatory.")
|
| 535 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 536 |
else:
|
| 537 |
-
cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_seq_length(
|
| 538 |
if use_cache else position_embeddings)
|
| 539 |
|
| 540 |
query_states, key_states = apply_rotary_pos_emb(query_states,
|
|
|
|
| 399 |
"removed and `position_embeddings` will be mandatory.")
|
| 400 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 401 |
else:
|
| 402 |
+
cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_seq_length(self.layer_idx))
|
| 403 |
if use_cache else position_embeddings)
|
| 404 |
|
| 405 |
query_states, key_states = apply_rotary_pos_emb(query_states,
|
|
|
|
| 534 |
"removed and `position_embeddings` will be mandatory.")
|
| 535 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 536 |
else:
|
| 537 |
+
cos, sin = (self.rotary_emb(value_states, q_len + past_key_value.get_seq_length(self.layer_idx))
|
| 538 |
if use_cache else position_embeddings)
|
| 539 |
|
| 540 |
query_states, key_states = apply_rotary_pos_emb(query_states,
|