Update modeling_qwen.py
#5
by
ctranslate2-4you
- opened
- modeling_qwen.py +55 -15
modeling_qwen.py
CHANGED
|
@@ -17,6 +17,9 @@
|
|
| 17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
# See the License for the specific language governing permissions and
|
| 19 |
# limitations under the License.
|
|
|
|
|
|
|
|
|
|
| 20 |
""" PyTorch Qwen2 model."""
|
| 21 |
from transformers import Qwen2Config
|
| 22 |
import inspect
|
|
@@ -274,7 +277,9 @@ class Qwen2Attention(nn.Module):
|
|
| 274 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 275 |
"with a layer index."
|
| 276 |
)
|
| 277 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
|
|
| 278 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 279 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 280 |
|
|
@@ -378,7 +383,9 @@ class Qwen2FlashAttention2(Qwen2Attention):
|
|
| 378 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 379 |
"with a layer index."
|
| 380 |
)
|
| 381 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
|
|
| 382 |
|
| 383 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 384 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
|
@@ -676,7 +683,9 @@ class Qwen2SdpaAttention(Qwen2Attention):
|
|
| 676 |
|
| 677 |
kv_seq_len = key_states.shape[-2]
|
| 678 |
if past_key_value is not None:
|
| 679 |
-
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
|
|
|
|
|
|
| 680 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 681 |
|
| 682 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
@@ -972,7 +981,6 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 972 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 973 |
)
|
| 974 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 975 |
-
|
| 976 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 977 |
|
| 978 |
# retrieve input_ids and inputs_embeds
|
|
@@ -993,12 +1001,28 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 993 |
use_cache = False
|
| 994 |
|
| 995 |
past_key_values_length = 0
|
|
|
|
| 996 |
|
| 997 |
if use_cache:
|
| 998 |
-
|
| 999 |
-
if
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1000 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1001 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1002 |
|
| 1003 |
if position_ids is None:
|
| 1004 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
@@ -1104,7 +1128,10 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 1104 |
|
| 1105 |
next_cache = None
|
| 1106 |
if use_cache:
|
| 1107 |
-
|
|
|
|
|
|
|
|
|
|
| 1108 |
|
| 1109 |
if not return_dict:
|
| 1110 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
@@ -1116,6 +1143,7 @@ class Qwen2Model(Qwen2PreTrainedModel):
|
|
| 1116 |
)
|
| 1117 |
|
| 1118 |
|
|
|
|
| 1119 |
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
| 1120 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1121 |
|
|
@@ -1243,21 +1271,32 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
| 1243 |
# Omit tokens covered by past_key_values
|
| 1244 |
if past_key_values is not None:
|
| 1245 |
if isinstance(past_key_values, Cache):
|
|
|
|
| 1246 |
cache_length = past_key_values.get_seq_length()
|
| 1247 |
-
past_length =
|
| 1248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1249 |
else:
|
|
|
|
|
|
|
| 1250 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1251 |
max_cache_length = None
|
| 1252 |
|
| 1253 |
# Keep only the unprocessed tokens:
|
| 1254 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1255 |
-
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
|
| 1256 |
-
# input)
|
| 1257 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 1258 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1259 |
-
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
|
| 1260 |
-
# input_ids based on the past_length.
|
| 1261 |
elif past_length < input_ids.shape[1]:
|
| 1262 |
input_ids = input_ids[:, past_length:]
|
| 1263 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
@@ -1287,13 +1326,14 @@ class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
|
| 1287 |
model_inputs.update(
|
| 1288 |
{
|
| 1289 |
"position_ids": position_ids,
|
| 1290 |
-
"past_key_values": past_key_values,
|
| 1291 |
"use_cache": kwargs.get("use_cache"),
|
| 1292 |
"attention_mask": attention_mask,
|
| 1293 |
}
|
| 1294 |
)
|
| 1295 |
return model_inputs
|
| 1296 |
|
|
|
|
| 1297 |
@staticmethod
|
| 1298 |
def _reorder_cache(past_key_values, beam_idx):
|
| 1299 |
reordered_past = ()
|
|
|
|
| 17 |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 18 |
# See the License for the specific language governing permissions and
|
| 19 |
# limitations under the License.
|
| 20 |
+
|
| 21 |
+
# includes edits by https://github.com/BBC-Esq to fix cache errors following transformers version post 4.53.3 major cache refactor
|
| 22 |
+
|
| 23 |
""" PyTorch Qwen2 model."""
|
| 24 |
from transformers import Qwen2Config
|
| 25 |
import inspect
|
|
|
|
| 277 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 278 |
"with a layer index."
|
| 279 |
)
|
| 280 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 281 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 282 |
+
kv_seq_len += past_len
|
| 283 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 284 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
| 285 |
|
|
|
|
| 383 |
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 384 |
"with a layer index."
|
| 385 |
)
|
| 386 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 387 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 388 |
+
kv_seq_len += past_len
|
| 389 |
|
| 390 |
# Because the input can be padded, the absolute sequence length depends on the max position id.
|
| 391 |
rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1
|
|
|
|
| 683 |
|
| 684 |
kv_seq_len = key_states.shape[-2]
|
| 685 |
if past_key_value is not None:
|
| 686 |
+
# kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 687 |
+
past_len = past_key_value.get_seq_length(self.layer_idx) if past_key_value is not None else 0
|
| 688 |
+
kv_seq_len += past_len
|
| 689 |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 690 |
|
| 691 |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
|
|
|
|
| 981 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 982 |
)
|
| 983 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
|
|
|
| 984 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 985 |
|
| 986 |
# retrieve input_ids and inputs_embeds
|
|
|
|
| 1001 |
use_cache = False
|
| 1002 |
|
| 1003 |
past_key_values_length = 0
|
| 1004 |
+
use_legacy_cache = False
|
| 1005 |
|
| 1006 |
if use_cache:
|
| 1007 |
+
# OLD behavior (removed in HF >= 4.55): treat anything not Cache as "legacy" but then
|
| 1008 |
+
# directly used legacy methods on it (would crash if None or new API).
|
| 1009 |
+
# use_legacy_cache = not isinstance(past_key_values, Cache)
|
| 1010 |
+
# if use_legacy_cache:
|
| 1011 |
+
# # past_key_values_length = past_key_values.get_seq_length()
|
| 1012 |
+
# past_key_values_length = past_key_values.get_usable_length(seq_length)
|
| 1013 |
+
|
| 1014 |
+
# NEW behavior: if a legacy tuple is passed, convert it to the new Cache API,
|
| 1015 |
+
# compute length via .get_seq_length(), and remember to return legacy if that’s what came in.
|
| 1016 |
+
if past_key_values is not None and not isinstance(past_key_values, Cache):
|
| 1017 |
+
use_legacy_cache = True # remember input format for return
|
| 1018 |
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
| 1019 |
+
|
| 1020 |
+
if isinstance(past_key_values, Cache):
|
| 1021 |
+
# Layer-agnostic total length; cache_position is handled deeper if needed
|
| 1022 |
+
past_key_values_length = past_key_values.get_seq_length()
|
| 1023 |
+
else:
|
| 1024 |
+
# No cache given on first forward, keep length at 0
|
| 1025 |
+
past_key_values_length = 0
|
| 1026 |
|
| 1027 |
if position_ids is None:
|
| 1028 |
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
|
|
|
| 1128 |
|
| 1129 |
next_cache = None
|
| 1130 |
if use_cache:
|
| 1131 |
+
# If the caller passed legacy, return legacy. Otherwise return the Cache object.
|
| 1132 |
+
next_cache = (
|
| 1133 |
+
next_decoder_cache.to_legacy_cache() if (use_legacy_cache and next_decoder_cache is not None) else next_decoder_cache
|
| 1134 |
+
)
|
| 1135 |
|
| 1136 |
if not return_dict:
|
| 1137 |
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
|
|
|
|
| 1143 |
)
|
| 1144 |
|
| 1145 |
|
| 1146 |
+
|
| 1147 |
class Qwen2ForCausalLM(Qwen2PreTrainedModel):
|
| 1148 |
_tied_weights_keys = ["lm_head.weight"]
|
| 1149 |
|
|
|
|
| 1271 |
# Omit tokens covered by past_key_values
|
| 1272 |
if past_key_values is not None:
|
| 1273 |
if isinstance(past_key_values, Cache):
|
| 1274 |
+
# NEW API (HF >= 4.55): use Cache methods
|
| 1275 |
cache_length = past_key_values.get_seq_length()
|
| 1276 |
+
past_length = cache_length # `seen_tokens` removed; use total seq length instead
|
| 1277 |
+
try:
|
| 1278 |
+
max_cache_length = past_key_values.get_max_cache_shape()
|
| 1279 |
+
except Exception:
|
| 1280 |
+
max_cache_length = None
|
| 1281 |
+
|
| 1282 |
+
# OLD API (deprecated/removed):
|
| 1283 |
+
# cache_length = past_key_values.get_seq_length()
|
| 1284 |
+
# past_length = past_key_values.seen_tokens
|
| 1285 |
+
# max_cache_length = past_key_values.get_max_length()
|
| 1286 |
+
|
| 1287 |
else:
|
| 1288 |
+
# Legacy tuple format: keep computing lengths directly from tensors
|
| 1289 |
+
# (We keep it compatible without forcing a conversion here)
|
| 1290 |
cache_length = past_length = past_key_values[0][0].shape[2]
|
| 1291 |
max_cache_length = None
|
| 1292 |
|
| 1293 |
# Keep only the unprocessed tokens:
|
| 1294 |
# 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
|
| 1295 |
+
# some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
|
|
|
|
| 1296 |
if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
|
| 1297 |
input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
|
| 1298 |
+
# 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens.
|
| 1299 |
+
# We can discard input_ids based on the past_length.
|
| 1300 |
elif past_length < input_ids.shape[1]:
|
| 1301 |
input_ids = input_ids[:, past_length:]
|
| 1302 |
# 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
|
|
|
|
| 1326 |
model_inputs.update(
|
| 1327 |
{
|
| 1328 |
"position_ids": position_ids,
|
| 1329 |
+
"past_key_values": past_key_values, # pass through unchanged (legacy or new Cache object)
|
| 1330 |
"use_cache": kwargs.get("use_cache"),
|
| 1331 |
"attention_mask": attention_mask,
|
| 1332 |
}
|
| 1333 |
)
|
| 1334 |
return model_inputs
|
| 1335 |
|
| 1336 |
+
|
| 1337 |
@staticmethod
|
| 1338 |
def _reorder_cache(past_key_values, beam_idx):
|
| 1339 |
reordered_past = ()
|