Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snorkell.ai] Please review the generated documentation #68

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,24 @@ def create_diffusion(
rescale_learned_sigmas=False,
diffusion_steps=1000
):
""" Create a diffusion model for the given parameters.

This function creates a diffusion model based on the provided parameters, including the timestep respacing, noise schedule, use of KL divergence, sigma size, prediction of xstart, learning of sigma, rescaling of learned sigmas, and diffusion steps.

Args:
timestep_respacing (int or list): The respacing of timesteps for diffusion.
noise_schedule (str?): The schedule for noise. Defaults to "linear".
use_kl (bool?): Whether to use KL divergence. Defaults to False.
sigma_small (bool?): Whether sigma is small. Defaults to False.
predict_xstart (bool?): Whether to predict xstart. Defaults to False.
learn_sigma (bool?): Whether to learn sigma. Defaults to True.
rescale_learned_sigmas (bool?): Whether to rescale learned sigmas. Defaults to False.
diffusion_steps (int?): The number of diffusion steps. Defaults to 1000.

Returns:
SpacedDiffusion: A diffusion model based on the provided parameters.
"""

betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
if use_kl:
loss_type = gd.LossType.RESCALED_KL
Expand Down
61 changes: 41 additions & 20 deletions diffusion/diffusion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,21 @@


def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
""" Compute the KL divergence between two Gaussian distributions.

This function computes the Kullback-Leibler (KL) divergence between two Gaussian distributions. The shapes of the input parameters are automatically broadcasted, allowing for comparisons between batches and scalars.

Args:
mean1 (Tensor or float): The mean of the first Gaussian distribution.
logvar1 (Tensor or float): The log variance of the first Gaussian distribution.
mean2 (Tensor or float): The mean of the second Gaussian distribution.
logvar2 (Tensor or float): The log variance of the second Gaussian distribution.

Returns:
Tensor: The computed KL divergence between the two Gaussian distributions.

Raises:
AssertionError: If all input arguments are not Tensors.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
Expand All @@ -37,20 +48,27 @@ def normal_kl(mean1, logvar1, mean2, logvar2):


def approx_standard_normal_cdf(x):
"""
A fast approximation of the cumulative distribution function of the
standard normal.
""" A fast approximation of the cumulative distribution function of the standard normal.

Args:
x (float): The input value for the standard normal distribution.

Returns:
float: The approximate cumulative distribution function value.
"""
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))


def continuous_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a continuous Gaussian distribution.
:param x: the targets
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).
""" Compute the log-likelihood of a continuous Gaussian distribution.

Args:
x (tensor): The targets.
means (tensor): The Gaussian mean Tensor.
log_scales (tensor): The Gaussian log stddev Tensor.

Returns:
tensor: A tensor like x of log probabilities (in nats).
"""
centered_x = x - means
inv_stdv = th.exp(-log_scales)
Expand All @@ -60,14 +78,17 @@ def continuous_gaussian_log_likelihood(x, *, means, log_scales):


def discretized_gaussian_log_likelihood(x, *, means, log_scales):
"""
Compute the log-likelihood of a Gaussian distribution discretizing to a
""" Compute the log-likelihood of a Gaussian distribution discretizing to a
given image.
:param x: the target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
:param means: the Gaussian mean Tensor.
:param log_scales: the Gaussian log stddev Tensor.
:return: a tensor like x of log probabilities (in nats).

Args:
x (Tensor): The target images. It is assumed that this was uint8 values,
rescaled to the range [-1, 1].
means (Tensor): The Gaussian mean Tensor.
log_scales (Tensor): The Gaussian log stddev Tensor.

Returns:
Tensor: A tensor like x of log probabilities (in nats).
"""
assert x.shape == means.shape == log_scales.shape
centered_x = x - means
Expand Down
Loading