Skip to content

Commit

Permalink
Merge pull request #278 from WenjieDu/dev
Browse files Browse the repository at this point in the history
Updating GP-VAE, adding load_dict_from_h5, etc.
  • Loading branch information
WenjieDu authored Dec 21, 2023
2 parents b6adbac + 022ee07 commit 62f67e1
Show file tree
Hide file tree
Showing 25 changed files with 509 additions and 147 deletions.
3 changes: 2 additions & 1 deletion pypots/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,7 +543,8 @@ def _print_model_size(self) -> None:
"""Print the number of trainable parameters in the initialized NN model."""
num_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
logger.info(
f"Model initialized successfully with the number of trainable parameters: {num_params:,}"
f"A {self.__class__.__name__} model initialized with the given hyperparameters, "
f"the number of trainable parameters: {num_params:,}"
)

@abstractmethod
Expand Down
10 changes: 6 additions & 4 deletions pypots/classification/raindrop/modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@
except ImportError as e:
logger.error(
f"{e}\n"
"Note torch_geometric is missing, "
"please install it with 'pip install torch_geometric' or 'conda install -c pyg pyg'"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
)
except NameError as e:
logger.error(
f"{e}\n"
"Note torch_geometric is missing, "
"please install it with 'pip install torch_geometric' or 'conda install -c pyg pyg'"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
)


Expand Down
5 changes: 3 additions & 2 deletions pypots/classification/raindrop/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
except ImportError as e:
logger.error(
f"{e}\n"
"torch_geometric is missing, "
"please install it with 'pip install torch_geometric' or 'conda install -c pyg pyg'"
"Note torch_geometric is missing, please install it with "
"'pip install torch_geometric torch_scatter torch_sparse' or "
"'conda install -c pyg pyg pytorch-scatter pytorch-sparse'"
)


Expand Down
47 changes: 42 additions & 5 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

import os
from argparse import ArgumentParser, Namespace
import torch

from .base import BaseCommand
from .utils import load_package_from_path
from ..classification import BRITS as BRITS_classification
from ..classification import Raindrop, GRUD
from ..clustering import CRLI, VaDER
from ..data.saving.h5 import load_dict_from_h5
from ..imputation import SAITS, Transformer, CSDI, USGAN, GPVAE, MRNN, BRITS, TimesNet
from ..optim import Adam
from ..utils.logging import logger
Expand All @@ -26,15 +28,14 @@
"but is missing in the current environment."
)


NN_MODELS = {
# imputation models
"pypots.imputation.SAITS": SAITS,
"pypots.imputation.Transformer": Transformer,
"pypots.imputation.TimesNet": TimesNet,
"pypots.imputation.CSDI": CSDI,
"pypots.imputation.US_GAN": USGAN,
"pypots.imputation.GP_VAE": GPVAE,
"pypots.imputation.USGAN": USGAN,
"pypots.imputation.GPVAE": GPVAE,
"pypots.imputation.BRITS": BRITS,
"pypots.imputation.MRNN": MRNN,
# classification models
Expand All @@ -53,6 +54,8 @@ def env_command_factory(args: Namespace):
args.model_package_path,
args.train_set,
args.val_set,
args.lazy_load,
args.torch_n_threads,
)


Expand Down Expand Up @@ -105,6 +108,21 @@ def register_subcommand(parser: ArgumentParser):
required=True,
help="",
)
sub_parser.add_argument(
"--lazy_load",
dest="lazy_load",
action="store_true",
help="Whether to use lazy loading for the dataset. If `True`, the dataset will be lazy loaded for model "
"training, i.e. only the current batch will be fetched from the file. Lazy loading needs less memory but "
"more time and CPU rate to read data each time.",
)
sub_parser.add_argument(
"--torch_n_threads",
dest="torch_n_threads",
type=int,
default=1,
help="The input value for torch.set_num_threads()",
)
sub_parser.set_defaults(func=env_command_factory)

def __init__(
Expand All @@ -113,25 +131,33 @@ def __init__(
model_package_path: str,
train_set: str,
val_set: str,
lazy_load: bool = False,
torch_n_threads: int = 1,
):
self._model = model
self._model_package_path = model_package_path
self._train_set = train_set
self._val_set = val_set
self._lazy_load = lazy_load
self._torch_n_threads = torch_n_threads

def checkup(self):
"""Run some checks on the arguments to avoid error usages"""
pass

def run(self):
"""Execute the given command."""

# set the number of threads for torch, avoid using too many CPU cores
torch.set_num_threads(self._torch_n_threads)

if os.getenv("enable_tuning", False):
# fetch a new set of hyperparameters from NNI tuner
tuner_params = nni.get_next_parameter()
# get the specified model class
if self._model not in NN_MODELS:
logger.info(
f"The specified model {self._model} is not in PyPOTS. "
f"The specified model {self._model} is not in PyPOTS. Available models are {NN_MODELS.keys()}. "
f"Trying to fetch it from the given model package {self._model_package_path}."
)
assert self._model_package_path is not None, (
Expand Down Expand Up @@ -187,7 +213,18 @@ def run(self):

# init an instance with the given hyperparameters for the model class
model = model_class(**tuner_params)

# load the dataset
if self._lazy_load:
train_set, val_set = self._train_set, self._val_set
else:
logger.info(
f"lazy loading {self._lazy_load}, loading all data from file..."
)
train_set = load_dict_from_h5(self._train_set)
val_set = load_dict_from_h5(self._val_set)

# train the model and report to NNI
model.fit(train_set=self._train_set, val_set=self._val_set)
model.fit(train_set=train_set, val_set=val_set)
else:
raise RuntimeError("Argument `enable_tuning` is not set. Aborting...")
3 changes: 2 additions & 1 deletion pypots/data/saving/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# Created by Wenjie Du <[email protected]>
# License: BSD-3-Clause

from .h5 import save_dict_into_h5
from .h5 import save_dict_into_h5, load_dict_from_h5
from .pickle import pickle_dump, pickle_load

__all__ = [
"save_dict_into_h5",
"load_dict_from_h5",
"pickle_dump",
"pickle_load",
]
53 changes: 51 additions & 2 deletions pypots/data/saving/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@


import os
from datetime import datetime
from typing import Optional

import h5py
import yaml

from pypots.utils.file import extract_parent_dir, create_dir_if_not_exist
from pypots.utils.logging import logger
from ...utils.file import extract_parent_dir, create_dir_if_not_exist
from ...utils.logging import logger


def save_dict_into_h5(
Expand Down Expand Up @@ -84,3 +86,50 @@ def save_set(handle, name, data):
save_set(hf, k, v)

logger.info(f"Successfully saved the given data into {saving_path}.")


def load_dict_from_h5(
file_path: str,
) -> dict:
"""Load the data from the given h5 file and return as a Python dictionary.
Parameters
----------
file_path : str,
The path to the h5 file.
Returns
-------
data : dict,
The data loaded from the given h5 file.
"""
assert isinstance(
file_path, str
), f"`file_path` should be a string, but got {type(file_path)}."
assert os.path.exists(file_path), "`file_path` does not exist."

def load_set(handle, datadict):
for key, item in handle.items():
if isinstance(item, h5py.Group):
datadict[key] = {}
datadict[key] = load_set(item, datadict[key])
elif isinstance(item, h5py.Dataset):
value = item[()]
if "_type_" in item.attrs:
if item.attrs["_type_"].astype(str) == "datetime":
if hasattr(value, "__iter__"):
value = [datetime.fromtimestamp(ts) for ts in value]
else:
value = datetime.fromtimestamp(value)
elif item.attrs["_type_"].astype(str) == "yaml":
value = yaml.safe_load(value.decode())
datadict[key] = value

return datadict

data = {}
with h5py.File(file_path, "r") as hf:
data = load_set(hf, data)

return data
6 changes: 6 additions & 0 deletions pypots/imputation/brits/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ class BRITS(BaseNNImputer):
Parameters
----------
n_steps :
The number of time steps in the time-series data sample.
n_features :
The number of features in the time-series data sample.
rnn_hidden_size :
The size of the RNN hidden state, also the number of hidden units in the RNN cell.
Expand Down
Loading

0 comments on commit 62f67e1

Please sign in to comment.