From 9c790d7502a9da179a791435c0b6a317a4316d6e Mon Sep 17 00:00:00 2001 From: Gregory Leleytner Date: Fri, 10 May 2024 20:49:43 +0300 Subject: [PATCH] Add VAE implementation in PyTorch --- pyod/models/vae.py | 268 ++++++++++++++++++++++++++---------------- pyod/test/test_vae.py | 6 +- 2 files changed, 170 insertions(+), 104 deletions(-) diff --git a/pyod/models/vae.py b/pyod/models/vae.py index 4938cbc8c..10e4fd8e7 100644 --- a/pyod/models/vae.py +++ b/pyod/models/vae.py @@ -19,6 +19,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from itertools import chain import numpy as np from sklearn.preprocessing import StandardScaler @@ -30,20 +31,40 @@ from ..utils.stat_models import pairwise_distances_no_broadcast from ..utils.utility import check_parameter -# if tensorflow 2, import from tf directly -if _get_tensorflow_version() == 1: - from keras.models import Model - from keras.layers import Lambda, Input, Dense, Dropout - from keras.regularizers import l2 - from keras.losses import mse - from keras import backend as K -else: - from tensorflow.keras.models import Model - from tensorflow.keras.layers import Lambda, Input, Dense, Dropout - from tensorflow.keras.regularizers import l2 - from tensorflow.keras.losses import mse - from tensorflow.keras import backend as K - +from torch import nn +import torch +from torch.utils.data import TensorDataset, DataLoader, random_split +from torch import optim +from tqdm import tqdm + + +_activation_classes = { + 'relu': nn.ReLU, + 'sigmoid': nn.Sigmoid, + 'tanh': nn.Tanh +} + +def _resolve_activation(activation): + if isinstance(activation, str) and activation in _activation_classes: + return _activation_classes[activation] + elif isinstance(activation, nn.Module): + return activation + else: + raise ValueError(f'Activation must be nn.Module subclass or one of {_activation_classes.keys()}') + +def _resolve_loss(loss): + if loss == 'mse': + return nn.MSELoss + elif isinstance(loss, callable): + return loss + else: + raise ValueError(f'Loss must be "mse" or some callable') + +def _resolve_optim(optimizer): + if optimizer == 'adam': + return optim.Adam + else: + raise ValueError(f'Only "adam" is supported as optimizer') class VAE(BaseDetector): """ Variational auto encoder @@ -139,6 +160,9 @@ class VAE(BaseDetector): The amount of contamination of the data set, i.e. the proportion of outliers in the data set. When fitting this is to define the threshold on the decision function. + + device : str + Torch device to train and inference autoencoder Attributes ---------- @@ -149,11 +173,14 @@ class VAE(BaseDetector): The ratio between the original feature and the number of neurons in the encoding layer. - model_ : Keras Object - The underlying AutoEncoder in Keras. + encoder_ : nn.Module + The underlying AutoEncoder Encoder + + decoder_ : nn.Module + The underlying AutoEncoder Decoder - history_: Keras Object - The AutoEncoder training history. + history_: List[float] + The AutoEncoder val losses for every epoch decision_scores_ : numpy array of shape (n_samples,) The outlier scores of the training data. @@ -175,18 +202,22 @@ class VAE(BaseDetector): def __init__(self, encoder_neurons=None, decoder_neurons=None, latent_dim=2, hidden_activation='relu', - output_activation='sigmoid', loss=mse, optimizer='adam', + output_activation='sigmoid', loss='mse', optimizer='adam', epochs=100, batch_size=32, dropout_rate=0.2, l2_regularizer=0.1, validation_size=0.1, preprocessing=True, - verbose=1, random_state=None, contamination=0.1, - gamma=1.0, capacity=0.0): + verbose=1, random_state=42, contamination=0.1, + gamma=1.0, capacity=0.0, device='cpu'): super(VAE, self).__init__(contamination=contamination) self.encoder_neurons = encoder_neurons self.decoder_neurons = decoder_neurons self.hidden_activation = hidden_activation + self.hidden_activation_cls = _resolve_activation(hidden_activation) self.output_activation = output_activation + self.output_activation_cls = _resolve_activation(output_activation) self.loss = loss + self.loss_fn = _resolve_loss(loss)() self.optimizer = optimizer + self.optimizer_cls = _resolve_optim(optimizer) self.epochs = epochs self.batch_size = batch_size self.dropout_rate = dropout_rate @@ -198,6 +229,7 @@ def __init__(self, encoder_neurons=None, decoder_neurons=None, self.latent_dim = latent_dim self.gamma = gamma self.capacity = capacity + self.device = device # default values if self.encoder_neurons is None: @@ -212,94 +244,125 @@ def __init__(self, encoder_neurons=None, decoder_neurons=None, check_parameter(dropout_rate, 0, 1, param_name='dropout_rate', include_left=True) - def sampling(self, args): - """Reparametrisation by sampling from Gaussian, N(0,I) - To sample from epsilon = Norm(0,I) instead of from likelihood Q(z|X) - with latent variables z: z = z_mean + sqrt(var) * epsilon - - Parameters - ---------- - args : tensor - Mean and log of variance of Q(z|X). - - Returns - ------- - z : tensor - Sampled latent variable. - """ - - z_mean, z_log = args - batch = K.shape(z_mean)[0] # batch size - dim = K.int_shape(z_mean)[1] # latent dimension - epsilon = K.random_normal(shape=(batch, dim)) # mean=0, std=1.0 - - return z_mean + K.exp(0.5 * z_log) * epsilon - - def vae_loss(self, inputs, outputs, z_mean, z_log): + def _vae_loss(self, inputs, outputs, z_mean, z_log): """ Loss = Recreation loss + Kullback-Leibler loss for probability function divergence (ELBO). gamma > 1 and capacity != 0 for beta-VAE """ - reconstruction_loss = self.loss(inputs, outputs) + reconstruction_loss = self.loss_fn(inputs, outputs) reconstruction_loss *= self.n_features_ - kl_loss = 1 + z_log - K.square(z_mean) - K.exp(z_log) - kl_loss = -0.5 * K.sum(kl_loss, axis=-1) - kl_loss = self.gamma * K.abs(kl_loss - self.capacity) + kl_loss = 1 + z_log - z_mean ** 2 - torch.exp(z_log) + kl_loss = -0.5 * kl_loss.sum(axis=-1) + kl_loss = self.gamma * (kl_loss - self.capacity).abs() + + return (reconstruction_loss + kl_loss).mean() + + def _forward_encoder(self, X): + is_encoder_input_layer = True + activity_regularization = 0 + for layer in self.encoder_: + X = layer(X) + if isinstance(layer, nn.Linear): + activity_regularization = (X ** 2).sum() + activity_regularization + return X, activity_regularization + + def _forward_vae(self, X): + X, activity_regularization = self._forward_encoder(X) + latent_stats = X.reshape(-1, 2, self.latent_dim) + z_mean = latent_stats[:, 0] + z_log = latent_stats[:, 1] + + # reparametrization trick + epsilon = torch.randn_like(z_mean) # mean=0, std=1.0 + sampled_latent = z_mean + torch.exp(0.5 * z_log) * epsilon - return K.mean(reconstruction_loss + kl_loss) + return self.decoder_(sampled_latent), z_mean, z_log, activity_regularization def _build_model(self): - """Build VAE = encoder + decoder + vae_loss""" + """Build VAE = encoder + decoder""" # Build Encoder - inputs = Input(shape=(self.n_features_,)) + encoder = [] # Input layer - layer = Dense(self.n_features_, activation=self.hidden_activation)( - inputs) + encoder.append(nn.Linear(self.n_features_, self.n_features_)) + encoder.append(self.hidden_activation_cls()) # Hidden layers + prev_neurons = self.n_features_ for neurons in self.encoder_neurons: - layer = Dense(neurons, activation=self.hidden_activation, - activity_regularizer=l2(self.l2_regularizer))(layer) - layer = Dropout(self.dropout_rate)(layer) + # TODO add activation regularizer + encoder.append(nn.Linear(prev_neurons, neurons)) + encoder.append(self.hidden_activation_cls()) + encoder.append(nn.Dropout(self.dropout_rate)) + prev_neurons = neurons # Create mu and sigma of latent variables - z_mean = Dense(self.latent_dim)(layer) - z_log = Dense(self.latent_dim)(layer) - # Use parametrisation sampling - z = Lambda(self.sampling, output_shape=(self.latent_dim,))( - [z_mean, z_log]) - # Instantiate encoder - encoder = Model(inputs, [z_mean, z_log, z]) + encoder.append(nn.Linear(prev_neurons, 2 * self.latent_dim)) + encoder = nn.Sequential(*encoder) + encoder.to(self.device) + if self.verbose >= 1: - encoder.summary() - - # Build Decoder - latent_inputs = Input(shape=(self.latent_dim,)) + print(encoder) + + decoder = [] # Latent input layer - layer = Dense(self.latent_dim, activation=self.hidden_activation)( - latent_inputs) + decoder.append(nn.Linear(self.latent_dim, self.latent_dim)) + decoder.append(self.hidden_activation_cls()) # Hidden layers + prev_neurons = self.latent_dim for neurons in self.decoder_neurons: - layer = Dense(neurons, activation=self.hidden_activation)(layer) - layer = Dropout(self.dropout_rate)(layer) - # Output layer - outputs = Dense(self.n_features_, activation=self.output_activation)( - layer) - # Instatiate decoder - decoder = Model(latent_inputs, outputs) - if self.verbose >= 1: - decoder.summary() - # Generate outputs - outputs = decoder(encoder(inputs)[2]) - - # Instantiate VAE - vae = Model(inputs, outputs) - vae.add_loss(self.vae_loss(inputs, outputs, z_mean, z_log)) - vae.compile(optimizer=self.optimizer) - if self.verbose >= 1: - vae.summary() - return vae - + decoder.append(nn.Linear(prev_neurons, neurons)) + decoder.append(self.hidden_activation_cls()) + decoder.append(nn.Dropout(self.dropout_rate)) + prev_neurons = neurons + # Create mu and sigma of latent variables + decoder.append(nn.Linear(prev_neurons, self.n_features_)) + decoder.append(self.output_activation_cls()) + decoder = nn.Sequential(*decoder) + decoder.to(self.device) + + return encoder, decoder + + def _fit_vae(self, X, epochs, batch_size, shuffle, validation_split, verbose=False): + dataset = TensorDataset(X) + + generator = torch.Generator().manual_seed(self.random_state) + train_ds, val_ds = random_split(dataset, [1-validation_split, validation_split], generator) + train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle) + val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False) + + optimizer = self.optimizer_cls(chain(self.encoder_.parameters(), self.decoder_.parameters())) + val_losses = [] + + for i in range(epochs): + iterable = tqdm(train_dl) if verbose else train_dl + for X_batch in iterable: + X_batch = X_batch[0].to(self.device) + X_batch = X_batch.float() + X_corrupted, z_mean, z_log, activ_reg = self._forward_vae(X_batch) + loss = self._vae_loss(X_batch, X_corrupted, z_mean, z_log) + # Add activity regularization from encoder Linear layers + # Regularization is divided by batch size to conform with Keras + # https://keras.io/api/layers/regularizers/ + loss += activ_reg * self.l2_regularizer / X_batch.shape[0] + optimizer.zero_grad() + loss.backward() + optimizer.step() + + iterable = tqdm(val_dl) if verbose else val_dl + losses = [] + for X_batch in iterable: + X_batch = X_batch[0].to(self.device) + X_batch = X_batch.float() + with torch.no_grad(): + X_corrupted, z_mean, z_log, activ_reg = self._forward_vae(X_batch) + loss = self._vae_loss(X_batch, X_corrupted, z_mean, z_log) + loss += activ_reg * self.l2_regularizer / X_batch.shape[0] + losses.append(loss.item()) + val_losses.append(np.mean(losses)) + print(f'Epoch: {i}, Val loss: {val_losses[-1]}') + + return val_losses + def fit(self, X, y=None): """Fit detector. y is optional for unsupervised methods. @@ -335,13 +398,13 @@ def fit(self, X, y=None): "the number of features") # Build VAE model & fit with X - self.model_ = self._build_model() - self.history_ = self.model_.fit(X_norm, - epochs=self.epochs, - batch_size=self.batch_size, - shuffle=True, - validation_split=self.validation_size, - verbose=self.verbose).history + self.encoder_, self.decoder_ = self._build_model() + self.history_ = self._fit_vae(torch.from_numpy(X_norm), + epochs=self.epochs, + batch_size=self.batch_size, + shuffle=True, + validation_split=self.validation_size, + verbose=self.verbose) # Predict on X itself and calculate the reconstruction error as # the outlier scores. Noted X_norm was shuffled has to recreate if self.preprocessing: @@ -349,9 +412,7 @@ def fit(self, X, y=None): else: X_norm = np.copy(X) - pred_scores = self.model_.predict(X_norm) - self.decision_scores_ = pairwise_distances_no_broadcast(X_norm, - pred_scores) + self.decision_scores_ = self.decision_function(X) self._process_decision_scores() return self @@ -373,7 +434,7 @@ def decision_function(self, X): anomaly_scores : numpy array of shape (n_samples,) The anomaly score of the input samples. """ - check_is_fitted(self, ['model_', 'history_']) + check_is_fitted(self, ['encoder_', 'decoder_', 'history_']) X = check_array(X) if self.preprocessing: @@ -382,5 +443,8 @@ def decision_function(self, X): X_norm = np.copy(X) # Predict on X and return the reconstruction errors - pred_scores = self.model_.predict(X_norm) + X_norm_pt = torch.from_numpy(X_norm).float().to(self.device) + with torch.no_grad(): + pred_scores, _, _, _ = self._forward_vae(X_norm_pt) + pred_scores = pred_scores.cpu().numpy() return pairwise_distances_no_broadcast(X_norm, pred_scores) diff --git a/pyod/test/test_vae.py b/pyod/test/test_vae.py index 54bb532ea..dff9f3bc4 100644 --- a/pyod/test/test_vae.py +++ b/pyod/test/test_vae.py @@ -46,8 +46,10 @@ def test_parameters(self): self.clf._mu is not None) assert (hasattr(self.clf, '_sigma') and self.clf._sigma is not None) - assert (hasattr(self.clf, 'model_') and - self.clf.model_ is not None) + assert (hasattr(self.clf, 'encoder_') and + self.clf.encoder_ is not None) + assert (hasattr(self.clf, 'decoder_') and + self.clf.decoder_ is not None) def test_train_scores(self): assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0])