From 497ffa40635c93b04e9b378d3511f31f110478d8 Mon Sep 17 00:00:00 2001 From: joseph-arulraj Date: Tue, 26 Nov 2024 17:07:09 +0000 Subject: [PATCH] Lazy loading Fix and other minor fixes Enabled lazy loading when non-unniform sampling is not required. Added flag to denote whether normalisation is needed. Added CSAI to pypots cli --- pypots/classification/csai/data.py | 6 +- pypots/classification/csai/model.py | 67 ++++++++--- pypots/cli/tuning.py | 4 + pypots/imputation/csai/data.py | 176 ++++++++++++++++++++++------ pypots/imputation/csai/model.py | 118 ++++++++++++------- tests/classification/csai.py | 4 +- tests/imputation/csai.py | 4 +- 7 files changed, 274 insertions(+), 105 deletions(-) diff --git a/pypots/classification/csai/data.py b/pypots/classification/csai/data.py index 3b93765c..a24d33fb 100644 --- a/pypots/classification/csai/data.py +++ b/pypots/classification/csai/data.py @@ -22,7 +22,8 @@ def __init__( replacement_probabilities=None, normalise_mean: list = [], normalise_std: list = [], - training: bool = True, + is_normalise: bool = False, + non_uniform: bool = False, ): super().__init__( data=data, @@ -35,5 +36,6 @@ def __init__( replacement_probabilities=replacement_probabilities, normalise_mean=normalise_mean, normalise_std=normalise_std, - training=training, + is_normalise=is_normalise, + non_uniform=non_uniform, ) diff --git a/pypots/classification/csai/model.py b/pypots/classification/csai/model.py index 3419c5bb..c6e29e95 100644 --- a/pypots/classification/csai/model.py +++ b/pypots/classification/csai/model.py @@ -97,6 +97,12 @@ class CSAI(BaseNNClassifier): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. The "all" strategy will save every model after each epoch training. + + is_normalise : + Whether to normalize the input data. Set this flag to False if the the data is already normalised. + + non_uniform : + Whether to apply non-uniform sampling to simulate missing values. If `True`, non-uniform sampling will be applied. verbose : Whether to print out the training logs during the training process. @@ -125,6 +131,8 @@ def __init__( device: Optional[Union[str, torch.device, list]] = None, saving_path: str = None, model_saving_strategy: Union[str, None] = "best", + is_normalise: bool = False, + non_uniform: bool = False, verbose: bool = True, ): super().__init__( @@ -154,6 +162,8 @@ def __init__( self.replacement_probabilities = None self.mean_set = None self.std_set = None + self.is_normalise = is_normalise + self.non_uniform = non_uniform # Initialise empty model self.model = _BCSAI( @@ -174,13 +184,11 @@ def __init__( # set up the optimizer self.optimizer = optimizer - self.optimizer.init_optimizer(self.model.parameters()) - def _assemble_input_for_training(self, data: list, training=True) -> dict: + def _assemble_input_for_training(self, data: list) -> dict: # extract data - sample = data["sample"] (indices, X, missing_mask, deltas, last_obs, back_X, back_missing_mask, back_deltas, back_last_obs, labels) = ( - self._send_data_to_given_device(sample) + self._send_data_to_given_device(data) ) inputs = { @@ -206,7 +214,6 @@ def _assemble_input_for_validating(self, data: list) -> dict: def _assemble_input_for_testing(self, data: list) -> dict: # extract data - sample = data["sample"] ( indices, X, @@ -217,9 +224,7 @@ def _assemble_input_for_testing(self, data: list) -> dict: back_missing_mask, back_deltas, back_last_obs, - X_ori, - indicating_mask, - ) = self._send_data_to_given_device(sample) + ) = self._send_data_to_given_device(data) # assemble input data inputs = { @@ -249,10 +254,10 @@ def fit( file_type: str = "hdf5", ) -> None: # Create dataset - if isinstance(train_set, str): + if isinstance(train_set, str) and self.non_uniform: logger.warning( - "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " - "Hence the whole train set will be loaded into memory." + "CSAI does not support lazy loading with non uniform sampling because normalise mean and std " + "need to be calculated ahead. Hence the whole train set will be loaded into memory." ) train_set = load_dict_from_h5(train_set) training_set = DatasetForCSAI( @@ -262,6 +267,8 @@ def fit( removal_percent=self.removal_percent, increase_factor=self.increase_factor, compute_intervals=self.compute_intervals, + is_normalise=self.is_normalise, + non_uniform=self.non_uniform, ) self.intervals = training_set.intervals @@ -277,10 +284,10 @@ def fit( ) val_loader = None if val_set is not None: - if isinstance(val_set, str): + if isinstance(val_set, str) and self.non_uniform: logger.warning( - "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " - "Hence the whole val set will be loaded into memory." + "CSAI does not support lazy loading with non uniform sampling because normalise mean and std " + "need to be calculated ahead. Hence the whole train set will be loaded into memory." ) val_set = load_dict_from_h5(val_set) @@ -296,6 +303,9 @@ def fit( replacement_probabilities=self.replacement_probabilities, normalise_mean=self.mean_set, normalise_std=self.std_set, + is_normalise=self.is_normalise, + non_uniform=self.non_uniform, + ) val_loader = DataLoader( val_set, @@ -304,6 +314,26 @@ def fit( num_workers=self.num_workers, ) + # set up the model + self.model = _BCSAI( + n_steps=self.n_steps, + n_features=self.n_features, + rnn_hidden_size=self.rnn_hidden_size, + imputation_weight=self.imputation_weight, + consistency_weight=self.consistency_weight, + classification_weight=self.classification_weight, + n_classes=self.n_classes, + step_channels=self.step_channels, + dropout=self.dropout, + intervals=self.intervals, + ) + + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer.init_optimizer(self.model.parameters()) + # train the model self._train_model(train_loader, val_loader) self.model.load_state_dict(self.best_model_dict) @@ -319,10 +349,10 @@ def predict( self.model.eval() - if isinstance(test_set, str): + if isinstance(test_set, str) and self.non_uniform: logger.warning( - "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " - "Hence the whole test set will be loaded into memory." + "CSAI does not support lazy loading with non uniform sampling because normalise mean and std " + "need to be calculated ahead. Hence the whole train set will be loaded into memory." ) test_set = load_dict_from_h5(test_set) test_set = DatasetForCSAI( @@ -335,7 +365,8 @@ def predict( replacement_probabilities=self.replacement_probabilities, normalise_mean=self.mean_set, normalise_std=self.std_set, - training=False, + is_normalise=self.is_normalise, + non_uniform=self.non_uniform, ) test_loader = DataLoader( test_set, diff --git a/pypots/cli/tuning.py b/pypots/cli/tuning.py index 2af0a863..1987fb3c 100644 --- a/pypots/cli/tuning.py +++ b/pypots/cli/tuning.py @@ -16,6 +16,7 @@ from .utils import load_package_from_path from ..classification import BRITS as BRITS_classification from ..classification import GRUD as GRUD_classification +from ..classification import CSAI as CSAI_classification from ..classification import Raindrop from ..clustering import CRLI, VaDER from ..data.saving.h5 import load_dict_from_h5 @@ -47,6 +48,7 @@ TiDE, Reformer, RevIN_SCINet, + CSAI, ) from ..optim import Adam from ..utils.logging import logger @@ -89,10 +91,12 @@ "pypots.imputation.TiDE": TiDE, "pypots.imputation.Transformer": Transformer, "pypots.imputation.USGAN": USGAN, + "pypots.imputation.CSAI": CSAI, # classification models "pypots.classification.BRITS": BRITS_classification, "pypots.classification.GRUD": GRUD_classification, "pypots.classification.Raindrop": Raindrop, + "pypots.classification.CSAI": CSAI_classification, # clustering models "pypots.clustering.CRLI": CRLI, "pypots.clustering.VaDER": VaDER, diff --git a/pypots/imputation/csai/data.py b/pypots/imputation/csai/data.py index a72fc322..fdb073be 100644 --- a/pypots/imputation/csai/data.py +++ b/pypots/imputation/csai/data.py @@ -12,10 +12,11 @@ import numpy as np import torch from sklearn.preprocessing import StandardScaler +from pygrinder import fill_and_get_mask_torch from ...data.dataset import BaseDataset -from ...data.utils import parse_delta - +from ...data.utils import _parse_delta_torch, parse_delta +from ...utils.logging import logger def normalize_csai( data, @@ -343,10 +344,13 @@ class DatasetForCSAI(BaseDataset): normalise_std : A list of standard deviation values for normalizing the input features. If not provided, they will be computed during initialization. + + is_normalise : + Whether to normalise the input data. Set this flag to False if the the data is already normalised. - training : - Whether the dataset is used for training. - If `False`, it will adjust how data is processed, particularly for evaluation and testing phases. + non_uniform : + Whether to apply non-uniform sampling to simulate missing values. If set `True`, non-uniform + sampling will be applied. Notes ----- @@ -371,7 +375,8 @@ def __init__( replacement_probabilities=None, normalise_mean=None, normalise_std=None, - training: bool = True, + is_normalise: bool = False, + non_uniform: bool = False, ): super().__init__( data=data, return_X_ori=return_X_ori, return_X_pred=False, return_y=return_y, file_type=file_type @@ -388,7 +393,8 @@ def __init__( self.replacement_probabilities = replacement_probabilities self.normalise_mean = normalise_mean self.normalise_std = normalise_std - self.training = training + self.is_normalise = is_normalise + self.non_uniform = non_uniform self.normalized_data = None self.mean_set = None @@ -396,27 +402,68 @@ def __init__( self.intervals = None if not isinstance(self.data, str): - self.normalized_data, self.mean_set, self.std_set, self.intervals = normalize_csai( - self.data["X"], - self.normalise_mean, - self.normalise_std, - compute_intervals, - ) - - self.processed_data, self.replacement_probabilities = non_uniform_sample( - self.normalized_data, - removal_percent, - replacement_probabilities, - increase_factor, - ) - self.forward_X = self.processed_data["values"] - self.forward_missing_mask = self.processed_data["masks"] - self.backward_X = torch.flip(self.forward_X, dims=[1]) - self.backward_missing_mask = torch.flip(self.forward_missing_mask, dims=[1]) - - self.X_ori = self.processed_data["evals"] - self.indicating_mask = self.processed_data["eval_masks"] - + if self.is_normalise: + self.normalized_data, self.mean_set, self.std_set, self.intervals = normalize_csai( + self.data["X"], + self.normalise_mean, + self.normalise_std, + compute_intervals, + ) + else: + self.normalized_data = self.data["X"] + self.mean_set = self.normalise_mean + self.std_set = self.normalise_std + self.intervals = {} + if compute_intervals: + for v in range(self.normalized_data.shape[2]): + all_intervals = [] + for p in range(self.normalized_data.shape[0]): + valid_time_points = np.where(~np.isnan(self.normalized_data[p, :, v]))[0] + if len(valid_time_points) > 1: + intervals = np.diff(valid_time_points) + all_intervals.extend(intervals) + self.intervals[v] = np.median(all_intervals) if all_intervals else np.nan + + if self.non_uniform: + self.processed_data, self.replacement_probabilities = non_uniform_sample( + self.normalized_data, + removal_percent, + replacement_probabilities, + increase_factor, + ) + self.forward_X = self.processed_data["values"] + self.forward_missing_mask = self.processed_data["masks"] + self.backward_X = torch.flip(self.forward_X, dims=[1]) + self.backward_missing_mask = torch.flip(self.forward_missing_mask, dims=[1]) + + self.X_ori = self.processed_data["evals"] + self.indicating_mask = self.processed_data["eval_masks"] + else: + if self.return_X_ori: + self.forward_missing_mask = self.missing_mask + self.forward_X = self.X + else: + self.forward_X, self.forward_missing_mask = fill_and_get_mask_torch(self.X) + + + deltas_f = parse_delta(self.forward_missing_mask) + self.backward_X = torch.flip(self.forward_X, dims=[1]) + self.backward_missing_mask = torch.flip(self.forward_missing_mask, dims=[1]) + deltas_b = parse_delta(self.backward_missing_mask) + + B, _, _ = self.forward_X.shape + + last_obs_f = np.array(np.nan_to_num([compute_last_obs(self.forward_X[b], self.forward_missing_mask[b].bool()) for b in range(B)])).astype(np.float32) + last_obs_b = np.array(np.nan_to_num([compute_last_obs(self.backward_X[b], self.backward_missing_mask[b].bool()) for b in range(B)])).astype(np.float32) + + self.processed_data = { + "deltas_f": deltas_f, + "deltas_b": deltas_b, + "last_obs_f": torch.FloatTensor(last_obs_f), + "last_obs_b": torch.FloatTensor(last_obs_b), + } + # self.replacement_probabilities = 0.0 + def _fetch_data_from_array(self, idx: int) -> Iterable: """Fetch data from self.X if it is given. @@ -460,21 +507,72 @@ def _fetch_data_from_array(self, idx: int) -> Iterable: self.processed_data["last_obs_b"][idx], ] - if not self.training: + if self.return_X_ori: sample.extend([self.X_ori[idx], self.indicating_mask[idx]]) if self.return_y: sample.append(self.y[idx].to(torch.long)) - return { - "sample": sample, - "replacement_probabilities": self.replacement_probabilities, - "mean_set": self.mean_set, - "std_set": self.std_set, - "intervals": self.intervals, - } + return sample def _fetch_data_from_file(self, idx: int) -> Iterable: - raise NotImplementedError( - "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead." - ) + """Fetch data with the lazy-loading strategy, i.e. only loading data from the file while requesting for samples. + Here the opened file handle doesn't load the entire dataset into RAM but only load the currently accessed slice. + Parameters + ---------- + idx : + The index of the sample to be return. + Returns + ------- + sample : + The collated data sample, a list including all necessary sample info. + """ + + if self.file_handle is None: + self.file_handle = self._open_file_handle() + + X = torch.from_numpy(self.file_handle["X"][idx]).to(torch.float32) + X, missing_mask = fill_and_get_mask_torch(X) + + forward = { + "X": X, + "missing_mask": missing_mask, + "deltas": _parse_delta_torch(missing_mask), + } + + backward = { + "X": torch.flip(forward["X"], dims=[0]), + "missing_mask": torch.flip(forward["missing_mask"], dims=[0]), + } + backward["deltas"] = _parse_delta_torch(backward["missing_mask"]) + # B, _, _ = forward['X'].shape + + last_obs_f = np.nan_to_num(compute_last_obs(forward['X'], forward['missing_mask'].bool())).astype(np.float32) + last_obs_b = np.nan_to_num(compute_last_obs(backward['X'], backward['missing_mask'].bool())).astype(np.float32) + + sample = [ + torch.tensor(idx), + # for forward + forward["X"], + forward["missing_mask"], + forward["deltas"], + torch.FloatTensor(last_obs_f), + # for backward + backward["X"], + backward["missing_mask"], + backward["deltas"], + torch.FloatTensor(last_obs_b), + ] + + if self.return_X_ori: + X_ori = torch.from_numpy(self.file_handle["X_ori"][idx]).to(torch.float32) + X_ori, X_ori_missing_mask = fill_and_get_mask_torch(X_ori) + indicating_mask = X_ori_missing_mask - missing_mask + sample.extend([X_ori, indicating_mask]) + + # if the dataset has labels and is for training, then fetch it from the file + if self.return_y: + sample.append(torch.tensor(self.file_handle["y"][idx], dtype=torch.long)) + + return sample + diff --git a/pypots/imputation/csai/model.py b/pypots/imputation/csai/model.py index a579fd2c..3a6afc58 100644 --- a/pypots/imputation/csai/model.py +++ b/pypots/imputation/csai/model.py @@ -91,6 +91,13 @@ class CSAI(BaseNNImputer): The "better" strategy will automatically save the model during training whenever the model performs better than in previous epochs. The "all" strategy will save every model after each epoch training. + + is_normalise: + Whether to normalize the input data. Set this flag to False if the the data is already normalised. + + non_uniform : + Whether to apply non-uniform sampling to simulate missing values. If `True`, non-uniform sampling will be + applied. verbose : Whether to print out the training logs during the training process. @@ -124,6 +131,8 @@ def __init__( device: Union[str, torch.device, list, None] = None, saving_path: str = None, model_saving_strategy: Union[str, None] = "best", + is_normalise: bool = False, + non_uniform: bool = False, verbose: bool = True, ): super().__init__( @@ -150,6 +159,8 @@ def __init__( self.replacement_probabilities = None self.mean_set = None self.std_set = None + self.is_normalise = is_normalise + self.non_uniform = non_uniform # Initialise model self.model = _BCSAI( @@ -167,14 +178,12 @@ def __init__( # set up the optimizer self.optimizer = optimizer - self.optimizer.init_optimizer(self.model.parameters()) - def _assemble_input_for_training(self, data: list, training=True) -> dict: + def _assemble_input_for_training(self, data: list) -> dict: # extract data - sample = data["sample"] (indices, X, missing_mask, deltas, last_obs, back_X, back_missing_mask, back_deltas, back_last_obs) = ( - self._send_data_to_given_device(sample) + self._send_data_to_given_device(data) ) # assemble input data @@ -198,7 +207,6 @@ def _assemble_input_for_training(self, data: list, training=True) -> dict: def _assemble_input_for_validating(self, data: list) -> dict: # extract data - sample = data["sample"] ( indices, X, @@ -211,7 +219,7 @@ def _assemble_input_for_validating(self, data: list) -> dict: back_last_obs, X_ori, indicating_mask, - ) = self._send_data_to_given_device(sample) + ) = self._send_data_to_given_device(data) # assemble input data inputs = { @@ -243,21 +251,24 @@ def fit( file_type: str = "hdf5", ) -> None: - if isinstance(train_set, str): + if isinstance(train_set, str) and self.non_uniform: logger.warning( - "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " - "Hence the whole train set will be loaded into memory." + "CSAI does not support lazy loading with non uniform sampling because normalise mean and std " + "need to be calculated ahead. Hence the whole train set will be loaded into memory." ) + train_set = load_dict_from_h5(train_set) training_set = DatasetForCSAI( - train_set, - False, - False, - file_type, - self.removal_percent, - self.increase_factor, - self.compute_intervals, + data=train_set, + return_X_ori=False, + return_y=False, + file_type=file_type, + removal_percent=self.removal_percent, + increase_factor=self.increase_factor, + compute_intervals=self.compute_intervals, + is_normalise=self.is_normalise, + non_uniform=self.non_uniform, ) self.intervals = training_set.intervals self.replacement_probabilities = training_set.replacement_probabilities @@ -273,27 +284,28 @@ def fit( ) val_loader = None if val_set is not None: - if isinstance(val_set, str): + if isinstance(val_set, str) and self.non_uniform: logger.warning( - "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " - "Hence the whole val set will be loaded into memory." + "CSAI does not support lazy loading with non uniform sampling because normalise mean and std " + "need to be calculated ahead. Hence the whole train set will be loaded into memory." ) val_set = load_dict_from_h5(val_set) if not key_in_data_set("X_ori", val_set): raise ValueError("val_set must contain 'X_ori' for model validation.") val_set = DatasetForCSAI( - val_set, - True, - False, - file_type, - self.removal_percent, - self.increase_factor, - self.compute_intervals, - self.replacement_probabilities, - self.mean_set, - self.std_set, - False, + data=val_set, + return_X_ori=True, + return_y=False, + file_type=file_type, + removal_percent=self.removal_percent, + increase_factor=self.increase_factor, + compute_intervals=self.compute_intervals, + replacement_probabilities=self.replacement_probabilities, + normalise_mean=self.mean_set, + normalise_std=self.std_set, + is_normalise=self.is_normalise, + non_uniform=self.non_uniform, ) val_loader = DataLoader( val_set, @@ -303,6 +315,23 @@ def fit( # collate_fn=collate_fn_bidirectional ) + # Reset the model + self.model = _BCSAI( + self.n_steps, + self.n_features, + self.rnn_hidden_size, + self.step_channels, + self.consistency_weight, + self.imputation_weight, + self.intervals, + ) + + self._send_model_to_given_device() + self._print_model_size() + + # set up the optimizer + self.optimizer.init_optimizer(self.model.parameters()) + # train the model self._train_model(training_loader, val_loader) self.model.load_state_dict(self.best_model_dict) @@ -319,24 +348,25 @@ def predict( self.model.eval() - if isinstance(test_set, str): + if isinstance(test_set, str) and self.non_uniform: logger.warning( - "CSAI does not support lazy loading because normalise mean and std need to be calculated ahead. " - "Hence the whole test set will be loaded into memory." + "CSAI does not support lazy loading with non uniform sampling because normalise mean and std " + "need to be calculated ahead. Hence the whole train set will be loaded into memory." ) test_set = load_dict_from_h5(test_set) test_set = DatasetForCSAI( - test_set, - True, - False, - file_type, - self.removal_percent, - self.increase_factor, - self.compute_intervals, - self.replacement_probabilities, - self.mean_set, - self.std_set, - False, + data=test_set, + return_X_ori=True, + return_y=False, + file_type=file_type, + removal_percent=self.removal_percent, + increase_factor=self.increase_factor, + compute_intervals=self.compute_intervals, + replacement_probabilities=self.replacement_probabilities, + normalise_mean=self.mean_set, + normalise_std=self.std_set, + is_normalise=self.is_normalise, + non_uniform=self.non_uniform, ) test_loader = DataLoader( diff --git a/tests/classification/csai.py b/tests/classification/csai.py index 4a2bbf5f..287ab541 100644 --- a/tests/classification/csai.py +++ b/tests/classification/csai.py @@ -58,6 +58,8 @@ class TestCSAI(unittest.TestCase): device=DEVICE, saving_path=saving_path, model_saving_strategy="better", + is_normalise=False, + non_uniform=False, verbose=True, ) @@ -106,7 +108,7 @@ def test_3_saving_path(self): # Save the trained model to file, and verify the file existence saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.csai.save(saved_model_path) + self.csai.save(saved_model_path, overwrite=True) # Test loading the saved model self.csai.load(saved_model_path) diff --git a/tests/imputation/csai.py b/tests/imputation/csai.py index f5c4873b..d8c914e2 100644 --- a/tests/imputation/csai.py +++ b/tests/imputation/csai.py @@ -57,6 +57,8 @@ class TestCSAI(unittest.TestCase): device=DEVICE, saving_path=saving_path, model_saving_strategy="best", + is_normalise=False, + non_uniform=False, verbose=True, ) @@ -97,7 +99,7 @@ def test_3_saving_path(self): # Save the trained model to file, and verify the file existence saved_model_path = os.path.join(self.saving_path, self.model_save_name) - self.csai.save(saved_model_path) + self.csai.save(saved_model_path, overwrite=True) # Test loading the saved model self.csai.load(saved_model_path)