Skip to content

Commit

Permalink
Modified truncate_normal with jax.random.truncated_normal
Browse files Browse the repository at this point in the history
  • Loading branch information
mhliu0001 committed Jul 18, 2024
1 parent cff12aa commit 8e9ba5a
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions appletree/randgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,38 @@ def normal(key, mean, std, shape=()):
rvs = random.normal(seed, shape=shape) * std + mean
return key, rvs.astype(FLOAT)


@export
@partial(jit, static_argnums=(5,))
def truncate_normal(key, mean, std, vmin=None, vmax=None, shape=()):
"""Truncated normal distribution random sampler.
Args:
key: seed for random generator.
mean: <jnp.array>-like mean in normal distribution.
std: <jnp.array>-like std in normal distribution.
vmin: <jnp.array>-like min value to clip. By default it's None.
vmin and vmax cannot be both None.
vmax: <jnp.array>-like max value to clip. By default it's None.
vmin and vmax cannot be both None.
shape: parameter passed to normal(..., shape=shape)
Returns:
an updated seed, random variables.
"""
lower_norm, upper_norm = (vmin - mean) / std, (vmax - mean) / std
key, rvs = random.truncated_normal(
key, mean, std, lower_norm, upper_norm, shape=shape
)
rvs = rvs * std + mean
return key, rvs.astype(FLOAT)

@export
@partial(jit, static_argnums=(5,))
def truncate_normal_naive(key, mean, std, vmin=None, vmax=None, shape=()):
"""Truncated normal distribution random sampler, with naive clipping.
This is DEPRECATED because this does not yield a continuous distribution.
Args:
key: seed for random generator.
mean: <jnp.array>-like mean in normal distribution.
Expand All @@ -155,7 +181,6 @@ def truncate_normal(key, mean, std, vmin=None, vmax=None, shape=()):
rvs = jnp.clip(rvs, a_min=vmin, a_max=vmax)
return key, rvs.astype(FLOAT)


@export
@partial(jit, static_argnums=(4,))
def skewnormal(key, a, loc, scale, shape=()):
Expand Down

0 comments on commit 8e9ba5a

Please sign in to comment.