Skip to content

Commit

Permalink
Merge pull request #20 from MLShukai/feature/#18/models
Browse files Browse the repository at this point in the history
  • Loading branch information
Geson-anko authored Jul 23, 2024
2 parents ced214e + e2d3167 commit ba53310
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 0 deletions.
101 changes: 101 additions & 0 deletions ami/models/components/mixture_desity_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Size, Tensor
from torch.distributions import Distribution, Normal, constraints


class NormalMixture(Distribution):
"""Computes Mixture Density Distribution of Normal distributions."""

SQRT_2_PI = (2 * torch.pi) ** 0.5
arg_constraints = {
"log_pi": constraints.less_than(0.0),
"mu": constraints.real,
"sigma": constraints.positive,
}

def __init__(
self, log_pi: Tensor, mu: Tensor, sigma: Tensor, eps: float = 1e-6, validate_args: bool | None = None
) -> None:
"""Constructor for the NormalMixture class.
This constructor initializes the parameters of the mixture normal distribution and calls the parent class constructor.
log_pi, mu, sigma are must be same shape.
Args:
log_pi: Tensor representing the mixture log ratios of each normal distribution.
mu: Tensor representing the means of each normal distribution.
sigma: Tensor representing the standard deviations of each normal distribution.
eps: A small value for numerical stability.
validate_args: Whether to validate the arguments.
Shape:
log_pi, mu, sigma: (*, Components)
"""
assert log_pi.shape == mu.shape == sigma.shape
batch_shape = log_pi.shape[:-1]
self.num_components = log_pi.size(-1)
self.log_pi = log_pi
self.mu = mu
self.sigma = sigma
self.eps = eps

super().__init__(batch_shape, validate_args=validate_args)

def _get_expand_shape(self, shape: Size) -> tuple[int, ...]:
return *shape, *self.batch_shape, self.num_components

def rsample(self, sample_shape: Size = Size()) -> Tensor:
shape = self._get_expand_shape(sample_shape)

pi = self.log_pi.exp().expand(shape).contiguous()
samples = torch.multinomial(
pi.view(-1, pi.size(-1)),
1,
).view(*pi.shape[:-1], 1)
sample_mu = self.mu.expand(shape).gather(-1, samples).squeeze(-1)
sample_sigma = self.sigma.expand(shape).gather(-1, samples).squeeze(-1)
return torch.randn_like(sample_mu) * sample_sigma + sample_mu

def sample(self, sample_shape: Size = Size()) -> Tensor:
return self.rsample(sample_shape).detach()

def log_prob(self, value: Tensor) -> Tensor:
shape = *value.shape, self.num_components
mu = self.mu.expand(shape)
sigma = self.sigma.expand(shape)
log_pi = self.log_pi.expand(shape)
normal_prob = -0.5 * ((value.unsqueeze(-1) - mu) / (sigma + self.eps)) ** 2 - torch.log(
self.SQRT_2_PI * sigma + self.eps
)
return torch.logsumexp(log_pi + normal_prob, -1)


class NormalMixtureDensityNetwork(nn.Module):
"""A neural network that outputs parameters for a mixture of normal
distributions.
This network takes an input tensor and produces the parameters
(mixture weights, means, and standard deviations) for a mixture of
normal distributions. It can be used as the output layer in a neural
network for tasks that require modeling complex, multi-modal
distributions.
"""

def __init__(self, in_features: int, out_features: int, num_components: int) -> None:
"""
Args:
in_feature: The number of input features.
out_features: The number of output features (dimensionality of each normal distribution).
num_components: The number of mixture components.
"""
super().__init__()
self.mu_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components))
self.sigma_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components))
self.logits_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components))

def forward(self, x: Tensor) -> NormalMixture:
mu = torch.stack([lyr(x) for lyr in self.mu_layers], dim=-1)
sigma = torch.stack([F.softplus(lyr(x)) for lyr in self.sigma_layers], dim=-1)
log_pi = torch.stack([lyr(x) for lyr in self.logits_layers], dim=-1).log_softmax(-1)
return NormalMixture(log_pi, mu, sigma)
101 changes: 101 additions & 0 deletions tests/models/components/test_mixture_density_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import pytest
import torch
import torch.nn as nn
from torch.distributions import Normal

from ami.models.components.mixture_desity_network import (
NormalMixture,
NormalMixtureDensityNetwork,
)


class TestNormalMixture:
@pytest.mark.parametrize("batch_shape", [(), (2,), (2, 3)])
@pytest.mark.parametrize("num_components", [2, 3])
def test_normal_mixture(self, batch_shape, num_components):
# Create test data
log_pi = torch.randn(*batch_shape, num_components).log_softmax(-1)
mu = torch.randn(*batch_shape, num_components)
sigma = torch.rand(*batch_shape, num_components).add_(0.1) # Ensure positive values

# Create NormalMixture instance
mixture = NormalMixture(log_pi, mu, sigma)

# Test batch_shape
assert mixture.batch_shape == torch.Size(batch_shape)

# Test sampling
sample = mixture.sample()
assert sample.shape == torch.Size(batch_shape)

# Test log_prob
log_prob = mixture.log_prob(sample)
assert log_prob.shape == torch.Size(batch_shape)

sample_shape = (10, 5)
samples = mixture.sample(sample_shape)
assert samples.shape == torch.Size(sample_shape + batch_shape)
assert mixture.log_prob(samples).shape == sample_shape + batch_shape

# Test rsample
rsample = mixture.rsample()
assert rsample.shape == torch.Size(batch_shape)

# Test consistency with individual normal components
components = [Normal(mu[..., i], sigma[..., i]) for i in range(num_components)]
mixture_log_prob = mixture.log_prob(sample)
component_log_probs = torch.stack([comp.log_prob(sample) for comp in components], dim=-1)
component_log_probs += log_pi
expected_log_prob = torch.logsumexp(component_log_probs, dim=-1)
assert torch.allclose(mixture_log_prob, expected_log_prob, atol=1e-5)

def test_normal_mixture_invalid_args(self):
# Test error handling for invalid arguments
with pytest.raises(AssertionError):
NormalMixture(torch.randn(3, 2), torch.randn(3, 3), torch.rand(3, 2).add_(0.1))


class TestNormalMixtureDensityNetwork:
@pytest.mark.parametrize("in_features", [10])
@pytest.mark.parametrize("out_features", [5])
@pytest.mark.parametrize("num_components", [2])
@pytest.mark.parametrize("batch_size", [1, 32])
def test_normal_mixture_density_network(self, in_features, out_features, num_components, batch_size):
# Create NormalMixtureDensityNetwork instance
network = NormalMixtureDensityNetwork(in_features, out_features, num_components)

# Create input tensor
x = torch.randn(batch_size, in_features)

# Forward pass
output = network(x)

# Check output type
assert isinstance(output, NormalMixture)

# Check output shapes
assert output.batch_shape == torch.Size([batch_size, out_features])
assert output.event_shape == torch.Size([])
assert output.log_pi.shape == (batch_size, out_features, num_components)
assert output.mu.shape == (batch_size, out_features, num_components)
assert output.sigma.shape == (batch_size, out_features, num_components)

# Check that sigma is positive
assert (output.sigma > 0).all()

# Check that log_pi is a valid log probability
assert torch.allclose(output.log_pi.exp().sum(dim=-1), torch.ones(batch_size, out_features))

def test_normal_mixture_density_network_gradients(self):
in_features, out_features, num_components = 10, 5, 3
network = NormalMixtureDensityNetwork(in_features, out_features, num_components)
x = torch.randn(32, in_features, requires_grad=True)

output = network(x)
sample = output.rsample()
loss = sample.sum()
loss.backward()

# Check that gradients are computed
assert x.grad is not None
assert output.sample().grad is None

0 comments on commit ba53310

Please sign in to comment.