diff --git a/experiments/bifurcation.py b/experiments/bifurcation.py new file mode 100644 index 0000000..0628f31 --- /dev/null +++ b/experiments/bifurcation.py @@ -0,0 +1,83 @@ +from jax import vmap +import jax.numpy as jnp +import jax.random as jr +import matplotlib.pyplot as plt +from tensorflow_probability.substrates import jax as tfp + +from bayes_ca.prox_grad import pgd_jaxopt + +tfd = tfp.distributions + + +def stagger_data(gap, num_timesteps, num_features): + """ + Hardcoding a two subject model with a single, staggered jump between + two Gaussian states with means at -1 and +1. + """ + offset_one = (num_timesteps // 2) - (gap.astype(int) // 2) + means_one = jnp.ones((num_timesteps, num_features)) + mask = jnp.arange(num_timesteps) <= offset_one + means_one = jnp.where(mask[:, None], means_one, -1) + + offset_two = (num_timesteps // 2) + (gap.astype(int) // 2) + means_two = jnp.ones((num_timesteps, num_features)) + mask = jnp.arange(num_timesteps) <= offset_two + means_two = jnp.where(mask[:, None], means_two, -1) + + subj_means = jnp.stack((means_one, means_two), axis=0) + + return subj_means + + +def sample_mu0(gap, x0, params): + """ """ + (num_timesteps, num_features, mu_pri, sigmasq_pri, sigmasq_subj, hazard_rates) = params + means, _ = stagger_data(gap, num_timesteps, num_features) + results = pgd_jaxopt(x0, means, mu_pri, sigmasq_pri, sigmasq_subj, hazard_rates) + return results + + +key = jr.PRNGKey(0) + +# data settings +num_features = 1 +num_subjects = 2 +num_timesteps = 300 + +mu_pri = 0.0 +sigmasq_pri = 1.5**2 +sigmasq_subj = 1.5**2 + +# temporal params +num_states = num_timesteps - 1 +max_duration = num_timesteps +hazard_prob = 0.01 + +hazard_rates = hazard_prob * jnp.ones(max_duration) +hazard_rates = hazard_rates.at[-1].set(1.0) + +# the true changepoint +x0 = jnp.concatenate( + ( + -1 * jnp.ones((num_timesteps // 2, num_features)), + jnp.ones((num_timesteps // 2, num_features)), + ) +) + +samples = 25 +gaps = jnp.linspace(0, 50, samples) +sigmas = jnp.linspace(0.01, 3.0, samples) + +for sigma in sigmas: + params = (num_timesteps, num_features, mu_pri, sigma**2, sigma**2, hazard_rates) + for gap in gaps: + results = vmap(sample_mu0, in_axes=(0, None, None))(gaps, x0, params) + +# fig = plt.figure() +# ax = plt.subplot(111) +# colors = plt.cm.viridis(jnp.linspace(0, 1, samples)) +# for i, mu0 in enumerate(mu0s): +# ax.plot(mu0, c=colors[i], alpha=0.8, label=f"sampled $\mu_0$, {gaps[i]} stagger") +# plt.legend() + +# plt.show()