Refactor Logits Naming (#15)
Browse files- Refactor Logits Naming (5aed2aa1a7dabed45b57dbde209de31cef94b39f)
Co-authored-by: Xu <[email protected]>
modeling_moonshot_kimia.py
CHANGED
|
@@ -901,15 +901,15 @@ class MoonshotKimiaForCausalLM(Qwen2PreTrainedModel):
|
|
| 901 |
else:
|
| 902 |
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
| 903 |
|
| 904 |
-
|
| 905 |
-
|
| 906 |
|
| 907 |
if not return_dict:
|
| 908 |
-
output = (
|
| 909 |
return output
|
| 910 |
return CausalLMOutputWithPast(
|
| 911 |
loss=None,
|
| 912 |
-
logits=(
|
| 913 |
past_key_values=outputs.past_key_values,
|
| 914 |
hidden_states=outputs.hidden_states,
|
| 915 |
attentions=outputs.attentions,
|
|
|
|
| 901 |
else:
|
| 902 |
hidden_states, mimo_hidden_states = outputs[0], outputs[1]
|
| 903 |
|
| 904 |
+
text_logits = self.lm_head(hidden_states)
|
| 905 |
+
audio_logits = self.mimo_output(mimo_hidden_states)
|
| 906 |
|
| 907 |
if not return_dict:
|
| 908 |
+
output = (audio_logits, text_logits) + outputs[2:]
|
| 909 |
return output
|
| 910 |
return CausalLMOutputWithPast(
|
| 911 |
loss=None,
|
| 912 |
+
logits=(audio_logits, text_logits),
|
| 913 |
past_key_values=outputs.past_key_values,
|
| 914 |
hidden_states=outputs.hidden_states,
|
| 915 |
attentions=outputs.attentions,
|