adityastomar commited on
Commit
0c28d1a
·
verified ·
1 Parent(s): d699e90

Add support for greedy decoding

Browse files

The current implementation of sampling only uses `torch.multinomial` and does not support greedy decoding when temperature is 0.0 / top-k is 0 / top-p is 1.0. This PR adds support for greedy decoding.

Files changed (1) hide show
  1. modeling_llada2_moe.py +8 -0
modeling_llada2_moe.py CHANGED
@@ -1240,6 +1240,14 @@ class LLaDA2MoeModelLM(LLaDA2MoePreTrainedModel, GenerationMixin):
1240
  orig_shape = logits.shape[:-1]
1241
  vocab_size = logits.shape[-1]
1242
  logits = logits.reshape(-1, vocab_size)
 
 
 
 
 
 
 
 
1243
  if temperature > 0 and temperature != 1.0:
1244
  logits = logits / temperature
1245
  logits = self._top_k_logits(logits, top_k)
 
1240
  orig_shape = logits.shape[:-1]
1241
  vocab_size = logits.shape[-1]
1242
  logits = logits.reshape(-1, vocab_size)
1243
+
1244
+ # Greedy mode: temperature = 0, no top-k/p
1245
+ if temperature == 0.0 and (top_k in (None, 0)) and (top_p is None or top_p >= 1.0):
1246
+ probs = F.softmax(logits, dim=-1)
1247
+ token = logits.argmax(dim=-1, keepdim=True)
1248
+ token_prob = probs.gather(-1, token)
1249
+ return token.view(*orig_shape), token_prob.view(*orig_shape)
1250
+
1251
  if temperature > 0 and temperature != 1.0:
1252
  logits = logits / temperature
1253
  logits = self._top_k_logits(logits, top_k)