forked from MLShukai/ami
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
202 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
from torch import Size, Tensor | ||
from torch.distributions import Distribution, Normal, constraints | ||
|
||
|
||
class NormalMixture(Distribution): | ||
"""Computes Mixture Density Distribution of Normal distributions.""" | ||
|
||
SQRT_2_PI = (2 * torch.pi) ** 0.5 | ||
arg_constraints = { | ||
"log_pi": constraints.less_than(0.0), | ||
"mu": constraints.real, | ||
"sigma": constraints.positive, | ||
} | ||
|
||
def __init__( | ||
self, log_pi: Tensor, mu: Tensor, sigma: Tensor, eps: float = 1e-6, validate_args: bool | None = None | ||
) -> None: | ||
"""Constructor for the NormalMixture class. | ||
This constructor initializes the parameters of the mixture normal distribution and calls the parent class constructor. | ||
log_pi, mu, sigma are must be same shape. | ||
Args: | ||
log_pi: Tensor representing the mixture log ratios of each normal distribution. | ||
mu: Tensor representing the means of each normal distribution. | ||
sigma: Tensor representing the standard deviations of each normal distribution. | ||
eps: A small value for numerical stability. | ||
validate_args: Whether to validate the arguments. | ||
Shape: | ||
log_pi, mu, sigma: (*, Components) | ||
""" | ||
assert log_pi.shape == mu.shape == sigma.shape | ||
batch_shape = log_pi.shape[:-1] | ||
self.num_components = log_pi.size(-1) | ||
self.log_pi = log_pi | ||
self.mu = mu | ||
self.sigma = sigma | ||
self.eps = eps | ||
|
||
super().__init__(batch_shape, validate_args=validate_args) | ||
|
||
def _get_expand_shape(self, shape: Size) -> tuple[int, ...]: | ||
return *shape, *self.batch_shape, self.num_components | ||
|
||
def rsample(self, sample_shape: Size = Size()) -> Tensor: | ||
shape = self._get_expand_shape(sample_shape) | ||
|
||
pi = self.log_pi.exp().expand(shape).contiguous() | ||
samples = torch.multinomial( | ||
pi.view(-1, pi.size(-1)), | ||
1, | ||
).view(*pi.shape[:-1], 1) | ||
sample_mu = self.mu.expand(shape).gather(-1, samples).squeeze(-1) | ||
sample_sigma = self.sigma.expand(shape).gather(-1, samples).squeeze(-1) | ||
return torch.randn_like(sample_mu) * sample_sigma + sample_mu | ||
|
||
def sample(self, sample_shape: Size = Size()) -> Tensor: | ||
return self.rsample(sample_shape).detach() | ||
|
||
def log_prob(self, value: Tensor) -> Tensor: | ||
shape = *value.shape, self.num_components | ||
mu = self.mu.expand(shape) | ||
sigma = self.sigma.expand(shape) | ||
log_pi = self.log_pi.expand(shape) | ||
normal_prob = -0.5 * ((value.unsqueeze(-1) - mu) / (sigma + self.eps)) ** 2 - torch.log( | ||
self.SQRT_2_PI * sigma + self.eps | ||
) | ||
return torch.logsumexp(log_pi + normal_prob, -1) | ||
|
||
|
||
class NormalMixtureDensityNetwork(nn.Module): | ||
"""A neural network that outputs parameters for a mixture of normal | ||
distributions. | ||
This network takes an input tensor and produces the parameters | ||
(mixture weights, means, and standard deviations) for a mixture of | ||
normal distributions. It can be used as the output layer in a neural | ||
network for tasks that require modeling complex, multi-modal | ||
distributions. | ||
""" | ||
|
||
def __init__(self, in_features: int, out_features: int, num_components: int) -> None: | ||
""" | ||
Args: | ||
in_feature: The number of input features. | ||
out_features: The number of output features (dimensionality of each normal distribution). | ||
num_components: The number of mixture components. | ||
""" | ||
super().__init__() | ||
self.mu_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components)) | ||
self.sigma_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components)) | ||
self.logits_layers = nn.ModuleList(nn.Linear(in_features, out_features) for _ in range(num_components)) | ||
|
||
def forward(self, x: Tensor) -> NormalMixture: | ||
mu = torch.stack([lyr(x) for lyr in self.mu_layers], dim=-1) | ||
sigma = torch.stack([F.softplus(lyr(x)) for lyr in self.sigma_layers], dim=-1) | ||
log_pi = torch.stack([lyr(x) for lyr in self.logits_layers], dim=-1).log_softmax(-1) | ||
return NormalMixture(log_pi, mu, sigma) |
101 changes: 101 additions & 0 deletions
101
tests/models/components/test_mixture_density_network.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
from torch.distributions import Normal | ||
|
||
from ami.models.components.mixture_desity_network import ( | ||
NormalMixture, | ||
NormalMixtureDensityNetwork, | ||
) | ||
|
||
|
||
class TestNormalMixture: | ||
@pytest.mark.parametrize("batch_shape", [(), (2,), (2, 3)]) | ||
@pytest.mark.parametrize("num_components", [2, 3]) | ||
def test_normal_mixture(self, batch_shape, num_components): | ||
# Create test data | ||
log_pi = torch.randn(*batch_shape, num_components).log_softmax(-1) | ||
mu = torch.randn(*batch_shape, num_components) | ||
sigma = torch.rand(*batch_shape, num_components).add_(0.1) # Ensure positive values | ||
|
||
# Create NormalMixture instance | ||
mixture = NormalMixture(log_pi, mu, sigma) | ||
|
||
# Test batch_shape | ||
assert mixture.batch_shape == torch.Size(batch_shape) | ||
|
||
# Test sampling | ||
sample = mixture.sample() | ||
assert sample.shape == torch.Size(batch_shape) | ||
|
||
# Test log_prob | ||
log_prob = mixture.log_prob(sample) | ||
assert log_prob.shape == torch.Size(batch_shape) | ||
|
||
sample_shape = (10, 5) | ||
samples = mixture.sample(sample_shape) | ||
assert samples.shape == torch.Size(sample_shape + batch_shape) | ||
assert mixture.log_prob(samples).shape == sample_shape + batch_shape | ||
|
||
# Test rsample | ||
rsample = mixture.rsample() | ||
assert rsample.shape == torch.Size(batch_shape) | ||
|
||
# Test consistency with individual normal components | ||
components = [Normal(mu[..., i], sigma[..., i]) for i in range(num_components)] | ||
mixture_log_prob = mixture.log_prob(sample) | ||
component_log_probs = torch.stack([comp.log_prob(sample) for comp in components], dim=-1) | ||
component_log_probs += log_pi | ||
expected_log_prob = torch.logsumexp(component_log_probs, dim=-1) | ||
assert torch.allclose(mixture_log_prob, expected_log_prob, atol=1e-5) | ||
|
||
def test_normal_mixture_invalid_args(self): | ||
# Test error handling for invalid arguments | ||
with pytest.raises(AssertionError): | ||
NormalMixture(torch.randn(3, 2), torch.randn(3, 3), torch.rand(3, 2).add_(0.1)) | ||
|
||
|
||
class TestNormalMixtureDensityNetwork: | ||
@pytest.mark.parametrize("in_features", [10]) | ||
@pytest.mark.parametrize("out_features", [5]) | ||
@pytest.mark.parametrize("num_components", [2]) | ||
@pytest.mark.parametrize("batch_size", [1, 32]) | ||
def test_normal_mixture_density_network(self, in_features, out_features, num_components, batch_size): | ||
# Create NormalMixtureDensityNetwork instance | ||
network = NormalMixtureDensityNetwork(in_features, out_features, num_components) | ||
|
||
# Create input tensor | ||
x = torch.randn(batch_size, in_features) | ||
|
||
# Forward pass | ||
output = network(x) | ||
|
||
# Check output type | ||
assert isinstance(output, NormalMixture) | ||
|
||
# Check output shapes | ||
assert output.batch_shape == torch.Size([batch_size, out_features]) | ||
assert output.event_shape == torch.Size([]) | ||
assert output.log_pi.shape == (batch_size, out_features, num_components) | ||
assert output.mu.shape == (batch_size, out_features, num_components) | ||
assert output.sigma.shape == (batch_size, out_features, num_components) | ||
|
||
# Check that sigma is positive | ||
assert (output.sigma > 0).all() | ||
|
||
# Check that log_pi is a valid log probability | ||
assert torch.allclose(output.log_pi.exp().sum(dim=-1), torch.ones(batch_size, out_features)) | ||
|
||
def test_normal_mixture_density_network_gradients(self): | ||
in_features, out_features, num_components = 10, 5, 3 | ||
network = NormalMixtureDensityNetwork(in_features, out_features, num_components) | ||
x = torch.randn(32, in_features, requires_grad=True) | ||
|
||
output = network(x) | ||
sample = output.rsample() | ||
loss = sample.sum() | ||
loss.backward() | ||
|
||
# Check that gradients are computed | ||
assert x.grad is not None | ||
assert output.sample().grad is None |