Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training fails with mini-batch size of one sample #3035

Open
LeonHafner opened this issue Oct 31, 2024 · 3 comments · May be fixed by #3036
Open

Training fails with mini-batch size of one sample #3035

LeonHafner opened this issue Oct 31, 2024 · 3 comments · May be fixed by #3036
Labels

Comments

@LeonHafner
Copy link

This bug has already been reported in several issues and on discourse: #2314 #2214 #221 #426
SCVI training fails if the ceil(n_cells * 0.9) % 128 == 1, where 0.9 is the training split and 128 the batch size.
I thought the bug should be fixed with 1.2.0, but unfortunately it still occurs.

It would be very cool, if you could find a fix for that, as that would allow users to remove unnecessary try-except blocks that change the batch size if the error occurs.

Thanks a lot!

import numpy as np
import anndata as ad
from scvi.model import SCVI


num_cells = 143
num_genes = 1000

shape_param = 2.0
scale_param = 1.0

gamma_rates = np.random.gamma(shape=shape_param, scale=scale_param, size=(num_cells, num_genes))
data = np.random.poisson(gamma_rates)

adata = ad.AnnData(X=data)

print(adata.shape)

SCVI.setup_anndata(adata)
model = SCVI(adata)
model.train()
Traceback (most recent call last):
  File "/nfs/home/students/l.hafner/nf-core/scvi_test/test_scvi.py", line 21, in <module>
    model.train()
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/model/base/_training_mixin.py", line 145, in train
    return runner()
           ^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainrunner.py", line 96, in __call__
    self.trainer.fit(self.training_plan, self.data_splitter)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainer.py", line 201, in fit
    super().fit(*args, **kwargs)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 47, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/trainer.py", line 1025, in _run_stage
    self.fit_loop.run()
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 205, in run
    self.advance()
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/fit_loop.py", line 363, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 140, in run
    self.advance(data_fetcher)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/training_epoch_loop.py", line 250, in advance
    batch_output = self.automatic_optimization.run(trainer.optimizers[0], batch_idx, kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 190, in run
    self._optimizer_step(batch_idx, closure)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 268, in _optimizer_step
    call._call_lightning_module_hook(
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 167, in _call_lightning_module_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/core/module.py", line 1306, in optimizer_step
    optimizer.step(closure=optimizer_closure)
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/core/optimizer.py", line 153, in step
    step_output = self._strategy.optimizer_step(self._optimizer, closure, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 238, in optimizer_step
    return self.precision_plugin.optimizer_step(optimizer, model=model, closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 122, in optimizer_step
    return optimizer.step(closure=closure, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/optimizer.py", line 484, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/optimizer.py", line 89, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/optim/adam.py", line 205, in step
    loss = closure()
           ^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/plugins/precision/precision.py", line 108, in _wrap_closure
    closure_result = closure()
                     ^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 144, in __call__
    self._result = self.closure(*args, **kwargs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 129, in closure
    step_output = self._step_fn()
                  ^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/loops/optimization/automatic.py", line 317, in _training_step
    training_step_output = call._call_strategy_hook(trainer, "training_step", *kwargs.values())
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/lightning/pytorch/strategies/strategy.py", line 390, in training_step
    return self.lightning_module.training_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainingplans.py", line 344, in training_step
    _, _, scvi_loss = self.forward(batch, loss_kwargs=self.loss_kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/train/_trainingplans.py", line 278, in forward
    return self.module(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_args
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 208, in forward
    return _generic_forward(
           ^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 748, in _generic_forward
    inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_args
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_base_module.py", line 312, in inference
    return self._regular_inference(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/base/_decorators.py", line 32, in auto_transfer_args
    return fn(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/module/_vae.py", line 377, in _regular_inference
    qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/nn/_base_components.py", line 283, in forward
    q = self.encoder(x, *cat_list)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/scvi/nn/_base_components.py", line 173, in forward
    x = layer(x)
        ^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/modules/batchnorm.py", line 176, in forward
    return F.batch_norm(
           ^^^^^^^^^^^^^
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/functional.py", line 2510, in batch_norm
    _verify_batch_size(input.size())
  File "/nfs/home/students/l.hafner/miniforge3/envs/scvi_test/lib/python3.12/site-packages/torch/nn/functional.py", line 2478, in _verify_batch_size
    raise ValueError(f"Expected more than 1 value per channel when training, got input size {size}")
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 128])

Versions:

scvi: 1.2.0
anndata: 0.10.9
numpy: 1.26.4

@ori-kron-wis
Copy link
Collaborator

Hi @LeonHafner thanks for posting this.
Currently, apparently, we issue a warning in such cases, see #2916 and the changelog. There is no logic fix. Will this not be enough for your case?

Otherwise, see my suggested fix here: #3036
The idea is that in this case we will artificially add +1 to n_train and remove 1 from n_val, if possible (i.e if n_val by itself is larger than 2)

@LeonHafner
Copy link
Author

Hi @ori-kron-wis, thanks for the quick reply.
for me the quickest fix was to pass datasplitter_kwargs={"drop_last": True} to the model.train function. This simply drops the cell that remains in the last batch.

But as this is not a very nice solution, I would appreciate some logic being implemented into scVI to fix this. Your suggested fix is a great idea, hope you will be able to get it merged!

Best,
Leon

@canergen
Copy link
Member

canergen commented Nov 2, 2024

@ori-kron-wis I didn’t come to fix it. I think we should add the cells to validation (and I would do it for less than 3 cells - sounds safer). We should set by default train_size and validation_size to None. If it’s None we change these small batches. If the user sets a custom value like 0.9 (old behavior), we don’t change the train cells and it still fails.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants