Skip to content

Commit

Permalink
Init bifurcation experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
emdupre committed Apr 19, 2024
1 parent 07b13aa commit 76f61cf
Showing 1 changed file with 83 additions and 0 deletions.
83 changes: 83 additions & 0 deletions experiments/bifurcation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from jax import vmap
import jax.numpy as jnp
import jax.random as jr
import matplotlib.pyplot as plt
from tensorflow_probability.substrates import jax as tfp

from bayes_ca.prox_grad import pgd_jaxopt

tfd = tfp.distributions


def stagger_data(gap, num_timesteps, num_features):
"""
Hardcoding a two subject model with a single, staggered jump between
two Gaussian states with means at -1 and +1.
"""
offset_one = (num_timesteps // 2) - (gap.astype(int) // 2)
means_one = jnp.ones((num_timesteps, num_features))
mask = jnp.arange(num_timesteps) <= offset_one
means_one = jnp.where(mask[:, None], means_one, -1)

offset_two = (num_timesteps // 2) + (gap.astype(int) // 2)
means_two = jnp.ones((num_timesteps, num_features))
mask = jnp.arange(num_timesteps) <= offset_two
means_two = jnp.where(mask[:, None], means_two, -1)

subj_means = jnp.stack((means_one, means_two), axis=0)

return subj_means


def sample_mu0(gap, x0, params):
""" """
(num_timesteps, num_features, mu_pri, sigmasq_pri, sigmasq_subj, hazard_rates) = params
means, _ = stagger_data(gap, num_timesteps, num_features)
results = pgd_jaxopt(x0, means, mu_pri, sigmasq_pri, sigmasq_subj, hazard_rates)
return results


key = jr.PRNGKey(0)

# data settings
num_features = 1
num_subjects = 2
num_timesteps = 300

mu_pri = 0.0
sigmasq_pri = 1.5**2
sigmasq_subj = 1.5**2

# temporal params
num_states = num_timesteps - 1
max_duration = num_timesteps
hazard_prob = 0.01

hazard_rates = hazard_prob * jnp.ones(max_duration)
hazard_rates = hazard_rates.at[-1].set(1.0)

# the true changepoint
x0 = jnp.concatenate(
(
-1 * jnp.ones((num_timesteps // 2, num_features)),
jnp.ones((num_timesteps // 2, num_features)),
)
)

samples = 25
gaps = jnp.linspace(0, 50, samples)
sigmas = jnp.linspace(0.01, 3.0, samples)

for sigma in sigmas:
params = (num_timesteps, num_features, mu_pri, sigma**2, sigma**2, hazard_rates)
for gap in gaps:
results = vmap(sample_mu0, in_axes=(0, None, None))(gaps, x0, params)

# fig = plt.figure()
# ax = plt.subplot(111)
# colors = plt.cm.viridis(jnp.linspace(0, 1, samples))
# for i, mu0 in enumerate(mu0s):
# ax.plot(mu0, c=colors[i], alpha=0.8, label=f"sampled $\mu_0$, {gaps[i]} stagger")
# plt.legend()

# plt.show()

0 comments on commit 76f61cf

Please sign in to comment.