Update implementation
Browse files- modeling_chatglm.py +15 -39
modeling_chatglm.py
CHANGED
|
@@ -35,12 +35,12 @@ if sys.platform != 'darwin':
|
|
| 35 |
|
| 36 |
logger = logging.get_logger(__name__)
|
| 37 |
|
| 38 |
-
_CHECKPOINT_FOR_DOC = "THUDM/
|
| 39 |
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
| 40 |
|
| 41 |
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 42 |
-
"THUDM/
|
| 43 |
-
# See all ChatGLM
|
| 44 |
]
|
| 45 |
|
| 46 |
|
|
@@ -92,7 +92,7 @@ class RotaryEmbedding(nn.Module):
|
|
| 92 |
self.dim = dim
|
| 93 |
self.original_impl = original_impl
|
| 94 |
|
| 95 |
-
def
|
| 96 |
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
| 97 |
):
|
| 98 |
"""Enhanced Transformer with Rotary Position Embedding.
|
|
@@ -118,14 +118,13 @@ class RotaryEmbedding(nn.Module):
|
|
| 118 |
return cache
|
| 119 |
|
| 120 |
def forward(self, max_seq_len, offset=0):
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
)
|
| 125 |
|
| 126 |
|
| 127 |
@torch.jit.script
|
| 128 |
-
def
|
| 129 |
# x: [sq, b, np, hn]
|
| 130 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
| 131 |
rot_dim = rope_cache.shape[-2] * 2
|
|
@@ -313,8 +312,6 @@ class SelfAttention(torch.nn.Module):
|
|
| 313 |
device=device, **_config_to_kwargs(config)
|
| 314 |
)
|
| 315 |
|
| 316 |
-
self.interleaved_qkv = config.interleaved_qkv
|
| 317 |
-
|
| 318 |
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
| 319 |
if self.multi_query_attention:
|
| 320 |
num_attention_heads = self.num_multi_query_groups_per_partition
|
|
@@ -364,33 +361,18 @@ class SelfAttention(torch.nn.Module):
|
|
| 364 |
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
| 365 |
)
|
| 366 |
else:
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
| 372 |
|
| 373 |
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
| 374 |
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
| 375 |
|
| 376 |
-
if not self.interleaved_qkv:
|
| 377 |
-
query_layer = query_layer.view(
|
| 378 |
-
query_layer.size()[:-1] + (
|
| 379 |
-
self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
| 380 |
-
).contiguous()
|
| 381 |
-
key_layer = key_layer.view(
|
| 382 |
-
key_layer.size()[:-1] + (
|
| 383 |
-
self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
| 384 |
-
).contiguous()
|
| 385 |
-
value_layer = value_layer.view(
|
| 386 |
-
value_layer.size()[:-1] + (
|
| 387 |
-
self.num_attention_heads_per_partition, self.hidden_size_per_attention_head)
|
| 388 |
-
).contiguous()
|
| 389 |
-
|
| 390 |
# apply relative positional encoding (rotary embedding)
|
| 391 |
if rotary_pos_emb is not None:
|
| 392 |
-
query_layer =
|
| 393 |
-
key_layer =
|
| 394 |
|
| 395 |
# adjust key and value for inference
|
| 396 |
if use_cache:
|
|
@@ -713,13 +695,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 713 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
| 714 |
)
|
| 715 |
|
| 716 |
-
|
| 717 |
-
rotary_dim = int(rotary_dim * config.rotary_percent)
|
| 718 |
-
|
| 719 |
-
# partial rotary embeddings, which is better than full rotary
|
| 720 |
-
# Wang and Komatsuzaki et al
|
| 721 |
-
# https://github.com/kingoflolz/mesh-transformer-jax/
|
| 722 |
-
self.rotary_pos_emb = RotaryEmbedding(rotary_dim, original_impl=config.original_rope, device=device,
|
| 723 |
dtype=config.torch_dtype)
|
| 724 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
| 725 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|
|
|
|
| 35 |
|
| 36 |
logger = logging.get_logger(__name__)
|
| 37 |
|
| 38 |
+
_CHECKPOINT_FOR_DOC = "THUDM/ChatGLM2-6B"
|
| 39 |
_CONFIG_FOR_DOC = "ChatGLM6BConfig"
|
| 40 |
|
| 41 |
CHATGLM_6B_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 42 |
+
"THUDM/chatglm2-6b",
|
| 43 |
+
# See all ChatGLM models at https://huggingface.co/models?filter=chatglm
|
| 44 |
]
|
| 45 |
|
| 46 |
|
|
|
|
| 92 |
self.dim = dim
|
| 93 |
self.original_impl = original_impl
|
| 94 |
|
| 95 |
+
def forward_impl(
|
| 96 |
self, seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000
|
| 97 |
):
|
| 98 |
"""Enhanced Transformer with Rotary Position Embedding.
|
|
|
|
| 118 |
return cache
|
| 119 |
|
| 120 |
def forward(self, max_seq_len, offset=0):
|
| 121 |
+
return self.forward_impl(
|
| 122 |
+
max_seq_len, self.dim, dtype=self.inv_freq.dtype, device=self.inv_freq.device
|
| 123 |
+
)
|
|
|
|
| 124 |
|
| 125 |
|
| 126 |
@torch.jit.script
|
| 127 |
+
def apply_rotary_pos_emb(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
|
| 128 |
# x: [sq, b, np, hn]
|
| 129 |
sq, b, np, hn = x.size(0), x.size(1), x.size(2), x.size(3)
|
| 130 |
rot_dim = rope_cache.shape[-2] * 2
|
|
|
|
| 312 |
device=device, **_config_to_kwargs(config)
|
| 313 |
)
|
| 314 |
|
|
|
|
|
|
|
| 315 |
def _allocate_memory(self, inference_max_sequence_len, batch_size, device=None, dtype=None):
|
| 316 |
if self.multi_query_attention:
|
| 317 |
num_attention_heads = self.num_multi_query_groups_per_partition
|
|
|
|
| 361 |
+ (self.num_multi_query_groups_per_partition, self.hidden_size_per_attention_head)
|
| 362 |
)
|
| 363 |
else:
|
| 364 |
+
new_tensor_shape = mixed_x_layer.size()[:-1] + \
|
| 365 |
+
(self.num_attention_heads_per_partition,
|
| 366 |
+
3 * self.hidden_size_per_attention_head)
|
| 367 |
+
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
|
|
|
|
| 368 |
|
| 369 |
# [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn]
|
| 370 |
(query_layer, key_layer, value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
# apply relative positional encoding (rotary embedding)
|
| 373 |
if rotary_pos_emb is not None:
|
| 374 |
+
query_layer = apply_rotary_pos_emb(query_layer, rotary_pos_emb)
|
| 375 |
+
key_layer = apply_rotary_pos_emb(key_layer, rotary_pos_emb)
|
| 376 |
|
| 377 |
# adjust key and value for inference
|
| 378 |
if use_cache:
|
|
|
|
| 695 |
config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
|
| 696 |
)
|
| 697 |
|
| 698 |
+
self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
dtype=config.torch_dtype)
|
| 700 |
self.encoder = init_method(GLMTransformer, config, **init_kwargs)
|
| 701 |
self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
|