diff --git a/src/multigrate/model/_multivae.py b/src/multigrate/model/_multivae.py index 5498baf..9b6caef 100644 --- a/src/multigrate/model/_multivae.py +++ b/src/multigrate/model/_multivae.py @@ -7,6 +7,7 @@ import scipy import torch from matplotlib import pyplot as plt +from pytorch_lightning.callbacks import ModelCheckpoint from scvi import REGISTRY_KEYS from scvi.data import AnnDataManager, fields from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY @@ -17,7 +18,6 @@ from scvi.model.base._utils import _initialize_model from scvi.train import AdversarialTrainingPlan, TrainRunner from scvi.train._callbacks import SaveBestState -from pytorch_lightning.callbacks import ModelCheckpoint from ..dataloaders import GroupDataSplitter from ..module import MultiVAETorch @@ -292,16 +292,19 @@ def train( if save_checkpoint_every_n_epochs is not None: if path_to_checkpoints is not None: - kwargs["callbacks"].append(ModelCheckpoint( - dirpath = path_to_checkpoints, - save_top_k = -1, - monitor = 'epoch', - every_n_epochs = save_checkpoint_every_n_epochs, - verbose = True, - )) + kwargs["callbacks"].append( + ModelCheckpoint( + dirpath=path_to_checkpoints, + save_top_k=-1, + monitor="epoch", + every_n_epochs=save_checkpoint_every_n_epochs, + verbose=True, + ) + ) else: - raise ValueError(f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}.") - + raise ValueError( + f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}." + ) if self.group_column is not None: data_splitter = GroupDataSplitter(