You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This bug has already been reported in several issues and on discourse: #2314#2214#221#426
SCVI training fails if the ceil(n_cells * 0.9) % 128 == 1, where 0.9 is the training split and 128 the batch size.
I thought the bug should be fixed with 1.2.0, but unfortunately it still occurs.
It would be very cool, if you could find a fix for that, as that would allow users to remove unnecessary try-except blocks that change the batch size if the error occurs.
Traceback (most recent call last):
File "/nfs/home/students/l.hafner/nf-core/scvi_test/test_scvi.py", line 21, in <module>
model.train()
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/model/base/_training_mixin.py", line 145, in trainreturn runner()
^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainrunner.py", line 96, in __call__self.trainer.fit(self.training_plan, self.data_splitter)
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainer.py", line 201, in fitsuper().fit(*args, **kwargs)
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
call._call_and_handle_interrupt(
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interruptreturn trainer_fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_implself._run(model, ckpt_path=ckpt_path)
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
results =self._run_stage()
^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stageself.fit_loop.run()
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in runself.advance()
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advanceself.epoch_loop.run(self._data_fetcher)
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in runself.advance(data_fetcher)
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
batch_output =self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in runself._optimizer_step(batch_idx, closure)
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
call._call_lightning_module_hook(
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
step_output =self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_stepreturnself.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_stepreturn optimizer.step(closure=closure, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/optimizer.py", line 484, in wrapper
out = func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/optimizer.py", line 89, in _use_grad
ret = func(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/adam.py", line 205, in step
loss = closure()
^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
closure_result = closure()
^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__self._result =self.closure(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_contextreturn func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
step_output =self._step_fn()
^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
output = fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_stepreturnself.lightning_module.training_step(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainingplans.py", line 344, in training_step
_, _, scvi_loss =self.forward(batch, loss_kwargs=self.loss_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainingplans.py", line 278, in forwardreturnself.module(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_implreturnself._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_implreturn forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_argsreturn fn(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 208, in forwardreturn _generic_forward(
^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 748, in _generic_forward
inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_argsreturn fn(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 312, in inferencereturnself._regular_inference(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_argsreturn fn(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/_vae.py", line 377, in _regular_inference
qz, z =self.z_encoder(encoder_input, batch_index, *categorical_input)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_implreturnself._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_implreturn forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/nn/_base_components.py", line 283, in forward
q =self.encoder(x, *cat_list)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_implreturnself._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_implreturn forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/nn/_base_components.py", line 173, in forward
x = layer(x)
^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_implreturnself._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_implreturn forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py", line 176, in forwardreturn F.batch_norm(
^^^^^^^^^^^^^
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/functional.py", line 2510, in batch_norm
_verify_batch_size(input.size())
File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/functional.py", line 2478, in _verify_batch_sizeraiseValueError(f"Expected more than 1 value per channel when training, got input size {size}")
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 128])
Versions:
scvi: 1.2.0
anndata: 0.10.9
numpy: 1.26.4
The text was updated successfully, but these errors were encountered:
Hi @LeonHafner thanks for posting this.
Currently, apparently, we issue a warning in such cases, see #2916 and the changelog. There is no logic fix. Will this not be enough for your case?
Otherwise, see my suggested fix here: #3036
The idea is that in this case we will artificially add +1 to n_train and remove 1 from n_val, if possible (i.e if n_val by itself is larger than 2)
Hi @ori-kron-wis, thanks for the quick reply.
for me the quickest fix was to pass datasplitter_kwargs={"drop_last": True} to the model.train function. This simply drops the cell that remains in the last batch.
But as this is not a very nice solution, I would appreciate some logic being implemented into scVI to fix this. Your suggested fix is a great idea, hope you will be able to get it merged!
@ori-kron-wis I didn’t come to fix it. I think we should add the cells to validation (and I would do it for less than 3 cells - sounds safer). We should set by default train_size and validation_size to None. If it’s None we change these small batches. If the user sets a custom value like 0.9 (old behavior), we don’t change the train cells and it still fails.
This bug has already been reported in several issues and on discourse: #2314 #2214 #221 #426
SCVI training fails if the
ceil(n_cells * 0.9) % 128 == 1
, where 0.9 is the training split and 128 the batch size.I thought the bug should be fixed with 1.2.0, but unfortunately it still occurs.
It would be very cool, if you could find a fix for that, as that would allow users to remove unnecessary try-except blocks that change the batch size if the error occurs.
Thanks a lot!
Versions:
The text was updated successfully, but these errors were encountered: