diff --git a/pyproject.toml b/pyproject.toml index 91c20d4..e78fca4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ urls.Source = "https://github.com/theislab/multigrate" urls.Home-page = "https://github.com/theislab/multigrate" dependencies = [ "scanpy", - "scvi-tools>=0.19.0", + "scvi-tools<1.0.0", "matplotlib" ] diff --git a/src/multigrate/model/_multivae.py b/src/multigrate/model/_multivae.py index 4419a11..5498baf 100644 --- a/src/multigrate/model/_multivae.py +++ b/src/multigrate/model/_multivae.py @@ -17,6 +17,7 @@ from scvi.model.base._utils import _initialize_model from scvi.train import AdversarialTrainingPlan, TrainRunner from scvi.train._callbacks import SaveBestState +from pytorch_lightning.callbacks import ModelCheckpoint from ..dataloaders import GroupDataSplitter from ..module import MultiVAETorch @@ -224,6 +225,8 @@ def train( n_steps_kl_warmup: Optional[int] = None, adversarial_mixing: bool = False, plan_kwargs: Optional[dict] = None, + save_checkpoint_every_n_epochs: Optional[int] = None, + path_to_checkpoints: Optional[str] = None, **kwargs, ): """Train the model using amortized variational inference. @@ -287,6 +290,19 @@ def train( kwargs["callbacks"] = [] kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation")) + if save_checkpoint_every_n_epochs is not None: + if path_to_checkpoints is not None: + kwargs["callbacks"].append(ModelCheckpoint( + dirpath = path_to_checkpoints, + save_top_k = -1, + monitor = 'epoch', + every_n_epochs = save_checkpoint_every_n_epochs, + verbose = True, + )) + else: + raise ValueError(f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}.") + + if self.group_column is not None: data_splitter = GroupDataSplitter( self.adata_manager, @@ -325,6 +341,7 @@ def train( check_val_every_n_epoch=check_val_every_n_epoch, early_stopping_monitor="reconstruction_loss_validation", early_stopping_patience=50, + enable_checkpointing=True, **kwargs, ) return runner()