TypeError: GemmaModel.forward() got an unexpected keyword argument 'num_items_in_batch'
getting below error while training the model for finetuning purpose for QLORA configuration
TypeError                                 Traceback (most recent call last)
Cell In[34], line 3
      1 import time
      2 start = time.time()
----> 3 trainer.train()
      4 print(time.time()- start)
File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2171, in Trainer.train(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)
   2169         hf_hub_utils.enable_progress_bars()
   2170 else:
-> 2171     return inner_training_loop(
   2172         args=args,
   2173         resume_from_checkpoint=resume_from_checkpoint,
   2174         trial=trial,
   2175         ignore_keys_for_eval=ignore_keys_for_eval,
   2176     )
File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:2531, in Trainer._inner_training_loop(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)
   2524 context = (
   2525     functools.partial(self.accelerator.no_sync, model=model)
   2526     if i != len(batch_samples) - 1
   2527     and self.accelerator.distributed_type != DistributedType.DEEPSPEED
   2528     else contextlib.nullcontext
   2529 )
   2530 with context():
-> 2531     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
   2533 if (
   2534     args.logging_nan_inf_filter
   2535     and not is_torch_xla_available()
   2536     and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
   2537 ):
   2538     # if loss is nan or inf simply add the average of previous logged losses
   2539     tr_loss = tr_loss + tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3675, in Trainer.training_step(self, model, inputs, num_items_in_batch)
   3672     return loss_mb.reduce_mean().detach().to(self.args.device)
   3674 with self.compute_loss_context_manager():
-> 3675     loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
   3677 del inputs
   3678 if (
   3679     self.args.torch_empty_cache_steps is not None
   3680     and self.state.global_step % self.args.torch_empty_cache_steps == 0
   3681 ):
File /usr/local/lib/python3.10/dist-packages/transformers/trainer.py:3731, in Trainer.compute_loss(self, model, inputs, return_outputs, num_items_in_batch)
   3729         loss_kwargs["num_items_in_batch"] = num_items_in_batch
   3730     inputs = {**inputs, **loss_kwargs}
-> 3731 outputs = model(**inputs)
   3732 # Save past state if it exists
   3733 # TODO: this needs to be fixed and made cleaner later.
   3734 if self.args.past_index >= 0:
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()
File /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:819, in convert_outputs_to_fp32..forward(*args, **kwargs)
    818 def forward(*args, **kwargs):
--> 819     return model_forward(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py:807, in ConvertOutputsToFp32.call(self, *args, **kwargs)
    806 def call(self, *args, **kwargs):
--> 807     return convert_to_fp32(self.model_forward(*args, **kwargs))
File /usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py:44, in autocast_decorator..decorate_autocast(*args, **kwargs)
     41 @functools.wraps(func)
     42 def decorate_autocast(*args, **kwargs):
     43     with autocast_instance:
---> 44         return func(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/peft/peft_model.py:1719, in PeftModelForCausalLM.forward(self, input_ids, attention_mask, inputs_embeds, labels, output_attentions, output_hidden_states, return_dict, task_ids, **kwargs)
   1717     with self._enable_peft_forward_hooks(**kwargs):
   1718         kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args}
-> 1719         return self.base_model(
   1720             input_ids=input_ids,
   1721             attention_mask=attention_mask,
   1722             inputs_embeds=inputs_embeds,
   1723             labels=labels,
   1724             output_attentions=output_attentions,
   1725             output_hidden_states=output_hidden_states,
   1726             return_dict=return_dict,
   1727             **kwargs,
   1728         )
   1730 batch_size = _get_batch_size(input_ids, inputs_embeds)
   1731 if attention_mask is not None:
   1732     # concat prompt attention mask
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()
File /usr/local/lib/python3.10/dist-packages/peft/tuners/tuners_utils.py:197, in BaseTuner.forward(self, *args, **kwargs)
    196 def forward(self, *args: Any, **kwargs: Any):
--> 197     return self.model.forward(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/transformers/models/gemma/modeling_gemma.py:832, in GemmaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict, cache_position, num_logits_to_keep, **kwargs)
    829 return_dict = return_dict if return_dict is not None else self.config.use_return_dict
    831 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
--> 832 outputs = self.model(
    833     input_ids=input_ids,
    834     attention_mask=attention_mask,
    835     position_ids=position_ids,
    836     past_key_values=past_key_values,
    837     inputs_embeds=inputs_embeds,
    838     use_cache=use_cache,
    839     output_attentions=output_attentions,
    840     output_hidden_states=output_hidden_states,
    841     return_dict=return_dict,
    842     cache_position=cache_position,
    843     **kwargs,
    844 )
    846 hidden_states = outputs[0]
    847 # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1739, in Module._wrapped_call_impl(self, *args, **kwargs)
   1737     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1738 else:
-> 1739     return self._call_impl(*args, **kwargs)
File /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1750, in Module._call_impl(self, *args, **kwargs)
   1745 # If we don't have any hooks, we want to skip the rest of the logic in
   1746 # this function, and just call forward.
   1747 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1748         or _global_backward_pre_hooks or _global_backward_hooks
   1749         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1750     return forward_call(*args, **kwargs)
   1752 result = None
   1753 called_always_called_hooks = set()
TypeError: GemmaModel.forward() got an unexpected keyword argument 'num_items_in_batch'
Hi @smkhant ,
I reproduced the same error, please refer this gist file. The error was occurring because the default Trainer was passing num_items_in_batch to the model's forward pass, which Gemma doesn't accept. To avoid that error we need to create a CustomTrainer class that inherits from Trainer. Override the compute_loss method to properly handle the inputs and remove the problematic num_items_in_batch parameter. For more details, please refer to the Github Code.
Thank you.
