Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lazy loading Fix and other minor fixes for CSAI #551

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pypots/classification/csai/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ def __init__(
replacement_probabilities=None,
normalise_mean: list = [],
normalise_std: list = [],
training: bool = True,
is_normalise: bool = False,
non_uniform: bool = False,
):
super().__init__(
data=data,
Expand All @@ -35,5 +36,6 @@ def __init__(
replacement_probabilities=replacement_probabilities,
normalise_mean=normalise_mean,
normalise_std=normalise_std,
training=training,
is_normalise=is_normalise,
non_uniform=non_uniform,
)
67 changes: 49 additions & 18 deletions pypots/classification/csai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ class CSAI(BaseNNClassifier):
The "better" strategy will automatically save the model during training whenever the model performs
better than in previous epochs.
The "all" strategy will save every model after each epoch training.

is_normalise :
Whether to normalize the input data. Set this flag to False if the the data is already normalised.

non_uniform :
Whether to apply non-uniform sampling to simulate missing values. If `True`, non-uniform sampling will be applied.

verbose :
Whether to print out the training logs during the training process.
Expand Down Expand Up @@ -125,6 +131,8 @@ def __init__(
device: Optional[Union[str, torch.device, list]] = None,
saving_path: str = None,
model_saving_strategy: Union[str, None] = "best",
is_normalise: bool = False,
non_uniform: bool = False,
verbose: bool = True,
):
super().__init__(
Expand Down Expand Up @@ -154,6 +162,8 @@ def __init__(
self.replacement_probabilities = None
self.mean_set = None
self.std_set = None
self.is_normalise = is_normalise
self.non_uniform = non_uniform

# Initialise empty model
self.model = _BCSAI(
Expand All @@ -174,13 +184,11 @@ def __init__(

# set up the optimizer
self.optimizer = optimizer
self.optimizer.init_optimizer(self.model.parameters())

def _assemble_input_for_training(self, data: list, training=True) -> dict:
def _assemble_input_for_training(self, data: list) -> dict:
# extract data
sample = data["sample"]
(indices, X, missing_mask, deltas, last_obs, back_X, back_missing_mask, back_deltas, back_last_obs, labels) = (
self._send_data_to_given_device(sample)
self._send_data_to_given_device(data)
)

inputs = {
Expand All @@ -206,7 +214,6 @@ def _assemble_input_for_validating(self, data: list) -> dict:

def _assemble_input_for_testing(self, data: list) -> dict:
# extract data
sample = data["sample"]
(
indices,
X,
Expand All @@ -217,9 +224,7 @@ def _assemble_input_for_testing(self, data: list) -> dict:
back_missing_mask,
back_deltas,
back_last_obs,
X_ori,
indicating_mask,
) = self._send_data_to_given_device(sample)
) = self._send_data_to_given_device(data)

# assemble input data
inputs = {
Expand Down Expand Up @@ -249,10 +254,10 @@ def fit(
file_type: str = "hdf5",
) -> None:
# Create dataset
if isinstance(train_set, str):
if isinstance(train_set, str) and self.non_uniform:
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole train set will be loaded into memory."
"CSAI does not support lazy loading with non uniform sampling because normalise mean and std "
"need to be calculated ahead. Hence the whole train set will be loaded into memory."
)
train_set = load_dict_from_h5(train_set)
training_set = DatasetForCSAI(
Expand All @@ -262,6 +267,8 @@ def fit(
removal_percent=self.removal_percent,
increase_factor=self.increase_factor,
compute_intervals=self.compute_intervals,
is_normalise=self.is_normalise,
non_uniform=self.non_uniform,
)

self.intervals = training_set.intervals
Expand All @@ -277,10 +284,10 @@ def fit(
)
val_loader = None
if val_set is not None:
if isinstance(val_set, str):
if isinstance(val_set, str) and self.non_uniform:
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole val set will be loaded into memory."
"CSAI does not support lazy loading with non uniform sampling because normalise mean and std "
"need to be calculated ahead. Hence the whole train set will be loaded into memory."
)
val_set = load_dict_from_h5(val_set)

Expand All @@ -296,6 +303,9 @@ def fit(
replacement_probabilities=self.replacement_probabilities,
normalise_mean=self.mean_set,
normalise_std=self.std_set,
is_normalise=self.is_normalise,
non_uniform=self.non_uniform,

)
val_loader = DataLoader(
val_set,
Expand All @@ -304,6 +314,26 @@ def fit(
num_workers=self.num_workers,
)

# set up the model
self.model = _BCSAI(
n_steps=self.n_steps,
n_features=self.n_features,
rnn_hidden_size=self.rnn_hidden_size,
imputation_weight=self.imputation_weight,
consistency_weight=self.consistency_weight,
classification_weight=self.classification_weight,
n_classes=self.n_classes,
step_channels=self.step_channels,
dropout=self.dropout,
intervals=self.intervals,
)

self._send_model_to_given_device()
self._print_model_size()

# set up the optimizer
self.optimizer.init_optimizer(self.model.parameters())

# train the model
self._train_model(train_loader, val_loader)
self.model.load_state_dict(self.best_model_dict)
Expand All @@ -319,10 +349,10 @@ def predict(

self.model.eval()

if isinstance(test_set, str):
if isinstance(test_set, str) and self.non_uniform:
logger.warning(
"CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. "
"Hence the whole test set will be loaded into memory."
"CSAI does not support lazy loading with non uniform sampling because normalise mean and std "
"need to be calculated ahead. Hence the whole train set will be loaded into memory."
)
test_set = load_dict_from_h5(test_set)
test_set = DatasetForCSAI(
Expand All @@ -335,7 +365,8 @@ def predict(
replacement_probabilities=self.replacement_probabilities,
normalise_mean=self.mean_set,
normalise_std=self.std_set,
training=False,
is_normalise=self.is_normalise,
non_uniform=self.non_uniform,
)
test_loader = DataLoader(
test_set,
Expand Down
4 changes: 4 additions & 0 deletions pypots/cli/tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from .utils import load_package_from_path
from ..classification import BRITS as BRITS_classification
from ..classification import GRUD as GRUD_classification
from ..classification import CSAI as CSAI_classification
from ..classification import Raindrop
from ..clustering import CRLI, VaDER
from ..data.saving.h5 import load_dict_from_h5
Expand Down Expand Up @@ -47,6 +48,7 @@
TiDE,
Reformer,
RevIN_SCINet,
CSAI,
)
from ..optim import Adam
from ..utils.logging import logger
Expand Down Expand Up @@ -89,10 +91,12 @@
"pypots.imputation.TiDE": TiDE,
"pypots.imputation.Transformer": Transformer,
"pypots.imputation.USGAN": USGAN,
"pypots.imputation.CSAI": CSAI,
# classification models
"pypots.classification.BRITS": BRITS_classification,
"pypots.classification.GRUD": GRUD_classification,
"pypots.classification.Raindrop": Raindrop,
"pypots.classification.CSAI": CSAI_classification,
# clustering models
"pypots.clustering.CRLI": CRLI,
"pypots.clustering.VaDER": VaDER,
Expand Down
Loading