Skip to content

Mixture Density Network in PyTorch with full covariance support.

License

Notifications You must be signed in to change notification settings

haimengzhao/full-cov-mdn

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Mixture Density Network in PyTorch with Full Covariance

DOI

Implementation of Mixture Density Network in PyTorch with full covariance matrix support.

The full covariance matrix is implemented via Cholesky decomposition with torch.distributions.MultivariateNormal. See this document for details.

Citation

If you find this repository useful, please cite us using the citation button in the right column provided by GitHub.

Usage

import torch
from mdn import MixtureDensityNetwork

x = torch.randn(5, 1)
data = torch.randn(5, 2)

# 1D input, 2D output, 3 mixture components
model = MixtureDensityNetwork(
    dim_in=1, dim_out=2, n_components=2, 
    full_cov=True 
    # whether to use a full covariance, 
    # default full_cov=True
    )

# returns predicted pi and normal distributions
pi, normal = model(x) 

# compute negative log likelihood 
# as loss function for back prop
loss = model.loss(x, y).mean()

# use this to sample a trained model
samples = model.sample(x)

Example

See example.ipynb for training a 2 component full covariance MDN with the following data:

x\sim\text{Uniform}(0, 1)

and

\mathbb{R}^2\ni\text{data}\sim\text{the following figure}

Data

Note that an MDN with 2 diagonal covariance components can never recover such data.

Reference

The code structure follows this repo, which only supports diagonal covariances.

About

Mixture Density Network in PyTorch with full covariance support.

Resources

License

Stars

Watchers

Forks

Packages

No packages published