duzx16
commited on
Commit
·
c57e892
1
Parent(s):
fc442f7
Fix prefix prompt in evaluation
Browse files- modeling_chatglm.py +8 -5
modeling_chatglm.py
CHANGED
|
@@ -803,6 +803,14 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 803 |
if inputs_embeds is None:
|
| 804 |
inputs_embeds = self.embedding(input_ids)
|
| 805 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 806 |
if full_attention_mask is None:
|
| 807 |
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
| 808 |
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
|
@@ -815,11 +823,6 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 815 |
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
| 816 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
| 817 |
|
| 818 |
-
if past_key_values is None:
|
| 819 |
-
if self.pre_seq_len is not None:
|
| 820 |
-
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
| 821 |
-
dtype=inputs_embeds.dtype)
|
| 822 |
-
|
| 823 |
# Run encoder.
|
| 824 |
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
| 825 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|
|
|
|
| 803 |
if inputs_embeds is None:
|
| 804 |
inputs_embeds = self.embedding(input_ids)
|
| 805 |
|
| 806 |
+
if self.pre_seq_len is not None:
|
| 807 |
+
if past_key_values is None:
|
| 808 |
+
past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
|
| 809 |
+
dtype=inputs_embeds.dtype)
|
| 810 |
+
if attention_mask is not None:
|
| 811 |
+
attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
|
| 812 |
+
attention_mask], dim=-1)
|
| 813 |
+
|
| 814 |
if full_attention_mask is None:
|
| 815 |
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
|
| 816 |
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
|
|
|
|
| 823 |
rotary_pos_emb = rotary_pos_emb[None, :seq_length]
|
| 824 |
rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()
|
| 825 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 826 |
# Run encoder.
|
| 827 |
hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
|
| 828 |
inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
|