Skip to content

Commit

Permalink
Plot smoothed means before fitting
Browse files Browse the repository at this point in the history
  • Loading branch information
emdupre committed May 6, 2024
1 parent 208205c commit 3856450
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions experiments/naturalistic_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
from functools import partial

import jax
import click
from scipy import stats
import jax.numpy as jnp
Expand All @@ -14,7 +11,7 @@
tfd = tfp.distributions

import bayes_ca.inference as core
from bayes_ca.prox_grad import pgd, pgd_jaxopt
from bayes_ca.prox_grad import pgd_jaxopt
from bayes_ca.data import naturalistic_data
from bayes_ca._utils import _safe_handling_params

Expand Down Expand Up @@ -133,6 +130,14 @@ def main(seed, data_dir, mu_pri, sigmasq_pri, sigmasq_subj, hazard_prob, max_dur
# create train split and initialize globals
global_means = jnp.mean(pca_train, axis=0)

_, _, _, smooth_means = core.gaussian_cp_smoother(
global_means, hazard_rates, mu_pri, sigmasq_pri, sigmasq_obs
)
plt.plot(global_means, label=f"1D input data after FactorAnalysis")
plt.plot(smooth_means, label="smoothed means")
plt.legend()
plt.show()

lps = []
for _ in progress_bar(range(2)):
this_key, key = jr.split(key)
Expand Down

0 comments on commit 3856450

Please sign in to comment.