diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index 881cfa2..b3f222a 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -218,6 +218,7 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 searcher.tree.set_tssb_params(dp_alpha=1., dp_gamma=1.,) searcher.tree.set_node_hyperparams(direction_shape=1.) searcher.tree.sample_variational_distributions(n_samples=mc_samples) + searcher.tree.reset_sufficient_statistics() searcher.tree.update_sufficient_statistics() searcher.tree.compute_elbo(memoized=memoized) searcher.proposed_tree = deepcopy(searcher.tree) @@ -239,13 +240,12 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 update_roots=True, step_size=step_size, mc_samples=mc_samples, memoized=False) - self.ntssb.update_sufficient_statistics() - self.ntssb.compute_elbo(memoized=memoized) - # Propose merges to account for new noise searcher = StructureSearch(self.ntssb) searcher.tree.sample_variational_distributions(n_samples=mc_samples) - searcher.tree.update_sufficient_statistics() + searcher.tree.reset_sufficient_statistics() + for batch_idx in range(len(searcher.tree.batch_indices)): + searcher.tree.update_sufficient_statistics(batch_idx=batch_idx) searcher.tree.compute_elbo(memoized=memoized) searcher.proposed_tree = deepcopy(searcher.tree) key = jax.random.PRNGKey(seed) @@ -253,7 +253,9 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 key, subkey = jax.random.split(key) searcher.merge(subkey, moves_per_tssb=1, memoized=memoized) - searcher.tree.update_sufficient_statistics() + searcher.tree.reset_sufficient_statistics() + for batch_idx in range(len(searcher.tree.batch_indices)): + searcher.tree.update_sufficient_statistics(batch_idx=batch_idx) searcher.tree.compute_elbo(memoized=memoized) searcher.proposed_tree = deepcopy(searcher.tree) # Propose root merges with opt after these merges @@ -261,7 +263,9 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 key, subkey = jax.random.split(key) searcher.merge_root(subkey, memoized=memoized, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size) - searcher.tree.update_sufficient_statistics() + searcher.tree.reset_sufficient_statistics() + for batch_idx in range(len(searcher.tree.batch_indices)): + searcher.tree.update_sufficient_statistics(batch_idx=batch_idx) searcher.tree.compute_elbo(memoized=memoized) searcher.proposed_tree = deepcopy(searcher.tree) # Propose root swaps after these merges @@ -273,13 +277,15 @@ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=1 # Re-learn noise and parameters except scales self.ntssb = deepcopy(searcher.tree) self.ntssb.sample_variational_distributions(n_samples=mc_samples) - self.ntssb.update_sufficient_statistics() + searcher.tree.reset_sufficient_statistics() + for batch_idx in range(len(searcher.tree.batch_indices)): + searcher.tree.update_sufficient_statistics(batch_idx=batch_idx) self.ntssb.compute_elbo(memoized=memoized) self.ntssb.learn_model(n_epochs=n_epochs, update_ass=True, update_globals=True, locals_names=['obs_weights'], globals_names=['factor_weights', 'factor_precisions'], update_roots=True, step_size=step_size, mc_samples=mc_samples, - memoized=False) # If I do this with memoization the noise ends up explaining too much and messing up the scales! Best not to mix + memoized=False) def learn_tree(self, n_iters=10, n_epochs=100, memoized=True, mc_samples=10, step_size=0.01, dp_alpha=.1, dp_gamma=.1, prune=True, seed=42): logger.info("Learning augmented tree")