diff --git a/scatrex/models/__init__.py b/scatrex/models/__init__.py index 78adaa0..b43e0b6 100644 --- a/scatrex/models/__init__.py +++ b/scatrex/models/__init__.py @@ -1 +1,3 @@ -from . import cna +from .gaussian import GaussianTree +from .trajectory import TrajectoryTree +from .cna import CNATree \ No newline at end of file diff --git a/scatrex/models/cna/__init__.py b/scatrex/models/cna/__init__.py index b04881c..388d504 100644 --- a/scatrex/models/cna/__init__.py +++ b/scatrex/models/cna/__init__.py @@ -1,3 +1,3 @@ -from .tree import ObservedTree -from .node import Node -from .opt_funcs import * +from .tree import CNATree +from .node import CNANode +from .node_opt import * diff --git a/scatrex/models/cna/node.py b/scatrex/models/cna/node.py index c37f72d..47085af 100644 --- a/scatrex/models/cna/node.py +++ b/scatrex/models/cna/node.py @@ -3,879 +3,1391 @@ from numpy.random import * from functools import partial - -import jax import jax.numpy as jnp -from jax import jit, grad, vmap -from jax import random, ops -from jax.example_libraries import optimizers -from jax.scipy.stats import norm -from jax.scipy.stats import gamma -from jax.scipy.stats import poisson -from jax.scipy.special import logit - -from ...util import * +import jax +import tensorflow_probability.substrates.jax.distributions as tfd + +from .node_opt import * # node optimization functions +from .node_opt import _mc_obs_ll +from ...utils.math_utils import * from ...ntssb.node import * -from ...ntssb.tree import * -MIN_CNV = 1e-6 -MAX_XI = 1 +MIN_ALPHA = jnp.log(0.1) +MAX_BETA = jnp.log(1./0.1) +def update_params(params, params_gradient, step_size): + new_params = [] + for i, param in enumerate(params): + new_params.append(param + step_size * params_gradient[i]) + return new_params -class Node(AbstractNode): +class CNANode(AbstractNode): def __init__( self, - is_observed, - observed_parameters, - log_lib_size_mean=7.1, - log_lib_size_std=0.6, - num_global_noise_factors=4, - global_noise_factors_precisions_shape=2.0, - cell_global_noise_factors_weights_scale=1.0, - unobserved_factors_root_kernel=0.1, - unobserved_factors_kernel=1.0, - unobserved_factors_kernel_concentration=0.01, - unobserved_factors_kernel_rate=1.0, - frac_dosage=1.0, - frac_overlap=0.25, - baseline_shape=0.7, - num_batches=0, + observed_parameters, # copy number state + cell_scale_mean=1e2, + cell_scale_shape=1., + gene_scale_mean=1e2, + gene_scale_shape=1., + direction_shape=.1, + inheritance_strength=1., + n_factors=2, + obs_weight_variance=1., + factor_precision_shape=2., + min_cnv = 1e-6, **kwargs, ): - super(Node, self).__init__(is_observed, observed_parameters, **kwargs) + """ + This model generates nodes in gene expression space combined with observed copy number states + """ + super(CNANode, self).__init__(observed_parameters, **kwargs) # The observed parameters are the CNVs of all genes self.cnvs = np.array(self.observed_parameters) - self.cnvs[np.where(self.cnvs == 0)[0]] = MIN_CNV + self.cnvs[np.where(self.cnvs == 0)[0]] = min_cnv # To avoid zero in Poisson likelihood self.observed_parameters = np.array(self.cnvs) + self.cnvs = jnp.array(self.cnvs) self.n_genes = self.cnvs.size # Node hyperparameters if self.parent() is None: self.node_hyperparams = dict( - log_lib_size_mean=log_lib_size_mean, - log_lib_size_std=log_lib_size_std, - num_global_noise_factors=num_global_noise_factors, - global_noise_factors_precisions_shape=global_noise_factors_precisions_shape, - cell_global_noise_factors_weights_scale=cell_global_noise_factors_weights_scale, - unobserved_factors_root_kernel=unobserved_factors_root_kernel, - unobserved_factors_kernel=unobserved_factors_kernel, - unobserved_factors_kernel_concentration=unobserved_factors_kernel_concentration, - unobserved_factors_kernel_rate=unobserved_factors_kernel_rate, - frac_dosage=frac_dosage, - frac_overlap=frac_overlap, - baseline_shape=baseline_shape, - num_batches=num_batches, + cell_scale_mean=cell_scale_mean, + cell_scale_shape=cell_scale_shape, + gene_scale_mean=gene_scale_mean, + gene_scale_shape=gene_scale_shape, + direction_shape=direction_shape, + inheritance_strength=inheritance_strength, + n_factors=n_factors, + obs_weight_variance=obs_weight_variance, + factor_precision_shape=factor_precision_shape, ) else: self.node_hyperparams = self.node_hyperparams_caller() self.reset_parameters(**self.node_hyperparams) - def inherit_parameters(self): - if not self.is_observed: - # Make sure we use the right observed parameters - self.cnvs = self.parent().cnvs - - def get_node_mean(self, log_baseline, unobserved_factors, noise, cnvs): - node_mean = jnp.exp( - log_baseline + unobserved_factors + noise + jnp.log(cnvs / 2) - ) - sum = jnp.sum(node_mean, axis=1).reshape(self.tssb.ntssb.num_data, 1) - node_mean = node_mean / sum - return node_mean - - def init_noise_factors(self): - # Noise - self.variational_parameters["globals"]["cell_noise_mean"] = np.zeros( - (self.tssb.ntssb.num_data, self.num_global_noise_factors) - ) - self.variational_parameters["globals"]["cell_noise_log_std"] = -np.ones( - (self.tssb.ntssb.num_data, self.num_global_noise_factors) - ) - self.variational_parameters["globals"]["noise_factors_mean"] = np.zeros( - (self.num_global_noise_factors, self.n_genes) - ) - self.variational_parameters["globals"]["noise_factors_log_std"] = -np.ones( - (self.num_global_noise_factors, self.n_genes) - ) - self.variational_parameters["globals"]["factor_precision_log_means"] = np.log( - self.global_noise_factors_precisions_shape - ) * np.ones((self.num_global_noise_factors)) - self.variational_parameters["globals"]["factor_precision_log_stds"] = -np.ones( - (self.num_global_noise_factors) - ) - - # Batch effects - self.variational_parameters["globals"]["batch_effects_mean"] = np.zeros( - (self.num_batches, self.n_genes) - ) - self.variational_parameters["globals"]["batch_effects_log_std"] = -np.ones( - (self.num_batches, self.n_genes) - ) - - def reset_variational_parameters(self, means=True, variances=True): - if self.parent() is None: - # Baseline: first value is 1 - if means: - self.variational_parameters["globals"]["log_baseline_mean"] = np.array( - normal_sample(0, 1, self.n_genes - 1) - ) # np.zeros((self.n_genes-1,)) - if self.tssb.ntssb.data is not None: - self.full_data = jnp.array(self.tssb.ntssb.data) - # init_baseline = np.mean(self.tssb.ntssb.data, axis=0) - # init_log_baseline = np.log(init_baseline / init_baseline[0])[1:] - init_baseline = np.mean( - self.tssb.ntssb.data - / np.sum(self.tssb.ntssb.data, axis=1).reshape(-1, 1) - * self.n_genes, - axis=0, - ) - init_baseline = init_baseline / init_baseline[0] - init_log_baseline = np.log(init_baseline[1:] + 1e-6) - self.variational_parameters["globals"][ - "log_baseline_mean" - ] = np.clip(init_log_baseline, -1, 1) - if variances: - self.variational_parameters["globals"][ - "log_baseline_log_std" - ] = np.array(-2*np.ones((self.n_genes - 1,))) - - # Overdispersion - # self.log_od_mean = np.zeros(1) - # self.log_od_log_std = np.zeros(1) - - if self.tssb.ntssb.num_data is not None: - # Noise - if means: - self.variational_parameters["globals"][ - "cell_noise_mean" - ] = np.random.normal( - 0, - 0.1, - (self.tssb.ntssb.num_data, self.num_global_noise_factors), - ) - if variances: - self.variational_parameters["globals"][ - "cell_noise_log_std" - ] = -np.ones( - (self.tssb.ntssb.num_data, self.num_global_noise_factors) - ) - if means: - self.variational_parameters["globals"][ - "noise_factors_mean" - ] = np.random.normal( - 0, 0.1, (self.num_global_noise_factors, self.n_genes) - ) - if variances: - self.variational_parameters["globals"][ - "noise_factors_log_std" - ] = -np.ones((self.num_global_noise_factors, self.n_genes)) - if means: - self.variational_parameters["globals"][ - "factor_precision_log_means" - ] = np.log(self.global_noise_factors_precisions_shape) * np.ones( - (self.num_global_noise_factors) - ) - if variances: - self.variational_parameters["globals"][ - "factor_precision_log_stds" - ] = -np.ones((self.num_global_noise_factors)) - if means: - self.variational_parameters["globals"][ - "batch_effects_mean" - ] = np.random.normal(0, 0.1, (self.num_batches, self.n_genes)) - if variances: - self.variational_parameters["globals"][ - "batch_effects_log_std" - ] = -np.ones((self.num_batches, self.n_genes)) - - self.data_ass_logits = np.array([]) - if self.tssb.ntssb.num_data is not None: - self.data_ass_logits = np.zeros((self.tssb.ntssb.num_data)) - try: - if self.parent().ll.shape[0] == self.tssb.ntssb.num_data: - self.ll = np.array(self.parent().ll) - self.data_weights = np.array(self.parent().data_weights) - except: - self.ll = -1e10 * np.ones((self.tssb.ntssb.num_data,)) - self.data_weights = np.zeros((self.tssb.ntssb.num_data,)) - if self.is_observed: - self.tssb.data_weights = np.zeros((self.tssb.ntssb.num_data,)) - self.tssb.unnormalized_data_weights = np.zeros((self.tssb.ntssb.num_data,)) - - # Sticks - if means: - self.variational_parameters["locals"]["nu_log_mean"] = np.array( - np.log(1.0 * self.tssb.dp_alpha) + np.random.normal(0, 0.01) - ) - if variances: - self.variational_parameters["locals"]["nu_log_std"] = np.array( - np.log(1.0 * self.tssb.dp_alpha) - ) - if means: - self.variational_parameters["locals"]["psi_log_mean"] = np.array( - np.log(1.0 * self.tssb.dp_gamma) + np.random.normal(0, 0.01) - ) - if variances: - self.variational_parameters["locals"]["psi_log_std"] = np.array( - np.log(1.0 * self.tssb.dp_gamma) - ) - - # Unobserved factors - if means: - try: - self.variational_parameters["locals"][ - "unobserved_factors_mean" - ] = np.array( - self.parent().variational_parameters["locals"][ - "unobserved_factors_mean" - ] - ) - except AttributeError: - self.variational_parameters["locals"][ - "unobserved_factors_mean" - ] = np.zeros((self.n_genes,)) - if variances: - self.variational_parameters["locals"][ - "unobserved_factors_log_std" - ] = -2.0 * np.ones((self.n_genes,)) - if means: - try: - kernel_means = np.clip( - self.unobserved_factors_kernel_concentration_caller() - / ( - np.exp( - (self.parent() != None) - * (self.parent().parent() != None) - * self.unobserved_factors_kernel_rate_caller() - * np.abs( - self.parent().variational_parameters["locals"][ - "unobserved_factors_mean" - ] - ) - ) - ), - self.unobserved_factors_kernel_concentration_caller() / 10, - 1e2, - ) - self.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = (np.log(kernel_means) - (np.exp(-4) ** 2) / 2) - except AttributeError: - self.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.log( - self.unobserved_factors_kernel_concentration_caller() - ) * np.ones( - (self.n_genes,) - ) - if variances: - self.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ] = -4.0 * np.ones((self.n_genes,)) - self.set_mean( - self.get_mean( - baseline=np.append(1, np.exp(self.log_baseline_caller())), - unobserved_factors=self.variational_parameters["locals"][ - "unobserved_factors_mean" - ], - ) - ) + if self.tssb is not None: + self.reset_variational_parameters() + self.sample_variational_distributions() + self.reset_sufficient_statistics(self.tssb.ntssb.num_batches) + # For adaptive optimization + self.reset_opt() + + def reset_opt(self): + # For adaptive optimization + self.direction_states = self.initialize_direction_states() + self.state_states = self.initialize_state_states() + + def apply_clip(self, param, minval=-jnp.inf, maxval=jnp.inf): + param = jnp.maximum(param, minval) + param = jnp.minimum(param, maxval) + return param + + def combine_params(self): + return np.exp(self.params[0]) * 0.5 * self.cnvs # params is a list of two: 0 is \qsi and 1 is \chi + + def get_mean(self): + return self.combine_params() + + def set_mean(self, node_mean=None): + if node_mean is not None: + self.node_mean = node_mean + else: + self.node_mean = self.get_mean() + + def get_observed_parameters(self): + return self.cnvs + + def get_params(self): + return self.get_mean() + + def get_param(self, param='mean'): + if param == 'observed': + return self.get_observed_parameters() + elif param == 'mean': + return self.get_mean() + else: + raise ValueError(f"No param available for `{param}`") + + def remove_noise(self, data): + """ + Noise is multiplicative in this model + """ + return data * 1./self.cell_scales_caller() * 1./self.gene_scales_caller() * 1./jnp.exp(self.noise_factors_caller()) # ========= Functions to initialize node. ========= - def reset_data_parameters(self): - self.full_data = jnp.array(self.tssb.ntssb.data) - self.lib_sizes = np.sum(self.tssb.ntssb.data, axis=1).reshape( - self.tssb.ntssb.num_data, 1 - ) - self.cell_global_noise_factors_weights = normal_sample( - 0, - self.cell_global_noise_factors_weights_scale, - size=[self.tssb.ntssb.num_data, self.num_global_noise_factors], - ) - self.cell_covariates = jnp.array(self.tssb.ntssb.covariates) - self.num_batches = self.cell_covariates.shape[1] - - def generate_data_params(self): - self.lib_sizes = 20.0 + np.exp( - normal_sample( - self.log_lib_size_mean, - self.log_lib_size_std, - size=self.tssb.ntssb.num_data, - ) - ).reshape(self.tssb.ntssb.num_data, 1) - self.lib_sizes = np.ceil(self.lib_sizes) - self.cell_global_noise_factors_weights = normal_sample( - 0, - self.cell_global_noise_factors_weights_scale, - size=[self.tssb.ntssb.num_data, self.num_global_noise_factors], - ) - n_cells = self.tssb.ntssb.num_data - self.cell_covariates = np.zeros((n_cells, self.num_batches)) - if self.num_batches > 1: - batches = np.random.choice(self.num_batches, size=n_cells) - self.cell_covariates[range(n_cells), batches] = 1 - else: - self.cell_covariates = np.zeros((n_cells, 0)) - self.num_batches = 0 + def set_node_hyperparams(self, **kwargs): + self.node_hyperparams.update(**kwargs) def reset_parameters( self, - root_params=True, - down_params=True, - log_lib_size_mean=6, - log_lib_size_std=0.8, - num_global_noise_factors=4, - global_noise_factors_precisions_shape=2.0, - cell_global_noise_factors_weights_scale=1.0, - unobserved_factors_root_kernel=0.1, - unobserved_factors_kernel=1.0, - unobserved_factors_kernel_concentration=0.01, - unobserved_factors_kernel_rate=1.0, - frac_dosage=1.0, - frac_overlap=0.25, - baseline_shape=0.1, - num_batches=0, + cell_scale_mean=1e2, + cell_scale_shape=1., + gene_scale_mean=1e2, + gene_scale_shape=1., + direction_shape=.1, + inheritance_strength=1., + n_factors=2, + obs_weight_variance=1., + factor_precision_shape=2., ): - parent = self.parent() - - if parent is None: # this is the root - self.node_hyperparams = dict( - log_lib_size_mean=log_lib_size_mean, - log_lib_size_std=log_lib_size_std, - num_global_noise_factors=num_global_noise_factors, - global_noise_factors_precisions_shape=global_noise_factors_precisions_shape, - cell_global_noise_factors_weights_scale=cell_global_noise_factors_weights_scale, - unobserved_factors_root_kernel=unobserved_factors_root_kernel, - unobserved_factors_kernel=unobserved_factors_kernel, - unobserved_factors_kernel_concentration=unobserved_factors_kernel_concentration, - unobserved_factors_kernel_rate=unobserved_factors_kernel_rate, - frac_dosage=frac_dosage, - frac_overlap=frac_overlap, - baseline_shape=baseline_shape, - num_batches=num_batches, - ) - - if root_params: - # The root is used to store global parameters: mu - # self.log_baseline = normal_sample(0, 1., size=self.n_genes) - # self.log_baseline[0] = 0. - # self.baseline = np.exp(self.log_baseline) - self.baseline_shape = baseline_shape - self.baseline = np.random.gamma( - self.baseline_shape, 1, size=self.n_genes - ) - # self.baseline = np.concatenate([1, self.baseline]) - self.log_baseline = np.log(self.baseline) - - self.overdispersion = np.exp(normal_sample(0, 1)) - self.log_lib_size_mean = log_lib_size_mean - self.log_lib_size_std = log_lib_size_std - - # Structured noise: keep all cells' noise factors at the root - self.num_global_noise_factors = num_global_noise_factors - self.global_noise_factors_precisions_shape = ( - global_noise_factors_precisions_shape - ) - self.cell_global_noise_factors_weights_scale = ( - cell_global_noise_factors_weights_scale - ) - - self.global_noise_factors_precisions = gamma_sample( - global_noise_factors_precisions_shape, - 1.0, - size=self.num_global_noise_factors, - ) - self.global_noise_factors = normal_sample( - 0, - 1.0 / np.sqrt(self.global_noise_factors_precisions), - size=[self.n_genes, self.num_global_noise_factors], - ).T # K x P - - # Batch effects - self.num_batches = num_batches - self.batch_effects_factors = normal_sample( - 0, - 1.0, - size=[self.n_genes, self.num_batches], - ).T # K x P - - self.unobserved_factors_root_kernel = unobserved_factors_root_kernel - self.unobserved_factors_kernel_concentration = ( - unobserved_factors_kernel_concentration - ) - self.unobserved_factors_kernel_rate = unobserved_factors_kernel_rate - - self.frac_dosage = frac_dosage - self.frac_overlap = frac_overlap - - self.inert_genes = np.random.choice( - self.n_genes, - size=int(self.n_genes * (1.0 - self.frac_dosage)), - replace=False, - ) - - cnv_genes = np.arange( - self.n_genes, - ) - if self.tssb is not None: - cnv_genes = self.tssb.ntssb.input_tree.get_affected_genes() - self.unavailable_genes = np.sort( - np.random.choice( - cnv_genes, - size=int(len(cnv_genes) * (1 - self.frac_overlap)), - replace=False, - ) - ) - - # Root should not have unobserved factors - self.unobserved_factors_kernel = 0 * np.array( - [self.unobserved_factors_root_kernel] * self.n_genes - ) - self.unobserved_factors = 0 * normal_sample( - 0.0, self.unobserved_factors_kernel - ) - - self.set_mean() + self.node_hyperparams = dict( + cell_scale_mean=cell_scale_mean, + cell_scale_shape=cell_scale_shape, + gene_scale_mean=gene_scale_mean, + gene_scale_shape=gene_scale_shape, + direction_shape=direction_shape, + inheritance_strength=inheritance_strength, + n_factors=n_factors, + obs_weight_variance=obs_weight_variance, + factor_precision_shape=factor_precision_shape, + ) + parent = self.parent() + + if parent is None: self.depth = 0.0 + self.params = [np.zeros((self.n_genes,)), np.ones((self.n_genes,))] # state and direction + + # Gene scales + rng = np.random.default_rng(seed=self.seed) + self.gene_scales = rng.gamma(self.node_hyperparams['gene_scale_shape'], self.node_hyperparams['gene_scale_mean']/self.node_hyperparams['gene_scale_shape'], size=(1, self.n_genes)) + + # Structured noise + n_factors = self.node_hyperparams['n_factors'] + factor_precision_shape = self.node_hyperparams['factor_precision_shape'] + + self.factor_precisions = rng.gamma(factor_precision_shape, 1., size=(n_factors,1)) + factor_scales = np.sqrt(1./self.factor_precisions) + self.factor_weights = rng.normal(0., factor_scales, size=(n_factors, self.n_genes)) * 1./np.sqrt(factor_precision_shape) + + if n_factors > 0: + n_genes_per_factor = int(2/n_factors) + offset = np.sqrt(factor_precision_shape) + perm = np.random.permutation(self.n_genes) + for factor in range(n_factors): + gene_idx = perm[factor*n_genes_per_factor:(factor+1)*n_genes_per_factor] + self.factor_weights[factor,gene_idx] *= offset + + # Set data-dependent parameters + if self.tssb is not None: + num_data = self.tssb.ntssb.num_data + if num_data is not None: + self.reset_data_parameters() else: # Non-root node: inherits everything from upstream node - self.node_hyperparams = self.node_hyperparams_caller() - if down_params: - self.unobserved_factors_kernel = gamma_sample( - self.unobserved_factors_kernel_concentration_caller(), - np.exp(np.abs(parent.unobserved_factors)), - ) - unavailable_genes = self.unavailable_genes_caller() - if len(unavailable_genes) > 0: - self.unobserved_factors_kernel[unavailable_genes] = ( - np.min(self.unobserved_factors_kernel) / 2.0 - ) - - # Make sure some genes are affected in unobserved nodes - if not self.is_observed: - top_genes = np.argsort(self.unobserved_factors_kernel)[::-1][:5] - self.unobserved_factors_kernel[top_genes] = np.max( - [5.0, np.max(self.unobserved_factors_kernel)] - ) - self.unobserved_factors = normal_sample( - parent.unobserved_factors, self.unobserved_factors_kernel - ) - self.unobserved_factors = np.clip( - self.unobserved_factors, -MAX_XI, MAX_XI - ) - if not self.is_observed: - self.unobserved_factors[ - top_genes - ] = MAX_XI # just force an amplification - - # Observation mean - self.set_mean() - self.depth = parent.depth + 1 + rng = np.random.default_rng(seed=self.seed) + sampled_direction = rng.gamma(self.node_hyperparams['direction_shape'], + jnp.exp(-self.node_hyperparams['inheritance_strength']*jnp.abs(parent.params[0]))) + sampled_state = jnp.maximum(rng.normal(parent.params[0], sampled_direction), -1) + sampled_state = jnp.minimum(sampled_state, 1) + self.params = [sampled_state, sampled_direction] + + self.set_mean() + + # Generate cell sizes and structure on the factors + def reset_data_parameters(self): + num_data = self.tssb.ntssb.num_data + rng = np.random.default_rng(seed=self.seed) + self.cell_scales = rng.gamma(self.node_hyperparams['cell_scale_shape'], self.node_hyperparams['cell_scale_mean']/self.node_hyperparams['cell_scale_shape'], size=(num_data, 1)) + + n_factors = self.node_hyperparams['n_factors'] + self.obs_weights = rng.normal(0., 1., size=(num_data, n_factors)) * 1./3. + if n_factors > 0: + n_obs_per_factor = int(num_data/n_factors) + offset = 6. + perm = np.random.permutation(num_data) + for factor in range(n_factors): + obs_idx = perm[factor*n_obs_per_factor:(factor+1)*n_obs_per_factor] + self.obs_weights[obs_idx,factor] *= offset + + self.noise_factors = self.obs_weights.dot(self.factor_weights) + + def reset_data_variational_parameters(self): + if self.parent() is None and self.tssb.parent() is None: + num_data = self.tssb.ntssb.num_data + + # Set priors + lib_sizes = np.sum(self.tssb.ntssb.data, axis=1) + self.lib_ratio = np.ones((self.tssb.ntssb.data.shape[0], 1)) + self.lib_ratio *= np.mean(lib_sizes) / np.var(lib_sizes) + gene_sizes = np.sum(self.tssb.ntssb.data, axis=0) + self.gene_ratio = np.mean(gene_sizes) / np.var(gene_sizes) + + cell_scales_alpha_init = self.node_hyperparams['cell_scale_shape'] * jnp.ones((num_data,1)) + cell_scales_beta_init = self.node_hyperparams['cell_scale_shape'] * jnp.ones((num_data,1)) + gene_scales_alpha_init = self.node_hyperparams['gene_scale_shape'] * jnp.ones((self.n_genes,)) + gene_scales_beta_init = self.node_hyperparams['gene_scale_shape'] * jnp.ones((self.n_genes,)) + + rng = np.random.default_rng(self.seed) + # root stores global parameters + n_factors = self.node_hyperparams['n_factors'] + factor_precision_shape = self.node_hyperparams['factor_precision_shape'] + self.variational_parameters["global"] = { + 'gene_scales': {'log_alpha': jnp.log(gene_scales_alpha_init), + 'log_beta': jnp.log(gene_scales_beta_init)}, + 'factor_precisions': {'log_alpha': jnp.log(10. * jnp.ones((n_factors,1))), + 'log_beta' : jnp.log(10./factor_precision_shape * jnp.ones((n_factors,1)))}, + 'factor_weights': {'mean': jnp.array(0.01*rng.normal(size=(n_factors, self.n_genes))), + 'log_std': -2. + jnp.zeros((n_factors, self.n_genes))} + } + rng = np.random.default_rng(self.seed+1) + self.variational_parameters["local"] = { + 'cell_scales': {'log_alpha': jnp.log(cell_scales_alpha_init), + 'log_beta': jnp.log(cell_scales_beta_init)}, + 'obs_weights': {'mean': jnp.array(self.node_hyperparams['obs_weight_variance']/10.*rng.normal(size=(num_data, n_factors))), + 'log_std': -2. + jnp.zeros((num_data, n_factors))} + } + self.cell_scales = jnp.exp(self.variational_parameters["local"]["cell_scales"]["log_alpha"]-self.variational_parameters["local"]["cell_scales"]["log_beta"]) + self.gene_scales = jnp.exp(self.variational_parameters["global"]["gene_scales"]["log_alpha"]-self.variational_parameters["global"]["gene_scales"]["log_beta"]) + self.obs_weights = self.variational_parameters["local"]["obs_weights"]["mean"] + self.factor_precisions = self.variational_parameters["global"]["factor_weights"]["mean"] + self.factor_weights = self.variational_parameters["global"]["factor_weights"]["mean"] + self.noise_factors = self.obs_weights.dot(self.factor_weights) + + def reset_variational_parameters(self): + # Assignments + num_data = self.tssb.ntssb.num_data + if num_data is not None: + self.variational_parameters['q_z'] = jnp.ones(num_data,) + + self.variational_parameters['sum_E_log_1_nu'] = 0. + self.variational_parameters['E_log_phi'] = 0. - def set_mean(self, node_mean=None, variational=False): - if node_mean is not None: - self.node_mean = node_mean - else: - if variational: - self.node_mean = ( - np.append(1, np.exp(self.log_baseline_caller())) - * self.cnvs - / 2 - * np.exp( - (self.parent() is not None) - * self.variational_parameters["locals"][ - "unobserved_factors_mean" - ] - ) - ) - else: - self.node_mean = ( - self.baseline_caller() - * self.cnvs - / 2 - * np.exp(self.unobserved_factors) - ) - self.node_mean = self.node_mean / np.sum(self.node_mean) - - def get_mean( - self, - baseline=None, - unobserved_factors=None, - noise=0.0, - batch_effects=0.0, - cell_factors=None, - global_factors=None, - cnvs=None, - norm=True, - inert_genes=None, - ): - baseline = ( - np.append(1, np.exp(self.log_baseline_caller())) - if baseline is None - else baseline - ) - unobserved_factors = ( - self.variational_parameters["locals"]["unobserved_factors_mean"] - if unobserved_factors is None - else unobserved_factors - ) - unobserved_factors *= self.parent() is not None - cnvs = self.cnvs if cnvs is None else cnvs - if inert_genes is not None: - cnvs = np.array(cnvs) - inert_genes = np.array(inert_genes) - zero_genes = np.where(cnvs == MIN_CNV)[0] - cnvs[ - inert_genes - ] = 2.0 # set these genes to 2, i.e., act like they have no CNV - cnvs[zero_genes] = MIN_CNV # (except if they are zero) - node_mean = ( - baseline * cnvs / 2 * np.exp(unobserved_factors + noise + batch_effects) - ) - if norm: - if len(node_mean.shape) == 1: - sum = np.sum(node_mean) - else: - sum = np.sum(node_mean, axis=1) - if len(sum.shape) > 0: - sum = sum.reshape(-1, 1) - node_mean = node_mean / sum - return node_mean - - def get_noise(self, variational=False): - return self.cell_global_noise_factors_weights_caller( - variational=variational - ).dot(self.global_noise_factors_caller(variational=variational)) - - def get_batch_effects(self, variational=False): - return self.cell_covariates_caller().dot( - self.batch_effects_factors_caller(variational=variational) - ) - - # ========= Functions to take samples from node. ========= - def sample_observation(self, n): - noise = self.cell_global_noise_factors_weights_caller()[n].dot( - self.global_noise_factors_caller() - ) - batch_effects = self.cell_global_noise_factors_weights_caller()[n].dot( - self.global_noise_factors_caller() - ) - node_mean = self.get_mean( - unobserved_factors=self.unobserved_factors, - baseline=self.baseline_caller(), - noise=noise, - batch_effects=batch_effects, - inert_genes=self.inert_genes_caller(), - ) - s = multinomial_sample(self.lib_sizes_caller()[n], node_mean) - # s = negative_binomial_sample(self.lib_sizes_caller()[n] * node_mean, 0.01) - return s - - # ========= Functions to evaluate node's parameters. ========= - def log_baseline_logprior(self, x=None): - if x is None: - x = np.log(self.baseline) - return normal_lpdf(x, 0, 1) - - def log_overdispersion_logprior(self, x=None): - if x is None: - x = np.log(self.overdispersion) - return normal_lpdf(x, 0, 1) - - def global_noise_factors_logprior(self): - return normal_lpdf( - self.global_noise_factors, - 0.0, - 1.0 / np.sqrt(self.global_noise_factors_precisions), - ) - - def cell_global_noise_factors_logprior(self): - return normal_lpdf( - self.cell_global_noise_factors_weights, - 0.0, - self.cell_global_noise_factors_weights_scale, - ) + # Sticks + self.variational_parameters["delta_1"] = 1. + self.variational_parameters["delta_2"] = 1. + self.variational_parameters["sigma_1"] = 1. + self.variational_parameters["sigma_2"] = 1. - def unobserved_factors_logprior(self, x=None): - if x is None: - x = self.unobserved_factors + # Pivots + self.variational_parameters["q_rho"] = np.ones(len(self.tssb.children_root_nodes),) - if self.parent() is not None: - if self.is_observed: - llp = normal_lpdf( - x, - self.parent().unobserved_factors, - self.unobserved_factors_root_kernel, - ) + parent = self.parent() + if parent is None and self.tssb.parent() is None: + self.params = [jnp.zeros((self.n_genes,)), jnp.ones((self.n_genes,))] + + if num_data is not None: + self.reset_data_variational_parameters() else: - llp = normal_lpdf( - x, self.parent().unobserved_factors, self.unobserved_factors_kernel - ) - else: - llp = normal_lpdf(x, 0.0, self.unobserved_factors_root_kernel) - - return llp - - def logprior(self): - # Prior of current node - llp = self.unobserved_factors_logprior() - if self.parent() is None: - llp = ( - llp - + self.global_noise_factors_logprior() - + self.cell_global_noise_factors_logprior() - ) - llp = llp + self.log_baseline_logprior() - - return llp - - def loglh(self, n, variational=False, axis=None): - noise = self.get_noise(variational=variational)[n] - batch_effects = self.get_batch_effects(variational=variational)[n] - unobs_factors = self.unobserved_factors - baseline = self.baseline_caller() - if variational: - unobs_factors = self.variational_parameters["locals"][ - "unobserved_factors_mean" - ] - baseline = np.append(1, np.exp(self.log_baseline_caller(variational=True))) - node_mean = self.get_mean( - baseline=baseline, - unobserved_factors=unobs_factors, - noise=noise, - batch_effects=batch_effects, - ) - lib_sizes = self.lib_sizes_caller()[n] - return partial(jit, static_argnums=2)(poisson_lpmf)( - jnp.array(self.tssb.ntssb.data[n]), lib_sizes * node_mean, axis=axis - ) - - def complete_loglh(self): - return self.loglh(list(self.data)) - - def logprob(self, n): - # Add prior - l = self.loglh(n) + self.logprior() - - # Add prob of children nodes given current node's parameters - for child in self.children(): - l = l + child.unobserved_factors_logprior() - - return l - - # Sum over all data points attached to this node - def complete_logprob(self): - return self.logprob(list(self.data)) - - # Sum over all data - def full_logprob(self): - return self.logprob(list(range(len(self.tssb.ntssb.data)))) - - # ========= Functions to acess root's parameters. ========= - def node_hyperparams_caller(self): - if self.parent() is None: - return self.node_hyperparams - else: - return self.parent().node_hyperparams_caller() - - def unavailable_genes_caller(self): - if self.parent() is None: - return self.unavailable_genes - else: - return self.parent().unavailable_genes_caller() - - def global_noise_factors_precisions_shape_caller(self): - if self.parent() is None: - return self.global_noise_factors_precisions_shape - else: - return self.parent().global_noise_factors_precisions_shape_caller() - - def cell_covariates_caller(self): - if self.parent() is None: - return self.cell_covariates - else: - return self.parent().cell_covariates_caller() - - def lib_sizes_caller(self): - if self.parent() is None: - return self.lib_sizes - else: - return self.parent().lib_sizes_caller() - - def log_kernel_mean_caller(self): - if self.parent() is None: - return self.logkernel_means - else: - return self.parent().log_kernel_mean_caller() - - def kernel_caller(self): - if self.parent() is None: - return self.kernel - else: - return self.parent().kernel_caller() - - def node_std_caller(self): - if self.parent() is None: - return self.node_std - else: - return self.parent().node_std_caller() - - def inert_genes_caller(self): - if self.parent() is None: - return self.inert_genes - else: - return self.parent().inert_genes_caller() - - def log_baseline_caller(self, variational=True): - if self.parent() is None: - if variational: - return self.variational_parameters["globals"]["log_baseline_mean"] + rng = np.random.default_rng(self.seed) + # root stores global parameters + n_factors = self.node_hyperparams['n_factors'] + factor_precision_shape = self.node_hyperparams['factor_precision_shape'] + self.variational_parameters["global"] = { + 'gene_scales': {'log_alpha': jnp.ones((self.n_genes,)), + 'log_beta': jnp.ones((self.n_genes,))}, + 'factor_precisions': {'log_alpha': jnp.log(10. * jnp.ones((n_factors,1))), + 'log_beta' : jnp.log(10./factor_precision_shape * jnp.ones((n_factors,1)))}, + 'factor_weights': {'mean': jnp.array(0.01*rng.normal(size=(n_factors, self.n_genes))), + 'log_std': -2. + jnp.zeros((n_factors, self.n_genes))} + } + else: # only the non-root nodes have kernel variational parameters + # Kernel + parent_param = jnp.zeros((self.n_genes,)) + if parent is not None: + parent_param = parent.params[0] + + rng = np.random.default_rng(self.seed+2) + sampled_direction = rng.gamma(self.node_hyperparams['direction_shape'], + jnp.exp(-self.node_hyperparams['inheritance_strength'] * jnp.abs(parent_param))) + rng = np.random.default_rng(self.seed+3) + if np.all(parent_param == 0): + sampled_state = rng.normal(parent_param*0.1, 0.01) # is root node, so avoid messing with main node attachments else: - return self.log_baseline + sampled_state = jnp.clip(rng.normal(parent_param*0.1, sampled_direction), a_min=-1, a_max=1) # to explore (without numerical explosions) + + init_concentration = 10. + self.variational_parameters["kernel"] = { + 'direction': {'log_alpha': jnp.log(init_concentration*jnp.ones((self.n_genes,))), 'log_beta': jnp.log(init_concentration/self.node_hyperparams['direction_shape'] * jnp.ones((self.n_genes,)))}, + 'state': {'mean': jnp.array(sampled_state), 'log_std': jnp.array(rng.normal(-2., 0.1, size=self.n_genes))} + } + self.params = [self.variational_parameters["kernel"]["state"]["mean"], + jnp.exp(self.variational_parameters["kernel"]["direction"]["log_alpha"]-self.variational_parameters["kernel"]["direction"]["log_beta"])] + + def reset_variational_noise_factors(self): + rng = np.random.default_rng(self.seed) + n_factors = self.node_hyperparams['n_factors'] + factor_precision_shape = self.node_hyperparams['factor_precision_shape'] + self.variational_parameters["global"]["factor_precisions"] = { + 'log_alpha': jnp.log(10. * jnp.ones((n_factors,1))), + 'log_beta' : jnp.log(10./factor_precision_shape * jnp.ones((n_factors,1))) + } + self.variational_parameters["global"]["factor_weights"] = { + 'mean': jnp.array(0.01*rng.normal(size=(n_factors, self.n_genes))), + 'log_std': -2. + jnp.zeros((n_factors, self.n_genes)) + } + num_data = self.tssb.ntssb.num_data + self.variational_parameters["local"]["obs_weights"] = { + 'mean': jnp.array(self.node_hyperparams['obs_weight_variance']/10.*rng.normal(size=(num_data, n_factors))), + 'log_std': -2. + jnp.zeros((num_data, n_factors)) + } + + def set_learned_parameters(self): + if self.parent() is None and self.tssb.parent() is None: + self.obs_weights = self.variational_parameters["local"]["obs_weights"]["mean"] + self.factor_precisions = jnp.exp(self.variational_parameters["global"]["factor_precisions"]["log_alpha"] + -self.variational_parameters["global"]["factor_precisions"]["log_beta"]) + self.factor_weights = self.variational_parameters["global"]["factor_weights"]["mean"] + self.noise_factors = self.obs_weights.dot(self.factor_weights) + self.cell_scales = jnp.exp(self.variational_parameters["local"]["cell_scales"]["log_alpha"] + -self.variational_parameters["local"]["cell_scales"]["log_beta"]) + self.gene_scales = jnp.exp(self.variational_parameters["global"]["gene_scales"]["log_alpha"] + -self.variational_parameters["global"]["gene_scales"]["log_beta"]) else: - return self.parent().log_baseline_caller(variational=variational) - - def baseline_caller(self): - if self.parent() is None: - return self.baseline + self.params = [self.variational_parameters["kernel"]["state"]["mean"], + jnp.exp(self.variational_parameters["kernel"]["direction"]["log_alpha"]-self.variational_parameters["kernel"]["direction"]["log_beta"])] + + def reset_sufficient_statistics(self, num_batches=1): + self.suff_stats = { + 'ent': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(c_n = this tree) q(z_n = this node) * log q(z_n = this node) + 'mass': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) + 'A': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g x_ng * E[\gamma_n] + 'B_g': {'total': 0, 'batch': np.zeros((num_batches,self.n_genes))}, # \sum_n q(z_n = this node) * x_ng + 'C': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g x_ng * E[s_nW_g] + 'D_g': {'total': 0, 'batch': np.zeros((num_batches,self.n_genes))}, # \sum_n q(z_n = this node) * E[\gamma_n] * E[s_nW_g] + 'E': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * lgamma(x_ng+1) + } + if self.parent() is None and self.tssb.parent() is None: + self.local_suff_stats = { + 'locals_kl': {'total': 0., 'batch': np.zeros((num_batches,))}, + } + + def init_new_node_kernel(self, **kwargs): + # Get data for which prob of assigning to parent is > 1/n_total_nodes + weights = self.parent().variational_parameters['q_z'] * self.tssb.variational_parameters['q_c'] + idx = np.where(weights > 1./np.sqrt(self.tssb.ntssb.n_total_nodes))[0] + # Initialize prioritizing cells with lowest ll in parent + if len(idx) > 0 and 'll' in self.parent().variational_parameters: # only proceed in this manner if parent has already been evaluated + thres = np.quantile(self.parent().variational_parameters['ll'][idx], q=.1) + idx = idx[np.where(self.parent().variational_parameters['ll'][idx] < thres)[0]] + self.reset_variational_state(idx=idx, **kwargs) + # Sample + if self.parent().samples is not None: + n_samples = self.parent().samples[0].shape[0] + self.sample_kernel(n_samples=n_samples) + + def init_kernel(self, **kwargs): + # Get data for which prob of assigning here is > 1/n_total_nodes + weights = self.variational_parameters['q_z'] * self.tssb.variational_parameters['q_c'] + idx = np.where(weights > 1./np.sqrt(self.tssb.ntssb.n_total_nodes))[0] + # Initialize prioritizing cells with highest ll here + thres = np.quantile(self.variational_parameters['ll'][idx], q=.9) + idx = idx[np.where(self.variational_parameters['ll'][idx] >= thres)[0]] + self.reset_variational_state(idx=idx, **kwargs) + + def reset_variational_state(self, log_std=-2, idx=None, weights=None): + if self.parent() is None and self.tssb.parent() is None: + return else: - return self.parent().baseline_caller() - - def log_overdispersion_caller(self): - if self.parent() is None: - return self.log_od_mean + if idx is None: + idx = jnp.arange(self.tssb.ntssb.num_data) + if weights is None: + weights = self.variational_parameters['q_z'][idx] * self.tssb.variational_parameters['q_c'][idx] + cell_scales_mean = jnp.mean(self.get_cell_scales_sample()[:,idx],axis=0) + gene_scales_mean = jnp.mean(self.get_gene_scales_sample(),axis=0) + noise_factors = jnp.mean(self.get_noise_sample(idx),axis=0) + cnvs_contrib = self.cnvs/2 + init_state = jnp.log(jnp.sum(self.tssb.ntssb.data[idx]/(cell_scales_mean*gene_scales_mean*cnvs_contrib*jnp.exp(noise_factors)) * weights[:,None],axis=0)/jnp.sum(weights[:,None])) + self.variational_parameters['kernel']['state']['mean'] = init_state + self.variational_parameters['kernel']['state']['log_std'] = log_std * jnp.ones((self.n_genes,)) + + def merge_suff_stats(self, suff_stats): + for stat in self.suff_stats: + self.suff_stats[stat]['total'] += suff_stats[stat]['total'] + self.suff_stats[stat]['batch'] += suff_stats[stat]['batch'] + + def update_sufficient_statistics(self, batch_idx=None): + if batch_idx is not None: + idx = self.tssb.ntssb.batch_indices[batch_idx] else: - return self.parent().log_overdispersion_caller() - - def overdispersion_caller(self): - if self.parent() is None: - return self.overdispersion + idx = jnp.arange(self.tssb.ntssb.num_data) + + if self.parent() is None and self.tssb.parent() is None: + locals_kl = self.compute_local_priors(idx) + self.compute_local_entropies(idx) + if batch_idx is not None: + self.local_suff_stats['locals_kl']['total'] -= self.local_suff_stats['locals_kl']['batch'][batch_idx] + self.local_suff_stats['locals_kl']['batch'][batch_idx] = locals_kl + self.local_suff_stats['locals_kl']['total'] += self.local_suff_stats['locals_kl']['batch'][batch_idx] + else: + self.local_suff_stats['locals_kl']['total'] = locals_kl + + ent = assignment_entropies(self.variational_parameters['q_z'][idx]) + ent *= self.tssb.variational_parameters['q_c'][idx] + E_ass = self.variational_parameters['q_z'][idx] * self.tssb.variational_parameters['q_c'][idx] + E_loggamma = jnp.mean(jnp.log(self.get_cell_scales_sample()[:,idx]),axis=0) + E_gamma = jnp.mean(self.get_cell_scales_sample()[:,idx],axis=0) + E_sw = jnp.mean(self.get_noise_sample(idx),axis=0) + E_expsw = jnp.mean(jnp.exp(self.get_noise_sample(idx)),axis=0) + x = self.tssb.ntssb.data[idx] + + new_ent = jnp.sum(ent) + new_mass = jnp.sum(E_ass) + new_A = jnp.sum(E_ass * E_loggamma.ravel() * jnp.sum(x, axis=1)) + new_B = jnp.sum(E_ass[:,None] * x, axis=0) + new_C = jnp.sum(E_ass * jnp.sum(x * E_sw, axis=1)) + new_D = jnp.sum(E_ass[:,None] * E_gamma * E_expsw, axis=0) + new_E = jnp.sum(E_ass * jnp.sum(gammaln(x+1), axis=1)) + + if batch_idx is not None: + self.suff_stats['ent']['total'] -= self.suff_stats['ent']['batch'][batch_idx] + self.suff_stats['ent']['batch'][batch_idx] = new_ent + self.suff_stats['ent']['total'] += self.suff_stats['ent']['batch'][batch_idx] + + self.suff_stats['mass']['total'] -= self.suff_stats['mass']['batch'][batch_idx] + self.suff_stats['mass']['batch'][batch_idx] = new_mass + self.suff_stats['mass']['total'] += self.suff_stats['mass']['batch'][batch_idx] + + self.suff_stats['A']['total'] -= self.suff_stats['A']['batch'][batch_idx] + self.suff_stats['A']['batch'][batch_idx] = new_A + self.suff_stats['A']['total'] += self.suff_stats['A']['batch'][batch_idx] + + self.suff_stats['B_g']['total'] -= self.suff_stats['B_g']['batch'][batch_idx] + self.suff_stats['B_g']['batch'][batch_idx] = new_B + self.suff_stats['B_g']['total'] += self.suff_stats['B_g']['batch'][batch_idx] + + self.suff_stats['C']['total'] -= self.suff_stats['C']['batch'][batch_idx] + self.suff_stats['C']['batch'][batch_idx] = new_C + self.suff_stats['C']['total'] += self.suff_stats['C']['batch'][batch_idx] + + self.suff_stats['D_g']['total'] -= self.suff_stats['D_g']['batch'][batch_idx] + self.suff_stats['D_g']['batch'][batch_idx] = new_D + self.suff_stats['D_g']['total'] += self.suff_stats['D_g']['batch'][batch_idx] + + self.suff_stats['E']['total'] -= self.suff_stats['E']['batch'][batch_idx] + self.suff_stats['E']['batch'][batch_idx] = new_E + self.suff_stats['E']['total'] += self.suff_stats['E']['batch'][batch_idx] else: - return self.parent().overdispersion_caller() + self.suff_stats['ent']['total'] = new_ent + self.suff_stats['mass']['total'] = new_mass + self.suff_stats['A']['total'] = new_A + self.suff_stats['B_g']['total'] = new_B + self.suff_stats['C']['total'] = new_C + self.suff_stats['D_g']['total'] = new_D + self.suff_stats['E']['total'] = new_E - def unobserved_factors_damp_caller(self): - if self.parent() is None: - return self.unobserved_factors_damp - else: - return self.parent().unobserved_factors_damp_caller() + # ========= Functions to take samples from node. ========= + def sample_observation(self, n): + state = self.params[0] + cnvs = self.cnvs + noise_factors = self.noise_factors_caller()[n] + cell_scales = self.cell_scales_caller()[n] + gene_scales = self.gene_scales_caller() + rng = np.random.default_rng(seed=self.seed+n) + s = rng.poisson(cell_scales*gene_scales*cnvs/2*jnp.exp(state + noise_factors)) + return s - def unobserved_factors_kernel_concentration_caller(self): - if self.parent() is None: - return self.unobserved_factors_kernel_concentration - else: - return self.parent().unobserved_factors_kernel_concentration_caller() + def sample_observations(self): + n_obs = len(self.data) + state = self.params[0] + cnvs = self.cnvs + noise_factors = self.noise_factors_caller()[np.array(list(self.data))] + cell_scales = self.cell_scales_caller()[np.array(list(self.data))] + gene_scales = self.gene_scales_caller() + rng = np.random.default_rng(seed=self.seed) + s = rng.poisson(cell_scales*gene_scales*cnvs/2*jnp.exp(state + noise_factors), size=[n_obs, self.n_genes]) + return s - def unobserved_factors_kernel_rate_caller(self): + # ========= Functions to access root's parameters. ========= + def node_hyperparams_caller(self): if self.parent() is None: - return self.unobserved_factors_kernel_rate + return self.node_hyperparams else: - return self.parent().unobserved_factors_kernel_rate_caller() + return self.parent().node_hyperparams_caller() - def unobserved_factors_root_kernel_caller(self): - if self.parent() is None: - return self.unobserved_factors_root_kernel + def noise_factors_caller(self): + return self.tssb.ntssb.root['node'].root['node'].noise_factors + + def cell_scales_caller(self): + return self.tssb.ntssb.root['node'].root['node'].cell_scales + + def get_cell_scales_sample(self): + return self.tssb.ntssb.root['node'].root['node'].cell_scales_sample + + def gene_scales_caller(self): + return self.tssb.ntssb.root['node'].root['node'].gene_scales + + def get_gene_scales_sample(self): + return self.tssb.ntssb.root['node'].root['node'].gene_scales_sample + + def get_obs_weights_sample(self): + return self.tssb.ntssb.root['node'].root['node'].obs_weights_sample + + def set_local_sample(self, sample, idx=None): + """ + obs_weights, cell_scales + """ + if idx is None: + idx = jnp.arange(self.tssb.ntssb.num_data) + self.tssb.ntssb.root['node'].root['node'].obs_weights_sample = self.tssb.ntssb.root['node'].root['node'].obs_weights_sample.at[:,idx].set(sample[0]) + self.tssb.ntssb.root['node'].root['node'].cell_scales_sample = self.tssb.ntssb.root['node'].root['node'].cell_scales_sample.at[:,idx].set(sample[1]) + + def get_factor_weights_sample(self): + return self.tssb.ntssb.root['node'].root['node'].factor_weights_sample + + def set_global_sample(self, sample): + """ + factor_weights, gene_scales + """ + self.tssb.ntssb.root['node'].root['node'].factor_weights_sample = jnp.array(sample[0]) + self.tssb.ntssb.root['node'].root['node'].gene_scales_sample = jnp.array(sample[1]) + + def get_noise_sample(self, idx): + obs_weights = self.get_obs_weights_sample()[:,idx] + factor_weights = self.get_factor_weights_sample() + return jax.vmap(sample_prod, in_axes=(0,0))(obs_weights,factor_weights) + + def get_direction_sample(self): + return self.samples[1] + + def get_state_sample(self): + return self.samples[0] + + # ======== Functions using the variational parameters. ========= + def compute_loglikelihood(self, idx): + # Use stored samples for loc + node_mean_samples = self.samples[0] + cnvs = self.cnvs + obs_weights_samples = self.get_obs_weights_sample()[:,idx] + factor_weights_samples = self.get_factor_weights_sample() + cell_scales_samples = self.get_cell_scales_sample()[:,idx] + gene_scales_samples = self.get_gene_scales_sample() + # Average over samples for each observation + ll = jnp.mean(jax.vmap(_mc_obs_ll, in_axes=[None,0,None,0,0,0,0])(self.tssb.ntssb.data[idx], + node_mean_samples, + cnvs, + obs_weights_samples, + factor_weights_samples, + cell_scales_samples, + gene_scales_samples, + ), axis=0) # mean over MC samples + return ll + + def compute_loglikelihood_suff(self): + state_samples = self.samples[0] + gene_scales_samples = self.get_gene_scales_sample() + cnv = self.cnvs + ll = jnp.mean(jax.vmap(ll_suffstats, in_axes=[0, None, 0, None, None, None, None, None]) + (state_samples, cnv, gene_scales_samples, self.suff_stats['A']['total'], self.suff_stats['B_g']['total'], + self.suff_stats['C']['total'], self.suff_stats['D_g']['total'], self.suff_stats['E']['total']) + ) + return ll + + def sample_variational_distributions(self, n_samples=10): + if self.parent() is not None: + if self.parent().samples is not None: + n_samples = self.parent().samples[0].shape[0] + if self.parent() is None and self.tssb.parent() is None: + self.sample_locals(n_samples=n_samples, store=True) + self.sample_globals(n_samples=n_samples, store=True) + self.sample_kernel(n_samples=n_samples, store=True) + + def sample_locals(self, n_samples, store=True): + key = jax.random.PRNGKey(self.seed) + key, sample_grad = self.local_sample_and_grad(jnp.arange(self.tssb.ntssb.num_data), key, n_samples=n_samples) + sample, _ = sample_grad + obs_weights_sample, cell_scales_sample = sample + if store: + self.obs_weights_sample = obs_weights_sample + self.cell_scales_sample = cell_scales_sample else: - return self.parent().unobserved_factors_root_kernel_caller() - - def unobserved_factors_kernel_caller(self): - if self.parent() is None: - return self.unobserved_factors_kernel + return obs_weights_sample, cell_scales_sample + + def sample_globals(self, n_samples, store=True): + key = jax.random.PRNGKey(self.seed) + key, sample_grad = self.global_sample_and_grad(key, n_samples=n_samples) + sample, _ = sample_grad + factor_weights_sample, gene_scales_sample, factor_precisions_sample = sample + if store: + self.factor_weights_sample = factor_weights_sample + self.gene_scales_sample = gene_scales_sample + self.factor_precisions_sample = factor_precisions_sample else: - return self.parent().unobserved_factors_kernel_caller() - - def cell_global_noise_factors_weights_caller(self, variational=False): - if self.parent() is None: - if variational: - return self.variational_parameters["globals"]["cell_noise_mean"] - else: - return self.cell_global_noise_factors_weights + return factor_weights_sample, gene_scales_sample, factor_precisions_sample + + def sample_kernel(self, n_samples=10, store=True): + if self.parent() is None and self.tssb.parent() is None: + return self._sample_root_kernel(n_samples=n_samples, store=store) + + key = jax.random.PRNGKey(self.seed) + key, sample_grad = self.direction_sample_and_grad(key, n_samples=n_samples) + sampled_angle, _ = sample_grad + + key, sample_grad = self.state_sample_and_grad(key, n_samples=n_samples) + sampled_loc, _ = sample_grad + + samples = [sampled_loc, sampled_angle] + if store: + self.samples = samples else: - return self.parent().cell_global_noise_factors_weights_caller( - variational=variational - ) - - def global_noise_factors_caller(self, variational=False): - if self.parent() is None: - if variational: - return self.variational_parameters["globals"]["noise_factors_mean"] - else: - return self.global_noise_factors - else: - return self.parent().global_noise_factors_caller(variational=variational) - - def batch_effects_factors_caller(self, variational=False): - if self.parent() is None: - if variational: - return self.variational_parameters["globals"]["batch_effects_mean"] - else: - return self.batch_effects_factors + return samples + + def _sample_root_kernel(self, n_samples=10, store=True): + # In this model the complete tree's root parameters are fixed and not learned, so just store n_samples copies of them to mimic a sample + sampled_direction = jnp.vstack(jnp.repeat(jnp.array([self.params[1]]), n_samples, axis=0)) + sampled_state = jnp.vstack(jnp.repeat(jnp.array([self.params[0]]), n_samples, axis=0)) + samples = [sampled_state, sampled_direction] + if store: + self.samples = samples else: - return self.parent().batch_effects_factors_caller(variational=variational) - - def set_event_string( - self, - var_names=None, - estimated=True, - unobs_threshold=1.0, - kernel_threshold=1.0, - max_len=5, - event_fontsize=14, - ): - if var_names is None: - var_names = np.arange(self.n_genes).astype(int).astype(str) - - unobserved_factors = self.unobserved_factors - unobserved_factors_kernel = self.unobserved_factors_kernel - if estimated: - unobserved_factors = self.variational_parameters["locals"][ - "unobserved_factors_mean" - ] - unobserved_factors_kernel = np.exp( - self.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] - ) - - # Up-regulated - up_color = "red" - up_list = np.where( - np.logical_and( - unobserved_factors > unobs_threshold, - unobserved_factors_kernel > kernel_threshold, - ) - )[0] - sorted_idx = np.argsort(unobserved_factors[up_list])[:max_len] - up_list = up_list[sorted_idx] - up_str = "" - if len(up_list) > 0: - up_str = ( - f'+' - + ",".join(var_names[up_list]) - ) - - # Down-regulated - down_color = "blue" - down_list = np.where( - np.logical_and( - unobserved_factors < -unobs_threshold, - unobserved_factors_kernel > kernel_threshold, - ) - )[0] - sorted_idx = np.argsort(-unobserved_factors[down_list])[:max_len] - down_list = down_list[sorted_idx] - down_str = "" - if len(down_list) > 0: - down_str = ( - f'-' - + ",".join(var_names[down_list]) - ) - - self.event_str = up_str - sep_str = "" - if len(up_list) > 0 and len(down_list) > 0: - sep_str = "

" - self.event_str = self.event_str + sep_str + down_str + return samples + + def compute_global_priors(self): + factor_weights_contrib = jnp.sum(jnp.mean(mc_factor_weights_logp_val_and_grad(self.factor_weights_sample, 0., self.factor_precisions_sample)[0], axis=0)) + log_alpha = jnp.log(self.node_hyperparams['gene_scale_shape']) + log_beta = jnp.log(self.node_hyperparams['gene_scale_shape'] * self.gene_ratio) + gene_scales_contrib = jnp.sum(jnp.mean(mc_gene_scales_logp_val_and_grad(self.gene_scales_sample, log_alpha, log_beta)[0], axis=0)) + log_alpha = jnp.log(self.node_hyperparams['factor_precision_shape']) + log_beta = jnp.log(1.) + factor_precisions_contrib = jnp.sum(jnp.mean(mc_factor_precisions_logp_val_and_grad(self.factor_precisions_sample, log_alpha, log_beta)[0], axis=0)) + return factor_weights_contrib + gene_scales_contrib + factor_precisions_contrib + + def compute_local_priors(self, batch_indices): + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_weight_variance'])) + obs_weights_contrib = jnp.sum(jnp.mean(mc_obs_weights_logp_val_and_grad(self.obs_weights_sample[:,batch_indices], 0., log_std)[0], axis=0)) + log_alpha = jnp.log(self.node_hyperparams['cell_scale_shape']) + log_beta = jnp.log(self.node_hyperparams['cell_scale_shape'] * self.lib_ratio) + cell_scales_contrib = jnp.sum(jnp.mean(mc_cell_scales_logp_val_and_grad(self.cell_scales_sample[batch_indices], log_alpha, log_beta)[0], axis=0)) + return obs_weights_contrib + cell_scales_contrib + + def compute_global_entropies(self): + mean = self.variational_parameters['global']['factor_weights']['mean'] + log_std = self.variational_parameters['global']['factor_weights']['log_std'] + factor_weights_contrib = jnp.sum(factor_weights_logq_val_and_grad(mean, log_std)[0]) + + log_alpha = self.variational_parameters['global']['gene_scales']['log_alpha'] + log_beta = self.variational_parameters['global']['gene_scales']['log_beta'] + gene_scales_contrib = jnp.sum(gene_scales_logq_val_and_grad(log_alpha, log_beta)[0]) + + log_alpha = self.variational_parameters['global']['factor_precisions']['log_alpha'] + log_beta = self.variational_parameters['global']['factor_precisions']['log_beta'] + factor_precisions_contrib = jnp.sum(factor_precisions_logq_val_and_grad(log_alpha, log_beta)[0]) + + return factor_weights_contrib + gene_scales_contrib + factor_precisions_contrib + + def compute_local_entropies(self, batch_indices): + mean = self.variational_parameters['local']['obs_weights']['mean'][batch_indices] + log_std = self.variational_parameters['local']['obs_weights']['log_std'][batch_indices] + obs_weights_contrib = jnp.sum(obs_weights_logq_val_and_grad(mean, log_std)[0]) + + log_alpha = self.variational_parameters['local']['cell_scales']['log_alpha'][batch_indices] + log_beta = self.variational_parameters['local']['cell_scales']['log_beta'][batch_indices] + cell_scales_contrib = jnp.sum(cell_scales_logq_val_and_grad(log_alpha, log_beta)[0]) + return obs_weights_contrib + cell_scales_contrib + + def compute_kernel_prior(self): + parent = self.parent() + if parent is None: + return self.compute_root_prior() + + parent_state = self.parent().get_state_sample() + direction_samples = self.get_direction_sample() + + direction_shape = self.node_hyperparams['direction_shape'] + inheritance_strength = self.node_hyperparams['inheritance_strength'] + + direction_logpdf = mc_direction_logp_val_and_grad(direction_samples, parent_state, direction_shape, inheritance_strength)[0] + direction_logpdf = jnp.sum(direction_logpdf,axis=1) + + state_samples = self.get_state_sample() + state_logpdf = mc_state_logp_val_and_grad(state_samples, parent_state, direction_samples)[0] + state_logpdf = jnp.sum(state_logpdf, axis=1) + + return jnp.mean(direction_logpdf + state_logpdf) + + def compute_root_direction_prior(self, parent_state): + direction_samples = self.get_direction_sample() + direction_shape = self.node_hyperparams['direction_shape'] + inheritance_strength = self.node_hyperparams['inheritance_strength'] + return jnp.mean(jnp.sum(mc_direction_logp_val_and_grad(direction_samples, parent_state, direction_shape, inheritance_strength)[0], axis=1)) + + def compute_root_state_prior(self, parent_state): + direction_samples = jnp.sqrt(self.get_direction_sample()) + state_samples = self.get_state_sample() + return jnp.mean(jnp.sum(mc_state_logp_val_and_grad(state_samples, parent_state, direction_samples)[0], axis=1)) + + def compute_root_kernel_prior(self, samples): + parent_state = samples[0] + logp = self.compute_root_direction_prior(parent_state) + logp += self.compute_root_state_prior(parent_state) + return logp + + def compute_root_prior(self): + return 0. + + def compute_kernel_entropy(self): + parent = self.parent() + if parent is None: + return self.compute_root_entropy() + + # Direction + direction_logpdf = tfd.Gamma(np.exp(self.variational_parameters['kernel']['direction']['log_alpha']), + jnp.exp(self.variational_parameters['kernel']['direction']['log_beta']) + ).entropy() + direction_logpdf = jnp.sum(direction_logpdf) + + # State + state_logpdf = tfd.Normal(self.variational_parameters['kernel']['state']['mean'], + jnp.exp(self.variational_parameters['kernel']['state']['log_std']) + ).entropy() + state_logpdf = jnp.sum(state_logpdf) # Sum across features + + return direction_logpdf + state_logpdf + + def compute_root_entropy(self): + # In this model the root nodes have no unknown parameters + return 0. + + # ======== Functions for updating the variational parameters. ========= + def local_sample_and_grad(self, idx, key, n_samples): + """Sample and take gradient of local parameters. Must be root""" + mean = self.variational_parameters['local']['obs_weights']['mean'][idx] + log_std = self.variational_parameters['local']['obs_weights']['log_std'][idx] + key, *sub_keys = jax.random.split(key, n_samples+1) + obs_weights_sample_grad = mc_sample_obs_weights_val_and_grad(jnp.array(sub_keys), mean, log_std) + obs_weights_sample = obs_weights_sample_grad[0] + # obs_weights_sample = obs_weights_sample.at[:,0,:].set(0.) + + log_alpha = self.variational_parameters['local']['cell_scales']['log_alpha'][idx] + log_beta = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + key, *sub_keys = jax.random.split(key, n_samples+1) + cell_scales_sample_grad = mc_sample_cell_scales_val_and_grad(jnp.array(sub_keys), log_alpha, log_beta) + + sample = [obs_weights_sample, cell_scales_sample_grad[0]] + grad = [obs_weights_sample_grad[1], cell_scales_sample_grad[1]] + + return key, (sample, grad) + + def global_sample_and_grad(self, key, n_samples): + """Sample and take gradient of global parameters. Must be root""" + mean = self.variational_parameters['global']['factor_weights']['mean'] + log_std = self.variational_parameters['global']['factor_weights']['log_std'] + key, *sub_keys = jax.random.split(key, n_samples+1) + factor_weights_sample_grad = mc_sample_factor_weights_val_and_grad(jnp.array(sub_keys), mean, log_std) + factor_weights_sample = factor_weights_sample_grad[0] + # factor_weights_sample = factor_weights_sample.at[:,:,0].set(0.) + + log_alpha = self.variational_parameters['global']['gene_scales']['log_alpha'] + log_beta = self.variational_parameters['global']['gene_scales']['log_beta'] + key, *sub_keys = jax.random.split(key, n_samples+1) + gene_scales_sample_grad = mc_sample_gene_scales_val_and_grad(jnp.array(sub_keys), log_alpha, log_beta) + # gene_scales_sample = gene_scales_sample_grad[0] + # gene_scales_sample = gene_scales_sample.at[jnp.arange(n_samples),0].set(1.) + + log_alpha = self.variational_parameters['global']['factor_precisions']['log_alpha'] + log_beta = self.variational_parameters['global']['factor_precisions']['log_beta'] + key, *sub_keys = jax.random.split(key, n_samples+1) + factor_precisions_sample_grad = mc_sample_factor_precisions_val_and_grad(jnp.array(sub_keys), log_alpha, log_beta) + + sample = [factor_weights_sample, gene_scales_sample_grad[0], factor_precisions_sample_grad[0]] + grad = [factor_weights_sample_grad[1], gene_scales_sample_grad[1], factor_precisions_sample_grad[1]] + + return key, (sample, grad) + + def compute_locals_prior_grad(self, sample): + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_weight_variance'])) + obs_weights_grad = mc_obs_weights_logp_val_and_grad(sample[0], 0., log_std)[1] + log_alpha = jnp.log(self.node_hyperparams['cell_scale_shape']) + log_beta = jnp.log(self.node_hyperparams['cell_scale_shape'] * self.lib_ratio) + cell_scales_grad = mc_cell_scales_logp_val_and_grad(sample[1], log_alpha, log_beta)[1] + return obs_weights_grad, cell_scales_grad + + def compute_globals_prior_grad(self, sample): + factor_weights_grad = mc_factor_weights_logp_val_and_grad(sample[0], 0., sample[2])[1] + log_alpha = jnp.log(self.node_hyperparams['gene_scale_shape']) + log_beta = jnp.log(self.node_hyperparams['gene_scale_shape'] * self.gene_ratio) + gene_scales_grad = mc_gene_scales_logp_val_and_grad(sample[1], log_alpha, log_beta)[1] + log_alpha = jnp.log(self.node_hyperparams['factor_precision_shape']) + log_beta = jnp.log(1.) + factor_precisions_grad = mc_factor_precisions_logp_val_and_grad(sample[2], log_alpha, log_beta)[1] + factor_precisions_grad += mc_factor_weights_logp_val_and_grad_wrt_precisions(sample[0], 0., sample[2])[1] + return factor_weights_grad, gene_scales_grad, factor_precisions_grad + + def compute_locals_entropy_grad(self, idx): + mean = self.variational_parameters['local']['obs_weights']['mean'][idx] + log_std = self.variational_parameters['local']['obs_weights']['log_std'][idx] + obs_weights_grad = obs_weights_logq_val_and_grad(mean, log_std)[1] + + log_alpha = self.variational_parameters['local']['cell_scales']['log_alpha'][idx] + log_beta = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + cell_scales_grad = cell_scales_logq_val_and_grad(log_alpha, log_beta)[1] + + return obs_weights_grad, cell_scales_grad + + def compute_globals_entropy_grad(self): + mean = self.variational_parameters['global']['factor_weights']['mean'] + log_std = self.variational_parameters['global']['factor_weights']['log_std'] + factor_weights_grad = factor_weights_logq_val_and_grad(mean, log_std)[1] + + log_alpha = self.variational_parameters['global']['gene_scales']['log_alpha'] + log_beta = self.variational_parameters['global']['gene_scales']['log_beta'] + gene_scales_grad = gene_scales_logq_val_and_grad(log_alpha, log_beta)[1] + + log_alpha = self.variational_parameters['global']['factor_precisions']['log_alpha'] + log_beta = self.variational_parameters['global']['factor_precisions']['log_beta'] + factor_precisions_grad = factor_precisions_logq_val_and_grad(log_alpha, log_beta)[1] + + return factor_weights_grad, gene_scales_grad, factor_precisions_grad + + def state_sample_and_grad(self, key, n_samples): + """Sample and take gradient of state""" + mu = self.variational_parameters['kernel']['state']['mean'] + log_std = self.variational_parameters['kernel']['state']['log_std'] + key, *sub_keys = jax.random.split(key, n_samples+1) + return key, mc_sample_state_val_and_grad(jnp.array(sub_keys), mu, log_std) + + def direction_sample_and_grad(self, key, n_samples): + """Sample and take gradient of direction""" + log_alpha = self.variational_parameters['kernel']['direction']['log_alpha'] + log_beta = self.variational_parameters['kernel']['direction']['log_beta'] + key, *sub_keys = jax.random.split(key, n_samples+1) + return key, mc_sample_direction_val_and_grad(jnp.array(sub_keys), log_alpha, log_beta) + + def state_sample_and_grad(self, key, n_samples): + """Sample and take gradient of state""" + mu = self.variational_parameters['kernel']['state']['mean'] + log_std = self.variational_parameters['kernel']['state']['log_std'] + key, *sub_keys = jax.random.split(key, n_samples+1) + return key, mc_sample_state_val_and_grad(jnp.array(sub_keys), mu, log_std) + + def compute_direction_prior_grad(self, direction, parent_direction, parent_state): + """Gradient of logp(direction|parent_state) wrt this direction""" + direction_shape = self.node_hyperparams['direction_shape'] + inheritance_strength = self.node_hyperparams['inheritance_strength'] + return mc_direction_logp_val_and_grad(direction, parent_state, direction_shape, inheritance_strength)[1] + + def compute_direction_prior_child_grad_wrt_direction(self, child_direction, direction, state): + """Gradient of logp(child_alpha|alpha) wrt this direction""" + return 0. # no influence under this model + + def compute_direction_prior_child_grad_wrt_state(self, child_direction, direction, state): + """Gradient of logp(child_alpha|alpha) wrt this direction""" + direction_shape = self.node_hyperparams['direction_shape'] + inheritance_strength = self.node_hyperparams['inheritance_strength'] + return mc_direction_logp_val_and_grad_wrt_parent(child_direction, state, direction_shape, inheritance_strength)[1] + + def compute_state_prior_grad(self, state, parent_state, direction): + """Gradient of logp(state|parent_state,direction) wrt this psi""" + return mc_state_logp_val_and_grad(state, parent_state, direction)[1] + + def compute_state_prior_child_grad(self, child_state, state, child_direction): + """Gradient of logp(child_state|state,child_direction) wrt this state""" + return mc_state_logp_val_and_grad_wrt_parent(child_state, state, child_direction)[1] + + def compute_root_state_prior_child_grad(self, child_state, state, child_direction): + """Gradient of logp(child_psi|psi,child_alpha) wrt this psi""" + return self.compute_state_prior_child_grad(child_state, state, child_direction) + + def compute_state_prior_grad_wrt_direction(self, state, parent_state, direction): + """Gradient of logp(state|parent_state,direction) wrt this direction""" + return mc_state_logp_val_and_grad_wrt_direction(state, parent_state, direction)[1] + + def compute_direction_entropy_grad(self): + """Gradient of logq(alpha) wrt this alpha""" + log_alpha = self.variational_parameters['kernel']['direction']['log_alpha'] + log_beta = self.variational_parameters['kernel']['direction']['log_beta'] + return direction_logq_val_and_grad(log_alpha, log_beta)[1] + + def compute_state_entropy_grad(self): + """Gradient of logq(psi) wrt this psi""" + mu = self.variational_parameters['kernel']['state']['mean'] + log_std = self.variational_parameters['kernel']['state']['log_std'] + return state_logq_val_and_grad(mu, log_std)[1] + + def compute_ll_state_grad(self, x, weights, state): + """Gradient of logp(x|psi,noise) wrt this psi""" + obs_weights = self.get_obs_weights_sample() + factor_weights = self.get_factor_weights_sample() + gene_scales = self.get_gene_scales_sample() + cell_scales = self.get_cell_scales_sample() + cnvs = self.cnvs + return mc_ll_val_and_grad_state(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1] + + def compute_ll_state_grad_suff(self, state): + """Gradient of logp(x|state,noise) wrt this state using suff stats""" + cnv = self.cnvs + gene_scales = self.get_gene_scales_sample() + return mc_ll_state_suff_val_and_grad(state, cnv, gene_scales, + self.suff_stats['B_g']['total'], self.suff_stats['D_g']['total'])[1] + + def compute_ll_locals_grad(self, x, idx, weights): + """Gradient of logp(x|psi,locals,globals) wrt locals""" + state = self.get_state_sample() + obs_weights = self.get_obs_weights_sample()[:,idx] + factor_weights = self.get_factor_weights_sample() + gene_scales = self.get_gene_scales_sample() + cell_scales = self.get_cell_scales_sample()[:,idx] + cnvs = self.cnvs + obs_weights_grad = mc_ll_val_and_grad_obs_weights(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1] + cell_scales_grad = mc_ll_val_and_grad_cell_scales(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1] + return obs_weights_grad, cell_scales_grad + + def compute_ll_globals_grad(self, x, idx, weights): + """Gradient of logp(x|psi,locals,globals) wrt globals""" + state = self.get_state_sample() + obs_weights = self.get_obs_weights_sample()[:,idx] + factor_weights = self.get_factor_weights_sample() + gene_scales = self.get_gene_scales_sample() + cell_scales = self.get_cell_scales_sample()[:,idx] + cnvs = self.cnvs + factor_weights_grad = mc_ll_val_and_grad_factor_weights(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1] + gene_scales_grad = mc_ll_val_and_grad_gene_scales(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1] + return factor_weights_grad, gene_scales_grad + + def update_direction_params(self, direction_params_grad, direction_sample_grad, direction_params_entropy_grad, step_size=0.001): + mc_grad = jnp.mean(direction_params_grad[0] * direction_sample_grad, axis=0) + direction_log_alpha_grad = mc_grad + direction_params_entropy_grad[0] + self.variational_parameters['kernel']['direction']['log_alpha'] += direction_log_alpha_grad * step_size + + mc_grad = jnp.mean(direction_params_grad[1] * direction_sample_grad, axis=0) + direction_log_beta_grad = mc_grad + direction_params_entropy_grad[1] + self.variational_parameters['kernel']['direction']['log_beta'] += direction_log_beta_grad * step_size + + self.variational_parameters['kernel']['direction']['log_alpha'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_alpha'], minval=MIN_ALPHA) + self.variational_parameters['kernel']['direction']['log_beta'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_beta'], maxval=MAX_BETA) + + def update_state_params(self, state_params_grad, state_sample_grad, state_params_entropy_grad, step_size=0.001): + mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0) + loc_mean_grad = mc_grad + state_params_entropy_grad[0] + self.variational_parameters['kernel']['state']['mean'] += loc_mean_grad * step_size + + mc_grad = jnp.mean(state_params_grad[1] * state_sample_grad, axis=0) + loc_log_std_grad = mc_grad + state_params_entropy_grad[1] + self.variational_parameters['kernel']['state']['log_std'] += loc_log_std_grad * step_size + + self.variational_parameters['kernel']['state']['mean'] = self.apply_clip(self.variational_parameters['kernel']['state']['mean'], minval=-5., maxval=5.) + self.variational_parameters['kernel']['state']['log_std'] = self.apply_clip(self.variational_parameters['kernel']['state']['log_std'], minval=-5., maxval=5.) + + def update_cell_scales_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001): + mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0] + new_param = self.variational_parameters['local']['cell_scales']['log_alpha'][idx] + param_grad * step_size + self.variational_parameters['local']['cell_scales']['log_alpha'] = self.variational_parameters['local']['cell_scales']['log_alpha'].at[idx].set(new_param) + + mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1] + new_param = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + param_grad * step_size + self.variational_parameters['local']['cell_scales']['log_beta'] = self.variational_parameters['local']['cell_scales']['log_beta'].at[idx].set(new_param) + + self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=MIN_ALPHA) + self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=MAX_BETA) + + def update_obs_weights_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001): + mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0] + new_param = self.variational_parameters['local']['obs_weights']['mean'][idx] + param_grad * step_size + self.variational_parameters['local']['obs_weights']['mean'] = self.variational_parameters['local']['obs_weights']['mean'].at[idx].set(new_param) + + mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1] + new_param = self.variational_parameters['local']['obs_weights']['log_std'][idx] + param_grad * step_size + self.variational_parameters['local']['obs_weights']['log_std'] = self.variational_parameters['local']['obs_weights']['log_std'].at[idx].set(new_param) + + def update_local_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=.001, param_names=["obs_weights", "cell_scales"], **kwargs): + if param_names is None: + param_names=["obs_weights", "cell_scales"] + if "obs_weights" in param_names: + self.update_obs_weights_params(idx, local_params_grad[0], local_sample_grad[0], local_params_entropy_grad[0], ent_anneal=ent_anneal, step_size=step_size) + if "cell_scales" in param_names: + self.update_cell_scales_params(idx, local_params_grad[1], local_sample_grad[1], local_params_entropy_grad[1], ent_anneal=ent_anneal, step_size=step_size) + + def update_gene_scales_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001): + mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[0] + self.variational_parameters['global']['gene_scales']['log_alpha'] += param_grad * step_size + + mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[1] + self.variational_parameters['global']['gene_scales']['log_beta'] += param_grad * step_size + + self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=MIN_ALPHA) + self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=MAX_BETA) + + def update_factor_weights_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001): + mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[0] + self.variational_parameters['global']['factor_weights']['mean'] += param_grad * step_size + + mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[1] + self.variational_parameters['global']['factor_weights']['log_std'] += param_grad * step_size + + def update_factor_precisions_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001): + mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[0] + self.variational_parameters['global']['factor_precisions']['log_alpha'] += param_grad * step_size + + mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[1] + self.variational_parameters['global']['factor_precisions']['log_beta'] += param_grad * step_size + + self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=MIN_ALPHA) + self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=MAX_BETA) + + def update_global_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001, + param_names=["factor_weights", "gene_scales", "factor_precisions"], **kwargs): + if param_names is None: + param_names=["factor_weights", "gene_scales", "factor_precisions"] + if "factor_weights" in param_names: + self.update_factor_weights_params(global_params_grad[0], global_sample_grad[0], global_params_entropy_grad[0], step_size=step_size) + if "gene_scales" in param_names: + self.update_gene_scales_params(global_params_grad[1], global_sample_grad[1], global_params_entropy_grad[1], step_size=step_size) + if "factor_precisions" in param_names: + self.update_factor_precisions_params(global_params_grad[2], global_sample_grad[2], global_params_entropy_grad[2], step_size=step_size) + + def initialize_global_opt_states(self, param_names=["factor_weights", "gene_scales", "factor_precisions"]): + states = dict() + if param_names is None: + param_names=["factor_weights", "gene_scales", "factor_precisions"] + if "factor_weights" in param_names: + factor_weights_states = self.initialize_factor_weights_states() + states["factor_weights"] = factor_weights_states + if "gene_scales" in param_names: + gene_scales_states = self.initialize_gene_scales_states() + states["gene_scales"] = gene_scales_states + if "factor_precisions" in param_names: + factor_precisions_states = self.initialize_factor_precisions_states() + states["factor_precisions"] = factor_precisions_states + return states + + def initialize_factor_weights_states(self): + n_factors = self.node_hyperparams['n_factors'] + m = jnp.zeros((n_factors,self.n_genes)) + v = jnp.zeros((n_factors,self.n_genes)) + state1 = (m,v) + m = jnp.zeros((n_factors,self.n_genes)) + v = jnp.zeros((n_factors,self.n_genes)) + state2 = (m,v) + states = (state1, state2) + return states + + def initialize_gene_scales_states(self): + m = jnp.zeros((self.n_genes,)) + v = jnp.zeros((self.n_genes,)) + state1 = (m,v) + m = jnp.zeros((self.n_genes,)) + v = jnp.zeros((self.n_genes,)) + state2 = (m,v) + states = (state1, state2) + return states + + def initialize_factor_precisions_states(self): + n_factors = self.node_hyperparams['n_factors'] + m = jnp.zeros((n_factors,1)) + v = jnp.zeros((n_factors,1)) + state1 = (m,v) + m = jnp.zeros((n_factors,1)) + v = jnp.zeros((n_factors,1)) + state2 = (m,v) + states = (state1, state2) + return states + + def update_factor_weights_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001): + mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[0] + + m, v = states[0] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state1 = (m, v) + self.variational_parameters['global']['factor_weights']['mean'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[1] + + m, v = states[1] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state2 = (m, v) + self.variational_parameters['global']['factor_weights']['log_std'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + states = (state1, state2) + return states + + def update_gene_scales_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001): + mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[0] + + m, v = states[0] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state1 = (m, v) + self.variational_parameters['global']['gene_scales']['log_alpha'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[1] + + m, v = states[1] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state2 = (m, v) + self.variational_parameters['global']['gene_scales']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=MIN_ALPHA) + self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=MAX_BETA) + + states = (state1, state2) + return states + + def update_factor_precisions_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001): + mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[0] + + m, v = states[0] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state1 = (m, v) + self.variational_parameters['global']['factor_precisions']['log_alpha'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[1] + + m, v = states[1] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state2 = (m, v) + self.variational_parameters['global']['factor_precisions']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=MIN_ALPHA) + self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=MAX_BETA) + + states = (state1, state2) + return states + + def update_global_params_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001, param_names=["factor_weights", "gene_scales", "factor_precisions"], **kwargs): + if param_names is None: + param_names=["factor_weights", "gene_scales", "factor_precisions"] + if "factor_weights" in param_names: + factor_weights_states = self.update_factor_weights_adaptive(global_params_grad[0], global_sample_grad[0], global_params_entropy_grad[0], + i=i, states=states["factor_weights"], b1=b1, b2=b2, eps=eps, step_size=step_size) + states["factor_weights"] = factor_weights_states + if "gene_scales" in param_names: + gene_scales_states = self.update_gene_scales_adaptive(global_params_grad[1], global_sample_grad[1], global_params_entropy_grad[1], + i=i, states=states["gene_scales"], b1=b1, b2=b2, eps=eps, step_size=step_size) + states["gene_scales"] = gene_scales_states + if "factor_precisions" in param_names: + factor_precisions_states = self.update_factor_precisions_adaptive(global_params_grad[2], global_sample_grad[2], global_params_entropy_grad[2], + i=i, states=states["factor_precisions"], b1=b1, b2=b2, eps=eps, step_size=step_size) + states["factor_precisions"] = factor_precisions_states + return states + + + def initialize_local_opt_states(self, param_names=["obs_weights", "cell_scales"]): + states = dict() + if param_names is None: + param_names=["obs_weights", "cell_scales"] + if "obs_weights" in param_names: + obs_weights_states = self.initialize_obs_weights_states() + states["obs_weights"] = obs_weights_states + if "cell_scales" in param_names: + cell_scales_states = self.initialize_cell_scales_states() + states["cell_scales"] = cell_scales_states + return states + + def initialize_obs_weights_states(self): + n_obs = self.tssb.ntssb.num_data + n_factors = self.node_hyperparams['n_factors'] + m = jnp.zeros((n_obs,n_factors)) + v = jnp.zeros((n_obs,n_factors)) + state1 = (m,v) + m = jnp.zeros((n_obs,n_factors)) + v = jnp.zeros((n_obs,n_factors)) + state2 = (m,v) + states = (state1, state2) + return states + + def initialize_cell_scales_states(self): + n_obs = self.tssb.ntssb.num_data + m = jnp.zeros((n_obs,1)) + v = jnp.zeros((n_obs,1)) + state1 = (m,v) + m = jnp.zeros((n_obs,1)) + v = jnp.zeros((n_obs,1)) + state2 = (m,v) + states = (state1, state2) + return states + + def update_obs_weights_adaptive(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, i, states, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001, ent_anneal=1.): + """ + states are not indexed + """ + mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0] + + m, v = states[0] + new_m = (1 - b1) * param_grad + b1 * m[idx] # First moment estimate. + new_v = (1 - b2) * jnp.square(param_grad) + b2 * v[idx] # Second moment estimate. + m = m.at[idx].set(new_m) + v = v.at[idx].set(new_v) + mhat = new_m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = new_v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state1 = (m, v) + new_param = self.variational_parameters['local']['obs_weights']['mean'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps) + self.variational_parameters['local']['obs_weights']['mean'] = self.variational_parameters['local']['obs_weights']['mean'].at[idx].set(new_param) + + mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1] + + m, v = states[1] + new_m = (1 - b1) * param_grad + b1 * m[idx] # First moment estimate. + new_v = (1 - b2) * jnp.square(param_grad) + b2 * v[idx] # Second moment estimate. + m = m.at[idx].set(new_m) + v = v.at[idx].set(new_v) + mhat = new_m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = new_v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state2 = (m, v) + new_param = self.variational_parameters['local']['obs_weights']['log_std'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps) + self.variational_parameters['local']['obs_weights']['log_std'] = self.variational_parameters['local']['obs_weights']['log_std'].at[idx].set(new_param) + + states = (state1, state2) + return states + + def update_cell_scales_adaptive(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, i, states, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001, ent_anneal=1.): + """ + states are already indexed, as are the gradients + """ + mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0] + + m, v = states[0] + new_m = (1 - b1) * param_grad + b1 * m[idx] # First moment estimate. + new_v = (1 - b2) * jnp.square(param_grad) + b2 * v[idx] # Second moment estimate. + m = m.at[idx].set(new_m) + v = v.at[idx].set(new_v) + mhat = new_m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = new_v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state1 = (m, v) + new_param = self.variational_parameters['local']['cell_scales']['log_alpha'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps) + self.variational_parameters['local']['cell_scales']['log_alpha'] = self.variational_parameters['local']['cell_scales']['log_alpha'].at[idx].set(new_param) + + mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1] + + m, v = states[1] + new_m = (1 - b1) * param_grad + b1 * m[idx] # First moment estimate. + new_v = (1 - b2) * jnp.square(param_grad) + b2 * v[idx] # Second moment estimate. + m = m.at[idx].set(new_m) + v = v.at[idx].set(new_v) + mhat = new_m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = new_v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state2 = (m, v) + new_param = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps) + self.variational_parameters['local']['cell_scales']['log_beta'] = self.variational_parameters['local']['cell_scales']['log_beta'].at[idx].set(new_param) + + self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=MIN_ALPHA) + self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=MAX_BETA) + + states = (state1, state2) + return states + + def update_local_params_adaptive(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, i, states, ent_anneal=1., b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001, param_names=["obs_weights", "cell_scales"], **kwargs): + if param_names is None: + param_names = ["obs_weights", "cell_scales"] + + if "obs_weights" in param_names: + obs_weights_states = self.update_obs_weights_adaptive(idx, local_params_grad[0], local_sample_grad[0], local_params_entropy_grad[0], + i=i, states=states["obs_weights"], b1=b1, b2=b2, eps=eps, step_size=step_size, ent_anneal=ent_anneal) + states["obs_weights"] = obs_weights_states + if "cell_scales" in param_names: + cell_scales_states = self.update_cell_scales_adaptive(idx, local_params_grad[1], local_sample_grad[1], local_params_entropy_grad[1], + i=i, states=states["cell_scales"], b1=b1, b2=b2, eps=eps, step_size=step_size, ent_anneal=ent_anneal) + states["cell_scales"] = cell_scales_states + return states + + def initialize_state_states(self): + m = jnp.zeros((self.n_genes,)) + v = jnp.zeros((self.n_genes,)) + state1 = (m,v) + m = jnp.zeros((self.n_genes,)) + v = jnp.zeros((self.n_genes,)) + state2 = (m,v) + states = (state1, state2) + return states + + def update_state_adaptive(self, state_params_grad, state_sample_grad, state_params_entropy_grad, i, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001): + states = self.state_states + + mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0) + param_grad = mc_grad + state_params_entropy_grad[0] + + m, v = states[0] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state1 = (m, v) + self.variational_parameters['kernel']['state']['mean'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + mc_grad = jnp.mean(state_params_grad[1] * state_sample_grad, axis=0) + param_grad = mc_grad + state_params_entropy_grad[1] + + m, v = states[1] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state2 = (m, v) + self.variational_parameters['kernel']['state']['log_std'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + states = (state1, state2) + self.state_states = states + + def initialize_direction_states(self): + m = jnp.zeros((self.n_genes,)) + v = jnp.zeros((self.n_genes,)) + state1 = (m,v) + m = jnp.zeros((self.n_genes,)) + v = jnp.zeros((self.n_genes,)) + state2 = (m,v) + states = (state1, state2) + return states + + def update_direction_adaptive(self, direction_params_grad, direction_sample_grad, direction_params_entropy_grad, i, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001): + states = self.direction_states + mc_grad = jnp.mean(direction_params_grad[0] * direction_sample_grad, axis=0) + param_grad = mc_grad + direction_params_entropy_grad[0] + + m, v = states[0] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state1 = (m, v) + self.variational_parameters['kernel']['direction']['log_alpha'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + mc_grad = jnp.mean(direction_params_grad[1] * direction_sample_grad, axis=0) + param_grad = mc_grad + direction_params_entropy_grad[1] + + m, v = states[1] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state2 = (m, v) + self.variational_parameters['kernel']['direction']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + states = (state1, state2) + self.direction_states = states \ No newline at end of file diff --git a/scatrex/models/cna/node_opt.py b/scatrex/models/cna/node_opt.py new file mode 100644 index 0000000..fac0fd3 --- /dev/null +++ b/scatrex/models/cna/node_opt.py @@ -0,0 +1,233 @@ +import jax +import jax.numpy as jnp +import tensorflow_probability.substrates.jax.distributions as tfd + +@jax.jit +def sample_direction(key, log_alpha, log_beta): # univariate: one sample + print("haahahdirection") + return jnp.maximum(jnp.exp(tfd.ExpGamma(jnp.exp(log_alpha), log_rate=log_beta).sample(seed=key)), 1e-6) +sample_direction_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(sample_direction, argnums=(1,2)), in_axes=(None, 0, 0))) # per-dimension val and grad +mc_sample_direction_val_and_grad = jax.jit(jax.vmap(sample_direction_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def sample_state(key, mu, log_std): # univariate: one sample + print("haahahstate") + return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key) +sample_state_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(sample_state, argnums=(1,2)), in_axes=(None, 0, 0))) # per-dimension val and grad +mc_sample_state_val_and_grad = jax.jit(jax.vmap(sample_state_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def direction_logp(this_direction, parent_state, direction_shape, inheritance_strength): # single sample + return tfd.Gamma(direction_shape, log_rate=inheritance_strength*jnp.abs(parent_state)).log_prob(this_direction) +univ_direction_logp_val_and_grad = jax.jit(jax.value_and_grad(direction_logp, argnums=0)) # Take grad wrt to this +direction_logp_val_and_grad = jax.jit(jax.vmap(univ_direction_logp_val_and_grad, in_axes=(0,0,None,None))) # Take grad wrt to this +mc_direction_logp_val_and_grad = jax.jit(jax.vmap(direction_logp_val_and_grad, in_axes=(0,0,None,None))) # Multiple sample value_and_grad + +univ_direction_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(direction_logp, argnums=1)) # Take grad wrt to parent +direction_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(univ_direction_logp_val_and_grad_wrt_parent, in_axes=(0,0,None,None))) # Multiple sample value_and_grad +mc_direction_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(direction_logp_val_and_grad, in_axes=(0,0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def direction_logq(log_alpha, log_beta): + return tfd.Gamma(jnp.exp(log_alpha), log_rate=log_beta).entropy() +direction_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(direction_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters + +@jax.jit +def state_logp(this_state, parent_state, this_direction): # single sample + return tfd.Normal(parent_state, this_direction).log_prob(this_state) # sum across dimensions +state_logp_val = jax.jit(state_logp) +mc_loc_logp_val = jax.jit(jax.vmap(state_logp_val, in_axes=(0,0,0))) # Multiple sample + +univ_state_logp_val_and_grad = jax.jit(jax.value_and_grad(state_logp, argnums=0)) # Take grad wrt to this +state_logp_val_and_grad = jax.jit(jax.vmap(univ_state_logp_val_and_grad, in_axes=(0,0,0))) # Take grad wrt to this +mc_state_logp_val_and_grad = jax.jit(jax.vmap(state_logp_val_and_grad, in_axes=(0,0,0))) # Multiple sample value_and_grad + +univ_state_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(state_logp, argnums=1)) # Take grad wrt to parent +state_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(univ_state_logp_val_and_grad_wrt_parent, in_axes=(0,0,0))) # Take grad wrt to parent +mc_state_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(state_logp_val_and_grad_wrt_parent, in_axes=(0,0,0))) # Multiple sample value_and_grad + +univ_state_logp_val_and_grad_wrt_direction = jax.jit(jax.value_and_grad(state_logp, argnums=2)) # Take grad wrt to angle +state_logp_val_and_grad_wrt_direction = jax.jit(jax.vmap(univ_state_logp_val_and_grad_wrt_direction, in_axes=(0,0,0))) # Take grad wrt to angle +mc_state_logp_val_and_grad_wrt_direction = jax.jit(jax.vmap(state_logp_val_and_grad_wrt_direction, in_axes=(0,0,0))) # Multiple sample value_and_grad + +@jax.jit +def state_logq(mu, log_std): + return tfd.Normal(mu, jnp.exp(log_std)).entropy() +state_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(state_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters + + +# Noise + +@jax.jit +def sample_obs_weights(key, mean, log_std): # NxK + return tfd.Normal(mean, jnp.exp(log_std)).sample(seed=key) +sample_obs_weights_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_obs_weights, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad +mc_sample_obs_weights_val_and_grad = jax.jit(jax.vmap(sample_obs_weights_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def sample_factor_weights(key, mu, log_std): # KxG + return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key) +sample_factor_weights_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_factor_weights, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad +mc_sample_factor_weights_val_and_grad = jax.jit(jax.vmap(sample_factor_weights_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + + +@jax.jit +def obs_weights_logp(sample, mean, log_std): # single sample, NxK + return tfd.Normal(mean, jnp.exp(log_std)).log_prob(sample) # sum across obs and dimensions +univ_obs_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(obs_weights_logp, argnums=0)) # Take grad wrt to sample (Nx1) +obs_weights_logp_val_and_grad = jax.jit(jax.vmap(jax.vmap(univ_obs_weights_logp_val_and_grad, in_axes=(0, None, None)), in_axes=(0,None,None))) # Take grad wrt to sample (NxK) +mc_obs_weights_logp_val_and_grad = jax.jit(jax.vmap(obs_weights_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxNxK + +@jax.jit +def obs_weights_logq(mu, log_std): + return tfd.Normal(mu, jnp.exp(log_std)).entropy() +obs_weights_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(obs_weights_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters + + +@jax.jit +def factor_weights_logp(sample, mean, precision): # single sample, KxG + return jnp.sum(tfd.Normal(mean, 1./jnp.sqrt(precision)).log_prob(sample)) # sum over 1 +univ_factor_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(factor_weights_logp, argnums=0)) +factor_weights_logp_val_and_grad = jax.jit(jax.vmap(jax.vmap(univ_factor_weights_logp_val_and_grad, in_axes=(0,None,None)), in_axes=(0,None,0))) # Take grad wrt to sample (KxG) +mc_factor_weights_logp_val_and_grad = jax.jit(jax.vmap(factor_weights_logp_val_and_grad, in_axes=(0,None,0))) # Multiple sample value_and_grad: SxKxG + +@jax.jit +def factor_weights_logp_summed(sample, mean, precision): # single sample, KxG + return jnp.sum(tfd.Normal(mean, 1./jnp.sqrt(precision)).log_prob(sample)) # sum over genes +factor_weights_logp_val_and_grad_wrt_precisions = jax.jit(jax.value_and_grad(factor_weights_logp_summed, argnums=2)) +mc_factor_weights_logp_val_and_grad_wrt_precisions = jax.jit(jax.vmap(factor_weights_logp_val_and_grad_wrt_precisions, in_axes=(0,None,0))) # Multiple sample value_and_grad: SxKxG + +@jax.jit +def factor_weights_logq(mu, log_std): + return tfd.Normal(mu, jnp.exp(log_std)).entropy() +factor_weights_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(factor_weights_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters + + +# Cell scales +@jax.jit +def sample_cell_scales(key, log_alpha, log_beta): # Nx1 + return jnp.exp(tfd.ExpGamma(jnp.exp(log_alpha), jnp.exp(log_beta)).sample(seed=key)) +sample_cell_scales_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_cell_scales, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad +mc_sample_cell_scales_val_and_grad = jax.jit(jax.vmap(sample_cell_scales_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def cell_scales_logp(sample, log_alpha, log_beta): # single sample, Nx1 + return jnp.sum(tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).log_prob(sample)) # sum across obs and dimensions +univ_cell_scales_logp_val_and_grad = jax.jit(jax.value_and_grad(cell_scales_logp, argnums=0)) # Take grad wrt to sample (Nx1) +cell_scales_logp_val_and_grad = jax.jit(jax.vmap(univ_cell_scales_logp_val_and_grad, in_axes=(0,None,None))) # Take grad wrt to sample (Nx1) +mc_cell_scales_logp_val_and_grad = jax.jit(jax.vmap(cell_scales_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxNx1 + +@jax.jit +def cell_scales_logq(log_alpha, log_beta): + return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).entropy() +cell_scales_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(cell_scales_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters + +# Gene scales +@jax.jit +def sample_gene_scales(key, log_alpha, log_beta): # G + return jnp.exp(tfd.ExpGamma(jnp.exp(log_alpha), jnp.exp(log_beta)).sample(seed=key)) +sample_gene_scales_val_and_grad = jax.vmap(jax.value_and_grad(sample_gene_scales, argnums=(1,2)), in_axes=(None,0,0)) # per-dimension val and grad +mc_sample_gene_scales_val_and_grad = jax.jit(jax.vmap(sample_gene_scales_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def gene_scales_logp(sample, log_alpha, log_beta): # single sample + return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).log_prob(sample) # sum across obs and dimensions +univ_gene_scales_logp_val_and_grad = jax.jit(jax.value_and_grad(gene_scales_logp, argnums=0)) # Take grad wrt to sample (G,) +gene_scales_logp_val_and_grad = jax.jit(jax.vmap(univ_gene_scales_logp_val_and_grad, in_axes=(0,None,None))) # Take grad wrt to sample (G,) +mc_gene_scales_logp_val_and_grad = jax.jit(jax.vmap(gene_scales_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxG + +@jax.jit +def gene_scales_logq(log_alpha, log_beta): + return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).entropy() +gene_scales_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(gene_scales_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters + + +# Factor variances +@jax.jit +def sample_factor_precisions(key, log_alpha, log_beta): # Kx1 + return jnp.exp(tfd.ExpGamma(jnp.exp(log_alpha), jnp.exp(log_beta)).sample(seed=key)) +sample_factor_precisions_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_factor_precisions, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad +mc_sample_factor_precisions_val_and_grad = jax.jit(jax.vmap(sample_factor_precisions_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def factor_precisions_logp(sample, log_alpha, log_beta): # single sample + return jnp.sum(tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).log_prob(sample)) # sum across obs and dimensions +univ_factor_precisions_logp_val_and_grad = jax.jit(jax.value_and_grad(factor_precisions_logp, argnums=0)) # Take grad wrt to sample (G,) +factor_precisions_logp_val_and_grad = jax.jit(jax.vmap(univ_factor_precisions_logp_val_and_grad, in_axes=(0,None,None))) # Take grad wrt to sample (G,) +mc_factor_precisions_logp_val_and_grad = jax.jit(jax.vmap(factor_precisions_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxG + +@jax.jit +def factor_precisions_logq(log_alpha, log_beta): + return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).entropy() +factor_precisions_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(factor_precisions_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters + + +@jax.jit +def _mc_obs_ll(obs, state, cnv, obs_weights, factor_weights, cell_scales, gene_scales): # For each MC sample: NxG + m = cell_scales * gene_scales * cnv/2 * jnp.exp(state + obs_weights.dot(factor_weights)) + return jnp.sum(jax.vmap(jax.scipy.stats.poisson.logpmf, in_axes=[0, 0])(obs, m), axis=1) # sum over dimensions + +@jax.jit +def ll(x, weights, state, cnv, obs_weights, factor_weights, cell_scales, gene_scales): # single sample + loc = cell_scales * gene_scales * cnv/2 * jnp.exp(state + obs_weights.dot(factor_weights)) + return jnp.sum(jnp.sum(tfd.Poisson(loc).log_prob(x),axis=1) * weights) +ll_val_and_grad_state = jax.jit(jax.value_and_grad(ll, argnums=2)) # Take grad wrt to psi +mc_ll_val_and_grad_state = jax.jit(jax.vmap(ll_val_and_grad_state, + in_axes=(None,None,0,None,0,0,0,0))) + +ll_val_and_grad_factor_weights = jax.jit(jax.value_and_grad(ll, argnums=5)) # Take grad wrt to factor_weights +mc_ll_val_and_grad_factor_weights = jax.jit(jax.vmap(ll_val_and_grad_factor_weights, + in_axes=(None,None,0,None,0,0,0,0))) + +ll_val_and_grad_cell_scales = jax.jit(jax.value_and_grad(ll, argnums=6)) # Take grad wrt to cell_scales +mc_ll_val_and_grad_cell_scales = jax.jit(jax.vmap(ll_val_and_grad_cell_scales, + in_axes=(None,None,0,None,0,0,0,0))) + +ll_val_and_grad_gene_scales = jax.jit(jax.value_and_grad(ll, argnums=7)) # Take grad wrt to gene_scales +mc_ll_val_and_grad_gene_scales = jax.jit(jax.vmap(ll_val_and_grad_gene_scales, + in_axes=(None,None,0,None,0,0,0,0))) + + +@jax.jit +def ll_obs(x, weight, state, cnv, obs_weights, factor_weights, cell_scales, gene_scales): # single obs + loc = cell_scales * gene_scales * cnv/2 * jnp.exp(state + obs_weights.dot(factor_weights)) + return jnp.sum(tfd.Poisson(loc).log_prob(x)) * weight + +univ_ll_val_and_grad_obs_weights = jax.jit(jax.value_and_grad(ll_obs, argnums=4)) # Take grad wrt to obs_weights +ll_val_and_grad_obs_weights = jax.jit(jax.vmap(univ_ll_val_and_grad_obs_weights, in_axes=(0,0, None, None, 0, None, 0, None))) +mc_ll_val_and_grad_obs_weights = jax.jit(jax.vmap(ll_val_and_grad_obs_weights, + in_axes=(None,None,0,None,0,0,0,0))) + +univ_ll_val_and_grad_cell_scales = jax.jit(jax.value_and_grad(ll_obs, argnums=6)) # Take grad wrt to cell_scales +ll_val_and_grad_cell_scales = jax.jit(jax.vmap(univ_ll_val_and_grad_cell_scales, in_axes=(0,0, None, None, 0, None, 0, None))) +mc_ll_val_and_grad_cell_scales = jax.jit(jax.vmap(ll_val_and_grad_cell_scales, + in_axes=(None,None,0,None,0,0,0,0))) + +@jax.jit +def ll_suffstats(state, cnv, gene_scales, A, B_g, C, D_g, E): # for a single state sample + """ + A: \sum_n q(z_n = this node) * \sum_g x_ng * E[log\gamma_n] + B_g: \sum_n q(z_n = this node) * x_ng + C: \sum_n q(z_n = this node) * \sum_g x_ng * E[(s_nW_g)] + D_g: \sum_n q(z_n = this node) * E[\gamma_n] * E[exp(s_nW_g)] + E: \sum_n q(z_n = this node) * lgamma(x_ng+1) + """ + ll = A + jnp.sum(B_g * (jnp.log(gene_scales) + jnp.log(cnv/2) + state)) + \ + C - jnp.sum(gene_scales * cnv/2 * jnp.exp(state) * D_g) - E + return ll + +@jax.jit +def ll_state_suff(state, cnv, gene_scales, B_g, D_g): # for a single state sample + """ + B_g: \sum_n q(z_n = this node) * x_ng + D_g: \sum_n q(z_n = this node) * E[\gamma_n] * E[s_nW_g] + """ + ll = jnp.sum(B_g * state) - jnp.sum(gene_scales * cnv/2 * jnp.exp(state) * D_g) + return ll + +ll_state_suff_val_and_grad = jax.jit(jax.value_and_grad(ll_state_suff, argnums=0)) # Take grad wrt to psi +mc_ll_state_suff_val_and_grad = jax.jit(jax.vmap(ll_state_suff_val_and_grad, + in_axes=(0,None,0,None, None))) + +# To get noise sample +sample_prod = jax.jit(lambda mat1, mat2: mat1.dot(mat2)) \ No newline at end of file diff --git a/scatrex/models/cna/opt_funcs.py b/scatrex/models/cna/opt_funcs.py deleted file mode 100644 index 64ca0ba..0000000 --- a/scatrex/models/cna/opt_funcs.py +++ /dev/null @@ -1,2111 +0,0 @@ -import jax -from jax import jit, grad, vmap -from jax import random -from jax.example_libraries import optimizers -import jax.numpy as jnp -import jax.nn as jnn - -from scatrex.util import * -from scatrex.callbacks import elbos_callback - -from jax.scipy.special import digamma, betaln, gammaln - -from functools import partial - -def loggaussian_ent(mean, std): - return jnp.log(std) + mean + 0.5 + 0.5*jnp.log(2*jnp.pi) - -def diag_loggaussian_ent(mean, log_std, axis=None): - return jnp.sum(vmap(loggaussian_ent)(mean, jnp.exp(log_std)), axis=axis) - - -def complete_elbo(self, rng, mc_samples=3): - rngs = random.split(rng, mc_samples) - - # Get data - data = self.data - lib_sizes = self.root["node"].root["node"].lib_sizes - n_cells, n_genes = self.data.shape - - # Get global variational parameters - log_baseline_mean = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"] - log_baseline_log_std = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"] - cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"] - cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"] - noise_factors_mean = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"] - noise_factors_log_std = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"] - - # Sample global - baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std) - cell_noise_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std) - noise_factor_samples = sample_noise_factors(rngs, noise_factors_mean, noise_factors_log_std) - - def sub_descend(root, depth=0): - elbo = 0 - subtree_weight = 0 - - if root["node"].parent() is None: - parent_unobserved_samples = jnp.zeros((mc_samples, n_genes)) - unobserved_samples = jnp.zeros((mc_samples, n_genes)) - unobserved_kernel_samples = jnp.zeros((mc_samples, n_genes)) - else: - # Sample parent - parent_unobserved_means = root["node"].parent().variational_parameters["locals"]["unobserved_factors_mean"] - parent_unobserved_log_stds = root["node"].parent().variational_parameters["locals"]["unobserved_factors_log_std"] - parent_unobserved_samples = sample_unobserved(rngs, parent_unobserved_means, parent_unobserved_log_stds) - - # Sample node - unobserved_means = root["node"].variational_parameters["locals"]["unobserved_factors_mean"] - unobserved_log_stds = root["node"].variational_parameters["locals"]["unobserved_factors_log_std"] - unobserved_samples = sample_unobserved(rngs, unobserved_means, unobserved_log_stds) - unobserved_factors_kernel_log_mean = root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_mean"] - unobserved_factors_kernel_log_std = root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_std"] - unobserved_kernel_samples = sample_unobserved_kernel(rngs, unobserved_factors_kernel_log_mean, unobserved_factors_kernel_log_std) - - psi_not_prev_sum = 0 - for i, child in enumerate(root["children"]): - nu_alpha = child["node"].variational_parameters["locals"]["nu_log_mean"] - nu_beta = child["node"].variational_parameters["locals"]["nu_log_std"] - psi_alpha = child["node"].variational_parameters["locals"]["psi_log_mean"] - psi_beta = child["node"].variational_parameters["locals"]["psi_log_std"] - - # Compute local exact KL divergence term - child["node"].stick_kl = beta_kl(nu_alpha, nu_beta, 1, (root["node"].tssb.alpha_decay**depth)*root["node"].tssb.dp_alpha) + beta_kl(psi_alpha, psi_beta, 1, root["node"].tssb.dp_gamma) - - # Compute expected local NTSSB weight term - E_log_phi = E_q_log_beta(psi_alpha, psi_beta) + psi_not_prev_sum - E_log_nu = E_q_log_beta(nu_alpha, nu_beta) - E_log_1_nu = E_q_log_1_beta(nu_alpha, nu_beta) - - child["node"].weight_until_here = child["node"].data_weights + root["node"].weight_until_here - child["node"].expected_weights = child["node"].data_weights*E_log_nu + root["node"].weight_until_here*E_log_1_nu + child["node"].weight_until_here*E_log_phi - - if i > 0: - psi_not_prev_sum += E_q_log_1_beta(psi_alpha, psi_beta) - - # Go down in the tree - elbo, subtree_weight = sub_descend(child, depth=depth+1) - - - # Compute local approximate expected log likelihood term - root["node"].ell = ell(root["node"].data_weights, baseline_samples, cell_noise_samples, noise_factor_samples, unobserved_samples, root["node"].cnvs, lib_sizes, data) - - if root["node"].parent() is None: - root["node"].param_kl = 0. - else: - # Compute local approximate KL divergence term - root["node"].param_kl = local_paramkl(parent_unobserved_samples, unobserved_samples, unobserved_kernel_samples, - unobserved_means, - unobserved_log_stds, - unobserved_factors_kernel_log_mean, - unobserved_factors_kernel_log_std, - jnp.array([0.1, 1.]), jnp.array(0.)) - - # If this node is the root of its subtree, compute expected weights here - if depth == 0: - nu_alpha = child["node"].variational_parameters["locals"]["nu_log_mean"] - nu_beta = child["node"].variational_parameters["locals"]["nu_log_std"] - psi_alpha = child["node"].variational_parameters["locals"]["psi_log_mean"] - psi_beta = child["node"].variational_parameters["locals"]["psi_log_std"] - - E_log_nu = E_q_log_beta(nu_alpha, nu_beta) - root["node"].expected_weights = root["node"].data_weights*E_log_nu - - # Compute local exact KL divergence term - root["node"].stick_kl = beta_kl(nu_alpha, nu_beta, 1, (root["node"].tssb.alpha_decay**depth)*root["node"].tssb.dp_alpha) + beta_kl(psi_alpha, psi_beta, 1, root["node"].tssb.dp_gamma) - - - # Compute local ELBO quantities - root["node"].local_elbo = root["node"].ell + root["node"].expected_weights - root["node"].data_weights*np.log(root["node"].data_weights) - root["node"].local_elbo = root["node"].local_elbo - root["node"].param_kl - root["node"].stick_kl - - # Remove root contribution - elbo = elbo + root["node"].local_elbo - - # Track total weight in the subtree - subtree_weight += root["node"].data_weights - - return elbo, subtree_weight - - def descend(super_root, elbo=0): - subtree_elbo, subtree_weights = sub_descend(super_root["node"].root) - elbo += np.sum(subtree_elbo + subtree_weights * super_root["node"].weight) - for super_child in super_root["children"]: - elbo += descend(super_child) - return elbo - - elbo = descend(self.root) - - # Compute global KL - global_kl = jnp.sum(baseline_kl(log_baseline_mean, log_baseline_log_std)) - global_kl += jnp.sum(noise_factors_kl(noise_factors_mean, noise_factors_log_std)) - global_kl += jnp.sum(cell_noise_kl(cell_noise_mean, cell_noise_log_std)) - - # Add to ELBO - elbo = elbo - global_kl - self.root["node"].root["node"].local_elbo - - return elbo - -@jax.jit -def beta_kl(a1, b1, a2, b2): - def logbeta_func(a,b): - return gammaln(a) + gammaln(b) - gammaln(a+b) - - kl = logbeta_func(a1,b1) - logbeta_func(a2,b2) - kl += (a1-a2)*digamma(a1) - kl += (b1-b2)*digamma(b1) - kl += (a2-a1 + b2-b1)*digamma(a1+b1) - return kl - -@jax.jit -def E_q_log_beta( - alpha, - beta, -): - return digamma(alpha) - digamma(alpha + beta) - -@jax.jit -def E_q_log_1_beta( - alpha, - beta, -): - return digamma(beta) - digamma(alpha + beta) - -# This computes E_q[p(z_n=\epsilon | \nu, \psi)] -@jax.jit -def compute_expected_weight( - nu_alpha, - nu_beta, - psi_alpha, - psi_beta, - prev_nu_sticks_sum, - prev_psi_sticks_sum, -): - nu_digamma_sum = digamma(nu_alpha + nu_beta) - E_log_nu = digamma(nu_alpha) - nu_digamma_sum - nu_sticks_sum = digamma(nu_beta) - nu_digamma_sum + prev_nu_sticks_sum - psi_sticks_sum = local_psi_sticks_sum(psi_alpha, psi_beta) + prev_psi_sticks_sum - weight = E_log_nu + nu_sticks_sum + psi_sticks_sum - return weight, nu_sticks_sum, psi_sticks_sum - -@jax.jit -def sample_baseline( - rngs, - log_baseline_mean, - log_baseline_log_std, -): - def _sample(rng, log_baseline_mean, log_baseline_log_std): - return jnp.append(1, jnp.exp(diag_gaussian_sample(rng, log_baseline_mean, log_baseline_log_std))) - vectorized_sample = vmap(_sample, in_axes=[0, None, None]) - return vectorized_sample(rngs, log_baseline_mean, log_baseline_log_std) - -@jax.jit -def sample_cell_noise_factors( - rngs, - cell_noise_mean, - cell_noise_log_std, -): - def _sample(rng, cell_noise_mean, cell_noise_log_std): - return diag_gaussian_sample(rng, cell_noise_mean, cell_noise_log_std) - vectorized_sample = vmap(_sample, in_axes=[0, None, None]) - return vectorized_sample(rngs, cell_noise_mean, cell_noise_log_std) - -@jax.jit -def sample_noise_factors( - rngs, - noise_factors_mean, - noise_factors_log_std, -): - def _sample(rng, noise_factors_mean, noise_factors_log_std): - return diag_gaussian_sample(rng, noise_factors_mean, noise_factors_log_std) - vectorized_sample = vmap(_sample, in_axes=[0, None, None]) - return vectorized_sample(rngs, noise_factors_mean, noise_factors_log_std) - -@jax.jit -def sample_unobserved( - rngs, - unobserved_means, - unobserved_log_stds, -): - def _sample(rng, unobserved_means, unobserved_log_stds): - return diag_gaussian_sample(rng, unobserved_means, unobserved_log_stds) - vectorized_sample = vmap(_sample, in_axes=[0, None, None]) - return vectorized_sample(rngs, unobserved_means, unobserved_log_stds) - -@jax.jit -def sample_unobserved_kernel( - rngs, - log_unobserved_factors_kernel_means, - log_unobserved_factors_kernel_log_stds, -): - def _sample(rng, log_unobserved_factors_kernel_means, log_unobserved_factors_kernel_log_stds): - return jnp.exp(diag_gaussian_sample(rng, log_unobserved_factors_kernel_means, log_unobserved_factors_kernel_log_stds)) - vectorized_sample = vmap(_sample, in_axes=[0, None, None]) - return vectorized_sample(rngs, log_unobserved_factors_kernel_means, log_unobserved_factors_kernel_log_stds) - -@jax.jit -def baseline_kl( - log_baseline_mean, - log_baseline_log_std, -): - std = jnp.exp(log_baseline_log_std) - return -log_baseline_log_std + .5 * (std**2 + log_baseline_mean**2 - 1.) - - -@jax.jit -def cell_noise_kl( - cell_noise_mean, - cell_noise_log_std, -): - std = jnp.exp(cell_noise_log_std) - return -cell_noise_log_std + .5 * (std**2 + cell_noise_mean**2 - 1.) - - -@jax.jit -def noise_factors_kl( - noise_factors_mean, - noise_factors_log_std, -): - std = jnp.exp(noise_factors_log_std) - return -noise_factors_log_std + .5 * (std**2 + noise_factors_mean**2 - 1.) - -@jax.jit -def ell( - data_node_weights, # lambda_{nk} with only the node attachments - baseline_samples, - cell_noise_samples, - noise_factor_samples, - unobserved_samples, # N vector corresponding to node attachments - cnv, # N vector corresponding to node attachments - lib_sizes, - data, - mask, -): - def _ell(data_node_weights, baseline_sample, cell_noise_sample, noise_factor_sample, unobserved_sample): - noise_sample = cell_noise_sample.dot(noise_factor_sample) - node_mean = ( - baseline_sample * cnv/2 * jnp.exp(unobserved_sample + noise_sample) - ) - sum = jnp.sum(node_mean, axis=1).reshape(-1, 1) - node_mean = node_mean / sum - node_mean = node_mean * lib_sizes - pll = vmap(jax.scipy.stats.poisson.logpmf)(data, node_mean) - ell = jnp.sum(pll, axis=1) # N-vector - ell *= data_node_weights - ell *= mask - return ell - vectorized = vmap(_ell, in_axes=[None, 0,0,0,0]) - return jnp.mean(vectorized(data_node_weights, baseline_samples, cell_noise_samples, noise_factor_samples, unobserved_samples), axis=0) - - -@jax.jit -def ll( - baseline_samples, - cell_noise_samples, - noise_factor_samples, - unobserved_samples, - cnv, - lib_sizes, - data, -): - def _ll(baseline_sample, cell_noise_sample, noise_factor_sample, unobserved_sample): - noise_sample = cell_noise_sample.dot(noise_factor_sample) - node_mean = ( - baseline_sample * cnv/2 * jnp.exp(unobserved_sample + noise_sample) - ) - sum = jnp.sum(node_mean, axis=1).reshape(-1, 1) - node_mean = node_mean / sum - node_mean = node_mean * lib_sizes - pll = vmap(jax.scipy.stats.poisson.logpmf)(data, node_mean) - ll = jnp.sum(pll, axis=1) # N-vector - return ll - vectorized = vmap(_ll, in_axes=[0,0,0,0]) - return jnp.mean(vectorized(baseline_samples, cell_noise_samples, noise_factor_samples, unobserved_samples), axis=0) - - - -@jax.jit -def local_paramkl( - parent_unobserved_samples, - child_unobserved_samples, - child_unobserved_kernel_samples, - child_unobserved_means, - child_unobserved_log_stds, - child_log_unobserved_factors_kernel_means, - child_log_unobserved_factors_kernel_log_stds, - hyperparams, # [concentration, rate] - child_in_root_subtree, # to make the prior prefer amplifications -): - def _local_paramkl( - parent_unobserved_sample, - child_unobserved_sample, - child_unobserved_kernel_sample, - child_unobserved_means, - child_unobserved_log_stds, - child_log_unobserved_factors_kernel_means, - child_log_unobserved_factors_kernel_log_stds, - hyperparams, - ): - kl = 0.0 - # Kernel - pl = diag_gamma_logpdf( - child_unobserved_kernel_sample, - jnp.log(hyperparams[0] * jnp.ones((child_unobserved_kernel_sample.shape[0],))), - (hyperparams[1]*jnp.abs(parent_unobserved_sample)), - ) - ent = -diag_loggaussian_logpdf( - child_unobserved_kernel_sample, - child_log_unobserved_factors_kernel_means, - child_log_unobserved_factors_kernel_log_stds, - ) -# ent = diag_loggaussian_ent(child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds) - kl += pl + ent - - # Cell state - pl = diag_gaussian_logpdf( - child_unobserved_sample, - parent_unobserved_sample, - jnp.log(0.001 + child_unobserved_kernel_sample), - ) - ent = -diag_gaussian_logpdf( - child_unobserved_sample, child_unobserved_means, child_unobserved_log_stds - ) - kl += pl + ent - - return kl - vectorized = vmap(_local_paramkl, in_axes=[0,0,0,None,None,None,None,None]) - return jnp.mean(vectorized(parent_unobserved_samples, child_unobserved_samples, - child_unobserved_kernel_samples, child_unobserved_means, child_unobserved_log_stds, - child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds, hyperparams)) - -@jax.jit -def update_local_parameters( - rngs, - child_unobserved_means, - child_unobserved_log_stds, - child_log_unobserved_factors_kernel_means, - child_log_unobserved_factors_kernel_log_stds, - data_node_weights, - parent_unobserved_samples, - baseline_samples, - cell_noise_samples, - noise_factor_samples, - hyperparams, # [concentration, rate] - child_in_root_subtree, # to make the prior prefer amplifications - cnv, - lib_sizes, - data, - mask, - states, - i, - mb_scaling=1, - lr=0.01 - ): - - def local_loss( - params, - parent_unobserved_samples, - baseline_samples, - cell_noise_samples, - noise_factor_samples, - hyperparams, - child_in_root_subtree - ): - child_unobserved_means = params[0] - child_unobserved_log_stds = params[1] - child_log_unobserved_factors_kernel_means = params[2] - child_log_unobserved_factors_kernel_log_stds = params[3] - - child_unobserved_kernel_samples = sample_unobserved_kernel(rngs, child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds) - child_unobserved_samples = sample_unobserved(rngs, child_unobserved_means, child_unobserved_log_stds) - - child_unobserved_kernel_samples = jnp.clip(child_unobserved_kernel_samples, a_min=1e-8, a_max=5) - - loss = 0. - - loss = jnp.sum(ell(data_node_weights, - baseline_samples, - cell_noise_samples, - noise_factor_samples, - child_unobserved_samples, - cnv, - lib_sizes, - data, - mask,)) - kl = local_paramkl(parent_unobserved_samples, - child_unobserved_samples, - child_unobserved_kernel_samples, - child_unobserved_means, - child_unobserved_log_stds, - child_log_unobserved_factors_kernel_means, - child_log_unobserved_factors_kernel_log_stds, - hyperparams, # [concentration, rate] - child_in_root_subtree,) - # Scale by minibatch - loss = loss + kl*mb_scaling - return loss - - child_unobserved_log_stds = jnp.clip(child_unobserved_log_stds, a_min=jnp.log(1e-8)) - child_log_unobserved_factors_kernel_means = jnp.clip(child_log_unobserved_factors_kernel_means, a_min=jnp.log(1e-8), a_max=0.) - child_log_unobserved_factors_kernel_log_stds = jnp.clip(child_log_unobserved_factors_kernel_log_stds, a_min=jnp.log(1e-8), a_max=0.) - - params = jnp.array([child_unobserved_means, child_unobserved_log_stds, - child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds]) - loss, grads = jax.value_and_grad(local_loss)(params, parent_unobserved_samples, - baseline_samples, - cell_noise_samples, - noise_factor_samples, - hyperparams, - child_in_root_subtree) - - state1, state2, state3, state4 = states - - - - m3, v3 = state3 - b1=0.9 - b2=0.999 - eps=1e-8 - m3 = (1 - b1) * grads[2] + b1 * m3 # First moment estimate. - v3 = (1 - b2) * jnp.square(grads[2]) + b2 * v3 # Second moment estimate. - state3 = (m3, v3) - mhat = m3 / (1 - jnp.asarray(b1, m3.dtype) ** (i + 1)) # Bias correction. - vhat = v3 / (1 - jnp.asarray(b2, m3.dtype) ** (i + 1)) - child_log_unobserved_factors_kernel_means = child_log_unobserved_factors_kernel_means + lr * mhat / (jnp.sqrt(vhat) + eps) - - - m4, v4 = state4 - b1=0.9 - b2=0.999 - eps=1e-8 - m4 = (1 - b1) * grads[3] + b1 * m4 # First moment estimate. - v4 = (1 - b2) * jnp.square(grads[3]) + b2 * v4 # Second moment estimate. - state4 = (m4, 4) - mhat = m4 / (1 - jnp.asarray(b1, m4.dtype) ** (i + 1)) # Bias correction. - vhat = v4 / (1 - jnp.asarray(b2, m4.dtype) ** (i + 1)) - child_log_unobserved_factors_kernel_log_stds = child_log_unobserved_factors_kernel_log_stds + lr * mhat / (jnp.sqrt(vhat) + eps) - - m1, v1 = state1 - b1=0.9 - b2=0.999 - eps=1e-8 - m1 = (1 - b1) * grads[0] + b1 * m1 # First moment estimate. - v1 = (1 - b2) * jnp.square(grads[0]) + b2 * v1 # Second moment estimate. - state1 = (m1, v1) - mhat = m1 / (1 - jnp.asarray(b1, m1.dtype) ** (i + 1)) # Bias correction. - vhat = v1 / (1 - jnp.asarray(b2, m1.dtype) ** (i + 1)) - child_unobserved_means = child_unobserved_means + lr * mhat / (jnp.sqrt(vhat) + eps) - - m2, v2 = state2 - m2 = (1 - b1) * grads[1] + b1 * m2 # First moment estimate. - v2 = (1 - b2) * jnp.square(grads[1]) + b2 * v2 # Second moment estimate. - state2 = (m2, v2) - mhat = m2 / (1 - jnp.asarray(b1, m2.dtype) ** (i + 1)) # Bias correction. - vhat = v2 / (1 - jnp.asarray(b2, m2.dtype) ** (i + 1)) - child_unobserved_log_stds = child_unobserved_log_stds + lr * mhat / (jnp.sqrt(vhat) + eps) - - - - states = (state1, state2, state3, state4) - - - return loss, states, child_unobserved_means, child_unobserved_log_stds, child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds - -@jax.jit -def baseline_node_grad( - rngs, - log_baseline_mean, - log_baseline_log_std, - data_node_weights, - child_unobserved_samples, - cell_noise_samples, - noise_factor_samples, - cnv, - lib_sizes, - data, - mask,): - def local_loss( - params, - child_unobserved_samples, - cell_noise_samples, - noise_factor_samples, - ): - log_baseline_mean, log_baseline_log_std = params[0], params[1] - baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std) - - loss = jnp.sum(ell(data_node_weights, - baseline_samples, - cell_noise_samples, - noise_factor_samples, - child_unobserved_samples, - cnv, - lib_sizes, - data, - mask,)) - return loss - - params = jnp.array([log_baseline_mean, log_baseline_log_std]) - grads = jax.grad(local_loss)(params, child_unobserved_samples, cell_noise_samples, noise_factor_samples,) - return grads - -@jax.jit -def baseline_kl_grad(mean, log_std): - def _kl(mean, log_std): - return jnp.sum(baseline_kl(mean, log_std)) - return jnp.array(jax.grad(_kl, argnums=(0,1))(mean, log_std)) - -@jax.jit -def baseline_step(log_baseline_mean, log_baseline_log_std, grads, states, i, lr=0.01): - state1, state2 = states - - m1, v1 = state1 - b1=0.9 - b2=0.999 - eps=1e-8 - m1 = (1 - b1) * grads[0] + b1 * m1 # First moment estimate. - v1 = (1 - b2) * jnp.square(grads[0]) + b2 * v1 # Second moment estimate. - state1 = (m1, v1) - mhat = m1 / (1 - jnp.asarray(b1, m1.dtype) ** (i + 1)) # Bias correction. - vhat = v1 / (1 - jnp.asarray(b2, m1.dtype) ** (i + 1)) - log_baseline_mean = log_baseline_mean + lr * mhat / (jnp.sqrt(vhat) + eps) - - m2, v2 = state2 - m2 = (1 - b1) * grads[1] + b1 * m2 # First moment estimate. - v2 = (1 - b2) * jnp.square(grads[1]) + b2 * v2 # Second moment estimate. - state2 = (m2, v2) - mhat = m2 / (1 - jnp.asarray(b1, m2.dtype) ** (i + 1)) # Bias correction. - vhat = v2 / (1 - jnp.asarray(b2, m2.dtype) ** (i + 1)) - log_baseline_log_std = log_baseline_log_std + lr * mhat / (jnp.sqrt(vhat) + eps) - - state = (state1, state2) - - return state, log_baseline_mean, log_baseline_log_std - -@jax.jit -def noise_node_grad( - rngs, - noise_factors_mean, - noise_factors_log_std, - data_node_weights, - child_unobserved_samples, - cell_noise_samples, - baseline_samples, - cnv, - lib_sizes, - data, - mask,): - def local_loss( - params, - child_unobserved_samples, - cell_noise_samples, - baseline_samples, - ): - noise_factors_mean, noise_factors_log_std = params[0], params[1] - noise_factor_samples = sample_noise_factors(rngs, noise_factors_mean, noise_factors_log_std) - - loss = jnp.sum(ell(data_node_weights, - baseline_samples, - cell_noise_samples, - noise_factor_samples, - child_unobserved_samples, - cnv, - lib_sizes, - data, - mask,)) - return loss - - params = jnp.array([noise_factors_mean, noise_factors_log_std]) - grads = jnp.array(jax.grad(local_loss)(params, child_unobserved_samples, cell_noise_samples, baseline_samples,)) - return grads - -@jax.jit -def noise_kl_grad(mean, log_std): - def _kl(mean, log_std): - return jnp.sum(noise_factors_kl(mean, log_std)) - return jnp.array(jax.grad(_kl, argnums=(0,1))(mean, log_std)) - -@jax.jit -def noise_step(noise_factors_mean, noise_factors_log_std, grads, states, i, lr=0.01): - state1, state2 = states - - m1, v1 = state1 - b1=0.9 - b2=0.999 - eps=1e-8 - m1 = (1 - b1) * grads[0] + b1 * m1 # First moment estimate. - v1 = (1 - b2) * jnp.square(grads[0]) + b2 * v1 # Second moment estimate. - state1 = (m1, v1) - mhat = m1 / (1 - jnp.asarray(b1, m1.dtype) ** (i + 1)) # Bias correction. - vhat = v1 / (1 - jnp.asarray(b2, m1.dtype) ** (i + 1)) - noise_factors_mean = noise_factors_mean + lr * mhat / (jnp.sqrt(vhat) + eps) - - m2, v2 = state2 - m2 = (1 - b1) * grads[1] + b1 * m2 # First moment estimate. - v2 = (1 - b2) * jnp.square(grads[1]) + b2 * v2 # Second moment estimate. - state2 = (m2, v2) - mhat = m2 / (1 - jnp.asarray(b1, m2.dtype) ** (i + 1)) # Bias correction. - vhat = v2 / (1 - jnp.asarray(b2, m2.dtype) ** (i + 1)) - noise_factors_log_std = noise_factors_log_std + lr * mhat / (jnp.sqrt(vhat) + eps) - - state = (state1, state2) - - return state, noise_factors_mean, noise_factors_log_std - - -@jax.jit -def cellnoise_node_grad( - rngs, - cell_noise_mean, - cell_noise_log_std, - data_node_weights, - child_unobserved_samples, - noise_factor_samples, - baseline_samples, - cnv, - lib_sizes, - data, - mask,): - def local_loss( - params, - child_unobserved_samples, - noise_factor_samples, - baseline_samples, - ): - cell_noise_mean, cell_noise_log_std = params[0], params[1] - cell_noise_samples = sample_noise_factors(rngs, cell_noise_mean, cell_noise_log_std) - - loss = jnp.sum(ell(data_node_weights, - baseline_samples, - cell_noise_samples, - noise_factor_samples, - child_unobserved_samples, - cnv, - lib_sizes, - data, - mask,)) - return loss - - params = jnp.array([cell_noise_mean, cell_noise_log_std]) - grads = jnp.array(jax.grad(local_loss)(params, child_unobserved_samples, noise_factor_samples, baseline_samples,)) - return grads - -@jax.jit -def cellnoise_kl_grad(mean, log_std): - def _kl(mean, log_std): - return jnp.sum(cell_noise_kl(mean, log_std)) - return jnp.array(jax.grad(_kl, argnums=(0,1))(mean, log_std)) - -@jax.jit -def cellnoise_step(cell_noise_mean, cell_noise_log_std, grads, states, i, lr=0.01): - state1, state2 = states - - m1, v1 = state1 - b1=0.9 - b2=0.999 - eps=1e-8 - m1 = (1 - b1) * grads[0] + b1 * m1 # First moment estimate. - v1 = (1 - b2) * jnp.square(grads[0]) + b2 * v1 # Second moment estimate. - state1 = (m1, v1) - mhat = m1 / (1 - jnp.asarray(b1, m1.dtype) ** (i + 1)) # Bias correction. - vhat = v1 / (1 - jnp.asarray(b2, m1.dtype) ** (i + 1)) - cell_noise_mean = cell_noise_mean + lr * mhat / (jnp.sqrt(vhat) + eps) - - m2, v2 = state2 - m2 = (1 - b1) * grads[1] + b1 * m2 # First moment estimate. - v2 = (1 - b2) * jnp.square(grads[1]) + b2 * v2 # Second moment estimate. - state2 = (m2, v2) - mhat = m2 / (1 - jnp.asarray(b1, m2.dtype) ** (i + 1)) # Bias correction. - vhat = v2 / (1 - jnp.asarray(b2, m2.dtype) ** (i + 1)) - cell_noise_log_std = cell_noise_log_std + lr * mhat / (jnp.sqrt(vhat) + eps) - - state = (state1, state2) - - return state, cell_noise_mean, cell_noise_log_std - - -# -# @jax.jit -# def update_global_parameters( -# hyperparams, # [concentration, rate] -# child_in_root_subtree, # to make the prior prefer amplifications -# cnv, -# lib_sizes, -# data, -# parent_unobserved_sample, -# child_unobserved_sample, -# child_unobserved_kernel_sample, -# ): -# def local_loss(local_params): -# return ell(local_params) + baseline_kl(baseline) -# return loss -# -# -# return child_unobserved_means, child_unobserved_log_stds, child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds - -# -# def tree_traversal_compute_elbo(rng, mc_samples=3): -# """ -# This function traverses the tree starting at the root and computes the -# complete ELBO -# """ -# rngs = random.split(rng, mc_samples) -# -# # Get data -# data = self.data -# lib_sizes = self.root["node"].root["node"].lib_sizes -# n_cells, n_genes = self.data.shape -# -# # Get global variational parameters -# log_baseline_mean = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"] -# log_baseline_log_std = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"] -# cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"] -# cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"] -# noise_factors_mean = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"] -# noise_factors_log_std = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"] -# -# # Sample global -# baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std) -# cell_noise_factors_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std) -# noise_factors_samples = sample_noise_factors(rng, noise_factors_mean, noise_factors_log_std) -# noise_samples = cell_noise_factors_samples.dot(noise_factors_samples) -# -# # Traverse tree -# def descend(root): -# total_subtree_weight = 0 -# indices = list(range(len(root["children"]))) -# # indices = indices[::-1] -# -# parent_unobserved_means = root["node"].variational_parameters["locals"]["unobserved_factors_mean"] -# parent_unobserved_log_stds = root["node"].variational_parameters["locals"]["unobserved_factors_log_std"] -# # Sample parent -# parent_unobserved_samples = sample_unobserved(rngs, parent_unobserved_means, parent_unobserved_log_stds) -# psi_not_prev_sum = 0 -# for i in indices: -# child = root["children"][i] -# cnv = child.cnvs * jnp.ones((n_cells,n_genes)) -# data_node_weights = child.data_weights -# -# # Sample node -# unobserved_samples = sample_unobserved(rngs, unobserved_means, unobserved_log_stds) -# unobserved_kernel_samples = sample_unobserved_kernel(rngs, unobserved_factors_kernel_log_means, unobserved_factors_kernel_log_stds) -# -# # Compute node ll -# unobserved_samples = unobserved_samples * jnp.ones((mc_samples, n_cells, n_genes)) # broadcast -# unobserved_kernel_samples = unobserved_kernel_samples * jnp.ones((mc_samples, n_cells, n_genes)) # broadcast -# ll = ell(data_node_weights, baseline_samples, noise_samples, unobserved_samples, cnv, lib_sizes, data) -# child.ell = ll -# child.param_kl = local_paramkl(unobserved_samples, unobserved_kernel_samples, parent_unobserved_samples) -# -# E_log_phi = E_q_log_beta(psi_alpha, psi_beta) + psi_not_prev_sum -# E_log_nu = E_q_log_beta(nu_alpha, nu_beta) -# E_log_1_nu = E_q_log_1_beta(nu_alpha, nu_beta) -# -# child.weight_until_here = child.data_weights + root.weight_until_here -# child.expected_weights = child.data_weights*E_log_nu + root.weight_until_here*E_log_1_nu + child.weight_until_here*E_log_phi -# total_weight += child.data_weights -# -# if i > 0: -# psi_not_prev_sum += E_q_log_1_beta(psi_alpha, psi_beta) -# -# node.stick_kl = beta_kl(root["node"].variational_parameters["locals"]["nu_log_mean"], 1, root["node"].variational_parameters["locals"]["nu_log_mean"], root.tssb.dp_alpha) -# node.stick_kl += beta_kl(root["node"].variational_parameters["locals"]["psi_log_mean"], 1, root["node"].variational_parameters["locals"]["psi_log_mean"], root.tssb.dp_gamma) -# -# expected_weight, nu_sticks_sum, psi_sticks_sum = compute_expected_weight(nu_alpha, nu_beta, psi_alpha, psi_beta, prev_nu_sticks_sum, prev_psi_sticks_sum) -# node.ew = expected_weight -# -# return total_weight, lls, expected_weight, nodes -# -# _, lls, expected_w, nodes = descend(self.root) -# -# # Normalize data_node_weights and compute local ELBO contributions -# data_weights = [] -# for node in nodes: -# data_weights.append(node.data_weights) -# data_weights = np.array(data_weights).reshape(-1,mb_size)/np.sum(data_weights) -# elbo = 0 -# for i, node in nodes: -# node.data_weights = data_weights[data_indices,i] -# # Compute ELBO using normalized data_weights -# node_data_elbo_contributions = node.data_weights*node.ell + node.ew - node.data_weights*np.log(node.data_weights)) -# node_kl = node.stick_kl + node.param_kl -# elbo += node_data_elbo_contributions - node_kl -# -# return elbo -# -# -# def tree_traversal_update(root, rng, update_global=True, n_inner_steps=10, mc_samples=3): -# """ -# This function traverses the tree starting at `root` and updates the -# variational parameters and ELBO contributions while doing it -# """ -# rngs = random.split(rng, mc_samples) -# -# # Get data -# data = self.data -# lib_sizes = self.root["node"].root["node"].lib_sizes -# n_cells, n_genes = self.data.shape -# -# # Get global variational parameters -# log_baseline_mean = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"] -# log_baseline_log_std = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"] -# cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"] -# cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"] -# noise_factors_mean = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"] -# noise_factors_log_std = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"] -# -# # Sample global -# baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std) -# cell_noise_factors_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std) -# noise_factors_samples = sample_noise_factors(rng, noise_factors_mean, noise_factors_log_std) -# noise_samples = cell_noise_factors_samples.dot(noise_factors_samples) -# -# # Traverse tree and update variational parameters -# def descend(root, depth=0): -# weight_down = 0 -# indices = list(range(len(root["children"]))) -# indices = indices[::-1] -# -# parent_unobserved_means = root["node"].variational_parameters["locals"]["unobserved_factors_mean"] -# parent_unobserved_log_stds = root["node"].variational_parameters["locals"]["unobserved_factors_log_std"] -# for i in indices: -# child = root["children"][i] -# cnv = child.cnvs * jnp.ones((n_cells,n_genes)) -# data_node_weights = child.data_weights -# -# # Sample parent -# parent_unobserved_samples = sample_unobserved(rngs, parent_unobserved_means, parent_unobserved_log_stds) -# -# # Update local parameters -# unobserved_means = root["node"].variational_parameters["locals"]["unobserved_factors_mean"] -# unobserved_log_stds = root["node"].variational_parameters["locals"]["unobserved_factors_log_std"] -# unobserved_factors_kernel_log_mean = root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_mean"] -# unobserved_factors_kernel_log_std = root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_std"] -# update_local_parameters(baseline_samples, noise_samples, parent_unobserved_samples, steps=n_inner_steps) -# -# # Sample updated node -# unobserved_samples = sample_unobserved(rngs, unobserved_means, unobserved_log_stds) -# unobserved_kernel_samples = sample_unobserved_kernel(rngs, unobserved_factors_kernel_log_means, unobserved_factors_kernel_log_stds) -# -# # Compute node ll -# unobserved_samples = unobserved_samples * jnp.ones((mc_samples, n_cells, n_genes)) # broadcast -# unobserved_kernel_samples = unobserved_kernel_samples * jnp.ones((mc_samples, n_cells, n_genes)) # broadcast -# ll = ell(data_node_weights, baseline_samples, noise_samples, unobserved_samples, cnv, lib_sizes, data) -# child.ell[data_indices] = ll -# child.param_kl = local_paramkl(unobserved_samples, unobserved_kernel_samples, parent_unobserved_samples) -# -# child_weight, _, _, _ = descend(child, depth + 1) -# post_alpha = 1.0 + child_weight -# post_beta = self.dp_gamma + weight_down -# child["node"].variational_parameters["locals"]["psi_log_mean"] = np.log(post_alpha) -# child["node"].variational_parameters["locals"]["psi_log_std"] = np.log(post_beta) -# weight_down += child_weight -# -# prev_psi_sticks_sum += local_psi_sticks_sum(psi_alpha, psi_beta) -# -# weight_here = np.sum(root["node"].data_weights) -# total_weight = weight_here + weight_down -# post_alpha = 1.0 + weight_here -# post_beta = (self.alpha_decay**depth) * self.dp_alpha + weight_down -# root["node"].variational_parameters["locals"]["nu_log_mean"] = np.log(post_alpha) -# root["node"].variational_parameters["locals"]["nu_log_std"] = np.log(post_beta) -# -# node.stick_kl = beta_kl(root["node"].variational_parameters["locals"]["nu_log_mean"], root["node"].variational_parameters["locals"]["nu_log_mean"]) -# node.stick_kl += beta_kl(root["node"].variational_parameters["locals"]["psi_log_mean"], root["node"].variational_parameters["locals"]["psi_log_mean"]) -# -# expected_weight, nu_sticks_sum, psi_sticks_sum = compute_expected_weight(nu_alpha, nu_beta, psi_alpha, psi_beta, prev_nu_sticks_sum, prev_psi_sticks_sum) -# node.ew = expected_weight -# node.data_weights[data_indices] = np.exp(node.ell + node.ew) -# -# return total_weight, lls, expected_weight, nodes -# -# _, lls, expected_w, nodes = descend(root) -# -# # Update global parameters -# if root.parent() is None and update_global: -# # Use samples from tree traversal to update global parameters -# update_global_parameters(steps=n_inner_steps) -# -# # Compute global KL -# global_kl = baseline_kl(log_baseline_mean, log_baseline_log_std) -# -# # Use lls from tree traversal to update data_node_weights -# data_weights = [] -# for node in nodes: -# data_weights.append(node.data_weights) -# data_weights = np.array(data_weights).reshape(-1,mb_size)/np.sum(data_weights) -# for i, node in nodes: -# node.data_weights[data_indices] = data_weights[data_indices,i] -# # Compute ELBO using normalized data_weights -# node_data_elbo_contributions = node.data_weights*(node.ell + node.ew - np.log(node.data_weights)) -# node_elbo_contributions = node.stick_kl + node.param_kl -# -# # Compute total elbo: use decomposibility! -# elbo = -# -# return elbo - - -def compute_elbo( - rng, - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - cnvs, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - nu_sticks_log_means, - nu_sticks_log_stds, - psi_sticks_log_means, - psi_sticks_log_stds, - unobserved_means, - unobserved_log_stds, - log_unobserved_factors_kernel_means, - log_unobserved_factors_kernel_log_stds, - log_baseline_mean, - log_baseline_log_std, - cell_noise_mean, - cell_noise_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, -): - elbo, ll, kl, node_kl = _compute_elbo( - rng, - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - cnvs, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - nu_sticks_log_means, - nu_sticks_log_stds, - psi_sticks_log_means, - psi_sticks_log_stds, - unobserved_means, - unobserved_log_stds, - log_unobserved_factors_kernel_means, - log_unobserved_factors_kernel_log_stds, - log_baseline_mean, - log_baseline_log_std, - cell_noise_mean, - cell_noise_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, - ) - return elbo - - -def _compute_elbo( - rng, - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - cnvs, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - nu_sticks_log_means, - nu_sticks_log_stds, - psi_sticks_log_means, - psi_sticks_log_stds, - unobserved_means, - unobserved_log_stds, - log_unobserved_factors_kernel_means, - log_unobserved_factors_kernel_log_stds, - log_baseline_mean, - log_baseline_log_std, - cell_noise_mean, - cell_noise_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, -): - - # single-sample Monte Carlo estimate of the variational lower bound - mb_size = len(indices) - - def stop_global(globals): - ( - log_baseline_mean, - log_baseline_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, - ) = ( - globals[0], - globals[1], - globals[2], - globals[3], - globals[4], - globals[5], - globals[6], - globals[7], - ) - log_baseline_mean = jax.lax.stop_gradient(log_baseline_mean) - log_baseline_log_std = jax.lax.stop_gradient(log_baseline_log_std) - noise_factors_mean = jax.lax.stop_gradient(noise_factors_mean) - noise_factors_log_std = jax.lax.stop_gradient(noise_factors_log_std) - factor_precision_log_means = jax.lax.stop_gradient(factor_precision_log_means) - factor_precision_log_stds = jax.lax.stop_gradient(factor_precision_log_stds) - batch_effects_mean = jax.lax.stop_gradient(batch_effects_mean) - batch_effects_log_std = jax.lax.stop_gradient(batch_effects_log_std) - return ( - log_baseline_mean, - log_baseline_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, - ) - - def alt_global(globals): - ( - log_baseline_mean, - log_baseline_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, - ) = ( - globals[0], - globals[1], - globals[2], - globals[3], - globals[4], - globals[5], - globals[6], - globals[7], - ) - return ( - log_baseline_mean, - log_baseline_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, - ) - - ( - log_baseline_mean, - log_baseline_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, - ) = jax.lax.cond( - do_global, - alt_global, - stop_global, - ( - log_baseline_mean, - log_baseline_log_std, - noise_factors_mean, - noise_factors_log_std, - factor_precision_log_means, - factor_precision_log_stds, - batch_effects_mean, - batch_effects_log_std, - ), - ) - - def stop_node_grad(i): - return ( - jax.lax.stop_gradient(nu_sticks_log_means[i]), - jax.lax.stop_gradient(nu_sticks_log_stds[i]), - jax.lax.stop_gradient(psi_sticks_log_means[i]), - jax.lax.stop_gradient(psi_sticks_log_stds[i]), - jax.lax.stop_gradient(log_unobserved_factors_kernel_log_stds[i]), - jax.lax.stop_gradient(log_unobserved_factors_kernel_means[i]), - jax.lax.stop_gradient(unobserved_log_stds[i]), - jax.lax.stop_gradient(unobserved_means[i]), - ) - - def stop_node_grads(i): - return jax.lax.cond( - node_mask[i] != 1, - stop_node_grad, - lambda i: ( - nu_sticks_log_means[i], - nu_sticks_log_stds[i], - psi_sticks_log_means[i], - psi_sticks_log_stds[i], - log_unobserved_factors_kernel_log_stds[i], - log_unobserved_factors_kernel_means[i], - unobserved_log_stds[i], - unobserved_means[i], - ), - i, - ) # Sample all - - ( - nu_sticks_log_means, - nu_sticks_log_stds, - psi_sticks_log_means, - psi_sticks_log_stds, - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - ) = vmap(stop_node_grads)(jnp.arange(len(cnvs))) - - def stop_node_params_grads(locals): - return ( - jax.lax.stop_gradient(log_unobserved_factors_kernel_log_stds), - jax.lax.stop_gradient(log_unobserved_factors_kernel_means), - jax.lax.stop_gradient(unobserved_log_stds), - jax.lax.stop_gradient(unobserved_means), - ) - - def alt_node_params(locals): - return ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - ) - - ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - ) = jax.lax.cond( - sticks_only, - stop_node_params_grads, - alt_node_params, - ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - ), - ) - - def stop_sticks(locals): - return ( - jax.lax.stop_gradient(nu_sticks_log_means), - jax.lax.stop_gradient(nu_sticks_log_stds), - jax.lax.stop_gradient(psi_sticks_log_means), - jax.lax.stop_gradient(psi_sticks_log_stds), - ) - - def keep_sticks(locals): - return ( - nu_sticks_log_means, - nu_sticks_log_stds, - psi_sticks_log_means, - psi_sticks_log_stds, - ) - - ( - nu_sticks_log_means, - nu_sticks_log_stds, - psi_sticks_log_means, - psi_sticks_log_stds, - ) = jax.lax.cond( - do_sticks, - keep_sticks, - stop_sticks, - ( - nu_sticks_log_means, - nu_sticks_log_stds, - psi_sticks_log_means, - psi_sticks_log_stds, - ), - ) - - def stop_non_global(locals): - ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - psi_sticks_log_stds, - psi_sticks_log_means, - nu_sticks_log_stds, - nu_sticks_log_means, - ) = ( - locals[0], - locals[1], - locals[2], - locals[3], - locals[4], - locals[5], - locals[6], - locals[7], - ) - log_unobserved_factors_kernel_log_stds = jax.lax.stop_gradient( - log_unobserved_factors_kernel_log_stds - ) - log_unobserved_factors_kernel_means = jax.lax.stop_gradient( - log_unobserved_factors_kernel_means - ) - unobserved_log_stds = jax.lax.stop_gradient(unobserved_log_stds) - unobserved_means = jax.lax.stop_gradient(unobserved_means) - psi_sticks_log_stds = jax.lax.stop_gradient(psi_sticks_log_stds) - psi_sticks_log_means = jax.lax.stop_gradient(psi_sticks_log_means) - nu_sticks_log_stds = jax.lax.stop_gradient(nu_sticks_log_stds) - nu_sticks_log_means = jax.lax.stop_gradient(nu_sticks_log_means) - return ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - psi_sticks_log_stds, - psi_sticks_log_means, - nu_sticks_log_stds, - nu_sticks_log_means, - ) - - def alt_non_global(locals): - ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - psi_sticks_log_stds, - psi_sticks_log_means, - nu_sticks_log_stds, - nu_sticks_log_means, - ) = ( - locals[0], - locals[1], - locals[2], - locals[3], - locals[4], - locals[5], - locals[6], - locals[7], - ) - return ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - psi_sticks_log_stds, - psi_sticks_log_means, - nu_sticks_log_stds, - nu_sticks_log_means, - ) - - ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - psi_sticks_log_stds, - psi_sticks_log_means, - nu_sticks_log_stds, - nu_sticks_log_means, - ) = jax.lax.cond( - global_only, - stop_non_global, - alt_non_global, - ( - log_unobserved_factors_kernel_log_stds, - log_unobserved_factors_kernel_means, - unobserved_log_stds, - unobserved_means, - psi_sticks_log_stds, - psi_sticks_log_means, - nu_sticks_log_stds, - nu_sticks_log_means, - ), - ) - - def keep_cell_grad(i): - return cell_noise_mean[indices][i], cell_noise_log_std[indices][i] - - def stop_cell_grad(i): - return jax.lax.stop_gradient( - cell_noise_mean[indices][i] - ), jax.lax.stop_gradient(cell_noise_log_std[indices][i]) - - def stop_cell_grads(i): - return jax.lax.cond(data_mask_subset[i] != 1, stop_cell_grad, keep_cell_grad, i) - - cell_noise_mean, cell_noise_log_std = vmap(stop_cell_grads)(jnp.arange(mb_size)) - - def has_children(i): - return jnp.any(ancestor_nodes_indices.ravel() == i) - - def has_next_branch(i): - return jnp.any(previous_branches_indices.ravel() == i) - - ones_vec = jnp.ones(cnvs[0].shape) - zeros_vec = jnp.log(ones_vec) - - log_baseline = diag_gaussian_sample(rng, log_baseline_mean, log_baseline_log_std) - - # noise - factor_precision_log_means = jnp.clip( - factor_precision_log_means, a_min=jnp.log(1e-3), a_max=jnp.log(1e3) - ) - factor_precision_log_stds = jnp.clip( - factor_precision_log_stds, a_min=jnp.log(1e-3), a_max=jnp.log(1e2) - ) - log_factors_precisions = diag_gaussian_sample( - rng, factor_precision_log_means, factor_precision_log_stds - ) - noise_factors_mean = jnp.clip(noise_factors_mean, a_min=-10.0, a_max=10.0) - noise_factors_log_std = jnp.clip( - noise_factors_log_std, a_min=jnp.log(1e-3), a_max=jnp.log(1e2) - ) - noise_factors = diag_gaussian_sample(rng, noise_factors_mean, noise_factors_log_std) - cell_noise_mean = jnp.clip(cell_noise_mean, a_min=-10.0, a_max=10.0) - cell_noise_log_std = jnp.clip( - cell_noise_log_std, a_min=jnp.log(1e-3), a_max=jnp.log(1e2) - ) - cell_noise = diag_gaussian_sample(rng, cell_noise_mean, cell_noise_log_std) - noise = jnp.dot(cell_noise, noise_factors) - - # batch effects - batch_effects_mean = jnp.clip(batch_effects_mean, a_min=-10.0, a_max=10.0) - batch_effects_log_std = jnp.clip( - batch_effects_log_std, a_min=jnp.log(1e-3), a_max=jnp.log(1e2) - ) - batch_effects_factors = diag_gaussian_sample( - rng, batch_effects_mean, batch_effects_log_std - ) - cell_covariates = jnp.array(cell_covariates)[indices] - batch_effects = jnp.dot(cell_covariates, batch_effects_factors) - - # unobserved factors - log_unobserved_factors_kernel_means = jnp.clip( - log_unobserved_factors_kernel_means, a_min=jnp.log(1e-6), a_max=jnp.log(1e2) - ) - log_unobserved_factors_kernel_log_stds = jnp.clip( - log_unobserved_factors_kernel_log_stds, - a_min=jnp.log(1e-6), - a_max=jnp.log(1e2), - ) - - def sample_unobs_kernel(i): - return jnp.clip( - diag_gaussian_sample( - rng, - log_unobserved_factors_kernel_means[i], - log_unobserved_factors_kernel_log_stds[i], - ), - a_min=jnp.log(1e-6), - ) - - def sample_all_unobs_kernel(i): - return jax.lax.cond( - node_mask[i] >= 0, sample_unobs_kernel, lambda i: zeros_vec, i - ) - - nodes_log_unobserved_factors_kernels = vmap(sample_all_unobs_kernel)( - jnp.arange(len(cnvs)) - ) - - unobserved_means = jnp.clip(unobserved_means, a_min=jnp.log(1e-3), a_max=10) - unobserved_log_stds = jnp.clip( - unobserved_log_stds, a_min=jnp.log(1e-6), a_max=jnp.log(1e2) - ) - - def sample_unobs(i): - return diag_gaussian_sample(rng, unobserved_means[i], unobserved_log_stds[i]) - - def sample_all_unobs(i): - return jax.lax.cond(node_mask[i] >= 0, sample_unobs, lambda i: zeros_vec, i) - - nodes_unobserved_factors = vmap(sample_all_unobs)(jnp.arange(len(cnvs))) - - nu_sticks_log_alphas = jnp.clip(nu_sticks_log_means, -3.0, 3.0) - nu_sticks_log_betas = jnp.clip(nu_sticks_log_stds, -3.0, 3.0) - - # def sample_nu(i): - # return jnp.clip( - # diag_gaussian_sample( - # rng, nu_sticks_log_means[i], nu_sticks_log_stds[i] - # ), - # -4.0, - # 4.0, - # ) - # - # def sample_valid_nu(i): - # return jax.lax.cond( - # has_children(i), sample_nu, lambda i: jnp.array([logit(1 - 1e-6)]), i - # ) # Sample all valid - # - # def sample_all_nus(i): - # return jax.lax.cond( - # cnvs[i][0] >= 0, sample_valid_nu, lambda i: jnp.array([logit(1e-6)]), i - # ) # Sample all - # - # log_nu_sticks = vmap(sample_all_nus)(jnp.arange(len(cnvs))) - # - # def sigmoid_nus(i): - # return jax.lax.cond( - # cnvs[i][0] >= 0, - # lambda i: jnn.sigmoid(log_nu_sticks[i]), - # lambda i: jnp.array([1e-6]), - # i, - # ) # Sample all - # - # nu_sticks = jnp.clip(vmap(sigmoid_nus)(jnp.arange(len(cnvs))), 1e-6, 1 - 1e-6) - - def sample_nu(i): - return jnp.clip( - beta_sample(rng, nu_sticks_log_alphas[i], nu_sticks_log_betas[i]), - 1e-4, - 1 - 1e-4, - ) - - def sample_valid_nu(i): - return jax.lax.cond( - has_children(i), sample_nu, lambda i: jnp.array([1 - 1e-4]), i - ) # Sample all valid - - def sample_all_nus(i): - return jax.lax.cond( - cnvs[i][0] >= 0, sample_valid_nu, lambda i: jnp.array([1e-4]), i - ) # Sample all - - nu_sticks = vmap(sample_all_nus)(jnp.arange(len(cnvs))) - - psi_sticks_log_alphas = jnp.clip(psi_sticks_log_means, -3.0, 3.0) - psi_sticks_log_betas = jnp.clip(psi_sticks_log_stds, -3.0, 3.0) - - # def sample_psi(i): - # return jnp.clip( - # diag_gaussian_sample( - # rng, psi_sticks_log_means[i], psi_sticks_log_stds[i] - # ), - # -4.0, - # 4.0, - # ) - # - # def sample_valid_psis(i): - # return jax.lax.cond( - # has_next_branch(i), - # sample_psi, - # lambda i: jnp.array([logit(1 - 1e-6)]), - # i, - # ) # Sample all valid - # - # def sample_all_psis(i): - # return jax.lax.cond( - # cnvs[i][0] >= 0, - # sample_valid_psis, - # lambda i: jnp.array([logit(1e-6)]), - # i, - # ) # Sample all - # - # log_psi_sticks = vmap(sample_all_psis)(jnp.arange(len(cnvs))) - # - # def sigmoid_psis(i): - # return jax.lax.cond( - # cnvs[i][0] >= 0, - # lambda i: jnn.sigmoid(log_psi_sticks[i]), - # lambda i: jnp.array([1e-6]), - # i, - # ) # Sample all - # - # psi_sticks = jnp.clip(vmap(sigmoid_psis)(jnp.arange(len(cnvs))), 1e-6, 1 - 1e-6) - - def sample_psi(i): - return jnp.clip( - beta_sample(rng, psi_sticks_log_alphas[i], psi_sticks_log_betas[i]), - 1e-4, - 1 - 1e-4, - ) - - def sample_valid_psis(i): - return jax.lax.cond( - has_next_branch(i), - sample_psi, - lambda i: jnp.array([1 - 1e-4]), - i, - ) # Sample all valid - - def sample_all_psis(i): - return jax.lax.cond( - cnvs[i][0] >= 0, - sample_valid_psis, - lambda i: jnp.array([1e-4]), - i, - ) # Sample all - - psi_sticks = vmap(sample_all_psis)(jnp.arange(len(cnvs))) - - lib_sizes = jnp.array(lib_sizes)[indices] - data = jnp.array(data)[indices] - baseline = jnp.exp(jnp.append(0, log_baseline)) - - def compute_node_ll(i): - unobserved_factors = nodes_unobserved_factors[i] * (parent_vector[i] != -1) - - node_mean = ( - baseline * cnvs[i] / 2 * jnp.exp(unobserved_factors + noise + batch_effects) - ) - sum = jnp.sum(node_mean, axis=1).reshape(mb_size, 1) - node_mean = node_mean / sum - node_mean = node_mean * lib_sizes - pll = vmap(jax.scipy.stats.poisson.logpmf)(data, node_mean) - ll = jnp.sum(pll, axis=1) # N-vector - - # TSSB prior - nu_stick = nu_sticks[i] - psi_stick = psi_sticks[i] - - def prev_branches_psi(idx): - return (idx != -1) * jnp.log(1.0 - psi_sticks[idx]) - - def ancestors_nu(idx): - _log_phi = jnp.log(psi_sticks[idx]) + jnp.sum( - vmap(prev_branches_psi)(previous_branches_indices[idx]) - ) - _log_1_nu = jnp.log(1.0 - nu_sticks[idx]) - total = _log_phi + _log_1_nu - return (idx != -1) * total - - # log_phi = jnp.log(psi_stick) + jnp.sum( - # vmap(prev_branches_psi)(previous_branches_indices[i]) - # ) - # log_node_weight = ( - # jnp.log(nu_stick) - # + log_phi - # + jnp.sum(vmap(ancestors_nu)(ancestor_nodes_indices[i])) - # ) - # log_node_weight = log_node_weight + jnp.log(tssb_weights[i]) - # ll = ll + log_node_weight # N-vector - - return ll - - small_ll = -1e30 * jnp.ones((mb_size)) - - def get_node_ll(i): - return jnp.where( - node_mask[i] >= 0, - compute_node_ll(jnp.where(node_mask[i] >= 0, i, 0)), - small_ll, - ) - - out = jnp.array(vmap(get_node_ll)(jnp.arange(len(parent_vector)))) - l = jnp.sum(jnn.logsumexp(out, axis=0) * data_mask_subset) - - log_rate = jnp.log(unobserved_factors_kernel_rate) - log_concentration = jnp.log(unobserved_factors_kernel_concentration) - log_kernel = jnp.log(unobserved_factors_root_kernel) - broadcasted_concentration = log_concentration * ones_vec - broadcasted_rate = log_rate * ones_vec - - def compute_node_kl(i): - kl = 0.0 - pl = diag_gamma_logpdf( - jnp.clip(jnp.exp(nodes_log_unobserved_factors_kernels[i]), a_min=1e-6), - broadcasted_concentration, - (parent_vector[i] != -1) - * (parent_vector[i] != 0) - * unobserved_factors_kernel_rate - * (jnp.abs(nodes_unobserved_factors[parent_vector[i]])), - ) - ent = -diag_loggaussian_logpdf( - jnp.clip(jnp.exp(nodes_log_unobserved_factors_kernels[i]), a_min=1e-6), - log_unobserved_factors_kernel_means[i], - log_unobserved_factors_kernel_log_stds[i], - ) - kl += (parent_vector[i] != -1) * (pl + ent) - - # # Penalize copies in unobserved nodes - # pl = diag_gamma_logpdf(1e-6 * jnp.ones(broadcasted_concentration.shape), broadcasted_concentration, - # (parent_vector[i] != -1)*(log_rate + jnp.abs(nodes_unobserved_factors[parent_vector[i]]))) - # ent = - diag_gaussian_logpdf(jnp.log(1e-6 * jnp.ones(broadcasted_concentration.shape)), log_unobserved_factors_kernel_means[i], log_unobserved_factors_kernel_log_stds[i]) - # kl -= (parent_vector[i] != -1) * jnp.all(tssb_indices[i] == tssb_indices[parent_vector[i]]) * (pl + 0*ent) - - # unobserved_factors - is_root_subtree = jnp.all(tssb_indices[i] == tssb_indices[0]) - pl = diag_gaussian_logpdf( - nodes_unobserved_factors[i], - (parent_vector[i] != -1) - * (parent_vector[i] != 0) - * nodes_unobserved_factors[parent_vector[i]] - + (parent_vector[i] != -1) - * is_root_subtree # promote overexpressing events near the root - * (jnp.exp(nodes_log_unobserved_factors_kernels[i]) > 0.1) - * 0.2, - jnp.clip(nodes_log_unobserved_factors_kernels[i], a_min=jnp.log(1e-6)) - * (parent_vector[i] != -1), - ) - ent = -diag_gaussian_logpdf( - nodes_unobserved_factors[i], unobserved_means[i], unobserved_log_stds[i] - ) - kl += (parent_vector[i] != -1) * (pl + ent) - - # # Penalize copied unobserved_factors - # pl = diag_gaussian_logpdf(nodes_unobserved_factors[parent_vector[i]], - # (parent_vector[i] != -1) * nodes_unobserved_factors[parent_vector[i]], - # jnp.clip(nodes_log_unobserved_factors_kernels[i], a_min=jnp.log(1e-6))*(parent_vector[i] != -1) + log_kernel*(parent_vector[i] == -1)) - # ent = - diag_gaussian_logpdf(nodes_unobserved_factors[parent_vector[i]], unobserved_means[i], unobserved_log_stds[i]) - # kl -= (parent_vector[i] != -1) * jnp.all(tssb_indices[i] == tssb_indices[parent_vector[i]]) * (pl + 0*ent) - - # sticks - nu_pl = has_children(i) * beta_logpdf( - nu_sticks[i], - jnp.log(jnp.array([1.0])), - jnp.log(jnp.array([dp_alphas[i]])), - ) - nu_ent = has_children(i) * -beta_logpdf( - nu_sticks[i], nu_sticks_log_alphas[i], nu_sticks_log_betas[i] - ) - kl += nu_pl + nu_ent - - psi_pl = has_next_branch(i) * beta_logpdf( - psi_sticks[i], - jnp.log(jnp.array([1.0])), - jnp.log(jnp.array([dp_gammas[i]])), - ) - psi_ent = has_next_branch(i) * -beta_logpdf( - psi_sticks[i], psi_sticks_log_alphas[i], psi_sticks_log_betas[i] - ) - kl += psi_pl + psi_ent - - return kl - - def get_node_kl(i): - return jnp.where( - node_mask[i] == 1, - compute_node_kl(jnp.where(node_mask[i] == 1, i, 0)), - 0.0, - ) - - node_kls = vmap(get_node_kl)(jnp.arange(len(parent_vector))) - node_kl = jnp.sum(node_kls) - - # Global vars KL - baseline_kl = diag_gaussian_logpdf( - log_baseline, zeros_vec[1:], zeros_vec[1:] - ) - diag_gaussian_logpdf(log_baseline, log_baseline_mean, log_baseline_log_std) - ones_mat = jnp.ones(log_factors_precisions.shape) - zeros_mat = jnp.zeros(log_factors_precisions.shape) - factor_precision_kl = diag_gamma_logpdf( - jnp.exp(log_factors_precisions), - jnp.log(global_noise_factors_precisions_shape) * ones_mat, - zeros_mat, - ) - diag_loggaussian_logpdf( - jnp.exp(log_factors_precisions), - factor_precision_log_means, - factor_precision_log_stds, - ) - noise_factors_kl = diag_gaussian_logpdf( - noise_factors, - jnp.zeros(noise_factors.shape), - jnp.log(jnp.sqrt(1.0 / jnp.exp(log_factors_precisions)).reshape(-1, 1)) - * jnp.ones(noise_factors.shape), - ) - diag_gaussian_logpdf(noise_factors, noise_factors_mean, noise_factors_log_std) - batch_effects_kl = diag_gaussian_logpdf( - batch_effects_factors, - jnp.zeros(batch_effects_factors.shape), - jnp.zeros(batch_effects_factors.shape), - ) - diag_gaussian_logpdf( - batch_effects_factors, batch_effects_mean, batch_effects_log_std - ) - total_kl = ( - node_kl - + baseline_kl - + factor_precision_kl - + noise_factors_kl - + batch_effects_kl - ) - - # Scale the KL by the data size - total_kl = total_kl * jnp.sum(data_mask_subset != 0) / data.shape[0] - - zeros_mat = jnp.zeros(cell_noise.shape) - cell_noise_kl = diag_gaussian_logpdf( - cell_noise, zeros_mat, zeros_mat, axis=1 - ) - diag_gaussian_logpdf(cell_noise, cell_noise_mean, cell_noise_log_std, axis=1) - cell_noise_kl = jnp.sum(cell_noise_kl * data_mask_subset) - total_kl = total_kl + cell_noise_kl - - elbo_val = l + total_kl - - return elbo_val, l, total_kl, node_kls - - -def batch_elbo( - rng, - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - params, - num_samples, -): - # Average over a batch of random samples from the var approx. - rngs = random.split(rng, num_samples) - init = [0] - init.extend([None] * (23 + len(params))) - vectorized_elbo = vmap(compute_elbo, in_axes=init) - return jnp.mean( - vectorized_elbo( - rngs, - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - *params, - ) - ) - - -@partial(jit, static_argnums=(3, 4, 5, 6, 23)) -def objective( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - num_samples, - params, - t, -): - rng = random.PRNGKey(t) - return -batch_elbo( - rng, - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - params, - num_samples, - ) - - -@partial(jit, static_argnums=(3, 4, 5, 6, 21, 22)) -def batch_objective( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - do_global, - global_only, - do_sticks, - sticks_only, - num_samples, - num_data, - params, - t, -): - rng = random.PRNGKey(t) - # Average over a batch of random samples from the var approx. - rngs = random.split(rng, num_samples) - init = [0] - init.extend([None] * (23 + len(params))) - vectorized_elbo = vmap(_compute_elbo, in_axes=init) - elbos, lls, kls, node_kls = vectorized_elbo( - rngs, - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - jnp.ones((num_data,)), - jnp.arange(num_data), - do_global, - global_only, - do_sticks, - sticks_only, - *params, - ) - elbo = jnp.mean(elbos) - ll = jnp.mean(lls) - kl = jnp.mean(kls) - node_kl = node_kls - return elbo, ll, kl, node_kl - - -@partial(jit, static_argnums=(3, 4, 5, 6, 23)) -def do_grad( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - num_samples, - params, - i, -): - return jax.value_and_grad(objective, argnums=24)( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - num_samples, - params, - i, - ) - - -def update( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - num_samples, - i, - opt_state, - opt_update, - get_params, -): - # print("Recompiling update!") - params = get_params(opt_state) - value, gradient = do_grad( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - node_mask, - data_mask_subset, - indices, - do_global, - global_only, - do_sticks, - sticks_only, - num_samples, - params, - i, - ) - opt_state = opt_update(i, gradient, opt_state) - return opt_state, gradient, params, value diff --git a/scatrex/models/cna/tree.py b/scatrex/models/cna/tree.py index 233d87b..0e77309 100644 --- a/scatrex/models/cna/tree.py +++ b/scatrex/models/cna/tree.py @@ -1,11 +1,11 @@ import numpy as np import matplotlib -from ...util import * -from ...ntssb.node import * -from ...ntssb.tree import * +from .node import CNANode +from ...utils.math_utils import * +from ...ntssb.observed_tree import * from ...plotting import * - +from ...utils.tree_utils import dict_to_tree def get_cnv_cmap(vmax=4, vmid=2): # Extend amplification colors beyond 4 @@ -31,47 +31,126 @@ def get_cnv_cmap(vmax=4, vmid=2): return cmap -class ObservedTree(Tree): +class CNATree(ObservedTree): def __init__(self, **kwargs): - super(ObservedTree, self).__init__(**kwargs) + super(CNATree, self).__init__(**kwargs) + self.node_constructor = CNANode self.cmap = get_cnv_cmap() self.sign_colors = {"-": "blue", "+": "red"} - def add_node_params( + def sample_kernel(self, parent_params, min_nevents=1, max_nevents_frac=0.67, min_cn=0, seed=None, **kwargs): + n_genes = parent_params.shape[0] + i = 0 + if seed is None: + seed = self.seed + while True: # Rejection sampling + cnvs = np.array(parent_params) + + i += 1 + # Sample number of regions to be affected + rng = np.random.default_rng(seed=seed + i) + n_r = rng.choice( + np.arange( + min_nevents, np.max([int(max_nevents_frac * self.n_regions), 2]) + ) + ) + + # Sample regions to be affected + rng = np.random.default_rng(seed=seed + i + 1) + affected_regions = rng.choice( + np.arange(0, self.n_regions), size=n_r, replace=False + ) + + all_affected_genes = [] + for r in affected_regions: + # Apply event to region + if r > 0: + affected_genes = np.arange(self.region_stops[r - 1], self.region_stops[r]) + elif r == 0: + affected_genes = np.arange(0, self.region_stops[r]) + + all_affected_genes.append(affected_genes) + + if np.any(parent_params[affected_genes]) == 0: + continue + + # Sample event sign + rng = np.random.default_rng(seed=seed + i + 2 + r) + s = rng.choice([-1, 1]) + + if np.any(parent_params[affected_genes] < 2): + s = -1 + elif np.any(parent_params[affected_genes] > 2): + s = 1 + + # Sample event magnitude + rng = np.random.default_rng(seed=seed + i + 3 + r) + m = np.max([1, rng.poisson(0.5)]) + + # Record event + clone_cn_events_genes = np.zeros((n_genes,)) + clone_cn_events_genes[affected_genes] = s * m + + cnvs[affected_genes] = ( + parent_params[affected_genes] + + clone_cn_events_genes[affected_genes] + ) + + + all_affected_genes = np.concatenate(all_affected_genes) + if np.all(cnvs[all_affected_genes] >= min_cn): + break + + return cnvs + + def sample_root(self, n_genes=50, n_regions=5, **kwargs): + self.n_regions = np.max([n_regions, 3]) + # Define regions + rng = np.random.default_rng(self.seed) + self.region_stops = np.sort( + rng.choice(np.arange(n_genes), size=n_regions, replace=False) + ) + return 2*np.ones((n_genes,)) + + def _add_node_params( self, n_genes=50, n_regions=5, min_nevents=1, max_nevents_frac=0.67, min_cn=0 ): C = len(self.tree_dict.keys()) n_regions = np.max([n_regions, 3]) # Define regions + rng = np.random.default_rng(self.seed) region_stops = np.sort( - np.random.choice(np.arange(n_genes), size=n_regions, replace=False) + rng.choice(np.arange(n_genes), size=n_regions, replace=False) ) # Trasverse the tree and generate events for each node for node in self.tree_dict: if self.tree_dict[node]["parent"] == "-1": - self.tree_dict[node]["params"] = np.ones((n_genes,)) * 2 + self.tree_dict[node]["param"] = np.ones((n_genes,)) * 2 self.tree_dict[node]["params_label"] = "" continue parent_params = np.array( - self.tree_dict[self.tree_dict[node]["parent"]]["params"] + self.tree_dict[self.tree_dict[node]["parent"]]["param"] ) - + i = 0 while True: - self.tree_dict[node]["params"] = np.array( - self.tree_dict[self.tree_dict[node]["parent"]]["params"] + i += 1 + self.tree_dict[node]["param"] = np.array( + self.tree_dict[self.tree_dict[node]["parent"]]["param"] ) self.tree_dict[node]["params_label"] = "" # Sample number of regions to be affected - n_r = np.random.choice( + rng = np.random.default_rng(seed=self.seed + i) + n_r = rng.choice( np.arange( min_nevents, np.max([int(max_nevents_frac * n_regions), 2]) ) ) # Sample regions to be affected - affected_regions = np.random.choice( + rng = np.random.default_rng(seed=self.seed + i + 1) + affected_regions = rng.choice( np.arange(0, n_regions), size=n_r, replace=False ) @@ -89,7 +168,8 @@ def add_node_params( continue # Sample event sign - s = np.random.choice([-1, 1]) + rng = np.random.default_rng(seed=self.seed + i + 2 + r) + s = rng.choice([-1, 1]) if np.any(parent_params[affected_genes] < 2): s = -1 @@ -97,13 +177,14 @@ def add_node_params( s = 1 # Sample event magnitude - m = np.max([1, np.random.poisson(0.5)]) + rng = np.random.default_rng(seed=self.seed + i + 3 + r) + m = np.max([1, rng.poisson(0.5)]) # Record event clone_cn_events_genes = np.zeros((n_genes,)) clone_cn_events_genes[affected_genes] = s * m - self.tree_dict[node]["params"][affected_genes] = ( + self.tree_dict[node]["param"][affected_genes] = ( parent_params[affected_genes] + clone_cn_events_genes[affected_genes] ) @@ -127,9 +208,11 @@ def add_node_params( ) all_affected_genes = np.concatenate(all_affected_genes) - if np.all(self.tree_dict[node]["params"][all_affected_genes] >= min_cn): + if np.all(self.tree_dict[node]["param"][all_affected_genes] >= min_cn): break + self.tree = dict_to_tree(self.tree_dict) + self.create_adata() def get_affected_genes(self): @@ -138,7 +221,7 @@ def get_affected_genes(self): def set_neutral_nodes(self, thres=0.95, neutral_level=2): for node in self.tree_dict: self.tree_dict[node]["is_neutral"] = False - cnvs = self.tree_dict[node]["params"].ravel() + cnvs = self.tree_dict[node]["param"].ravel() frac_neutral = np.sum(cnvs == neutral_level) / cnvs.size if frac_neutral > thres: self.tree_dict[node]["is_neutral"] = True diff --git a/scatrex/models/trajectory/__init__.py b/scatrex/models/trajectory/__init__.py new file mode 100644 index 0000000..ff72ac8 --- /dev/null +++ b/scatrex/models/trajectory/__init__.py @@ -0,0 +1,2 @@ +from .tree import TrajectoryTree +from .node import TrajectoryNode diff --git a/scatrex/models/trajectory/node.py b/scatrex/models/trajectory/node.py new file mode 100644 index 0000000..10009bc --- /dev/null +++ b/scatrex/models/trajectory/node.py @@ -0,0 +1,772 @@ +from numpy import * +import numpy as np +from numpy.random import * + +from functools import partial +import jax.numpy as jnp +import jax +import tensorflow_probability.substrates.jax.distributions as tfd + +from .node_opt import * # node optimization functions +from .node_opt import _mc_obs_ll +from ...utils.math_utils import * +from ...ntssb.node import * + +def update_params(params, params_gradient, step_size): + new_params = [] + for i, param in enumerate(params): + new_params.append(param + step_size * params_gradient[i]) + return new_params + +class TrajectoryNode(AbstractNode): + def __init__( + self, + observed_parameters, # subtree root location and angle + root_loc_mean=2., + loc_mean=.5, + angle_concentration=10., + loc_variance=.1, + obs_variance=.1, + n_factors=2, + obs_weight_variance=1., + factor_variance=1., + **kwargs, + ): + """ + This model generates nodes in a 2D space by sampling an angle, which roughly follows + the angle at which the parent was generated, and a location around some radius. + TODO: Make the location prior a Gamma instead of a Normal, still centered around some radius + but with a peak at zero and a long tail, to allow for more or less close by nodes + """ + super(TrajectoryNode, self).__init__(observed_parameters, **kwargs) + + self.n_genes = self.observed_parameters[0].size + + # Node hyperparameters + if self.parent() is None: + self.node_hyperparams = dict( + root_loc_mean=root_loc_mean, + angle_concentration=angle_concentration, + loc_mean=loc_mean, + loc_variance=loc_variance, + obs_variance=obs_variance, + n_factors=n_factors, + obs_weight_variance=obs_weight_variance, + factor_variance=factor_variance, + ) + else: + self.node_hyperparams = self.node_hyperparams_caller() + + self.reset_parameters(**self.node_hyperparams) + + if self.tssb is not None: + self.reset_variational_parameters() + self.sample_variational_distributions() + self.reset_sufficient_statistics(self.tssb.ntssb.num_batches) + + def combine_params(self): + return self.params[0] # get loc + + def get_mean(self): + return self.combine_params() + + def set_mean(self, node_mean=None): + if node_mean is not None: + self.node_mean = node_mean + else: + self.node_mean = self.get_mean() + + def get_observed_parameters(self): + return self.observed_parameters[0] # get root loc + + def get_params(self): + return self.get_mean() + + def get_param(self, param='mean'): + if param == 'observed': + return self.get_observed_parameters() + elif param == 'mean': + return self.get_mean() + else: + raise ValueError(f"No param available for `{param}`") + + def remove_noise(self, data): + """ + Noise is additive in this model + """ + return data - self.noise_factors_caller() + + # ========= Functions to initialize node. ========= + def set_node_hyperparams(self, **kwargs): + self.node_hyperparams.update(**kwargs) + + def reset_parameters( + self, + root_loc_mean=2., + loc_mean=.5, + angle_concentration=10., + loc_variance=.1, + obs_variance=.1, + n_factors=2, + obs_weight_variance=1., + factor_variance=1., + ): + self.node_hyperparams = dict( + root_loc_mean=root_loc_mean, + angle_concentration=angle_concentration, + loc_mean=loc_mean, + loc_variance=loc_variance, + obs_variance=obs_variance, + n_factors=n_factors, + obs_weight_variance=obs_weight_variance, + factor_variance=factor_variance, + ) + + parent = self.parent() + + if parent is None: + self.depth = 0.0 + self.params = self.observed_parameters # loc and angle + + n_factors = self.node_hyperparams['n_factors'] + factor_variance = self.node_hyperparams['factor_variance'] + rng = np.random.default_rng(seed=self.seed) + self.factor_weights = rng.normal(0., np.sqrt(factor_variance), size=(n_factors, 2)) * 1./3. + + if n_factors > 0: + n_genes_per_factor = int(2/n_factors) + offset = 6. + perm = np.random.permutation(2) + for factor in range(n_factors): + gene_idx = perm[factor*n_genes_per_factor:(factor+1)*n_genes_per_factor] + self.factor_weights[factor,gene_idx] *= offset + + + # Set data-dependent parameters + if self.tssb is not None: + num_data = self.tssb.ntssb.num_data + if num_data is not None: + self.reset_data_parameters() + + elif parent.tssb != self.tssb: + self.depth = 0.0 + self.params = self.observed_parameters # loc and angle + else: # Non-root node: inherits everything from upstream node + self.depth = parent.depth + 1 + loc_mean = self.node_hyperparams['loc_mean'] + angle_concentration = self.node_hyperparams['angle_concentration'] * self.depth + rng = np.random.default_rng(seed=self.seed) + sampled_angle = rng.vonmises(parent.params[1], angle_concentration) + sampled_loc = rng.normal( + loc_mean, + self.node_hyperparams['loc_variance'] + ) + sampled_loc = parent.params[0] + np.array([np.cos(sampled_angle)*np.abs(sampled_loc), np.sin(sampled_angle)*np.abs(sampled_loc)]) + self.params = [sampled_loc, sampled_angle] + + self.set_mean() + + # Generate structure on the factors + def reset_data_parameters(self): + num_data = self.tssb.ntssb.num_data + n_factors = self.node_hyperparams['n_factors'] + rng = np.random.default_rng(seed=self.seed) + self.obs_weights = rng.normal(0., 1., size=(num_data, n_factors)) * 1./3. + if n_factors > 0: + n_obs_per_factor = int(num_data/n_factors) + offset = 6. + perm = np.random.permutation(num_data) + for factor in range(n_factors): + obs_idx = perm[factor*n_obs_per_factor:(factor+1)*n_obs_per_factor] + self.obs_weights[obs_idx,factor] *= offset + + self.noise_factors = self.obs_weights.dot(self.factor_weights) + + def reset_variational_parameters(self): + # Assignments + num_data = self.tssb.ntssb.num_data + if num_data is not None: + self.variational_parameters['q_z'] = jnp.ones(num_data,) + + self.variational_parameters['sum_E_log_1_nu'] = 0. + self.variational_parameters['E_log_phi'] = 0. + + # Sticks + self.variational_parameters["delta_1"] = 1. + self.variational_parameters["delta_2"] = 1. + self.variational_parameters["sigma_1"] = 1. + self.variational_parameters["sigma_2"] = 1. + + # Pivots + self.variational_parameters["q_rho"] = np.ones(len(self.tssb.children_root_nodes),) + + parent = self.parent() + if parent is None and self.tssb.parent() is None: + rng = np.random.default_rng(self.seed) + # root stores global parameters + n_factors = self.node_hyperparams['n_factors'] + self.variational_parameters["global"] = { + 'factor_weights': {'mean': jnp.array(self.node_hyperparams['factor_variance']/10.*rng.normal(size=(n_factors, 2))), + 'log_std': -2. + jnp.zeros((n_factors, 2))} + } + if num_data is not None: + rng = np.random.default_rng(self.seed+1) + self.variational_parameters["local"] = { + 'obs_weights': {'mean': jnp.array(self.node_hyperparams['obs_weight_variance']/10.*rng.normal(size=(num_data, n_factors))), + 'log_std': -2. + jnp.zeros((num_data, n_factors))} + } + self.obs_weights = self.variational_parameters["local"]["obs_weights"]["mean"] + self.factor_weights = self.variational_parameters["global"]["factor_weights"]["mean"] + self.noise_factors = self.obs_weights.dot(self.factor_weights) + elif parent is None: + return # no variational parameters for root nodes of TSSBs in this model + else: # only the non-root nodes have variational parameters + # Kernel + radius = self.node_hyperparams['loc_mean'] + + if "angle" not in parent.variational_parameters["kernel"]: + mean_angle = jnp.array([parent.observed_parameters[1]]) + parent_loc = jnp.array(parent.observed_parameters[0]) + else: + mean_angle = parent.variational_parameters["kernel"]["angle"]["mean"] + parent_loc = parent.variational_parameters["kernel"]["loc"]["mean"] + + rng = np.random.default_rng(self.seed+2) + mean_angle = rng.vonmises(mean_angle, self.node_hyperparams['angle_concentration'] * self.depth) + mean_loc = parent_loc + jnp.array([np.cos(mean_angle[0])*radius, jnp.sin(mean_angle[0])*radius]) + rng = np.random.default_rng(self.seed+3) + mean_loc = rng.normal(mean_loc, self.node_hyperparams['loc_variance']) + self.variational_parameters["kernel"] = { + 'angle': {'mean': jnp.array(mean_angle), 'log_kappa': jnp.array([-1.])}, + 'loc': {'mean': jnp.array(mean_loc), 'log_std': jnp.array([-1., -1.])} + } + self.params = [self.variational_parameters["kernel"]["loc"]["mean"], + self.variational_parameters["kernel"]["angle"]["mean"]] + + def set_learned_parameters(self): + if self.parent() is None and self.tssb.parent() is None: + self.obs_weights = self.variational_parameters["local"]["obs_weights"]["mean"] + self.factor_weights = self.variational_parameters["global"]["factor_weights"]["mean"] + self.noise_factors = self.obs_weights.dot(self.factor_weights) + elif self.parent() is None: + self.params = self.observed_parameters + else: + self.params = [self.variational_parameters["kernel"]["loc"]["mean"], + self.variational_parameters["kernel"]["angle"]["mean"]] + + def reset_sufficient_statistics(self, num_batches=1): + self.suff_stats = { + 'ent': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(c_n = this tree) q(z_n = this node) * log q(z_n = this node) + 'mass': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) + 'A': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g x_ng ** 2 + 'B_g': {'total': 0, 'batch': np.zeros((num_batches,2))}, # \sum_n q(z_n = this node) * x_ng + 'C': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g x_ng * E[s_nW_g] + 'D_g': {'total': 0, 'batch': np.zeros((num_batches,2))}, # \sum_n q(z_n = this node) * E[s_nW_g] + 'E': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g E[(s_nW_g)**2] + } + if self.parent() is None and self.tssb.parent() is None: + self.local_suff_stats = { + 'locals_kl': {'total': 0., 'batch': np.zeros((num_batches,))}, + } + + def merge_suff_stats(self, suff_stats): + for stat in self.suff_stats: + self.suff_stats[stat]['total'] += suff_stats[stat]['total'] + self.suff_stats[stat]['batch'] += suff_stats[stat]['batch'] + + def update_sufficient_statistics(self, batch_idx=None): + if batch_idx is not None: + idx = self.tssb.ntssb.batch_indices[batch_idx] + else: + idx = jnp.arange(self.tssb.ntssb.num_data) + + if self.parent() is None and self.tssb.parent() is None: + locals_kl = self.compute_local_priors(idx) + self.compute_local_entropies(idx) + if batch_idx is not None: + self.local_suff_stats['locals_kl']['total'] -= self.local_suff_stats['locals_kl']['batch'][batch_idx] + self.local_suff_stats['locals_kl']['batch'][batch_idx] = locals_kl + self.local_suff_stats['locals_kl']['total'] += self.local_suff_stats['locals_kl']['batch'][batch_idx] + else: + self.local_suff_stats['locals_kl']['total'] = locals_kl + + ent = assignment_entropies(self.variational_parameters['q_z'][idx]) + ent *= self.tssb.variational_parameters['q_c'][idx] + E_ass = self.variational_parameters['q_z'][idx] * self.tssb.variational_parameters['q_c'][idx] + E_sw = jnp.mean(self.get_noise_sample(idx),axis=0) + E_sw2 = jnp.mean(self.get_noise_sample(idx)**2,axis=0) + x = self.tssb.ntssb.data[idx] + + new_ent = jnp.sum(ent) + new_mass = jnp.sum(E_ass) + new_A = jnp.sum(E_ass * jnp.sum(x**2, axis=1)) + new_B = jnp.sum(E_ass[:,None] * x, axis=0) + new_C = jnp.sum(E_ass * jnp.sum(x * E_sw, axis=1)) + new_D = jnp.sum(E_ass[:,None] * E_sw, axis=0) + new_E = jnp.sum(E_ass * jnp.sum(E_sw2, axis=1)) + + if batch_idx is not None: + self.suff_stats['ent']['total'] -= self.suff_stats['ent']['batch'][batch_idx] + self.suff_stats['ent']['batch'][batch_idx] = new_ent + self.suff_stats['ent']['total'] += self.suff_stats['ent']['batch'][batch_idx] + + self.suff_stats['mass']['total'] -= self.suff_stats['mass']['batch'][batch_idx] + self.suff_stats['mass']['batch'][batch_idx] = new_mass + self.suff_stats['mass']['total'] += self.suff_stats['mass']['batch'][batch_idx] + + self.suff_stats['A']['total'] -= self.suff_stats['A']['batch'][batch_idx] + self.suff_stats['A']['batch'][batch_idx] = new_A + self.suff_stats['A']['total'] += self.suff_stats['A']['batch'][batch_idx] + + self.suff_stats['B_g']['total'] -= self.suff_stats['B_g']['batch'][batch_idx] + self.suff_stats['B_g']['batch'][batch_idx] = new_B + self.suff_stats['B_g']['total'] += self.suff_stats['B_g']['batch'][batch_idx] + + self.suff_stats['C']['total'] -= self.suff_stats['C']['batch'][batch_idx] + self.suff_stats['C']['batch'][batch_idx] = new_C + self.suff_stats['C']['total'] += self.suff_stats['C']['batch'][batch_idx] + + self.suff_stats['D_g']['total'] -= self.suff_stats['D_g']['batch'][batch_idx] + self.suff_stats['D_g']['batch'][batch_idx] = new_D + self.suff_stats['D_g']['total'] += self.suff_stats['D_g']['batch'][batch_idx] + + self.suff_stats['E']['total'] -= self.suff_stats['E']['batch'][batch_idx] + self.suff_stats['E']['batch'][batch_idx] = new_E + self.suff_stats['E']['total'] += self.suff_stats['E']['batch'][batch_idx] + else: + self.suff_stats['ent']['total'] = new_ent + self.suff_stats['mass']['total'] = new_mass + self.suff_stats['A']['total'] = new_A + self.suff_stats['B_g']['total'] = new_B + self.suff_stats['C']['total'] = new_C + self.suff_stats['D_g']['total'] = new_D + self.suff_stats['E']['total'] = new_E + + # ========= Functions to take samples from node. ========= + def sample_observation(self, n): + node_mean = self.get_mean() + noise_factors = self.noise_factors_caller()[n] + rng = np.random.default_rng(seed=self.seed+n) + s = rng.normal(node_mean + noise_factors, self.node_hyperparams['obs_variance']) + return s + + def sample_observations(self): + n_obs = len(self.data) + node_mean = self.get_mean() + noise_factors = self.noise_factors_caller()[np.array(list(self.data))] + rng = np.random.default_rng(seed=self.seed) + s = rng.normal(node_mean + noise_factors, self.node_hyperparams['obs_variance'], size=[n_obs, self.n_genes]) + return s + + # ========= Functions to access root's parameters. ========= + def node_hyperparams_caller(self): + if self.parent() is None: + return self.node_hyperparams + else: + return self.parent().node_hyperparams_caller() + + def noise_factors_caller(self): + return self.tssb.ntssb.root['node'].root['node'].noise_factors + + def get_obs_weights_sample(self): + return self.tssb.ntssb.root['node'].root['node'].obs_weights_sample + + def set_local_sample(self, sample, idx=None): + if idx is None: + idx = jnp.arange(self.tssb.ntssb.num_data) + self.tssb.ntssb.root['node'].root['node'].obs_weights_sample = self.tssb.ntssb.root['node'].root['node'].obs_weights_sample.at[:,idx].set(sample) + + def get_factor_weights_sample(self): + return self.tssb.ntssb.root['node'].root['node'].factor_weights_sample + + def set_global_sample(self, sample): + self.tssb.ntssb.root['node'].root['node'].factor_weights_sample = jnp.array(sample) + + def get_noise_sample(self, idx): + obs_weights = self.get_obs_weights_sample()[:,idx] + factor_weights = self.get_factor_weights_sample() + return jax.vmap(sample_prod, in_axes=(0,0))(obs_weights,factor_weights) + + def get_direction_sample(self): + return self.samples[1] + + def get_state_sample(self): + return self.samples[0] + + def get_prior_angle_concentration(self, depth=None): + if depth is None: + depth = self.depth + return self.node_hyperparams['angle_concentration'] * jnp.maximum(depth, 1) # Prior hyperparameter + + # ======== Functions using the variational parameters. ========= + def compute_loglikelihood(self, idx): + # Use stored samples for loc + node_mean_samples = self.samples[0] + obs_weights_samples = self.get_obs_weights_sample()[:,idx] + factor_weights_samples = self.get_factor_weights_sample() + std = jnp.sqrt(self.node_hyperparams['obs_variance']) + # Average over samples for each observation + ll = jnp.mean(jax.vmap(_mc_obs_ll, in_axes=[None,0,0,0,None])(self.tssb.ntssb.data[idx], + node_mean_samples, + obs_weights_samples, + factor_weights_samples, + std), axis=0) # mean over MC samples + return ll + + def compute_loglikelihood_suff(self): + node_mean_samples = self.samples[0] + std = jnp.sqrt(self.node_hyperparams['obs_variance']) + ll = jnp.mean(jax.vmap(ll_suffstats, in_axes=[0, None, None, None, None, None, None, None]) + (node_mean_samples,self.suff_stats['mass']['total'],self.suff_stats['A']['total'], + self.suff_stats['B_g']['total'], self.suff_stats['C']['total'], + self.suff_stats['D_g']['total'],self.suff_stats['E']['total'],std) + ) + return ll + + def sample_variational_distributions(self, n_samples=10): + if self.parent() is not None: + if self.parent().samples is not None: + n_samples = self.parent().samples[0].shape[0] + if self.parent() is None and self.tssb.parent() is None: + self.sample_locals(n_samples=n_samples, store=True) + self.sample_globals(n_samples=n_samples, store=True) + self.sample_kernel(n_samples=n_samples, store=True) + + def sample_locals(self, n_samples, store=True): + key = jax.random.PRNGKey(self.seed) + key, sample_grad = self.local_sample_and_grad(jnp.arange(self.tssb.ntssb.num_data), key, n_samples=n_samples) + sampled_obs_weights, _ = sample_grad + if store: + self.obs_weights_sample = sampled_obs_weights + else: + return sampled_obs_weights + + def sample_globals(self, n_samples, store=True): + key = jax.random.PRNGKey(self.seed) + key, sample_grad = self.global_sample_and_grad(key, n_samples=n_samples) + sampled_factor_weights, _ = sample_grad + if store: + self.factor_weights_sample = sampled_factor_weights + else: + return sampled_factor_weights + + def sample_kernel(self, n_samples=10, store=True): + parent = self.parent() + if parent is None: + return self._sample_root_kernel(n_samples=n_samples, store=store) + + key = jax.random.PRNGKey(self.seed) + key, sample_grad = self.direction_sample_and_grad(key, n_samples=n_samples) + sampled_angle, _ = sample_grad + + key, sample_grad = self.state_sample_and_grad(key, n_samples=n_samples) + sampled_loc, _ = sample_grad + + samples = [sampled_loc, sampled_angle] + if store: + self.samples = samples + else: + return samples + + def _sample_root_kernel(self, n_samples=10, store=True): + # In this model the root is just the known parameters, so just store n_samples copies of them to mimic a sample + observed_angle = jnp.array([self.observed_parameters[1]]) # Observed angle + observed_loc = jnp.array([self.observed_parameters[0]]) # Observed location + sampled_angle = jnp.vstack(jnp.repeat(observed_angle, n_samples, axis=0)) + sampled_loc = jnp.vstack(jnp.repeat(observed_loc, n_samples, axis=0)) + samples = [sampled_loc, sampled_angle] + if store: + self.samples = samples + else: + return samples + + def compute_global_priors(self): + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['factor_variance'])) + return jnp.mean(mc_factor_weights_logp_val_and_grad(self.factor_weights_sample, 0., log_std)[0]) + + def compute_local_priors(self, batch_indices): + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_weight_variance'])) + return jnp.mean(mc_obs_weights_logp_val_and_grad(self.obs_weights_sample[:,batch_indices], 0., log_std)[0]) + + def compute_global_entropies(self): + mean = self.variational_parameters['global']['factor_weights']['mean'] + log_std = self.variational_parameters['global']['factor_weights']['log_std'] + return jnp.sum(factor_weights_logq_val_and_grad(mean, log_std)[0]) + + def compute_local_entropies(self, batch_indices): + mean = self.variational_parameters['local']['obs_weights']['mean'][batch_indices] + log_std = self.variational_parameters['local']['obs_weights']['log_std'][batch_indices] + return jnp.sum(obs_weights_logq_val_and_grad(mean, log_std)[0]) + + def compute_kernel_prior(self): + parent = self.parent() + if parent is None: + return self.compute_root_prior() + + prior_mean_angle = self.parent().get_direction_sample() + prior_angle_concentration = self.get_prior_angle_concentration() + angle_samples = self.get_direction_sample() + angle_logpdf = mc_angle_logp_val_and_grad(angle_samples, prior_mean_angle, prior_angle_concentration)[0] + + radius = self.node_hyperparams['loc_mean'] + parent_loc = self.parent().get_state_sample() + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance'])) + loc_samples = self.get_state_sample() + loc_logpdf = mc_loc_logp_val_and_grad(loc_samples, parent_loc, angle_samples, log_std, radius)[0] + + return jnp.mean(angle_logpdf + loc_logpdf) + + def compute_root_direction_prior(self, parent_alpha): + concentration = self.get_prior_angle_concentration() + alpha = self.get_direction_sample() + return jnp.mean(mc_angle_logp_val_and_grad(alpha, parent_alpha, concentration)[0]) + + def compute_root_state_prior(self, parent_psi): + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance'])) + radius = self.node_hyperparams['root_loc_mean'] + psi = self.get_state_sample() + alpha = self.get_direction_sample() + return jnp.mean(mc_loc_logp_val_and_grad(psi, parent_psi, alpha, log_std, radius)[0]) + + def compute_root_kernel_prior(self, samples): + parent_alpha = samples[0] + logp = self.compute_root_direction_prior(parent_alpha) + parent_psi = samples[1] + logp += self.compute_root_state_prior(parent_psi) + return logp + + def compute_root_prior(self): + return 0. + + def compute_kernel_entropy(self): + parent = self.parent() + if parent is None: + return self.compute_root_entropy() + + # Angle + angle_logpdf = tfd.VonMises(np.exp(self.variational_parameters['kernel']['angle']['mean']), + jnp.exp(self.variational_parameters['kernel']['angle']['log_kappa']) + ).entropy() + angle_logpdf = jnp.sum(angle_logpdf) + + # Location + loc_logpdf = tfd.Normal(self.variational_parameters['kernel']['loc']['mean'], + jnp.exp(self.variational_parameters['kernel']['loc']['log_std']) + ).entropy() + loc_logpdf = jnp.sum(loc_logpdf) # Sum across features + + return angle_logpdf + loc_logpdf + + def compute_root_entropy(self): + # In this model the root nodes have no unknown parameters + return 0. + + # ======== Functions for updating the variational parameters. ========= + def local_sample_and_grad(self, idx, key, n_samples): + """Sample and take gradient of local parameters. Must be root""" + mean = self.variational_parameters['local']['obs_weights']['mean'][idx] + log_std = self.variational_parameters['local']['obs_weights']['log_std'][idx] + key, *sub_keys = jax.random.split(key, n_samples+1) + return key, mc_sample_obs_weights_val_and_grad(jnp.array(sub_keys), mean, log_std) + + def global_sample_and_grad(self, key, n_samples): + """Sample and take gradient of global parameters. Must be root""" + mean = self.variational_parameters['global']['factor_weights']['mean'] + log_std = self.variational_parameters['global']['factor_weights']['log_std'] + key, *sub_keys = jax.random.split(key, n_samples+1) + return key, mc_sample_factor_weights_val_and_grad(jnp.array(sub_keys), mean, log_std) + + def compute_locals_prior_grad(self, sample): + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_weight_variance'])) + return mc_obs_weights_logp_val_and_grad(sample, 0., log_std)[1] + + def compute_globals_prior_grad(self, sample): + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['factor_variance'])) + return mc_factor_weights_logp_val_and_grad(sample, 0., log_std)[1] + + def compute_locals_entropy_grad(self, idx): + mean = self.variational_parameters['local']['obs_weights']['mean'][idx] + log_std = self.variational_parameters['local']['obs_weights']['log_std'][idx] + return obs_weights_logq_val_and_grad(mean, log_std)[1] + + def compute_globals_entropy_grad(self): + mean = self.variational_parameters['global']['factor_weights']['mean'] + log_std = self.variational_parameters['global']['factor_weights']['log_std'] + return factor_weights_logq_val_and_grad(mean, log_std)[1] + + def state_sample_and_grad(self, key, n_samples): + """Sample and take gradient of state""" + mu = self.variational_parameters['kernel']['loc']['mean'] + log_std = self.variational_parameters['kernel']['loc']['log_std'] + key, *sub_keys = jax.random.split(key, n_samples+1) + return key, mc_sample_loc_val_and_grad(jnp.array(sub_keys), mu, log_std) + + def direction_sample_and_grad(self, key, n_samples): + """Sample and take gradient of direction""" + mu = self.variational_parameters['kernel']['angle']['mean'] + log_kappa = self.variational_parameters['kernel']['angle']['log_kappa'] + key, *sub_keys = jax.random.split(key, n_samples+1) + return key, mc_sample_angle_val_and_grad(jnp.array(sub_keys), mu, log_kappa) + + def state_sample_and_grad(self, key, n_samples): + """Sample and take gradient of state""" + mu = self.variational_parameters['kernel']['loc']['mean'] + log_std = self.variational_parameters['kernel']['loc']['log_std'] + key, *sub_keys = jax.random.split(key, n_samples+1) + return key, mc_sample_loc_val_and_grad(jnp.array(sub_keys), mu, log_std) + + def compute_direction_prior_grad_wrt_direction(self, alpha, parent_alpha, parent_loc): + """Gradient of logp(alpha|parent_alpha) wrt this alpha""" + concentration = self.get_prior_angle_concentration() + return mc_angle_logp_val_and_grad(alpha, parent_alpha, concentration)[1] + + def compute_direction_prior_grad_wrt_state(self, alpha, parent_alpha, parent_loc): + """Gradient of logp(alpha|parent_alpha) wrt this alpha""" + return 0. + + def compute_direction_prior_child_grad(self, child_alpha, alpha): + """Gradient of logp(child_alpha|alpha) wrt this alpha""" + concentration = self.get_prior_angle_concentration(depth=self.depth+1) + return mc_angle_logp_val_and_grad_wrt_parent(child_alpha, alpha, concentration)[1] + + def compute_state_prior_grad(self, psi, parent_psi, alpha): + """Gradient of logp(psi|parent_psi,new_alpha) wrt this psi""" + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance'])) + radius = self.node_hyperparams['loc_mean'] + return mc_loc_logp_val_and_grad(psi, parent_psi, alpha, log_std, radius)[1] + + def compute_state_prior_child_grad(self, child_psi, psi, child_alpha): + """Gradient of logp(child_psi|psi,child_alpha) wrt this psi""" + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance'])) + radius = self.node_hyperparams['loc_mean'] + return mc_loc_logp_val_and_grad_wrt_parent(child_psi, psi, child_alpha, log_std, radius)[1] + + def compute_root_state_prior_child_grad(self, child_psi, psi, child_alpha): + """Gradient of logp(child_psi|psi,child_alpha) wrt this psi""" + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance'])) + radius = self.node_hyperparams['root_loc_mean'] + return mc_loc_logp_val_and_grad_wrt_parent(child_psi, psi, child_alpha, log_std, radius)[1] + + def compute_state_prior_grad_wrt_direction(self, psi, parent_psi, alpha): + """Gradient of logp(psi|parent_psi,alpha) wrt this alpha""" + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance'])) + radius = self.node_hyperparams['loc_mean'] + return mc_loc_logp_val_and_grad_wrt_angle(psi, parent_psi, alpha, log_std, radius)[1] + + def compute_direction_entropy_grad(self): + """Gradient of logq(alpha) wrt this alpha""" + mu = self.variational_parameters['kernel']['angle']['mean'] + log_kappa = self.variational_parameters['kernel']['angle']['log_kappa'] + return angle_logq_val_and_grad(mu, log_kappa)[1] + + def compute_state_entropy_grad(self): + """Gradient of logq(psi) wrt this psi""" + mu = self.variational_parameters['kernel']['loc']['mean'] + log_std = self.variational_parameters['kernel']['loc']['log_std'] + return loc_logq_val_and_grad(mu, log_std)[1] + + def compute_ll_state_grad(self, x, weights, psi): + """Gradient of logp(x|psi,noise) wrt this psi""" + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_variance'])) + locals = self.get_obs_weights_sample() + globals = self.get_factor_weights_sample() + return mc_ll_val_and_grad_psi(x, weights, psi, locals, globals, log_std)[1] + + def compute_ll_state_grad_suff(self, psi): + """Gradient of logp(x|psi,noise) wrt this psi using suff stats""" + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_variance'])) + return mc_ll_node_mean_suff_val_and_grad(psi, self.suff_stats['mass']['total'], + self.suff_stats['B_g']['total'], + self.suff_stats['D_g']['total'], log_std)[1] + + def compute_ll_locals_grad(self, x, idx, weights): + """Gradient of logp(x|psi,locals,globals) wrt locals""" + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_variance'])) + psi = self.get_state_sample() + locals = self.get_obs_weights_sample()[:,idx] + globals = self.get_factor_weights_sample() + return mc_ll_val_and_grad_obs_weights(x, weights, psi, locals, globals, log_std)[1] + + def compute_ll_globals_grad(self, x, idx, weights): + """Gradient of logp(x|psi,locals,globals) wrt globals""" + log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_variance'])) + psi = self.get_state_sample() + locals = self.get_obs_weights_sample()[:,idx] + globals = self.get_factor_weights_sample() + return mc_ll_val_and_grad_factor_weights(x, weights, psi, locals, globals, log_std)[1] + + def update_direction_params(self, direction_params_grad, direction_sample_grad, direction_params_entropy_grad, step_size=0.001): + mc_grad = jnp.mean(direction_params_grad[0] * direction_sample_grad, axis=0) + angle_mean_grad = mc_grad + direction_params_entropy_grad[0] + self.variational_parameters['kernel']['angle']['mean'] += angle_mean_grad * step_size + + mc_grad = jnp.mean(direction_params_grad[1] * direction_sample_grad, axis=0) + angle_log_kappa_grad = mc_grad + direction_params_entropy_grad[1] + self.variational_parameters['kernel']['angle']['log_kappa'] += angle_log_kappa_grad * step_size + + def update_state_params(self, state_params_grad, state_sample_grad, state_params_entropy_grad, step_size=0.001): + mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0) + loc_mean_grad = mc_grad + state_params_entropy_grad[0] + self.variational_parameters['kernel']['loc']['mean'] += loc_mean_grad * step_size + + mc_grad = jnp.mean(state_params_grad[1] * state_sample_grad, axis=0) + loc_log_std_grad = mc_grad + state_params_entropy_grad[1] + self.variational_parameters['kernel']['loc']['log_std'] += loc_log_std_grad * step_size + + def update_local_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001): + mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0] + new_param = self.variational_parameters['local']['obs_weights']['mean'][idx] + param_grad * step_size + self.variational_parameters['local']['obs_weights']['mean'] = self.variational_parameters['local']['obs_weights']['mean'].at[idx].set(new_param) + + mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0) + param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1] + new_param = self.variational_parameters['local']['obs_weights']['log_std'][idx] + param_grad * step_size + self.variational_parameters['local']['obs_weights']['log_std'] = self.variational_parameters['local']['obs_weights']['log_std'].at[idx].set(new_param) + + def update_global_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001): + mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[0] + self.variational_parameters['global']['factor_weights']['mean'] += param_grad * step_size + + mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[1] + self.variational_parameters['global']['factor_weights']['log_std'] += param_grad * step_size + + def initialize_global_opt_states(self): + n_factors = self.node_hyperparams['n_factors'] + m = jnp.zeros((n_factors,self.n_genes)) + v = jnp.zeros((n_factors,self.n_genes)) + state1 = (m,v) + m = jnp.zeros((n_factors,self.n_genes)) + v = jnp.zeros((n_factors,self.n_genes)) + state2 = (m,v) + states = (state1, state2) + return states + + def update_global_params_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9, + b2=0.999, eps=1e-8, step_size=0.001): + mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[0] + + m, v = states[0] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state1 = (m, v) + self.variational_parameters['global']['factor_weights']['mean'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + + mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0) + param_grad = mc_grad + global_params_entropy_grad[1] + + m, v = states[1] + m = (1 - b1) * param_grad + b1 * m # First moment estimate. + v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate. + mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction. + vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1)) + state2 = (m, v) + self.variational_parameters['global']['factor_weights']['log_std'] += step_size * mhat / (jnp.sqrt(vhat) + eps) + + states = (state1, state2) + return states \ No newline at end of file diff --git a/scatrex/models/trajectory/node_opt.py b/scatrex/models/trajectory/node_opt.py new file mode 100644 index 0000000..756b431 --- /dev/null +++ b/scatrex/models/trajectory/node_opt.py @@ -0,0 +1,142 @@ +import jax +import jax.numpy as jnp +import tensorflow_probability.substrates.jax.distributions as tfd + +@jax.jit +def sample_angle(key, mu, log_kappa): # univariate: one sample + return tfd.VonMises(mu, jnp.exp(log_kappa)).sample(seed=key) +sample_angle_val_and_grad = jax.vmap(jax.value_and_grad(sample_angle, argnums=(1,2)), in_axes=(None, 0, 0)) # per-dimension val and grad +mc_sample_angle_val_and_grad = jax.jit(jax.vmap(sample_angle_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def sample_loc(key, mu, log_std): # univariate: one sample + return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key) +sample_loc_val_and_grad = jax.vmap(jax.value_and_grad(sample_loc, argnums=(1,2)), in_axes=(None, 0, 0)) # per-dimension val and grad +mc_sample_loc_val_and_grad = jax.jit(jax.vmap(sample_loc_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def angle_logp(this_angle, parent_angle, concentration): # single sample + return jnp.sum(tfd.VonMises(parent_angle, concentration).log_prob(this_angle)) +angle_logp_val_and_grad = jax.jit(jax.value_and_grad(angle_logp, argnums=0)) # Take grad wrt to this +mc_angle_logp_val_and_grad = jax.jit(jax.vmap(angle_logp_val_and_grad, in_axes=(0,0, None))) # Multiple sample value_and_grad + +angle_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(angle_logp, argnums=1)) # Take grad wrt to parent +mc_angle_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(angle_logp_val_and_grad, in_axes=(0,0, None))) # Multiple sample value_and_grad + +@jax.jit +def angle_logq(mu, log_kappa): + return jnp.sum(tfd.VonMises(mu, jnp.exp(log_kappa)).entropy()) +angle_logq_val_and_grad = jax.jit(jax.value_and_grad(angle_logq, argnums=(0,1))) # Take grad wrt to parameters + +@jax.jit +def loc_logp(this_loc, parent_loc, this_angle, log_std, radius): # single sample + mean = parent_loc + jnp.hstack([jnp.cos(this_angle)*radius, jnp.sin(this_angle)*radius]) # Use samples from parent + return jnp.sum(tfd.Normal(mean, jnp.exp(log_std)).log_prob(this_loc)) # sum across dimensions +loc_logp_val = jax.jit(loc_logp) +mc_loc_logp_val = jax.jit(jax.vmap(loc_logp_val, in_axes=(0,0,0, None, None))) # Multiple sample + +loc_logp_val_and_grad = jax.jit(jax.value_and_grad(loc_logp, argnums=0)) # Take grad wrt to this +mc_loc_logp_val_and_grad = jax.jit(jax.vmap(loc_logp_val_and_grad, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad + +loc_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(loc_logp, argnums=1)) # Take grad wrt to parent +mc_loc_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_parent, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad + +loc_logp_val_and_grad_wrt_angle = jax.jit(jax.value_and_grad(loc_logp, argnums=2)) # Take grad wrt to angle +mc_loc_logp_val_and_grad_wrt_angle = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_angle, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad + +@jax.jit +def loc_logq(mu, log_std): + return tfd.Normal(mu, jnp.exp(log_std)).entropy() +loc_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(loc_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters + + + + +@jax.jit +def sample_obs_weights(key, mean, log_std): # NxK + return tfd.Normal(mean, jnp.exp(log_std)).sample(seed=key) +sample_obs_weights_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_obs_weights, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad +mc_sample_obs_weights_val_and_grad = jax.jit(jax.vmap(sample_obs_weights_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + +@jax.jit +def sample_factor_weights(key, mu, log_std): # KxG + return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key) +sample_factor_weights_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_factor_weights, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad +mc_sample_factor_weights_val_and_grad = jax.jit(jax.vmap(sample_factor_weights_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad + + +@jax.jit +def obs_weights_logp(sample, mean, log_std): # single sample, NxK + return jnp.sum(tfd.Normal(mean, jnp.exp(log_std)).log_prob(sample)) # sum across obs and dimensions +obs_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(obs_weights_logp, argnums=0)) # Take grad wrt to sample (NxK) +mc_obs_weights_logp_val_and_grad = jax.jit(jax.vmap(obs_weights_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxNxK + +@jax.jit +def obs_weights_logq(mu, log_std): + return tfd.Normal(mu, jnp.exp(log_std)).entropy() +obs_weights_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(obs_weights_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters + + +@jax.jit +def factor_weights_logp(sample, mean, log_std): # single sample, KxG + return jnp.sum(tfd.Normal(mean, jnp.exp(log_std)).log_prob(sample)) # sum across factors and genes +factor_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(factor_weights_logp, argnums=0)) # Take grad wrt to sample (KxG) +mc_factor_weights_logp_val_and_grad = jax.jit(jax.vmap(factor_weights_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxKxG + +@jax.jit +def factor_weights_logq(mu, log_std): + return tfd.Normal(mu, jnp.exp(log_std)).entropy() +factor_weights_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(factor_weights_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters + +@jax.jit +def _mc_obs_ll(obs, node_mean, obs_weights, factor_weights, std): # For each MC sample: Nx2 + m = node_mean + obs_weights.dot(factor_weights) + return jnp.sum(jax.vmap(jax.scipy.stats.norm.logpdf, in_axes=[0, 0, None])(obs, m, std), axis=1) # sum over dimensions + +@jax.jit +def ll(x, weights, psi, obs_weights, factor_weights, log_std): # single sample + loc = psi + obs_weights.dot(factor_weights) + return jnp.sum(jnp.sum(tfd.Normal(loc, jnp.exp(log_std)).log_prob(x),axis=1) * weights) +ll_val_and_grad_psi = jax.jit(jax.value_and_grad(ll, argnums=2)) # Take grad wrt to psi +mc_ll_val_and_grad_psi = jax.jit(jax.vmap(ll_val_and_grad_psi, + in_axes=(None,None,0,0,0,None))) + +ll_val_and_grad_obs_weights = jax.jit(jax.value_and_grad(ll, argnums=3)) # Take grad wrt to obs_weights +mc_ll_val_and_grad_obs_weights = jax.jit(jax.vmap(ll_val_and_grad_obs_weights, + in_axes=(None,None,0,0,0,None))) + +ll_val_and_grad_factor_weights = jax.jit(jax.value_and_grad(ll, argnums=4)) # Take grad wrt to factor_weights +mc_ll_val_and_grad_factor_weights = jax.jit(jax.vmap(ll_val_and_grad_factor_weights, + in_axes=(None,None,0,0,0,None))) + +@jax.jit +def ll_suffstats(node_mean, mass, A, B_g, C, D_g, E, std): # for a single node_mean sample + """ + mass: \sum_n q(z_n = this node) + A: \sum_n q(z_n = this node) * \sum_g x_ng ** 2 + B_g: \sum_n q(z_n = this node) * x_ng + C: \sum_n q(z_n = this node) * \sum_g x_ng * E[s_nW_g] + D_g: \sum_n q(z_n = this node) * E[s_nW_g] + E: \sum_n q(z_n = this node) * \sum_g E[(s_nW_g)**2] + """ + v = std**2 + ll = -jnp.log(2*jnp.pi*v) * mass - A/(2*v) + jnp.sum(B_g*node_mean)/v + C/v - \ + (mass*jnp.sum(node_mean**2))/(2*v) - jnp.sum(D_g*node_mean)/v - E/(2*v) + return ll + +@jax.jit +def ll_node_mean_suff(node_mean, mass, B_g, D_g, log_std): # for a single node_mean sample + """ + mass: \sum_n q(z_n = this node) + B_g: \sum_n q(z_n = this node) * x_ng + D_g: \sum_n q(z_n = this node) * E[s_nW_g] + """ + v = jnp.exp(log_std)**2 + ll = jnp.sum(B_g*node_mean)/v - (mass*jnp.sum(node_mean**2))/(2*v) - jnp.sum(D_g*node_mean)/v + return ll +ll_node_mean_suff_val_and_grad = jax.jit(jax.value_and_grad(ll_node_mean_suff, argnums=0)) # Take grad wrt to psi +mc_ll_node_mean_suff_val_and_grad = jax.jit(jax.vmap(ll_node_mean_suff_val_and_grad, + in_axes=(0,None, None, None, None))) + +# To get noise sample +sample_prod = jax.jit(lambda mat1, mat2: mat1.dot(mat2)) \ No newline at end of file diff --git a/scatrex/models/trajectory/tree.py b/scatrex/models/trajectory/tree.py new file mode 100644 index 0000000..7ba9268 --- /dev/null +++ b/scatrex/models/trajectory/tree.py @@ -0,0 +1,67 @@ +import numpy as np +from .node import TrajectoryNode +from ...ntssb.observed_tree import * +from anndata import AnnData + +class TrajectoryTree(ObservedTree): + def __init__(self, **kwargs): + super(TrajectoryTree, self).__init__(**kwargs) + self.node_constructor = TrajectoryNode + + def sample_kernel(self, parent_params, mean_dist=1., angle_concentration=1., loc_variance=.1, seed=42, depth=1., **kwargs): + rng = np.random.default_rng(seed=seed) + parent_loc = parent_params[0] + parent_angle = parent_params[1] + angle_concentration = angle_concentration * depth + sampled_angle = rng.vonmises(parent_angle, angle_concentration) + sampled_loc = rng.normal(mean_dist, loc_variance) + sampled_loc = parent_loc + np.array([np.cos(sampled_angle)*np.abs(sampled_loc), np.sin(sampled_angle)*np.abs(sampled_loc)]) + return [sampled_loc, sampled_angle] + + def sample_root(self, **kwargs): + return [np.array([0., 0.]), 0.] + + def get_param_size(self): + return self.tree["param"][0].size + + def get_params(self): + params = [] + for node in self.tree_dict: + params.append(self.tree_dict[node]["param"][0]) + return np.array(params, dtype=np.float) + + def param_distance(self, paramA, paramB): + return np.sqrt(np.sum((paramA[0]-paramB[0])**2)) + + def create_adata(self, var_names=None): + params = [] + params_labels = [] + for node in self.tree_dict: + if self.tree_dict[node]["size"] != 0: + params_labels.append( + [self.tree_dict[node]["label"]] * self.tree_dict[node]["size"] + ) + params.append( + np.vstack( + [self.tree_dict[node]["param"][0]] * self.tree_dict[node]["size"] + ) + ) + params = pd.DataFrame(np.vstack(params)) + params_labels = np.concatenate(params_labels).tolist() + if var_names is not None: + params.columns = var_names + self.adata = AnnData(params) + self.adata.obs["node"] = params_labels + self.adata.uns["node_colors"] = [ + self.tree_dict[node]["color"] + for node in self.tree_dict + if self.tree_dict[node]["size"] != 0 + ] + self.adata.uns["node_sizes"] = np.array( + [ + self.tree_dict[node]["size"] + for node in self.tree_dict + if self.tree_dict[node]["size"] != 0 + ] + ) + self.adata.var["bulk"] = np.mean(self.adata.X, axis=0) \ No newline at end of file diff --git a/scatrex/ntssb/__init__.py b/scatrex/ntssb/__init__.py index 2648fba..551928f 100644 --- a/scatrex/ntssb/__init__.py +++ b/scatrex/ntssb/__init__.py @@ -1,4 +1,4 @@ from .node import AbstractNode from .ntssb import NTSSB -from .tree import Tree +from .observed_tree import ObservedTree from .search import StructureSearch diff --git a/scatrex/ntssb/node.py b/scatrex/ntssb/node.py index bc0c94c..82d5a48 100644 --- a/scatrex/ntssb/node.py +++ b/scatrex/ntssb/node.py @@ -1,22 +1,33 @@ import numpy as np from numpy.random import * +import jax.numpy as jnp +from abc import ABC, abstractmethod -class AbstractNode(object): +class AbstractNode(ABC): def __init__( - self, is_observed, observed_parameters, parent=None, tssb=None, label="" + self, observed_parameters, parent=None, tssb=None, label="", seed=42, ): self.data = set([]) self._children = set([]) self.tssb = tssb - self.is_observed = is_observed - self.observed_parameters = observed_parameters self.label = label self.event_str = "" + self.seed = seed self.params = dict() + self.observed_parameters = observed_parameters self.variational_parameters = dict(globals=dict(), locals=dict()) - + self.variational_parameters = {'delta_1': 1., 'delta_2': 1., # nu stick + 'sigma_1': 1., 'sigma_2': 1., # psi stick + 'q_z': [], # prob of assigning each cell to this node + 'kernel' : dict(), # model-specific + 'q_rho': [], # prob of assigning the root node of each child TSSB to this node + 'E_log_phi': 0., # auxiliary quantity + 'sum_E_log_1_nu': 0., # auxiliary quantity + } + self.samples = None + self.pivot_prior_prob = 1. # prior prob of assigning the root node of each child TSSB to this node self.data_weights = 0. self.weight_until_here = 0. @@ -38,6 +49,34 @@ def __init__( else: self._parent = None + @abstractmethod + def compute_loglikelihood(self): + return + + @abstractmethod + def combine_params(self): + return + + @abstractmethod + def sample_kernel(self): + return + + @abstractmethod + def compute_kernel_prior(self): + return + + @abstractmethod + def compute_root_kernel_prior(self): + return + + @abstractmethod + def compute_kernel_entropy(self): + return + + @abstractmethod + def remove_noise(self): + return + def set_parent(self, parent, reset=False): if self._parent is not None and self._parent._children is not None: self._parent._children.remove(self) @@ -45,20 +84,11 @@ def set_parent(self, parent, reset=False): parent.add_child(self) self._parent = parent - if not self.is_observed: - # Make sure we use the right observed parameters - self.observed_parameters = parent.observed_parameters - self.inherit_parameters() - - self.unobserved_factors = parent.unobserved_factors + self.unobserved_factors self.set_mean() if reset: self.reset_variational_parameters() - def inherit_parameters(self): - pass - def reset_parameters(self): pass @@ -69,11 +99,17 @@ def kill(self): self._parent = None self._children = None - def spawn(self, is_observed, observed_parameters): + def spawn(self, observed_parameters, seed=42): return self.__class__( - is_observed, observed_parameters, parent=self, tssb=self.tssb + observed_parameters, parent=self, tssb=self.tssb, seed=seed, ) + def get_observed_parameters(self): + return self.observed_parameters + + def get_params(self): + return self.params + def has_data(self): if len(self.data): return True @@ -158,7 +194,7 @@ def global_param(self, key): return self.parent().global_param(key) def get_ancestors(self, all=True): - if self._parent is None or (not all and self.is_observed): + if self._parent is None or (not all and self.tssb != self._parent.tssb): return [self] else: ancestors = self._parent.get_ancestors(all=all) @@ -179,11 +215,20 @@ def descend(node): def parameter_log_prior(self): return 0 - def tssb_caller(self): - if self.parent() is None: - return self.tssb - else: - return self.parent().tssb_caller() - def set_event_string(self): pass + + def get_top_obs(self, q=70, idx=None): + """ + Get data which is very well explained by this node's parameter + """ + if idx is None: + idx = jnp.arange(self.tssb.ntssb.num_data) + if len(idx) == 0: + return np.array([]) + lls = self.compute_loglikelihood(idx) + top_obs = idx[np.where(lls > np.percentile(lls, q=q))[0]] + return top_obs + + def reset_variational_state(self, **kwargs): + return \ No newline at end of file diff --git a/scatrex/ntssb/ntssb.py b/scatrex/ntssb/ntssb.py index 225d5fd..2cba92b 100644 --- a/scatrex/ntssb/ntssb.py +++ b/scatrex/ntssb/ntssb.py @@ -8,6 +8,7 @@ from graphviz import Digraph import matplotlib import matplotlib.cm +import matplotlib.pyplot as plt import numpy as np from numpy import * @@ -19,13 +20,10 @@ import jax.numpy as jnp import jax.nn as jnn -from ..util import * +from ..utils.math_utils import * from ..callbacks import elbos_callback from .tssb import TSSB -from ..models import cna -from ..models.cna.opt_funcs import * - -import time +from ..plotting import tree_colors, plot_full_tree import logging @@ -62,7 +60,6 @@ class NTSSB(object): def __init__( self, input_tree, - node_constructor, dp_alpha=1.0, dp_gamma=1.0, alpha_decay=1.0, @@ -70,8 +67,11 @@ def __init__( max_depth=15, fixed_weights_pivot_sampling=True, use_weights=True, + weights_concentration=10., + min_weight=1e-6, verbosity=logging.INFO, node_hyperparams=dict(), + seed=42, ): if input_tree is None: raise Exception("Input tree must be specified.") @@ -80,11 +80,14 @@ def __init__( self.max_depth = max_depth self.dp_alpha = dp_alpha # smaller dp_alpha => larger nu => less nodes self.dp_gamma = dp_gamma # smaller dp_gamma => larger psi => less nodes - self.alpha_decay = alpha_decay + self.alpha_decay = alpha_decay # smaller alpha_decay => larger decay with depth => less nodes + + self.seed = seed self.input_tree = input_tree self.input_tree_dict = self.input_tree.tree_dict - self.node_constructor = node_constructor + self.node_constructor = self.input_tree.node_constructor + self.node_hyperparams = node_hyperparams self.fixed_weights_pivot_sampling = fixed_weights_pivot_sampling @@ -99,11 +102,14 @@ def __init__( self.data = None self.num_data = None self.covariates = None + self.num_batches = 1 + self.batch_size = None self.max_nodes = ( len(self.input_tree_dict.keys()) * 1 ) # upper bound on number of nodes self.n_nodes = len(self.input_tree_dict.keys()) + self.n_total_nodes = self.n_nodes self.obs_cmap = self.input_tree.cmap self.exp_cmap = matplotlib.cm.viridis @@ -111,199 +117,331 @@ def __init__( logger.setLevel(verbosity) - self.reset_tree(use_weights=use_weights, node_hyperparams=node_hyperparams) + self.reset_tree(use_weights=use_weights, weights_concentration=weights_concentration, min_weight=min_weight) + + self.set_pivot_priors() + + self.variational_parameters = { + 'LSE_c': [], # normalizing constant for cell-TSSB assignments + } # ========= Functions to initialize tree. ========= - def reset_tree(self, use_weights=False, node_hyperparams=dict()): + def reset_tree(self, use_weights=False, weights_concentration=10., min_weight=1e-6): if use_weights and "weight" not in self.input_tree_dict["A"].keys(): raise KeyError("No weights were specified in the input tree.") # Clear tree self.assignments = [] - input_tree_dict = self.input_tree_dict - - # Get MRCA node - root_node = self.input_tree.mrca() - - obj = self.node_constructor( - True, - input_tree_dict[root_node]["params"], - parent=None, - label=root_node, - **node_hyperparams, - ) - input_tree_dict[root_node]["subtree"] = TSSB( - obj, - root_node, - ntssb=self, - dp_alpha=input_tree_dict[root_node]["dp_alpha_subtree"], - alpha_decay=input_tree_dict[root_node]["alpha_decay_subtree"], - dp_gamma=input_tree_dict[root_node]["dp_gamma_subtree"], - color=input_tree_dict[root_node]["color"], - ) - - main = ( - boundbeta(1.0, self.dp_alpha) if self.min_depth == 0 else 0.0 - ) # if min_depth > 0, no data can be added to the root (main stick is nu) - if use_weights: - main = self.input_tree_dict[root_node]["weight"] - input_tree_dict[root_node]["subtree"].weight = self.input_tree_dict[ - root_node - ]["weight"] - - self.root = { - "node": input_tree_dict[root_node]["subtree"], - "main": main, - "sticks": empty((0, 1)), # psi sticks - "children": [], - "label": root_node, - "super_parent": None, - "parent": None, - } - - # Recursively construct tree of subtrees - def descend(super_tree, label, depth=0): - for i, child in enumerate(input_tree_dict[label]["children"]): - - stick = boundbeta(1, self.dp_gamma) + # Traverse tree in depth first and create TSSBs + def descend(input_root, idx=1, depth=0): + alpha_nu = 1. + beta_nu = (self.alpha_decay**depth) * self.dp_alpha + + children_roots = [] + sticks = [] + psi_priors = [] + for i, child in enumerate(input_root['children']): + child_root, idx = descend(child, idx, depth+1) + children_roots.append(child_root) + + rng = np.random.default_rng(int(self.seed+idx*1e6)) + stick = boundbeta(1, self.dp_gamma, rng) + psi_prior = {"alpha_psi": 1., "beta_psi": self.dp_gamma} if use_weights: - stick = self.input_tree.get_sum_weights_subtree(child) - if i < len(input_tree_dict[label]["children"]) - 1: + stick = self.input_tree.get_sum_weights_subtree(child_root["label"]) + if i < len(input_root["children"]) - 1: sum = 0 - for j, c in enumerate(input_tree_dict[label]["children"][i:]): - sum = sum + self.input_tree.get_sum_weights_subtree(c) + for j, c in enumerate(input_root["children"][i:]): + sum = sum + self.input_tree.get_sum_weights_subtree(c["label"]) stick = stick / sum else: stick = 1.0 - - super_tree["sticks"] = vstack( - [ - super_tree["sticks"], - stick - if i < len(input_tree_dict[label]["children"]) - 1 - else 1.0, - ] + psi_prior["alpha_psi"] = stick * (weights_concentration - 2) + 1 + psi_prior["beta_psi"] = (1-stick) * (weights_concentration -2) + 1 + psi_priors.append(psi_prior) + sticks.append(stick) + if len(sticks) == 0: + sticks = empty((0, 1)) + else: + sticks = vstack(sticks) + + label = input_root["label"] + + # Create node + local_seed = int(self.seed+idx*1e6) + node = self.node_constructor( + input_root["param"], + label=label, + seed=local_seed, + **self.node_hyperparams, + ) + + # Create TSSB with pointers to children root nodes + rng = np.random.default_rng(local_seed) + + children_nodes = [c["node"].root["node"] for c in children_roots] + tssb = TSSB( + node, + label, + ntssb=self, + children_root_nodes=children_nodes, + dp_alpha=input_root["dp_alpha_subtree"], + alpha_decay=input_root["alpha_decay_subtree"], + dp_gamma=input_root["dp_gamma_subtree"], + eta=input_root["eta"], + color=input_root["color"], + seed=local_seed, ) + input_root["subtree"] = tssb + + # Create root dict + if depth >= self.min_depth: + main = boundbeta(1.0, (self.alpha_decay ** depth) * self.dp_alpha, rng) + else: # if depth < min_depth, no data can be added to this node (main stick is nu) + main = 0. + if use_weights: + main = input_root["weight"] + subtree_weights_sum = self.input_tree.get_sum_weights_subtree(label) + main = main / subtree_weights_sum + input_root["subtree"].weight = input_root["weight"] + if len(input_root["children"]) < 1: + main = 1.0 # stop at leaf node + + if use_weights: + alpha_nu = main * (weights_concentration - 2) + 1 + beta_nu = (1-main) * (weights_concentration - 2) + 1 + + root_dict = { + "node": tssb, + "main": main, + "sticks": sticks, + "children": children_roots, + "label": input_root["label"], + "super_parent": None, # maybe remove + "pivot_node": None, # maybe remove + "pivot_tssb": None, # maybe remove + "color": input_root["color"], + "alpha_nu": alpha_nu, + "beta_nu": beta_nu, + "psi_priors": psi_priors, + } + + return root_dict, idx+1 + + self.root, _ = descend(self.input_tree.tree) - main = boundbeta(1.0, (self.alpha_decay ** (depth + 1)) * self.dp_alpha) - if use_weights: - main = self.input_tree_dict[child]["weight"] - subtree_weights_sum = self.input_tree.get_sum_weights_subtree(child) - main = main / subtree_weights_sum - - if len(input_tree_dict[child]["children"]) < 1: - main = 1.0 # stop at leaf node - - pivot_tssb = input_tree_dict[label]["subtree"] - # pivot_tssb.dp_alpha = input_tree_dict[child]['dp_alpha_parent_edge'] - # pivot_tssb.alpha_decay = input_tree_dict[child]['alpha_decay_parent_edge'] - # pivot_tssb.truncate() - - # pivot_tssb.eta = input_tree_dict[child]['eta'] - pivot_node = super_tree["node"].root["node"] - - obj = self.node_constructor( - True, - input_tree_dict[child]["params"], - parent=pivot_node, - label=child, - ) + def descend(root): + for child in root['children']: + child["node"]._parent = root["node"] + descend(child) + + # Set parents + descend(self.root) - input_tree_dict[child]["subtree"] = TSSB( - obj, - child, - ntssb=self, - dp_alpha=input_tree_dict[child]["dp_alpha_subtree"], - alpha_decay=input_tree_dict[child]["alpha_decay_subtree"], - dp_gamma=input_tree_dict[child]["dp_gamma_subtree"], - color=input_tree_dict[child]["color"], - ) - input_tree_dict[child]["subtree"].eta = input_tree_dict[child]["eta"] + # And add weights keys + self.set_weights() - if use_weights: - input_tree_dict[child]["subtree"].weight = self.input_tree_dict[ - child - ]["weight"] - - super_tree["children"].append( - { - "node": input_tree_dict[child]["subtree"], - "main": main if self.min_depth <= (depth + 1) else 0.0, - "sticks": empty((0, 1)), # psi sticks - "children": [], - "label": child, - "super_parent": super_tree["node"], - "pivot_node": pivot_node, - "pivot_tssb": pivot_tssb, - } - ) + def set_tssb_params(self, dp_alpha=1., alpha_decay=1., dp_gamma=1.): + def descend(root): + root['node'].dp_alpha = dp_alpha + root['node'].alpha_decay = alpha_decay + root['node'].dp_gamma = dp_gamma + for child in root['children']: + descend(child) + descend(self.root) - descend(super_tree["children"][-1], child, depth + 1) + def set_node_hyperparams(self, **kwargs): + def descend(root): + root['node'].set_node_hyperparams(**kwargs) + for child in root['children']: + descend(child) + descend(self.root) - descend(self.root, root_node) + def sample_variational_distributions(self, **kwargs): + def descend(root): + root['node'].sample_variational_distributions(**kwargs) + for child in root['children']: + descend(child) + descend(self.root) - def reset_variational_parameters(self, **kwargs): - # Reset node parameters + def set_learned_parameters(self): + def descend(root): + root['node'].set_learned_parameters() + for child in root['children']: + descend(child) + descend(self.root) + + def reset_sufficient_statistics(self): def descend(super_tree): - super_tree["node"].reset_node_variational_parameters(**kwargs) + super_tree["node"].reset_sufficient_statistics(num_batches=self.num_batches) for child in super_tree["children"]: descend(child) + descend(self.root) + def reset_variational_parameters(self, **kwargs): + # Reset node parameters + def descend(super_tree, alpha_psi=1., beta_psi=1.): + alpha_nu = super_tree['alpha_nu'] + beta_nu = super_tree['beta_nu'] + super_tree["node"].reset_variational_parameters(alpha_nu=alpha_nu, beta_nu=beta_nu, + alpha_psi=alpha_psi,beta_psi=beta_psi, + **kwargs) + c_norm = jnp.array(super_tree["node"].variational_parameters['q_c']) + for i, child in enumerate(super_tree["children"]): + alpha_psi = super_tree['psi_priors'][i]["alpha_psi"] + beta_psi = super_tree['psi_priors'][i]["beta_psi"] + c_norm += descend(child, alpha_psi=alpha_psi, beta_psi=beta_psi) + return c_norm + + c_norm = descend(self.root) + + # Apply normalization + def descend(root, alpha_psi=1., beta_psi=1.): + alpha_nu = root['alpha_nu'] + beta_nu = root['beta_nu'] + root["node"].reset_variational_parameters(alpha_nu=alpha_nu, beta_nu=beta_nu, + alpha_psi=alpha_psi,beta_psi=beta_psi, + **kwargs) + root["node"].variational_parameters['q_c'] = root["node"].variational_parameters['q_c'] / c_norm + for i, child in enumerate(root["children"]): + alpha_psi = root['psi_priors'][i]["alpha_psi"] + beta_psi = root['psi_priors'][i]["beta_psi"] + descend(child, alpha_psi=alpha_psi, beta_psi=beta_psi) descend(self.root) + def init_root_kernels(self, **kwargs): + def descend(super_tree): + for child in super_tree["children"]: + child["node"].root["node"].init_kernel(**kwargs) + descend(child) + descend(self.root) + def reset_node_parameters( - self, root_params=True, down_params=True, node_hyperparams=None - ): + self, **node_hyperparams + ): # Reset node parameters def descend(super_tree): - super_tree["node"].reset_node_variational_parameters() - super_tree["node"].reset_node_parameters( - root_params=root_params, - down_params=down_params, - node_hyperparams=node_hyperparams, - ) + super_tree["node"].reset_node_parameters(**node_hyperparams) for child in super_tree["children"]: descend(child) descend(self.root) + def remake_observed_params(self): + def descend(super_tree): + self.input_tree.tree_dict[super_tree["label"]]["param"] = super_tree["node"].root["node"].params + super_tree["node"].root["node"].observed_parameters = self.input_tree.tree_dict[super_tree["label"]]["param"] + for child in super_tree["children"]: + descend(child) + + descend(self.root) + self.input_tree.update_tree() + + def set_radial_positions(self): + """ + Create a radial layout from the full NTSSB and set the node means + as their positions in the layout. + + Make sure the outer params correspond to longer branches than internal ones. + """ + import networkx as nx + self.create_augmented_tree_dict() # create self.node_dict + + G = nx.DiGraph() + for node in self.node_dict: + G.add_node(G, node) + if self.node_dict[node]['parent'] != '-1': + parent = self.node_dict[node]['parent'] + G.add_edge(parent, node) + pos = nx.nx_pydot.graphviz_layout(G, prog="twopi") + + self.set_node_means(pos) # to sample observations from nodes in these positions + + def set_node_means(self, pos): + for node in self.node_dict: + self.node_dict[node].set_node_mean(pos[node]) + def sync_subtrees(self): subtrees = self.get_subtrees() for subtree in subtrees: subtree.ntssb = self - def put_data_in_nodes(self, num_data, eta=0): + def get_node(self, u, key=None, uniform=False, include_leaves=True): + # See in which subtree it lands + subtree, _, u = self.find_node(u) + + # See in which node it lands + if uniform: + _, _, root = subtree.find_node_uniform(key, include_leaves=include_leaves) + else: + _, _, root = subtree.find_node(u, include_leaves=include_leaves) + + return root + + def sample_assignments(self, num_data): + self.num_data = num_data + + node_assignments = [] + obs_node_assignments = [] + self.assignments = [] + self.subtree_assignments = [] subtrees = self.get_subtrees() for tssb in subtrees: tssb.assignments = [] + tssb.remove_data() + + # Draw sticks + rng = np.random.default_rng(self.seed) + u_vector = rng.random(size=num_data) + for n in range(num_data): + u = u_vector[n] + # See in which subtree it lands + subtree, _, u = self.find_node(u) - self.num_data = num_data + # See in which node it lands + node, _, _ = subtree.find_node(u) - # Get mixture weights - nodes, weights = self.get_node_weights(eta=eta) + self.assignments.append(node) + self.subtree_assignments.append(subtree) - for node in nodes: - node.remove_data() + node_assignments.append(node.label) + obs_node_assignments.append(subtree.label) - for n in range(self.num_data): - node = np.random.choice(nodes, p=weights) - node.tssb.assignments.append(node) + subtree.assignments.append(node) + subtree.add_datum(n) node.add_datum(n) - self.assignments.append(node) + + return node_assignments, obs_node_assignments + + def simulate_data(self): + self.data = np.zeros((self.num_data, self.input_tree.get_param_size())) + + # Reset root node parameters to set data-dependent variables if applicable + self.root["node"].root["node"].reset_data_parameters() + + # Sample observations + def super_descend(super_root): + descend(super_root['node'].root) + for super_child in super_root['children']: + super_descend(super_child) - def normalize_data(self): - if self.data is None: - raise Exception("Need to `call add_data(self, data, to_root=False)` first.") + def descend(root): + attached_cells = np.array(list(root['node'].data)) + if len(attached_cells) > 0: + self.data[attached_cells] = root['node'].sample_observations() + for child in root['children']: + descend(child) + + super_descend(self.root) + self.data = jnp.array(self.data) - self.normalized_data = np.log( - 10000 * self.data / np.sum(self.data, axis=1).reshape(self.num_data, 1) + 1 - ) + return self.data - def add_data(self, data, covariates=None, to_root=False): - self.data = data - self.num_data = 0 if data is None else data.shape[0] + def add_data(self, data, covariates=None): + self.data = jnp.array(data) + self.num_data = data.shape[0] if covariates is None: self.covariates = np.zeros((self.num_data, 0)) else: @@ -312,29 +450,15 @@ def add_data(self, data, covariates=None, to_root=False): logger.debug(f"Adding data of shape {data.shape} to NTSSB") - self.assignments = [] - - for n in range(self.num_data): - if to_root: - subtree = self.root["node"] - node = self.root["node"].root["node"] - else: - u = rand() - subtree, _, u = self.find_node(u) - - # Now choose the node - node, _, _ = subtree.find_node(u) - - subtree.assignments.append(node) - node.add_datum(n) - self.assignments.append(node) - try: # Reset root node parameters to set data-dependent variables if applicable self.root["node"].root["node"].reset_data_parameters() except AttributeError: pass + # Reset node variational parameters to use this data size + self.reset_variational_parameters() + def clear_data(self): def descend(root): for index, child in enumerate(root["children"]): @@ -363,44 +487,47 @@ def descend(super_tree, label, depth=0): descend(self.root, "A") - def create_new_tree(self, n_extra_per_observed=1, num_data=None): + def create_new_tree(self, n_extra_per_observed=1): # Clear current tree (including subtrees) self.reset_tree( - True, node_hyperparams=self.root["node"].root["node"].node_hyperparams + True ) - self.reset_node_parameters( - node_hyperparams=self.root["node"].root["node"].node_hyperparams - ) - self.plot_tree(super_only=False) # update names + self.set_weights() - # Add nodes to subtrees - subtrees = self.get_subtrees() - for subtree in subtrees: - n_nodes = 0 - while n_nodes < n_extra_per_observed: - _, nodes = subtree.get_mixture() - # Uniformly choose a node from the subtree - snode = np.random.choice(nodes) - self.add_node_to(snode.label, optimal_init=False) - self.plot_tree(super_only=False) # update names - n_nodes = n_nodes + 1 - - # Choose pivots + def get_distance(nodeA, nodeB): + return np.sqrt(np.sum((nodeA.get_mean() - nodeB.get_mean())**2)) + + # Add nodes and set pivots def descend(super_tree): - for child in super_tree["children"]: + if super_tree['weight'] != 0: # Add nodes only if it has some mass + n_nodes = 0 + while n_nodes < n_extra_per_observed: + _, _, nodes_roots = super_tree['node'].get_mixture(get_roots=True) + # Uniformly choose a node from the subtree + rng = np.random.default_rng(super_tree['node'].seed + n_nodes) + snode = rng.choice(nodes_roots) + super_tree['node'].add_node(snode) + n_nodes = n_nodes + 1 + super_tree['node'].reset_node_parameters(**self.node_hyperparams) # adjust parameters to avoid overlapping subnodes + for i, child in enumerate(super_tree["children"]): weights, nodes = super_tree["node"].get_fixed_weights( eta=child["node"].eta ) - pivot_node = np.random.choice(nodes, p=weights) + weights = np.array([w/get_distance(child['node'].root['node'], n) for w, n in zip(weights, nodes)]) + weights = weights / np.sum(weights) + # rng = np.random.default_rng(super_tree['node'].seed + i) + # pivot_node = rng.choice(nodes, p=weights) + pivot_node = nodes[np.argmax(weights)] child["pivot_node"] = pivot_node child["node"].root["node"].set_parent(pivot_node) descend(child) + super_tree['node'].truncate() + super_tree['node'].set_weights() + super_tree['node'].set_pivot_priors() descend(self.root) - - if num_data is not None: - self.put_data_in_nodes(num_data, eta=0) + self.plot_tree(super_only=False) # update names def sample_new_tree(self, num_data, use_weights=False): self.num_data = num_data @@ -408,7 +535,6 @@ def sample_new_tree(self, num_data, use_weights=False): # Clear current tree (including subtrees) self.reset_tree( use_weights, - node_hyperparams=self.root["node"].root["node"].node_hyperparams, ) # Break sticks to assign data @@ -449,6 +575,38 @@ def descend(super_tree): descend(self.root) + def get_param_dict(self): + """ + Go from a dictionary where each node is a TSSB to a dictionary where each node is a dictionary, + with `params` and `weight` keys + """ + + param_dict = { + "node": self.root['node'].get_param_dict(), + "weight": self.root['weight'], + "children": [], + "obs_param": self.root['node'].root['node'].get_observed_parameters(), + "label": self.root['label'], + "color": self.root['color'], + "size": len(self.root['node']._data), + } + def descend(root, root_new): + for child in root["children"]: + child_new = { + "node": child['node'].get_param_dict(), + "weight": child['weight'], + "children": [], + "obs_param": child['node'].root['node'].get_observed_parameters(), + "label": child['label'], + "color": child['color'], + "size": len(child['node']._data) + } + root_new['children'].append(child_new) + descend(child, root_new['children'][-1]) + + descend(self.root, param_dict) + return param_dict + # ========= Functions to sample tree parameters. ========= def sample_pivot_main_sticks(self): def descend(super_tree): @@ -577,6 +735,34 @@ def descend(root, mass): return descend(self.root, 1.0) + def set_weights(self): + def descend(root, mass): + root['weight'] = mass * root["main"] + edges = sticks_to_edges(root["sticks"]) + weights = diff(hstack([0.0, edges])) + for i, child in enumerate(root["children"]): + descend(child, mass * (1.0 - root["main"]) * weights[i]) + return descend(self.root, 1.0) + + def set_expected_weights(self): + def descend(root): + logprior = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + if root['node'].parent() is not None: + logprior += root['node'].parent().variational_parameters['sum_E_log_1_nu'] + logprior += root['node'].variational_parameters['E_log_phi'] + root['weight'] = jnp.exp(logprior) + root['node'].set_expected_weights() + for child in root['children']: + descend(child) + descend(self.root) + + def set_pivot_priors(self): + def descend(root): + for child in root['children']: + root['node'].set_pivot_priors() + descend(child) + descend(self.root) + def get_node_data_sizes(self, normalized=False, super_only=False): nodes, _ = self.get_node_mixture() sizes = [] @@ -663,29 +849,20 @@ def descend(super_root): return descend(self.root) - def get_nodes(self, root_node=None, parent_vector=False): - def descend(root, idx=0, prev_idx=-1): - idx = idx + 1 - node = [root] - parent_idx = [prev_idx] - parent_idx = [prev_idx] - prev_idx = idx - children = list(root.children()) - labs = [c.label for c in children] - children = np.array(children)[np.argsort(labs)] - for i, child in enumerate(children): - nodes, idx, parents_idx = descend(child, idx, prev_idx - 1) - node.extend(nodes) - parent_idx.extend(parents_idx) - return node, idx, parent_idx - - if root_node is None: - root_node = self.root["node"].root["node"] - node_list, _, parent_list = descend(root_node) - if parent_vector: - return node_list, parent_list - else: - return node_list + def get_nodes(self): + def descend(root): + nodes = [root['node']] + for child in root['children']: + nodes.extend(descend(child)) + return nodes + + def super_descend(root): + nodes = descend(root['node'].root) + for child in root['children']: + nodes.extend(super_descend(child)) + return nodes + + return super_descend(self.root) def get_width_distribution(self): def descend(root, depth, width_vec): @@ -812,136 +989,6 @@ def vbis_estimate(self, importance_samples=None, n_samples=1000): # ========= Functions to update tree parameters given data. ========= - def get_node_mean(self, log_baseline, unobserved_factors, noise, cnvs): - node_mean = jnp.exp( - log_baseline + unobserved_factors + noise + jnp.log(cnvs / 2) - ) - sum = jnp.sum(node_mean, axis=1).reshape(self.num_data, 1) - node_mean = node_mean / sum - return node_mean - - def get_tssb_indices(self, nodes, tssbs): - # start = time.time() - max_len = self.max_nodes - tssb_indices = [] - for node in nodes: - tssb_indices.append( - np.array([i for i, tssb in enumerate(tssbs) if tssb == node.tssb.label]) - ) - # if len(tssb_indices[-1].shape[0]) > max_len: - # max_len = len(tssb_indices[-1].shape[0]) - - for i, c in enumerate(tssb_indices): - l = c.shape[0] - if l < max_len: - c = np.concatenate([c, np.array([-1] * (max_len - l))]) - tssb_indices[i] = c - tssb_indices = jnp.array(tssb_indices).astype(int) - # end = time.time() - # print(f"get_tssb_indices: {end-start}") - return tssb_indices - - def get_below_root(self, root_idx, children_vector, tssbs=None): - def descend(idx): - below_root = [idx] - for child_idx in children_vector[idx]: - if child_idx > 0: - if tssbs is not None: - if tssbs[child_idx] == tssbs[root_idx]: - aux = descend(child_idx) - below_root.extend(aux) - else: - aux = descend(child_idx) - below_root.extend(aux) - return below_root - - return np.array(descend(root_idx)) - - @partial(jax.jit, static_argnums=0) - def get_children_vector(self, parent_vector): - def f(i): - return jnp.where(parent_vector == i, size=self.max_nodes, fill_value=-1)[0] - - return jax.vmap(f)(jnp.arange(self.max_nodes)) - - def get_ancestor_indices(self, nodes, parent_vector, inclusive=False): - # start = time.time() - ancestor_indices = [] - max_len = self.max_nodes - for i in range(len(nodes)): - # get ancestor nodes in the same subtree - p = i - indices = [] - while p != -1 and nodes[p].tssb == nodes[i].tssb: - if not (not inclusive and p == i): - indices.append(p) - p = parent_vector[p] - - indices = np.array(indices) - ancestor_indices.append(indices) - # if len(indices) > max_len: - # max_len = len(indices) - - for i, c in enumerate(ancestor_indices): - l = c.shape[0] - if l < max_len: - c = np.concatenate([c, np.array([-1] * (max_len - l))]) - ancestor_indices[i] = c - ancestor_indices = jnp.array(ancestor_indices).astype(int) - # end = time.time() - # print(f"get_ancestor_indices: {end-start}") - return ancestor_indices - - def get_previous_branches_indices(self, nodes, within_tssb=True): - # start = time.time() - previous_branches_indices = [] - max_len = self.max_nodes - for node in nodes: - indices = [] - if within_tssb and not node.is_observed: - children = np.array(list(node.parent().children())) - labs = [l.label for l in children] - children = children[np.argsort(labs)] - for j, prev_child in enumerate(children): - if prev_child.is_observed: - continue - if prev_child == node: - break - # Locate prev_child in nodes list - for idx, n_ in enumerate(nodes): - if n_ == prev_child: - indices.append(idx) - break - elif not within_tssb: - if node.parent() is not None: - children = np.array(list(node.parent().children())) - if len(children) > 0: - labs = [l.label for l in children] - children = children[np.argsort(labs)] - for j, prev_child in enumerate(children): - if prev_child.is_observed: - continue - if prev_child == node: - break - # Locate prev_child in nodes list - for idx, n_ in enumerate(nodes): - if n_ == prev_child: - indices.append(idx) - break - previous_branches_indices.append(np.array(indices)) - # if len(indices) > max_len: - # max_len = len(indices) - - for i, c in enumerate(previous_branches_indices): - l = c.shape[0] - if l < max_len: - c = np.concatenate([c, np.array([-1] * (max_len - l))]) - previous_branches_indices[i] = c - previous_branches_indices = jnp.array(previous_branches_indices).astype(int) - # end = time.time() - # print(f"get_previous_branches_indices: {end-start}") - return previous_branches_indices - def Eq_log_p_nu(self, dp_alpha, nu_sticks_alpha, nu_sticks_beta): l = 0 aux = digamma(nu_sticks_beta) - digamma(nu_sticks_alpha + nu_sticks_beta) @@ -999,1292 +1046,664 @@ def Eq_log_q_tau(self, tau_alpha, tau_beta): ) return l - def tssb_log_priors(self): - nodes, parent_vector = self.get_nodes(root_node=None, parent_vector=True) - tssb_weights = jnp.array([node.tssb.weight for node in nodes]) - init_nu_log_alphas = jnp.array([node.nu_log_alpha for node in nodes]) - init_nu_log_betas = jnp.array([node.nu_log_beta for node in nodes]) - init_psi_log_alphas = jnp.array([node.psi_log_alpha for node in nodes]) - init_psi_log_betas = jnp.array([node.psi_log_beta for node in nodes]) - ancestor_nodes_indices = self.get_ancestor_indices(nodes, parent_vector) - previous_branches_indices = self.get_previous_branches_indices(nodes) - - rem = self.max_nodes - len(nodes) - init_psi_log_betas = jnp.concatenate( - [init_psi_log_betas, -1 * jnp.ones((rem,))] - ) - init_psi_log_alphas = jnp.concatenate( - [init_psi_log_alphas, -1 * jnp.ones((rem,))] - ) - init_nu_log_betas = jnp.concatenate([init_nu_log_betas, -1 * jnp.ones((rem,))]) - init_nu_log_alphas = jnp.concatenate( - [init_nu_log_alphas, -1 * jnp.ones((rem,))] - ) - tssb_weights = jnp.concatenate([tssb_weights, 10 * jnp.ones((rem,))]) - previous_branches_indices = jnp.concatenate( - [ - previous_branches_indices, - -1 * jnp.ones((rem, previous_branches_indices.shape[1])), - ], - axis=0, - ).astype(int) - ancestor_nodes_indices = jnp.concatenate( - [ - ancestor_nodes_indices, - -1 * jnp.ones((rem, ancestor_nodes_indices.shape[1])), - ], - axis=0, - ).astype(int) - - nu_sticks = jnp.exp(init_nu_log_alphas) / ( - jnp.exp(init_nu_log_alphas) + jnp.exp(init_nu_log_betas) - ) - psi_sticks = jnp.exp(init_psi_log_alphas) / ( - jnp.exp(init_psi_log_alphas) + jnp.exp(init_psi_log_betas) - ) + def make_batches(self, batch_size=None, seed=42): + if batch_size is None: + batch_size = self.num_data + self.batch_size = batch_size - logpis = [] - for i, node in enumerate(nodes): - logpis.append( - self.tssb_log_prior( - i, - nu_sticks, - psi_sticks, - previous_branches_indices, - ancestor_nodes_indices, - tssb_weights, - ) - ) - ws = list(jnp.exp(np.array(logpis))) - return list(np.array(logpis)), ws + rng = np.random.RandomState(seed) + perm = rng.permutation(self.num_data) - def tssb_log_prior( - self, - i, - nu_sticks, - psi_sticks, - previous_branches_indices, - ancestor_nodes_indices, - tssb_weights, - ): - # TSSB prior - nu_stick = nu_sticks[i] - psi_stick = psi_sticks[i] - - def prev_branches_psi(idx): - return (idx != -1) * jnp.log(1.0 - psi_sticks[idx]) + num_complete_batches, leftover = divmod(self.num_data, self.batch_size) + self.num_batches = num_complete_batches + bool(leftover) - def ancestors_nu(idx): - _log_phi = jnp.log(psi_sticks[idx]) + jnp.sum( - vmap(prev_branches_psi)(previous_branches_indices[idx]) - ) - _log_1_nu = jnp.log(1.0 - nu_sticks[idx]) - total = _log_phi + _log_1_nu - return (idx != -1) * total + self.batch_indices = [] + for i in range(self.num_batches): + batch_idx = perm[i * self.batch_size : (i + 1) * self.batch_size] + self.batch_indices.append(batch_idx) - log_phi = jnp.log(psi_stick) + jnp.sum( - vmap(prev_branches_psi)(previous_branches_indices[i]) - ) - log_node_weight = ( - jnp.log(nu_stick) - + log_phi - + jnp.sum(vmap(ancestors_nu)(ancestor_nodes_indices[i])) - ) - log_node_weight = log_node_weight + jnp.log(tssb_weights[i]) + self.reset_sufficient_statistics() - return log_node_weight + def get_top_node_obs(self, q=70): + """ + Get data which is very well explained by the node they attach to + """ + def sub_descend(root): + # Get cells attached to this node + idx = np.where(self.assignments == root['node'])[0] + top_obs = root['node'].get_top_obs(q=q, idx=idx) + for child in root['children']: + top_obs = np.concatenate([top_obs,sub_descend(child)]) + return top_obs - def optimize_elbo( - self, - update_all=False, - global_only=False, - sticks_only=False, - n_iters=20, - run=True, - n_inner_steps=50, - n_local_traverses=1, - **update_kwargs, - ): - # Init - if not run: - self.update_elbo(update=False, root=self.root, - sub_root=None, n_traverses=1, update_global=False, compute_global=True, - n_inner_steps=0, mb_size=self.data.shape[0]) - self.assign_to_best() - return [self.elbo] - - # Full: alternate between globals and locals for n_iters - elbos = [] - if update_all: - logger.debug("Updating global and node parameters") - update = True - if sticks_only: - update = False - logger.debug("Updating only sticks") - for i in range(n_iters): - # Globals - self.update_elbo(update=False, n_traverses=3, update_global=True, - compute_global=True, n_inner_steps=0, **update_kwargs) - # Locals - self.update_elbo(update=update, n_traverses=n_local_traverses, update_global=False, - compute_global=False, n_inner_steps=n_inner_steps, **update_kwargs) - elbos.append(self.elbo) + def descend(root): + top_obs = sub_descend(root['node'].root) + for child in root['children']: + top_obs = np.concatenate([top_obs, descend(child)]) + return top_obs + + top_obs = descend(self.root) + top_obs = np.unique(top_obs).astype(int) + return top_obs + + def compute_elbo(self, memoized=True, batch_idx=None, **kwargs): + if memoized: + return self.compute_elbo_suff() else: - if update_kwargs["root"] is not None: - update = True - update_global = False - compute_global = False - logger.debug("Updating from root") - if global_only: - n_inner_steps = 0 - update = False - update_global = True - compute_global = True - logger.debug("Updating only global parameters") - elif sticks_only: - n_inner_steps = 0 - logger.debug("Updating only stick parameters") - elbos = self.update_elbo(update=update, n_traverses=20, update_global=update_global, - compute_global=compute_global, n_inner_steps=n_inner_steps, **update_kwargs) - # Local only - elif update_kwargs["sub_root"] is not None: - if sticks_only: - n_inner_steps = 0 - nlabel = update_kwargs["sub_root"]["node"].label - logger.debug(f"Updating parameters below {nlabel}") - elbos = self.update_elbo(update=True, n_traverses=20, update_global=False, - compute_global=False, n_inner_steps=n_inner_steps, **update_kwargs) - self.assign_to_best() - self.plot_tree(counts=True) - return elbos - - def update_elbo(self, update=True, root=None, sub_root=None, go_down=True, compute_global=False, update_global=False, restricted=False, mb_size=128, n_traverses=1, n_inner_steps=10, mc_samples=3, lr=0.01): - if root is None and sub_root is None: - raise ValueError("`root` and `sub_root` can't both be None!") - - n_cells, n_genes = self.data.shape - data_indices = np.arange(n_cells) - res_data_indices = np.arange(n_cells) - mask = np.ones((n_cells,)) - probs = np.ones((n_cells,)) - probs = probs / np.sum(probs) - - n_factors = self.root["node"].root["node"].num_global_noise_factors - bs_grads = jnp.zeros((2,n_genes-1)) - noise_grads = jnp.zeros((2,n_factors,n_genes)) - cellnoise_grads = jnp.zeros((2,n_cells,n_factors)) - - # Get global variational parameters - log_baseline_mean = jnp.array(self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"]) - log_baseline_log_std = jnp.array(self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"]) - noise_factors_mean = jnp.array(self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"]) - noise_factors_log_std = jnp.array(self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"]) - - def sub_update_params(root, data_indices, mask, depth=0): - data = self.data[data_indices] - lib_sizes = self.root["node"].root["node"].lib_sizes[data_indices] - def _sub_update_params(root, depth=0): - if root["node"].parent() is None: - parent_unobserved_samples = jnp.zeros((mc_samples, n_genes)) - unobserved_samples = jnp.zeros((mc_samples, n_genes)) - unobserved_kernel_samples = jnp.zeros((mc_samples, n_genes)) - else: - # Sample parent - if root["node"].parent().parent() is None: - parent_unobserved_samples = jnp.zeros((mc_samples, n_genes)) - else: - parent_unobserved_means = root["node"].parent().variational_parameters["locals"]["unobserved_factors_mean"] - parent_unobserved_log_stds = root["node"].parent().variational_parameters["locals"]["unobserved_factors_log_std"] - parent_unobserved_samples = sample_unobserved(rngs, parent_unobserved_means, parent_unobserved_log_stds) - - unobserved_means = jnp.array(root["node"].variational_parameters["locals"]["unobserved_factors_mean"]) - unobserved_log_stds = jnp.array(root["node"].variational_parameters["locals"]["unobserved_factors_log_std"]) - unobserved_factors_kernel_log_mean = jnp.array(root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_mean"]) - unobserved_factors_kernel_log_std = jnp.array(root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_std"]) - - m1 = jnp.zeros_like(unobserved_means) - v1 = jnp.zeros_like(unobserved_means) - state1 = (m1,v1) - - m2 = jnp.zeros_like(unobserved_means) - v2 = jnp.zeros_like(unobserved_means) - state2 = (m2,v2) - - m3 = jnp.zeros_like(unobserved_means) - v3 = jnp.zeros_like(unobserved_means) - state3 = (m3,v3) - - m4 = jnp.zeros_like(unobserved_means) - v4 = jnp.zeros_like(unobserved_means) - state4 = (m4,v4) - states = (state1, state2, state3, state4) - local_grad = obs_ll_grad + parent_dep - for i in range(n_inner_steps): - loss, states, unobserved_means, unobserved_log_stds, unobserved_factors_kernel_log_mean, unobserved_factors_kernel_log_std = update_local_parameters(rngs, - unobserved_means, - unobserved_log_stds, - unobserved_factors_kernel_log_mean, - unobserved_factors_kernel_log_std, - root["node"].data_weights[data_indices] * root["node"].tssb.data_weights[data_indices], - parent_unobserved_samples, - baseline_samples, - cell_noise_samples, - noise_factor_samples, - jnp.array([self.root["node"].root["node"].unobserved_factors_kernel_concentration, self.root["node"].root["node"].unobserved_factors_kernel_rate]), # [concentration, rate] - jnp.array(0.), # to make the prior prefer amplifications - root["node"].cnvs, - lib_sizes, - data, - mask, - states, - i, - mb_scaling=np.sum(mask)/n_cells, - lr=lr, - ) - root["node"].variational_parameters["locals"]["unobserved_factors_mean"] = np.array(unobserved_means) - root["node"].variational_parameters["locals"]["unobserved_factors_log_std"] = np.array(unobserved_log_stds) - root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_mean"] = np.array(unobserved_factors_kernel_log_mean) - root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_std"] = np.array(unobserved_factors_kernel_log_std) - - root["node"].set_mean(variational=True) - - unobserved_samples = sample_unobserved(rngs, unobserved_means, unobserved_log_stds) - unobserved_kernel_samples = sample_unobserved_kernel(rngs, unobserved_factors_kernel_log_mean, unobserved_factors_kernel_log_std) + return self.compute_elbo_batch(batch_idx=batch_idx) + def compute_elbo_batch(self, batch_idx=None): + """ + Compute the ELBO of the model in a tree traversal, abstracting away the likelihood and kernel specific functions + for the model. The seed is used for MC sampling from the variational distributions for which Eq[logp] is generally not analytically + available (which is the likelihood and the kernel distribution). - # Compute local approximate expected log likelihood term -- unweighted - updated_indices = data_indices[np.where(mask)[0]] - ll_res = ll(baseline_samples, cell_noise_samples, noise_factor_samples, unobserved_samples, root["node"].cnvs, lib_sizes, data)[np.where(mask)[0]] - root["node"].ll[updated_indices] = ll_res - - # Compute local approximate KL divergence term - if root["node"].parent() is None: - root["node"].param_kl = 0. - else: - root["node"].param_kl = local_paramkl(parent_unobserved_samples, unobserved_samples, unobserved_kernel_samples, - unobserved_means, - unobserved_log_stds, - unobserved_factors_kernel_log_mean, - unobserved_factors_kernel_log_std, - jnp.array([self.root["node"].root["node"].unobserved_factors_kernel_concentration, self.root["node"].root["node"].unobserved_factors_kernel_rate]), jnp.array(0.)) - - bs_grads = 0. - noise_grads = 0. - cellnoise_grads = 0. - if update_global: - # Compute gradient of node ell wrt globals - data_weights = root["node"].data_weights[data_indices] * root["node"].tssb.data_weights[data_indices] - bs_grads = baseline_node_grad(rngs, log_baseline_mean, log_baseline_log_std, data_weights, unobserved_samples, cell_noise_samples, noise_factor_samples, root["node"].cnvs, lib_sizes, data, mask) - noise_grads = noise_node_grad(rngs, noise_factors_mean, noise_factors_log_std, data_weights, unobserved_samples, cell_noise_samples, baseline_samples, root["node"].cnvs, lib_sizes, data, mask) - cellnoise_grads = cellnoise_node_grad(rngs, cell_noise_mean, cell_noise_log_std, data_weights, unobserved_samples, noise_factor_samples, baseline_samples, root["node"].cnvs, lib_sizes, data, mask) - - weight_down = 0 - indices = list(range(len(root["children"]))) - indices = indices[::-1] - - for i in indices: - child = root["children"][i] - # Go down in the tree and get its weight - child_weight, child_bs_grads, child_noise_grads, child_cellnoise_grads = _sub_update_params(child, depth + 1) - if update: - post_alpha = 1.0 + child_weight - post_beta = self.dp_gamma + weight_down - child["node"].variational_parameters["locals"]["psi_log_mean"] = np.log(post_alpha) - child["node"].variational_parameters["locals"]["psi_log_std"] = np.log(post_beta) - weight_down += child_weight - bs_grads += child_bs_grads - noise_grads += child_noise_grads - cellnoise_grads += child_cellnoise_grads - - # Compute local exact KL divergence term - child["node"].psi_stick_kl = -beta_kl(np.exp(child["node"].variational_parameters["locals"]["psi_log_mean"]), np.exp(child["node"].variational_parameters["locals"]["psi_log_mean"]), 1, root["node"].tssb.dp_gamma) - - weight_here = np.sum(root["node"].data_weights) - total_weight = weight_here + weight_down - if update: - post_alpha = 1.0 + weight_here - post_beta = (self.alpha_decay**depth) * self.dp_alpha + weight_down - root["node"].variational_parameters["locals"]["nu_log_mean"] = np.log(post_alpha) - root["node"].variational_parameters["locals"]["nu_log_std"] = np.log(post_beta) - root["node"].total_weight = total_weight - root["node"].weight_down = weight_down - root["node"].weight_here = weight_here - root["node"].nu_stick_kl = -beta_kl(np.exp(root["node"].variational_parameters["locals"]["nu_log_mean"]), np.exp(root["node"].variational_parameters["locals"]["nu_log_std"]), 1, (root["node"].tssb.alpha_decay**depth)*root["node"].tssb.dp_alpha) - - return total_weight, bs_grads, noise_grads, cellnoise_grads - return _sub_update_params(root, depth=depth) - - def sub_update_weights(root): - node = [root["node"]] - - nu_alpha = np.exp(root["node"].variational_parameters["locals"]["nu_log_mean"]) - nu_beta = np.exp(root["node"].variational_parameters["locals"]["nu_log_std"]) - psi_alpha = np.exp(root["node"].variational_parameters["locals"]["psi_log_mean"]) - psi_beta = np.exp(root["node"].variational_parameters["locals"]["psi_log_std"]) - - # Compute expected local NTSSB weight term - E_log_psi = E_q_log_beta(psi_alpha, psi_beta) - E_log_nu = E_q_log_beta(nu_alpha, nu_beta) - E_log_1_nu = E_q_log_1_beta(nu_alpha, nu_beta) - - if root["node"].is_observed: - ancestors_E_log_1_nu = 0. - ancestors_and_this_E_log_phi = 0. - root["node"].ancestors_and_this_E_log_1_nu = E_log_1_nu - root["node"].ancestors_and_this_E_log_phi = 0. - root["node"].psi_not_prev_sum = 0. + If batch_idx is not None, return an estimate of the ELBO based on just the subset of data in batch_idx. + Otherwise, use sufficient statistics. + """ + if batch_idx is None: + idx = jnp.arange(self.num_data) + else: + idx = self.batch_indices[batch_idx] + def descend(root, depth=0, local_contrib=0, global_contrib=0): + # Traverse inner TSSB + subtree_ll_contrib, subtree_ass_contrib, subtree_node_contrib = root['node'].compute_elbo(idx) + ll_contrib = subtree_ll_contrib * root['node'].variational_parameters['q_c'][idx] + + # Assignments + ## E[log p(c|nu,psi)] + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + eq_logp_c = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + if root['node'].parent() is not None: + eq_logp_c += root['node'].parent().variational_parameters['sum_E_log_1_nu'] + eq_logp_c += root['node'].variational_parameters['E_log_phi'] + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu'] else: - ancestors_E_log_1_nu = root["node"].parent().ancestors_and_this_E_log_1_nu - root["node"].ancestors_and_this_E_log_1_nu = ancestors_E_log_1_nu + E_log_nu - root["node"].ancestors_and_this_E_log_phi = root["node"].parent().ancestors_and_this_E_log_phi + root["node"].psi_not_prev_sum + E_log_psi - ancestors_and_this_E_log_phi = root["node"].ancestors_and_this_E_log_phi - - # root["node"].weight_until_here = root["node"].data_weights + until_here - # root["node"].prior_weight = root["node"].data_weights*E_log_nu + until_here*E_log_1_nu + root["node"].weight_until_here*E_log_phi - root["node"].prior_weight = E_log_nu + ancestors_E_log_1_nu + ancestors_and_this_E_log_phi - root["node"].ew = root["node"].data_weights*E_log_nu + root["node"].weight_down*E_log_1_nu + root["node"].total_weight*E_log_psi - # root["node"].data_weights = root["node"].ll + root["node"].prior_weight - root["node"].unnormalized_data_weights = root["node"].ll + root["node"].prior_weight - - data_weight = [root["node"].unnormalized_data_weights] - psi_not_prev_sum = 0 - for i, child in enumerate(root["children"]): - nu_alpha = np.exp(child["node"].variational_parameters["locals"]["nu_log_mean"]) - nu_beta = np.exp(child["node"].variational_parameters["locals"]["nu_log_std"]) - psi_alpha = np.exp(child["node"].variational_parameters["locals"]["psi_log_mean"]) - psi_beta = np.exp(child["node"].variational_parameters["locals"]["psi_log_std"]) - - if i > 0: - psi_not_prev_sum += E_q_log_1_beta(psi_alpha, psi_beta) - child["node"].psi_not_prev_sum = psi_not_prev_sum - - # Go down in the tree - nodes, data_weights = sub_update_weights(child) - node.extend(nodes) - data_weight.extend(data_weights) - return node, data_weight - - def sub_normalize_update_elbo(nodes): - data_weights = np.vstack([node.unnormalized_data_weights for node in nodes]).T - data_weights = np.exp(data_weights - jnn.logsumexp(data_weights,axis=1).reshape(-1,1)) - tree_ell = 0 - tree_ew = 0 - tree_kl = 0 - for i, node in enumerate(nodes): - node.data_weights = data_weights[:,i] - node.ell = node.data_weights*node.ll - tree_ell += node.ell - tree_ew += node.data_weights*node.prior_weight - tree_kl += node.nu_stick_kl + node.psi_stick_kl + node.param_kl - return tree_ell, tree_ew, tree_kl - - def tssb_normalize_update_elbo(tssbs): - data_weights = np.vstack([tssb.unnormalized_data_weights for tssb in tssbs]).T - data_weights = np.exp(data_weights - jnn.logsumexp(data_weights,axis=1).reshape(-1,1)) - total_elbo = 0 - total_ell = 0 - total_kl = 0 - for i, tssb in enumerate(tssbs): - tssb.data_weights = data_weights[:,i] - tssb.elbo = np.sum(tssb.data_weights * tssb.ell + tssb.data_weights * jnp.log(tssb.weight) + tssb.ew) + tssb.kl - total_elbo += tssb.elbo - total_ell += np.sum(tssb.data_weights * tssb.ell + tssb.data_weights * jnp.log(tssb.weight) + tssb.ew) - total_kl += tssb.kl - return total_elbo, total_ell, total_kl - - def descend(super_root, elbo=0): - _, bs_grads, noise_grads, cellnoise_grads = sub_update_params(super_root["node"].root, data_indices, mask)#sub_update_params(super_root["node"].root) - - nodes, data_weights = sub_update_weights(super_root["node"].root) - tree_ell, tree_ew, tree_kl = sub_normalize_update_elbo(nodes) - super_root["node"].ell = tree_ell - super_root["node"].ew = tree_ew - super_root["node"].kl = tree_kl - - # Update TSSB assignment - super_root["node"].unnormalized_data_weights = tree_ell + np.log(super_root["node"].weight) - - for super_child in super_root["children"]: - child_bs_grads, child_noise_grads, child_cellnoise_grads = descend(super_child) - bs_grads += child_bs_grads - noise_grads += child_noise_grads - cellnoise_grads += child_cellnoise_grads - - return bs_grads, noise_grads, cellnoise_grads - + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + ## E[log q(c)] + eq_logq_c = jax.lax.select(root['node'].variational_parameters['q_c'][idx] != 0, + root['node'].variational_parameters['q_c'][idx] * jnp.log(root['node'].variational_parameters['q_c'][idx]), + root['node'].variational_parameters['q_c'][idx]) + ass_contrib = eq_logp_c*root['node'].variational_parameters['q_c'][idx] - eq_logq_c + subtree_ass_contrib * root['node'].variational_parameters['q_c'][idx] + + # Sticks + E_log_nu = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + nu_kl = (self.dp_alpha * self.alpha_decay**depth - root['node'].variational_parameters['delta_2']) * E_log_1_nu + nu_kl -= (root['node'].variational_parameters['delta_1'] - 1) * E_log_nu + nu_kl += logbeta_func(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + nu_kl -= logbeta_func(1, self.dp_alpha * self.alpha_decay**depth) + psi_kl = 0. + if depth != 0: + E_log_psi = E_log_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + E_log_1_psi = E_log_1_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + psi_kl = (self.dp_gamma - root['node'].variational_parameters['sigma_2']) * E_log_1_psi + psi_kl -= (root['node'].variational_parameters['sigma_1'] - 1) * E_log_psi + psi_kl += logbeta_func(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + psi_kl -= logbeta_func(1, self.dp_gamma) + stick_contrib = nu_kl + psi_kl + + self.n_total_nodes += root['node'].n_nodes + sum_E_log_1_psi = 0. + for child in root['children']: + # Auxiliary quantities + ## Branches + E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi + E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + sum_E_log_1_psi += E_log_1_psi + + # Go down + local_contrib, global_contrib = descend(child, depth=depth+1, local_contrib=local_contrib, global_contrib=global_contrib) + + local_contrib += ll_contrib + ass_contrib + global_contrib += subtree_node_contrib + stick_contrib + return local_contrib, global_contrib + + self.n_total_nodes = 0 + local_contrib, global_contrib = descend(self.root) + + # Add tree-independent contributions + global_contrib += self.num_data/len(idx) * (self.root['node'].root['node'].compute_local_priors(idx) + self.root['node'].root['node'].compute_local_entropies(idx)) + global_contrib += self.root['node'].root['node'].compute_global_priors() + self.root['node'].root['node'].compute_global_entropies() + + elbo = self.num_data/len(idx) * np.sum(local_contrib) + global_contrib + self.elbo = elbo + return elbo + + def compute_elbo_suff(self): + def descend(root, depth=0, local_contrib=0, global_contrib=0): + # Traverse inner TSSB + ll_contrib, subtree_ass_contrib, subtree_node_contrib = root['node'].compute_elbo_suff() + + # Assignments + ## E[log p(c|nu,psi)] + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + eq_logp_c = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + if root['node'].parent() is not None: + eq_logp_c += root['node'].parent().variational_parameters['sum_E_log_1_nu'] + eq_logp_c += root['node'].variational_parameters['E_log_phi'] + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu'] + else: + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + ass_contrib = eq_logp_c*root['node'].suff_stats['mass']['total'] + root['node'].suff_stats['ent']['total'] + subtree_ass_contrib + + # Sticks + E_log_nu = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + nu_kl = (self.dp_alpha * self.alpha_decay**depth - root['node'].variational_parameters['delta_2']) * E_log_1_nu + nu_kl -= (root['node'].variational_parameters['delta_1'] - 1) * E_log_nu + nu_kl += logbeta_func(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + nu_kl -= logbeta_func(1, self.dp_alpha * self.alpha_decay**depth) + psi_kl = 0. + if depth != 0: + E_log_psi = E_log_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + E_log_1_psi = E_log_1_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + psi_kl = (self.dp_gamma - root['node'].variational_parameters['sigma_2']) * E_log_1_psi + psi_kl -= (root['node'].variational_parameters['sigma_1'] - 1) * E_log_psi + psi_kl += logbeta_func(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + psi_kl -= logbeta_func(1, self.dp_gamma) + stick_contrib = nu_kl + psi_kl + + self.n_total_nodes += root['node'].n_nodes + sum_E_log_1_psi = 0. + for child in root['children']: + # Auxiliary quantities + ## Branches + E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi + E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + sum_E_log_1_psi += E_log_1_psi + + # Go down + local_contrib, global_contrib = descend(child, depth=depth+1, local_contrib=local_contrib, global_contrib=global_contrib) + + local_contrib += ll_contrib + ass_contrib + global_contrib += subtree_node_contrib + stick_contrib + return local_contrib, global_contrib + + self.n_total_nodes = 0 + local_contrib, global_contrib = descend(self.root) + + # Add tree-independent contributions + global_contrib += self.root['node'].root['node'].local_suff_stats['locals_kl']['total'] + global_contrib += self.root['node'].root['node'].compute_global_priors() + self.root['node'].root['node'].compute_global_entropies() + + elbo = local_contrib + global_contrib + self.elbo = elbo + return elbo + + def learn_model(self, n_epochs, seed=42, memoized=True, update_roots=True, update_globals=True, adaptive=True, return_trace=False, + locals_names=None, globals_names=None, **kwargs): + key = jax.random.PRNGKey(seed) elbos = [] - ells = [] - kls = [] - if root: - if update_global: - m1 = jnp.zeros((n_genes-1,)) - v1 = jnp.zeros((n_genes-1,)) - state1 = (m1,v1) - m2 = jnp.zeros((n_genes-1,)) - v2 = jnp.zeros((n_genes-1,)) - state2 = (m2,v2) - bs_states = (state1, state2) - - m1 = jnp.zeros((n_factors,n_genes)) - v1 = jnp.zeros((n_factors,n_genes)) - state1 = (m1,v1) - m2 = jnp.zeros((n_factors,n_genes)) - v2 = jnp.zeros((n_factors,n_genes)) - state2 = (m2,v2) - noise_states = (state1, state2) - - m1 = jnp.zeros((n_cells,n_factors)) - v1 = jnp.zeros((n_cells,n_factors)) - state1 = (m1,v1) - m2 = jnp.zeros((n_cells,n_factors)) - v2 = jnp.zeros((n_cells,n_factors)) - state2 = (m2,v2) - cellnoise_states = (state1, state2) - for i in range(n_traverses): - rng = random.PRNGKey(i) - rngs = random.split(rng, mc_samples) - - if update or update_global: - # Setup minibatch optimization - mask = np.ones((n_cells,)) - mask[res_data_indices] = 1 - data_indices = np.sort(random.choice(rng, n_cells, shape=(mb_size,), p=probs, replace=False)) - mask = mask[data_indices] - - # Get cell variational parameters - cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"][data_indices] - cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"][data_indices] - cell_noise_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std) - - # Get data - data = self.data[data_indices] - lib_sizes = self.root["node"].root["node"].lib_sizes[data_indices] - - # Sample global - baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std) - noise_factor_samples = sample_noise_factors(rngs, noise_factors_mean, noise_factors_log_std) - - if compute_global: - # Compute global KL - self.global_kl = -jnp.sum(baseline_kl(log_baseline_mean, log_baseline_log_std)) - self.global_kl += -jnp.sum(noise_factors_kl(noise_factors_mean, noise_factors_log_std)) - self.cell_kl = -jnp.sum(cell_noise_kl(cell_noise_mean, cell_noise_log_std)) - self.global_kl += self.cell_kl - - # Traverse tree - bs_grads, noise_grads, cellnoise_grads = descend(root) - - # Normalize across-tssbs and update global ELBOs - self.elbo, ell, kl = tssb_normalize_update_elbo(self.get_subtrees()) - self.elbo += self.global_kl - elbos.append(self.elbo) - - if update_global: - # Update baseline - bs_klgrad = baseline_kl_grad(log_baseline_mean, log_baseline_log_std) - bs_klgrad *= len(data_indices)/n_cells # minibatch scaling - bs_grads += bs_klgrad - bs_states, log_baseline_mean, log_baseline_log_std = baseline_step(log_baseline_mean, log_baseline_log_std, bs_grads, bs_states, i, lr=lr) - # - # # Update noise factors - # noise_klgrad = noise_kl_grad(noise_factors_mean, noise_factors_log_std) - # noise_klgrad *= len(data_indices)/n_cells # minibatch scaling - # noise_grads += noise_klgrad - # noise_states, noise_factors_mean, noise_factors_log_std = noise_step(noise_factors_mean, noise_factors_log_std, noise_grads, noise_states, i, lr=lr) - # - # # Update cell noise - # cellnoise_klgrad = cellnoise_kl_grad(cell_noise_mean, cell_noise_log_std) - # cellnoise_grads += cellnoise_klgrad - # local_state1, local_state2 = cellnoise_states - # local_state1 = (local_state1[0][data_indices], local_state1[1][data_indices]) - # local_state2 = (local_state2[0][data_indices], local_state2[1][data_indices]) - # local_states = (local_state1, local_state2) - # local_states, cell_noise_mean, cell_noise_log_std = cellnoise_step(cell_noise_mean, cell_noise_log_std, cellnoise_grads, local_states, i, lr=lr) - # cellnoise_states[0][0].at[data_indices].set(local_states[0][0]) - # cellnoise_states[0][1].at[data_indices].set(local_states[0][1]) - # cellnoise_states[1][0].at[data_indices].set(local_states[1][0]) - # cellnoise_states[1][1].at[data_indices].set(local_states[1][1]) - - self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"] = log_baseline_mean - self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"] = log_baseline_log_std - - # self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"] = noise_factors_mean - # self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"] = noise_factors_log_std - # - # self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"][data_indices] = cell_noise_mean - # self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"][data_indices] = cell_noise_log_std - - if sub_root: - rng = random.PRNGKey(0) - rngs = random.split(rng, mc_samples) - - # Sample global - baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std) - noise_factor_samples = sample_noise_factors(rngs, noise_factors_mean, noise_factors_log_std) - - if go_down: - this_root = sub_root["node"].tssb.get_ntssb_root() - - if restricted: - # Get data in nodes below current one - res_data_indices = set() - def data_below(node_obj): - res_data_indices.update(node_obj.data) - for child in node_obj.children(): - if child.tssb != sub_root["node"].tssb: - if go_down: - data_below(child) - else: - data_below(child) - data_below(sub_root["node"]) - res_data_indices = np.sort(jnp.array(list(res_data_indices))) - if len(res_data_indices) == 0: - res_data_indices = np.arange(n_cells) - mask[res_data_indices] = 1 - probs = np.ones((n_cells,)) - probs[res_data_indices] = 1e6 - probs = probs / np.sum(probs) - - for i in range(n_traverses): - rng = random.PRNGKey(i) - rngs = random.split(rng, mc_samples) - mask = np.zeros((n_cells,)) - mask[res_data_indices] = 1 - data_indices = np.sort(random.choice(rng, n_cells, shape=(mb_size,), p=probs, replace=False)) - mask = mask[data_indices] - - # Get cell variational parameters - cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"][data_indices] - cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"][data_indices] - cell_noise_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std) - - # Get data - data = self.data[data_indices] - lib_sizes = self.root["node"].root["node"].lib_sizes[data_indices] - - # Update node parameters - sub_update_params(sub_root, data_indices, mask, depth=len(sub_root["node"].label)-1) - - # Update within-tssb data assignment probabilities - nodes, data_weights = sub_update_weights(sub_root) - - # Normalize within-tssb and update local ELBOs - # This can be done approximately: no need to re-normalize across all the others if their relative contributions - # didn't change and we end up with weights close to 0. This applies only to restricted ELBO updates on selected data - tree_ell, tree_ew, tree_kl = sub_normalize_update_elbo(sub_root["node"].tssb.get_mixture()[1]) - sub_root["node"].tssb.ell = tree_ell - sub_root["node"].tssb.ew = tree_ew - sub_root["node"].tssb.kl = tree_kl - - # Update TSSB assignment - sub_root["node"].tssb.unnormalized_data_weights = tree_ell + np.log(sub_root["node"].tssb.weight) - - # Maybe continue down to any children subtrees? - if go_down: - for sub_child in this_root["children"]: - # Only continue to children subtrees in sub_root's path - if sub_root["node"].label in sub_child["pivot_node"].label: - logger.debug(f"Also updating parameters of {sub_child['node'].label} and down") - descend(sub_child) - - # Normalize across-tssbs and update global ELBOs - self.elbo, ell, kl = tssb_normalize_update_elbo(self.get_subtrees()) - self.elbo += self.global_kl - elbos.append(self.elbo) - ells.append(ell) - kls.append(kl) - - return elbos, ells, kls - - def _optimize_elbo( - self, - root_node=None, - local_node=None, - global_only=False, - sticks_only=False, - unique_node=None, - num_samples=10, - n_iters=100, - thin=10, - step_size=0.05, - debug=False, - tol=1e-5, - run=True, - max_nodes=5, - init=False, - opt=None, - opt_triplet=None, - mb_size=100, - callback=None, - **callback_kwargs, - ): - # start = time.time() - self.max_nodes = ( - len(self.input_tree_dict.keys()) * max_nodes - ) # upper bound on number of nodes - self.data = jnp.array(self.data, dtype="float32") + states = None + local_states = None + if adaptive: + local_states = self.root['node'].root['node'].initialize_local_opt_states(param_names=locals_names) + states = self.root['node'].root['node'].initialize_global_opt_states(param_names=globals_names) + it = 0 + idx = None + for i in range(n_epochs): + for batch_idx in range(self.num_batches): + key, subkey = jax.random.split(key) + local_states = self.update_local_params(subkey, batch_idx=batch_idx, adaptive=adaptive, states=local_states, i=it, + param_names=locals_names, update_globals=update_globals, **kwargs) + if update_globals: + states = self.update_global_params(subkey, idx=idx, batch_idx=batch_idx, adaptive=adaptive, states=states, i=it, param_names=globals_names, **kwargs) + if memoized: + self.update_sufficient_statistics(batch_idx=batch_idx) + self.update_node_params(subkey, i=it, adaptive=adaptive, memoized=memoized, **kwargs) + if update_roots: + self.update_root_node_params(subkey, memoized=memoized, batch_idx=batch_idx, adaptive=adaptive, i=it, **kwargs) + self.update_pivot_probs() + it += 1 + if return_trace: + elbos.append(self.compute_elbo(memoized=memoized, batch_idx=batch_idx)) + + if return_trace: + return elbos - # Var params of nodes below root - nodes, parent_vector = self.get_nodes(root_node=None, parent_vector=True) + def learn_roots(self, n_epochs, seed=42, adaptive=True, return_trace=False, **kwargs): + key = jax.random.PRNGKey(seed) + elbos = [] + it = 0 + for i in range(n_epochs): + for batch_idx in range(self.num_batches): + key, subkey = jax.random.split(key) + self.update_root_node_params(subkey, batch_idx=batch_idx, adaptive=adaptive, i=it, **kwargs) + if return_trace: + elbos.append(self.compute_elbo(batch_idx=batch_idx, **kwargs)) + it += 1 + + if return_trace: + return elbos - n_nodes = len(nodes) - rem = self.max_nodes - n_nodes + def learn_globals(self, n_epochs, globals_names=None, locals_names=None, ass_anneal=1., ent_anneal=1., update_ass=True, update_locals=True, update_roots=False, subset=None, adaptive=True, seed=42, return_trace=False, **kwargs): + key = jax.random.PRNGKey(seed) + elbos = [] + states = None + local_states = None + if adaptive: + local_states = self.root['node'].root['node'].initialize_local_opt_states(param_names=locals_names) + states = self.root['node'].root['node'].initialize_global_opt_states(param_names=globals_names) + it = 0 + idx = None + for i in range(n_epochs): + for batch_idx in range(self.num_batches): + key, subkey = jax.random.split(key) + if subset is not None: + idx = jnp.array(list(set(self.batch_indices[batch_idx]).intersection(set(subset)))) + local_states = self.update_local_params(subkey, idx=idx, batch_idx=batch_idx, ass_anneal=ass_anneal, ent_anneal=ent_anneal, + update_ass=update_ass, update_globals=update_locals, adaptive=adaptive, states=local_states, i=it, + param_names=locals_names, **kwargs) + states = self.update_global_params(subkey, idx=idx, batch_idx=batch_idx, adaptive=adaptive, states=states, i=it, param_names=globals_names, **kwargs) + if update_roots: + self.update_root_node_params(subkey, memoized=False, adaptive=adaptive, i=it, **kwargs) + it += 1 + if return_trace: + elbos.append(self.compute_elbo_batch(batch_idx=batch_idx)) + + if return_trace: + return elbos + + def learn_params(self, n_epochs, seed=42, memoized=True, adaptive=True, update_roots=False, return_trace=False, **kwargs): + key = jax.random.PRNGKey(seed) + elbos = [] + it = 0 + for i in range(n_epochs): + for batch_idx in range(self.num_batches): + key, subkey = jax.random.split(key) + self.update_local_params(subkey, batch_idx=batch_idx, update_globals=False, **kwargs) + if memoized: + self.update_sufficient_statistics(batch_idx=batch_idx) + if update_roots: + self.update_root_node_params(subkey, memoized=memoized, adaptive=adaptive, i=it, **kwargs) + self.update_node_params(subkey, i=it, adaptive=adaptive, memoized=memoized, **kwargs) + self.update_pivot_probs() + it += 1 + if return_trace: + elbos.append(self.compute_elbo(memoized=memoized, batch_idx=batch_idx)) + + if return_trace: + return elbos - # Root node label - data_indices = np.array(list(range(self.num_data))) - do_global = False - if root_node is not None: - root_label = root_node.label - data_indices = set() + def update_sufficient_statistics(self, batch_idx=None): + """ + Go to each node and update its sufficient statistics. Set the suff stats of batch_idx, + and update the total suff stats + """ + def descend(root): + root['node'].update_sufficient_statistics(batch_idx=batch_idx) + for child in root['children']: + descend(child) + + descend(self.root) + def update_local_params(self, key, ass_anneal=1., ent_anneal=1., idx=None, batch_idx=None, states=None, adaptive=False, i=0, step_size=0.0001, mc_samples=10, update_ass=True, update_outer_ass=True, update_globals=True, **kwargs): + """ + This performs a tree traversal to update the sample to node attachment probabilities and other sample-specific parameters + """ + if idx is None: + if batch_idx is None: + idx = jnp.arange(self.num_data) + else: + idx = self.batch_indices[batch_idx] + + take_gradients = False + if update_globals: + take_gradients = True + + def descend(root, local_grads=None): + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + logprior = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + if root['node'].parent() is not None: + logprior += root['node'].parent().variational_parameters['sum_E_log_1_nu'] + logprior += root['node'].variational_parameters['E_log_phi'] + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu'] + else: + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + + # Traverse inner TSSB + ll_node_sum, ent_node_sum, local_grads_down = root['node'].update_local_params(idx, update_ass=update_ass, take_gradients=take_gradients) + new_log_probs = ass_anneal*(ll_node_sum + logprior + ent_node_sum) + if update_ass and update_outer_ass: + root['node'].variational_parameters['q_c'] = root['node'].variational_parameters['q_c'].at[idx].set(new_log_probs) + if local_grads_down is not None: + if local_grads is None: + local_grads = list(local_grads_down) + else: + for ii, grads in enumerate(list(local_grads_down)): + local_grads[ii] += grads + logqs = [new_log_probs] + sum_E_log_1_psi = 0. + for child in root['children']: + E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi + E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + sum_E_log_1_psi += E_log_1_psi + + # Go down + child_log_probs, child_local_grads = descend(child, local_grads=local_grads) + logqs.extend(child_log_probs) + if child_local_grads is not None: + for ii, grads in enumerate(list(child_local_grads)): + local_grads[ii] += grads + + return logqs, local_grads # batch-sized + + if update_globals: + # Take MC sample and its gradient wrt parameters + key, sample_grad = self.root['node'].root['node'].local_sample_and_grad(idx, key, n_samples=mc_samples) + locals_curr_sample, locals_params_grad = sample_grad + self.root['node'].root['node'].set_local_sample(locals_curr_sample, idx=idx) + + # Traverse tree + logqs, locals_sample_grad = descend(self.root) + + if update_globals: + locals_prior_grads = self.root['node'].root['node'].compute_locals_prior_grad(locals_curr_sample) + for ii, grads in enumerate(locals_prior_grads): + locals_sample_grad[ii] += grads + + # Gradient of entropy wrt parameters + locals_entropies_grad = self.root['node'].root['node'].compute_locals_entropy_grad(idx) + + # Take gradient step + if adaptive: + states = self.root['node'].root['node'].update_local_params_adaptive(idx, locals_params_grad, locals_sample_grad, locals_entropies_grad, ent_anneal=ent_anneal, step_size=step_size, + states=states, i=i, **kwargs) + else: + self.root['node'].root['node'].update_local_params(idx, locals_params_grad, locals_sample_grad, locals_entropies_grad, ent_anneal=ent_anneal, step_size=step_size, **kwargs) + + # Resample and store + key, sample_grad = self.root['node'].root['node'].local_sample_and_grad(idx, key, n_samples=mc_samples) + locals_curr_sample, _ = sample_grad + self.root['node'].root['node'].set_local_sample(locals_curr_sample, idx=idx) + + if update_ass and update_outer_ass: + # Compute LSE + logqs = jnp.array(logqs).T + self.variational_parameters['LSE_c'] = jax.scipy.special.logsumexp(logqs, axis=1) + # Set probs def descend(root): - data_indices.update(root.data) - for child in root.children(): + newvals = jnp.exp(root['node'].variational_parameters['q_c'][idx] - self.variational_parameters['LSE_c']) + root['node'].variational_parameters['q_c'] = root['node'].variational_parameters['q_c'].at[idx].set(newvals) + for child in root['children']: descend(child) + descend(self.root) + + return states - descend(root_node) - data_indices = np.sort(list(data_indices)) - if len(data_indices) == 0 and root_node.parent() is not None: - data_indices = np.sort(list(root_node.parent().data)) - # data_indices = list(root_node.data) - else: - do_global = True - root_label = self.root["node"].label - root_node = self.root["node"].root["node"] - - if local_node is not None: - do_global = False - root_node = local_node.parent() - root_label = root_node.label - init_ass_logits = np.array([node.data_ass_logits for node in nodes]).T - - @jax.jit - def f(a): - return jnn.softmax(a, axis=1) - - init_ass_probs = f(init_ass_logits) - data_indices = list( - np.where( - init_ass_probs[:, np.where(root_node == np.array(nodes))[0][0]] - > 1.0 / np.sqrt(len(nodes)) - )[0] - ) - if len(data_indices) == 0: - data_indices = np.arange(self.num_data) - - data_mask = np.zeros((self.num_data,)) - data_mask[data_indices] = 1.0 - data_mask = data_mask.astype(int) - - parent_vector = np.array(parent_vector) - parent_vector = jnp.array( - np.concatenate([parent_vector, -2 * np.ones((rem,))]) - ).astype(int) - - tssbs = [node.tssb.label for node in nodes] - tssb_indices = self.get_tssb_indices(nodes, tssbs) - # start3 = time.time() - children_vector = self.get_children_vector(parent_vector) - # end3 = time.time() - # print(f"get_children_vector: {end3-start3}") - ancestor_nodes_indices = self.get_ancestor_indices(nodes, parent_vector) - previous_branches_indices = self.get_previous_branches_indices(nodes) - node_idx = np.where(np.array(nodes) == root_node)[0][0] - node_mask_idx = node_idx - if local_node is None: - node_mask_idx = self.get_below_root(node_idx, children_vector, tssbs=None) + def update_global_params(self, key, idx=None, batch_idx=None, mc_samples=10, adaptive=False, step_size=0.0001, states=None, i=0, **kwargs): + """ + This performs a tree traversal to update the global parameters. + The global parameters are updated using stochastic mini-batch VI. + """ + if idx is None: + if batch_idx is None: + idx = jnp.arange(self.num_data) + else: + idx = self.batch_indices[batch_idx] + def descend(root, globals_grads=None): + globals_grads_down = root['node'].get_global_grads(idx) + if globals_grads_down is not None: + if globals_grads is None: + globals_grads = list(globals_grads_down) + else: + for ii, grads in enumerate(list(globals_grads_down)): + globals_grads[ii] += grads + for child in root['children']: + child_globals_grads = descend(child, globals_grads=globals_grads) + if child_globals_grads is not None: + for ii, grads in enumerate(list(child_globals_grads)): + globals_grads[ii] += grads + return globals_grads + + # Take MC sample and its gradient wrt parameters + key, sample_grad = self.root['node'].root['node'].global_sample_and_grad(key, n_samples=mc_samples) + globals_curr_sample, globals_params_grad = sample_grad + self.root['node'].root['node'].set_global_sample(globals_curr_sample) + + # Get gradient of loss of data likelihood weighted by assignment probability to each node wrt current sample of global params + globals_sample_grad = descend(self.root) + + # Scale gradient by batch size + for ii in range(len(globals_sample_grad)): + globals_sample_grad[ii] *= self.num_data/len(idx) + + # Gradient of prior wrt sample + globals_prior_grads = self.root['node'].root['node'].compute_globals_prior_grad(globals_curr_sample) + len_globals_in_ll = len(globals_sample_grad) + # Add the priors + for ii in range(len_globals_in_ll): + globals_sample_grad[ii] += globals_prior_grads[ii] + + # Extend to hierarchical parameters + for grads in globals_prior_grads[len_globals_in_ll:]: + globals_sample_grad.append(grads) + + # Gradient of entropy wrt parameters + globals_entropies_grad = self.root['node'].root['node'].compute_globals_entropy_grad() + + # Take gradient step + if adaptive: + states = self.root['node'].root['node'].update_global_params_adaptive(globals_params_grad, globals_sample_grad, + globals_entropies_grad, step_size=step_size, states=states, i=i, **kwargs) else: - local_node_idx = np.where(np.array(nodes) == local_node)[0][0] - node_mask_idx = np.array([node_idx, local_node_idx]) - - if unique_node is not None: - node_idx = np.where(np.array(nodes) == unique_node)[0][0] - node_mask_idx = node_idx - do_global = False - data_indices = list(unique_node.data) - data_mask = np.zeros((self.num_data,)) - data_mask[data_indices] = 1.0 - data_mask = data_mask.astype(int) - sticks_only = True - - # start2 = time.time() - node_mask = np.zeros((len(nodes),)) - node_mask[node_mask_idx] = 1 - - dp_alphas = np.array([node.tssb.dp_alpha for node in nodes]) - dp_gammas = np.array([node.tssb.dp_gamma for node in nodes]) - tssb_weights = np.array([node.tssb.weight for node in nodes]) - - global_names = list(nodes[0].variational_parameters["globals"].keys()) - global_params = [ - jnp.array(nodes[0].variational_parameters["globals"][param]) - for param in global_names - ] + states = self.root['node'].root['node'].update_global_params(globals_params_grad, globals_sample_grad, + globals_entropies_grad, step_size=step_size, **kwargs) - local_names = list(nodes[0].variational_parameters["locals"].keys()) - local_params = [] - for node in nodes: - local_params.append( - [node.variational_parameters["locals"][param] for param in local_names] - ) + # Resample and store + key, sample_grad = self.root['node'].root['node'].global_sample_and_grad(key, n_samples=mc_samples) + globals_curr_sample, _ = sample_grad + self.root['node'].root['node'].set_global_sample(globals_curr_sample) - obs_params = np.array([node.observed_parameters for node in nodes]) + if adaptive: + return states - tssb_weights = jnp.array(np.concatenate([tssb_weights, 10 * np.ones((rem,))])) - dp_gammas = jnp.array(np.concatenate([dp_gammas, 1 * np.ones((rem,))])) - dp_alphas = jnp.array(np.concatenate([dp_alphas, 1 * np.ones((rem,))])) - previous_branches_indices = jnp.array( - np.concatenate( - [ - previous_branches_indices, - -1 * np.ones((rem, previous_branches_indices.shape[1])), - ], - axis=0, - ) - ).astype(int) - ancestor_nodes_indices = jnp.array( - np.concatenate( - [ - ancestor_nodes_indices, - -1 * np.ones((rem, ancestor_nodes_indices.shape[1])), - ], - axis=0, - ) - ).astype(int) - # children_vector = jnp.concatenate([children_vector, -1*jnp.ones((rem, children_vector.shape[1]))], axis=0).astype(int) - tssb_indices = jnp.array( - np.concatenate( - [tssb_indices, -1 * np.ones((rem, tssb_indices.shape[1]))], axis=0 - ) - ).astype(int) - obs_params = jnp.array( - np.concatenate( - [obs_params, np.zeros((rem, nodes[0].observed_parameters.size))], axis=0 - ) - ) - node_mask = jnp.array(np.concatenate([node_mask, -2 * np.ones((rem,))])).astype( - int - ) - all_nodes_mask = np.ones(len(node_mask)) * -2 - all_nodes_mask[np.where(node_mask >= 0)[0]] = 1 - all_nodes_mask = jnp.array(all_nodes_mask) - local_params_list = [] - for param_idx in range(len(local_params[0])): - l = [] - for node_idx in range(len(nodes)): - l.append(local_params[node_idx][param_idx]) - l = np.vstack(l) - - # Add dummy nodes - param_shape = l[0].shape[0] - l = jnp.array(np.concatenate([l, np.zeros((rem, param_shape))], axis=0)) - - local_params_list.append(l) - # print([node.label for node in nodes]) - # print(parent_vector) - # print(children_vector) - # print([cnv[0] for cnv in cnvs]) - - if not do_global: - logger.debug("Won't take derivatives wrt global parameters") - elif global_only: - logger.debug("Won't take derivatives wrt local parameters") - - do_global = do_global * jnp.array(1.0) - global_only = global_only * jnp.array(1.0) - sticks_only = sticks_only * jnp.array(1.0) - - init_params = local_params_list + global_params - - # end2 = time.time() - # print(f"Getting parameters: {end2-start2}") - - if opt_triplet is None: - if opt is None: - opt = optimizers.adam - opt_init, opt_update, get_params = opt(step_size=step_size) - get_params = jit(get_params) - opt_update = jit(opt_update) - opt_init = jit(opt_init) - else: - opt_init, opt_update, get_params = ( - opt_triplet[0], - opt_triplet[1], - opt_triplet[2], - ) - opt_state = opt_init(init_params) - - # print(f"Time to prepare optimizer: {end-start} s") - # n_nodes = jnp.array(n_nodes) - self.n_nodes = n_nodes - # print(n_nodes) - if callback is None: - callback = elbos_callback - # print("Iteration {} lower bound {}".format(t, self.batch_objective(cnvs, parent_vector, children_vector, ancestor_nodes_indices, tssb_indices, previous_branches_indices, tssb_weights, dp_alphas, dp_gammas, params, t))) - - data = jnp.array(self.data) - lib_sizes = jnp.array(self.root["node"].root["node"].lib_sizes) - cell_covariates = jnp.array(self.covariates) - - unobserved_factors_kernel_rate = ( - self.root["node"].root["node"].unobserved_factors_kernel_rate - ) - unobserved_factors_kernel_concentration = ( - self.root["node"].root["node"].unobserved_factors_kernel_concentration - ) - unobserved_factors_root_kernel = ( - self.root["node"].root["node"].unobserved_factors_root_kernel - ) - global_noise_factors_precisions_shape = ( - self.root["node"].root["node"].global_noise_factors_precisions_shape - ) + def update_node_params(self, key, memoized=True, **kwargs): + """ + This performs a tree traversal to update the stick parameters and the kernel parameters. + The node parameters are updated using stochastic memoized VI. + """ + def descend(root, depth=0): + mass_down = 0 + for i, child in enumerate(root["children"][::-1]): + j = len(root['children'])-1-i + child_mass = descend(child, depth + 1) + child['node'].variational_parameters['sigma_1'] = root['psi_priors'][j]["alpha_psi"] + child_mass + child['node'].variational_parameters['sigma_2'] = root['psi_priors'][j]["beta_psi"] + mass_down + mass_down += child_mass + + # Traverse inner TSSB + root['node'].update_node_params(key, memoized=memoized, **kwargs) + + if memoized: + mass_here = root['node'].suff_stats['mass']['total'] + else: + mass_here = jnp.sum(root['node'].variational_parameters['q_c']) + root['node'].variational_parameters['delta_1'] = root['alpha_nu'] + mass_here + root['node'].variational_parameters['delta_2'] = root['beta_nu'] + mass_down - # print(all_nodes_mask) - full_data_indices = jnp.array(np.arange(self.num_data)) - data_mask_subset = jnp.array(data_mask) - sub_data_indices = np.where(data_mask)[0] - current_elbo = self.elbo - - # Get max width in node_mask - # Should not count below root? - max_width = 1 - node_mask_idx_below_root = node_mask_idx[ - np.where(ancestor_nodes_indices[node_mask_idx, 0] != 0)[0] - ] - local_node_mask = jnp.array(node_mask) - if len(node_mask_idx_below_root) > 0: - # original previous_branches_indices only contains within tssb! - previous_branches_indices2 = self.get_previous_branches_indices( - nodes, within_tssb=False - ) - masked_prevs = np.array( - previous_branches_indices2[node_mask_idx_below_root] - ).reshape(len(node_mask_idx_below_root), -1) - max_width = np.max(np.sum(masked_prevs >= 0, axis=1)) + 1 - - # Get max depth within TSSB - masked_ancs = np.array(ancestor_nodes_indices)[node_mask_idx] - max_depth = np.max(np.sum(masked_ancs >= 0, axis=1)) + 1 - - # end = time.time() - # print(f"before run: {end-start}") - if run: - # Main loop. - current_elbo = self.elbo - # if self.elbo == -np.inf: - # current_elbo = -self.batch_objective(obs_params, parent_vector, children_vector, ancestor_nodes_indices, tssb_indices, previous_branches_indices, tssb_weights, dp_alphas, dp_gammas, all_nodes_mask, do_global, global_only, sticks_only, num_samples, init_params, 0) - # print(f"Current ELBO: {current_elbo:.5f}") - # print(f"Optimizing variational parameters from node {root_label}...") - elbos = [] - means = [] - minibatch_probs = np.ones((self.num_data,)) - minibatch_probs[sub_data_indices] = 1e6 - minibatch_probs = minibatch_probs / np.sum(minibatch_probs) - for t in range(n_iters): - minibatch_idx = np.random.choice( - self.num_data, p=minibatch_probs, size=mb_size, replace=False - ) - minibatch_idx = jnp.array(np.sort(minibatch_idx)).ravel() - data_mask_subset = jnp.array(data_mask)[minibatch_idx] - # minibatch_idx = np.arange(self.num_data) - # data_mask_subset = data_mask - # start = time.time() - # Update params - # Iterate through nodes - # This is very slow: should reduce number of iterations?! Or - # randomly choose a node at each iteration? - # If I reduce the number of iterations I don't account for different complexities - # Only take gradient of one node at a time - # If I do it randomly I'm probably not using momentum, right? - # Only do it for big widths - do_sticks = jnp.array(1.0) - if max_width > 2 or (max_depth > 2 and np.sum(node_mask) > 2): - node_idx = np.random.choice(node_mask_idx) - local_node_mask = np.array(node_mask) - off_idx = np.where(node_mask >= 0)[0] - local_node_mask[off_idx] = 0 - local_node_mask[node_idx] = 1 - local_node_mask = jnp.array(local_node_mask) - opt_state, g, params, elbo = cna.opt_funcs.update( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - local_node_mask, - data_mask_subset, - minibatch_idx, - do_global, - global_only, - do_sticks, - sticks_only, - num_samples, - t, - opt_state, - opt_update, - get_params, - ) - # end = time.time() - # print(f"update: {end-start}") - elbos.append(-elbo) - try: - callback(elbos, **callback_kwargs) - except StopIteration as e: - logger.debug(f"Stopped optimization at iteration {t}/{n_iters}") - break + return mass_here + mass_down + + descend(self.root) - # Without node mask - # start = time.time() - ret = cna.opt_funcs.batch_objective( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - all_nodes_mask, - jnp.array(1.0), - jnp.array(0.0), - jnp.array(1.0), - jnp.array(0.0), - num_samples, - data.shape[0], - get_params(opt_state), - 0, - ) - # end = time.time() - # print(f"batch_objective: {end-start}") - self.elbo = np.array(ret[0]).item() - self.ll = np.array(ret[1]).item() - self.kl = np.array(ret[2]).item() - self.node_kl = np.array(ret[3]) - - # Weigh by tree prior - subtrees = self.get_mixture()[1][1:] # without the root - for subtree in subtrees: - pivot_node = subtree.root["node"].parent() - parent_subtree = pivot_node.tssb - prior_weights, subnodes = parent_subtree.get_fixed_weights() - # Weight ELBO by chosen pivot's prior probability - node_idx = np.where(pivot_node == np.array(subnodes))[0][0] - self.elbo = self.elbo + np.log(prior_weights[node_idx]) - - # Combinatorial penalization to avoid duplicates -- also avoids real clusters! - # self.elbo = self.elbo + np.log(1/(2**len(data_indices))) - - new_elbo = self.elbo - # print(f"Done. Speed: {avg_speed} s/it, Total: {total} s") - # print(f"New ELBO: {new_elbo:.5f}") - # print(f"New ELBO improvement: {(new_elbo - current_elbo)/np.abs(current_elbo) * 100:.3f}%\n") - - # start = time.time() - self.set_node_means( - get_params(opt_state), - nodes, - local_names, - global_names, - node_mask=node_mask, - do_global=do_global, - ) - self.update_ass_logits( - nodes=nodes, variational=True - ) - self.assign_to_best(nodes=nodes) - # end = time.time() - # print(f"last part: {end-start}") - return elbos - else: - ret = cna.opt_funcs.batch_objective( - data, - cell_covariates, - lib_sizes, - unobserved_factors_kernel_rate, - unobserved_factors_kernel_concentration, - unobserved_factors_root_kernel, - global_noise_factors_precisions_shape, - obs_params, - parent_vector, - children_vector, - ancestor_nodes_indices, - tssb_indices, - previous_branches_indices, - tssb_weights, - dp_alphas, - dp_gammas, - all_nodes_mask, - jnp.array(1.0), - jnp.array(0.0), - jnp.array(1.0), - jnp.array(0.0), - num_samples, - data.shape[0], - get_params(opt_state), - 0, - ) - self.elbo = np.array(ret[0]).item() - self.ll = np.array(ret[1]).item() - self.kl = np.array(ret[2]).item() - self.node_kl = np.array(ret[3]) - - # Weigh by tree prior - subtrees = self.get_mixture()[1][1:] # without the root - for subtree in subtrees: - pivot_node = subtree.root["node"].parent() - parent_subtree = pivot_node.tssb - prior_weights, subnodes = parent_subtree.get_fixed_weights() - # Weight ELBO by chosen pivot's prior probability - node_idx = np.where(pivot_node == np.array(subnodes))[0][0] - self.elbo = self.elbo + np.log(prior_weights[node_idx]) - self.update_ass_logits(variational=True) - self.assign_to_best(nodes=nodes) - return None - - def set_node_means( - self, params, nodes, local_names, global_names, node_mask=None, do_global=True - ): - # start = time.time() - globals_start = len(local_names) - params_idx = 0 - for i, global_param in enumerate(global_names): - params_idx = globals_start + i - if ( - do_global or "cell" in global_param - ): # always update cell-specific parameters - self.root["node"].root["node"].variational_parameters["globals"][ - global_param - ] = np.array(params[params_idx]) - - if node_mask is None: - node_indices = np.arange(len(nodes)) - else: - node_indices = np.where(node_mask == 1)[0] - for node_idx in node_indices: - for i, local_param in enumerate(local_names): - nodes[node_idx].variational_parameters["locals"][ - local_param - ] = np.array(params[i][node_idx]) - nodes[node_idx].set_mean(variational=True) - - def update_ass_logits( - self, nodes=None, indices=None, variational=False, prior=True - ): - # start = time.time() - if indices is None: - indices = list(range(self.num_data)) - - ns, weights = self.get_node_mixture() - if nodes is None: - nodes = ns - - for i, node in enumerate(ns): - if node in nodes: - node_lls = node.loglh( - np.array(indices), variational=variational, axis=1 - ) - node_lls = node_lls + np.log(weights[i] + 1e-300) if prior else node_lls - node.data_ass_logits[np.array(indices)] = node_lls - # print(f"update_ass_logits: {time.time()-start}") - - - def assign_to_best(self): - total_weights = [] - node_list = [] - tssbs = self.get_subtrees() - tssb_list = [] - for tssb in tssbs: - tssb._data.clear() - _, nodes = tssb.get_fixed_weights() - for node in nodes: - node_list.append(node) - tssb_list.append(tssb) - total_weights.append(tssb.data_weights * node.data_weights) - total_weights = np.vstack(total_weights).T - self.total_weights = total_weights - self.node_list = node_list - self.tssb_list = tssb_list - assignments = np.argmax(total_weights,axis=1) - for i, node in enumerate(node_list): - node.remove_data() - data = np.where(assignments == i)[0] - node.add_data(np.where(assignments == i)[0]) - node.tssb._data.update(np.where(assignments == i)[0]) - self.assignments = list(np.array(node_list)[assignments]) + def update_root_node_params(self, key, memoized=True, adaptive=True, i=0, **kwargs): + """ + This performs a tree traversal to update the kernel parameters of root nodes + The node parameters are updated using stochastic memoized VI. - def _assign_to_best(self, nodes=None): - # start = time.time() - if nodes is None: - nodes = self.get_nodes() + Go inside TSSB1, compute the gradient of TSSB2's root wrt TSSB1's nodes, return their sum + Go inside TSSB2, compute the gradient of TSSB2's root wrt TSSB3's root, add to previous gradient and update + Repeat + """ + def descend(root, depth=0): + ll_grads = [] + locals_grads = [] + children_grads = [] + for child in root["children"]: + child_ll_grad, child_locals_grads, child_children_grads = descend(child, depth + 1) + ll_grads.append(child_ll_grad) + locals_grads.append(child_locals_grads) + children_grads.append(child_children_grads) + + if len(root["children"]) > 0: + # Compute gradient of children roots wrt possible parents here + parent_grads = root['node'].compute_children_root_node_grads(**kwargs) + + # Update parameters of each child root + for ii, child in enumerate(root["children"]): + ll_grad = ll_grads[ii] + local_grad = locals_grads[ii] + children_grad = children_grads[ii] + parent_grad = parent_grads[ii] + child['node'].update_root_node_params(key, ll_grad, local_grad, children_grad, parent_grad, adaptive=adaptive, i=i, **kwargs) + + if depth > 0: # root of root TSSB has no parameters + # Sample root node and compute gradients wrt children (including child roots) + # direction_grads = [direction_params_grad, direction_params_entropy_grad, direction_sample_grad] + # state_grads = [state_params_grad, state_params_entropy_grad, state_sample_grad] + ll_grad, locals_grads, children_grads = root['node'].sample_grad_root_node(key, memoized=memoized, **kwargs) + + return ll_grad, locals_grads, children_grads + + descend(self.root) + + def update_pivot_probs(self): + def descend(root): + if len(root["children"]) > 0: + root["node"].update_pivot_probs() + for child in root["children"]: + descend(child) + descend(self.root) - assignment_logits = jnp.array([node.data_ass_logits for node in nodes]).T + def assign_samples(self): + def descend(root): + nodes_probs = [root['node'].variational_parameters['q_z'].ravel() * root['node'].tssb.variational_parameters['q_c'].ravel()] + nodes = [root['node']] + for child in root["children"]: + cnodes_probs, cnodes = descend(child) + nodes_probs.extend(cnodes_probs) + nodes.extend(cnodes) + return nodes_probs, nodes + def super_descend(super_root): + nodes_probs_lists = [] + nodes_lists = [] + nodes_probs, nodes = descend(super_root["node"].root) + nodes_probs_lists = [nodes_probs] + nodes_lists = [nodes] + for super_child in super_root["children"]: + children_nodes_probs, children_nodes = super_descend(super_child) + nodes_probs_lists.extend(children_nodes_probs) + nodes_lists.extend(children_nodes) + return nodes_probs_lists, nodes_lists + + nodes_probs, nodes = super_descend(self.root) + nodes = [x for xs in nodes for x in xs] + nodes_probs = [x for xs in nodes_probs for x in xs] + nodes_probs = np.array(nodes_probs).T + self.assignments = np.array(nodes)[np.argmax(nodes_probs, axis=1)] + for node in nodes: + node.remove_data() + node.add_data(np.where(self.assignments == node)[0]) + node.tssb._data.update(np.where(self.assignments == node)[0]) - @jit - def get_assignments(assignment_logits): - assignment_probs = jnp.array(jnn.softmax(assignment_logits, axis=1)) - return jax.vmap(jnp.argmax)(assignment_probs) + def assign_pivots(self): + def descend(root): + pivot_probs_nodes = [root['node'].get_pivot_probabilities(i) for i in range(len(root['children']))] + for i, child in enumerate(root['children']): + pivot_probs = [l for l in pivot_probs_nodes[i][0]] + pivot_nodes = [l for l in pivot_probs_nodes[i][1]] + child['pivot_node'] = pivot_nodes[np.argmax(pivot_probs)] + child['node'].root['node'].set_parent(child['pivot_node']) + descend(child) + descend(self.root) - assignments = np.array(get_assignments(assignment_logits)) + # ========= Functions to update tree structure. ========= - # Clear all - for i, node in enumerate(nodes): - node.remove_data() - node.add_data(np.where(assignments == i)[0]) + def prune_subtrees(self): + # Remove all nodes except observed ones + def descend(root): + def sub_descend(sub_root): + for sub_child in sub_root['children']: + sub_descend(sub_child) + root['node'].merge_nodes(sub_root, sub_child, sub_root) + sub_descend(root['node'].root) + for child in root['children']: + descend(child) + descend(self.root) - self.assignments = list(np.array(nodes)[assignments]) + def birth(self, source, seed=None): + def sub_descend(root, target_root=None): + if source == root['node'].label: # node label + target_root = root + for child in root['children']: + if target_root is not None: + return target_root + else: + target_root = sub_descend(child, target_root=target_root) + return target_root - # print(f"assign_to_best: {time.time()-start}") + def descend(root): + if root['node'].label in source: # TSSB label + target_root = sub_descend(root['node'].root) + root['node'].add_node(target_root, seed=seed) + for child in root['children']: + descend(child) - # ========= Functions to update tree structure. ========= + descend(self.root) - def add_node_to(self, node, optimal_init=True, factor_idx=None, return_parent_root=True): - if isinstance(node, dict): - root = node - else: - nodes = self._get_nodes(get_roots=True) - nodes_list = np.array([node[0] for node in nodes]) - roots_list = np.array([node[1] for node in nodes]) - if isinstance(node, str): - self.plot_tree(super_only=False) - nodes_list = np.array([node[0].label for node in nodes]) - node_idx = np.where(nodes_list == node)[0][0] - root = roots_list[node_idx] - - # Create child - stick_length = boundbeta(1, self.dp_gamma) - root["sticks"] = np.vstack([root["sticks"], stick_length]) - root["children"].append( - { - "node": root["node"].spawn(False, root["node"].observed_parameters), - "main": boundbeta( - 1.0, (self.alpha_decay ** (root["node"].depth + 1)) * self.dp_alpha - ) - if self.min_depth <= (root["node"].depth + 1) - else 0.0, - "sticks": np.empty((0, 1)), - "children": [], - } - ) - root["children"][-1]["node"].reset_variational_parameters() + def merge(self, source, target): + def sub_descend(root, parent_root=None, source_root=None, target_root=None): + if target == root['node'].label: # node label + target_root = root + for child in root['children']: + if source == child['node'].label: # node label + source_root = child + if target == child['node'].label: # node label + target_root = child + if source_root is not None and target_root is not None and parent_root is None: + parent_root = root if source_root['node'].parent().label == root['node'].label else target_root + return parent_root, source_root, target_root + else: + parent_root, source_root, target_root = sub_descend(child, parent_root=parent_root, source_root=source_root, target_root=target_root) + return parent_root, source_root, target_root - if optimal_init: - # Remove some mass from the parent - root["node"].variational_parameters["locals"]["nu_log_mean"] = np.array(0.0) - root["node"].variational_parameters["locals"]["nu_log_std"] = np.array(0.0) - root["children"][-1]["node"].data_ass_logits = -np.inf * np.ones( - (self.num_data) - ) - baseline = np.append( - 1, np.exp(self.root["node"].root["node"].log_baseline_caller()) - ) + def descend(root): + if root['node'].label in source: # TSSB label + parent_root, source_root, target_root = sub_descend(root['node'].root) + root['node'].merge_nodes(parent_root, source_root, target_root) + for child in root['children']: + descend(child) - if factor_idx is not None: - target_genes = np.argsort( - np.abs( - self.root["node"] - .root["node"] - .variational_parameters["globals"]["noise_factors_mean"][ - factor_idx - ] - ) - )[-5:] - root["children"][-1]["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ][target_genes] = -1.0 - else: - data_indices = np.where(np.array(self.assignments) == root["node"])[0] - if len(data_indices) > 0: - data_in_node = np.array(self.data)[data_indices] - target_genes = np.argsort(np.var(np.log(data_in_node + 1), axis=0))[ - -10: - ] - root["children"][-1]["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ][target_genes] = -1.0 + descend(self.root) - if return_parent_root: - return root - else: - return root["children"][-1] - def perturb_node(self, node, target): - # Perturb parameters of node to become closer to data explained by target - if isinstance(node, str) and isinstance(target, str): - self.plot_tree(super_only=False) - nodes_list = np.array(self.get_nodes()) - node_labels = np.array([node.label for node in nodes_list]) - node = nodes_list[np.where(node_labels == node)[0][0]] - target = nodes_list[np.where(node_labels == target)[0][0]] - - # data_indices = list(target.data.copy()) - data_indices = np.where(np.array(self.assignments) == target)[0] - - if len(data_indices) > 0: - index = np.random.choice(np.array(data_indices)) - # worst_index = np.argmin(root['node'].data_ass_logits[data_indices]) - # worst_index = np.random.choice(np.array(data_indices)[np.array([np.argsort(target.data_ass_logits[data_indices])[:5]])].ravel()) - logger.debug(f"Setting node to explain datum {index}") - worst_datum = self.data[index] - baseline = np.append( - 1, np.exp(self.root["node"].root["node"].log_baseline_caller()) - ) - noise = ( - self.root["node"] - .root["node"] - .variational_parameters["globals"]["cell_noise_mean"][index] - .dot( - self.root["node"] - .root["node"] - .variational_parameters["globals"]["noise_factors_mean"] - ) - ) - total_rna = np.sum( - baseline - * node.cnvs - / 2 - * np.exp( - node.variational_parameters["locals"]["unobserved_factors_mean"] - + noise - ) - ) - node.variational_parameters["locals"]["unobserved_factors_mean"] = np.log( - (worst_datum + 1) - * total_rna - / ( - self.root["node"].root["node"].lib_sizes[index] - * baseline - * node.cnvs - / 2 - * np.exp(noise) - ) - ) - node.set_mean( - node.get_mean( - unobserved_factors=node.variational_parameters["locals"][ - "unobserved_factors_mean" - ], - baseline=baseline, - ) - ) + def swap_nodes(self, parent_node, child_node): + # Put parameters of child in parent and of parent in child + parent_params = deepcopy(parent_node.variational_parameters) + parent_suffstats = deepcopy(parent_node.suff_stats) + child_params = deepcopy(child_node.variational_parameters) + child_suffstats = deepcopy(child_node.suff_stats) + parent_node.variational_parameters = child_params + parent_node.suff_stats = child_suffstats + child_node.variational_parameters = parent_params + child_node.suff_stats = parent_suffstats + + def swap_root(self, parent, child): + def descend(root, depth=0): + if depth > 0: + if root['node'].label == parent: # TSSB label + for ch in root['node'].root['children']: + print(ch['node'].label) + if ch['node'].label == child: + self.swap_nodes(root['node'].root['node'], ch['node']) + for ch in root['children']: + descend(ch, depth+1) + + descend(self.root) def remove_last_leaf_node(self, parent_label): nodes = self._get_nodes(get_roots=True) @@ -2342,40 +1761,6 @@ def prune_reattach(self, node, new_parent): # node.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] = np.log(node.unobserved_factors_kernel_concentration_caller())*np.ones((n_genes,)) # node.variational_parameters['locals']['unobserved_factors_kernel_log_std'] += .5 - def pivot_reattach_to(self, subtree, pivot): - nodes = self._get_nodes(get_roots=False) - nodes = np.array(nodes) - if isinstance(subtree, str) or isinstance(pivot, str): - self.plot_tree(super_only=False) - node_labels = np.array([node.label for node in nodes]) - subtree = nodes[np.where(node_labels == subtree)[0][0]] - subtree = subtree.tssb - pivot = nodes[np.where(node_labels == pivot)[0][0]] - - subtree_label = subtree.label - - root_node_idx = np.where(nodes == subtree.root["node"])[0][0] - root_node = nodes[root_node_idx] - pivot_node_idx = np.where(nodes == pivot)[0][0] - pivot_node = nodes[pivot_node_idx] - - subtrees = self.get_subtrees(get_roots=True) - subtree_objs = np.array([s[0] for s in subtrees]) - subtree_idx = np.where(subtree_objs == subtree)[0][0] - subtrees[subtree_idx][1]["pivot_node"] = pivot_node - - # prev_unobserved_factors = root_node[0].unobserved_factors_mean - # root_node.variational_parameters['locals']['unobserved_factors_kernel_log_std'] += .5 - root_node.set_parent(pivot_node, reset=False) - root_node.set_mean(variational=True) - # Reset the kernel posterior - # root_node[0].unobserved_factors_kernel_log_mean = -1.*jnp.ones((root_node[0].n_genes,)) - # root_node[0].unobserved_factors_kernel_log_std = -1.*jnp.ones((root_node[0].n_genes,)) - # root_node[0].unobserved_factors_mean = prev_unobserved_factors - # root_node[0].set_mean(variational=True) - # self.update_ass_logits(variational=True) - # self.assign_to_best() - def extract_pivot(self, node): """ extract_pivot(B): @@ -2405,10 +1790,10 @@ def extract_pivot(self, node): node.variational_parameters["locals"]["unobserved_factors_log_std"] ) paramsB_k = np.array( - node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"] + node.variational_parameters["locals"]["unobserved_factors_kernel_log_shape"] ) paramsB_k_std = np.array( - node.variational_parameters["locals"]["unobserved_factors_kernel_log_std"] + node.variational_parameters["locals"]["unobserved_factors_kernel_log_rate"] ) # Set new node's parameters equal to the previous parameters of node @@ -2419,10 +1804,10 @@ def extract_pivot(self, node): "unobserved_factors_log_std" ] = np.array(paramsB_std) new_node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" + "unobserved_factors_kernel_log_shape" ] = np.array(paramsB_k) new_node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" + "unobserved_factors_kernel_log_rate" ] = np.array(paramsB_k_std) new_node.set_mean(variational=True) @@ -2447,7 +1832,7 @@ def extract_pivot(self, node): np.abs(node.variational_parameters["locals"]["unobserved_factors_mean"]) > 0.5 )[0] - node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"][ + node.variational_parameters["locals"]["unobserved_factors_kernel_log_shape"][ affected_genes ] = np.log(node.unobserved_factors_kernel_concentration_caller()) @@ -2456,75 +1841,6 @@ def extract_pivot(self, node): return new_node_root - def push_subtree(self, node): - """ - push_subtree(B): - A-0 -> B -> B-0 to A-0 -> A-0-0 -> B -> B-0 - """ - if isinstance(node, str): - self.plot_tree(super_only=False) - nodes = self.get_nodes(None) - node_labels = np.array([node.label for node in nodes]) - node = nodes[np.where(node_labels == node)[0][0]] - - if node.parent() is None: - raise ValueError("Can't pull from root tree") - if not node.is_observed: - raise ValueError("Can't pull unobserved node") - - parent_node = node.parent() - children = np.array(list(node.children())) - idx = np.argsort([c.label for c in children]) - children = children[idx] - if len(children) > 0: - children = [n for n in children if not n.is_observed] - child_node = None - if len(children) > 0: - child_node = children[0] - - # Add node below parent - new_node_root = self.add_node_to(parent_node, return_parent_root=False) - new_node = new_node_root["node"] - self.pivot_reattach_to(node.tssb, new_node) - paramsB = np.array( - node.variational_parameters["locals"]["unobserved_factors_mean"] - ) - paramsB_k = np.array( - node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"] - ) - dataB = node.data.copy() - logitsB = np.array(node.data_ass_logits) - if child_node: - node.variational_parameters["locals"]["unobserved_factors_mean"] = np.array( - child_node.variational_parameters["locals"]["unobserved_factors_mean"] - ) - node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.array( - child_node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] - ) - node.data = child_node.data.copy() - node.data_ass_logits = np.array(child_node.data_ass_logits) - # Merge B with child that it has become equal to - self.merge_nodes(child_node, node) - node.set_mean(variational=True) - - # Set new node's parameters equal to the previous parameters of node - new_node.variational_parameters["locals"]["unobserved_factors_mean"] = np.array( - paramsB - ) - new_node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.array(paramsB_k) - new_node.set_mean(variational=True) - if child_node: - new_node.data = dataB.copy() - new_node.data_ass_logits = np.array(logitsB) - - return new_node_root - def path_to_node(self, node): path = [] path.append(node) @@ -2567,788 +1883,6 @@ def path_between_nodes(self, nodeA, nodeB): path = path + list(pathB[np.where(pathB == mrca)[0][0] :]) return path - def swap_nodes(self, nodeA, nodeB, update_pivots=True, use_top_kernel=True): - self.plot_tree(super_only=False) - if isinstance(nodeA, str) and isinstance(nodeB, str): - nodes = self.get_nodes(None) - node_labels = np.array([node.label for node in nodes]) - nodeA = nodes[np.where(node_labels == nodeA)[0][0]] - nodeB = nodes[np.where(node_labels == nodeB)[0][0]] - - # If we are swapping the root node, need to change the baseline too - root_node = None - non_root_node = None - child_unobserved_factors = None - initial_log_baseline = None - if nodeA.parent() is None: - root_node = nodeA - non_root_node = nodeB - elif nodeB.parent() is None: - root_node = nodeB - non_root_node = nodeA - - def swap_params(nA, nB): - params_names = list(nA.variational_parameters["locals"].keys()) - paramsA_list = [ - np.array(nA.variational_parameters["locals"][key]) - for key in params_names - ] - # paramsA = np.array(nA.variational_parameters['locals']['unobserved_factors_mean']) - paramsA_k = np.array( - nA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] - ) - # nu_sticksA = np.array(nA.variational_parameters['locals']['nu_log_mean']) - # psi_sticksA = np.array(nA.variational_parameters['locals']['psi_log_mean']) - paramsB_k = np.array( - nB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] - ) - params_k = np.array([paramsA_k, paramsB_k]) - # We will initialize the kernels to be the one with the most events, just in case - top_k_idx = np.argmax(np.array([np.var(paramsA_k), np.var(paramsB_k)])) - - # Relax kernel of intermediate nodes - int_nodes = [] - if nA.label in nB.label: - n = nB - while True: - n = n.parent() - if n == nA: - break - else: - int_nodes.append(n) - elif nB.label in nA.label: - n = nA - while True: - n = n.parent() - if n == nB: - break - else: - int_nodes.append(n) - if len(int_nodes) > 0: - if use_top_kernel: - for node in int_nodes: - node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.clip(params_k[top_k_idx], -3, 10) - - dataA = nA.data.copy() - logitsA = np.array(nA.data_ass_logits) - data_weights_A = np.array(nA.data_weights) - for param in params_names: - nA.variational_parameters["locals"][param] = np.array( - nB.variational_parameters["locals"][param] - ) - if use_top_kernel: - nA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.array(params_k[top_k_idx]) - if nA == nB.parent(): - # If one is a child of the other, the child should not have high kernel where parent does - parent_events = np.where( - np.abs( - nA.variational_parameters["locals"]["unobserved_factors_mean"] - ) - > 0.1 - )[0] - nA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.array( - nA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] - ) - nA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ][parent_events] -= 1 - nA.data = nB.data.copy() - nA.data_ass_logits = np.array(nB.data_ass_logits) - nA.data_weights = np.array(nB.data_weights) - nA.set_mean(variational=True) - for i, param in enumerate(params_names): - nB.variational_parameters["locals"][param] = np.array(paramsA_list[i]) - if use_top_kernel: - nB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.array(params_k[top_k_idx]) - if nB == nA.parent(): - # If one is a child of the other, the child should not have high kernel where parent does - parent_events = np.where( - np.abs( - nB.variational_parameters["locals"]["unobserved_factors_mean"] - ) - > 0.1 - )[0] - nB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.array( - nB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] - ) - nB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ][parent_events] -= 1 - nB.data = dataA.copy() - nB.data_ass_logits = np.array(logitsA) - nB.data_weights = np.array(data_weights_A) - nB.set_mean(variational=True) - - if not root_node: - if nodeA.tssb == nodeB.tssb: - swap_params(nodeA, nodeB) - else: - # e.g. A-0 with C - if nodeA.parent() == nodeB or nodeB.parent() == nodeA: - unobserved_node_idx = np.where( - [not nodeA.is_observed, not nodeB.is_observed] - )[0] - if len(unobserved_node_idx) > 1: - logger.debug( - "Warning: both nodes to swap are unobserved but are part of different TSSBs:" - ) - logger.debug( - f"{nodeA.label}: {nodeA.tssb.label}, {nodeB.label}: {nodeB.tssb.label}" - ) - logger.debug("Proceeding without swapping.") - return - elif len(unobserved_node_idx) == 1: - # change A -> A-0 -> B to A-> B -> B-0: and put params of A-0 in B and of B in B-0 - unobserved_node_idx = unobserved_node_idx[0] - unobserved_node = [nodeA, nodeB][unobserved_node_idx] - observed_node = [nodeA, nodeB][1 - unobserved_node_idx] - parent_unobserved = unobserved_node.parent() - # if unobserved node is parent of more than one subtree, update pivot of the others to parent of unobserved_node - unobserved_node_children = np.array( - list(unobserved_node.children()) - ) - idx = np.argsort([n.label for n in unobserved_node_children]) - unobserved_node_children = unobserved_node_children[idx] - if ( - np.sum( - np.array( - [ - child.is_observed - for child in unobserved_node_children - ] - ) - ) - > 1 - ): - for child in list(unobserved_node_children): - if child.is_observed: - self.pivot_reattach_to( - child.tssb, parent_unobserved - ) - # if not (np.sum(np.array([child.is_observed for child in unobserved_node_children])) > 1): - # Update params - # init_obs_params = observed_node.variational_parameters['locals']['unobserved_factors_mean'] - # observed_node.variational_parameters['locals']['unobserved_factors_mean'] = parent_unobserved.variational_parameters['locals']['unobserved_factors_mean'] - # unobserved_node.variational_parameters['locals']['unobserved_factors_mean'] = init_obs_params - swap_params(nodeA, nodeB) - # observed_node.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] += 2#np.log(observed_node.unobserved_factors_kernel_concentration_caller())*np.ones((observed_node.n_genes,)) - # unobserved_node.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] += 2#np.log(unobserved_node.unobserved_factors_kernel_concentration_caller())*np.ones((unobserved_node.n_genes,)) - - # Put same-subtree children of unobserved node in its parent in order to not move a whole subtree - nodes = self._get_nodes(get_roots=True) - unobserved_node_tssb_root = [ - node[1] for node in nodes if node[0] == unobserved_node - ][0] - parent_unobserved_node_tssb_root = [ - node[1] for node in nodes if node[0] == parent_unobserved - ][0] - for i, unobs_child in enumerate( - unobserved_node_tssb_root["children"] - ): - unobs_child["node"].set_parent(unobserved_node.parent()) - # Add children from unobserved to the parent dict - parent_unobserved_node_tssb_root["children"].append( - unobs_child - ) - parent_unobserved_node_tssb_root["sticks"] = np.vstack( - [ - parent_unobserved_node_tssb_root["sticks"], - unobserved_node_tssb_root["sticks"][i], - ] - ) - if len(unobserved_node_tssb_root["children"]) > 0: - # Remove children from unobserved - unobserved_node_tssb_root["sticks"] = np.array([]).reshape( - 0, 1 - ) - unobserved_node_tssb_root["children"] = [] - - # Now move the unobserved node to below the observed one - observed_node.set_parent(parent_unobserved) - unobserved_node.set_parent(observed_node) - unobserved_node.tssb = observed_node.tssb - unobserved_node.cnvs = observed_node.cnvs - unobserved_node.observed_parameters = ( - observed_node.observed_parameters - ) - n_siblings = len(list(observed_node.children())) - unobserved_node.label = ( - observed_node.label + "-" + str(n_siblings - 1) - ) - - nodes = self._get_nodes(get_roots=True) - unobserved_node_tssb_root = [ - node[1] for node in nodes if node[0] == unobserved_node - ][0] - parent_unobserved_node_tssb_root = [ - node[1] for node in nodes if node[0] == parent_unobserved - ][0] - - # Update dicts - # Remove unobserved_node from its parent dict - childnodes = np.array( - [ - n["node"] - for n in parent_unobserved_node_tssb_root["children"] - ] - ) - tokeep = ( - np.where(childnodes != unobserved_node_tssb_root["node"])[0] - .astype(int) - .ravel() - ) - parent_unobserved_node_tssb_root[ - "sticks" - ] = parent_unobserved_node_tssb_root["sticks"][tokeep] - parent_unobserved_node_tssb_root["children"] = list( - np.array(parent_unobserved_node_tssb_root["children"])[ - tokeep - ] - ) - # Update observed_node's pivot_node to unobserved_node's parent - observed_node_ntssb_root = observed_node.tssb.get_ntssb_root() - observed_node_ntssb_root["pivot_node"] = parent_unobserved - # Add unobserved_node to observed_node's dict - observed_node_tssb_root = observed_node_ntssb_root["node"].root - observed_node_tssb_root["children"].append( - unobserved_node_tssb_root - ) - observed_node_tssb_root["sticks"] = np.vstack( - [observed_node_tssb_root["sticks"], 1.0] - ) - else: # random swap: change data and parameters - swap_params(nodeA, nodeB) - else: # e.g. A with A-0 - # init_baseline = np.mean(self.data / np.sum(self.data, axis=1).reshape(-1,1) * self.data.shape[1], axis=0) - # init_baseline = init_baseline / init_baseline[0] - # init_log_baseline = np.log(init_baseline[1:] + 1e-6) - init_bs = np.array( - root_node.variational_parameters["globals"]["log_baseline_mean"] - - np.mean( - root_node.variational_parameters["globals"]["log_baseline_mean"] - ) - ) - root_node.variational_parameters["globals"]["log_baseline_mean"] = np.log( - non_root_node.node_mean / non_root_node.node_mean[0] - )[1:] - root_node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] = np.zeros((self.data.shape[1],)) - root_node.variational_parameters["locals"]["unobserved_factors_log_std"] = ( - np.zeros((self.data.shape[1],)) - 2 - ) - root_node.set_mean(variational=True) - non_root_node_init_psi = np.array( - non_root_node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] - ) - nodes = self.get_nodes()[1:] - for node in nodes: - node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] -= non_root_node_init_psi - # node.variational_parameters['locals']['unobserved_factors_log_std'] += .5 - # node.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] += 1 - # node.variational_parameters['locals']['unobserved_factors_kernel_log_std'] += .5 - # non_root_node.variational_parameters['locals']['unobserved_factors_mean'] = np.clip(normal_sample(0., gamma_sample(root_node.unobserved_factors_kernel_concentration_caller(), - # root_node.unobserved_factors_kernel_concentration_caller(), size=self.data.shape[1])), a_min=-5, a_max=5) - data_indices = np.where(np.array(self.assignments) == root_node)[ - 0 - ] # list(root_node.data) - if len(data_indices) > 0: - # idx = np.random.choice(np.array(data_indices)) - # print(f'Setting new node to explain datum {idx}') - # datum = self.data[idx] - # baseline = np.append(1, np.exp(non_root_node.log_baseline_caller())) - # total_rna = np.sum(baseline * non_root_node.cnvs/2 * np.exp(root_node.variational_parameters['locals']['unobserved_factors_mean'])) - # non_root_node.variational_parameters['locals']['unobserved_factors_mean'] = np.log((datum+1) * total_rna/(root_node.lib_sizes[idx]*baseline * root_node.cnvs/2)) - non_root_node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] = np.zeros((self.data.shape[1],)) - new_bs = np.array( - root_node.variational_parameters["globals"]["log_baseline_mean"] - - np.mean( - root_node.variational_parameters["globals"]["log_baseline_mean"] - ) - ) - non_root_node.variational_parameters["locals"][ - "unobserved_factors_mean" - ][1:] = np.array(init_bs - new_bs) - non_root_node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.log( - root_node.unobserved_factors_kernel_concentration_caller() - ) * np.ones( - (self.data.shape[1],) - ) - data_in_node = np.array(self.data)[data_indices] - target_genes_1 = np.argsort(np.var(np.log(data_in_node + 1), axis=0))[ - -5: - ] - target_genes_2 = np.where( - np.abs( - non_root_node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] - ) - > 0.5 - )[0] - target_genes = np.unique( - np.concatenate([np.array(target_genes_1), np.array(target_genes_2)]) - ) - non_root_node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ][target_genes] = -1.0 - - non_root_node.set_mean(variational=True) - dataRoot = root_node.data.copy() - logitsRoot = np.array(root_node.data_ass_logits) - root_node.data = non_root_node.data.copy() - root_node.data_ass_logits = np.array(non_root_node.data_ass_logits) - non_root_node.data = dataRoot.copy() - non_root_node.data_ass_logits = np.array(logitsRoot) - - nodeA_children = np.array(list(nodeA.children())) - idx = np.argsort([n.label for n in nodeA_children]) - nodeA_children = nodeA_children[idx] - nodeA_children = [ - node - for node in nodeA_children - if not node.is_observed and node != nodeB - ] - - nodeB_children = np.array(list(nodeB.children())) - idx = np.argsort([n.label for n in nodeB_children]) - nodeB_children = nodeB_children[idx] - nodeB_children = [ - node - for node in nodeB_children - if not node.is_observed and node != nodeA - ] - - if not non_root_node.is_observed: - if non_root_node.parent() == root_node: - # Go through children of nodeA and set them as children of nodeB and vice-versa - for nodeA_child in nodeA_children: - self.prune_reattach(nodeA_child, nodeB) - - for nodeB_child in nodeB_children: - self.prune_reattach(nodeB_child, nodeA) - - if update_pivots: - if nodeA.tssb == nodeB.tssb: - root = nodeB.tssb.get_ntssb_root() - # For each subtree, if pivot was swapped, update it - for child in root["children"]: - if child["pivot_node"] == nodeA: - child["pivot_node"] = nodeB - child["node"].root["node"].set_parent(nodeB, reset=False) - child["node"].root["node"].set_mean(variational=True) - elif child["pivot_node"] == nodeB: - child["pivot_node"] = nodeA - child["node"].root["node"].set_parent(nodeA, reset=False) - child["node"].root["node"].set_mean(variational=True) - - # def swap_nodes(self, nodeA, nodeB): - # if isinstance(nodeA, str) and isinstance(nodeB, str): - # self.plot_tree(super_only=False) - # nodes = self.get_nodes(None) - # node_labels = np.array([node.label for node in nodes]) - # nodeA = nodes[np.where(node_labels == nodeA)[0][0]] - # nodeB = nodes[np.where(node_labels == nodeB)[0][0]] - # - # # If we are swapping the root node, need to change the baseline too - # root_node = None - # non_root_node = None - # child_unobserved_factors = None - # initial_log_baseline = None - # if nodeA.parent() is None: - # root_node = nodeA - # non_root_node = nodeB - # elif nodeB.parent() is None: - # root_node = nodeB - # non_root_node = nodeA - # if root_node and non_root_node: - # child_unobserved_factors = non_root_node.unobserved_factors - # child_unobserved_factors_k = non_root_node.unobserved_factors - # initial_log_baseline = root_node.log_baseline_mean - # - # if not root_node: - # paramsA = nodeA.variational_parameters['locals']['unobserved_factors_mean'] - # paramsA_k = nodeA.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] - # # sticks_alphaA = nodeA.nu_log_alpha - # # sticks_betaA = nodeA.nu_log_beta - # dataA = nodeA.data - # logitsA = nodeA.data_ass_logits - # - # nodeA.variational_parameters['locals']['unobserved_factors_mean'] = nodeB.variational_parameters['locals']['unobserved_factors_mean'] - # nodeA.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] = nodeB.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] - # nodeA.data = nodeB.data - # nodeA.data_ass_logits = nodeB.data_ass_logits - # # nodeA.nu_log_alpha = nodeB.nu_log_alpha - # # nodeA.nu_log_beta = nodeB.nu_log_beta - # nodeA.set_mean(variational=True) - # - # nodeB.variational_parameters['locals']['unobserved_factors_mean'] = paramsA - # nodeB.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] = paramsA_k - # nodeB.data = dataA - # nodeB.data_ass_logits = logitsA - # # nodeB.nu_log_alpha = sticks_alphaA - # # nodeB.nu_log_beta = sticks_betaA - # nodeB.set_mean(variational=True) - # - # if root_node and non_root_node: - # root_node.variational_parameters['locals']['unobserved_factors_mean'] = child_unobserved_factors - # root_node.variational_parameters['globals']['log_baseline_mean'] = non_root_node.variational_parameters['locals']['unobserved_factors_mean'][1:] - # root_node.set_mean(variational=True) - # non_root_node.variational_parameters['locals']['unobserved_factors_mean'] = np.append(0., initial_log_baseline) - # non_root_node.set_mean(variational=True) - # - # if nodeA.tssb == nodeB.tssb: - # # Go to subtrees - # # For each subtree, if pivot was swapped, update it - # root = nodeB.tssb.get_ntssb_root() - # for child in root['children']: - # if child['pivot_node'] == nodeA: - # child['pivot_node'] = nodeB - # child['node'].root['node'].set_parent(nodeB, reset=False) - # child['node'].root['node'].set_mean(variational=True) - # elif child['pivot_node'] == nodeB: - # child['pivot_node'] = nodeA - # child['node'].root['node'].set_parent(nodeA, reset=False) - # child['node'].root['node'].set_mean(variational=True) - # else: - # if nodeA.parent() == nodeB or nodeB.parent() == nodeA: - # unobserved_node_idx = np.where([not nodeA.is_observed, not nodeB.is_observed])[0] - # if len(unobserved_node_idx) > 0: - # # change A -> A-0 -> B to A-> B -> B-0: - # unobserved_node_idx = unobserved_node_idx[0] - # unobserved_node = [nodeA, nodeB][unobserved_node_idx] - # # if unobserved node is parent of more than one subtree, don't proceed with full swap - # unobserved_node_children = unobserved_node.children() - # if not (np.sum(np.array([child.is_observed for child in unobserved_node_children])) > 1): - # observed_node = [nodeA, nodeB][1 - unobserved_node_idx] - # parent_unobserved = unobserved_node.parent() - # observed_node.set_parent(parent_unobserved) - # unobserved_node.set_parent(observed_node) - # - # nodes = self._get_nodes(get_roots=True) - # unobserved_node_tssb_root = [node[1] for node in nodes if node[0] == unobserved_node][0] - # parent_unobserved_node_tssb_root = [node[1] for node in nodes if node[0] == parent_unobserved][0] - # - # # Update dicts - # # Remove unobserved_node from its parent dict - # childnodes = np.array([n['node'] for n in parent_unobserved_node_tssb_root['children']]) - # tokeep = np.where(childnodes != unobserved_node_tssb_root['node'])[0].astype(int).ravel() - # parent_unobserved_node_tssb_root['sticks'] = parent_unobserved_node_tssb_root['sticks'][tokeep] - # parent_unobserved_node_tssb_root['children'] = list(np.array(parent_unobserved_node_tssb_root['children'])[tokeep]) - # - # # Update observed_node's pivot_node to unobserved_node's parent - # observed_node_ntssb_root = observed_node.tssb.get_ntssb_root() - # observed_node_ntssb_root['pivot_node'] = parent_unobserved - # - # # Add unobserved_node to observed_node's dict - # observed_node_tssb_root = observed_node_ntssb_root['node'].root - # observed_node_tssb_root['children'].append(unobserved_node_tssb_root) - # observed_node_tssb_root['sticks'] = np.vstack([observed_node_tssb_root['sticks'], 1.]) - - def merge_nodes(self, nodeA, nodeB, optimal_params=True): - if isinstance(nodeA, str) and isinstance(nodeB, str): - self.plot_tree(super_only=False) - nodes = self.get_nodes(None) - node_labels = np.array([node.label for node in nodes]) - nodeA = nodes[np.where(node_labels == nodeA)[0][0]] - nodeB = nodes[np.where(node_labels == nodeB)[0][0]] - - nodes = self._get_nodes(get_roots=True) - nodes_list = np.array([node[0] for node in nodes]) - nodeA_idx = np.where(nodes_list == nodeA)[0][0] - nodeB_idx = np.where(nodes_list == nodeB)[0][0] - nodeA_root = nodes[nodeA_idx][1] - nodeB_root = nodes[nodeB_idx][1] - nodeA_parent_root = nodes[np.where(np.array(nodes) == nodeA.parent())[0][0]][1] - if nodeB.parent() is not None: - nodeB_parent_root = nodes[ - np.where(np.array(nodes) == nodeB.parent())[0][0] - ][1] - - numDataA, numDataB = (len(nodeA.data), len(nodeB.data)) - - nodeA_init_psi = np.array( - nodeA.variational_parameters["locals"]["unobserved_factors_mean"] - ) - - if not nodeA.is_observed or not nodeB.is_observed: - if nodeA.tssb == nodeB.tssb: - # Move child nodes of nodeA to nodeB and remove nodeA - # And make sure the children of nodeA keep their parameters - n_childrenA = len(nodeA_root["children"]) - n_childrenB = len(nodeB_root["children"]) - for i, nodeA_child in enumerate(nodeA_root["children"]): - nodeA_child["node"].set_parent(nodeB_root["node"], reset=False) - nodeA_child["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.maximum( - nodeA_child["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - nodeA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - ) - nodeA_child["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ] = np.minimum( - nodeA_child["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ], - nodeA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ], - ) - nodeA_child["node"].set_mean(variational=True) - nodeB_root["children"].append(nodeA_child) - nodeB_root["sticks"] = np.vstack([nodeB_root["sticks"], 1.0]) - nodeA_root["children"] = [] - - # If nodeA was the pivot of a downstream tree, update the pivot to nodeB - nodeA_children = np.array(list(nodeA_root["node"].children().copy())) - idx = np.argsort([n.label for n in nodeA_children]) - nodeA_children = nodeA_children[idx] - for nodeA_child in nodeA_children: - if nodeA_child.tssb != nodeA_root["node"].tssb: - nodeA_child.set_parent(nodeB_root["node"], reset=False) - nodeA_child.set_mean(variational=True) - ntssb_root = nodeA_child.tssb.get_ntssb_root() - ntssb_root["pivot_node"] = nodeB_root["node"] - nodeA_root["node"].children().clear() - - nodeB.data_weights = np.array(nodeA.data_weights) - nodeB.data.update(nodeA.data) - else: - # nodeB is parent (and pivot) of nodeA - # Set parent of nodeB as pivot of nodeA's subtree - ntssb_root = nodeA.tssb.get_ntssb_root() - ntssb_root["pivot_node"] = nodeB_parent_root["node"] - nodeA.set_parent(nodeB.parent(), reset=False) - - # Set all children of nodeB as children of nodeB's parent - for i, nodeB_child in enumerate(nodeB_root["children"]): - nodeB_child["node"].set_parent( - nodeB_parent_root["node"], reset=False - ) - nodeB_child["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.maximum( - nodeB_child["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - nodeB_parent_root["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - ) - nodeB_child["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ] = np.minimum( - nodeB_child["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ], - nodeB_parent_root["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ], - ) - nodeB_child["node"].set_mean(variational=True) - nodeB_parent_root["children"].append(nodeB_child) - nodeB_parent_root["sticks"] = np.vstack( - [nodeB_parent_root["sticks"], 1.0] - ) - nodeB_root["children"] = [] - - # If nodeB was pivot of another tree, update the pivot to nodeB_parent - nodeB_children = np.array(list(nodeB_root["node"].children().copy())) - idx = np.argsort([n.label for n in nodeB_children]) - nodeB_children = nodeB_children[idx] - for nodeB_child in nodeB_children: - if ( - nodeB_child.tssb != nodeB_root["node"].tssb - and nodeB_child.tssb != nodeA.tssb - ): - nodeB_child.set_parent(nodeB_parent_root["node"], reset=False) - nodeB_child.set_mean(variational=True) - ntssb_root = nodeB_child.tssb.get_ntssb_root() - ntssb_root["pivot_node"] = nodeB_parent_root["node"] - nodeB_root["node"].children().clear() - - nodeA.data_weights = np.array(nodeB.data_weights) - nodeA.data.update(nodeB.data) - else: - nodeB.data_weights = np.array(nodeA.data_weights) - nodeB.data.update(nodeA.data) - nodeA.data.clear() - - # Keep node that explains the most data - if optimal_params: - if nodeA.tssb == nodeB.tssb: - # Merge kernels - nodeB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.maximum( - nodeA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - nodeB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - ) - nodeB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ] = np.minimum( - nodeA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ], - nodeB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ], - ) - # if numDataA > numDataB: - if nodeB.parent() is not None: - nodeB.variational_parameters["locals"][ - "unobserved_factors_mean" - ] = np.array( - nodeA.variational_parameters["locals"][ - "unobserved_factors_mean" - ] - ) - nodeB.variational_parameters["locals"][ - "unobserved_factors_log_std" - ] = np.array( - nodeA.variational_parameters["locals"][ - "unobserved_factors_log_std" - ] - ) - nodeB.variational_parameters["locals"][ - "nu_log_mean" - ] = np.array( - nodeA.variational_parameters["locals"]["nu_log_mean"] - ) - nodeB.variational_parameters["locals"]["nu_log_std"] = np.array( - nodeA.variational_parameters["locals"]["nu_log_std"] - ) - nodeB.variational_parameters["locals"][ - "psi_log_mean" - ] = np.array( - nodeA.variational_parameters["locals"]["psi_log_mean"] - ) - nodeB.variational_parameters["locals"][ - "psi_log_std" - ] = np.array( - nodeA.variational_parameters["locals"]["psi_log_std"] - ) - else: # We're trying to merge to root and root has no data, so adjust its baseline - nodeB.variational_parameters["globals"][ - "log_baseline_mean" - ] = np.log(nodeA.node_mean / nodeA.node_mean[0])[1:] - nodeB.variational_parameters["locals"][ - "unobserved_factors_mean" - ] *= 0.0 - # Also adjust all unobserved factors by removing previous nodeA psi from all nodes below it: it is now present as the baseline - def update_psi(node_obj): - node_obj.variational_parameters["locals"][ - "unobserved_factors_mean" - ] -= nodeA_init_psi - # node_obj.variational_parameters["locals"][ - # "unobserved_factors_kernel_log_mean" - # ] += 0.5 - for child in node_obj.children(): - update_psi(child) - update_psi(nodeA) - # - # nodes = self.get_nodes()[1:] - # for node in nodes: - # node.variational_parameters["locals"][ - # "unobserved_factors_mean" - # ] -= nodeA_init_psi - # node.variational_parameters["locals"][ - # "unobserved_factors_kernel_log_mean" - # ] += 0.5 - nodeB.set_mean(variational=True) - else: - nodeA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.maximum( - nodeA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - nodeB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - ) - nodeA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ] = np.minimum( - nodeA.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ], - nodeB.variational_parameters["locals"][ - "unobserved_factors_kernel_log_std" - ], - ) - if nodeB.parent() is None: - nodeB.variational_parameters["globals"][ - "log_baseline_mean" - ] = np.log(nodeA.node_mean / nodeA.node_mean[0])[1:] - nodes = self.get_nodes()[1:] - for node in nodes: - node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] -= nodeA_init_psi - node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] += 1.0 - - if not nodeA.is_observed or not nodeB.is_observed: - if nodeA.tssb == nodeB.tssb: - # Remove nodeA from tssb root dict - nodes = np.array([n["node"] for n in nodeA_parent_root["children"]]) - tokeep = np.where(nodes != nodeA)[0].astype(int).ravel() - nodeA_root["node"].kill() - del nodeA_root["node"] - - nodeA_parent_root["sticks"] = nodeA_parent_root["sticks"][tokeep] - nodeA_parent_root["children"] = list( - np.array(nodeA_parent_root["children"])[tokeep] - ) - else: - # Remove nodeB from tssb root dict - nodes = np.array([n["node"] for n in nodeB_parent_root["children"]]) - tokeep = np.where(nodes != nodeB)[0].astype(int).ravel() - nodeB_root["node"].kill() - del nodeB_root["node"] - - nodeB_parent_root["sticks"] = nodeB_parent_root["sticks"][tokeep] - nodeB_parent_root["children"] = list( - np.array(nodeB_parent_root["children"])[tokeep] - ) - def subtree_reattach_to(self, node, target_clone, optimal_init=True): # Get the node and its parent root nodes = self._get_nodes(get_roots=True) @@ -3448,7 +1982,7 @@ def descend(root): )[0] for node in nodes_below_nodeA: node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" + "unobserved_factors_kernel_log_shape" ][parent_affected_genes] = -1 if roots[nodeA_idx]["node"].parent().parent() is not None: @@ -3464,12 +1998,12 @@ def descend(root): genes = np.where(roots[nodeA_idx]["node"].cnvs != 2)[0] roots[nodeA_idx]["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" + "unobserved_factors_kernel_log_shape" ][genes] -= 1 genes = np.where(init_cnvs != 2)[0] roots[nodeA_idx]["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" + "unobserved_factors_kernel_log_shape" ][genes] -= 1 # Need to accomodate events this node has that were previously inherited @@ -3482,7 +2016,7 @@ def descend(root): > 0.5 )[0] roots[nodeA_idx]["node"].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" + "unobserved_factors_kernel_log_shape" ][affected_genes] = -1 # Reset variational parameters: all log_std and unobserved factors kernel @@ -3690,7 +2224,15 @@ def label_nodes(self, counts=False, names=False): elif not names or counts is True: self.label_nodes_counts() - def set_node_names(self, root_name="X"): + def set_node_names(self): + def descend(root): + root['node'].set_node_names(root_name=root['label']) + for child in root["children"]: + descend(child) + + descend(self.root) + + def set_tree_names(self, root_name="X"): self.root["label"] = str(root_name) self.root["node"].label = str(root_name) @@ -3831,31 +2373,27 @@ def descend(root, g): return g def get_node_unobs(self): - nodes = self.get_nodes(None) + nodes = self.get_nodes() unobs = [] estimated = ( - np.var(nodes[1].variational_parameters["locals"]["unobserved_factors_mean"]) + np.var(nodes[1].variational_parameters["kernel"]["state"]['mean']) != 0 ) if estimated: logger.debug("Getting the learned unobserved factors.") for node in nodes: - unobs_factors = ( - node.unobserved_factors - if not estimated - else node.variational_parameters["locals"]["unobserved_factors_mean"] - ) + unobs_factors = node.params[0] unobs.append(unobs_factors) return nodes, unobs def get_node_unobs_affected_genes(self): - nodes = self.get_nodes(None) + nodes = self.get_nodes() unobs = [] estimated = ( np.var( - nodes[1].variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] + np.exp(nodes[1].variational_parameters["kernel"][ + "direction" + ]['log_alpha']) ) != 0 ) @@ -3870,22 +2408,24 @@ def get_node_unobs_affected_genes(self): unobs_factors_kernel = ( node.unobserved_factors_kernel if not estimated - else node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] + else np.exp(node.variational_parameters["kernel"]['direction'][ + "log_alpha" + ] - node.variational_parameters["kernel"]['direction'][ + "log_beta" + ]) ) unobs.append(unobs_factors) return nodes, unobs def get_node_obs(self): - nodes = self.get_nodes(None) + nodes = self.get_nodes() obs = [] for node in nodes: obs.append(node.observed_parameters) return nodes, obs def get_avg_node_exp(self, norm=True): - nodes = self.get_nodes(None) + nodes = self.get_nodes() data = self.data if norm: try: @@ -4190,87 +2730,68 @@ def descend(node): descend(self.root["node"].root["node"]) - # - # def path_to_node(self, node_id): - # path = [] - # path.append(node_id) - # parent_id = self.node_dict[node_id]['parent'] - # while parent_id != 'NULL': - # path.append(parent_id) - # parent_id = self.node_dict[parent_id]['parent'] - # return path[::-1][:] - # - # def path_between_nodes(self, nodeA, nodeB): - # pathA = np.array(self.path_to_node(nodeA)) - # pathB = np.array(self.path_to_node(nodeB)) - # path = [] - # # Get MRCA - # i = -1 - # for node in pathA: - # if node in pathB: - # i += 1 - # else: - # break - # mrca = pathA[i] - # pathA = np.array(pathA[::-1]) - # # Get path from A to MRCA - # path = path + list(pathA[:np.where(pathA == mrca)[0][0]]) - # # Get path from MRCA to B - # path = path + list(pathB[np.where(pathB == mrca)[0][0]:]) - # return path - - # TODO: Should have a distance that counts the number of changed genes while going through the path - def get_distance(self, id1, id2, distance="n_nodes"): - path = self.path_between_nodes(id1, id2) - + def get_distance(self, node1, node2): + path = self.path_between_nodes(node1, node2) + path_labels = [n.label for n in path] + node_dict = dict(zip(path_labels, path)) dist = 0 - if distance == "n_nodes": - dist = len(path) - else: - prev_node = path[0] - for node in path: - if node != prev_node: - if dist == "estimated": - dist += np.sqrt( - np.sum( - ( - self.node_dict[node]["node"].variational_parameters[ - "locals" - ]["unobserved_factors_mean"] - - self.node_dict[prev_node][ - "node" - ].variational_parameters["locals"][ - "unobserved_factors_mean" - ] - ) - ** 2 - ) - ) - else: - dist += np.sqrt( - np.sum( - ( - self.node_dict[node].unobserved_factors - - self.node_dict[prev_node][ - "node" - ].unobserved_factors - ) - ** 2 - ) - ) - prev_node = node - dist += self.node_dict[node][distance] + prev_node = path_labels[0] + for node in path_labels: + if node != prev_node: + dist += np.sqrt( + np.sum( + (node_dict[node].get_mean() - node_dict[prev_node].get_mean())** 2 + ) + ) + prev_node = node - return dist + return dist - def get_pairwise_cell_distances(self, distance="n_nodes"): + def get_pairwise_obs_distances(self): n_cells = len(self.assignments) mat = np.zeros((n_cells, n_cells)) - - for i in range(1, n_cells): - id1 = self.assignments[i].label - for j in range(i): - id2 = self.assignments[j].label - mat[i][j] = self.get_distance(str(id1), str(id2), distance=distance) + nodes = self.get_nodes() + + for node1 in nodes: + idx1 = np.where(self.assignments == node1)[0] + if len(idx1) == 0: + continue + for node2 in nodes: + idx2 = np.where(self.assignments == node2)[0] + if len(idx2) == 0: + continue + mat[np.meshgrid(idx1,idx2)] = self.get_distance(node1, node2) return mat + + def set_combined_params(self): + def sub_descend(subroot, obs_param): + for i, child in enumerate(subroot['children']): + child['combined_param'] = child['node'].combine_params(child['param'], obs_param) + sub_descend(child, obs_param) + + def descend(root): + sub_descend(root['node'], root['obs_param']) + for i, child in enumerate(root['children']): + descend(child) + + descend(self.root) + + def set_ntssb_colors(self, **cmap_kwargs): + # Traverse tree to update dict + def descend(root): + tree_colors.make_tree_colormap(root['node'].root, root['color'], **cmap_kwargs) + for i, child in enumerate(root['children']): + descend(child) + descend(self.root) + + def show_tree(self, **kwargs): + self.set_learned_parameters() + self.set_node_names() + self.set_expected_weights() + self.assign_samples() + self.set_ntssb_colors() + tree = self.get_param_dict() + plt.figure(figsize=(4,4)) + ax = plt.gca() + plot_full_tree(tree, ax=ax, node_size=101, **kwargs) \ No newline at end of file diff --git a/scatrex/ntssb/tree.py b/scatrex/ntssb/observed_tree.py similarity index 68% rename from scatrex/ntssb/tree.py rename to scatrex/ntssb/observed_tree.py index 51f7d67..4e3247e 100644 --- a/scatrex/ntssb/tree.py +++ b/scatrex/ntssb/observed_tree.py @@ -1,5 +1,6 @@ import string from abc import ABC, abstractmethod +from copy import deepcopy import numpy as np import pandas as pd @@ -8,23 +9,30 @@ import matplotlib.pyplot as plt from graphviz import Digraph +from .node import AbstractNode from ..plotting import constants +from ..utils.tree_utils import tree_to_dict, dict_to_tree, subsample_tree, condense_tree -class Tree(ABC): +class ObservedTree(ABC): def __init__( self, n_nodes=3, dp_alpha_subtree=1.0, - alpha_decay_subtree=0.9, + alpha_decay_subtree=1.0, dp_gamma_subtree=1.0, dp_alpha_parent_edge=1.0, - alpha_decay_parent_edge=0.9, - eta=0.0, + alpha_decay_parent_edge=1.0, + eta=1.0, node_weights=None, + seed=42, + add_root=True, + **kwargs, ): - - self.tree_dict = dict() + self.node_constructor = AbstractNode + self.seed = seed + self.tree_dict = dict() # flat + self.tree = dict() # recursive self.n_nodes = n_nodes self.dp_alpha_subtree = dp_alpha_subtree self.alpha_decay_subtree = alpha_decay_subtree @@ -35,16 +43,71 @@ def __init__( self.adata = None self.cmap = None self.node_weights = node_weights + self.add_root = add_root if self.node_weights is None: - self.node_weights = np.random.dirichlet([10.0] * self.n_nodes) + rng = np.random.default_rng(seed=self.seed) + self.node_weights = rng.dirichlet([100.0] * self.n_nodes) + + def condense(self, min_weight=.1, inplace=False): + """ + Traverse the tree from the bottom up and merge nodes until all nodes have at least min_weight + """ + if inplace: + condense_tree(self.tree, min_weight=min_weight) + self.tree_dict = tree_to_dict(self.tree) + self.tree = dict_to_tree(self.tree_dict, root_name=self.tree['label']) + self.n_nodes = len(self.tree_dict.keys()) + else: + new_tree = deepcopy(self.tree) + condense_tree(new_tree, min_weight=min_weight) + new_tree_dict = tree_to_dict(new_tree) + new_tree = dict_to_tree(new_tree_dict, root_name=new_tree['label']) + n_nodes = len(new_tree_dict.keys()) + new_obj = self.__class__(n_nodes=n_nodes, seed=self.seed) + new_obj.tree = new_tree + new_obj.tree_dict = new_tree_dict + return new_obj + + def subsample(self, keep_prob=0.5, force=True, inplace=False): + """ + Randomly choose a fraction of nodes to keep in the tree + """ + init_n_nodes = self.n_nodes + seed = self.seed + while True: + new_tree = deepcopy(self.tree) + subsample_tree(new_tree, keep_prob=keep_prob, seed=seed) + new_tree_dict = tree_to_dict(new_tree) + new_tree = dict_to_tree(new_tree_dict, root_name=new_tree['label']) + n_nodes = len(new_tree_dict.keys()) + if force: + if n_nodes == int(init_n_nodes*keep_prob): + break + else: + seed += 1 + else: + break + + if inplace: + self.tree = new_tree + self.tree_dict = new_tree_dict + self.n_nodes = n_nodes + else: + new_obj = self.__class__(n_nodes=n_nodes, seed=self.seed) + new_obj.tree = new_tree + new_obj.tree_dict = new_tree_dict + return new_obj def get_size(self): return len(self.tree_dict.keys()) + def get_param_size(self): + return self.tree["param"].size + def get_params(self): params = [] for node in self.tree_dict: - params.append(self.tree_dict[node]["params"]) + params.append(self.tree_dict[node]["param"]) return np.array(params, dtype=np.float) def change_names(self, keep="root"): @@ -62,7 +125,8 @@ def change_names(self, keep="root"): self.tree_dict[node]["children"] = [] self.tree_dict[alphabet[i]] = self.tree_dict[node] self.tree_dict[alphabet[i]]["label"] = alphabet[i] - del self.tree_dict[node] + if new_names[node] != node: + del self.tree_dict[node] for i in self.tree_dict: self.tree_dict[i]["children"] = [] @@ -72,7 +136,15 @@ def change_names(self, keep="root"): if self.tree_dict[j]["parent"] == i: self.tree_dict[i]["children"].append(j) - def set_colors(self, root_node=None): + root_name = keep + for node in self.tree_dict: + if self.tree_dict[node]['parent'] == '-1': + root_name = node + + self.tree = dict_to_tree(self.tree_dict, root_name=root_name) + self.tree_dict = tree_to_dict(self.tree) + + def set_colors(self, root_node='root'): idx = 0 if root_node in self.tree_dict: self.tree_dict[root_node]["color"] = "lightgray" @@ -82,6 +154,14 @@ def set_colors(self, root_node=None): self.tree_dict[node]["color"] = constants.LABEL_COLORS_DICT[node] except: self.tree_dict[node]["color"] = constants.CLONES_PAL[i] + + root_name = root_node + for node in self.tree_dict: + if self.tree_dict[node]['parent'] == '-1': + root_name = node + + self.tree = dict_to_tree(self.tree_dict, root_name=root_name) + self.tree_dict = tree_to_dict(self.tree) def add_tree_parameters(self, change_name=True): @@ -125,28 +205,32 @@ def add_tree_parameters(self, change_name=True): def generate_tree(self): alphabet = list(string.ascii_uppercase) # Add healthy node - self.tree_dict = dict( - root=dict( - parent="-1", - children=[], - params=None, - dp_alpha_subtree=self.dp_alpha_subtree, - alpha_decay_subtree=self.alpha_decay_subtree, - dp_gamma_subtree=self.dp_gamma_subtree, - dp_alpha_parent_edge=self.dp_alpha_parent_edge, - alpha_decay_parent_edge=self.alpha_decay_parent_edge, - eta=self.eta, - weight=0, - size=int(0), - color="lightgray", - label="root", + if self.add_root: + self.tree_dict = dict( + root=dict( + parent="-1", + children=[], + param=None, + dp_alpha_subtree=self.dp_alpha_subtree, + alpha_decay_subtree=self.alpha_decay_subtree, + dp_gamma_subtree=self.dp_gamma_subtree, + dp_alpha_parent_edge=self.dp_alpha_parent_edge, + alpha_decay_parent_edge=self.alpha_decay_parent_edge, + eta=self.eta, + weight=0, + size=int(0), + color="lightgray", + label="root", + ) ) - ) + mrca_parent = "root" + else: + mrca_parent = "-1" # Add MRCA node self.tree_dict["A"] = dict( - parent="root", + parent=mrca_parent, children=[], - params=None, + param=None, dp_alpha_subtree=self.dp_alpha_subtree, alpha_decay_subtree=self.alpha_decay_subtree, dp_gamma_subtree=self.dp_gamma_subtree, @@ -159,11 +243,12 @@ def generate_tree(self): label="A", ) for c in range(1, self.n_nodes): - parent = alphabet[np.random.choice(np.arange(0, c))] + rng = np.random.default_rng(seed=self.seed+c) + parent = alphabet[rng.choice(np.arange(0, c))] self.tree_dict[alphabet[c]] = dict( parent=parent, children=[], - params=None, + param=None, dp_alpha_subtree=self.dp_alpha_subtree, alpha_decay_subtree=self.alpha_decay_subtree, dp_gamma_subtree=self.dp_gamma_subtree, @@ -181,6 +266,12 @@ def generate_tree(self): if self.tree_dict[j]["parent"] == i: self.tree_dict[i]["children"].append(j) + for node in self.tree_dict: + if self.tree_dict[node]['parent'] == '-1': + root_name = node + + self.tree = dict_to_tree(self.tree_dict, root_name=root_name) + def get_sum_weights_subtree(self, label): if "weight" not in self.tree_dict["A"].keys(): raise KeyError("No weights were specified in the input tree.") @@ -288,7 +379,7 @@ def create_adata(self, var_names=None): ) params.append( np.vstack( - [self.tree_dict[node]["params"]] * self.tree_dict[node]["size"] + [self.tree_dict[node]["param"]] * self.tree_dict[node]["size"] ) ) params = pd.DataFrame(np.vstack(params)) @@ -335,7 +426,7 @@ def plot_heatmap(self, var_names=None, cmap=None, **kwds): def read_tree_from_dict( self, tree_dict, - input_params_key="params", + input_params_key="param", input_label_key="label", input_parent_key="parent", input_sizes_key="size", @@ -417,6 +508,12 @@ def read_tree_from_dict( for j in self.tree_dict: if self.tree_dict[j]["parent"] == i: self.tree_dict[i]["children"].append(j) + + for node in self.tree_dict: + if self.tree_dict[node]['parent'] == '-1': + root_name = node + + self.tree = dict_to_tree(self.tree_dict, root_name=root_name) def root(self): nodes = list(self.tree_dict.keys()) @@ -447,12 +544,62 @@ def update_weights(self, uniform=False): def subset_genes(self, gene_list): for node in self.tree_dict: - self.tree_dict[node]["params"] = pd.DataFrame( - self.tree_dict[node]["params"][:, np.newaxis].T, + self.tree_dict[node]["param"] = pd.DataFrame( + self.tree_dict[node]["param"][:, np.newaxis].T, columns=self.adata.var_names, )[gene_list].values.ravel() self.adata = self.adata[:, gene_list] @abstractmethod - def add_node_params(self): + def sample_root(self): return + + @abstractmethod + def sample_kernel(self): + return + + def update_tree(self): + for node in self.tree_dict: + if self.tree_dict[node]['parent'] == '-1': + root_name = node + self.tree = dict_to_tree(self.tree_dict, root_name=root_name) + + def update_dict(self): + self.tree_dict = tree_to_dict(self.tree) + + def param_distance(self, paramA, paramB): + return np.sqrt(np.sum((paramA-paramB)**2)) + + def add_node_params( + self, n_genes=2, min_dist=0.2, **params + ): + def descend(root, idx=0, depth=1): + for i, child in enumerate(root['children']): + seed = self.seed+idx + accepted = False + while not accepted: + child["param"] = self.sample_kernel(root["param"], seed=seed, depth=depth, **params) + dist_to_parent = self.param_distance(root["param"], child["param"]) + # Reject sample if too close to any other child + dists = [] + for j, child2 in enumerate(root['children']): + if j < i: + dists.append(self.param_distance(child["param"], child2["param"])) + if np.all(np.array(dists) >= min_dist*dist_to_parent): + accepted = True + else: + seed += 1 + + idx = descend(child, idx+1, depth=depth+1) + return idx + + # Set root param + self.tree["param"] = self.sample_root(n_genes=n_genes, seed=self.seed, **params) + + # Set node params recursively + descend(self.tree) + + # Update tree_dict too + self.tree_dict = tree_to_dict(self.tree) + + self.create_adata() diff --git a/scatrex/ntssb/search.py b/scatrex/ntssb/search.py index ef621e3..0f5cb0a 100644 --- a/scatrex/ntssb/search.py +++ b/scatrex/ntssb/search.py @@ -1,67 +1,272 @@ import numpy as np from copy import deepcopy -from tqdm.auto import tqdm +from tqdm.auto import trange from time import time -from ..util import * -from jax import jit -from jax.example_libraries.optimizers import adam import matplotlib.pyplot as plt -import logging - -logger = logging.getLogger(__name__) +from ..utils.math_utils import * +import jax +import jax.numpy as jnp -def search_callback(inf): - return - +import logging -MOVE_WEIGHTS = { - "add": 1, - "merge": 1.0, - "prune_reattach": 1.0, - "pivot_reattach": 1.0, - "swap": 1.0, - "add_reattach_pivot": 1.0, - "subtree_reattach": 0.5, - "push_subtree": 1.0, - "extract_pivot": 1.0, - "subtree_pivot_reattach": 0.5, - "perturb_node": .1, - "perturb_globals": .1, - "optimize_node": 1, - "transfer_factor": 0.5, - "transfer_unobserved": 0.5, - "clean_node": 0. -} +logger = logging.getLogger(__name__) class StructureSearch(object): def __init__(self, ntssb): - + # Keep a pointer to the current tree self.tree = deepcopy(ntssb) + # And a pointer to the changed tree + self.proposed_tree = deepcopy(self.tree) + self.traces = dict() self.traces["tree"] = [] self.traces["elbo"] = [] - self.traces["score"] = [] - self.traces["move"] = [] - self.traces["temperature"] = [] - self.traces["accepted"] = [] - self.traces["times"] = [] self.traces["n_nodes"] = [] - self.traces["gamma"] = [] self.traces["elbos"] = [] - self.best_elbo = self.tree.elbo self.best_tree = deepcopy(self.tree) - self.opt_triplet = None - def init_optimizer(self, lr=0.01, opt=adam): - opt_init, opt_update, get_params = opt(lr=lr) - get_params = jit(get_params) - opt_update = jit(opt_update) - opt_init = jit(opt_init) - self.opt_triplet = (opt_init, opt_update, get_params) + def run_search(self, n_iters=10, n_epochs=10, mc_samples=10, step_size=0.01, moves_per_tssb=1, global_freq=0, memoized=True, update_roots=True, seed=42, swap_freq=0, update_outer_ass=True): + """ + Start with learning the parameters of the model for the non-augmented tree, + which are just the assignments of cells to TSSBs, the outer stick parameters, + and the global model parameters + """ + # Generate PRNG key + key = jax.random.PRNGKey(seed) + + self.proposed_tree = deepcopy(self.tree) + + # Run the structure search + t = trange(n_iters, desc='Finding NTSSB', leave=True) + for i in t: + key, subkey = jax.random.split(key) + + update_globals = False + if global_freq != 0: + if i % global_freq == 0: + # Do birth-merge step in which other local params are updated + update_globals = True + + # Birth: traverse the tree and spawn a bunch of nodes (quick and helps escape local optima) + self.birth(subkey, moves_per_tssb=moves_per_tssb) + + # Update parameters in n_epochs passes through the data, interleaving node updates with local batch updates + self.tree.learn_params(n_epochs, update_roots=update_roots, mc_samples=mc_samples, + step_size=step_size, memoized=memoized) + self.tree.compute_elbo(memoized=memoized) + self.proposed_tree = deepcopy(self.tree) + + # Merge: traverse the tree and propose merges and accept/reject based on their summary statistics (reliable) + self.merge(subkey, moves_per_tssb=moves_per_tssb, memoized=memoized, update_globals=update_globals, + n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size) + + # Prune and reattach: traverse the tree and propose pruning nodes and reattaching somewhere else inside their TSSB + # self.prune_reattach(subkey, moves_per_tssb=moves_per_tssb) + + # Swap roots: traverse the tree and propose swapping roots of TSSBs with their immediate children + # self.swap_roots(subkey, moves_per_tssb=moves_per_tssb) + + # Keep best + if self.tree.elbo > self.best_tree.elbo: + self.best_tree = deepcopy(self.tree) + + # Compute ELBO + self.traces['elbo'].append(self.tree.elbo) + + t.set_description(f"Finding NTSSB ({self.tree.n_total_nodes} nodes, elbo: {self.tree.elbo})" ) + t.refresh() # to show immediately the update + + def swap(self, key, memoized=True, n_epochs=10, **learn_kwargs): + """ + Propose changing the root of a subtree. Preferably if + """ + if self.proposed_tree.n_total_nodes == self.proposed_tree.n_nodes: + self.logger("Nothing to swap.") + return + + n_children = 0 + while n_children == 0: + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey) + # See in which subtree it lands + subtree, _, u = self.proposed_tree.find_node(u) + # Choose either couple of children to merge or a child to merge with parent + parent = subtree.root + n_children = len(parent['children']) + + # Swap parent-child + source_idx = jax.random.choice(subkey, n_children) + source = parent['children'][source_idx] + target = parent + + self.proposed_tree.swap_root(source, target) + self.proposed_tree.compute_elbo(memoized=memoized) + + if self.proposed_tree.elbo > self.tree.elbo: + # print(f"Merged {source_label} to {target_label}") + self.tree = deepcopy(self.proposed_tree) + else: + self.proposed_tree.learn_params(n_epochs, update_globals=False, memoized=memoized, **learn_kwargs) + self.proposed_tree.compute_elbo(memoized=memoized) + if self.proposed_tree.elbo > self.tree.elbo: + self.tree = deepcopy(self.proposed_tree) + else: + self.proposed_tree = deepcopy(self.tree) + + def birth(self, key, moves_per_tssb=1): + """ + Spawn `moves_per_tssb=1` nodes + """ + n_births = self.proposed_tree.n_nodes * moves_per_tssb + for _ in range(n_births): + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey) + target = self.proposed_tree.get_node(u, key=subkey, uniform=True) + new_node = target['node'].tssb.add_node(target, seed=int(subkey[1])) + if jax.random.bernoulli(subkey): + new_node.init_new_node_kernel() + + # Always accept + self.tree = deepcopy(self.proposed_tree) + + def merge_root(self, key, memoized=True, n_epochs=10, **learn_kwargs): + """ + Propose merging a root's child to the root and keeping the child's parameters. Optimize and accept if ELBO improves + This is done by swapping the parameters of root and child and then merging child to root as usual + """ + if self.proposed_tree.n_total_nodes == self.proposed_tree.n_nodes: + self.logger("Nothing to swap.") + return + + n_children = 0 + while n_children == 0: + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey) + # See in which subtree it lands + subtree, _, u = self.proposed_tree.find_node(u) + # Choose either couple of children to merge or a child to merge with parent + parent = subtree.root + n_children = len(parent['children']) + + # Merge parent-child + source_idx = jax.random.choice(subkey, n_children) + source = parent['children'][source_idx] + target = parent + + slab = source['node'].label + tlab = target['node'].label + # self.proposed_tree.swap_nodes(source['node'], target['node']) + self.proposed_tree.swap_root(source['node'].label, target['node'].label) + subtree.merge_nodes(target, source, target) + self.proposed_tree.compute_elbo(memoized=memoized) + + if self.proposed_tree.elbo > self.tree.elbo: # if ELBO improves even before optimizing + self.tree = deepcopy(self.proposed_tree) + else: + # Update node parameters + self.proposed_tree.learn_model(n_epochs, update_globals=False, update_roots=True, memoized=memoized, **learn_kwargs) + self.proposed_tree.compute_elbo(memoized=memoized) + if self.proposed_tree.elbo > self.tree.elbo: + self.tree = deepcopy(self.proposed_tree) + else: + self.proposed_tree = deepcopy(self.tree) + + def merge(self, key, moves_per_tssb=1, memoized=True, update_globals=False, n_epochs=10, **learn_kwargs): + """ + Traverse the trees and propose and accept/reject merges as we go using local suff stats + """ + n_merges = int(0.7 * self.proposed_tree.n_total_nodes * moves_per_tssb * 2) + if update_globals: + n_merges = int(0.7 * self.proposed_tree.n_total_nodes * moves_per_tssb) + for _ in range(n_merges): + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey) + parent = self.proposed_tree.get_node(u, key=subkey, uniform=True, include_leaves=False) # get non-leaf node, without accounting for weights + tssb = parent['node'].tssb + # Choose either couple of children to merge or a child to merge with parent + n_children = len(parent['children']) + if n_children == 0: + continue + if n_children > 1: + # Choose + if jax.random.bernoulli(subkey, 0.5) == 1: + # Merge sibling-sibling + # Choose a child + source_idx, target_idx = jax.random.choice(subkey, n_children, shape=(2,), replace=False) + # Choose most similar sibling + source = parent['children'][source_idx] + target = parent['children'][target_idx] + else: + # Merge parent-child + # Choose most similar child + source_idx = jax.random.choice(subkey, n_children) + source = parent['children'][source_idx] + target = parent + else: + # Merge parent-child + source = parent['children'][0] + target = parent + + source_label = source['node'].label + target_label = target['node'].label + # print(f"Will merge {source_label} to {target_label}") + # Merge, updating suff stats + tssb.merge_nodes(parent, source, target) + # Update node sticks + tssb.update_stick_params(parent) + # Update pivot probs + tssb.update_pivot_probs() + # Compute ELBO of new tree + self.proposed_tree.compute_elbo(memoized=memoized) + # print(f"{self.tree.elbo} -> {self.proposed_tree.elbo}") + # Update if ELBO improved + if self.proposed_tree.elbo > self.tree.elbo: + # print(f"Merged {source_label} to {target_label}") + self.tree = deepcopy(self.proposed_tree) + else: + # Maybe update other locals + if update_globals: + # print("Inference") + # print(self.tree.elbo) + self.proposed_tree.learn_params(n_epochs, update_globals=True, memoized=memoized, **learn_kwargs) + self.proposed_tree.compute_elbo(memoized=memoized) + # print(self.proposed_tree.elbo) + if self.proposed_tree.elbo > self.tree.elbo: + self.tree = deepcopy(self.proposed_tree) + else: + self.proposed_tree = deepcopy(self.tree) + else: + self.proposed_tree = deepcopy(self.tree) + + + def prune_reattach(self, moves_per_tssb=1): + """ + Prune subtree and reattach somewhere else within the same TSSB + """ + n_prs = self.proposed_tree.n_nodes * moves_per_tssb + for _ in range(n_prs): + key, subkey = jax.random.split(key) + u = jax.random.uniform(subkey) + parent, source, target = self.proposed_tree.get_nodes(u, n_nodes=2) # find two nodes in the same TSSB + tssb = parent['node'].tssb + tssb.prune_reattach(parent, source, target) + # Update node stick parameters to account for changed mass distribution + tssb.update_stick_params(parent) + tssb.update_stick_params(target) + + # Optimize kernel parameters of root of moved subtree + tssb.update_node_params(source) + + # Compute ELBO of new tree + self.proposed_tree.compute_elbo() + # Update if ELBO improved + if self.proposed_tree.elbo > self.tree.elbo: + self.tree = self.proposed_tree + def plot_traces( self, @@ -112,2251 +317,3 @@ def plot_traces( color="black", ) plt.show() - - def run_search( - self, - n_iters=500, - n_iters_elbo=20, - n_iters_elbo_init=20, - factor_delay=0, - posterior_delay=0, - global_delay=0, - joint_init=True, - thin=10, - local=False, - mc_samples=1, - lr=0.001, - verbosity=logging.INFO, - tol=1e-6, - mb_size=256, - max_nodes=5, - debug=False, - callback=None, - alpha=0.0, - Tmax=10, - anneal=False, - restart_step=10, - move_weights=None, - weighted=True, - merge_n_tries=5, - opt=adam, - search_callback=None, - add_rule="accept", - add_rule_thres=1.0, - seed=1, - rescore_best=False, - **callback_kwargs, - ): - - logger.setLevel(verbosity) - - np.random.seed(seed) - - elbos = [] - gamma = 1.0 - - if move_weights is None: - move_weights = MOVE_WEIGHTS - - self.tree.max_nodes = ( - len(self.tree.input_tree_dict.keys()) * max_nodes - ) # upper bound on number of nodes - - mb_size = min(mb_size, self.tree.data.shape[0]) - - score_type = "elbo" - # if posterior_delay > 0: - # score_type = 'll' - - main_lr = lr - T = Tmax - if not anneal: - T = 1.0 - - if not local and global_delay > 0: - local = True - - n_factors = self.tree.root["node"].root["node"].num_global_noise_factors - - transfer_factor_weight = 0.0 - if "transfer_factor" in move_weights: - if n_factors == 0: - move_weights["transfer_factor"] = 0.0 - move_weights["transfer_unobserved"] = 0.0 - transfer_factor_weight = move_weights["transfer_factor"] - - init_baseline = np.mean(self.tree.data, axis=0) - init_baseline = init_baseline / np.median( - self.tree.input_tree.adata.X / 2, axis=0 - ) - init_baseline = init_baseline / np.std(init_baseline) - init_baseline = init_baseline / init_baseline[0] - init_log_baseline = np.log(init_baseline[1:] + 1e-6) - init_log_baseline = np.clip(init_log_baseline, -2, 2) - - if len(self.traces["score"]) == 0: - if n_factors > 0 and factor_delay > 0: - self.tree.root["node"].root["node"].num_global_noise_factors = 0 - if "transfer_factor" in move_weights: - move_weights["transfer_factor"] = 0.0 - move_weights["transfer_unobserved"] = 0.0 - - # Compute score of initial tree -- should we really optimize the baseline to the max before doing it with the unobs factors? - self.tree.reset_variational_parameters() - self.tree.root["node"].root["node"].variational_parameters["globals"][ - "log_baseline_mean" - ] = init_log_baseline - self.tree.optimize_elbo( - root=self.tree.root, - sub_root=None, - restricted=False, - update_all=True, - sticks_only=True, - mc_samples=mc_samples, - n_inner_steps=n_iters_elbo_init, - n_local_traverses=5, - lr=lr, - mb_size=mb_size, - ) - - # full update -- maybe without globals? - root_node_init = self.tree.root["node"].root["node"] - update_all = False - if joint_init: - update_all = True - root_node_init = None - self.tree.root["node"].root["node"].reset_variational_parameters( - means=False - ) - - self.tree.optimize_elbo( - root=self.tree.root, - sub_root=None, - restricted=False, - update_all=update_all, - mc_samples=mc_samples, - n_inner_steps=n_iters_elbo, - n_local_traverses=5, - lr=lr, - mb_size=mb_size, - ) - - self.tree.plot_tree(super_only=False) - self.tree.update_ass_logits(variational=True) - self.tree.assign_to_best() - self.best_elbo = self.tree.elbo - self.best_tree = deepcopy(self.tree) - else: - self.tree.root = deepcopy(self.best_tree.root) - self.best_elbo = self.best_tree.elbo - self.tree.elbo = self.best_tree.elbo - gamma = self.traces["gamma"][-1] - - moves = list(move_weights.keys()) - moves_weights = list(move_weights.values()) - - logger.debug( - f"Will search for the maximum marginal likelihood tree with the following moves: {moves}\n" - ) - - init_score = self.tree.elbo if score_type == "elbo" else self.tree.ll - init_root = deepcopy(self.tree.root) - move_id = "full" - n_merge = 0 - # Search tree - p = np.array(moves_weights) - p = p / np.sum(p) - for i in tqdm(range(n_iters)): - - try: - if np.mod(i, 5) == 0: - nodes, mixture = self.tree.get_node_mixture() - n_nodes = len(nodes) - if ( - np.sum(mixture < 0.1 * 1.0 / n_nodes) > np.ceil(n_nodes / 3) - or n_nodes > self.tree.max_nodes * add_rule_thres - ): - # Reduce probability of adding, and only add if score improves - # p = np.array(move_weights) - # p[np.where(np.array(moves)=='add')[0][0]] = 0.25 * 1/len(moves) - add_rule = "improve" - else: - # Keep uniform move probability and always accept adds - p = np.array(moves_weights) - add_rule = "accept" - except: - pass - - p = p / np.sum(p) - # if move_id == 'add': - # move_id = 'merge' # always try to merge after adding # not -- must give a chance for the noise factors to be updated too - # else: - move_id = np.random.choice(moves, p=p) - - if ( - i == factor_delay - and n_factors > 0 - and self.tree.root["node"].root["node"].num_global_noise_factors == 0 - ): - self.tree = deepcopy(self.best_tree) - self.tree.root["node"].root["node"].num_global_noise_factors = n_factors - self.tree.root["node"].root["node"].init_noise_factors() - move_weights["transfer_factor"] = transfer_factor_weight - move_weights["transfer_unobserved"] = transfer_factor_weight - moves = list(move_weights.keys()) - moves_weights = list(move_weights.values()) - move_id = "full" - - if lr < main_lr: - move_id = "reset_globals" - - # nits_check = int(np.max([50, .1*n_iters])) - # if i > 50 and score_type == 'elbo' and np.sum(np.array(self.traces['accepted'][-nits_check:]) == False) == nits_check: - # logger.debug(f"No moves accepted in {nits_check} iterations. Using ll for 10 iterations.") - # posterior_delay = i + 10 - # - # if i < posterior_delay: - # score_type = 'll' - # elif i == posterior_delay: - # # Go back to best in terms of ELBO - # # self.tree.root = deepcopy(self.best_tree.root) - # # self.tree.elbo = self.best_elbo - # score_type = 'elbo' - - if global_delay > 0 and i > global_delay: - local = False - - init_root = deepcopy(self.tree.root) - init_elbo = self.tree.elbo - init_score = self.tree.elbo if score_type == "elbo" else self.tree.ll - success = True - - # nodes = self.tree.get_nodes() - # nodes = self.tree._get_nodes(get_roots=True) - nodes = self.tree.get_node_roots() - self.tree.n_nodes = len(nodes) - start = time() - if move_id == "add" and self.tree.n_nodes < self.tree.max_nodes - 1: - success, elbos = self.add_node( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - mb_size=mb_size, - max_nodes=max_nodes, - debug=debug, - opt=opt, - weighted=weighted, - callback=callback, - **callback_kwargs, - ) - elif move_id == "merge": - success, elbos = self.merge_nodes( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - mb_size=mb_size, - max_nodes=max_nodes, - debug=debug, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif move_id == "prune_reattach": - success, elbos = self.prune_reattach( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - mb_size=mb_size, - max_nodes=max_nodes, - debug=debug, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif move_id == "pivot_reattach": - success, elbos = self.pivot_reattach( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - mb_size=mb_size, - max_nodes=max_nodes, - debug=debug, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif ( - move_id == "add_reattach_pivot" - and self.tree.n_nodes < self.tree.max_nodes - 1 - ): - success, elbos = self.add_reattach_pivot( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - mb_size=mb_size, - max_nodes=max_nodes, - debug=debug, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif move_id == "swap": - success, elbos = self.swap_nodes( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - debug=debug, - mb_size=mb_size, - max_nodes=max_nodes, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif move_id == "subtree_reattach": - success, elbos = self.subtree_reattach( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - debug=debug, - mb_size=mb_size, - max_nodes=max_nodes, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif move_id == "subtree_pivot_reattach": - success, elbos = self.subtree_pivot_reattach( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - debug=debug, - mb_size=mb_size, - max_nodes=max_nodes, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif ( - move_id == "push_subtree" - and self.tree.n_nodes < self.tree.max_nodes - 1 - ): - success, elbos = self.push_subtree( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - debug=debug, - mb_size=mb_size, - max_nodes=max_nodes, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif ( - move_id == "extract_pivot" - and self.tree.n_nodes < self.tree.max_nodes - 1 - ): - success, elbos = self.extract_pivot( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - debug=debug, - mb_size=mb_size, - max_nodes=max_nodes, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif move_id == "optimize_node": - # Randomly choose a node - node_root = np.random.choice(nodes[1:]) - node = node_root["node"] - logger.debug(f"Optimizing {node.label}...") - elbos = self.tree.optimize_elbo( - root=None, - sub_root=node_root, - restricted=True, - mc_samples=mc_samples, - n_inner_steps=n_iters_elbo, - lr=lr, - mb_size=mb_size, - ) - elif move_id == "perturb_node": - # Randomly choose a node - node_root = np.random.choice(nodes[1:]) - node = node_root["node"] - logger.debug(f"Perturbing {node.label}...") - # Perturb a bit and avoid clash between kernel and effect - perturbation = np.random.normal(0, 0.5, size=init_log_baseline.size + 1) - node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] += np.abs(perturbation) - node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] += perturbation - # self.tree.root['node'].root['node'].variational_parameters['globals']['log_baseline_log_std'] += 2. # increase std - elbos = self.tree.optimize_elbo( - root=None, - sub_root=node_root, - restricted=False, - mc_samples=mc_samples, - n_inner_steps=n_iters_elbo, - lr=lr, - mb_size=mb_size, - ) - # init_root, init_elbo, success, elbos = self.perturb_node(local=local, mc_samples=mc_samples, n_iters=n_iters_elbo, thin=thin, lr=lr, tol=tol, debug=debug, mb_size=mb_size, max_nodes=max_nodes, opt=opt, callback=callback, **callback_kwargs) - elif move_id == "clean_node": - success, elbos = self.clean_node( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - debug=debug, - mb_size=mb_size, - max_nodes=max_nodes, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif move_id == "transfer_factor": - success, elbos = self.transfer_factor( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - debug=debug, - mb_size=mb_size, - max_nodes=max_nodes, - opt=opt, - callback=callback, - **callback_kwargs, - ) - - elif move_id == "transfer_unobserved": - success, elbos = self.transfer_unobserved( - local=local, - mc_samples=mc_samples, - n_iters=n_iters_elbo, - thin=thin, - lr=lr, - tol=tol, - debug=debug, - mb_size=mb_size, - max_nodes=max_nodes, - opt=opt, - callback=callback, - **callback_kwargs, - ) - elif move_id == "globals": - elbos = self.tree.optimize_elbo( - root=self.tree.root, - sub_root=None, - restricted=False, - globals_only=True, - mc_samples=mc_samples, - n_inner_steps=n_iters_elbo, - lr=lr, - mb_size=mb_size, - ) - elif move_id == "perturb_globals": - logger.debug("Perturbing globals...") - # Perturb a bit and optimize all - perturbation = np.random.normal(0, 0.5, size=init_log_baseline.size) - self.tree.root["node"].root["node"].variational_parameters["globals"][ - "log_baseline_mean" - ] += perturbation - perturbation = np.random.normal( - 0, - 0.1, - size=( - self.tree.root["node"].root["node"].num_global_noise_factors, - init_log_baseline.size + 1, - ), - ) - self.tree.root["node"].root["node"].variational_parameters["globals"][ - "noise_factors_mean" - ] += perturbation - self.tree.root["node"].root["node"].variational_parameters["locals"][ - "unobserved_factors_mean" - ] *= 0 - elbos = self.tree.optimize_elbo( - root=self.tree.root, - sub_root=None, - restricted=False, - update_all=True, - mc_samples=mc_samples, - n_inner_steps=n_iters_elbo, - lr=lr, - mb_size=mb_size, - ) - elif move_id == "full": - logger.debug(f"Full update...") - elbos = self.tree.optimize_elbo( - root=self.tree.root, - sub_root=None, - restricted=False, - update_all=True, - mc_samples=mc_samples, - n_inner_steps=n_iters_elbo, - lr=lr, - mb_size=mb_size, - ) - self.traces["times"].append(time() - start) - - if np.isnan(self.tree.elbo): - logger.debug("Got NaN!") - self.tree.root = deepcopy(init_root) - self.tree.elbo = init_elbo - logger.debug( - "Proceeding with previous tree, reducing step size and doing `reset_globals`." - ) - lr = lr * 0.1 - if lr < 1e-4: - logger.debug( - "Step size is becoming small. Fetching best tree with noise factors." - ) - self.tree.root = deepcopy(self.best_tree.root) - self.tree.elbo = self.best_elbo - if ( - self.tree.root["node"].root["node"].num_global_noise_factors - == 0 - and n_factors > 0 - and i > factor_delay - ): - self.tree.root["node"].root[ - "node" - ].num_global_noise_factors = n_factors - self.tree.root["node"].root["node"].init_noise_factors() - if lr < 1e-6: - raise ValueError("Step size became too small due to too many NaNs!") - # self.init_optimizer(lr=lr) - continue - else: - if lr != main_lr: - lr = main_lr - # self.init_optimizer(lr=lr) - - # if anneal: - # if i/thin >= 0 and np.mod(i, thin) == 0: - # idx = int(i/thin) - # if Tmax != 1: - # T = Tmax - alpha*idx - # T = T * (1 + (self.tree.elbo - self.best_elbo)/self.tree.elbo) - - new_score = self.tree.elbo if score_type == "elbo" else self.tree.ll - - accepted = True - - if move_id == "full" or success == False: - accepted = False - - logger.debug(f"ELBO change: {init_elbo} -> {self.tree.elbo}") - if self.tree.n_nodes >= self.tree.max_nodes: - self.tree.root = deepcopy(init_root) - self.tree.elbo = init_elbo - accepted = False - elif move_id == "add" or move_id == "add_reattach_pivot": - if ( - add_rule == "accept" and score_type == "elbo" - ): # only accept immediatly if using ELBO to score - logger.debug( - f"*Move ({move_id}) accepted. ({init_elbo} -> {self.tree.elbo})*" - ) - if self.tree.elbo > self.best_elbo: - self.best_elbo = self.tree.elbo - self.best_tree = deepcopy(self.tree) - logger.debug(f"New best! {self.best_elbo}") - else: - rate = -(init_score - new_score) / T - rate = rate * gamma - if rate < np.log(np.random.rand()): - self.tree.root = deepcopy(init_root) - self.tree.elbo = init_elbo - accepted = False - gamma = gamma * np.exp((0.0 - alpha) * alpha) - else: - logger.debug( - f"*Move ({move_id}) accepted. ({init_elbo} -> {self.tree.elbo})*" - ) - gamma = gamma * np.exp((1.0 - alpha) * alpha) - if self.tree.elbo > self.best_elbo: - self.best_elbo = self.tree.elbo - self.best_tree = deepcopy(self.tree) - logger.debug(f"New best! {self.best_elbo}") - elif move_id != "full" and success == True: - rate = -(init_score - new_score) / T - rate = rate * gamma - rate_thres = np.log(np.random.rand()) - logger.debug(f"{rate}, {rate_thres}") - if rate <= rate_thres: - self.tree.root = deepcopy(init_root) - self.tree.elbo = init_elbo - accepted = False - gamma = gamma * np.exp((0.0 - alpha) * alpha) - else: # Accepted - logger.debug( - f"*Move ({move_id}) accepted. ({init_elbo} -> {self.tree.elbo})*" - ) - gamma = gamma * np.exp((1.0 - alpha) * alpha) - if self.tree.elbo > self.best_elbo: - self.best_elbo = self.tree.elbo - self.best_tree = deepcopy(self.tree) - logger.debug(f"New best! {self.best_elbo}") - - if self.tree.elbo > self.best_elbo: - self.best_elbo = self.tree.elbo - self.best_tree = deepcopy(self.tree) - logger.debug(f"New best! {self.best_elbo}") - - if i == factor_delay and factor_delay > 0 and n_factors > 0: - logger.debug( - "Setting current tree with complete number of factors as the best." - ) - self.best_elbo = self.tree.elbo - self.best_tree = deepcopy(self.tree) - logger.debug(f"New best! {self.best_elbo}") - - gamma = np.max([gamma, 1e-5]) - score = self.tree.elbo if score_type == "elbo" else self.tree.ll - if rescore_best: - self.traces["tree"].append(deepcopy(self.tree)) - else: - self.traces["tree"].append(self.tree.plot_tree(counts=True)) - self.traces["elbo"].append(self.tree.elbo) - self.traces["score"].append(score) - self.traces["move"].append(move_id) - self.traces["n_nodes"].append(self.tree.n_nodes) - self.traces["temperature"].append(T) - self.traces["accepted"].append(accepted) - - self.traces["gamma"].append(gamma) - self.traces["elbos"].append(elbos) - - if search_callback is not None: - search_callback(self) - - if anneal: - if i / restart_step > 0 and np.mod(i, restart_step) == 0: - self.tree.root = deepcopy(self.best_tree.root) - self.tree.elbo = self.best_elbo - - if T == 0: - break - - if rescore_best: - logger.debug("Re-scoring top 10% trees") - logger.debug(f"Current best one is tree {np.argmax(self.traces['elbo'])}") - # Get top 10 unique-scoring trees - elbos, indices = np.unique(self.traces["elbo"], return_index=True) - top_scoring = indices[np.where(elbos > np.quantile(elbos, q=0.90))[0]] - logger.debug( - f"Re-scoring top 10% trees: got {len(top_scoring)} trees to score" - ) - # Re-score them - new_elbos = [] - for tree_idx in top_scoring: - elbos = self.traces["tree"][tree_idx].optimize_elbo( - local_node=None, - root_node=None, - mc_samples=5, - n_iters=10000, - thin=thin, - tol=tol, - lr=lr, - mb_size=mb_size, - max_nodes=max_nodes, - init=True, - debug=False, - opt=opt, - opt_triplet=None, - callback=callback, - **callback_kwargs, - ) - logger.debug( - f"Tree {tree_idx}: {self.traces['elbo'][tree_idx]} -> {self.traces['tree'][tree_idx].elbo}" - ) - new_elbos.append(self.traces["tree"][tree_idx].elbo) - # Get new best - new_best_idx = top_scoring[np.argmax(new_elbos)] - logger.debug(f"New best is tree {new_best_idx}") - new_best = self.traces["tree"][new_best_idx] - self.best_tree = deepcopy(new_best) - self.best_elbo = new_best.elbo - - self.tree.plot_tree(super_only=False) - self.best_tree.plot_tree(super_only=False) - self.best_tree.update_ass_logits(variational=True) - self.best_tree.assign_to_best() - return self.best_tree - - def add_node( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - weighted=True, - callback=None, - **callback_kwargs, - ): - success = True - elbos = [] - - nodes, target_probs = self.tree.get_node_data_sizes(normalized=True) - if not weighted: - target_probs = np.array([1.0] * len(nodes)) - target_probs /= np.sum(target_probs) - node_idx = np.random.choice(range(len(nodes)), p=np.array(target_probs)) - node = nodes[node_idx] - - logger.debug(f"Trying to add node below {node.label}") - - # Decide wether to initialize from a factor - from_factor = False - factor_idx = None - if self.tree.root["node"].root["node"].num_global_noise_factors > 0: - if len(nodes) < len(self.tree.input_tree_dict.keys()) * 2: - p = 0.8 - else: - p = 0.5 - from_factor = bool(np.random.binomial(1, p)) - - if from_factor: - logger.debug(f"Initializing new node from noise factor") - # Choose factor that the data in the node like - cells_in_node = np.where(np.array(self.tree.assignments) == node) - factor_idx = np.argmax( - np.mean( - np.abs( - self.tree.root["node"] - .root["node"] - .variational_parameters["globals"]["cell_noise_mean"][ - cells_in_node - ] - ), - axis=0, - ) - ) - - # new_node = self.tree.add_node_to(node, optimal_init=True, factor_idx=factor_idx) - parent_root = self.tree.add_node_to(node, optimal_init=True, factor_idx=factor_idx) - new_node = parent_root["children"][-1]["node"] - - root = self.tree.root - sub_root = None - restricted = False - go_down = False - if local: - root = None - sub_root = parent_root - restricted = True - - children = np.array(list(node.children())) - idx = np.argsort([n.label for n in children]) - children = children[idx] - unobs_children = [ - child for child in children if not child.is_observed and child != new_node - ] - if len(unobs_children) > 0: - pr = np.random.binomial(1, 0.5) - if pr: - child = np.random.choice(unobs_children) - logger.debug(f"Also attaching {child.label} to new node") - self.tree.prune_reattach(child, new_node) - go_down = True - - self.tree.plot_tree() - # Ensure constant node order - if from_factor: - # Remove factor - self.tree.root["node"].root["node"].variational_parameters["globals"][ - "noise_factors_mean" - ][factor_idx] *= 0.0 - self.tree.root["node"].root["node"].variational_parameters["globals"][ - "cell_noise_mean" - ][:, factor_idx] = 0.0 - elbos = self.tree.optimize_elbo( - root=root, - sub_root=None, - update_all=True, - mc_samples=mc_samples, - n_inner_steps=n_iters,# * 5, - lr=lr, - mb_size=mb_size, - ) - else: - if node.parent() is None: # if root, need to update also global factors - elbos = self.tree.optimize_elbo( - root=self.tree.root, - sub_root=None, - update_all=True, - mc_samples=mc_samples, - n_inner_steps=n_iters,# * 5, - lr=lr, - mb_size=mb_size, - ) - else: - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - go_down=go_down, - mc_samples=mc_samples, - n_inner_steps=n_iters, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def merge_nodes( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - # merge_paris = self.tree.get_merge_pairs() - # for pair in merge_pairs: - success = False - elbos = [] - # self.tree.merge_nodes(pair[0], pair[1]) - # self.tree.optimize_elbo(unique_node=None, root_node=pair[1], run=True, mc_samples=mc_samples, n_iters=n_iters, thin=thin, tol=tol, lr=lr, mb_size=mb_size, max_nodes=max_nodes, init=False, debug=debug, opt=opt, opt_triplet=self.opt_triplet, callback=callback, **callback_kwargs) - - # Choose a subtree - # _, subtrees = self.tree.get_mixture() - subtrees = self.tree.get_tree_roots() - - n_nodes = [] - nodes = [] - for subtree in subtrees: - nodes.append(subtree["node"].get_node_roots()) - n_nodes.append(len(nodes[-1])) - n_nodes = np.array(n_nodes) - # Only proceed if there is at least one subtree with mergeable nodes - if np.any(n_nodes > 1): - success = True - probs = n_nodes - 1 - probs = probs / np.sum(probs) - - # Choose a subtree with more than 1 node - idx = np.random.choice(range(len(subtrees)), p=probs) - subtree = subtrees[idx] - nodes = nodes[idx] - - # Choose a first node A (which can't be the root), biased by size - inv_sizes = np.array([1/(len(n["node"].data) + 1.) for n in nodes[1:]]) - node_idx = np.random.choice( - range(len(nodes[1:])), p=inv_sizes/np.sum(inv_sizes) - ) - nodeA_root = nodes[1:][node_idx] - nodeA = nodeA_root["node"] - - # Get parent and siblings in the same subtree - parent = nodeA.parent() - nodes = np.array(list(parent.children())) - idx = np.argsort([n.label for n in nodes]) - nodes = nodes[idx] - nodes = [s for s in nodes if s != nodeA and nodeA.tssb == s.tssb] - nodes.append(parent) - - # If nodeA is pivot node, it's also possible to merge the child with it - n_pivots = 0 - nodeA_children = np.array(list(nodeA.children())) - idx = np.argsort([n.label for n in nodeA_children]) - nodeA_children = nodeA_children[idx] - for nodeA_child in nodeA_children: - if nodeA_child.tssb != nodeA.tssb: - n_pivots += 1 - nodes.append(nodeA_child) - - sims = [ - 1.0 / (np.mean(np.abs(nodeA.node_mean - node.node_mean)) + 1e-8) - for node in nodes - ] - - # Choose nodeB proportionally to similarities - nodeB = np.random.choice(nodes, p=sims / np.sum(sims)) - nodeB_root = nodeB.get_tssb_root() - - # Choose initialization - optimal_params = True - if nodeB.parent() is None: - optimal_params = bool(np.random.choice(2, p=[0.4, 0.6])) - - local_node = None - restricted = False - root = self.tree.root - sub_root = None - update_all = True - # If a pivot was chosen, choose merge root of subtree with it - if nodeB.tssb != nodeA.tssb: - logger.debug(f"Trying to merge {nodeB.label} to {nodeA.label}...") - self.tree.merge_nodes(nodeB, nodeA, optimal_params=optimal_params) - if local: - local_node = nodeB - root = None - restricted = True - sub_root = nodeB_root - update_all = False - else: - logger.debug(f"Trying to merge {nodeA.label} to {nodeB.label}...") - self.tree.merge_nodes(nodeA, nodeB, optimal_params=optimal_params) - if local: - local_node = nodeB.parent() - root = None - restricted = True - sub_root = nodeB_root - update_all = False - - # Account for merging to baseline - update_all_n_iters = 40 - if nodeB.parent() is None: - if optimal_params: - logger.debug( - f"Global update since the baseline has been changed..." - ) - local_node = None - root = self.tree.root - sub_root = None - restricted = False - update_all = True - # n_iters = np.max([n_iters, 50]) - - self.tree.plot_tree() - # Ensure constant node order - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - update_all=update_all, - restricted=restricted, - run=True, - n_iters=update_all_n_iters, - mc_samples=mc_samples, - n_inner_steps=n_iters, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def prune_reattach( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - success = False - elbos = [] - - # Choose a subtree - # _, subtrees = self.tree.get_mixture() - subtrees = self.tree.get_tree_roots() - - n_nodes = [] - nodes = [] - for subtree in subtrees: - nodes.append(subtree["node"].get_mixture()[1]) - n_nodes.append(len(nodes[-1])) - n_nodes = np.array(n_nodes) - # Only proceed if there is at least one subtree with pruneable nodes -- i.e., it needs to have at least 2 extra nodes - if np.any(n_nodes > 2): - success = True - probs = np.array(n_nodes) - probs[probs <= 2] = 0.0 - probs = probs / np.sum(probs) - - # Choose a subtree with more than 1 node - idx = np.random.choice(range(len(subtrees)), p=probs) - subtree = subtrees[idx] - nodes = nodes[idx] - - # Get the nodes which can be reattached, use labels - self.tree.plot_tree() - possible_nodes = dict() - node_label_dict = dict() - for node in nodes: - possible_nodes[node.label] = [ - n.label - for n in nodes - if node.label not in n.label and n != node.parent() - ] - node_label_dict[node.label] = node - - # Choose a first node - possible_nodesA = [ - node for node in possible_nodes if len(possible_nodes[node]) > 0 - ] - nodeA_label = np.random.choice( - possible_nodesA, p=[1.0 / len(possible_nodesA)] * len(possible_nodesA) - ) - nodeA = node_label_dict[nodeA_label] - - # Get nodes not below node A: use labels - sims = [ - 1.0 - / ( - np.mean( - np.abs(nodeA.node_mean - node_label_dict[node_label].node_mean) - ) - + 1e-8 - ) - for node_label in possible_nodes[nodeA_label] - ] - - # Choose nodeB proportionally to similarities - nodeB_label = np.random.choice( - possible_nodes[nodeA_label], p=sims / np.sum(sims) - ) - nodeB = node_label_dict[nodeB_label] - - logger.debug(f"Trying to reattach {nodeA_label} to {nodeB_label}...") - - self.tree.prune_reattach(nodeA, nodeB) - self.tree.plot_tree() - # Ensure constant node order - local_node = None - root = self.tree.root - sub_root = None - restricted = False - if local: - root = None - local_node = nodeA - sub_root = subtree["node"].root - restricted = True - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - mc_samples=mc_samples, - n_inner_steps=n_iters, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def pivot_reattach( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - success = False - elbos = [] - - # Uniformly pick a subtree - subtrees = self.tree.get_mixture()[1][1:] # without the root - subtree = np.random.choice(subtrees, p=[1.0 / len(subtrees)] * len(subtrees)) - init_pivot_node = subtree.root["node"].parent() - init_pivot = init_pivot_node.label - init_pivot_node_parent = init_pivot_node.parent() - - # Choose a pivot node from the parent subtree that isn't the current one - weights, nodes = init_pivot_node.tssb.get_fixed_weights() - # Only proceed if parent subtree has more than 1 node - if len(nodes) > 1: - success = True - # weights = [weight for i, weight in enumerate(weights) if nodes[i] != init_pivot_node] - # weights = np.array(weights) / np.sum(weights) - # Also use the similarity of the parent subtree's nodes' unobserved factors with the subtree root - sims = [ - 1.0 - / ( - np.mean( - np.abs( - subtree.root["node"].variational_parameters["locals"][ - "unobserved_factors_mean" - ] - - node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] - ) - ) - + 1e-8 - ) - for node in nodes - ] - log_weights = [ - np.log(weights[i]) + np.log(sims[i]) - for i, node in enumerate(nodes) - if node != init_pivot_node - ] - weights = np.exp(np.array(log_weights)) - weights = weights / np.sum(weights) - nodes = [ - node for node in nodes if node != init_pivot_node - ] # remove the current pivot - node_idx = np.random.choice(range(len(nodes)), p=weights) - node = nodes[node_idx] - - # Update pivot - logger.debug(f"Trying to set {node.label} as pivot of {subtree.label}") - - self.tree.pivot_reattach_to(subtree, node) - - removed_pivot = False - if ( - len(init_pivot_node.data) == 0 - and init_pivot_node_parent is not None - and len(init_pivot_node.children()) == 0 - ): - logger.debug( - f"Also removing initial pivot ({init_pivot_node.label}) from tree" - ) - self.tree.merge_nodes(init_pivot_node, init_pivot_node_parent) - # self.tree.optimize_elbo(sticks_only=True, root_node=init_pivot_node_parent, mc_samples=mc_samples, n_iters=n_iters, thin=thin, tol=tol, lr=lr, mb_size=mb_size, max_nodes=max_nodes, init=False, debug=debug, opt=opt, opt_triplet=self.opt_triplet, callback=callback, **callback_kwargs) - removed_pivot = True - - self.tree.plot_tree() - # Ensure constant node order - root = self.tree.root - sub_root = None - restricted = False - if local: - root = None - sub_root = subtree.root - restricted = True - # if root_node.parent() is not None: - # root_node = root_node.parent() # more robust - # if removed_pivot: - # root_node = init_pivot_node_parent - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - mc_samples=mc_samples, - n_inner_steps=n_iters, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def add_reattach_pivot( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - success = True - elbos = [] - - # Add a node below a subtree with children subtrees - subtrees = self.tree.get_subtrees(get_roots=True) - nonleaf_subtrees = [ - subtree for subtree in subtrees if len(subtree[1]["children"]) > 0 - ] - - # Don't waste too much time with subtrees which are not expected to have any complexity - subtree_weights = [subtree[0].weight for subtree in nonleaf_subtrees] - - # If every subtree has a parent with no weight, don't proceed - if np.sum(subtree_weights) < 1e-10: - return False, elbos - - subtree_weights = np.array(subtree_weights) / np.sum(subtree_weights) - - # Pick a subtree - parent_subtree = nonleaf_subtrees[ - np.random.choice( - len(nonleaf_subtrees), - p=subtree_weights, - ) - ] - - # Pick a node in the parent subtree - nodes, target_probs = self.tree.get_node_data_sizes(normalized=True) - target_probs = [ - prob + 1e-8 - for i, prob in enumerate(target_probs) - if nodes[i].tssb == parent_subtree[0] - ] - nodes = [node for node in nodes if node.tssb == parent_subtree[0]] - target_probs /= np.sum(target_probs) - node = np.random.choice(nodes, p=np.array(target_probs)) - pivot_node_parent_root = self.tree.add_node_to(node, optimal_init=True, return_parent_root=True) - - # Pick one of the children subtrees - subtrees = [subtree for subtree in parent_subtree[1]["children"]] - subtree = np.random.choice(subtrees, p=[1.0 / len(subtrees)] * len(subtrees)) - init_pivot = subtree["node"].root["node"].parent() - - # Update pivot - self.tree.pivot_reattach_to(subtree["node"], pivot_node_parent_root["children"][-1]["node"]) - - logger.debug( - f"Trying to add node {pivot_node_parent_root['children'][-1]['node'].label} and setting it as pivot of {subtree['node'].label}" - ) - - self.tree.plot_tree() - # Ensure constant node order - root_node = pivot_node_parent_root - n_iters_elbo = n_iters - root = self.tree.root - sub_root = None - restricted = False - if pivot_node_parent_root["children"][-1]["node"].parent().parent() is None: - n_iters_elbo = n_iters #* 5 - if local: - root_node = pivot_node_parent_root - root = None - sub_root = pivot_node_parent_root - restricted = True - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - mc_samples=mc_samples, - go_down=True, - n_inner_steps=n_iters_elbo, - lr=lr, - mb_size=mb_size, - ) - - - return success, elbos - - def push_subtree( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - success = True - elbos = [] - - # Uniformly pick a subtree - subtrees = self.tree.get_tree_roots()[1:] # without the root - subtree = np.random.choice(subtrees, p=[1.0 / len(subtrees)] * len(subtrees)) - - # Push subtree down - new_pivot_root = self.tree.push_subtree(subtree["node"].root["node"]) - - logger.debug(f"Trying to push {subtree['node'].label} down") - - root_node = None - root = self.tree.root - sub_root = None - restricted = False - if local: - root = None - sub_root = new_pivot_root - restricted = True - go_down=True - # root_node = subtree.root["node"].parent() - self.tree.plot_tree() - # Ensure constant node order - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - go_down=go_down, - mc_samples=mc_samples, - n_inner_steps=n_iters, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def extract_pivot( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - success = True - elbos = [] - - # Pick a subtree biased by the weight of the parent TSSB - subtrees = self.tree.get_mixture()[1][1:] # without the root - parent_subtree_weights = [ - subtree.root["node"].parent().tssb.weight for subtree in subtrees - ] - - # If every subtree has a parent with no weight, don't proceed - if np.sum(parent_subtree_weights) < 1e-10: - return False, elbos - - parent_subtree_weights = np.array(parent_subtree_weights) / np.sum( - parent_subtree_weights - ) - subtree = np.random.choice(subtrees, p=parent_subtree_weights) - - # Push subtree down - new_node = self.tree.extract_pivot(subtree.root["node"]) - - logger.debug(f"Trying to extract pivot from {subtree.label}") - - root_node = None - root = self.tree.root - sub_root = None - restricted = False - if local: - root_node = new_node - root = None - sub_root = new_node - restricted = True - self.tree.plot_tree() - # Ensure constant node order - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - mc_samples=mc_samples, - n_inner_steps=n_iters, - go_down=True, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def subtree_reattach( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - """ - Move a subtree to a different clone - """ - success = False - elbos = [] - - subtrees = self.tree.get_subtrees(get_roots=True) - - # Pick a subtree with more than 1 node - in_subtrees = [ - subtree[1] for subtree in subtrees if len(subtree[0].root["children"]) > 0 - ] - - # If there is any subtree with more than 1 node, proceed - if len(in_subtrees) > 0: - success = True - # Choose one subtree - subtreeA = np.random.choice( - in_subtrees, p=[1.0 / len(in_subtrees)] * len(in_subtrees) - ) - - # Choose one of its nodes uniformly which is not the root - node_weights, nodes, roots = subtreeA["node"].get_mixture(get_roots=True) - nodeA_idx = ( - np.random.choice( - len(roots[1:]), p=[1.0 / len(roots[1:])] * len(roots[1:]) - ) - + 1 - ) - nodeA_parent_idx = np.where(np.array(nodes) == nodes[nodeA_idx].parent())[ - 0 - ][0] - - # Choose another subtree that's similar to the subtree's top node - rem_subtrees = [s[1] for s in subtrees if s[1]["node"] != subtreeA["node"]] - sims = [ - 1.0 - / ( - np.mean( - np.abs( - roots[nodeA_idx]["node"].node_mean - - s["node"].root["node"].node_mean - ) - ) - + 1e-8 - ) - for s in rem_subtrees - ] - new_subtree = np.random.choice(rem_subtrees, p=sims / np.sum(sims)) - - logger.debug( - f"Trying to set {roots[nodeA_idx]['node'].label} below {new_subtree['node'].label}" - ) - - # Move subtree - optimal_init = bool(np.random.binomial(1, 0.5)) - pivot_changed = self.tree.subtree_reattach_to( - roots[nodeA_idx]["node"], new_subtree["node"], optimal_init=optimal_init - ) - - # Also swap to make the moved subtree root be the new root? - # if len(list(roots[nodeA_idx]["node"].children())) == 0: - # if np.random.binomial(1, 0.5): - # self.tree.swap_nodes( - # roots[nodeA_idx]["node"], new_subtree["node"].root["node"] - # ) - - # self.tree.reset_variational_parameters(variances_only=True) - # init_baseline = jnp.mean(self.tree.data, axis=0) - # init_log_baseline = jnp.log(init_baseline / init_baseline[0])[1:] - # self.tree.root['node'].root['node'].log_baseline_mean = init_log_baseline + np.random.normal(0, .5, size=self.tree.data.shape[1]-1) - self.tree.plot_tree() - # Ensure constant node order - root_node = None - root = self.tree.root - sub_root = None - restricted = False - update_all = True - if local: - root_node = roots[nodeA_idx]["node"] - root = None - sub_root = roots[nodeA_idx] - restricted = True - update_all = False - if pivot_changed: - n_iters = n_iters #* 5 - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - update_all=update_all, - mc_samples=mc_samples, - n_inner_steps=n_iters, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def subtree_pivot_reattach( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - """ - Reattach subtree of node A to clone B and use it as pivot of A - """ - success = False - elbos = [] - - subtrees = self.tree.get_subtrees(get_roots=True) - - # Pick a non-root subtree with more than 1 node - in_subtrees = [ - subtree[1] - for subtree in subtrees[1:] - if len(subtree[0].root["children"]) > 0 - ] - - # If there is any subtree with more than 1 node, proceed - if len(in_subtrees) > 0: - success = True - # Choose one subtree - subtreeA = np.random.choice( - in_subtrees, p=[1.0 / len(in_subtrees)] * len(in_subtrees) - ) - - # Choose one of its nodes uniformly which is not the root - node_weights, nodes, roots = subtreeA["node"].get_mixture(get_roots=True) - nodeA_idx = ( - np.random.choice( - len(roots[1:]), p=[1.0 / len(roots[1:])] * len(roots[1:]) - ) - + 1 - ) - nodeA_parent_idx = np.where(np.array(nodes) == nodes[nodeA_idx].parent())[ - 0 - ][0] - - logger.debug( - f"Trying to set {roots[nodeA_idx]['node'].label} below {subtreeA['super_parent'].label} and use it as pivot of {subtreeA['node'].label}" - ) - - # Move subtree to parent - self.tree.subtree_reattach_to( - roots[nodeA_idx]["node"], subtreeA["super_parent"].label - ) # Use label to avoid bugs with references - - # Set root of moved subtree as new pivot. TODO: Choose one node from the leaves instead - self.tree.pivot_reattach_to(subtreeA["node"], roots[nodeA_idx]["node"]) - - # And choose a leaf node of that subtree as the pivot of old subtree - # self.tree.reset_variational_parameters(variances_only=True) - # init_baseline = jnp.mean(self.tree.data, axis=0) - # init_log_baseline = jnp.log(init_baseline / init_baseline[0])[1:] - # self.tree.root['node'].root['node'].log_baseline_mean = init_log_baseline + np.random.normal(0, .5, size=self.tree.data.shape[1]-1) - self.tree.plot_tree() - # Ensure constant node order - root_node = None - root = self.tree.root - sub_root = None - restricted = False - update_all = True - if local: - root_node = roots[nodeA_idx]["node"] - root = None - sub_root = roots[nodeA_idx] - restricted = True - update_all = False - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - update_all=update_all, - go_down=True, - mc_samples=mc_samples, - n_inner_steps=n_iters,# * 5, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def swap_nodes( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - success = True - elbos = [] - - # Randomly decide whether to update pivots - update_pivots = np.random.binomial(1, 0.5) - - def tssb_swap(tssb, children_trees, ntree, n_iters, update_pivots=True): - weights, nodes = tssb.get_mixture() - nodes = tssb.get_node_roots() - - empty_root = False - if len(nodes) > 1: - nodeA_root, nodeB_root = np.random.choice(nodes, replace=False, size=2) - nodeA = nodeA_root["node"] - nodeB = nodeB_root["node"] - # Root can't be empty in the original TSSB - if len(nodes[0]["node"].data) == 0 and nodes[0]["node"].tssb.weight > 1e-6 and nodes[0]["node"].parent() is not None: - logger.debug("Swapping root") - empty_root = True - nodeA_root = nodes[0] - nodeA = nodeA_root["node"] - nodeB_root = np.random.choice(nodes[1:]) - nodeB = nodeB_root["node"] - - logger.debug(f"Trying to swap {nodeA.label} with {nodeB.label}...") - self.tree.swap_nodes(nodeA, nodeB, update_pivots=update_pivots) - - if empty_root: - # self.tree = deepcopy(ntree) - logger.debug(f"Swapped {nodeA.label} with {nodeB.label}") - # Go through all nodes below root and reset their unobserved_factors_kernel_log_std - self.tree.plot_tree() - # Ensure constant node order - ntree.optimize_elbo( - root = self.tree.root, - mc_samples=mc_samples, - n_inner_steps=n_iters,# * 10, - lr=lr, - mb_size=mb_size, - ) - else: - # ntree = self.compute_expected_score(ntree, n_burnin=n_burnin, n_samples=n_samples, thin=thin, global_params=global_params, compound=compound) - # ntree.reset_variational_parameters(variances_only=True) - # init_baseline = jnp.mean(ntree.data, axis=0) - # init_log_baseline = jnp.log(init_baseline / init_baseline[0])[1:] - # ntree.root['node'].root['node'].log_baseline_mean = init_log_baseline + np.random.normal(0, .5, size=ntree.data.shape[1]-1) - root_node = nodes[0] - root = None - sub_root = None - restricted = False - update_all = False - go_down = False - n_all_iters = 40 - if nodeB == nodeA.parent(): - root_node = nodeB_root - root = None - sub_root = nodeB_root - restricted = True - go_down = True - update_all = False - elif nodeA == nodeB.parent(): - root_node = nodeA_root - root = None - sub_root = nodeA_root - restricted = True - go_down = True - update_all = False - if not local: - root_node = None - root = self.tree.root - sub_root = None - restricted = False - update_all = True - # n_iters *= 10 # Big change, so give time to converge - if root_node: - if root_node["node"].parent() is None: - root_node = None # Update everything! - # n_iters *= 2 # Big change, so give time to converge - root = self.tree.root - sub_root = None - restricted = False - update_all = True - self.tree.plot_tree() - # Ensure constant node order - logger.debug(f"Using {root_node} as root within in-tssb swap...") - elbos = ntree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - update_all=update_all, - n_iters=n_all_iters, - go_down=go_down, - mc_samples=mc_samples, - n_inner_steps=n_iters, - lr=lr, - mb_size=mb_size, - ) - - def descend(root, subtree, ntree, done): - if not done: - if root["node"] == subtree: - tssb_swap( - subtree, - root["children"], - ntree, - n_iters, - update_pivots=update_pivots, - ) - return True - else: - for index, child in enumerate(root["children"]): - done = descend(child, subtree, ntree, done) - if done: - break - - # Randomly decide between within TSSB swap or unrestricted in ntssb - within_tssb = np.random.binomial(1, 0.5) - - if within_tssb: - # Uniformly pick a subtree with more than 1 node - subtrees = self.tree.get_mixture()[1] - subtrees = [ - subtree for subtree in subtrees if len(subtree.root["children"]) > 0 - ] - if len(subtrees) > 0: - subtree = np.random.choice( - subtrees, p=[1.0 / len(subtrees)] * len(subtrees) - ) - - descend(self.tree.root, subtree, self.tree, False) - else: - nodes = self.tree.get_node_roots() - nodes = nodes[1:] # without root - - # Randomly decide between parent-child and unrestricted - unrestricted = np.random.binomial(1, 0.5) - if unrestricted: - nodeA_root, nodeB_root = np.random.choice(nodes, replace=False, size=2) - nodeA = nodeA_root["node"] - nodeB = nodeB_root["node"] - else: - nodeA_root = np.random.choice(nodes) - nodeA = nodeA_root["node"] - nodeB = nodeA.parent() - nodeB_root = nodeB.get_tssb_root() - - if nodeB is not None: - logger.debug(f"Trying to swap {nodeA.label} with {nodeB.label}...") - self.tree.swap_nodes(nodeA, nodeB, update_pivots=update_pivots) - root_node = nodeB - root = None - sub_root=nodeB_root - restricted=False - update_all=False - go_down=True - if unrestricted: - if nodeA == nodeB.parent(): - root_node = nodeA - root=None - sub_root=nodeA_root - restricted=True - update_all=False - go_down=True - elif nodeB == nodeA.parent(): - root_node = nodeB - root=None - sub_root=nodeB_root - restricted=True - update_all=False - go_down=True - else: - mrca = self.tree.get_mrca(nodeA, nodeB) - mrca_root = mrca.get_tssb_root() - root=None - sub_root=mrca_root - restricted=True - update_all=False - go_down=True - for child in root_node.children(): - child.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.clip( - root_node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ], - -10, - 10, - ) - # root_node = self.tree.root['node'].root['node'] - new_n_iters = n_iters - if root_node.parent() is None: - new_n_iters = n_iters - if not local: - root_node = None - root = self.tree.root - sub_root = None - restricted = False - update_all = True - go_down = False - self.tree.plot_tree() - # Ensure constant node order - elbos = self.tree.optimize_elbo( - root=root, - sub_root=sub_root, - restricted=restricted, - update_all=update_all, - go_down=go_down, - mc_samples=mc_samples, - n_inner_steps=new_n_iters, - lr=lr, - mb_size=mb_size, - ) - - return success, elbos - - def transfer_factor( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - # Take events from a factor and put them in the node that contains - # cells that use that factor the most - - success = True - elbos = [] - - # Randomly choose a factor - factor_idx = np.random.randint( - self.tree.root["node"].root["node"].num_global_noise_factors - ) - - # Get genes in the factor - target_genes = np.argsort( - np.abs( - self.tree.root["node"] - .root["node"] - .variational_parameters["globals"]["noise_factors_mean"][factor_idx] - ) - )[-10:] - - # Get cells that give large weight to it - thres = np.quantile( - np.abs( - self.tree.root["node"] - .root["node"] - .variational_parameters["globals"]["cell_noise_mean"][:, factor_idx] - ), - 0.75, - ) - target_cells = np.where( - np.abs( - self.tree.root["node"] - .root["node"] - .variational_parameters["globals"]["cell_noise_mean"][:, factor_idx] - ) - > thres - )[0] - - # Get node that most of them attach to - target_node = max( - list(np.array(self.tree.assignments)[target_cells]), - key=list(np.array(self.tree.assignments)[target_cells]).count, - ) - - # Increase kernel on the genes that are affected by that factor - target_node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ][target_genes] = -1.0 - - # Remove these genes from the factor - self.tree.root["node"].root["node"].variational_parameters["globals"][ - "noise_factors_mean" - ][factor_idx, target_genes] *= 0.0 - - # Remove weight of this factor from the target cells - self.tree.root["node"].root["node"].variational_parameters["globals"][ - "cell_noise_mean" - ][target_cells, factor_idx] *= 0.0 - - logger.debug( - f"Trying to move factor {factor_idx} to node {target_node.label}..." - ) - - # This move has to be global because we mess with a noise factor - root_node = None - elbos = self.tree.optimize_elbo( - root_node=root_node, - mc_samples=mc_samples, - n_iters=n_iters, #* 5, - thin=thin, - tol=tol, - lr=lr, - mb_size=mb_size, - max_nodes=max_nodes, - init=False, - debug=debug, - opt=opt, - opt_triplet=self.opt_triplet, - callback=callback, - **callback_kwargs, - ) - - return success, elbos - - def transfer_unobserved( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - # Take events from a node and put them in the factor that cells below - # that node most use - - success = True - elbos = [] - - # Randomly choose a node, biased towards nodes with most events - nodes = self.tree.get_nodes() - n_events = [ - np.sum( - np.exp( - node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] - ) - ) - for node in nodes[1:] - ] - node = np.random.choice(nodes[1:], p=n_events / np.sum(n_events)) - - # Get all descendants of node - nodes = node.get_descendants() - - # Get cells in those nodes - data = np.concatenate([list(n.data) for n in nodes]).astype(int) - if len(data) != 0: - - # Get a factor (biased towards the ones those cells like the most) - factor_counts = np.bincount( - np.argmax( - np.abs( - self.tree.root["node"] - .root["node"] - .variational_parameters["globals"]["cell_noise_mean"][data] - ), - axis=1, - ) - ) - target_factor = np.random.choice( - np.arange(len(factor_counts)), p=factor_counts / np.sum(factor_counts) - ) - else: - target_factor = np.random.choice( - np.arange(self.tree.root["node"].root["node"].num_global_noise_factors) - ) - - # Get events from node and put them in factor - node_event_locs = np.where( - node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"] - > -1 - )[0] - node_event_vec = node.variational_parameters["locals"][ - "unobserved_factors_mean" - ] - self.tree.root["node"].root["node"].variational_parameters["globals"][ - "noise_factors_mean" - ][target_factor, node_event_locs] = node_event_vec[node_event_locs] - - # Remove those events from node and its descendants - node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"][ - node_event_locs - ] = -2.0 - for desc in nodes: - desc.variational_parameters["locals"]["unobserved_factors_mean"][ - node_event_locs - ] = 0.0 - - logger.debug( - f"Trying to move events in node {node.label} to factor {target_factor}..." - ) - - # This move has to be global because we mess with a noise factor - root_node = None - elbos = self.tree.optimize_elbo( - root_node=root_node, - mc_samples=mc_samples, - n_iters=n_iters,# * 5, - thin=thin, - tol=tol, - lr=lr, - mb_size=mb_size, - max_nodes=max_nodes, - init=False, - debug=debug, - opt=opt, - opt_triplet=self.opt_triplet, - callback=callback, - **callback_kwargs, - ) - - return success, elbos - - def clean_factors( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - # If a factor encodes a CNV profile, remove it and move cells that used - # it to the actual clone with that CNV profile - pass - - def perturb_node( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - # Move node towards the data in its neighborhood that is currently explained by nodes in its neighborhood - success = True - elbos = [] - - nodes = self.tree.get_nodes() - node = np.random.choice(nodes) - - # Decide wether to move closer to parent, sibling or child - parent = node.parent() - if parent is not None: - siblings = np.array([n for n in list(parent.children()) if n != node]) - idx = np.argsort([n.label for n in siblings]) - siblings = siblings[idx] - parent = np.array([parent]) - else: - parent = np.array([]) - siblings = np.array([]) - children = np.array(list(node.children())) - idx = np.argsort([n.label for n in children]) - children = children[idx] - if len(children) == 0: - children = np.array([]) - possibilities = np.concatenate([parent, siblings, children]) - probs = np.array( - [1 + node.num_local_data() for node in possibilities] - ) # the more data they have, the more likely it is that we decide to move towards them - probs = probs / np.sum(probs) - - target = np.random.choice(possibilities, p=probs) - - logger.debug(f"Trying to move {node.label} close to {target.label}...") - - self.tree.perturb_node(node, target) - root_node = node - if not local: - root_node = None - elbos = self.tree.optimize_elbo( - root_node=root_node, - mc_samples=mc_samples, - n_iters=n_iters, - thin=thin, - tol=tol, - lr=lr, - mb_size=mb_size, - max_nodes=max_nodes, - init=False, - debug=debug, - opt=opt, - opt_triplet=self.opt_triplet, - callback=callback, - **callback_kwargs, - ) - - return success, elbos - - def clean_node( - self, - local=False, - mc_samples=1, - n_iters=100, - thin=10, - tol=1e-7, - lr=0.05, - mb_size=100, - max_nodes=5, - debug=False, - opt=None, - callback=None, - **callback_kwargs, - ): - # Get node with bad kernel and clean it up - success = True - elbos = [] - - nodes = self.tree.get_nodes() - - n_genes = nodes[0].observed_parameters.shape[0] - - frac_events = np.array( - [ - np.sum( - np.exp( - node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] - ) - > 0.2 - ) - / n_genes - for node in nodes - ] - ) - - probs = ( - 1e-6 + frac_events - ) # the more complex the kernel, the more likely it is to clean it - probs = probs / np.sum(probs) - - node = np.random.choice(nodes, p=probs) - - # Get the number of nodes with too many events - n_bad_nodes = np.sum(frac_events > 1 / 3) - frac_bad_nodes = n_bad_nodes / len(nodes) - if frac_bad_nodes > 1 / 3 or n_bad_nodes > 3: - # Reset all unobserved_factors - logger.debug(f"Trying to clean all nodes...") - for node in nodes: - node.variational_parameters["locals"]["unobserved_factors_mean"] *= 0.0 - # node.variational_parameters["locals"][ - # "unobserved_factors_log_std" - # ] = -2 * np.ones((n_genes,)) - node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.log( - node.unobserved_factors_kernel_concentration_caller() - ) * np.ones( - (n_genes,) - ) - # node.variational_parameters["locals"][ - # "unobserved_factors_kernel_log_std" - # ] = -2 * np.ones((n_genes,)) - - root_node = nodes[0] - if not local: - root_node = None - elbos = self.tree.optimize_elbo( - root_node=root_node, - mc_samples=mc_samples, - n_iters=n_iters,# * 5, - thin=thin, - tol=tol, - lr=lr, - mb_size=mb_size, - max_nodes=max_nodes, - init=False, - debug=debug, - opt=opt, - opt_triplet=self.opt_triplet, - callback=callback, - **callback_kwargs, - ) - else: - logger.debug(f"Trying to clean {node.label}...") - - node.variational_parameters["locals"]["unobserved_factors_mean"] *= 0.0 - # node.variational_parameters["locals"][ - # "unobserved_factors_log_std" - # ] = -2 * np.ones((n_genes,)) - node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" - ] = np.log(node.unobserved_factors_kernel_concentration_caller()) * np.ones( - (n_genes,) - ) - # node.variational_parameters["locals"][ - # "unobserved_factors_kernel_log_std" - # ] = -2 * np.ones((n_genes,)) - root_node = node - if not local: - root_node = None - elbos = self.tree.optimize_elbo( - root_node=root_node, - mc_samples=mc_samples, - n_iters=n_iters, - thin=thin, - tol=tol, - lr=lr, - mb_size=mb_size, - max_nodes=max_nodes, - init=False, - debug=debug, - opt=opt, - opt_triplet=self.opt_triplet, - callback=callback, - **callback_kwargs, - ) - - return success, elbos diff --git a/scatrex/ntssb/tssb.py b/scatrex/ntssb/tssb.py index f82978b..46faf40 100644 --- a/scatrex/ntssb/tssb.py +++ b/scatrex/ntssb/tssb.py @@ -15,7 +15,7 @@ import numpy as np import matplotlib -from ..util import * +from ..utils.math_utils import * import logging @@ -35,13 +35,16 @@ def __init__( root_node, label, ntssb=None, + parent=None, dp_alpha=1.0, dp_gamma=1.0, min_depth=0, max_depth=15, - alpha_decay=0.5, - eta=0.0, + alpha_decay=1., + eta=1.0, + children_root_nodes=[], # Root nodes of any children TSSBs color="black", + seed=42, ): if root_node is None: raise Exception("Root node must be specified.") @@ -52,7 +55,10 @@ def __init__( self.dp_gamma = dp_gamma # smaller dp_gamma => larger psi => less nodes self.alpha_decay = alpha_decay self.weight = 1.0 + self.children_root_nodes = children_root_nodes self.color = color + self.seed = seed + self._parent = parent self.eta = eta # wether to put more weight on top or bottom of tree in case we want fixed weights @@ -62,7 +68,7 @@ def __init__( self.root = { "node": root_node, - "main": boundbeta(1.0, dp_alpha) + "main": boundbeta(1.0, dp_alpha, np.random.default_rng(self.seed)) if self.min_depth == 0 else 0.0, # if min_depth > 0, no data can be added to the root (main stick is nu) "sticks": empty((0, 1)), # psi sticks @@ -82,8 +88,56 @@ def __init__( self.kl = -1e6 self._data = set() - def add_data(self, l): - self._data.update(l) + self.n_nodes = 1 + + self.variational_parameters = {'delta_1': 1., 'delta_2': 1., # nu stick + 'sigma_1': 1., 'sigma_2': 1., # psi stick + 'q_c': [], # prob of assigning each cell to this TSSB + 'LSE_z': [], # normalizing constant for prob of assigning each cell to nodes + 'LSE_rho': [], # normalizing constant for each child TSSB's probs of assigning each node to nodes in this TSSB (this is a list) + 'll': [], # auxiliary quantity + 'sum_E_log_1_nu': 0., # auxiliary quantity + 'E_log_phi': 0., # auxiliary quantity + } + + def parent(self): + return self._parent + + def get_param_dict(self): + """ + Go from a dictionary where each node is a TSSB to a dictionary where each node is a dictionary, + with `params` and `weight` keys + """ + param_dict = { + "param": self.root['node'].get_params(), + "mean": self.root['node'].get_mean(), + "weight": self.root['weight'], + "children": [], + "label": self.root['label'], + "color": self.root['color'], + "size": len(self.root['node'].data), + "pivot_probs": self.root['node'].variational_parameters['q_rho'], + } + def descend(root, root_new): + for child in root["children"]: + child_new = { + "param": child['node'].get_params(), + "mean": child['node'].get_mean(), + "weight": child['weight'], + "children": [], + "label": child['label'], + "color": child['color'], + "size": len(child['node'].data), + "pivot_probs": child['node'].variational_parameters['q_rho'], + } + root_new['children'].append(child_new) + descend(child, root_new['children'][-1]) + + descend(self.root, param_dict) + return param_dict + + def add_datum(self, id): + self._data.add(id) def remove_data(self): self._data.clear() @@ -106,6 +160,7 @@ def descend(root): self.root["node"].remove_data() descend(self.root) self.assignments = [] + self.remove_data() def reset_tree(self): # Clear tree @@ -122,27 +177,167 @@ def reset_tree(self): } def reset_node_parameters( - self, root_params=True, down_params=True, node_hyperparams=None + self, min_dist=0.7, **node_hyperparams ): + def get_distance(nodeA, nodeB): + return np.sqrt(np.sum((nodeA.get_mean() - nodeB.get_mean())**2)) + # Reset node parameters def descend(root): - root["node"].reset_parameters( - root_params=root_params, down_params=down_params, **node_hyperparams + for i, child in enumerate(root["children"]): + accepted = False + while not accepted: + child["node"].reset_parameters( + **node_hyperparams + ) + dist_to_parent = get_distance(root["node"], child["node"]) + # Reject sample if too close to any other child + dists = [] + for j, child2 in enumerate(root["children"]): + if j < i: + dists.append(get_distance(child["node"], child2["node"])) + if np.all(np.array(dists) >= min_dist*dist_to_parent): + accepted = True + else: + child["node"].seed += 1 + descend(child) + + self.root["node"].reset_parameters(**node_hyperparams) + descend(self.root) + + def set_node_hyperparams(self, **kwargs): + def descend(root): + root['node'].set_node_hyperparams(**kwargs) + for child in root['children']: + descend(child) + descend(self.root) + + def set_weights(self, node_weights_dict): + def descend(root): + root['weight'] = node_weights_dict[root['label']] + for child in root['children']: + descend(child) + descend(self.root) + + def set_sticks_from_weights(self): + def descend(root): + stick = self.input_tree.get_sum_weights_subtree(child) + if i < len(input_tree_dict[label]["children"]) - 1: + sum = 0 + for j, c in enumerate(input_tree_dict[label]["children"][i:]): + sum = sum + self.input_tree.get_sum_weights_subtree(c) + stick = stick / sum + else: + stick = 1.0 + + sticks = vstack( + [ + root['sticks'], + stick + ] ) - for child in root["children"]: + main = root["weight"] + subtree_weights_sum = self.input_tree.get_sum_weights_subtree(child) + main = main / subtree_weights_sum + root['main'] = main + root['sticks'] = sticks + for child in root['children']: descend(child) + def set_pivot_priors(self): + def descend(root, depth=0): + root["node"].pivot_prior_prob = self.eta ** depth + prior_norm = root["node"].pivot_prior_prob + for child in root['children']: + child_prior_norm = descend(child, depth=depth+1) + prior_norm += child_prior_norm + return prior_norm + + prior_norm = descend(self.root) + + def descend(root, depth=0): + root["node"].pivot_prior_prob = root["node"].pivot_prior_prob/prior_norm + for child in root['children']: + descend(child, depth=depth+1) + + # Normalize + descend(self.root) + + def sample_variational_distributions(self, **kwargs): + def descend(root): + root['node'].sample_variational_distributions(**kwargs) + for child in root['children']: + descend(child) descend(self.root) - def reset_node_variational_parameters(self, **kwargs): + def set_learned_parameters(self): + def descend(root): + root['node'].set_learned_parameters() + for child in root['children']: + descend(child) + descend(self.root) + + def get_pivot_probabilities(self, i): + def descend(root): + probs = [root['node'].variational_parameters['q_rho'][i]] + nodes = [root['node']] + for child in root['children']: + child_probs, child_nodes = descend(child) + nodes.extend(child_nodes) + probs.extend(child_probs) + return probs, nodes + return descend(self.root) + + def reset_sufficient_statistics(self, num_batches=1): + def descend(root): + root["node"].reset_sufficient_statistics(num_batches=num_batches) + for child in root["children"]: + descend(child) + descend(self.root) + + self.suff_stats = { + 'mass': {'total': 0, 'batch': [0.] * num_batches}, # \sum_n q(c_n = this tree) + 'ent': {'total': 0, 'batch': [0.] * num_batches}, # \sum_n q(c_n = this tree) log q(c_n = this tree) + } + + def reset_variational_parameters(self, alpha_nu=1., beta_nu=1., alpha_psi=1., beta_psi=1., **kwargs): + # Reset NTSSB parameters relative to this TSSB + self.variational_parameters = {'delta_1': alpha_nu, 'delta_2': beta_nu, # nu stick + 'sigma_1': alpha_psi, 'sigma_2': beta_psi, # psi stick + 'q_c': jnp.ones(self.ntssb.num_data,) * alpha_nu/(alpha_nu+beta_nu), # prob of assigning each cell to this TSSB + 'LSE_z': jnp.ones(self.ntssb.num_data,), # normalizing constant for prob of assigning each cell to nodes + 'LSE_rho': [1.] * len(self.children_root_nodes), # normalizing constant for each child TSSB's probs of assigning each node to nodes in this TSSB (this is a list) + 'll': jnp.ones(self.ntssb.num_data,), # auxiliary quantity + 'sum_E_log_1_nu': 0., # auxiliary quantity + 'E_log_phi': 0., # auxiliary quantity + } + # Reset node parameters def descend(root): root["node"].reset_variational_parameters(**kwargs) + z_norm = jnp.array(root["node"].variational_parameters['q_z']) + rho_norm = jnp.array(root["node"].variational_parameters['q_rho']) for child in root["children"]: - descend(child) + child_z_norm, child_rho_norm = descend(child) + z_norm += child_z_norm + rho_norm += child_rho_norm + return z_norm, rho_norm + + self.root["node"]._parent = None + + # Get normalizing constants + z_norm, rho_norm = descend(self.root) + # Apply normalization + def descend(root): + root["node"].reset_variational_parameters(**kwargs) + root["node"].variational_parameters['q_z'] = root["node"].variational_parameters['q_z'] / z_norm + root["node"].variational_parameters['q_rho'] = list(root["node"].variational_parameters['q_rho'] / rho_norm) + for child in root["children"]: + descend(child) descend(self.root) + def sample_new_tree(self, num_data, cull=True): # Clear current tree self.reset_tree() @@ -234,336 +429,694 @@ def add_data(self, data, initialize_assignments=False): node.add_datum(n) self.assignments.append(node) - def resample_node_params( - self, - iters=1, - independent_subtrees=False, - top_node=None, - data_indices=None, - compound=True, - ): - """ - Go through all nodes in the tree, starting at the bottom, and resample their parameters iteratively. - """ - if top_node is None: - top_node = self.root + def add_node(self, target_root, seed=None): + if seed is None: + seed = self.seed + target_root["node"].seed + len(target_root["children"]) - for iter in range(iters): + node = target_root["node"].spawn(target_root["node"].observed_parameters, seed=seed) - def descend(root): - for index, child in enumerate(root["children"]): - descend(child) - root["node"].resample_params( - independent_subtrees=independent_subtrees, - data_indices=data_indices, - compound=compound, + rng = np.random.default_rng(seed) + stick_length = boundbeta(1, self.dp_gamma, rng) + target_root["sticks"] = np.vstack([target_root["sticks"], stick_length]) + target_root["children"].append( + { + "node": node, + "main": boundbeta( + 1.0, (self.alpha_decay ** (target_root["node"].depth + 1)) * self.dp_alpha, np.random.default_rng(seed+1) ) + if self.min_depth <= (target_root["node"].depth + 1) + else 0.0, + "sticks": np.empty((0, 1)), + "children": [], + "label": node.label, + } + ) + + # Update pivot prior + self.set_pivot_priors() + + self.n_nodes += 1 + + return node + + def merge_nodes(self, parent_root, source_root, target_root): + # Add mass of source to mass of target + target_root['node'].variational_parameters['q_z'] += source_root['node'].variational_parameters['q_z'] + # Only need to update totals, because after the merges we go back to iterate + target_root['node'].merge_suff_stats(source_root['node'].suff_stats) + + # Update pivot probs + for i in range(len(self.children_root_nodes)): + target_root['node'].variational_parameters['q_rho'][i] += source_root['node'].variational_parameters['q_rho'][i] + + # Set children of source as children of target + for i, child in enumerate(source_root['children']): + target_root['children'].append(child) + target_root["sticks"] = np.vstack([target_root["sticks"], source_root['sticks'][i]]) + child['node'].set_parent(target_root['node']) + + # Remove source from its parent's dict + nodes = np.array([n["node"] for n in parent_root["children"]]) + tokeep = np.where(nodes != source_root['node'])[0].astype(int).ravel() + parent_root["sticks"] = parent_root["sticks"][tokeep] + parent_root["children"] = list(np.array(parent_root["children"])[tokeep]) + + # Delete source node object and dict + source_root["node"].kill() + del source_root["node"] + + # Update names + self.set_node_names(root=parent_root, root_name=parent_root['node'].label) + + # Update pivot prior + self.set_pivot_priors() - descend(top_node) - - def resample_assignment(self, current_node, n, obs): - def path_lt(path1, path2): - if len(path1) == 0 and len(path2) == 0: - return 0 - elif len(path1) == 0: - return 1 - elif len(path2) == 0: - return -1 - s1 = "".join(map(lambda i: "%03d" % (i), path1)) - s2 = "".join(map(lambda i: "%03d" % (i), path2)) - - return cmp(s2, s1) - - epsilon = finfo(float64).eps - lengths = [] - reassign = 0 - better = 0 - - # Get an initial uniform variate. - ancestors = current_node.get_ancestors() - current = self.root - indices = [] - for anc in ancestors[1:]: - index = list(map(lambda c: c["node"], current["children"])).index(anc) - current = current["children"][index] - indices.append(index) - - max_u = 1.0 - min_u = 0.0 - llh_s = log(rand()) + current_node.logprob(obs) - # llh_s = self.assignments[n].logprob(self.data[n:n+1]) - 0.0000001 - while True: - new_u = (max_u - min_u) * rand() + min_u - (new_node, new_path, _) = self.find_node(new_u) - new_llh = new_node.logprob(obs) - if new_llh > llh_s: - if new_node != current_node: - if new_llh > current_node.logprob(obs): - better += 1 - current_node.remove_datum(n) - new_node.add_datum(n) - current_node = new_node - reassign += 1 - break - elif abs(max_u - min_u) < epsilon: - logger.debug("Slice sampler shrank down. Keep current state.") - break + self.n_nodes -= 1 + + def compute_elbo(self, idx): + """ + Compute the ELBO of the model in a tree traversal, abstracting away the likelihood and kernel specific functions + for the model. The seed is used for MC sampling from the variational distributions for which Eq[logp] is not analytically + available (which is the likelihood and the kernel distribution). + """ + def descend(root, depth=0, ll_contrib=0, ass_contrib=0, global_contrib=0): + self.n_nodes += 1 + # Assignments + ## E[log p(z|nu,psi))] + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + eq_logp_z = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + if root['node'].parent() is not None: + eq_logp_z += root['node'].parent().variational_parameters['sum_E_log_1_nu'] + eq_logp_z += root['node'].variational_parameters['E_log_phi'] + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu'] else: - path_comp = path_lt(indices, new_path) - if path_comp < 0: - min_u = new_u - # if we are at a leaf in a fixed tree, move on, because we can't create more nodes - if len(new_node.children()) == 0: - break - elif path_comp > 0: - max_u = new_u + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + ## E[log q(z)] + eq_logq_z = jax.lax.select(root['node'].variational_parameters['q_z'][idx] != 0, + root['node'].variational_parameters['q_z'][idx] * jnp.log(root['node'].variational_parameters['q_z'][idx]), + root['node'].variational_parameters['q_z'][idx]) + ass_contrib += eq_logp_z*root['node'].variational_parameters['q_z'][idx] - eq_logq_z + + # Sticks + E_log_nu = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + nu_kl = (self.dp_alpha * self.alpha_decay**depth - root['node'].variational_parameters['delta_2']) * E_log_1_nu + nu_kl -= (root['node'].variational_parameters['delta_1'] - 1) * E_log_nu + nu_kl += logbeta_func(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + nu_kl -= logbeta_func(1, self.dp_alpha * self.alpha_decay**depth) + psi_kl = 0. + if depth != 0: + E_log_psi = E_log_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + E_log_1_psi = E_log_1_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + psi_kl = (self.dp_gamma - root['node'].variational_parameters['sigma_2']) * E_log_1_psi + psi_kl -= (root['node'].variational_parameters['sigma_1'] - 1) * E_log_psi + psi_kl += logbeta_func(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + psi_kl -= logbeta_func(1, self.dp_gamma) + global_contrib += nu_kl + psi_kl + + # Kernel + if root['node'].parent() is None and self.parent() is None: # is the root of the root TSSB + ## E[log p(kernel)] + eq_logp_kernel = root['node'].compute_root_prior() + ## -E[log q(kernel)] + negeq_logq_kernel = root['node'].compute_root_entropy() + else: + if root['node'].parent() is not None: # Not a TSSB root. The root param probs are computed in the parent TSSBs, weighted by the pivots + ## E[log p(kernel)] + eq_logp_kernel = root['node'].compute_kernel_prior() else: - raise Exception("Slice sampler weirdness.") - - return current_node - - def resample_assignments(self): - def path_lt(path1, path2): - if len(path1) == 0 and len(path2) == 0: - return 0 - elif len(path1) == 0: - return 1 - elif len(path2) == 0: - return -1 - s1 = "".join(map(lambda i: "%03d" % (i), path1)) - s2 = "".join(map(lambda i: "%03d" % (i), path2)) - - return cmp(s2, s1) - - epsilon = finfo(float64).eps - lengths = [] - reassign = 0 - better = 0 - for n in range(self.num_data): + eq_logp_kernel = 0. + ## -E[log q(kernel)] + negeq_logq_kernel = root['node'].compute_kernel_entropy() + global_contrib += eq_logp_kernel + negeq_logq_kernel + + # Pivots + eq_logp_rootkernel = 0. + eq_logp_rho = 0. + eq_logq_rho = 0. + for i, next_tssb_root_node in enumerate(self.children_root_nodes): + ## E[log p(root kernel | rho kernel)] + eq_logp_rootkernel += root['node'].variational_parameters['q_rho'][i] * next_tssb_root_node.compute_root_kernel_prior(root['node'].samples) + ## E[log p(rho))] + eq_logp_rho += root['node'].pivot_prior_prob + ## E[log q(rho))] + eq_logq_rho = jax.lax.select(root['node'].variational_parameters['q_rho'][i] != 0, + root['node'].variational_parameters['q_rho'][i] * jnp.log(root['node'].variational_parameters['q_rho'][i]), + root['node'].variational_parameters['q_rho'][i]) + global_contrib += eq_logp_rootkernel + eq_logp_rho-eq_logq_rho + + # Likelihood + # Use node's kernel sample to evaluate likelihood + ll_contrib += root['node'].compute_loglikelihood(idx) * root['node'].variational_parameters['q_z'][idx] + + sum_E_log_1_psi = 0. + for child in root['children']: + # Auxiliary quantities + E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi + E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + sum_E_log_1_psi += E_log_1_psi + + # Go down + ll_contrib, ass_contrib, global_contrib = descend(child, depth=depth+1, + ll_contrib=ll_contrib, + ass_contrib=ass_contrib, + global_contrib=global_contrib) + + return ll_contrib, ass_contrib, global_contrib + + self.n_nodes = 0 + return descend(self.root) - # Get an initial uniform variate. - ancestors = self.assignments[n].get_ancestors() - current = self.root - indices = [] - for anc in ancestors[1:]: - index = list(map(lambda c: c["node"], current["children"])).index(anc) - current = current["children"][index] - indices.append(index) - - max_u = 1.0 - min_u = 0.0 - llh_s = log(rand()) + self.assignments[n].logprob(self.data[n : n + 1]) - # llh_s = self.assignments[n].logprob(self.data[n:n+1]) - 0.0000001 - while True: - new_u = (max_u - min_u) * rand() + min_u - (new_node, new_path, _) = self.find_node(new_u) - new_llh = new_node.logprob(self.data[n : n + 1]) - if new_llh > llh_s: - if new_node != self.assignments[n]: - if new_llh > self.assignments[n].logprob(self.data[n : n + 1]): - better += 1 - self.assignments[n].remove_datum(n) - new_node.add_datum(n) - self.assignments[n] = new_node - reassign += 1 - break - elif abs(max_u - min_u) < epsilon: - logger.debug("Slice sampler shrank down. Keep current state.") - break + def compute_elbo_suff(self): + """ + Compute the ELBO of the model in a tree traversal, abstracting away the likelihood and kernel specific functions + for the model. The seed is used for MC sampling from the variational distributions for which Eq[logp] is not analytically + available (which is the likelihood and the kernel distribution). + """ + def descend(root, depth=0, ll_contrib=0, ass_contrib=0, global_contrib=0): + self.n_nodes += 1 + # Assignments + ## E[log p(z|nu,psi))] + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + eq_logp_z = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + if root['node'].parent() is not None: + eq_logp_z += root['node'].parent().variational_parameters['sum_E_log_1_nu'] + eq_logp_z += root['node'].variational_parameters['E_log_phi'] + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu'] + else: + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + ## E[log q(z)] + ass_contrib += eq_logp_z*root['node'].suff_stats['mass']['total'] + root['node'].suff_stats['ent']['total'] + + # Sticks + E_log_nu = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + nu_kl = (self.dp_alpha * self.alpha_decay**depth - root['node'].variational_parameters['delta_2']) * E_log_1_nu + nu_kl -= (root['node'].variational_parameters['delta_1'] - 1) * E_log_nu + nu_kl += logbeta_func(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + nu_kl -= logbeta_func(1, self.dp_alpha * self.alpha_decay**depth) + psi_kl = 0. + if depth != 0: + E_log_psi = E_log_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + E_log_1_psi = E_log_1_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + psi_kl = (self.dp_gamma - root['node'].variational_parameters['sigma_2']) * E_log_1_psi + psi_kl -= (root['node'].variational_parameters['sigma_1'] - 1) * E_log_psi + psi_kl += logbeta_func(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2']) + psi_kl -= logbeta_func(1, self.dp_gamma) + global_contrib += nu_kl + psi_kl + + # Kernel + if root['node'].parent() is None and self.parent() is None: # is the root of the root TSSB + ## E[log p(kernel)] + eq_logp_kernel = root['node'].compute_root_prior() + ## -E[log q(kernel)] + negeq_logq_kernel = root['node'].compute_root_entropy() + else: + if root['node'].parent() is not None: # Not a TSSB root. The root param probs are computed in the parent TSSBs, weighted by the pivots + ## E[log p(kernel)] + eq_logp_kernel = root['node'].compute_kernel_prior() else: - path_comp = path_lt(indices, new_path) - if path_comp < 0: - min_u = new_u - # if we are at a leaf in a fixed tree, move on, because we can't create more nodes - if len(new_node.children()) == 0: - break - elif path_comp > 0: - max_u = new_u - else: - raise Exception("Slice sampler weirdness.") - lengths.append(len(new_path)) - lengths = array(lengths) - # logger.debug "reassign: "+str(reassign)+" better: "+str(better) - - # def resample_birth(self): - # # Break sticks to choose node - # u = rand() - # node, root = self.find_node(u) - # - # # Create child - # stick_length = boundbeta(1, self.dp_gamma) - # root['sticks'] = vstack([ root['sticks'], stick_length ]) - # root['children'].append({ 'node' : root['node'].spawn(False, self.root_node.observed_parameters), - # 'main' : boundbeta(1.0, (self.alpha_decay**(depth+1))*self.dp_alpha) if self.min_depth <= (depth+1) else 0.0, - # 'sticks' : empty((0,1)), - # 'children' : [] }) - # - # # Update parameters of data in parent and child node until convergence: - # # assignments (starting from parent) - # # stick lengths - # # parameters - # - # def resample_merge(self): - # # Choose any node a at random - # - # # Compute similarity of parameters of leaf sibilings and parent - # - # # Choose closest node b - # - # - # - # # If the merge is accepted, the child nodes of a are transferred to node b. - - def resample_tree_topology(self, children_trees, independent_subtrees=False): - # x = self.complete_data_log_likelihood_nomix() - post = self.ntssb.unnormalized_posterior() - weights, nodes = self.get_mixture() + eq_logp_kernel = 0. + ## -E[log q(kernel)] + negeq_logq_kernel = root['node'].compute_kernel_entropy() + global_contrib += eq_logp_kernel + negeq_logq_kernel + + # Pivots + eq_logp_rootkernel = 0. + eq_logp_rho = 0. + eq_logq_rho = 0. + for i, next_tssb_root_node in enumerate(self.children_root_nodes): + ## E[log p(root kernel | rho kernel)] + eq_logp_rootkernel += root['node'].variational_parameters['q_rho'][i] * next_tssb_root_node.compute_root_kernel_prior(root['node'].samples) + ## E[log p(rho))] + eq_logp_rho += root['node'].pivot_prior_prob + ## E[log q(rho))] + eq_logq_rho = jax.lax.select(root['node'].variational_parameters['q_rho'][i] != 0, + root['node'].variational_parameters['q_rho'][i] * jnp.log(root['node'].variational_parameters['q_rho'][i]), + root['node'].variational_parameters['q_rho'][i]) + global_contrib += eq_logp_rootkernel + eq_logp_rho-eq_logq_rho + + # Likelihood + # Use node's kernel sample to evaluate likelihood + ll_contrib += root['node'].compute_loglikelihood_suff() + + sum_E_log_1_psi = 0. + for child in root['children']: + # Auxiliary quantities + E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi + E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + sum_E_log_1_psi += E_log_1_psi + + # Go down + ll_contrib, ass_contrib, global_contrib = descend(child, depth=depth+1, + ll_contrib=ll_contrib, + ass_contrib=ass_contrib, + global_contrib=global_contrib) + + return ll_contrib, ass_contrib, global_contrib + + self.n_nodes = 0 + return descend(self.root) + + + def update_sufficient_statistics(self, batch_idx=None): + def descend(root): + root['node'].update_sufficient_statistics(batch_idx=batch_idx) + for child in root['children']: + descend(child) + + descend(self.root) + + if batch_idx is not None: + idx = self.ntssb.batch_indices[batch_idx] + else: + idx = jnp.arange(self.ntssb.num_data) + + ent = assignment_entropies(self.variational_parameters['q_c'][idx]) + E_ass = self.variational_parameters['q_c'][idx] + + new_ent = jnp.sum(ent) + new_mass = jnp.sum(E_ass) + + if batch_idx is not None: + self.suff_stats['ent']['total'] -= self.suff_stats['ent']['batch'][batch_idx] + self.suff_stats['ent']['batch'][batch_idx] = new_ent + self.suff_stats['ent']['total'] += self.suff_stats['ent']['batch'][batch_idx] + + self.suff_stats['mass']['total'] -= self.suff_stats['mass']['batch'][batch_idx] + self.suff_stats['mass']['batch'][batch_idx] = new_mass + self.suff_stats['mass']['total'] += self.suff_stats['mass']['batch'][batch_idx] + else: + self.suff_stats['ent']['total'] = new_ent + self.suff_stats['mass']['total'] = new_mass - empty_root = False - if len(nodes) > 1: - if len(nodes[0].data) == 0: - logger.debug("Swapping root") - empty_root = True - nodeAnum = 0 + def update_local_params(self, idx, update_ass=True, take_gradients=False, **kwargs): + """ + This performs a tree traversal to update the cell to node attachment probabilities. + Returns \sum_node Eq[logp(y_n|psi_node)] * q(z_n=node) with the updated q(z_n=node) + """ + def descend(root, local_grads=None): + E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + logprior = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + if root['node'].parent() is not None: + logprior += root['node'].parent().variational_parameters['sum_E_log_1_nu'] + logprior += root['node'].variational_parameters['E_log_phi'] + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu'] else: - nodeAnum = randint(0, len(nodes)) - nodeBnum = randint(0, len(nodes)) - while nodeAnum == nodeBnum: - nodeBnum = randint(0, len(nodes)) - - def swap_nodes(nodeAnum, nodeBnum, verbose=False): - def findNodes(root, nodeNum, nodeA=False, nodeB=False): - node = root - if nodeNum == nodeAnum: - nodeA = node - if nodeNum == nodeBnum: - nodeB = node - for i, child in enumerate(root["children"]): - nodeNum = nodeNum + 1 - (nodeA, nodeB, nodeNum) = findNodes( - child, nodeNum, nodeA, nodeB - ) - return (nodeA, nodeB, nodeNum) + root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + + # Take gradient of locals + weights = root['node'].variational_parameters['q_z'][idx] * self.variational_parameters['q_c'][idx] + if take_gradients: + local_grads_down = root["node"].compute_ll_locals_grad(self.ntssb.data[idx], idx, weights) # returns a tuple of grads for each cell-specific param + if local_grads is None: + local_grads = list(local_grads_down) + else: + for i, grads in enumerate(list(local_grads_down)): + local_grads[i] += grads + ll = root['node'].compute_loglikelihood(idx) + root['node'].variational_parameters['ll'] = ll + root['node'].variational_parameters['logprior'] = logprior + new_log_prob = ll + logprior + if update_ass: + root['node'].variational_parameters['q_z'] = root['node'].variational_parameters['q_z'].at[idx].set(new_log_prob) + + logqs = [new_log_prob] + sum_E_log_1_psi = 0. + for child in root['children']: + E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi + E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2']) + sum_E_log_1_psi += E_log_1_psi + + # Go down + child_log_probs, child_local_param_grads = descend(child, local_grads=local_grads) + logqs.extend(child_log_probs) + if child_local_param_grads is not None: + for i, grads in enumerate(list(child_local_param_grads)): + local_grads[i] += grads + # local_grads += child_local_param_grads + + return logqs, local_grads + + logqs, local_grads = descend(self.root) + + # Compute LSE + logqs = jnp.array(logqs).T # make it obs by nodes + self.variational_parameters['LSE_z'] = jax.scipy.special.logsumexp(logqs, axis=1) + # Set probs and return sum of weighted likelihoods + def descend(root): + if update_ass: + newvals = jnp.exp(root['node'].variational_parameters['q_z'][idx] - self.variational_parameters['LSE_z']) + root['node'].variational_parameters['q_z'] = root['node'].variational_parameters['q_z'].at[idx].set(newvals) + ell = (root['node'].variational_parameters['ll'] + root['node'].variational_parameters['logprior']) * root['node'].variational_parameters['q_z'][idx] + ent = -jax.lax.select(root['node'].variational_parameters['q_z'][idx] != 0, + root['node'].variational_parameters['q_z'][idx] * jnp.log(root['node'].variational_parameters['q_z'][idx]), + root['node'].variational_parameters['q_z'][idx]) + + for child in root['children']: + ell_, ent_ = descend(child) + ell += ell_ + ent += ent_ + return ell, ent + + ell, ent = descend(self.root) + return ell, ent, local_grads + + def get_global_grads(self, idx): + """ + This performs a tree traversal to update the global parameters + """ + def descend(root, globals_grads=None): + weights = root['node'].variational_parameters['q_z'][idx] * self.variational_parameters['q_c'][idx] + globals_grads_down = root["node"].compute_ll_globals_grad(self.ntssb.data[idx], idx, weights) + if globals_grads is None: + globals_grads = list(globals_grads_down) + else: + for i, grads in enumerate(list(globals_grads_down)): + globals_grads[i] += grads + for child in root['children']: + child_globals_grads = descend(child, globals_grads=globals_grads) + for i, grads in enumerate(list(child_globals_grads)): + globals_grads[i] += grads + return globals_grads + + # Get gradient of loss of data likelihood weighted by assignment probability to each node wrt current sample of global params + return descend(self.root) - (nodeA, nodeB, nodeNum) = findNodes(self.root, nodeNum=0) + def update_stick_params(self, root=None, memoized=True): + def descend(root, depth=0): + mass_down = 0 + for child in root['children'][::-1]: + child_mass = descend(child, depth=depth+1) - paramsA = nodeA["node"].unobserved_factors - dataA = set(nodeA["node"].data) - mainA = nodeA["main"] + child['node'].variational_parameters['sigma_1'] = 1.0 + child_mass + child['node'].variational_parameters['sigma_2'] = self.dp_gamma + mass_down + mass_down += child_mass - nodeA["node"].unobserved_factors = nodeB["node"].unobserved_factors - nodeA["node"].node_mean = ( - nodeA["node"].baseline_caller() - * nodeA["node"].cnvs - / 2 - * np.exp(nodeA["node"].unobserved_factors) - ) + if memoized: + mass_here = root['node'].suff_stats['mass']['total'] + else: + mass_here = jnp.sum(root['node'].variational_parameters['q_z'] * self.variational_parameters['q_c']) + root['node'].variational_parameters['delta_1'] = 1.0 + mass_here + root['node'].variational_parameters['delta_2'] = (self.alpha_decay**depth) * self.dp_alpha + mass_down - for dataid in list(dataA): - nodeA["node"].remove_datum(dataid) - for dataid in nodeB["node"].data: - nodeA["node"].add_datum(dataid) - self.ntssb.assignments[dataid]["node"] = nodeA["node"] - nodeA["main"] = nodeB["main"] - - nodeB["node"].unobserved_factors = paramsA - nodeB["node"].node_mean = ( - nodeB["node"].baseline_caller() - * nodeB["node"].cnvs - / 2 - * np.exp(nodeB["node"].unobserved_factors) - ) + return mass_here + mass_down - dataB = set(nodeB["node"].data) - - for dataid in list(dataB): - nodeB["node"].remove_datum(dataid) - for dataid in dataA: - nodeB["node"].add_datum(dataid) - self.ntssb.assignments[dataid]["node"] = nodeB["node"] - nodeB["main"] = mainA - - if not independent_subtrees: - # Go to subtrees - # For each subtree, if pivot was swapped, update it - for child in children_trees: - if child["pivot_node"] == nodeA["node"]: - child["pivot_node"] = nodeB["node"] - child["node"].root["node"].set_parent( - nodeB["node"], reset=False - ) - elif child["pivot_node"] == nodeB["node"]: - child["pivot_node"] = nodeA["node"] - child["node"].root["node"].set_parent( - nodeA["node"], reset=False - ) + # Update sticks + if root is None: + root = self.root + descend(root) - logger.debug( - f"Swapped {nodeA['node'].label} with {nodeB['node'].label}" - ) + def update_node_params(self, key, root=None, memoized=True, step_size=0.0001, mc_samples=10, i=0, adaptive=True, **kwargs): + """ + Update variational parameters for kernels, sticks and pivots - if empty_root: - logger.debug("checking alternative root") - nodenum = [] - for ii, nn in enumerate(nodes): - if len(nodes[ii].data) > 0: - nodenum.append(ii) - post_temp = zeros(len(nodenum)) - for idx, nodeBnum in enumerate(nodenum): - logger.debug(f"nodeBnum: {nodeBnum}") - logger.debug(f"nodeAnum: {nodeAnum}") - swap_nodes(nodeAnum, nodeBnum) - for i in range(5): - self.resample_sticks() - self.ntssb.root["node"].root["node"].resample_cell_params() - post_new = self.ntssb.unnormalized_posterior() - post_temp[idx] = post_new - - accept_prob = np.exp(np.min([0.0, post_new - post])) - - if rand() > accept_prob: - swap_nodes(nodeAnum, nodeBnum) - for i in range(5): - self.resample_sticks() - self.ntssb.root["node"].root["node"].resample_cell_params() - - if nodeBnum == len(nodes) - 1: - logger.debug("forced swapping") - nodeBnum = post_temp.argmax() + 1 - swap_nodes(nodeAnum, nodeBnum) - for i in range(5): - self.resample_sticks() - self.ntssb.root["node"].root[ - "node" - ].resample_cell_params() - - self.resample_node_params() - self.resample_stick_orders() - else: - logger.debug("Successful swap!") - self.resample_node_params() - self.resample_stick_orders() - break - # else: - # swap_nodes(nodeAnum,nodeBnum, verbose=True) - # for i in range(5): - # self.resample_sticks() - # self.ntssb.root['node'].root['node'].resample_cell_params() - # - # post_new = self.ntssb.unnormalized_posterior() - # accept_prob = np.exp(np.min([0., post_new - post])) - # if (rand() > accept_prob): - # logger.debug("Unsuccessful swap.") - # swap_nodes(nodeAnum,nodeBnum) # swap back - # for i in range(5): - # self.resample_sticks() - # self.ntssb.root['node'].root['node'].resample_cell_params() - # - # else: - # logger.debug("Successful swap!") - # self.resample_node_params() - # self.resample_stick_orders() + Each node must have two parameters for the kernel: a direction and a state. + We assume the tree kernel, regardless of the model, is always defined as + P(direction|parent_direction) and P(state|direction,parent_state). + For each node, we first update the direction and then the state, taking one gradient step for each + parameter and then moving on to the next nodes in the tree traversal + + PSEUDOCODE: + def descend(root): + alpha, alpha_grad = sample_grad_alpha + psi, psi_grad = sample_grad_psi + + alpha_grad += Gradient of logp(alpha|parent_alpha) wrt this alpha + alpha_grad += Gradient of logp(psi|parent_psi,alpha) wrt this alpha + + alpha_grad += Gradient of logq(alpha) wrt this alpha + psi_grad += Gradient of logq(psi) wrt this psi + + psi_grad += Gradient of logp(x|psi) wrt this psi + + for each child: + child_alpha, child_psi = descend(child) + alpha_grad += Gradient of logp(child_alpha|alpha) wrt this alpha + psi_grad += Gradient of logp(child_psi|psi,child_alpha) wrt this psi + + for each child_root: + alpha_grad += Gradient of logp(child_root_alpha|alpha) wrt this alpha + psi_grad += Gradient of logp(child_root_psi|psi,child_root_alpha) wrt this psi + + new_alpha_params = alpha_params + alpha_grad * step_size + new_alpha = sample_alpha + + psi_grad += Gradient of logp(psi|parent_psi,new_alpha) wrt this psi + new_psi_params = psi_params + psi_grad * step_size + new_psi = sample_psi + + return new_alpha, new_psi + + """ + def descend(root, key, depth=0): + direction_sample_grad = 0. + state_sample_grad = 0. + + if depth != 0: + key, sample_grad = root['node'].direction_sample_and_grad(key, n_samples=mc_samples) + direction_curr_sample, direction_params_grad = sample_grad + key, sample_grad = root['node'].state_sample_and_grad(key, n_samples=mc_samples) + state_curr_sample, state_params_grad = sample_grad + else: + root['node'].sample_kernel(n_samples=mc_samples) + direction_curr_sample = root['node'].get_direction_sample() + state_curr_sample = root['node'].get_state_sample() + + if depth != 0: + direction_parent_sample = root["node"].parent().get_direction_sample() + state_parent_sample = root["node"].parent().get_state_sample() + + direction_sample_grad += root["node"].compute_direction_prior_grad(direction_curr_sample, direction_parent_sample, state_parent_sample) + direction_sample_grad += root["node"].compute_state_prior_grad_wrt_direction(state_curr_sample, state_parent_sample, direction_curr_sample) + + direction_params_entropy_grad = root["node"].compute_direction_entropy_grad() + state_params_entropy_grad = root["node"].compute_state_entropy_grad() + + if memoized: + state_sample_grad += root["node"].compute_ll_state_grad_suff(state_curr_sample) + else: + weights = root['node'].variational_parameters['q_z'] * self.variational_parameters['q_c'] + state_sample_grad += root["node"].compute_ll_state_grad(self.ntssb.data, weights, state_curr_sample) + + mass_down = 0 + for child in root['children'][::-1]: + child_mass, direction_child_sample, state_child_sample = descend(child, key, depth=depth+1) + + child['node'].variational_parameters['sigma_1'] = 1.0 + child_mass + child['node'].variational_parameters['sigma_2'] = self.dp_gamma + mass_down + mass_down += child_mass + + if depth != 0: + direction_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_direction(direction_child_sample, direction_curr_sample, state_curr_sample) + state_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_state(direction_child_sample, direction_curr_sample, state_curr_sample) + state_sample_grad += root["node"].compute_state_prior_child_grad(state_child_sample, state_curr_sample, direction_child_sample) + + if depth != 0: + for ii, child_root in enumerate(self.children_root_nodes): + direction_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_direction(child_root.get_direction_sample(), direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][ii] + state_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_state(child_root.get_direction_sample(), direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][ii] + state_sample_grad += root["node"].compute_root_state_prior_child_grad(child_root.get_state_sample(), state_curr_sample, child_root.get_direction_sample()) * root['node'].variational_parameters['q_rho'][ii] + + if depth != 0: + if adaptive and i == 0: + root['node'].reset_opt() + + # Combine gradients of functions wrt sample with gradient of sample wrt var params + if adaptive: + root['node'].update_direction_adaptive(direction_params_grad, direction_sample_grad, direction_params_entropy_grad, + step_size=step_size, i=i) + else: + root['node'].update_direction_params(direction_params_grad, direction_sample_grad, direction_params_entropy_grad, + step_size=step_size) + key, sample_grad = root['node'].direction_sample_and_grad(key, n_samples=mc_samples) + direction_curr_sample, _ = sample_grad + + state_sample_grad += root["node"].compute_state_prior_grad(state_curr_sample, state_parent_sample, direction_curr_sample) + + if adaptive: + root['node'].update_state_adaptive(state_params_grad, state_sample_grad, state_params_entropy_grad, + step_size=step_size, i=i) + else: + root['node'].update_state_params(state_params_grad, state_sample_grad, state_params_entropy_grad, + step_size=step_size) + key, sample_grad = root['node'].state_sample_and_grad(key, n_samples=mc_samples) + state_curr_sample, _ = sample_grad + root['node'].samples[0] = state_curr_sample + root['node'].samples[1] = direction_curr_sample + + if memoized: + mass_here = root['node'].suff_stats['mass']['total'] + else: + mass_here = jnp.sum(root['node'].variational_parameters['q_z'] * self.variational_parameters['q_c']) + root['node'].variational_parameters['delta_1'] = 1.0 + mass_here + root['node'].variational_parameters['delta_2'] = (self.alpha_decay**depth) * self.dp_alpha + mass_down + + return mass_here + mass_down, direction_curr_sample, state_curr_sample + + # Update kernels and sticks + if root is None: + root = self.root + descend(root, key) + + + def sample_grad_root_node(self, key, memoized=True, mc_samples=10, **kwargs): + """ + B->B0 --> C + Compute gradient of p(stateB0|dirB0,stateB) wrt stateB, p(dirB0|dirB,stateB) wrt dirB, stateB + and gradient p(stateC|dirC,stateB) wrt stateB, p(dirC|dirB,stateB) wrt dirB, stateB + """ + # Sample root params and compute initial gradients wrt children + root = self.root + direction_sample_grad = 0. + state_sample_grad = 0. + + key, sample_grad = root['node'].direction_sample_and_grad(key, n_samples=mc_samples) + direction_curr_sample, direction_params_grad = sample_grad + key, sample_grad = root['node'].state_sample_and_grad(key, n_samples=mc_samples) + state_curr_sample, state_params_grad = sample_grad + + # Gradient of entropy + direction_params_entropy_grad = root["node"].compute_direction_entropy_grad() + state_params_entropy_grad = root["node"].compute_state_entropy_grad() + + # Gradient of likelihood + if memoized: + ll_grad = root["node"].compute_ll_state_grad_suff(state_curr_sample) + else: + weights = root['node'].variational_parameters['q_z'] * self.variational_parameters['q_c'] + ll_grad = root["node"].compute_ll_state_grad(self.ntssb.data, weights, state_curr_sample) + + # Gradient of children in TSSB + for child in root['children'][::-1]: + direction_child_sample = child['node'].get_direction_sample() + state_child_sample = child['node'].get_state_sample() + direction_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_direction(direction_child_sample, direction_curr_sample, state_curr_sample) + state_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_state(direction_child_sample, direction_curr_sample, state_curr_sample) + state_sample_grad += root["node"].compute_state_prior_child_grad(state_child_sample, state_curr_sample, direction_child_sample) + + # Gradient of roots of children TSSB + for i, child_root in enumerate(self.children_root_nodes): + direction_child_sample = child_root.get_direction_sample() + state_child_sample = child_root.get_state_sample() + # Gradient of the root nodes of children TSSBs wrt to their parameters using this TSSB root as parent + direction_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_direction(direction_child_sample, direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][i] + state_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_state(direction_child_sample, direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][i] + state_sample_grad += root["node"].compute_state_prior_child_grad(state_child_sample, state_curr_sample, direction_child_sample) * root['node'].variational_parameters['q_rho'][i] + + direction_locals_grads = [direction_params_grad, direction_params_entropy_grad] + state_locals_grads = [state_params_grad, state_params_entropy_grad] + + return ll_grad, [direction_locals_grads, state_locals_grads], [direction_sample_grad, state_sample_grad] + + def compute_children_root_node_grads(self, **kwargs): + """ + A -> B + Compute gradient of p(dirB|dirA,stateA) wrt dirB and p(stateB|dirB,stateA) wrt stateB, dirB + """ + def descend(root, children_grads=None): + direction_curr_sample = root['node'].samples[1] + state_curr_sample = root['node'].samples[0] + + # Compute gradient of children roots wrt their params + if children_grads is None: + children_grads = [[0., 0.]] * len(self.children_root_nodes) + for i, child_root in enumerate(self.children_root_nodes): + # Gradient of the root nodes of children TSSBs wrt to their parameters using this + direction_child_sample = child_root.get_direction_sample() + state_child_sample = child_root.get_state_sample() + direction_sample_grad = root["node"].compute_direction_prior_grad(direction_child_sample, direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][i] + direction_sample_grad += root["node"].compute_state_prior_grad_wrt_direction(state_child_sample, state_curr_sample, direction_child_sample) * root['node'].variational_parameters['q_rho'][i] + state_sample_grad = root["node"].compute_state_prior_grad(state_child_sample, state_curr_sample, direction_child_sample) * root['node'].variational_parameters['q_rho'][i] + children_grads[i][0] += direction_sample_grad + children_grads[i][1] += state_sample_grad + + for child in root['children']: + descend(child, children_grads=children_grads) + + return children_grads + + # Return list of children, for each child a tuple with direction grad sum and sample grad sum + return descend(self.root) + + + def update_root_node_params(self, key, ll_grad, local_grads, children_grads, parent_grads, i=0, adaptive=True, step_size=0.0001, mc_samples=10, **kwargs): + """ + ll_grads: ll_grad_state_D + local_grads: params_grad_D, params_entropy_grad_D + children_grads: grad_D p(D0|D) + grad_D p(D|D) + parent_grads: grad_D p(D|B) + grad_D p(D|B0) + """ + if adaptive and i == 0: + self.root['node'].reset_opt() + + direction_params_grad, direction_params_entropy_grad = local_grads[0] + direction_sample_grad = children_grads[0] + parent_grads[0] + + # Combine gradients of functions wrt sample with gradient of sample wrt var params + if adaptive: + self.root['node'].update_direction_adaptive(direction_params_grad, direction_sample_grad, direction_params_entropy_grad, + step_size=step_size, i=i) + else: + self.root['node'].update_direction_params(direction_params_grad, direction_sample_grad, direction_params_entropy_grad, + step_size=step_size) + key, sample_grad = self.root['node'].direction_sample_and_grad(key, n_samples=mc_samples) + direction_curr_sample, _ = sample_grad + self.root['node'].samples[1] = direction_curr_sample + + state_params_grad, state_params_entropy_grad = local_grads[1] + state_sample_grad = children_grads[1] + parent_grads[1] + state_sample_grad += ll_grad + + if adaptive: + self.root['node'].update_state_adaptive(state_params_grad, state_sample_grad, state_params_entropy_grad, + step_size=step_size, i=i) + else: + self.root['node'].update_state_params(state_params_grad, state_sample_grad, state_params_entropy_grad, + step_size=step_size) + key, sample_grad = self.root['node'].state_sample_and_grad(key, n_samples=mc_samples) + state_curr_sample, _ = sample_grad + self.root['node'].samples[0] = state_curr_sample + + + def update_pivot_probs(self, **kwargs): + def descend(root): + # Update pivot assignment probabilities + new_log_probs = [] + root['node'].variational_parameters['q_rho'] = [0.] * len(self.children_root_nodes) + for i, child_root in enumerate(self.children_root_nodes): + pivot_direction_score = child_root.compute_root_direction_prior(root['node'].get_direction_sample()) + pivot_state_score = child_root.compute_root_state_prior(root['node'].get_state_sample()) + new_log_prob = pivot_direction_score + pivot_state_score + jnp.log(root['node'].pivot_prior_prob) + root['node'].variational_parameters['q_rho'][i] = new_log_prob + new_log_probs.append([new_log_prob]) + + for child in root['children']: + children_log_probs = descend(child) + for i, child_root in enumerate(self.children_root_nodes): + new_log_probs[i].extend(children_log_probs[i]) + + return new_log_probs + + logqs = descend(self.root) + # Compute LSE for pivot probs + for i in range(len(self.children_root_nodes)): + this_logqs = jnp.array(logqs[i]) + self.variational_parameters['LSE_rho'][i] = jax.scipy.special.logsumexp(this_logqs) + + # Set probs and return sum of weighted likelihoods + def descend(root): + for i, child_root in enumerate(self.children_root_nodes): + root['node'].variational_parameters['q_rho'][i] = jnp.exp(root['node'].variational_parameters['q_rho'][i] - self.variational_parameters['LSE_rho'][i]) + for child in root['children']: + descend(child) + # Normalize pivot probs + descend(self.root) + def cull_tree(self, verbose=False, resample_sticks=True): """ If a leaf node has no data assigned to it, remove it @@ -989,10 +1542,9 @@ def descend(root): return pivot_node, pivot_tssb - def find_node(self, u, truncated=False): + def find_node(self, u, include_leaves=True): def descend(root, u, depth=0): if depth >= self.max_depth: - # logger.debug >>sys.stderr, "WARNING: Reached maximum depth." return (root["node"], [], root) elif u < root["main"]: return (root["node"], [], root) @@ -1000,36 +1552,19 @@ def descend(root, u, depth=0): # Rescale the uniform variate to the remaining interval. u = (u - root["main"]) / (1.0 - root["main"]) - if not truncated: - # Perhaps break sticks out appropriately. - while ( - not root["children"] or (1.0 - prod(1.0 - root["sticks"])) < u - ): - stick_length = boundbeta(1, self.dp_gamma) - root["sticks"] = vstack([root["sticks"], stick_length]) - root["children"].append( - { - "node": root["node"].spawn( - False, self.root_node.observed_parameters - ), - "main": boundbeta( - 1.0, - (self.alpha_decay ** (depth + 1)) * self.dp_alpha, - ) - if self.min_depth <= (depth + 1) - else 0.0, - "sticks": empty((0, 1)), - "children": [], - } - ) - else: - root["sticks"][-1] = 1.0 + # Don't break sticks edges = 1.0 - cumprod(1.0 - root["sticks"]) index = sum(u > edges) + if index >= len(root['sticks']): + return (root["node"], [], root) edges = hstack([0.0, edges]) u = (u - edges[index]) / (edges[index + 1] - edges[index]) + # Perhaps stop before continuing to a leaf + if not include_leaves and len(root["children"][index]["children"]) == 0: + return (root["node"], [], root) + (node, path, root) = descend(root["children"][index], u, depth + 1) path.insert(0, index) @@ -1038,6 +1573,33 @@ def descend(root, u, depth=0): return descend(self.root, u) + def find_node_uniform(self, key, include_leaves=True): + def descend(root, key, depth=0): + if depth >= self.max_depth: + return (root["node"], [], root) + elif len(root["children"]) == 0: + return (root["node"], [], root) + else: + key, subkey = jax.random.split(key) + n_children = len(root["children"]) + if jax.random.bernoulli(subkey, p=1./(n_children+1)): + return (root["node"], [], root) + else: + key, subkey = jax.random.split(key) + index = jax.random.choice(subkey, len(root["children"])) + + # Perhaps stop before continuing to a leaf + if not include_leaves and len(root["children"][index]["children"]) == 0: + return (root["node"], [], root) + + (node, path, root) = descend(root["children"][index], key, depth + 1) + + path.insert(0, index) + + return (node, path, root) + + return descend(self.root, key) + def get_expected_mixture(self, reset_names=False): """ Computes the expected weight for each node. @@ -1085,6 +1647,26 @@ def descend(root): return sr return descend(self.root) + def set_weights(self): + def descend(root, mass): + root['weight'] = mass * root["main"] + edges = sticks_to_edges(root["sticks"]) + weights = diff(hstack([0.0, edges])) + for i, child in enumerate(root["children"]): + descend(child, mass * (1.0 - root["main"]) * weights[i]) + return descend(self.root, 1.0) + + def set_expected_weights(self): + def descend(root): + logprior = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2']) + if root['node'].parent() is not None: + logprior += root['node'].parent().variational_parameters['sum_E_log_1_nu'] + logprior += root['node'].variational_parameters['E_log_phi'] + root['weight'] = jnp.exp(logprior) + for child in root['children']: + descend(child) + descend(self.root) + def get_mixture( self, reset_names=False, get_roots=False, get_depths=False, truncate=False ): @@ -1481,20 +2063,21 @@ def label_nodes(self, counts=False, names=False): elif not names or counts is True: self.label_nodes_counts() - def set_node_names(self, root_name="X"): - self.root["label"] = str(root_name) - self.root["node"].label = str(root_name) + def set_node_names(self, root=None, root_name="X"): + if root is None: + root = self.root + + root["label"] = str(root_name) + root["node"].label = str(root_name) def descend(root, name): for i, child in enumerate(root["children"]): - child_name = "%s-%d" % (name, i) - + child_name = f"{name}-{i}" root["children"][i]["label"] = child_name root["children"][i]["node"].label = child_name - descend(child, child_name) - descend(self.root, root_name) + descend(root, root_name) def set_subcluster_node_names(self): # Assumes the other fixed nodes have already been named, and ignores the root diff --git a/scatrex/plotting/__init__.py b/scatrex/plotting/__init__.py index ee3f060..5da138d 100644 --- a/scatrex/plotting/__init__.py +++ b/scatrex/plotting/__init__.py @@ -1 +1,2 @@ from .constants import * +from .scatterplot import * \ No newline at end of file diff --git a/scatrex/plotting/scatterplot.py b/scatrex/plotting/scatterplot.py index 9f8e71f..8aab3d4 100644 --- a/scatrex/plotting/scatterplot.py +++ b/scatrex/plotting/scatterplot.py @@ -3,7 +3,101 @@ """ import matplotlib import matplotlib.pyplot as plt +import networkx as nx import numpy as np +from ..utils.tree_utils import tree_to_dict + + +def plot_full_tree(tree, ax=None, figsize=(6,6), **kwargs): + if ax is None: + plt.figure(figsize=figsize) + ax = plt.gca() + def descend(root, graph, pos={}): + pos_out = plot_tree(root['node'], G=graph, ax=ax, alpha=1., draw=False, **kwargs) # Draw subtree + pos.update(pos_out) + for child in root['children']: + descend(child, graph, pos) + + def sub_descend(sub_root, graph): + parent = sub_root['label'] + for i, super_child in enumerate(root['children']): + child = super_child['label'] + graph.add_edge(parent, child, alpha=sub_root['pivot_probs'][i], ls='--') + nx.draw_networkx_edges(graph, pos, edgelist=[(parent, child)], edge_color=sub_root['color'], alpha=sub_root['pivot_probs'][i], style='--') + for child in sub_root['children']: + sub_descend(child, graph) + + if len(root['children']) > 0: + sub_descend(root['node'], graph) # Draw pivot edges + + G = nx.DiGraph() + descend(tree, G) + + ax.margins(0.20) # Set margins for the axes so that nodes aren't clipped + ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True) + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + + +def plot_tree(tree, G = None, param_key='param', data=None, labels=True, alpha=0.5, font_size=12, node_size=1500, edge_width=1., arrows=True, draw=True, ax=None): + tree_dict = tree_to_dict(tree, param_key=param_key) + + # Get all positions + pos = {} + pos[tree['label']] = tree[param_key] + for node in tree_dict: + if tree_dict[node]['parent'] != '-1': + pos[node] = tree_dict[node]['param'] + + # Draw graph + node_options = {'alpha': alpha, + 'node_size': node_size,} + edge_options = {'alpha': alpha, + 'width': edge_width, + 'node_size':node_size, + 'arrows': arrows} + label_options = {'alpha': alpha, + 'font_size': font_size,} + + if ax is None: + fig = plt.figure(figsize=(6,6)) + + if G is None: + G = nx.DiGraph() + for node in tree_dict: + nx.draw_networkx_nodes(G, pos, nodelist=[node], node_color=tree_dict[node]['color'], + **node_options) + if tree_dict[node]['parent'] != '-1': + parent = tree_dict[node]['parent'] + G.add_edge(parent, node) + nx.draw_networkx_edges(G, pos, edgelist=[(parent, node)], edge_color=tree_dict[parent]['color'],**edge_options) + if labels: + labs = dict(zip(list(tree_dict.keys()), list(tree_dict.keys()))) + nx.draw_networkx_labels(G, pos, labs, **label_options) + + ax = plt.gca() + ax.margins(0.20) # Set margins for the axes so that nodes aren't clipped + ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True) + ax.spines['right'].set_visible(False) + ax.spines['top'].set_visible(False) + if draw: + plt.show() + else: + return pos + +def plot_nested_tree(tree, top=True, param_key='param', out_alpha=0.4, in_alpha=1., large_node_size=5000, small_node_size=500, draw=True, ax=None, **kwargs): + tree_dict = tree_to_dict(tree, param_key=param_key) + + if top: + # Plot main tree with transparency, large nodes and without labels + ax = plot_tree(tree, param_key=param_key, labels=False, node_size=large_node_size, alpha=out_alpha, draw=False, **kwargs) + + for subtree in tree_dict: # Do this in a tree traversal so that we add the pivots + # Plot each subtree + plot_tree(tree_dict[subtree]['node'], param_key='mean', labels=False, node_size=small_node_size, alpha=in_alpha, ax=ax, draw=False, **kwargs) + + if draw: + plt.show() def plot_tree_proj( diff --git a/scatrex/plotting/tree_colors.py b/scatrex/plotting/tree_colors.py new file mode 100644 index 0000000..89e1bde --- /dev/null +++ b/scatrex/plotting/tree_colors.py @@ -0,0 +1,38 @@ +# Functions to make a colormap out of a tree +import matplotlib +import seaborn as sns +from colorsys import rgb_to_hls + +def adjust_color(rgba, lightness_scale, saturation_scale): + # scale the lightness (The values should be between 0 and 1) + rgb = rgba[:-1] + a = rgba[-1] + hls = rgb_to_hls(*rgb) + lightness = max(0,min(1, hls[1] * lightness_scale)) + saturation = max(0,min(1,hls[2] * saturation_scale)) + rgb = sns.set_hls_values(color = rgb, h = None, l = lightness, s = saturation) + rgba = rgb + (a,) + hex = matplotlib.colors.to_hex(rgba, keep_alpha=True) + return hex + + +def make_tree_colormap(tree, base_color, brightness_mult=0.7, saturation_mult=1.3): + """ + Updates the tree dictionary with colors defined from node depth and breadth + tree: nested dictionary containing {node: children} with children a list of also dictionaries {child: children} + base_color: HEX code for base color + saturation_mult: how much to change saturation for each step in depth, centered at 1 + brightness_mult: how much to change brightness for each step in breadth, centered at 1 + """ + base_color_rgba = matplotlib.colors.ColorConverter.to_rgba(base_color) + + tree['color'] = base_color + + # Traverse tree to update out_dict + def descend(root, depth=1, breadth=1): + for i, child in enumerate(root['children']): + breadth += i + color = adjust_color(base_color_rgba, brightness_mult**depth, saturation_mult**breadth) + child['color'] = color + descend(child, depth=depth+1, breadth=breadth) + descend(tree) diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py index c1cf3cc..5082a33 100644 --- a/scatrex/scatrex.py +++ b/scatrex/scatrex.py @@ -1,7 +1,11 @@ -from .models import * +from .models import CNATree, TrajectoryTree from .ntssb import NTSSB from .ntssb import StructureSearch from .plotting import scatterplot, constants +from .utils import tree_utils + +import jax +import jax.numpy as jnp import numpy as np from sklearn.decomposition import PCA @@ -25,13 +29,13 @@ class SCATrEx(object): def __init__( self, - model=cna, + model=CNATree, model_args=dict(), verbosity=logging.INFO, temppath="./temppath", + seed=42, ): - - self.model = cna + self.model = model self.model_args = model_args self.observed_tree = None self.ntssb = None @@ -42,6 +46,7 @@ def __init__( self.temppath = temppath if not os.path.exists(temppath): os.makedirs(temppath, exist_ok=True) + self.seed = seed logger.setLevel(verbosity) @@ -93,14 +98,11 @@ def simulate_tree( observed_tree=None, observed_tree_args=dict(), observed_tree_params=dict(), - model_args=None, n_genes=50, n_extra_per_observed=1, - seed=42, copy=False, + **cmap_kwargs, ): - np.random.seed(seed) - self.observed_tree = observed_tree if not self.observed_tree: @@ -111,7 +113,7 @@ def simulate_tree( for arg in observed_tree_args: logger.info(f"{arg}: {observed_tree_args[arg]}") - self.observed_tree = self.model.ObservedTree(**observed_tree_args) + self.observed_tree = self.model(**observed_tree_args) self.observed_tree.generate_tree() self.observed_tree.add_node_params(n_genes=n_genes, **observed_tree_params) @@ -123,58 +125,281 @@ def simulate_tree( logger.info(f"{arg}: {self.model_args[arg]}") self.ntssb = NTSSB( - self.observed_tree, self.model.Node, node_hyperparams=self.model_args + self.observed_tree, node_hyperparams=self.model_args, seed=self.seed ) self.ntssb.create_new_tree(n_extra_per_observed=n_extra_per_observed) + self.ntssb.set_ntssb_colors(**cmap_kwargs) + logger.info("Tree is stored in `self.observed_tree` and `self.ntssb`") return self.ntssb if copy else None - def simulate_data(self, n_cells=100, seed=42, copy=False): - np.random.seed(seed) - self.ntssb.put_data_in_nodes(n_cells) - self.ntssb.root["node"].root["node"].generate_data_params() - - # Sample observations - observations = [] - assignments = [] - assignments_labels = [] - assignments_obs_labels = [] - for obs in range(len(self.ntssb.assignments)): - sample = self.ntssb.assignments[obs].sample_observation(obs).reshape(1, -1) - observations.append(sample) - assignments.append(self.ntssb.assignments[obs]) - assignments_labels.append(self.ntssb.assignments[obs].label) - assignments_obs_labels.append(self.ntssb.assignments[obs].tssb.label) - assignments = np.array(assignments) - assignments_labels = np.array(assignments_labels) - assignments_obs_labels = np.array(assignments_obs_labels) - observations = np.concatenate(observations) - - self.ntssb.data = observations - self.ntssb.num_data = observations.shape[0] + def simulate_data(self, n_cells=100, copy=False): + node_assignments, obs_node_assignments = self.ntssb.sample_assignments(n_cells) + data = np.array(self.ntssb.simulate_data()) + noiseless_data = self.ntssb.root['node'].root['node'].remove_noise(data) - if self.ntssb.root["node"].root["node"].num_batches > 1: - self.ntssb.covariates = self.ntssb.root["node"].root["node"].cell_covariates + self.adata = AnnData(data) + self.adata.layers['corrected'] = noiseless_data + self.adata.obs["obs_node"] = obs_node_assignments + self.adata.uns["obs_node_colors"] = [ + self.observed_tree.tree_dict[node]["color"] + for node in self.observed_tree.tree_dict + ] + tree = self.ntssb.get_param_dict() + tree_dict = tree_utils.tree_to_dict(tree) + node_colors = [] + for node in tree_dict: + d = tree_utils.tree_to_dict(tree_dict[node]['node']) + for n in d: + if d[n]['size'] > 0: + node_colors.append(d[n]['color']) + self.adata.obs["node"] = node_assignments + self.adata.uns["node_colors"] = node_colors[:] # to remove the root + self.adata.raw = self.adata logger.info("Labeled data are stored in `self.adata`") - self.adata = AnnData(observations) - self.adata.obs["node"] = assignments_labels - self.adata.obs["obs_node"] = assignments_obs_labels - if self.ntssb.root["node"].root["node"].num_batches > 1: - self.adata.obs["batch"] = np.argmax(self.ntssb.covariates, axis=1) + return self.adata if copy else None - self.adata.uns["obs_node_colors"] = [ + def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01): + logger.info("Learning cell and gene scales") + n_cells = self.ntssb.data.shape[0] + n_genes = self.ntssb.data.shape[1] + gs = np.sqrt(np.median(self.ntssb.data)) + + root = self.ntssb.root['node'].root['node'] + gene_scales_alpha_init = 10. * jnp.ones((n_genes,)) #* jnp.exp(np.random.normal(size=self.n_genes)) + gene_scales_beta_init = 10. * jnp.ones((n_genes,)) * jnp.exp(gs + 0. * np.random.normal(size=n_genes)) + root.variational_parameters['global']['gene_scales']['log_alpha'] = jnp.log(gene_scales_alpha_init) + root.variational_parameters['global']['gene_scales']['log_beta'] = jnp.log(gene_scales_beta_init) + + cell_scales_alpha_init = 10. * jnp.ones((n_cells,1)) #* jnp.exp(np.random.normal(size=[500,1])) + cell_scales_beta_init = 10. * jnp.ones((n_cells,1)) * jnp.exp(gs + 0. * np.random.normal(size=[n_cells,1])) + root.variational_parameters['local']['cell_scales']['log_alpha'] = jnp.log(cell_scales_alpha_init) + root.variational_parameters['local']['cell_scales']['log_beta'] = jnp.log(cell_scales_beta_init) + + # Initialize MC samples + self.ntssb.sample_variational_distributions(n_samples=mc_samples) + self.ntssb.update_sufficient_statistics() + + # Update cell, gene scales, assignments + self.ntssb.learn_globals(n_epochs=n_epochs, step_size=step_size,mc_samples=mc_samples, + update_ass=True, update_locals=True, update_roots=False, + locals_names=['cell_scales'], + globals_names=['gene_scales']) + + # Add noise factors to learn better scales + # root.variational_parameters['global']['factor_weights']['mean'] = jnp.array(0.01 * np.random.normal(size=(2, n_genes))) + # root.variational_parameters['local']['obs_weights']['mean'] = jnp.array(0.01 * np.random.normal(size=(500, 2))) + self.ntssb.sample_variational_distributions(n_samples=mc_samples) + self.ntssb.learn_globals(n_epochs=n_epochs, step_size=step_size, mc_samples=mc_samples, + update_ass=True, update_locals=True, update_roots=False) + + def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=10, memoized=True, mc_samples=10, step_size=0.01, seed=42): + logger.info("Learning roots and noise") + # Remove noise and learn roots + self.ntssb.set_node_hyperparams(n_factors=0) + self.ntssb.root['node'].root['node'].reset_variational_noise_factors() + self.ntssb.sample_variational_distributions(n_samples=mc_samples) + self.ntssb.update_sufficient_statistics() + self.ntssb.learn_roots(n_epochs, memoized=memoized, mc_samples=mc_samples, step_size=step_size, return_trace=False) + + # Update assignments + self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False) + + # Learn a tree with root updates on noiseless data (over-cluster) + searcher = StructureSearch(self.ntssb) + searcher.tree.set_tssb_params(dp_alpha=1., dp_gamma=1.,) + searcher.tree.sample_variational_distributions(n_samples=10) + searcher.tree.update_sufficient_statistics() + searcher.tree.compute_elbo(memoized=memoized) + searcher.proposed_tree = deepcopy(searcher.tree) + searcher.run_search(n_iters=n_iters, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size, + memoized=memoized, seed=seed, update_roots=True) + + self.ntssb = deepcopy(searcher.tree) + self.ntssb.set_node_hyperparams(n_factors=self.model_args['n_factors']) + self.ntssb.root['node'].root['node'].reset_variational_noise_factors() + self.ntssb.sample_variational_distributions(n_samples=mc_samples) + self.ntssb.update_sufficient_statistics() + self.ntssb.compute_elbo(memoized=memoized) + # Cleanup parameters: learn noise and update all parameters (including roots) except scales, no memoization needed + self.ntssb.learn_model(n_epochs=n_epochs, update_ass=True, update_globals=True, + locals_names=['obs_weights'], + globals_names=['factor_weights', 'factor_precisions'], + update_roots=True, step_size=step_size, mc_samples=mc_samples, + memoized=False) + + self.ntssb.update_sufficient_statistics() + self.ntssb.compute_elbo(memoized=memoized) + + # Propose merges to account for new noise + searcher = StructureSearch(self.ntssb) + searcher.tree.sample_variational_distributions(n_samples=mc_samples) + searcher.tree.update_sufficient_statistics() + searcher.tree.compute_elbo(memoized=memoized) + searcher.proposed_tree = deepcopy(searcher.tree) + key = jax.random.PRNGKey(seed) + for i in range(10): + key, subkey = jax.random.split(key) + searcher.merge(subkey, moves_per_tssb=1, memoized=memoized) + + searcher.tree.update_sufficient_statistics() + searcher.tree.compute_elbo(memoized=memoized) + searcher.proposed_tree = deepcopy(searcher.tree) + # Propose root merges with opt after these merges + for i in range(n_merges): + key, subkey = jax.random.split(key) + searcher.merge_root(subkey, memoized=memoized, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size) + + searcher.tree.update_sufficient_statistics() + searcher.tree.compute_elbo(memoized=memoized) + searcher.proposed_tree = deepcopy(searcher.tree) + # Propose root swaps after these merges + for i in range(n_swaps): + key, subkey = jax.random.split(key) + searcher.swap(subkey, memoized=memoized, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size, + update_ass=False) # fixed assignments + + # Re-learn noise and parameters except scales + self.ntssb = deepcopy(searcher.tree) + self.ntssb.sample_variational_distributions(n_samples=mc_samples) + self.ntssb.update_sufficient_statistics() + self.ntssb.compute_elbo(memoized=memoized) + self.ntssb.learn_model(n_epochs=n_epochs, update_ass=True, update_globals=True, + locals_names=['obs_weights'], + globals_names=['factor_weights', 'factor_precisions'], + update_roots=True, step_size=step_size, mc_samples=mc_samples, + memoized=False) # If I do this with memoization the noise ends up explaining too much and messing up the scales! Best not to mix + + def learn_tree(self, n_iters=10, n_epochs=100, memoized=True, mc_samples=10, step_size=0.01, dp_alpha=.1, dp_gamma=.1, prune=True, seed=42): + logger.info("Learning augmented tree") + # Re-learn tree (optionally from scratch) with fixed noise and roots + if prune: + self.ntssb.prune_subtrees() + searcher = StructureSearch(self.ntssb) + searcher.tree.set_tssb_params(dp_alpha=dp_alpha, dp_gamma=dp_gamma) + searcher.tree.sample_variational_distributions(n_samples=mc_samples) + searcher.tree.update_sufficient_statistics() + searcher.tree.compute_elbo(memoized=memoized) + searcher.proposed_tree = deepcopy(searcher.tree) + searcher.run_search(n_iters=n_iters, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size, + memoized=memoized, seed=seed, + update_roots=False) + self.ntssb = deepcopy(searcher.tree) + + def update_anndata(self, adata): + """ + Use learned NTSSB to add annotations to input AnnData object. + Cells and genes must be in same order as the data in NTSSB! + """ + node_assignments = [] + for i in range(adata.shape[0]): + node_assignments.append(self.ntssb.assignments[i].label) + + obs_node_assignments = [] + for i in range(adata.shape[0]): + obs_node_assignments.append(self.ntssb.assignments[i].tssb.label) + + adata.obs["scatrex_node"] = node_assignments + adata.obs["scatrex_obs_node"] = obs_node_assignments + + adata.uns["scatrex_node_colors"] = [ + self.observed_tree.tree_dict[node.split("-")[0]]["color"] + for node in np.unique(adata.obs["scatrex_node"]) + ] + adata.uns["scatrex_obs_node_colors"] = [ self.observed_tree.tree_dict[node]["color"] - for node in self.observed_tree.tree_dict + for node in np.unique(adata.obs["scatrex_obs_node"]) ] - self.adata.raw = self.adata - return (observations, assignments_labels) if copy else None + labels = list(self.observed_tree.tree_dict.keys()) + sizes = [ + np.count_nonzero(adata.obs["scatrex_obs_node"] == label) + for label in labels + ] + adata.uns["scatrex_estimated_frequencies"] = dict(zip(labels, sizes)) - def learn_tree( + adata.layers["scatrex_noise"] = ( + self.ntssb.root["node"] + .root["node"] + .variational_parameters["local"]["obs_weights"]['mean'] + .dot( + self.ntssb.root["node"] + .root["node"] + .variational_parameters["global"]["factor_weights"]['mean'] + ) + ) + + genes_pos = np.arange(adata.var_names.size) + + xi_mat = np.zeros(adata.shape) + om_mat = np.zeros(adata.shape) + cnv_mat = np.zeros(adata.shape) + mean_mat = np.zeros(adata.shape) + nodes = np.array(self.ntssb.get_nodes()) + nodes_labels = np.array([node.label for node in nodes]) + for node_id in np.unique(adata.obs["scatrex_node"]): + cells = np.where(adata.obs["scatrex_node"] == node_id)[0] + node = nodes[np.where(node_id == nodes_labels)[0][0]] + pos = np.meshgrid(cells, genes_pos) + xi_mat[tuple(pos)] = ( + np.array( + node.params[0] + ).reshape(-1, 1) + * np.ones((len(cells), len(genes_pos))).T + ) + om_mat[tuple(pos)] = ( + np.array( + node.params[1] + ).reshape(-1, 1) + * np.ones((len(cells), len(genes_pos))).T + ) + cnv_mat[tuple(pos)] = ( + np.array(node.cnvs).reshape(-1, 1) + * np.ones((len(cells), len(genes_pos))).T + ) + mean_mat[tuple(pos)] = ( + np.array(node.get_mean()).reshape(-1, 1) + * np.ones((len(cells), len(genes_pos))).T + ) + adata.layers["scatrex_cell_states"] = xi_mat + adata.layers["scatrex_cell_state_events"] = om_mat + adata.layers["scatrex_cnvs"] = cnv_mat + adata.layers["scatrex_mean"] = mean_mat + + def learn(self, adata, observed_tree=None, counts_layer='counts', allow_subtrees=True, allow_root_subtrees=False, root_cells=None, + batch_size=None, seed=42, + n_epochs=100, mc_samples=10, step_size=0.01, n_iters=10, n_merges=10, n_swaps=10, memoized=True, dp_alpha=.1, dp_gamma=.1): + """ + Complete NTSSB learning procedure. + """ + if observed_tree is not None: + self.set_observed_tree(observed_tree) + + # Setup NTSSB + self.ntssb = NTSSB(self.observed_tree, + node_hyperparams=self.model_args, + seed=seed,) + self.ntssb.add_data(np.array(adata.layers[counts_layer])) + self.ntssb.make_batches(batch_size, seed) + self.ntssb.reset_variational_parameters() + + # Learn + self.learn_scales(n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size) + self.learn_roots_and_noise(n_iters=n_iters, n_epochs=n_epochs, n_merges=n_merges, n_swaps=n_swaps, memoized=memoized, mc_samples=mc_samples, step_size=step_size, seed=seed) + self.learn_tree(n_iters=n_iters, n_epochs=n_epochs, memoized=memoized, mc_samples=mc_samples, step_size=step_size, dp_alpha=dp_alpha, dp_gamma=dp_gamma, prune=True, seed=seed) + + # Create outputs for analysis + self.ntssb.set_learned_parameters() + self.ntssb.assign_samples() + self.ntssb.create_augmented_tree_dict() + self.ntssb.initialize_gene_node_colormaps() + self.update_anndata(adata) + + def _learn_tree( self, observed_tree=None, reset=True, @@ -363,7 +588,9 @@ def learn_tree( np.array( np.exp( node.variational_parameters["locals"][ - "unobserved_factors_kernel_log_mean" + "unobserved_factors_kernel_log_shape" + ] - node.variational_parameters["locals"][ + "unobserved_factors_kernel_log_rate" ] ) ).reshape(-1, 1) @@ -783,7 +1010,7 @@ def assign_new_data(self, data): return def set_node_event_strings(self, **kwargs): - self.ntssb.set_node_event_strings(var_names=sim_sca.adata.var_names, **kwargs) + self.ntssb.set_node_event_strings(var_names=self.adata.var_names, **kwargs) def plot_tree( self, @@ -871,6 +1098,63 @@ def plot_tree( return g + def plot_data(self, level=0, draw=True, layer=None, color=None, remove_noise=False, **kwargs): + # Plot data projection using node colors + data = self.adata.X + if layer is not None: + data = self.adata.layers[layer] + + if remove_noise: + # Remove noise according to noise model + data = self.ntssb.root['node'].root['node'].remove_noise(data) + + def super_descend(super_root, go_down=True): + if go_down: + descend(super_root['node'].root) + else: + # Plot data + attached_cells = np.array(list(super_root['node']._data)) + if len(attached_cells > 0): + color = super_root['color'] + if color is not None: + cell_color = color + plt.scatter(data[attached_cells,0], data[attached_cells,1], + color=cell_color, label=super_root['label'], **kwargs) + + for super_child in super_root['children']: + super_descend(super_child, go_down=go_down) + + def descend(root): + # Plot data + attached_cells = np.array(list(root['node'].data)) + if len(attached_cells > 0): + cell_color = root['color'] + if color is not None: + cell_color = color + plt.scatter(data[attached_cells,0], data[attached_cells,1], + color=cell_color, label=root['label'], **kwargs) + for child in root['children']: + descend(child) + + if level == 0: # Unobs + super_descend(self.ntssb.root, go_down=True) + elif level == 1: # Obs + super_descend(self.ntssb.root, go_down=False) + + if draw: + plt.show() + + def plot_tree_projection(self, level=None, ax=None, title="", **kwargs): + + tree = self.ntssb.get_param_dict() + if level is None: # Both levels + ax = scatterplot.plot_nested_tree(tree, param_key='obs_param', top=True, ax=ax, **kwargs) + elif level == 1: # Only obs + ax = scatterplot.plot_tree(tree, param_key='obs_param', ax=ax, **kwargs) + elif level == 0: # Only unobs + ax = scatterplot.plot_nested_tree(tree, top=False, ax=ax, **kwargs) + return ax + def plot_tree_proj( self, project=True, @@ -981,6 +1265,8 @@ def plot_unobserved_parameters( if gene_names is not None: # Transform gene names into gene indices genes = np.array([self.adata.var_names.get_loc(g) for g in gene_names]) + else: + genes = np.arange(len(self.adata.var_names)) if self.search is not None: if len(self.search.traces["elbo"]) > 0: @@ -1038,14 +1324,17 @@ def plot_unobserved_parameters( ls=ls, ) else: - plt.plot( - unobs[genes].ravel() - step * i, + plt.bar( + np.arange(len(genes)), + unobs[genes].ravel(),# - step * i, + bottom=- step * i, label=node.label, color=node.tssb.color, lw=lw, alpha=alpha, ls=ls, ) + plt.plot(np.zeros((len(genes),)) - step * i, ls='--', color='gray', alpha=alpha) if gene_names is not None and show_names: plt.xticks(np.arange(len(gene_names)), labels=gene_names) else: @@ -1059,6 +1348,77 @@ def plot_unobserved_parameters( if ax is None: plt.show() + def plot_estimated_unobserved_kernels( + self, + node_names=None, + gene=None, + gene_names=None, + ax=None, + figsize=(4, 4), + lw=4, + alpha=0.7, + title="", + fontsize=18, + step=4, + x_max=1, + show_names=False, + save=None, + ): + nodes, _ = self.ntssb.get_node_mixture() + + if node_names is not None: + nodes = [node for node in nodes if node.label in node_names] + + genes = None + if gene_names is not None: + # Transform gene names into gene indices + genes = np.array([self.adata.var_names.get_loc(g) for g in gene_names]) + else: + genes = np.arange(len(self.adata.var_names)) + + if self.search is not None: + if len(self.search.traces["elbo"]) > 0: + estimated = True + + if ax is None: + plt.figure(figsize=figsize) + else: + plt.gca(ax) + ticklabs = [] + tickpos = [] + for i, node in enumerate(nodes): + if node.parent() is not None: + ls = "-" + shape = np.exp(node.variational_parameters["locals"]["unobserved_factors_kernel_log_shape"]) + rate = np.exp(node.variational_parameters["locals"]["unobserved_factors_kernel_log_rate"]) + unobs = shape/rate + std = np.sqrt( + shape/rate**2 + ) + plt.bar( + np.arange(len(genes)), + unobs[genes].ravel(), + bottom= - step * i, + label=node.label, + color=node.tssb.color, + lw=lw, + alpha=alpha, + ls=ls, + ) + plt.plot(np.zeros((len(genes),)) - step * i, ls='--', color='gray', alpha=alpha) + if gene_names is not None and show_names: + plt.xticks(np.arange(len(gene_names)), labels=gene_names) + else: + plt.xticks([]) + tickpos.append(-step * i) + ticklabs.append(rf"{node.label.replace('-', '')}") + plt.yticks(tickpos, labels=ticklabs, fontsize=fontsize) + plt.title(title, fontsize=fontsize) + if save is not None: + plt.savefig(save, bbox_inches="tight") + if ax is None: + plt.show() + def bulkify(self): self.adata.var["raw_bulk"] = np.mean(self.adata.X, axis=0) try: diff --git a/scatrex/utils/annotate_utils.py b/scatrex/utils/annotate_utils.py new file mode 100644 index 0000000..e495ab6 --- /dev/null +++ b/scatrex/utils/annotate_utils.py @@ -0,0 +1,105 @@ +from pybiomart import Server +import pandas as pd + + +def convert_tidy_to_matrix(tidy_df, rows="single_cell_id", columns="copy_number"): + # Takes a tidy dataframe specifying the CNVs of cells along genomic bins + # and converts it to a cell by bin matrix + cell_df = tidy_df.loc[tidy_df[rows] == tidy_df[rows][0]] + bins_df = cell_df.drop(columns=[columns, rows], inplace=False) + tidy_df["bin_id"] = np.tile(bins_df.index, tidy_df[rows].unique().size) + matrix = tidy_df[[columns, rows, "bin_id"]].pivot_table( + values=columns, index=rows, columns="bin_id" + ) + + return matrix, bins_df + + +def annotate_bins(bins_df): + # Takes a dataframe of genomic regions and returns an ordered list of full genes in each region + server = Server("www.ensembl.org", use_cache=False) + dataset = server.marts["ENSEMBL_MART_ENSEMBL"].datasets["hsapiens_gene_ensembl"] + gene_coordinates = dataset.query( + attributes=[ + "chromosome_name", + "start_position", + "end_position", + "external_gene_name", + ], + filters={ + "chromosome_name": [ + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, + 17, + 18, + 19, + 20, + 21, + 22, + "X", + "Y", + ] + }, + use_attr_names=True, + ) + # Drop duplicate genes + gene_coordinates.drop_duplicates(subset="external_gene_name", ignore_index=True) + + annotated_bins = bins_df.copy() + annotated_bins["genes"] = [list() for _ in range(annotated_bins.shape[0])] + + bin_size = bins_df["end"][0] - bins_df["start"][0] + for index, row in gene_coordinates.iterrows(): + gene = row["external_gene_name"] + if pd.isna(gene): + continue + start_bin_in_chr = int(row["start_position"] / bin_size) + stop_bin_in_chr = int(row["end_position"] / bin_size) + chromosome = str(row["chromosome_name"]) + chr_start = np.where(bins_df["chr"] == chromosome)[0][0] + start_bin = start_bin_in_chr + chr_start + stop_bin = stop_bin_in_chr + chr_start + + if stop_bin < annotated_bins.shape[0]: + if np.all(annotated_bins.iloc[start_bin:stop_bin].chr == chromosome): + for bin in range(start_bin, stop_bin + 1): + annotated_bins.loc[bin, "genes"].append(gene) + + return annotated_bins + + +def annotate_matrix(matrix, annotated_bins): + # Takes a dataframe of cells by bins and a dataframe with gene lists for each bin + # and returns a dataframe of cells by genes + df_list = [] + chrs = [] + for bin, row in annotated_bins.iterrows(): + genes = row["genes"] + chr = row["chr"] + if len(genes) > 0: + df_list.append( + pd.concat([matrix[bin]] * len(genes), axis=1, ignore_index=True).rename( + columns=dict(zip(range(len(genes)), genes)) + ) + ) + chrs.append([chr] * df_list[-1].shape[1]) + chrs = np.concatenate(chrs) + df = pd.concat(df_list, axis=1) + chrs = chrs[np.where(~df.columns.duplicated())[0]] + df = df.loc[:, ~df.columns.duplicated()] + return df, chrs + diff --git a/scatrex/util.py b/scatrex/utils/math_utils.py similarity index 71% rename from scatrex/util.py rename to scatrex/utils/math_utils.py index 90eb396..e8544f8 100644 --- a/scatrex/util.py +++ b/scatrex/utils/math_utils.py @@ -5,6 +5,7 @@ from functools import partial import numpy as np +import jax from jax import vmap from jax import random import jax.numpy as jnp @@ -12,8 +13,6 @@ from jax.scipy.stats import norm, gamma, laplace, beta, dirichlet, poisson from jax.scipy.special import digamma, betaln, gammaln -from pybiomart import Server -import pandas as pd def relative_difference(current, prev, eps=1e-6): @@ -25,7 +24,7 @@ def absolute_difference(current, prev): def diag_gamma_sample(rng, log_alpha, log_beta): - return jnp.exp(-log_beta) * random.gamma(rng, jnp.exp(log_alpha)) + return jnp.clip(jnp.exp(-log_beta) * random.gamma(rng, jnp.exp(log_alpha)), a_min=1e-10, a_max=1e30) def diag_gamma_logpdf(x, log_alpha, log_beta): @@ -384,9 +383,9 @@ def betapdfln(x, a, b): ) -def boundbeta(a, b): +def boundbeta(a, b, rng): return (1.0 - numpy.finfo(numpy.float64).eps) * ( - numpy.random.beta(a, b) - 0.5 + rng.beta(a, b) - 0.5 ) + 0.5 # return numpy.random.beta(a,b) @@ -411,109 +410,70 @@ def logsumexp(X, axis=None): return numpy.log(numpy.sum(numpy.exp(X - maxes), axis=axis)) + maxes -def convert_tidy_to_matrix(tidy_df, rows="single_cell_id", columns="copy_number"): - # Takes a tidy dataframe specifying the CNVs of cells along genomic bins - # and converts it to a cell by bin matrix - cell_df = tidy_df.loc[tidy_df[rows] == tidy_df[rows][0]] - bins_df = cell_df.drop(columns=[columns, rows], inplace=False) - tidy_df["bin_id"] = np.tile(bins_df.index, tidy_df[rows].unique().size) - matrix = tidy_df[[columns, rows, "bin_id"]].pivot_table( - values=columns, index=rows, columns="bin_id" - ) +# JAX-optimized functions - return matrix, bins_df - - -def annotate_bins(bins_df): - # Takes a dataframe of genomic regions and returns an ordered list of full genes in each region - server = Server("www.ensembl.org", use_cache=False) - dataset = server.marts["ENSEMBL_MART_ENSEMBL"].datasets["hsapiens_gene_ensembl"] - gene_coordinates = dataset.query( - attributes=[ - "chromosome_name", - "start_position", - "end_position", - "external_gene_name", - ], - filters={ - "chromosome_name": [ - 1, - 2, - 3, - 4, - 5, - 6, - 7, - 8, - 9, - 10, - 11, - 12, - 13, - 14, - 15, - 16, - 17, - 18, - 19, - 20, - 21, - 22, - "X", - "Y", - ] - }, - use_attr_names=True, - ) - # Drop duplicate genes - gene_coordinates.drop_duplicates(subset="external_gene_name", ignore_index=True) - - annotated_bins = bins_df.copy() - annotated_bins["genes"] = [list() for _ in range(annotated_bins.shape[0])] - - bin_size = bins_df["end"][0] - bins_df["start"][0] - for index, row in gene_coordinates.iterrows(): - gene = row["external_gene_name"] - if pd.isna(gene): - continue - start_bin_in_chr = int(row["start_position"] / bin_size) - stop_bin_in_chr = int(row["end_position"] / bin_size) - chromosome = str(row["chromosome_name"]) - chr_start = np.where(bins_df["chr"] == chromosome)[0][0] - start_bin = start_bin_in_chr + chr_start - stop_bin = stop_bin_in_chr + chr_start - - if stop_bin < annotated_bins.shape[0]: - if np.all(annotated_bins.iloc[start_bin:stop_bin].chr == chromosome): - for bin in range(start_bin, stop_bin + 1): - annotated_bins.loc[bin, "genes"].append(gene) - - return annotated_bins - - -def annotate_matrix(matrix, annotated_bins): - # Takes a dataframe of cells by bins and a dataframe with gene lists for each bin - # and returns a dataframe of cells by genes - df_list = [] - chrs = [] - for bin, row in annotated_bins.iterrows(): - genes = row["genes"] - chr = row["chr"] - if len(genes) > 0: - df_list.append( - pd.concat([matrix[bin]] * len(genes), axis=1, ignore_index=True).rename( - columns=dict(zip(range(len(genes)), genes)) - ) - ) - chrs.append([chr] * df_list[-1].shape[1]) - chrs = np.concatenate(chrs) - df = pd.concat(df_list, axis=1) - chrs = chrs[np.where(~df.columns.duplicated())[0]] - df = df.loc[:, ~df.columns.duplicated()] - return df, chrs - - -def convert_phylogeny_to_clonal_tree(threshold): - # Converts a phylogenetic tree to a clonal tree by choosing the main clades - # according to some threshold - raise NotImplementedError +@jax.jit +def logbeta_func(a,b): + return gammaln(a) + gammaln(b) - gammaln(a+b) + +@jax.jit +def beta_kl(a1, b1, a2, b2): + """ + logp(a1,b1) - logp(a2,b2) + """ + kl = logbeta_func(a1,b1) - logbeta_func(a2,b2) + kl += (a1-a2)*digamma(a1) + kl += (b1-b2)*digamma(b1) + kl += (a2-a1 + b2-b1)*digamma(a1+b1) + return kl + +@jax.jit +def E_log_beta( + alpha, + beta, +): + return digamma(alpha) - digamma(alpha + beta) + +@jax.jit +def E_log_1_beta( + alpha, + beta, +): + return digamma(beta) - digamma(alpha + beta) + +@jax.jit +def E_q_log_beta( + alpha1, + beta1, + alpha2, + beta2, +): + """ + Expected value of log p(y) with x~Beta(alpha1,beta1) wrt to q(y) with y~Beta(alpha2,beta2) + """ + return (alpha1-1)*E_log_beta(alpha2,beta2) + (beta1-1)*E_log_1_beta(alpha2,beta2) - betaln(alpha1,beta1) + + +# This computes E_q[p(z_n=\epsilon | \nu, \psi)] +@jax.jit +def compute_expected_weight( + nu_alpha, + nu_beta, + psi_alpha, + psi_beta, + prev_nu_sticks_sum, + prev_psi_sticks_sum, +): + nu_digamma_sum = digamma(nu_alpha + nu_beta) + E_log_nu = digamma(nu_alpha) - nu_digamma_sum + nu_sticks_sum = digamma(nu_beta) - nu_digamma_sum + prev_nu_sticks_sum + psi_sticks_sum = local_psi_sticks_sum(psi_alpha, psi_beta) + prev_psi_sticks_sum + weight = E_log_nu + nu_sticks_sum + psi_sticks_sum + return weight, nu_sticks_sum, psi_sticks_sum + + +@jax.jit +def assignment_entropies(probs): + return -jax.lax.select(probs != 0, + probs * jnp.log(probs), + probs) diff --git a/scatrex/utils/tree_utils.py b/scatrex/utils/tree_utils.py new file mode 100644 index 0000000..0b3948c --- /dev/null +++ b/scatrex/utils/tree_utils.py @@ -0,0 +1,164 @@ +import numpy as np +import jax + +def tree_to_dict(tree, param_key='param', root_par='-1'): + """Converts a tree in a recursive dictionary to a flat dictionary with parent keys + """ + tree_dict = dict() + + def descend(node, par_id): + label = node['label'] + tree_dict[label] = dict() + tree_dict[label]["parent"] = par_id + for key in node: + if key == param_key: + key = 'param' + tree_dict[label]['param'] = node[param_key] + elif key == 'children': + tree_dict[label][key] = [c['label'] for c in node[key]] + elif key != 'parent': + tree_dict[label][key] = node[key] + for child in node['children']: + descend(child, node['label']) + + descend(tree, root_par) + + return tree_dict + +def dict_to_tree(tree_dict, root_name, param_key='param'): + """Converts a tree in a flat dictionary with parent keys to a recursive dictionary + """ + # Make children in the flat dict + for i in tree_dict: + tree_dict[i]["children"] = [] + for i in tree_dict: + for j in tree_dict: + if tree_dict[j]["parent"] == i: + tree_dict[i]["children"].append(j) + + # Make tree + root = {} + root['label'] = root_name + for key in tree_dict[root_name]: + # if isinstance(tree_dict[root_name][key], list): + # if isinstance(tree_dict[root_name][key][0], dict): + # continue + root[key] = tree_dict[root_name][key] + root['children'] = [] + + # Recursively construct tree + def descend(super_tree, label): + for i, child in enumerate(tree_dict[label]["children"]): + d = {} + for key in tree_dict[child]: + # if isinstance(tree_dict[child][key], list): + # if isinstance(tree_dict[child][key][0], dict): + # continue + d[key] = tree_dict[child][key] + d['children'] = [] + super_tree["children"].append(d) + descend(super_tree["children"][-1], child) + + descend(root, root_name) + return root + +def condense_tree(tree, min_weight=0.1): + """ + Traverse and choose with some prob whether to keep each node in the tree + """ + def descend(root): + to_keep = [] + for child in root['children']: + descend(child) + to_keep.append(int(child['weight'] > min_weight)) + if len(to_keep) > 0: + to_remove_idx = [i for i in range(len(to_keep)) if to_keep[i] == 0] + if 'weight' in root: + root['weight'] += np.sum([r['weight'] for i, r in enumerate(root['children']) if to_keep[i]==0]) + if 'size' in root: + root['size'] += int(np.sum([r['size'] for i, r in enumerate(root['children']) if to_keep[i]==0])) + # Set children of source as children of target + for child_to_remove in list(np.array(root['children'])[to_remove_idx]): + for child_of_to_remove in child_to_remove['children']: + to_keep.append(1) + root['children'].append(child_of_to_remove) + # Remove child + root['children'] = list(np.array(root['children'])[np.where(np.array(to_keep))[0]]) + descend(tree) + +def subsample_tree(tree, keep_prob=0.5, seed=42): + """ + Traverse and choose with some prob whether to keep each node in the tree + """ + def descend(root, key): + to_keep = [] + for child in root['children']: + key, subkey = jax.random.split(key) + descend(child, key) + to_keep.append(int(jax.random.bernoulli(subkey, keep_prob))) + if len(to_keep) > 0: + to_remove_idx = [i for i in range(len(to_keep)) if to_keep[i] == 0] + if 'weight' in root: + root['weight'] += np.sum([r['weight'] for i, r in enumerate(root['children']) if to_keep[i]==0]) + if 'size' in root: + root['size'] += int(np.sum([r['size'] for i, r in enumerate(root['children']) if to_keep[i]==0])) + # Set children of source as children of target + for child_to_remove in list(np.array(root['children'])[to_remove_idx]): + for child_of_to_remove in child_to_remove['children']: + to_keep.append(1) + root['children'].append(child_of_to_remove) + # Remove child + root['children'] = list(np.array(root['children'])[np.where(np.array(to_keep))[0]]) + key = jax.random.PRNGKey(seed) + descend(tree, key) + +def convert_phylogeny_to_clonal_tree(threshold): + # Converts a phylogenetic tree to a clonal tree by choosing the main clades + # according to some threshold + raise NotImplementedError + +def obs_rmse(ntssb1, ntssb2, param='observed'): + ntssb1_obs = np.zeros(ntssb1.data.shape) + ntssb2_obs = np.zeros(ntssb2.data.shape) + + ntssb1_nodes = ntssb1.get_nodes() + for node in ntssb1_nodes: + idx = np.where(ntssb1.assignments == node) + ntssb1_obs[idx] = node.get_param(param) + + ntssb2_nodes = ntssb2.get_nodes() + for node in ntssb2_nodes: + idx = np.where(ntssb2.assignments == node) + ntssb2_obs[idx] = node.get_param(param) + + return np.mean(np.sqrt(np.mean((ntssb1_obs - ntssb2_obs)**2, axis=1))) + +def ntssb_distance(ntssb1, ntssb2): + pdist1 = ntssb1.get_pairwise_obs_distances() + pdist2 = ntssb2.get_pairwise_obs_distances() + n_obs = pdist1.shape[0] + return np.sqrt(2./(n_obs*(n_obs-1)) * np.sum((pdist1-pdist2)**2)) + + +def subtree_to_tree_distance(subtree, tree): + """subtree is a TSSB in the NTSSB + tree is a dictionary containing the true tree that should be there + To check that the substructure we find is close to the real structure and not just something + completely different + """ + subtree_nodes = [] + tree_nodes = [] + dists = np.zeros((len(subtree_nodes), len(tree_nodes))) + for i, subtree_node in enumerate(subtree_nodes): + for j, tree_node in enumerate(tree_nodes): + dists[i,j] = compute_distance(subtree_node, tree_node) + return np.sqrt(np.mean(dists**2)) + +def print_tree(tree, tab=' '): + def descend(root, depth=0): + tabs = [tab] * depth + tabs = ''.join(tabs) + print(f"{tabs}{root['label']}") + for child in root['children']: + descend(child, depth=depth+1) + descend(tree) \ No newline at end of file