xueyunlong commited on
Commit
48e0294
·
verified ·
1 Parent(s): 7fdf68d

Upload 10 files

Browse files
Files changed (3) hide show
  1. LMConfig.py +4 -5
  2. config.json +3 -3
  3. model.py +21 -8
LMConfig.py CHANGED
@@ -7,8 +7,8 @@ class LMConfig(PretrainedConfig):
7
 
8
  def __init__(
9
  self,
10
- dim: int = 512,
11
- n_layers: int = 8,
12
  tie_word_embeddings: bool = True,
13
  ###########################################
14
  attention:str='GQA',
@@ -27,7 +27,7 @@ class LMConfig(PretrainedConfig):
27
  hidden_dim: int = None,
28
  multiple_of: int = 64,
29
  norm_eps: float = 1e-5,
30
- max_seq_len: int = 8192,
31
  rope_theta: int = 1e6,
32
  dropout: float = 0.0,
33
  flash_attn: bool = True,
@@ -46,10 +46,9 @@ class LMConfig(PretrainedConfig):
46
  norm_topk_prob: bool = True,
47
  **kwargs,
48
  ):
49
- super().__init__(**kwargs)
50
  self.dim = dim
51
  self.n_layers = n_layers
52
- self.tie_word_embeddings = tie_word_embeddings
53
  self.vocab_size = vocab_size
54
  self.hidden_dim = hidden_dim
55
  self.multiple_of = multiple_of
 
7
 
8
  def __init__(
9
  self,
10
+ dim: int = 896,
11
+ n_layers: int = 24,
12
  tie_word_embeddings: bool = True,
13
  ###########################################
14
  attention:str='GQA',
 
27
  hidden_dim: int = None,
28
  multiple_of: int = 64,
29
  norm_eps: float = 1e-5,
30
+ max_seq_len: int = 512,
31
  rope_theta: int = 1e6,
32
  dropout: float = 0.0,
33
  flash_attn: bool = True,
 
46
  norm_topk_prob: bool = True,
47
  **kwargs,
48
  ):
49
+ super().__init__(tie_word_embeddings=tie_word_embeddings,**kwargs)
50
  self.dim = dim
51
  self.n_layers = n_layers
 
52
  self.vocab_size = vocab_size
53
  self.hidden_dim = hidden_dim
54
  self.multiple_of = multiple_of
config.json CHANGED
@@ -1,13 +1,13 @@
1
  {
2
- "_name_or_path": "out/",
3
  "architectures": [
4
  "NanoChatLM"
5
  ],
6
  "attention": "GQA",
7
  "auto_map": {
8
  "AutoConfig": "LMConfig.LMConfig",
9
- "AutoModelForCausalLM": "model.NanoChatLM",
10
- "AutoModel": "model.NanoChatLM"
11
  },
12
  "aux_loss_alpha": 0.1,
13
  "dim": 896,
 
1
  {
2
+ "_name_or_path": "NanoChat-0.3B-base/",
3
  "architectures": [
4
  "NanoChatLM"
5
  ],
6
  "attention": "GQA",
7
  "auto_map": {
8
  "AutoConfig": "LMConfig.LMConfig",
9
+ "AutoModel": "model.NanoChatLM",
10
+ "AutoModelForCausalLM": "model.NanoChatLM"
11
  },
12
  "aux_loss_alpha": 0.1,
13
  "dim": 896,
model.py CHANGED
@@ -397,11 +397,18 @@ class NanoChatLM(PreTrainedModel):
397
  self.layers = nn.ModuleList([NanoChatBlock(l, params) for l in range(self.n_layers)])
398
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
399
  self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
400
- if params.tie_word_embeddings:
401
- self.output.weight = self.tok_embeddings.weight
402
  self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
403
  theta=params.rope_theta), persistent=False)
404
  self.OUT = CausalLMOutputWithPast()
 
 
 
 
 
 
 
405
 
406
  def forward(self,
407
  input_ids: Optional[torch.Tensor] = None,
@@ -429,16 +436,16 @@ class NanoChatLM(PreTrainedModel):
429
 
430
  @torch.inference_mode()
431
  def generate(self, input_ids, eos_token_id=151643, max_new_tokens=1024, temperature=0.75, top_p=0.90,
432
- stream=False, rp=1., use_cache=True, pad_token_id=0, **args):
433
  # 流式生成
434
  if stream:
435
- return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
436
 
437
  # 直接生成
438
  generated = []
439
  for i in range(input_ids.size(0)):
440
  non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
441
- out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache)
442
  tokens_list = [tokens[:, -1:] for tokens in out]
443
  gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
444
  full_sequence = torch.cat([non_pad, gen], dim=-1)
@@ -452,14 +459,14 @@ class NanoChatLM(PreTrainedModel):
452
  ]
453
  return torch.cat(generated, dim=0)
454
 
455
- def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, **args):
456
  start, first_seq, past_kvs = input_ids.shape[1], True, None
457
  while input_ids.shape[1] < max_new_tokens - 1:
458
  if first_seq or not use_cache:
459
  out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
460
  else:
461
  out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
462
- start_pos=input_ids.shape[1] - 1)
463
  logits, past_kvs = out.logits[:, -1, :], out.past_key_values
464
  logits[:, list(set(input_ids.tolist()[0]))] /= rp
465
  logits /= (temperature + 1e-9)
@@ -472,7 +479,13 @@ class NanoChatLM(PreTrainedModel):
472
  sorted_indices_to_remove[:, 0] = False
473
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
474
  logits[indices_to_remove] = -float('Inf')
475
- input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
 
 
 
 
 
 
476
  input_ids = torch.cat((input_ids, input_ids_next), dim=1)
477
  yield input_ids[:, start:]
478
  if input_ids_next.item() == eos_token_id:
 
397
  self.layers = nn.ModuleList([NanoChatBlock(l, params) for l in range(self.n_layers)])
398
  self.norm = RMSNorm(params.dim, eps=params.norm_eps)
399
  self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
400
+ # if params.tie_word_embeddings:
401
+ # self.output.weight = self.tok_embeddings.weight
402
  self.register_buffer("pos_cis", precompute_pos_cis(params.dim // params.n_heads, params.max_seq_len,
403
  theta=params.rope_theta), persistent=False)
404
  self.OUT = CausalLMOutputWithPast()
405
+
406
+ self.post_init()
407
+
408
+ def tie_weights(self):
409
+ super().tie_weights()
410
+ if self.params.tie_word_embeddings:
411
+ self.output.weight = self.tok_embeddings.weight
412
 
413
  def forward(self,
414
  input_ids: Optional[torch.Tensor] = None,
 
436
 
437
  @torch.inference_mode()
438
  def generate(self, input_ids, eos_token_id=151643, max_new_tokens=1024, temperature=0.75, top_p=0.90,
439
+ stream=False, rp=1., use_cache=True, pad_token_id=0, do_sample=True, **args):
440
  # 流式生成
441
  if stream:
442
+ return self._generate_stream(input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, do_sample)
443
 
444
  # 直接生成
445
  generated = []
446
  for i in range(input_ids.size(0)):
447
  non_pad = input_ids[i][input_ids[i] != pad_token_id].unsqueeze(0)
448
+ out = self._generate_stream(non_pad, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, do_sample)
449
  tokens_list = [tokens[:, -1:] for tokens in out]
450
  gen = torch.cat(tokens_list, dim=-1) if tokens_list else non_pad
451
  full_sequence = torch.cat([non_pad, gen], dim=-1)
 
459
  ]
460
  return torch.cat(generated, dim=0)
461
 
462
+ def _generate_stream(self, input_ids, eos_token_id, max_new_tokens, temperature, top_p, rp, use_cache, do_sample, **args):
463
  start, first_seq, past_kvs = input_ids.shape[1], True, None
464
  while input_ids.shape[1] < max_new_tokens - 1:
465
  if first_seq or not use_cache:
466
  out, first_seq = self(input_ids, past_key_values=past_kvs, use_cache=use_cache), False
467
  else:
468
  out = self(input_ids[:, -1:], past_key_values=past_kvs, use_cache=use_cache,
469
+ start_pos=input_ids.shape[1] - 1)
470
  logits, past_kvs = out.logits[:, -1, :], out.past_key_values
471
  logits[:, list(set(input_ids.tolist()[0]))] /= rp
472
  logits /= (temperature + 1e-9)
 
479
  sorted_indices_to_remove[:, 0] = False
480
  indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
481
  logits[indices_to_remove] = -float('Inf')
482
+
483
+ if do_sample:
484
+ input_ids_next = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
485
+ else:
486
+ # Greedy decoding: choose the token with the highest probability
487
+ input_ids_next = torch.argmax(F.softmax(logits, dim=-1), dim=-1).unsqueeze(-1)
488
+
489
  input_ids = torch.cat((input_ids, input_ids_next), dim=1)
490
  yield input_ids[:, start:]
491
  if input_ids_next.item() == eos_token_id: