Skip to content

Commit

Permalink
Fix suff stats updates
Browse files Browse the repository at this point in the history
  • Loading branch information
pedrofale committed Apr 24, 2024
1 parent f0b95b1 commit c9fd280
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions scatrex/scatrex.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
searcher.tree.set_tssb_params(dp_alpha=1., dp_gamma=1.,)
searcher.tree.set_node_hyperparams(direction_shape=1.)
searcher.tree.sample_variational_distributions(n_samples=mc_samples)
searcher.tree.reset_sufficient_statistics()
searcher.tree.update_sufficient_statistics()
searcher.tree.compute_elbo(memoized=memoized)
searcher.proposed_tree = deepcopy(searcher.tree)
Expand All @@ -239,29 +240,32 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
update_roots=True, step_size=step_size, mc_samples=mc_samples,
memoized=False)

self.ntssb.update_sufficient_statistics()
self.ntssb.compute_elbo(memoized=memoized)

# Propose merges to account for new noise
searcher = StructureSearch(self.ntssb)
searcher.tree.sample_variational_distributions(n_samples=mc_samples)
searcher.tree.update_sufficient_statistics()
searcher.tree.reset_sufficient_statistics()
for batch_idx in range(len(searcher.tree.batch_indices)):
searcher.tree.update_sufficient_statistics(batch_idx=batch_idx)
searcher.tree.compute_elbo(memoized=memoized)
searcher.proposed_tree = deepcopy(searcher.tree)
key = jax.random.PRNGKey(seed)
for i in range(10):
key, subkey = jax.random.split(key)
searcher.merge(subkey, moves_per_tssb=1, memoized=memoized)

searcher.tree.update_sufficient_statistics()
searcher.tree.reset_sufficient_statistics()
for batch_idx in range(len(searcher.tree.batch_indices)):
searcher.tree.update_sufficient_statistics(batch_idx=batch_idx)
searcher.tree.compute_elbo(memoized=memoized)
searcher.proposed_tree = deepcopy(searcher.tree)
# Propose root merges with opt after these merges
for i in range(n_merges):
key, subkey = jax.random.split(key)
searcher.merge_root(subkey, memoized=memoized, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size)

searcher.tree.update_sufficient_statistics()
searcher.tree.reset_sufficient_statistics()
for batch_idx in range(len(searcher.tree.batch_indices)):
searcher.tree.update_sufficient_statistics(batch_idx=batch_idx)
searcher.tree.compute_elbo(memoized=memoized)
searcher.proposed_tree = deepcopy(searcher.tree)
# Propose root swaps after these merges
Expand All @@ -273,13 +277,15 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1
# Re-learn noise and parameters except scales
self.ntssb = deepcopy(searcher.tree)
self.ntssb.sample_variational_distributions(n_samples=mc_samples)
self.ntssb.update_sufficient_statistics()
searcher.tree.reset_sufficient_statistics()
for batch_idx in range(len(searcher.tree.batch_indices)):
searcher.tree.update_sufficient_statistics(batch_idx=batch_idx)
self.ntssb.compute_elbo(memoized=memoized)
self.ntssb.learn_model(n_epochs=n_epochs, update_ass=True, update_globals=True,
locals_names=['obs_weights'],
globals_names=['factor_weights', 'factor_precisions'],
update_roots=True, step_size=step_size, mc_samples=mc_samples,
memoized=False) # If I do this with memoization the noise ends up explaining too much and messing up the scales! Best not to mix
memoized=False)

def learn_tree(self, n_iters=10, n_epochs=100, memoized=True, mc_samples=10, step_size=0.01, dp_alpha=.1, dp_gamma=.1, prune=True, seed=42):
logger.info("Learning augmented tree")
Expand Down

0 comments on commit c9fd280

Please sign in to comment.