Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dustinvtran committed Oct 7, 2017
1 parent c2b61a2 commit 70b6edb
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 35 deletions.
1 change: 1 addition & 0 deletions edward/inferences/conjugacy/conjugacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def normal_from_natural_params(p1, p2):
return {'loc': loc, 'scale': tf.sqrt(sigmasq)}


# TODO
_suff_stat_to_dist = defaultdict(dict)
_suff_stat_to_dist['binary'][(('#x',),)] = (
Bernoulli, lambda p1: {'logits': p1})
Expand Down
16 changes: 0 additions & 16 deletions edward/models/random_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,6 @@ def __init__(self, *args, **kwargs):

del _candidate

# Add supports; these are used, e.g., in conjugacy.
Bernoulli.support = 'binary'
Beta.support = '01'
Binomial.support = 'onehot'
Categorical.support = 'categorical'
Chi2.support = 'nonnegative'
Dirichlet.support = 'simplex'
Exponential.support = 'nonnegative'
Gamma.support = 'nonnegative'
InverseGamma.support = 'nonnegative'
Laplace.support = 'real'
Multinomial.support = 'onehot'
MultivariateNormalDiag.support = 'multivariate_real'
Normal.support = 'real'
Poisson.support = 'countable'

del absolute_import
del division
del print_function
62 changes: 43 additions & 19 deletions edward/util/random_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,28 +753,52 @@ def transform(x, *args, **kwargs):
if len(args) != 0 or kwargs.get('bijector', None) is not None:
return TransformedDistribution(x, *args, **kwargs)

try:
support = x.support
except AttributeError as e:
msg = """'{}' object has no 'support'
so cannot be transformed.""".format(type(x).__name__)
raise AttributeError(msg)

if support == '01':
bij = bijectors.Invert(bijectors.Sigmoid())
new_support = 'real'
elif support == 'nonnegative':
bij = bijectors.Invert(bijectors.Softplus())
new_support = 'real'
elif support == 'simplex':
bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1))
new_support = 'multivariate_real'
elif support in ('real', 'multivariate_real'):
real = (Gumbel,
Laplace,
Logistic,
Normal,
StudentT,
MultivariateNormalDiag,
MultivariateNormalFullCovariance,
MultivariateNormalTriL,
MultivariateNormalDiagPlusLowRank)
if isinstance(x, real):
# Determine if distribution has real support at construction time
# via hard-coded distributions. This prevents adding unnecessary
# ops via a transformation with identity bijector.
return x
else:

if x.support is None or len(x.support) > 1:
msg = "'transform' does not handle supports of type '{}'".format(support)
raise ValueError(msg)

interval, measure = x.support[0]
if measure == 'simplex':
# TODO
pass
elif measure != 'real':
raise

# TODO get event_shape
# TODO compatible dtypes
# TODO tf.fill_like
is_real = tf.logical_and(tf.is_equal(interval[0], tf.constant(-np.inf)),
tf.is_equal(interval[1], tf.constant(np.inf)))
is_01 = tf.logical_and(tf.is_equal(interval[0], tf.constant(0)),
tf.is_equal(interval[1], tf.constant(1)))
is_nonnegative = tf.logical_and(tf.is_equal(interval[0], tf.constant(0)),
tf.is_equal(interval[1], tf.constant(np.inf)))
# TODO
tf.where(is_real, x, tf.where()...)
elif interval == '01':
bij = bijectors.Invert(bijectors.Sigmoid())
elif interval == 'nonnegative':
bij = bijectors.Invert(bijectors.Softplus())
elif interval == 'simplex':
bij = bijectors.Invert(bijectors.SoftmaxCentered(event_ndims=1))
# TODO identity

new_x = TransformedDistribution(x, bij, *args, **kwargs)
new_x.support = new_support
# TODO
new_x.support = [([], 'real')]
return new_x

0 comments on commit 70b6edb

Please sign in to comment.