Skip to content

Commit

Permalink
Merge pull request #26 from theislab/update_pyproject
Browse files Browse the repository at this point in the history
updated pyproject, added checkpoints
  • Loading branch information
alitinet authored Oct 31, 2023
2 parents e637a32 + e8789de commit 91af5ed
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ urls.Source = "https://github.com/theislab/multigrate"
urls.Home-page = "https://github.com/theislab/multigrate"
dependencies = [
"scanpy",
"scvi-tools>=0.19.0",
"scvi-tools<1.0.0",
"matplotlib"
]

Expand Down
20 changes: 20 additions & 0 deletions src/multigrate/model/_multivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import scipy
import torch
from matplotlib import pyplot as plt
from pytorch_lightning.callbacks import ModelCheckpoint
from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager, fields
from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
Expand Down Expand Up @@ -224,6 +225,8 @@ def train(
n_steps_kl_warmup: Optional[int] = None,
adversarial_mixing: bool = False,
plan_kwargs: Optional[dict] = None,
save_checkpoint_every_n_epochs: Optional[int] = None,
path_to_checkpoints: Optional[str] = None,
**kwargs,
):
"""Train the model using amortized variational inference.
Expand Down Expand Up @@ -287,6 +290,22 @@ def train(
kwargs["callbacks"] = []
kwargs["callbacks"].append(SaveBestState(monitor="reconstruction_loss_validation"))

if save_checkpoint_every_n_epochs is not None:
if path_to_checkpoints is not None:
kwargs["callbacks"].append(
ModelCheckpoint(
dirpath=path_to_checkpoints,
save_top_k=-1,
monitor="epoch",
every_n_epochs=save_checkpoint_every_n_epochs,
verbose=True,
)
)
else:
raise ValueError(
f"`save_checkpoint_every_n_epochs` = {save_checkpoint_every_n_epochs} so `path_to_checkpoints` has to be not None but is {path_to_checkpoints}."
)

if self.group_column is not None:
data_splitter = GroupDataSplitter(
self.adata_manager,
Expand Down Expand Up @@ -325,6 +344,7 @@ def train(
check_val_every_n_epoch=check_val_every_n_epoch,
early_stopping_monitor="reconstruction_loss_validation",
early_stopping_patience=50,
enable_checkpointing=True,
**kwargs,
)
return runner()
Expand Down

0 comments on commit 91af5ed

Please sign in to comment.