Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Oct 31, 2023
1 parent 7ac3556 commit e8789de
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions src/multigrate/model/_multivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import scipy
import torch
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager, fields
from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
Expand All @@ -17,7 +18,6 @@
from scvi.model.base._utils import _initialize_model
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.train._callbacks import SaveBestState
from pytorch_lightning.callbacks import ModelCheckpoint

from ..dataloaders import GroupDataSplitter
from ..module import MultiVAETorch
Expand Down Expand Up @@ -292,16 +292,19 @@ def train(

if save_checkpoint_every_n_epochs is not None:
if path_to_checkpoints is not None:
kwargs["callbacks"].append(ModelCheckpoint(
dirpath = path_to_checkpoints,
save_top_k = -1,
monitor = 'epoch',
every_n_epochs = save_checkpoint_every_n_epochs,
verbose = True,
))
kwargs["callbacks"].append(
ModelCheckpoint(
dirpath=path_to_checkpoints,
save_top_k=-1,
monitor="epoch",
every_n_epochs=save_checkpoint_every_n_epochs,
verbose=True,
)
)
else:
raise ValueError(f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}.")

raise ValueError(
f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}."
)

if self.group_column is not None:
data_splitter = GroupDataSplitter(
Expand Down

0 comments on commit e8789de

Please sign in to comment.