diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index c878d2b345cc31..93b67a239fcc8e 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2465,6 +2465,7 @@ def _inner_training_loop( if remainder == 0: remainder = args.gradient_accumulation_steps update_step = -1 + # Note: total_updates can be larger than dataloader length total_updates = steps_in_epoch // args.gradient_accumulation_steps + 1 for _ in range(total_updates): update_step += 1 @@ -5153,6 +5154,8 @@ def get_batch_samples(self, epoch_iterator, num_batches): pass if self.args.average_tokens_across_devices: + if num_items_in_batch is None: + num_items_in_batch = torch.tensor(0).to(self.args.device) num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum().item() if torch.is_tensor(num_items_in_batch):