This repository provides an implementation of a Variational Autoencoder (VAE) using PyTorch. VAEs are a class of deep generative models designed to learn latent representations of data and generate new samples. This README introduces the fundamental concepts of VAEs, details the implementation provided, and guides you through the setup using Poetry.
for inference run poetry run inference
to sample or reconstruct images.
The images will appear in /images
for training run poetry run train
A Variational Autoencoder is a type of autoencoder that models the underlying data distribution through latent variables sampled from a probabilistic space. Unlike traditional autoencoders, VAEs are generative models capable of producing new data similar to the training set.
-
Encoder:
- Maps the input data to a latent space represented by a mean vector (
μ
) and a standard deviation vector (σ
). - The latent space is regularized using a Kullback-Leibler (KL) divergence term to approximate a normal distribution.
- Maps the input data to a latent space represented by a mean vector (
-
Reparameterization Trick:
- Allows backpropagation through the stochastic sampling process.
z = μ + σ * ε
, whereε
is sampled from a standard normal distribution.
-
Decoder:
- Maps the latent variable
z
back to the data space, reconstructing the input data.
- Maps the latent variable
The VAE loss comprises two components:
- Reconstruction Loss: Measures how well the decoded output matches the input.
- KL Divergence: Regularizes the latent space to follow a standard normal distribution.
The implementation includes:
- Encoder: Processes input data into latent representations (
μ
,σ
). - Decoder: Reconstructs data from latent variables.
- Reparameterization Trick: Enables sampling from a latent space while maintaining differentiability.
class VariationalAutoEncoder(nn.Module):
def __init__(self, input_dim, h_dim=200, z_dim=20):
# Encoder and decoder architecture
encode(x)
: Encodes input intoμ
andσ
.decode(z)
: Decodes latent vectorz
back into the input space.forward(x)
: Encodes, reparameterizes, and decodes.
Follow these steps to set up and run the project.
git clone <repository-url>
cd <repository-name>
This project uses Poetry for dependency management. Ensure you have Poetry installed on your system. To install Poetry, follow the official guide.
poetry install
poetry shell
python vae.py
The script processes a random batch of data through the VAE. Below is a quick demonstration:
if __name__ == "__main__":
input_dim = 28 * 28 # for MNIST data
batch_size = 4
x = torch.randn(batch_size, input_dim) # Randomly generated data
vae = VariationalAutoEncoder(input_dim=input_dim)
x_reconstructed, mu, sigma = vae(x)
print(f"Reconstructed Shape: {x_reconstructed.shape}")
print(f"Latent Mean Shape: {mu.shape}")
print(f"Latent Sigma Shape: {sigma.shape}")
Reconstructed Shape: torch.Size([4, 784])
Latent Mean Shape: torch.Size([4, 20])
Latent Sigma Shape: torch.Size([4, 20])
- Input Layer: Flattens the input image (
28x28 → 784
). - Hidden Layer: Encodes features into a hidden representation.
- Outputs: Latent mean (
μ
) and standard deviation (σ
).
- Input: Sampled latent variable
z
. - Hidden Layer: Maps latent variable back to the feature space.
- Output Layer: Reconstructs the original image.
- Fork the repository and create a branch for your feature.
- Follow best practices for Python development.
- Submit a pull request for review.