diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index df48e35..fba9823 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -389,7 +389,7 @@ def update_anndata(self, adata): adata.layers["scatrex_mean"] = mean_mat def learn(self, adata, observed_tree=None, counts_layer='counts', allow_subtrees=True, allow_root_subtrees=False, root_cells=None, - batch_size=None, seed=42, weights_concentration=1e6, update_outer_ass=False, + batch_size=None, seed=42, weights_variance=1., update_outer_ass=False, n_epochs=100, mc_samples=10, step_size=0.01, n_iters=10, n_merges=10, n_swaps=10, memoized=True, dp_alpha=.1, dp_gamma=.1): """ Complete NTSSB learning procedure. @@ -401,7 +401,7 @@ def learn(self, adata, observed_tree=None, counts_layer='counts', allow_subtrees self.ntssb = NTSSB(self.observed_tree, node_hyperparams=self.model_args, seed=seed, - weights_concentration=weights_concentration) + weights_variance=weights_variance) self.ntssb.add_data(np.array(adata.layers[counts_layer])) self.ntssb.make_batches(batch_size, seed) self.ntssb.reset_variational_parameters() @@ -1098,7 +1098,7 @@ def plot_tree( mapper = self.ntssb.gene_node_colormaps[kwargs["genemode"]][ "mapper" ][gene_pos] - cbar = plt.colorbar(mapper, label=cbtitle) + cbar = plt.colorbar(mapper, ax=g, label=cbtitle) if kwargs["genemode"] == "observed": n_discrete_levels = self.ntssb.gene_node_colormaps["observed"][ "mapper"