Fix: AttributeError when `input_ids` is None during multimodal LLM training
#77
by
lyulumos
- opened
- modeling_chatglm.py +5 -4
modeling_chatglm.py
CHANGED
|
@@ -771,15 +771,16 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 771 |
if padding_mask is not None and not padding_mask.all():
|
| 772 |
return padding_mask
|
| 773 |
return None
|
| 774 |
-
batch_size, seq_length = input_ids.shape
|
| 775 |
-
|
|
|
|
| 776 |
full_attention_mask.tril_()
|
| 777 |
past_length = 0
|
| 778 |
if past_key_values:
|
| 779 |
past_length = past_key_values[0][0].shape[2]
|
| 780 |
if past_length:
|
| 781 |
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
| 782 |
-
device=
|
| 783 |
if padding_mask is not None:
|
| 784 |
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
| 785 |
if not past_length and padding_mask is not None:
|
|
@@ -872,7 +873,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
|
|
| 872 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 873 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 874 |
|
| 875 |
-
batch_size, seq_length = input_ids.shape
|
| 876 |
|
| 877 |
if inputs_embeds is None:
|
| 878 |
inputs_embeds = self.embedding(input_ids)
|
|
|
|
| 771 |
if padding_mask is not None and not padding_mask.all():
|
| 772 |
return padding_mask
|
| 773 |
return None
|
| 774 |
+
batch_size, seq_length = input_ids.shape if input_ids is not None else padding_mask.shape
|
| 775 |
+
device = input_ids.device if input_ids is not None else padding_mask.device
|
| 776 |
+
full_attention_mask = torch.ones(batch_size, seq_length, seq_length, device=device)
|
| 777 |
full_attention_mask.tril_()
|
| 778 |
past_length = 0
|
| 779 |
if past_key_values:
|
| 780 |
past_length = past_key_values[0][0].shape[2]
|
| 781 |
if past_length:
|
| 782 |
full_attention_mask = torch.cat((torch.ones(batch_size, seq_length, past_length,
|
| 783 |
+
device=device), full_attention_mask), dim=-1)
|
| 784 |
if padding_mask is not None:
|
| 785 |
full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1)
|
| 786 |
if not past_length and padding_mask is not None:
|
|
|
|
| 873 |
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 874 |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 875 |
|
| 876 |
+
batch_size, seq_length = (input_ids.shape if input_ids is not None else inputs_embeds.shape[:2] if inputs_embeds is not None else (None, None))
|
| 877 |
|
| 878 |
if inputs_embeds is None:
|
| 879 |
inputs_embeds = self.embedding(input_ids)
|