Skip to content

Commit

Permalink
Add legacy loop back in.
Browse files Browse the repository at this point in the history
  • Loading branch information
alicjapolanska committed May 1, 2024
1 parent e16efd8 commit 2a1af73
Showing 1 changed file with 48 additions and 2 deletions.
50 changes: 48 additions & 2 deletions harmonic/evidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,22 +261,25 @@ def add_chains(self, chains):
lnargs = lnargs.at[jnp.isinf(lnargs)].set(jnp.nan)

else:
lnpred = np.zeros_like(Y)
# lnpred = np.zeros_like(Y)
lnargs = np.zeros_like(Y)
for i_chains in range(nchains):
i_samples_start = chains.start_indices[i_chains]
i_samples_end = chains.start_indices[i_chains + 1]

for i, i_samples in enumerate(range(i_samples_start, i_samples_end)):
lnpredict = self.model.predict(X[i_samples, :])
lnpred[i_samples] = lnpredict
# lnpred[i_samples] = lnpredict

lnprob = Y[i_samples]
lnargs[i_samples] = lnpredict - lnprob

if np.isinf(lnargs[i_samples]):
lnargs[i_samples] = np.nan

# if np.isinf(lnpred[i_samples]):
# lnpred[i_samples] = np.nan

# The following performs a shift in log-space to avoid overflow or float
# rounding errors in realspace.
if not self.shift_set:
Expand Down Expand Up @@ -331,6 +334,49 @@ def get_nans_per_chain(lnargs, mask):
self.lnpredictmax = jnp.nanmax(lnpred)
self.lnpredictmin = jnp.nanmin(lnpred)

else:
for i_chains in range(nchains):
i_samples_start = chains.start_indices[i_chains]
i_samples_end = chains.start_indices[i_chains + 1]

for i, i_samples in enumerate(range(i_samples_start, i_samples_end)):
# Apply shifting term to avoid overflow.
lnarg = lnargs[i_samples] + self.shift_value
# Store realspace or logspace sum depending on choice.
term = np.exp(lnarg)
nsamples_per_chain[i_chains] += 1

if not np.isnan(lnargs[i_samples]):
# Count number of samples used.
nsamples_eff_per_chain[i_chains] += 1

# Add contribution to running sum.
running_sum[i_chains] += term

# Log diagnostic terms.
self.lnargmax = (
lnarg if lnarg > self.lnargmax else self.lnargmax
)
self.lnargmin = (
lnarg if lnarg < self.lnargmin else self.lnargmin
)
self.lnprobmax = (
lnprob if lnprob > self.lnprobmax else self.lnprobmax
)
self.lnprobmin = (
lnprob if lnprob < self.lnprobmin else self.lnprobmin
)
self.lnpredictmax = (
lnpredict
if lnpredict > self.lnpredictmax
else self.lnpredictmax
)
self.lnpredictmin = (
lnpredict
if lnpredict < self.lnpredictmin
else self.lnpredictmin
)

self.process_run()
self.chains_added = True
self.check_basic_diagnostic()
Expand Down

0 comments on commit 2a1af73

Please sign in to comment.