Skip to content

Commit

Permalink
updated pyproject
Browse files Browse the repository at this point in the history
  • Loading branch information
alitinet committed Oct 31, 2023
1 parent cc634f9 commit 7ac3556
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ urls.Source = "https://github.com/theislab/multigrate"
urls.Home-page = "https://github.com/theislab/multigrate"
dependencies = [
"scanpy",
"scvi-tools>=0.19.0",
"scvi-tools<1.0.0",
"matplotlib"
]

Expand Down
17 changes: 17 additions & 0 deletions src/multigrate/model/_multivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from scvi.model.base._utils import _initialize_model
from scvi.train import AdversarialTrainingPlan, TrainRunner
from scvi.train._callbacks import SaveBestState
from pytorch_lightning.callbacks import ModelCheckpoint

from ..dataloaders import GroupDataSplitter
from ..module import MultiVAETorch
Expand Down Expand Up @@ -224,6 +225,8 @@ def train(
n_steps_kl_warmup: Optional[int] = None,
adversarial_mixing: bool = False,
plan_kwargs: Optional[dict] = None,
save_checkpoint_every_n_epochs: Optional[int] = None,
path_to_checkpoints: Optional[str] = None,
**kwargs,
):
"""Train the model using amortized variational inference.
Expand Down Expand Up @@ -287,6 +290,19 @@ def train(
kwargs["callbacks"] = []
kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation"))

if save_checkpoint_every_n_epochs is not None:
if path_to_checkpoints is not None:
kwargs["callbacks"].append(ModelCheckpoint(
dirpath = path_to_checkpoints,
save_top_k = -1,
monitor = 'epoch',
every_n_epochs = save_checkpoint_every_n_epochs,
verbose = True,
))
else:
raise ValueError(f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}.")


if self.group_column is not None:
data_splitter = GroupDataSplitter(
self.adata_manager,
Expand Down Expand Up @@ -325,6 +341,7 @@ def train(
check_val_every_n_epoch=check_val_every_n_epoch,
early_stopping_monitor="reconstruction_loss_validation",
early_stopping_patience=50,
enable_checkpointing=True,
**kwargs,
)
return runner()
Expand Down

0 comments on commit 7ac3556

Please sign in to comment.