diff --git a/src/lightning/pytorch/callbacks/throughput_monitor.py b/src/lightning/pytorch/callbacks/throughput_monitor.py index a2d73d83184b1..5b4e9d39b4a06 100644 --- a/src/lightning/pytorch/callbacks/throughput_monitor.py +++ b/src/lightning/pytorch/callbacks/throughput_monitor.py @@ -87,6 +87,7 @@ def __init__( self._throughputs: Dict[RunningStage, Throughput] = {} self._t0s: Dict[RunningStage, float] = {} self._lengths: Dict[RunningStage, int] = {} + self._samples: Dict[RunningStage, int] = {} @override def setup(self, trainer: "Trainer", pl_module: "LightningModule", stage: str) -> None: @@ -109,6 +110,7 @@ def _start(self, trainer: "Trainer") -> None: self._throughputs[stage].reset() self._lengths[stage] = 0 self._t0s[stage] = time.perf_counter() + self._samples[stage] = 0 @torch.inference_mode() # in case `length_fn` or `batch_size_fn` computes grads def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, iter_num: int) -> None: @@ -133,12 +135,13 @@ def _update(self, trainer: "Trainer", pl_module: "LightningModule", batch: Any, ) flops_per_batch = None - batch_size = self.batch_size_fn(batch) + self._samples[stage] += self.batch_size_fn(batch) + throughput.update( time=elapsed, batches=iter_num, # this assumes that all iterations used the same batch size - samples=iter_num * batch_size, + samples=self._samples[stage], lengths=None if self.length_fn is None else self._lengths[stage], flops=flops_per_batch, )