fix_rope_scaling
#28
by
haipingwu
- opened
- modeling_phi3_v.py +8 -1
modeling_phi3_v.py
CHANGED
|
@@ -441,7 +441,7 @@ class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding):
|
|
| 441 |
|
| 442 |
@torch.no_grad()
|
| 443 |
def forward(self, x, position_ids, seq_len=None):
|
| 444 |
-
seq_len = torch.max(position_ids) + 1
|
| 445 |
if seq_len > self.original_max_position_embeddings:
|
| 446 |
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
| 447 |
else:
|
|
@@ -1647,6 +1647,13 @@ class Phi3VForCausalLM(Phi3VPreTrainedModel):
|
|
| 1647 |
def prepare_inputs_for_generation(
|
| 1648 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs
|
| 1649 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1650 |
if past_key_values is not None:
|
| 1651 |
if isinstance(past_key_values, Cache):
|
| 1652 |
cache_length = past_key_values.get_seq_length()
|
|
|
|
| 441 |
|
| 442 |
@torch.no_grad()
|
| 443 |
def forward(self, x, position_ids, seq_len=None):
|
| 444 |
+
seq_len = seq_len or torch.max(position_ids) + 1
|
| 445 |
if seq_len > self.original_max_position_embeddings:
|
| 446 |
ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device)
|
| 447 |
else:
|
|
|
|
| 1647 |
def prepare_inputs_for_generation(
|
| 1648 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, pixel_values=None, image_sizes=None, **kwargs
|
| 1649 |
):
|
| 1650 |
+
# When the first time input length reached long and short factor switching point, enforce re-compute cache
|
| 1651 |
+
# It will cause downside of slower at this single token position, however, better than current failure.
|
| 1652 |
+
if past_key_values and self.config.rope_scaling and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1:
|
| 1653 |
+
past_length = past_key_values.seen_tokens if isinstance(past_key_values, Cache) else past_key_values[0][0].shape[2]
|
| 1654 |
+
if past_length <= self.config.original_max_position_embeddings:
|
| 1655 |
+
past_key_values = None
|
| 1656 |
+
|
| 1657 |
if past_key_values is not None:
|
| 1658 |
if isinstance(past_key_values, Cache):
|
| 1659 |
cache_length = past_key_values.get_seq_length()
|