Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add VAE implementation in PyTorch #567

Open
wants to merge 1 commit into
base: development
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 166 additions & 102 deletions pyod/models/vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from itertools import chain

import numpy as np
from sklearn.preprocessing import StandardScaler
Expand All @@ -30,20 +31,40 @@
from ..utils.stat_models import pairwise_distances_no_broadcast
from ..utils.utility import check_parameter

# if tensorflow 2, import from tf directly
if _get_tensorflow_version() == 1:
from keras.models import Model
from keras.layers import Lambda, Input, Dense, Dropout
from keras.regularizers import l2
from keras.losses import mse
from keras import backend as K
else:
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Lambda, Input, Dense, Dropout
from tensorflow.keras.regularizers import l2
from tensorflow.keras.losses import mse
from tensorflow.keras import backend as K

from torch import nn
import torch
from torch.utils.data import TensorDataset, DataLoader, random_split
from torch import optim
from tqdm import tqdm


_activation_classes = {
'relu': nn.ReLU,
'sigmoid': nn.Sigmoid,
'tanh': nn.Tanh
}

def _resolve_activation(activation):
if isinstance(activation, str) and activation in _activation_classes:
return _activation_classes[activation]
elif isinstance(activation, nn.Module):
return activation
else:
raise ValueError(f'Activation must be nn.Module subclass or one of {_activation_classes.keys()}')

def _resolve_loss(loss):
if loss == 'mse':
return nn.MSELoss
elif isinstance(loss, callable):
return loss
else:
raise ValueError(f'Loss must be "mse" or some callable')

def _resolve_optim(optimizer):
if optimizer == 'adam':
return optim.Adam
else:
raise ValueError(f'Only "adam" is supported as optimizer')

class VAE(BaseDetector):
""" Variational auto encoder
Expand Down Expand Up @@ -139,6 +160,9 @@ class VAE(BaseDetector):
The amount of contamination of the data set, i.e.
the proportion of outliers in the data set. When fitting this is
to define the threshold on the decision function.

device : str
Torch device to train and inference autoencoder

Attributes
----------
Expand All @@ -149,11 +173,14 @@ class VAE(BaseDetector):
The ratio between the original feature and
the number of neurons in the encoding layer.

model_ : Keras Object
The underlying AutoEncoder in Keras.
encoder_ : nn.Module
The underlying AutoEncoder Encoder

decoder_ : nn.Module
The underlying AutoEncoder Decoder

history_: Keras Object
The AutoEncoder training history.
history_: List[float]
The AutoEncoder val losses for every epoch

decision_scores_ : numpy array of shape (n_samples,)
The outlier scores of the training data.
Expand All @@ -175,18 +202,22 @@ class VAE(BaseDetector):

def __init__(self, encoder_neurons=None, decoder_neurons=None,
latent_dim=2, hidden_activation='relu',
output_activation='sigmoid', loss=mse, optimizer='adam',
output_activation='sigmoid', loss='mse', optimizer='adam',
epochs=100, batch_size=32, dropout_rate=0.2,
l2_regularizer=0.1, validation_size=0.1, preprocessing=True,
verbose=1, random_state=None, contamination=0.1,
gamma=1.0, capacity=0.0):
verbose=1, random_state=42, contamination=0.1,
gamma=1.0, capacity=0.0, device='cpu'):
super(VAE, self).__init__(contamination=contamination)
self.encoder_neurons = encoder_neurons
self.decoder_neurons = decoder_neurons
self.hidden_activation = hidden_activation
self.hidden_activation_cls = _resolve_activation(hidden_activation)
self.output_activation = output_activation
self.output_activation_cls = _resolve_activation(output_activation)
self.loss = loss
self.loss_fn = _resolve_loss(loss)()
self.optimizer = optimizer
self.optimizer_cls = _resolve_optim(optimizer)
self.epochs = epochs
self.batch_size = batch_size
self.dropout_rate = dropout_rate
Expand All @@ -198,6 +229,7 @@ def __init__(self, encoder_neurons=None, decoder_neurons=None,
self.latent_dim = latent_dim
self.gamma = gamma
self.capacity = capacity
self.device = device

# default values
if self.encoder_neurons is None:
Expand All @@ -212,94 +244,125 @@ def __init__(self, encoder_neurons=None, decoder_neurons=None,
check_parameter(dropout_rate, 0, 1, param_name='dropout_rate',
include_left=True)

def sampling(self, args):
"""Reparametrisation by sampling from Gaussian, N(0,I)
To sample from epsilon = Norm(0,I) instead of from likelihood Q(z|X)
with latent variables z: z = z_mean + sqrt(var) * epsilon

Parameters
----------
args : tensor
Mean and log of variance of Q(z|X).

Returns
-------
z : tensor
Sampled latent variable.
"""

z_mean, z_log = args
batch = K.shape(z_mean)[0] # batch size
dim = K.int_shape(z_mean)[1] # latent dimension
epsilon = K.random_normal(shape=(batch, dim)) # mean=0, std=1.0

return z_mean + K.exp(0.5 * z_log) * epsilon

def vae_loss(self, inputs, outputs, z_mean, z_log):
def _vae_loss(self, inputs, outputs, z_mean, z_log):
""" Loss = Recreation loss + Kullback-Leibler loss
for probability function divergence (ELBO).
gamma > 1 and capacity != 0 for beta-VAE
"""

reconstruction_loss = self.loss(inputs, outputs)
reconstruction_loss = self.loss_fn(inputs, outputs)
reconstruction_loss *= self.n_features_
kl_loss = 1 + z_log - K.square(z_mean) - K.exp(z_log)
kl_loss = -0.5 * K.sum(kl_loss, axis=-1)
kl_loss = self.gamma * K.abs(kl_loss - self.capacity)
kl_loss = 1 + z_log - z_mean ** 2 - torch.exp(z_log)
kl_loss = -0.5 * kl_loss.sum(axis=-1)
kl_loss = self.gamma * (kl_loss - self.capacity).abs()

return (reconstruction_loss + kl_loss).mean()

def _forward_encoder(self, X):
is_encoder_input_layer = True
activity_regularization = 0
for layer in self.encoder_:
X = layer(X)
if isinstance(layer, nn.Linear):
activity_regularization = (X ** 2).sum() + activity_regularization
return X, activity_regularization

def _forward_vae(self, X):
X, activity_regularization = self._forward_encoder(X)
latent_stats = X.reshape(-1, 2, self.latent_dim)
z_mean = latent_stats[:, 0]
z_log = latent_stats[:, 1]

# reparametrization trick
epsilon = torch.randn_like(z_mean) # mean=0, std=1.0
sampled_latent = z_mean + torch.exp(0.5 * z_log) * epsilon

return K.mean(reconstruction_loss + kl_loss)
return self.decoder_(sampled_latent), z_mean, z_log, activity_regularization

def _build_model(self):
"""Build VAE = encoder + decoder + vae_loss"""
"""Build VAE = encoder + decoder"""

# Build Encoder
inputs = Input(shape=(self.n_features_,))
encoder = []
# Input layer
layer = Dense(self.n_features_, activation=self.hidden_activation)(
inputs)
encoder.append(nn.Linear(self.n_features_, self.n_features_))
encoder.append(self.hidden_activation_cls())
# Hidden layers
prev_neurons = self.n_features_
for neurons in self.encoder_neurons:
layer = Dense(neurons, activation=self.hidden_activation,
activity_regularizer=l2(self.l2_regularizer))(layer)
layer = Dropout(self.dropout_rate)(layer)
# TODO add activation regularizer
encoder.append(nn.Linear(prev_neurons, neurons))
encoder.append(self.hidden_activation_cls())
encoder.append(nn.Dropout(self.dropout_rate))
prev_neurons = neurons
# Create mu and sigma of latent variables
z_mean = Dense(self.latent_dim)(layer)
z_log = Dense(self.latent_dim)(layer)
# Use parametrisation sampling
z = Lambda(self.sampling, output_shape=(self.latent_dim,))(
[z_mean, z_log])
# Instantiate encoder
encoder = Model(inputs, [z_mean, z_log, z])
encoder.append(nn.Linear(prev_neurons, 2 * self.latent_dim))
encoder = nn.Sequential(*encoder)
encoder.to(self.device)

if self.verbose >= 1:
encoder.summary()

# Build Decoder
latent_inputs = Input(shape=(self.latent_dim,))
print(encoder)

decoder = []
# Latent input layer
layer = Dense(self.latent_dim, activation=self.hidden_activation)(
latent_inputs)
decoder.append(nn.Linear(self.latent_dim, self.latent_dim))
decoder.append(self.hidden_activation_cls())
# Hidden layers
prev_neurons = self.latent_dim
for neurons in self.decoder_neurons:
layer = Dense(neurons, activation=self.hidden_activation)(layer)
layer = Dropout(self.dropout_rate)(layer)
# Output layer
outputs = Dense(self.n_features_, activation=self.output_activation)(
layer)
# Instatiate decoder
decoder = Model(latent_inputs, outputs)
if self.verbose >= 1:
decoder.summary()
# Generate outputs
outputs = decoder(encoder(inputs)[2])

# Instantiate VAE
vae = Model(inputs, outputs)
vae.add_loss(self.vae_loss(inputs, outputs, z_mean, z_log))
vae.compile(optimizer=self.optimizer)
if self.verbose >= 1:
vae.summary()
return vae

decoder.append(nn.Linear(prev_neurons, neurons))
decoder.append(self.hidden_activation_cls())
decoder.append(nn.Dropout(self.dropout_rate))
prev_neurons = neurons
# Create mu and sigma of latent variables
decoder.append(nn.Linear(prev_neurons, self.n_features_))
decoder.append(self.output_activation_cls())
decoder = nn.Sequential(*decoder)
decoder.to(self.device)

return encoder, decoder

def _fit_vae(self, X, epochs, batch_size, shuffle, validation_split, verbose=False):
dataset = TensorDataset(X)

generator = torch.Generator().manual_seed(self.random_state)
train_ds, val_ds = random_split(dataset, [1-validation_split, validation_split], generator)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle)
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

optimizer = self.optimizer_cls(chain(self.encoder_.parameters(), self.decoder_.parameters()))
val_losses = []

for i in range(epochs):
iterable = tqdm(train_dl) if verbose else train_dl
for X_batch in iterable:
X_batch = X_batch[0].to(self.device)
X_batch = X_batch.float()
X_corrupted, z_mean, z_log, activ_reg = self._forward_vae(X_batch)
loss = self._vae_loss(X_batch, X_corrupted, z_mean, z_log)
# Add activity regularization from encoder Linear layers
# Regularization is divided by batch size to conform with Keras
# https://keras.io/api/layers/regularizers/
loss += activ_reg * self.l2_regularizer / X_batch.shape[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()

iterable = tqdm(val_dl) if verbose else val_dl
losses = []
for X_batch in iterable:
X_batch = X_batch[0].to(self.device)
X_batch = X_batch.float()
with torch.no_grad():
X_corrupted, z_mean, z_log, activ_reg = self._forward_vae(X_batch)
loss = self._vae_loss(X_batch, X_corrupted, z_mean, z_log)
loss += activ_reg * self.l2_regularizer / X_batch.shape[0]
losses.append(loss.item())
val_losses.append(np.mean(losses))
print(f'Epoch: {i}, Val loss: {val_losses[-1]}')

return val_losses

def fit(self, X, y=None):
"""Fit detector. y is optional for unsupervised methods.

Expand Down Expand Up @@ -335,23 +398,21 @@ def fit(self, X, y=None):
"the number of features")

# Build VAE model & fit with X
self.model_ = self._build_model()
self.history_ = self.model_.fit(X_norm,
epochs=self.epochs,
batch_size=self.batch_size,
shuffle=True,
validation_split=self.validation_size,
verbose=self.verbose).history
self.encoder_, self.decoder_ = self._build_model()
self.history_ = self._fit_vae(torch.from_numpy(X_norm),
epochs=self.epochs,
batch_size=self.batch_size,
shuffle=True,
validation_split=self.validation_size,
verbose=self.verbose)
# Predict on X itself and calculate the reconstruction error as
# the outlier scores. Noted X_norm was shuffled has to recreate
if self.preprocessing:
X_norm = self.scaler_.transform(X)
else:
X_norm = np.copy(X)

pred_scores = self.model_.predict(X_norm)
self.decision_scores_ = pairwise_distances_no_broadcast(X_norm,
pred_scores)
self.decision_scores_ = self.decision_function(X)
self._process_decision_scores()
return self

Expand All @@ -373,7 +434,7 @@ def decision_function(self, X):
anomaly_scores : numpy array of shape (n_samples,)
The anomaly score of the input samples.
"""
check_is_fitted(self, ['model_', 'history_'])
check_is_fitted(self, ['encoder_', 'decoder_', 'history_'])
X = check_array(X)

if self.preprocessing:
Expand All @@ -382,5 +443,8 @@ def decision_function(self, X):
X_norm = np.copy(X)

# Predict on X and return the reconstruction errors
pred_scores = self.model_.predict(X_norm)
X_norm_pt = torch.from_numpy(X_norm).float().to(self.device)
with torch.no_grad():
pred_scores, _, _, _ = self._forward_vae(X_norm_pt)
pred_scores = pred_scores.cpu().numpy()
return pairwise_distances_no_broadcast(X_norm, pred_scores)
6 changes: 4 additions & 2 deletions pyod/test/test_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ def test_parameters(self):
self.clf._mu is not None)
assert (hasattr(self.clf, '_sigma') and
self.clf._sigma is not None)
assert (hasattr(self.clf, 'model_') and
self.clf.model_ is not None)
assert (hasattr(self.clf, 'encoder_') and
self.clf.encoder_ is not None)
assert (hasattr(self.clf, 'decoder_') and
self.clf.decoder_ is not None)

def test_train_scores(self):
assert_equal(len(self.clf.decision_scores_), self.X_train.shape[0])
Expand Down
Loading