Skip to content

Commit

Permalink
optimizing param sweep
Browse files Browse the repository at this point in the history
  • Loading branch information
emdupre committed Apr 25, 2024
1 parent b748c7f commit 7034d96
Showing 1 changed file with 59 additions and 86 deletions.
145 changes: 59 additions & 86 deletions experiments/bifurcation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from functools import partial
from itertools import product

import click
import jax.numpy as jnp
from jax import jit, vmap
Expand Down Expand Up @@ -26,23 +29,28 @@ def stagger_data(gap, num_timesteps, num_features):

subj_means = jnp.stack((means_one, means_two), axis=0)

return subj_means
# calculate and return the average of these timeseries
# to use as the initialization for PGD
x0 = jnp.average(subj_means, axis=0)

return subj_means, x0

@jit
def sample_mu0(
gap, sigmasq_subj, x0, num_timesteps, num_features, mu_pri, sigmasq_pri, hazard_rates
):
""" """
means = stagger_data(gap, num_timesteps, num_features)

# @partial(jit, static_argnums=(1, 2))
def sample_mu0(params, num_timesteps, num_features, mu_pri, sigmasq_pri, hazard_rates):
"""
For provided params, generate sample data and perform PGD to sample $\mu_0$.
"""
sigmasq_subj, gap = params
means, x0 = stagger_data(gap, num_timesteps, num_features)
results = pgd_jaxopt(x0, means, mu_pri, sigmasq_pri, sigmasq_subj, hazard_rates)
return results
return results.params


def plot_mu0s(
x0_strategy,
mu_pri,
sigmasq_pri,
sigma_pri,
num_timesteps,
num_features,
hazard_rates,
Expand All @@ -69,7 +77,7 @@ def plot_mu0s(
elif x0_strategy == "average":
x0 = jnp.average(m, axis=0)

results = pgd(x0, m, mu_pri, sigmasq_pri**2, sigma_val**2, hazard_rates)
results = pgd(x0, m, mu_pri, sigma_pri**2, sigma_val**2, hazard_rates)
mu0s.append(results.x)

fig = plt.figure()
Expand All @@ -89,60 +97,40 @@ def plot_mu0s(


def plot_param_sweep(
x0_strategy, mu_pri, num_timesteps, num_features, hazard_rates, max_gap, max_sigma, n_samples
mu_pri,
sigma_pri,
num_timesteps,
num_features,
hazard_rates,
max_gap,
max_sigmasq,
n_samples,
):
""" """
x0_strategy = "average"
max_sigmasq = 9.0
max_gap = 50
n_samples = 50
gaps = jnp.linspace(0, max_gap, n_samples)
"""
Currently only supports "average" x0_strategy, rather than "true" x0.
"""
gaps = jnp.linspace(1, max_gap, n_samples)
sigmasqs = jnp.linspace(0.01, max_sigmasq, n_samples)
sigmas = [jnp.sqrt(s) for s in sigmasqs]

mu0s = vmap(sample_mu0, in_axes=(0, 0, None, None, None, None, None, None))(
gaps, sigmasqs, x0, num_timesteps, num_features, mu_pri, 1.0**2, hazard_rates
)
params = jnp.asarray(list(product(sigmasqs, gaps)))
mu0s = jit(
vmap(sample_mu0, in_axes=(0, None, None, None, None, None)), static_argnums=(1, 2)
)(params, num_timesteps, num_features, mu_pri, sigma_pri**2, hazard_rates)

mu0s = []
# muns = []
for sigmasq in sigmasqs:
means = vmap(stagger_data, in_axes=(0, None, None))(gaps, num_timesteps, num_features)
# muns.append(means)
for m in means:
if x0_strategy == "true":
# the true changepoint
x0 = jnp.concatenate(
(
-1 * jnp.ones((num_timesteps // 2, num_features)),
jnp.ones((num_timesteps // 2, num_features)),
)
)
elif x0_strategy == "average":
x0 = jnp.average(m, axis=0)

results = pgd(x0, m, mu_pri, 1.0**2, sigmasq, hazard_rates)
mu0s.append(results.x)

# muns = jnp.vstack(muns)
mu0s = jnp.asarray(mu0s)
split_changepoints = jnp.full(n_samples**2, False)

for i, mu0 in enumerate(mu0s):
def _count_changepoints(mu0):
""" """
_, counts = jnp.unique(mu0, return_counts=True)
if len(counts) > 2:
split_changepoints = split_changepoints.at[i].set(True)
return jnp.count_nonzero(counts) > 2

change = jnp.repeat(jnp.inf, n_samples)
reshape_bin = jnp.reshape(split_changepoints, (n_samples, n_samples))
for i, r in enumerate(reshape_bin):
try:
change = change.at[i].set(jnp.where(jnp.diff(r, axis=0))[0][0])
except IndexError: # we never see a switch
pass
count_cp = jnp.asarray([_count_changepoints(mu0) for mu0 in mu0s])

# NOTE : This works but I'm not sure why....
split = [gaps[c.astype(int)] for c in change if c is not jnp.inf]
# check, for sigma values, whether we still have 2 changepoints at increasing
# stagger distance...
cp_match = jnp.diff(jnp.reshape(count_cp, (n_samples, n_samples)), axis=1)

# ...and at which index the change to 3 cp's occurs, if it occurs.
res = [jnp.nonzero(c, size=1, fill_value=False) for c in cp_match]

hazard_prob = hazard_rates[0]
b = -jnp.log(hazard_prob / (1 - hazard_prob))
Expand All @@ -157,61 +145,46 @@ def plot_param_sweep(
ax.spines[["left", "top"]].set_visible(False)
ax.set_title(f"Transition from 1 to 2 changepoints")

# # define the colors
# cmap = mpl.colors.ListedColormap(["w", "k"])
# # create a normalize object the describes the limits of
# # each color
# bounds = [0.0, 0.5, 1.0]
# norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

# ax.imshow(
# jnp.reshape(split_changepoints, (n_samples, n_samples)),
# interpolation="none",
# cmap=cmap,
# norm=norm,
# )

return fig


@click.command()
@click.option("--mu_pri", default=0.0, help="")
@click.option("--sigmasq", default=2.0, help="")
@click.option("--sigma_pri", default=1.0, help="")
@click.option("--sigma", default=2.0, help="")
@click.option("--hazard_prob", default=0.01, help="")
@click.option("--num_features", default=1, help="")
@click.option("--num_timesteps", default=300, help="")
@click.option("--x0_strategy", default="average", help="")
def main(mu_pri, sigmasq, hazard_prob, num_features, num_timesteps, x0_strategy):
def main(mu_pri, sigma_pri, sigma, hazard_prob, num_features, num_timesteps, x0_strategy):
""" """
# hardcoded params
sigmasq_pri = 1.0
max_duration = num_timesteps

hazard_rates = hazard_prob * jnp.ones(max_duration)
hazard_rates = hazard_rates.at[-1].set(1.0)

fig1 = plot_mu0s(
x0_strategy,
mu_pri,
sigmasq_pri,
sigma_pri,
num_timesteps,
num_features,
hazard_rates,
max_gap=50,
sigma_val=sigmasq,
sigma_val=sigma,
n_samples=25,
)

# fig2 = plot_param_sweep(
# x0_strategy,
# mu_pri,
# num_timesteps,
# num_features,
# hazard_rates,
# max_gap=50,
# max_sigma=3.0,
# n_samples=50,
# )
fig2 = plot_param_sweep(
mu_pri,
sigma_pri,
num_timesteps,
num_features,
hazard_rates,
max_gap=50,
max_sigmasq=9.0,
n_samples=50,
)

plt.show()

Expand Down

0 comments on commit 7034d96

Please sign in to comment.