Upload 10 files
Browse files- LMConfig.py +4 -5
- config.json +3 -3
- model.py +21 -8
LMConfig.py
CHANGED
|
@@ -7,8 +7,8 @@ class LMConfig(PretrainedConfig):
|
|
| 7 |
|
| 8 |
def __init__(
|
| 9 |
self,
|
| 10 |
-
dim: int =
|
| 11 |
-
n_layers: int =
|
| 12 |
tie_word_embeddings: bool = True,
|
| 13 |
###########################################
|
| 14 |
attention:str='GQA',
|
|
@@ -27,7 +27,7 @@ class LMConfig(PretrainedConfig):
|
|
| 27 |
hidden_dim: int = None,
|
| 28 |
multiple_of: int = 64,
|
| 29 |
norm_eps: float = 1e-5,
|
| 30 |
-
max_seq_len: int =
|
| 31 |
rope_theta: int = 1e6,
|
| 32 |
dropout: float = 0.0,
|
| 33 |
flash_attn: bool = True,
|
|
@@ -46,10 +46,9 @@ class LMConfig(PretrainedConfig):
|
|
| 46 |
norm_topk_prob: bool = True,
|
| 47 |
**kwargs,
|
| 48 |
):
|
| 49 |
-
super().__init__(
|
| 50 |
self.dim = dim
|
| 51 |
self.n_layers = n_layers
|
| 52 |
-
self.tie_word_embeddings = tie_word_embeddings
|
| 53 |
self.vocab_size = vocab_size
|
| 54 |
self.hidden_dim = hidden_dim
|
| 55 |
self.multiple_of = multiple_of
|
|
|
|
| 7 |
|
| 8 |
def __init__(
|
| 9 |
self,
|
| 10 |
+
dim: int = 896,
|
| 11 |
+
n_layers: int = 24,
|
| 12 |
tie_word_embeddings: bool = True,
|
| 13 |
###########################################
|
| 14 |
attention:str='GQA',
|
|
|
|
| 27 |
hidden_dim: int = None,
|
| 28 |
multiple_of: int = 64,
|
| 29 |
norm_eps: float = 1e-5,
|
| 30 |
+
max_seq_len: int = 512,
|
| 31 |
rope_theta: int = 1e6,
|
| 32 |
dropout: float = 0.0,
|
| 33 |
flash_attn: bool = True,
|
|
|
|
| 46 |
norm_topk_prob: bool = True,
|
| 47 |
**kwargs,
|
| 48 |
):
|
| 49 |
+
super().__init__(tie_word_embeddings=tie_word_embeddings,**kwargs)
|
| 50 |
self.dim = dim
|
| 51 |
self.n_layers = n_layers
|
|
|
|
| 52 |
self.vocab_size = vocab_size
|
| 53 |
self.hidden_dim = hidden_dim
|
| 54 |
self.multiple_of = multiple_of
|
config.json
CHANGED
|
@@ -1,13 +1,13 @@
|
|
| 1 |
{
|
| 2 |
-
"_name_or_path": "
|
| 3 |
"architectures": [
|
| 4 |
"NanoChatLM"
|
| 5 |
],
|
| 6 |
"attention": "GQA",
|
| 7 |
"auto_map": {
|
| 8 |
"AutoConfig": "LMConfig.LMConfig",
|
| 9 |
-
"
|
| 10 |
-
"
|
| 11 |
},
|
| 12 |
"aux_loss_alpha": 0.1,
|
| 13 |
"dim": 896,
|
|
|
|
| 1 |
{
|
| 2 |
+
"_name_or_path": "NanoChat-0.3B-base/",
|
| 3 |
"architectures": [
|
| 4 |
"NanoChatLM"
|
| 5 |
],
|
| 6 |
"attention": "GQA",
|
| 7 |
"auto_map": {
|
| 8 |
"AutoConfig": "LMConfig.LMConfig",
|
| 9 |
+
"AutoModel": "model.NanoChatLM",
|
| 10 |
+
"AutoModelForCausalLM": "model.NanoChatLM"
|
| 11 |
},
|
| 12 |
"aux_loss_alpha": 0.1,
|
| 13 |
"dim": 896,
|
model.py
CHANGED
|
@@ -397,11 +397,18 @@ class NanoChatLM(PreTrainedModel):
|
|
| 397 |
self.layers = nn.ModuleList([NanoChatBlock(l, params) for l in range(self.n_layers)])
|
| 398 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
| 399 |
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
| 400 |
-
if params.tie_word_embeddings:
|
| 401 |
-
|
| 402 |
self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
|
| 403 |
theta=params.rope_theta), persistent=False)
|
| 404 |
self.OUT = CausalLMOutputWithPast()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 405 |
|
| 406 |
def forward(self,
|
| 407 |
input_ids: Optional[torch.Tensor] = None,
|
|
@@ -429,16 +436,16 @@ class NanoChatLM(PreTrainedModel):
|
|
| 429 |
|
| 430 |
@torch.inference_mode()
|
| 431 |
def generate(self, input_ids, eos_token_id=151643, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
| 432 |
-
|
| 433 |
# 流式生成
|
| 434 |
if stream:
|
| 435 |
-
return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
|
| 436 |
|
| 437 |
# 直接生成
|
| 438 |
generated = []
|
| 439 |
for i in range(input_ids.size(0)):
|
| 440 |
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
| 441 |
-
out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
|
| 442 |
tokens_list = [tokens[:, -1:] for tokens in out]
|
| 443 |
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
| 444 |
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
|
@@ -452,14 +459,14 @@ class NanoChatLM(PreTrainedModel):
|
|
| 452 |
]
|
| 453 |
return torch.cat(generated, dim=0)
|
| 454 |
|
| 455 |
-
def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
|
| 456 |
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
| 457 |
while input_ids.shape[1] < max_new_tokens - 1:
|
| 458 |
if first_seq or not use_cache:
|
| 459 |
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
|
| 460 |
else:
|
| 461 |
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
| 462 |
-
|
| 463 |
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
| 464 |
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
| 465 |
logits /= (temperature + 1e-9)
|
|
@@ -472,7 +479,13 @@ class NanoChatLM(PreTrainedModel):
|
|
| 472 |
sorted_indices_to_remove[:, 0] = False
|
| 473 |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 474 |
logits[indices_to_remove] = -float('Inf')
|
| 475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 476 |
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
| 477 |
yield input_ids[:, start:]
|
| 478 |
if input_ids_next.item() == eos_token_id:
|
|
|
|
| 397 |
self.layers = nn.ModuleList([NanoChatBlock(l, params) for l in range(self.n_layers)])
|
| 398 |
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
| 399 |
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
| 400 |
+
# if params.tie_word_embeddings:
|
| 401 |
+
# self.output.weight = self.tok_embeddings.weight
|
| 402 |
self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
|
| 403 |
theta=params.rope_theta), persistent=False)
|
| 404 |
self.OUT = CausalLMOutputWithPast()
|
| 405 |
+
|
| 406 |
+
self.post_init()
|
| 407 |
+
|
| 408 |
+
def tie_weights(self):
|
| 409 |
+
super().tie_weights()
|
| 410 |
+
if self.params.tie_word_embeddings:
|
| 411 |
+
self.output.weight = self.tok_embeddings.weight
|
| 412 |
|
| 413 |
def forward(self,
|
| 414 |
input_ids: Optional[torch.Tensor] = None,
|
|
|
|
| 436 |
|
| 437 |
@torch.inference_mode()
|
| 438 |
def generate(self, input_ids, eos_token_id=151643, max_new_tokens=1024, temperature=0.75, top_p=0.90,
|
| 439 |
+
stream=False, rp=1., use_cache=True, pad_token_id=0, do_sample=True, **args):
|
| 440 |
# 流式生成
|
| 441 |
if stream:
|
| 442 |
+
return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, do_sample)
|
| 443 |
|
| 444 |
# 直接生成
|
| 445 |
generated = []
|
| 446 |
for i in range(input_ids.size(0)):
|
| 447 |
non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
|
| 448 |
+
out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, do_sample)
|
| 449 |
tokens_list = [tokens[:, -1:] for tokens in out]
|
| 450 |
gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
|
| 451 |
full_sequence = torch.cat([non_pad, gen], dim=-1)
|
|
|
|
| 459 |
]
|
| 460 |
return torch.cat(generated, dim=0)
|
| 461 |
|
| 462 |
+
def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, do_sample, **args):
|
| 463 |
start, first_seq, past_kvs = input_ids.shape[1], True, None
|
| 464 |
while input_ids.shape[1] < max_new_tokens - 1:
|
| 465 |
if first_seq or not use_cache:
|
| 466 |
out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
|
| 467 |
else:
|
| 468 |
out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
|
| 469 |
+
start_pos=input_ids.shape[1] - 1)
|
| 470 |
logits, past_kvs = out.logits[:, -1, :], out.past_key_values
|
| 471 |
logits[:, list(set(input_ids.tolist()[0]))] /= rp
|
| 472 |
logits /= (temperature + 1e-9)
|
|
|
|
| 479 |
sorted_indices_to_remove[:, 0] = False
|
| 480 |
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
| 481 |
logits[indices_to_remove] = -float('Inf')
|
| 482 |
+
|
| 483 |
+
if do_sample:
|
| 484 |
+
input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
|
| 485 |
+
else:
|
| 486 |
+
# Greedy decoding: choose the token with the highest probability
|
| 487 |
+
input_ids_next = torch.argmax(F.softmax(logits, dim=-1), dim=-1).unsqueeze(-1)
|
| 488 |
+
|
| 489 |
input_ids = torch.cat((input_ids, input_ids_next), dim=1)
|
| 490 |
yield input_ids[:, start:]
|
| 491 |
if input_ids_next.item() == eos_token_id:
|