diff --git a/scatrex/models/__init__.py b/scatrex/models/__init__.py
index 78adaa0..b43e0b6 100644
--- a/scatrex/models/__init__.py
+++ b/scatrex/models/__init__.py
@@ -1 +1,3 @@
-from . import cna
+from .gaussian import GaussianTree
+from .trajectory import TrajectoryTree
+from .cna import CNATree
\ No newline at end of file
diff --git a/scatrex/models/cna/__init__.py b/scatrex/models/cna/__init__.py
index b04881c..388d504 100644
--- a/scatrex/models/cna/__init__.py
+++ b/scatrex/models/cna/__init__.py
@@ -1,3 +1,3 @@
-from .tree import ObservedTree
-from .node import Node
-from .opt_funcs import *
+from .tree import CNATree
+from .node import CNANode
+from .node_opt import *
diff --git a/scatrex/models/cna/node.py b/scatrex/models/cna/node.py
index c37f72d..47085af 100644
--- a/scatrex/models/cna/node.py
+++ b/scatrex/models/cna/node.py
@@ -3,879 +3,1391 @@
from numpy.random import *
from functools import partial
-
-import jax
import jax.numpy as jnp
-from jax import jit, grad, vmap
-from jax import random, ops
-from jax.example_libraries import optimizers
-from jax.scipy.stats import norm
-from jax.scipy.stats import gamma
-from jax.scipy.stats import poisson
-from jax.scipy.special import logit
-
-from ...util import *
+import jax
+import tensorflow_probability.substrates.jax.distributions as tfd
+
+from .node_opt import * # node optimization functions
+from .node_opt import _mc_obs_ll
+from ...utils.math_utils import *
from ...ntssb.node import *
-from ...ntssb.tree import *
-MIN_CNV = 1e-6
-MAX_XI = 1
+MIN_ALPHA = jnp.log(0.1)
+MAX_BETA = jnp.log(1./0.1)
+def update_params(params, params_gradient, step_size):
+ new_params = []
+ for i, param in enumerate(params):
+ new_params.append(param + step_size * params_gradient[i])
+ return new_params
-class Node(AbstractNode):
+class CNANode(AbstractNode):
def __init__(
self,
- is_observed,
- observed_parameters,
- log_lib_size_mean=7.1,
- log_lib_size_std=0.6,
- num_global_noise_factors=4,
- global_noise_factors_precisions_shape=2.0,
- cell_global_noise_factors_weights_scale=1.0,
- unobserved_factors_root_kernel=0.1,
- unobserved_factors_kernel=1.0,
- unobserved_factors_kernel_concentration=0.01,
- unobserved_factors_kernel_rate=1.0,
- frac_dosage=1.0,
- frac_overlap=0.25,
- baseline_shape=0.7,
- num_batches=0,
+ observed_parameters, # copy number state
+ cell_scale_mean=1e2,
+ cell_scale_shape=1.,
+ gene_scale_mean=1e2,
+ gene_scale_shape=1.,
+ direction_shape=.1,
+ inheritance_strength=1.,
+ n_factors=2,
+ obs_weight_variance=1.,
+ factor_precision_shape=2.,
+ min_cnv = 1e-6,
**kwargs,
):
- super(Node, self).__init__(is_observed, observed_parameters, **kwargs)
+ """
+ This model generates nodes in gene expression space combined with observed copy number states
+ """
+ super(CNANode, self).__init__(observed_parameters, **kwargs)
# The observed parameters are the CNVs of all genes
self.cnvs = np.array(self.observed_parameters)
- self.cnvs[np.where(self.cnvs == 0)[0]] = MIN_CNV
+ self.cnvs[np.where(self.cnvs == 0)[0]] = min_cnv # To avoid zero in Poisson likelihood
self.observed_parameters = np.array(self.cnvs)
+ self.cnvs = jnp.array(self.cnvs)
self.n_genes = self.cnvs.size
# Node hyperparameters
if self.parent() is None:
self.node_hyperparams = dict(
- log_lib_size_mean=log_lib_size_mean,
- log_lib_size_std=log_lib_size_std,
- num_global_noise_factors=num_global_noise_factors,
- global_noise_factors_precisions_shape=global_noise_factors_precisions_shape,
- cell_global_noise_factors_weights_scale=cell_global_noise_factors_weights_scale,
- unobserved_factors_root_kernel=unobserved_factors_root_kernel,
- unobserved_factors_kernel=unobserved_factors_kernel,
- unobserved_factors_kernel_concentration=unobserved_factors_kernel_concentration,
- unobserved_factors_kernel_rate=unobserved_factors_kernel_rate,
- frac_dosage=frac_dosage,
- frac_overlap=frac_overlap,
- baseline_shape=baseline_shape,
- num_batches=num_batches,
+ cell_scale_mean=cell_scale_mean,
+ cell_scale_shape=cell_scale_shape,
+ gene_scale_mean=gene_scale_mean,
+ gene_scale_shape=gene_scale_shape,
+ direction_shape=direction_shape,
+ inheritance_strength=inheritance_strength,
+ n_factors=n_factors,
+ obs_weight_variance=obs_weight_variance,
+ factor_precision_shape=factor_precision_shape,
)
else:
self.node_hyperparams = self.node_hyperparams_caller()
self.reset_parameters(**self.node_hyperparams)
- def inherit_parameters(self):
- if not self.is_observed:
- # Make sure we use the right observed parameters
- self.cnvs = self.parent().cnvs
-
- def get_node_mean(self, log_baseline, unobserved_factors, noise, cnvs):
- node_mean = jnp.exp(
- log_baseline + unobserved_factors + noise + jnp.log(cnvs / 2)
- )
- sum = jnp.sum(node_mean, axis=1).reshape(self.tssb.ntssb.num_data, 1)
- node_mean = node_mean / sum
- return node_mean
-
- def init_noise_factors(self):
- # Noise
- self.variational_parameters["globals"]["cell_noise_mean"] = np.zeros(
- (self.tssb.ntssb.num_data, self.num_global_noise_factors)
- )
- self.variational_parameters["globals"]["cell_noise_log_std"] = -np.ones(
- (self.tssb.ntssb.num_data, self.num_global_noise_factors)
- )
- self.variational_parameters["globals"]["noise_factors_mean"] = np.zeros(
- (self.num_global_noise_factors, self.n_genes)
- )
- self.variational_parameters["globals"]["noise_factors_log_std"] = -np.ones(
- (self.num_global_noise_factors, self.n_genes)
- )
- self.variational_parameters["globals"]["factor_precision_log_means"] = np.log(
- self.global_noise_factors_precisions_shape
- ) * np.ones((self.num_global_noise_factors))
- self.variational_parameters["globals"]["factor_precision_log_stds"] = -np.ones(
- (self.num_global_noise_factors)
- )
-
- # Batch effects
- self.variational_parameters["globals"]["batch_effects_mean"] = np.zeros(
- (self.num_batches, self.n_genes)
- )
- self.variational_parameters["globals"]["batch_effects_log_std"] = -np.ones(
- (self.num_batches, self.n_genes)
- )
-
- def reset_variational_parameters(self, means=True, variances=True):
- if self.parent() is None:
- # Baseline: first value is 1
- if means:
- self.variational_parameters["globals"]["log_baseline_mean"] = np.array(
- normal_sample(0, 1, self.n_genes - 1)
- ) # np.zeros((self.n_genes-1,))
- if self.tssb.ntssb.data is not None:
- self.full_data = jnp.array(self.tssb.ntssb.data)
- # init_baseline = np.mean(self.tssb.ntssb.data, axis=0)
- # init_log_baseline = np.log(init_baseline / init_baseline[0])[1:]
- init_baseline = np.mean(
- self.tssb.ntssb.data
- / np.sum(self.tssb.ntssb.data, axis=1).reshape(-1, 1)
- * self.n_genes,
- axis=0,
- )
- init_baseline = init_baseline / init_baseline[0]
- init_log_baseline = np.log(init_baseline[1:] + 1e-6)
- self.variational_parameters["globals"][
- "log_baseline_mean"
- ] = np.clip(init_log_baseline, -1, 1)
- if variances:
- self.variational_parameters["globals"][
- "log_baseline_log_std"
- ] = np.array(-2*np.ones((self.n_genes - 1,)))
-
- # Overdispersion
- # self.log_od_mean = np.zeros(1)
- # self.log_od_log_std = np.zeros(1)
-
- if self.tssb.ntssb.num_data is not None:
- # Noise
- if means:
- self.variational_parameters["globals"][
- "cell_noise_mean"
- ] = np.random.normal(
- 0,
- 0.1,
- (self.tssb.ntssb.num_data, self.num_global_noise_factors),
- )
- if variances:
- self.variational_parameters["globals"][
- "cell_noise_log_std"
- ] = -np.ones(
- (self.tssb.ntssb.num_data, self.num_global_noise_factors)
- )
- if means:
- self.variational_parameters["globals"][
- "noise_factors_mean"
- ] = np.random.normal(
- 0, 0.1, (self.num_global_noise_factors, self.n_genes)
- )
- if variances:
- self.variational_parameters["globals"][
- "noise_factors_log_std"
- ] = -np.ones((self.num_global_noise_factors, self.n_genes))
- if means:
- self.variational_parameters["globals"][
- "factor_precision_log_means"
- ] = np.log(self.global_noise_factors_precisions_shape) * np.ones(
- (self.num_global_noise_factors)
- )
- if variances:
- self.variational_parameters["globals"][
- "factor_precision_log_stds"
- ] = -np.ones((self.num_global_noise_factors))
- if means:
- self.variational_parameters["globals"][
- "batch_effects_mean"
- ] = np.random.normal(0, 0.1, (self.num_batches, self.n_genes))
- if variances:
- self.variational_parameters["globals"][
- "batch_effects_log_std"
- ] = -np.ones((self.num_batches, self.n_genes))
-
- self.data_ass_logits = np.array([])
- if self.tssb.ntssb.num_data is not None:
- self.data_ass_logits = np.zeros((self.tssb.ntssb.num_data))
- try:
- if self.parent().ll.shape[0] == self.tssb.ntssb.num_data:
- self.ll = np.array(self.parent().ll)
- self.data_weights = np.array(self.parent().data_weights)
- except:
- self.ll = -1e10 * np.ones((self.tssb.ntssb.num_data,))
- self.data_weights = np.zeros((self.tssb.ntssb.num_data,))
- if self.is_observed:
- self.tssb.data_weights = np.zeros((self.tssb.ntssb.num_data,))
- self.tssb.unnormalized_data_weights = np.zeros((self.tssb.ntssb.num_data,))
-
- # Sticks
- if means:
- self.variational_parameters["locals"]["nu_log_mean"] = np.array(
- np.log(1.0 * self.tssb.dp_alpha) + np.random.normal(0, 0.01)
- )
- if variances:
- self.variational_parameters["locals"]["nu_log_std"] = np.array(
- np.log(1.0 * self.tssb.dp_alpha)
- )
- if means:
- self.variational_parameters["locals"]["psi_log_mean"] = np.array(
- np.log(1.0 * self.tssb.dp_gamma) + np.random.normal(0, 0.01)
- )
- if variances:
- self.variational_parameters["locals"]["psi_log_std"] = np.array(
- np.log(1.0 * self.tssb.dp_gamma)
- )
-
- # Unobserved factors
- if means:
- try:
- self.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] = np.array(
- self.parent().variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- )
- except AttributeError:
- self.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] = np.zeros((self.n_genes,))
- if variances:
- self.variational_parameters["locals"][
- "unobserved_factors_log_std"
- ] = -2.0 * np.ones((self.n_genes,))
- if means:
- try:
- kernel_means = np.clip(
- self.unobserved_factors_kernel_concentration_caller()
- / (
- np.exp(
- (self.parent() != None)
- * (self.parent().parent() != None)
- * self.unobserved_factors_kernel_rate_caller()
- * np.abs(
- self.parent().variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- )
- )
- ),
- self.unobserved_factors_kernel_concentration_caller() / 10,
- 1e2,
- )
- self.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = (np.log(kernel_means) - (np.exp(-4) ** 2) / 2)
- except AttributeError:
- self.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.log(
- self.unobserved_factors_kernel_concentration_caller()
- ) * np.ones(
- (self.n_genes,)
- )
- if variances:
- self.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ] = -4.0 * np.ones((self.n_genes,))
- self.set_mean(
- self.get_mean(
- baseline=np.append(1, np.exp(self.log_baseline_caller())),
- unobserved_factors=self.variational_parameters["locals"][
- "unobserved_factors_mean"
- ],
- )
- )
+ if self.tssb is not None:
+ self.reset_variational_parameters()
+ self.sample_variational_distributions()
+ self.reset_sufficient_statistics(self.tssb.ntssb.num_batches)
+ # For adaptive optimization
+ self.reset_opt()
+
+ def reset_opt(self):
+ # For adaptive optimization
+ self.direction_states = self.initialize_direction_states()
+ self.state_states = self.initialize_state_states()
+
+ def apply_clip(self, param, minval=-jnp.inf, maxval=jnp.inf):
+ param = jnp.maximum(param, minval)
+ param = jnp.minimum(param, maxval)
+ return param
+
+ def combine_params(self):
+ return np.exp(self.params[0]) * 0.5 * self.cnvs # params is a list of two: 0 is \qsi and 1 is \chi
+
+ def get_mean(self):
+ return self.combine_params()
+
+ def set_mean(self, node_mean=None):
+ if node_mean is not None:
+ self.node_mean = node_mean
+ else:
+ self.node_mean = self.get_mean()
+
+ def get_observed_parameters(self):
+ return self.cnvs
+
+ def get_params(self):
+ return self.get_mean()
+
+ def get_param(self, param='mean'):
+ if param == 'observed':
+ return self.get_observed_parameters()
+ elif param == 'mean':
+ return self.get_mean()
+ else:
+ raise ValueError(f"No param available for `{param}`")
+
+ def remove_noise(self, data):
+ """
+ Noise is multiplicative in this model
+ """
+ return data * 1./self.cell_scales_caller() * 1./self.gene_scales_caller() * 1./jnp.exp(self.noise_factors_caller())
# ========= Functions to initialize node. =========
- def reset_data_parameters(self):
- self.full_data = jnp.array(self.tssb.ntssb.data)
- self.lib_sizes = np.sum(self.tssb.ntssb.data, axis=1).reshape(
- self.tssb.ntssb.num_data, 1
- )
- self.cell_global_noise_factors_weights = normal_sample(
- 0,
- self.cell_global_noise_factors_weights_scale,
- size=[self.tssb.ntssb.num_data, self.num_global_noise_factors],
- )
- self.cell_covariates = jnp.array(self.tssb.ntssb.covariates)
- self.num_batches = self.cell_covariates.shape[1]
-
- def generate_data_params(self):
- self.lib_sizes = 20.0 + np.exp(
- normal_sample(
- self.log_lib_size_mean,
- self.log_lib_size_std,
- size=self.tssb.ntssb.num_data,
- )
- ).reshape(self.tssb.ntssb.num_data, 1)
- self.lib_sizes = np.ceil(self.lib_sizes)
- self.cell_global_noise_factors_weights = normal_sample(
- 0,
- self.cell_global_noise_factors_weights_scale,
- size=[self.tssb.ntssb.num_data, self.num_global_noise_factors],
- )
- n_cells = self.tssb.ntssb.num_data
- self.cell_covariates = np.zeros((n_cells, self.num_batches))
- if self.num_batches > 1:
- batches = np.random.choice(self.num_batches, size=n_cells)
- self.cell_covariates[range(n_cells), batches] = 1
- else:
- self.cell_covariates = np.zeros((n_cells, 0))
- self.num_batches = 0
+ def set_node_hyperparams(self, **kwargs):
+ self.node_hyperparams.update(**kwargs)
def reset_parameters(
self,
- root_params=True,
- down_params=True,
- log_lib_size_mean=6,
- log_lib_size_std=0.8,
- num_global_noise_factors=4,
- global_noise_factors_precisions_shape=2.0,
- cell_global_noise_factors_weights_scale=1.0,
- unobserved_factors_root_kernel=0.1,
- unobserved_factors_kernel=1.0,
- unobserved_factors_kernel_concentration=0.01,
- unobserved_factors_kernel_rate=1.0,
- frac_dosage=1.0,
- frac_overlap=0.25,
- baseline_shape=0.1,
- num_batches=0,
+ cell_scale_mean=1e2,
+ cell_scale_shape=1.,
+ gene_scale_mean=1e2,
+ gene_scale_shape=1.,
+ direction_shape=.1,
+ inheritance_strength=1.,
+ n_factors=2,
+ obs_weight_variance=1.,
+ factor_precision_shape=2.,
):
- parent = self.parent()
-
- if parent is None: # this is the root
- self.node_hyperparams = dict(
- log_lib_size_mean=log_lib_size_mean,
- log_lib_size_std=log_lib_size_std,
- num_global_noise_factors=num_global_noise_factors,
- global_noise_factors_precisions_shape=global_noise_factors_precisions_shape,
- cell_global_noise_factors_weights_scale=cell_global_noise_factors_weights_scale,
- unobserved_factors_root_kernel=unobserved_factors_root_kernel,
- unobserved_factors_kernel=unobserved_factors_kernel,
- unobserved_factors_kernel_concentration=unobserved_factors_kernel_concentration,
- unobserved_factors_kernel_rate=unobserved_factors_kernel_rate,
- frac_dosage=frac_dosage,
- frac_overlap=frac_overlap,
- baseline_shape=baseline_shape,
- num_batches=num_batches,
- )
-
- if root_params:
- # The root is used to store global parameters: mu
- # self.log_baseline = normal_sample(0, 1., size=self.n_genes)
- # self.log_baseline[0] = 0.
- # self.baseline = np.exp(self.log_baseline)
- self.baseline_shape = baseline_shape
- self.baseline = np.random.gamma(
- self.baseline_shape, 1, size=self.n_genes
- )
- # self.baseline = np.concatenate([1, self.baseline])
- self.log_baseline = np.log(self.baseline)
-
- self.overdispersion = np.exp(normal_sample(0, 1))
- self.log_lib_size_mean = log_lib_size_mean
- self.log_lib_size_std = log_lib_size_std
-
- # Structured noise: keep all cells' noise factors at the root
- self.num_global_noise_factors = num_global_noise_factors
- self.global_noise_factors_precisions_shape = (
- global_noise_factors_precisions_shape
- )
- self.cell_global_noise_factors_weights_scale = (
- cell_global_noise_factors_weights_scale
- )
-
- self.global_noise_factors_precisions = gamma_sample(
- global_noise_factors_precisions_shape,
- 1.0,
- size=self.num_global_noise_factors,
- )
- self.global_noise_factors = normal_sample(
- 0,
- 1.0 / np.sqrt(self.global_noise_factors_precisions),
- size=[self.n_genes, self.num_global_noise_factors],
- ).T # K x P
-
- # Batch effects
- self.num_batches = num_batches
- self.batch_effects_factors = normal_sample(
- 0,
- 1.0,
- size=[self.n_genes, self.num_batches],
- ).T # K x P
-
- self.unobserved_factors_root_kernel = unobserved_factors_root_kernel
- self.unobserved_factors_kernel_concentration = (
- unobserved_factors_kernel_concentration
- )
- self.unobserved_factors_kernel_rate = unobserved_factors_kernel_rate
-
- self.frac_dosage = frac_dosage
- self.frac_overlap = frac_overlap
-
- self.inert_genes = np.random.choice(
- self.n_genes,
- size=int(self.n_genes * (1.0 - self.frac_dosage)),
- replace=False,
- )
-
- cnv_genes = np.arange(
- self.n_genes,
- )
- if self.tssb is not None:
- cnv_genes = self.tssb.ntssb.input_tree.get_affected_genes()
- self.unavailable_genes = np.sort(
- np.random.choice(
- cnv_genes,
- size=int(len(cnv_genes) * (1 - self.frac_overlap)),
- replace=False,
- )
- )
-
- # Root should not have unobserved factors
- self.unobserved_factors_kernel = 0 * np.array(
- [self.unobserved_factors_root_kernel] * self.n_genes
- )
- self.unobserved_factors = 0 * normal_sample(
- 0.0, self.unobserved_factors_kernel
- )
-
- self.set_mean()
+ self.node_hyperparams = dict(
+ cell_scale_mean=cell_scale_mean,
+ cell_scale_shape=cell_scale_shape,
+ gene_scale_mean=gene_scale_mean,
+ gene_scale_shape=gene_scale_shape,
+ direction_shape=direction_shape,
+ inheritance_strength=inheritance_strength,
+ n_factors=n_factors,
+ obs_weight_variance=obs_weight_variance,
+ factor_precision_shape=factor_precision_shape,
+ )
+ parent = self.parent()
+
+ if parent is None:
self.depth = 0.0
+ self.params = [np.zeros((self.n_genes,)), np.ones((self.n_genes,))] # state and direction
+
+ # Gene scales
+ rng = np.random.default_rng(seed=self.seed)
+ self.gene_scales = rng.gamma(self.node_hyperparams['gene_scale_shape'], self.node_hyperparams['gene_scale_mean']/self.node_hyperparams['gene_scale_shape'], size=(1, self.n_genes))
+
+ # Structured noise
+ n_factors = self.node_hyperparams['n_factors']
+ factor_precision_shape = self.node_hyperparams['factor_precision_shape']
+
+ self.factor_precisions = rng.gamma(factor_precision_shape, 1., size=(n_factors,1))
+ factor_scales = np.sqrt(1./self.factor_precisions)
+ self.factor_weights = rng.normal(0., factor_scales, size=(n_factors, self.n_genes)) * 1./np.sqrt(factor_precision_shape)
+
+ if n_factors > 0:
+ n_genes_per_factor = int(2/n_factors)
+ offset = np.sqrt(factor_precision_shape)
+ perm = np.random.permutation(self.n_genes)
+ for factor in range(n_factors):
+ gene_idx = perm[factor*n_genes_per_factor:(factor+1)*n_genes_per_factor]
+ self.factor_weights[factor,gene_idx] *= offset
+
+ # Set data-dependent parameters
+ if self.tssb is not None:
+ num_data = self.tssb.ntssb.num_data
+ if num_data is not None:
+ self.reset_data_parameters()
else: # Non-root node: inherits everything from upstream node
- self.node_hyperparams = self.node_hyperparams_caller()
- if down_params:
- self.unobserved_factors_kernel = gamma_sample(
- self.unobserved_factors_kernel_concentration_caller(),
- np.exp(np.abs(parent.unobserved_factors)),
- )
- unavailable_genes = self.unavailable_genes_caller()
- if len(unavailable_genes) > 0:
- self.unobserved_factors_kernel[unavailable_genes] = (
- np.min(self.unobserved_factors_kernel) / 2.0
- )
-
- # Make sure some genes are affected in unobserved nodes
- if not self.is_observed:
- top_genes = np.argsort(self.unobserved_factors_kernel)[::-1][:5]
- self.unobserved_factors_kernel[top_genes] = np.max(
- [5.0, np.max(self.unobserved_factors_kernel)]
- )
- self.unobserved_factors = normal_sample(
- parent.unobserved_factors, self.unobserved_factors_kernel
- )
- self.unobserved_factors = np.clip(
- self.unobserved_factors, -MAX_XI, MAX_XI
- )
- if not self.is_observed:
- self.unobserved_factors[
- top_genes
- ] = MAX_XI # just force an amplification
-
- # Observation mean
- self.set_mean()
-
self.depth = parent.depth + 1
+ rng = np.random.default_rng(seed=self.seed)
+ sampled_direction = rng.gamma(self.node_hyperparams['direction_shape'],
+ jnp.exp(-self.node_hyperparams['inheritance_strength']*jnp.abs(parent.params[0])))
+ sampled_state = jnp.maximum(rng.normal(parent.params[0], sampled_direction), -1)
+ sampled_state = jnp.minimum(sampled_state, 1)
+ self.params = [sampled_state, sampled_direction]
+
+ self.set_mean()
+
+ # Generate cell sizes and structure on the factors
+ def reset_data_parameters(self):
+ num_data = self.tssb.ntssb.num_data
+ rng = np.random.default_rng(seed=self.seed)
+ self.cell_scales = rng.gamma(self.node_hyperparams['cell_scale_shape'], self.node_hyperparams['cell_scale_mean']/self.node_hyperparams['cell_scale_shape'], size=(num_data, 1))
+
+ n_factors = self.node_hyperparams['n_factors']
+ self.obs_weights = rng.normal(0., 1., size=(num_data, n_factors)) * 1./3.
+ if n_factors > 0:
+ n_obs_per_factor = int(num_data/n_factors)
+ offset = 6.
+ perm = np.random.permutation(num_data)
+ for factor in range(n_factors):
+ obs_idx = perm[factor*n_obs_per_factor:(factor+1)*n_obs_per_factor]
+ self.obs_weights[obs_idx,factor] *= offset
+
+ self.noise_factors = self.obs_weights.dot(self.factor_weights)
+
+ def reset_data_variational_parameters(self):
+ if self.parent() is None and self.tssb.parent() is None:
+ num_data = self.tssb.ntssb.num_data
+
+ # Set priors
+ lib_sizes = np.sum(self.tssb.ntssb.data, axis=1)
+ self.lib_ratio = np.ones((self.tssb.ntssb.data.shape[0], 1))
+ self.lib_ratio *= np.mean(lib_sizes) / np.var(lib_sizes)
+ gene_sizes = np.sum(self.tssb.ntssb.data, axis=0)
+ self.gene_ratio = np.mean(gene_sizes) / np.var(gene_sizes)
+
+ cell_scales_alpha_init = self.node_hyperparams['cell_scale_shape'] * jnp.ones((num_data,1))
+ cell_scales_beta_init = self.node_hyperparams['cell_scale_shape'] * jnp.ones((num_data,1))
+ gene_scales_alpha_init = self.node_hyperparams['gene_scale_shape'] * jnp.ones((self.n_genes,))
+ gene_scales_beta_init = self.node_hyperparams['gene_scale_shape'] * jnp.ones((self.n_genes,))
+
+ rng = np.random.default_rng(self.seed)
+ # root stores global parameters
+ n_factors = self.node_hyperparams['n_factors']
+ factor_precision_shape = self.node_hyperparams['factor_precision_shape']
+ self.variational_parameters["global"] = {
+ 'gene_scales': {'log_alpha': jnp.log(gene_scales_alpha_init),
+ 'log_beta': jnp.log(gene_scales_beta_init)},
+ 'factor_precisions': {'log_alpha': jnp.log(10. * jnp.ones((n_factors,1))),
+ 'log_beta' : jnp.log(10./factor_precision_shape * jnp.ones((n_factors,1)))},
+ 'factor_weights': {'mean': jnp.array(0.01*rng.normal(size=(n_factors, self.n_genes))),
+ 'log_std': -2. + jnp.zeros((n_factors, self.n_genes))}
+ }
+ rng = np.random.default_rng(self.seed+1)
+ self.variational_parameters["local"] = {
+ 'cell_scales': {'log_alpha': jnp.log(cell_scales_alpha_init),
+ 'log_beta': jnp.log(cell_scales_beta_init)},
+ 'obs_weights': {'mean': jnp.array(self.node_hyperparams['obs_weight_variance']/10.*rng.normal(size=(num_data, n_factors))),
+ 'log_std': -2. + jnp.zeros((num_data, n_factors))}
+ }
+ self.cell_scales = jnp.exp(self.variational_parameters["local"]["cell_scales"]["log_alpha"]-self.variational_parameters["local"]["cell_scales"]["log_beta"])
+ self.gene_scales = jnp.exp(self.variational_parameters["global"]["gene_scales"]["log_alpha"]-self.variational_parameters["global"]["gene_scales"]["log_beta"])
+ self.obs_weights = self.variational_parameters["local"]["obs_weights"]["mean"]
+ self.factor_precisions = self.variational_parameters["global"]["factor_weights"]["mean"]
+ self.factor_weights = self.variational_parameters["global"]["factor_weights"]["mean"]
+ self.noise_factors = self.obs_weights.dot(self.factor_weights)
+
+ def reset_variational_parameters(self):
+ # Assignments
+ num_data = self.tssb.ntssb.num_data
+ if num_data is not None:
+ self.variational_parameters['q_z'] = jnp.ones(num_data,)
+
+ self.variational_parameters['sum_E_log_1_nu'] = 0.
+ self.variational_parameters['E_log_phi'] = 0.
- def set_mean(self, node_mean=None, variational=False):
- if node_mean is not None:
- self.node_mean = node_mean
- else:
- if variational:
- self.node_mean = (
- np.append(1, np.exp(self.log_baseline_caller()))
- * self.cnvs
- / 2
- * np.exp(
- (self.parent() is not None)
- * self.variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- )
- )
- else:
- self.node_mean = (
- self.baseline_caller()
- * self.cnvs
- / 2
- * np.exp(self.unobserved_factors)
- )
- self.node_mean = self.node_mean / np.sum(self.node_mean)
-
- def get_mean(
- self,
- baseline=None,
- unobserved_factors=None,
- noise=0.0,
- batch_effects=0.0,
- cell_factors=None,
- global_factors=None,
- cnvs=None,
- norm=True,
- inert_genes=None,
- ):
- baseline = (
- np.append(1, np.exp(self.log_baseline_caller()))
- if baseline is None
- else baseline
- )
- unobserved_factors = (
- self.variational_parameters["locals"]["unobserved_factors_mean"]
- if unobserved_factors is None
- else unobserved_factors
- )
- unobserved_factors *= self.parent() is not None
- cnvs = self.cnvs if cnvs is None else cnvs
- if inert_genes is not None:
- cnvs = np.array(cnvs)
- inert_genes = np.array(inert_genes)
- zero_genes = np.where(cnvs == MIN_CNV)[0]
- cnvs[
- inert_genes
- ] = 2.0 # set these genes to 2, i.e., act like they have no CNV
- cnvs[zero_genes] = MIN_CNV # (except if they are zero)
- node_mean = (
- baseline * cnvs / 2 * np.exp(unobserved_factors + noise + batch_effects)
- )
- if norm:
- if len(node_mean.shape) == 1:
- sum = np.sum(node_mean)
- else:
- sum = np.sum(node_mean, axis=1)
- if len(sum.shape) > 0:
- sum = sum.reshape(-1, 1)
- node_mean = node_mean / sum
- return node_mean
-
- def get_noise(self, variational=False):
- return self.cell_global_noise_factors_weights_caller(
- variational=variational
- ).dot(self.global_noise_factors_caller(variational=variational))
-
- def get_batch_effects(self, variational=False):
- return self.cell_covariates_caller().dot(
- self.batch_effects_factors_caller(variational=variational)
- )
-
- # ========= Functions to take samples from node. =========
- def sample_observation(self, n):
- noise = self.cell_global_noise_factors_weights_caller()[n].dot(
- self.global_noise_factors_caller()
- )
- batch_effects = self.cell_global_noise_factors_weights_caller()[n].dot(
- self.global_noise_factors_caller()
- )
- node_mean = self.get_mean(
- unobserved_factors=self.unobserved_factors,
- baseline=self.baseline_caller(),
- noise=noise,
- batch_effects=batch_effects,
- inert_genes=self.inert_genes_caller(),
- )
- s = multinomial_sample(self.lib_sizes_caller()[n], node_mean)
- # s = negative_binomial_sample(self.lib_sizes_caller()[n] * node_mean, 0.01)
- return s
-
- # ========= Functions to evaluate node's parameters. =========
- def log_baseline_logprior(self, x=None):
- if x is None:
- x = np.log(self.baseline)
- return normal_lpdf(x, 0, 1)
-
- def log_overdispersion_logprior(self, x=None):
- if x is None:
- x = np.log(self.overdispersion)
- return normal_lpdf(x, 0, 1)
-
- def global_noise_factors_logprior(self):
- return normal_lpdf(
- self.global_noise_factors,
- 0.0,
- 1.0 / np.sqrt(self.global_noise_factors_precisions),
- )
-
- def cell_global_noise_factors_logprior(self):
- return normal_lpdf(
- self.cell_global_noise_factors_weights,
- 0.0,
- self.cell_global_noise_factors_weights_scale,
- )
+ # Sticks
+ self.variational_parameters["delta_1"] = 1.
+ self.variational_parameters["delta_2"] = 1.
+ self.variational_parameters["sigma_1"] = 1.
+ self.variational_parameters["sigma_2"] = 1.
- def unobserved_factors_logprior(self, x=None):
- if x is None:
- x = self.unobserved_factors
+ # Pivots
+ self.variational_parameters["q_rho"] = np.ones(len(self.tssb.children_root_nodes),)
- if self.parent() is not None:
- if self.is_observed:
- llp = normal_lpdf(
- x,
- self.parent().unobserved_factors,
- self.unobserved_factors_root_kernel,
- )
+ parent = self.parent()
+ if parent is None and self.tssb.parent() is None:
+ self.params = [jnp.zeros((self.n_genes,)), jnp.ones((self.n_genes,))]
+
+ if num_data is not None:
+ self.reset_data_variational_parameters()
else:
- llp = normal_lpdf(
- x, self.parent().unobserved_factors, self.unobserved_factors_kernel
- )
- else:
- llp = normal_lpdf(x, 0.0, self.unobserved_factors_root_kernel)
-
- return llp
-
- def logprior(self):
- # Prior of current node
- llp = self.unobserved_factors_logprior()
- if self.parent() is None:
- llp = (
- llp
- + self.global_noise_factors_logprior()
- + self.cell_global_noise_factors_logprior()
- )
- llp = llp + self.log_baseline_logprior()
-
- return llp
-
- def loglh(self, n, variational=False, axis=None):
- noise = self.get_noise(variational=variational)[n]
- batch_effects = self.get_batch_effects(variational=variational)[n]
- unobs_factors = self.unobserved_factors
- baseline = self.baseline_caller()
- if variational:
- unobs_factors = self.variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- baseline = np.append(1, np.exp(self.log_baseline_caller(variational=True)))
- node_mean = self.get_mean(
- baseline=baseline,
- unobserved_factors=unobs_factors,
- noise=noise,
- batch_effects=batch_effects,
- )
- lib_sizes = self.lib_sizes_caller()[n]
- return partial(jit, static_argnums=2)(poisson_lpmf)(
- jnp.array(self.tssb.ntssb.data[n]), lib_sizes * node_mean, axis=axis
- )
-
- def complete_loglh(self):
- return self.loglh(list(self.data))
-
- def logprob(self, n):
- # Add prior
- l = self.loglh(n) + self.logprior()
-
- # Add prob of children nodes given current node's parameters
- for child in self.children():
- l = l + child.unobserved_factors_logprior()
-
- return l
-
- # Sum over all data points attached to this node
- def complete_logprob(self):
- return self.logprob(list(self.data))
-
- # Sum over all data
- def full_logprob(self):
- return self.logprob(list(range(len(self.tssb.ntssb.data))))
-
- # ========= Functions to acess root's parameters. =========
- def node_hyperparams_caller(self):
- if self.parent() is None:
- return self.node_hyperparams
- else:
- return self.parent().node_hyperparams_caller()
-
- def unavailable_genes_caller(self):
- if self.parent() is None:
- return self.unavailable_genes
- else:
- return self.parent().unavailable_genes_caller()
-
- def global_noise_factors_precisions_shape_caller(self):
- if self.parent() is None:
- return self.global_noise_factors_precisions_shape
- else:
- return self.parent().global_noise_factors_precisions_shape_caller()
-
- def cell_covariates_caller(self):
- if self.parent() is None:
- return self.cell_covariates
- else:
- return self.parent().cell_covariates_caller()
-
- def lib_sizes_caller(self):
- if self.parent() is None:
- return self.lib_sizes
- else:
- return self.parent().lib_sizes_caller()
-
- def log_kernel_mean_caller(self):
- if self.parent() is None:
- return self.logkernel_means
- else:
- return self.parent().log_kernel_mean_caller()
-
- def kernel_caller(self):
- if self.parent() is None:
- return self.kernel
- else:
- return self.parent().kernel_caller()
-
- def node_std_caller(self):
- if self.parent() is None:
- return self.node_std
- else:
- return self.parent().node_std_caller()
-
- def inert_genes_caller(self):
- if self.parent() is None:
- return self.inert_genes
- else:
- return self.parent().inert_genes_caller()
-
- def log_baseline_caller(self, variational=True):
- if self.parent() is None:
- if variational:
- return self.variational_parameters["globals"]["log_baseline_mean"]
+ rng = np.random.default_rng(self.seed)
+ # root stores global parameters
+ n_factors = self.node_hyperparams['n_factors']
+ factor_precision_shape = self.node_hyperparams['factor_precision_shape']
+ self.variational_parameters["global"] = {
+ 'gene_scales': {'log_alpha': jnp.ones((self.n_genes,)),
+ 'log_beta': jnp.ones((self.n_genes,))},
+ 'factor_precisions': {'log_alpha': jnp.log(10. * jnp.ones((n_factors,1))),
+ 'log_beta' : jnp.log(10./factor_precision_shape * jnp.ones((n_factors,1)))},
+ 'factor_weights': {'mean': jnp.array(0.01*rng.normal(size=(n_factors, self.n_genes))),
+ 'log_std': -2. + jnp.zeros((n_factors, self.n_genes))}
+ }
+ else: # only the non-root nodes have kernel variational parameters
+ # Kernel
+ parent_param = jnp.zeros((self.n_genes,))
+ if parent is not None:
+ parent_param = parent.params[0]
+
+ rng = np.random.default_rng(self.seed+2)
+ sampled_direction = rng.gamma(self.node_hyperparams['direction_shape'],
+ jnp.exp(-self.node_hyperparams['inheritance_strength'] * jnp.abs(parent_param)))
+ rng = np.random.default_rng(self.seed+3)
+ if np.all(parent_param == 0):
+ sampled_state = rng.normal(parent_param*0.1, 0.01) # is root node, so avoid messing with main node attachments
else:
- return self.log_baseline
+ sampled_state = jnp.clip(rng.normal(parent_param*0.1, sampled_direction), a_min=-1, a_max=1) # to explore (without numerical explosions)
+
+ init_concentration = 10.
+ self.variational_parameters["kernel"] = {
+ 'direction': {'log_alpha': jnp.log(init_concentration*jnp.ones((self.n_genes,))), 'log_beta': jnp.log(init_concentration/self.node_hyperparams['direction_shape'] * jnp.ones((self.n_genes,)))},
+ 'state': {'mean': jnp.array(sampled_state), 'log_std': jnp.array(rng.normal(-2., 0.1, size=self.n_genes))}
+ }
+ self.params = [self.variational_parameters["kernel"]["state"]["mean"],
+ jnp.exp(self.variational_parameters["kernel"]["direction"]["log_alpha"]-self.variational_parameters["kernel"]["direction"]["log_beta"])]
+
+ def reset_variational_noise_factors(self):
+ rng = np.random.default_rng(self.seed)
+ n_factors = self.node_hyperparams['n_factors']
+ factor_precision_shape = self.node_hyperparams['factor_precision_shape']
+ self.variational_parameters["global"]["factor_precisions"] = {
+ 'log_alpha': jnp.log(10. * jnp.ones((n_factors,1))),
+ 'log_beta' : jnp.log(10./factor_precision_shape * jnp.ones((n_factors,1)))
+ }
+ self.variational_parameters["global"]["factor_weights"] = {
+ 'mean': jnp.array(0.01*rng.normal(size=(n_factors, self.n_genes))),
+ 'log_std': -2. + jnp.zeros((n_factors, self.n_genes))
+ }
+ num_data = self.tssb.ntssb.num_data
+ self.variational_parameters["local"]["obs_weights"] = {
+ 'mean': jnp.array(self.node_hyperparams['obs_weight_variance']/10.*rng.normal(size=(num_data, n_factors))),
+ 'log_std': -2. + jnp.zeros((num_data, n_factors))
+ }
+
+ def set_learned_parameters(self):
+ if self.parent() is None and self.tssb.parent() is None:
+ self.obs_weights = self.variational_parameters["local"]["obs_weights"]["mean"]
+ self.factor_precisions = jnp.exp(self.variational_parameters["global"]["factor_precisions"]["log_alpha"]
+ -self.variational_parameters["global"]["factor_precisions"]["log_beta"])
+ self.factor_weights = self.variational_parameters["global"]["factor_weights"]["mean"]
+ self.noise_factors = self.obs_weights.dot(self.factor_weights)
+ self.cell_scales = jnp.exp(self.variational_parameters["local"]["cell_scales"]["log_alpha"]
+ -self.variational_parameters["local"]["cell_scales"]["log_beta"])
+ self.gene_scales = jnp.exp(self.variational_parameters["global"]["gene_scales"]["log_alpha"]
+ -self.variational_parameters["global"]["gene_scales"]["log_beta"])
else:
- return self.parent().log_baseline_caller(variational=variational)
-
- def baseline_caller(self):
- if self.parent() is None:
- return self.baseline
+ self.params = [self.variational_parameters["kernel"]["state"]["mean"],
+ jnp.exp(self.variational_parameters["kernel"]["direction"]["log_alpha"]-self.variational_parameters["kernel"]["direction"]["log_beta"])]
+
+ def reset_sufficient_statistics(self, num_batches=1):
+ self.suff_stats = {
+ 'ent': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(c_n = this tree) q(z_n = this node) * log q(z_n = this node)
+ 'mass': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node)
+ 'A': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g x_ng * E[\gamma_n]
+ 'B_g': {'total': 0, 'batch': np.zeros((num_batches,self.n_genes))}, # \sum_n q(z_n = this node) * x_ng
+ 'C': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g x_ng * E[s_nW_g]
+ 'D_g': {'total': 0, 'batch': np.zeros((num_batches,self.n_genes))}, # \sum_n q(z_n = this node) * E[\gamma_n] * E[s_nW_g]
+ 'E': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * lgamma(x_ng+1)
+ }
+ if self.parent() is None and self.tssb.parent() is None:
+ self.local_suff_stats = {
+ 'locals_kl': {'total': 0., 'batch': np.zeros((num_batches,))},
+ }
+
+ def init_new_node_kernel(self, **kwargs):
+ # Get data for which prob of assigning to parent is > 1/n_total_nodes
+ weights = self.parent().variational_parameters['q_z'] * self.tssb.variational_parameters['q_c']
+ idx = np.where(weights > 1./np.sqrt(self.tssb.ntssb.n_total_nodes))[0]
+ # Initialize prioritizing cells with lowest ll in parent
+ if len(idx) > 0 and 'll' in self.parent().variational_parameters: # only proceed in this manner if parent has already been evaluated
+ thres = np.quantile(self.parent().variational_parameters['ll'][idx], q=.1)
+ idx = idx[np.where(self.parent().variational_parameters['ll'][idx] < thres)[0]]
+ self.reset_variational_state(idx=idx, **kwargs)
+ # Sample
+ if self.parent().samples is not None:
+ n_samples = self.parent().samples[0].shape[0]
+ self.sample_kernel(n_samples=n_samples)
+
+ def init_kernel(self, **kwargs):
+ # Get data for which prob of assigning here is > 1/n_total_nodes
+ weights = self.variational_parameters['q_z'] * self.tssb.variational_parameters['q_c']
+ idx = np.where(weights > 1./np.sqrt(self.tssb.ntssb.n_total_nodes))[0]
+ # Initialize prioritizing cells with highest ll here
+ thres = np.quantile(self.variational_parameters['ll'][idx], q=.9)
+ idx = idx[np.where(self.variational_parameters['ll'][idx] >= thres)[0]]
+ self.reset_variational_state(idx=idx, **kwargs)
+
+ def reset_variational_state(self, log_std=-2, idx=None, weights=None):
+ if self.parent() is None and self.tssb.parent() is None:
+ return
else:
- return self.parent().baseline_caller()
-
- def log_overdispersion_caller(self):
- if self.parent() is None:
- return self.log_od_mean
+ if idx is None:
+ idx = jnp.arange(self.tssb.ntssb.num_data)
+ if weights is None:
+ weights = self.variational_parameters['q_z'][idx] * self.tssb.variational_parameters['q_c'][idx]
+ cell_scales_mean = jnp.mean(self.get_cell_scales_sample()[:,idx],axis=0)
+ gene_scales_mean = jnp.mean(self.get_gene_scales_sample(),axis=0)
+ noise_factors = jnp.mean(self.get_noise_sample(idx),axis=0)
+ cnvs_contrib = self.cnvs/2
+ init_state = jnp.log(jnp.sum(self.tssb.ntssb.data[idx]/(cell_scales_mean*gene_scales_mean*cnvs_contrib*jnp.exp(noise_factors)) * weights[:,None],axis=0)/jnp.sum(weights[:,None]))
+ self.variational_parameters['kernel']['state']['mean'] = init_state
+ self.variational_parameters['kernel']['state']['log_std'] = log_std * jnp.ones((self.n_genes,))
+
+ def merge_suff_stats(self, suff_stats):
+ for stat in self.suff_stats:
+ self.suff_stats[stat]['total'] += suff_stats[stat]['total']
+ self.suff_stats[stat]['batch'] += suff_stats[stat]['batch']
+
+ def update_sufficient_statistics(self, batch_idx=None):
+ if batch_idx is not None:
+ idx = self.tssb.ntssb.batch_indices[batch_idx]
else:
- return self.parent().log_overdispersion_caller()
-
- def overdispersion_caller(self):
- if self.parent() is None:
- return self.overdispersion
+ idx = jnp.arange(self.tssb.ntssb.num_data)
+
+ if self.parent() is None and self.tssb.parent() is None:
+ locals_kl = self.compute_local_priors(idx) + self.compute_local_entropies(idx)
+ if batch_idx is not None:
+ self.local_suff_stats['locals_kl']['total'] -= self.local_suff_stats['locals_kl']['batch'][batch_idx]
+ self.local_suff_stats['locals_kl']['batch'][batch_idx] = locals_kl
+ self.local_suff_stats['locals_kl']['total'] += self.local_suff_stats['locals_kl']['batch'][batch_idx]
+ else:
+ self.local_suff_stats['locals_kl']['total'] = locals_kl
+
+ ent = assignment_entropies(self.variational_parameters['q_z'][idx])
+ ent *= self.tssb.variational_parameters['q_c'][idx]
+ E_ass = self.variational_parameters['q_z'][idx] * self.tssb.variational_parameters['q_c'][idx]
+ E_loggamma = jnp.mean(jnp.log(self.get_cell_scales_sample()[:,idx]),axis=0)
+ E_gamma = jnp.mean(self.get_cell_scales_sample()[:,idx],axis=0)
+ E_sw = jnp.mean(self.get_noise_sample(idx),axis=0)
+ E_expsw = jnp.mean(jnp.exp(self.get_noise_sample(idx)),axis=0)
+ x = self.tssb.ntssb.data[idx]
+
+ new_ent = jnp.sum(ent)
+ new_mass = jnp.sum(E_ass)
+ new_A = jnp.sum(E_ass * E_loggamma.ravel() * jnp.sum(x, axis=1))
+ new_B = jnp.sum(E_ass[:,None] * x, axis=0)
+ new_C = jnp.sum(E_ass * jnp.sum(x * E_sw, axis=1))
+ new_D = jnp.sum(E_ass[:,None] * E_gamma * E_expsw, axis=0)
+ new_E = jnp.sum(E_ass * jnp.sum(gammaln(x+1), axis=1))
+
+ if batch_idx is not None:
+ self.suff_stats['ent']['total'] -= self.suff_stats['ent']['batch'][batch_idx]
+ self.suff_stats['ent']['batch'][batch_idx] = new_ent
+ self.suff_stats['ent']['total'] += self.suff_stats['ent']['batch'][batch_idx]
+
+ self.suff_stats['mass']['total'] -= self.suff_stats['mass']['batch'][batch_idx]
+ self.suff_stats['mass']['batch'][batch_idx] = new_mass
+ self.suff_stats['mass']['total'] += self.suff_stats['mass']['batch'][batch_idx]
+
+ self.suff_stats['A']['total'] -= self.suff_stats['A']['batch'][batch_idx]
+ self.suff_stats['A']['batch'][batch_idx] = new_A
+ self.suff_stats['A']['total'] += self.suff_stats['A']['batch'][batch_idx]
+
+ self.suff_stats['B_g']['total'] -= self.suff_stats['B_g']['batch'][batch_idx]
+ self.suff_stats['B_g']['batch'][batch_idx] = new_B
+ self.suff_stats['B_g']['total'] += self.suff_stats['B_g']['batch'][batch_idx]
+
+ self.suff_stats['C']['total'] -= self.suff_stats['C']['batch'][batch_idx]
+ self.suff_stats['C']['batch'][batch_idx] = new_C
+ self.suff_stats['C']['total'] += self.suff_stats['C']['batch'][batch_idx]
+
+ self.suff_stats['D_g']['total'] -= self.suff_stats['D_g']['batch'][batch_idx]
+ self.suff_stats['D_g']['batch'][batch_idx] = new_D
+ self.suff_stats['D_g']['total'] += self.suff_stats['D_g']['batch'][batch_idx]
+
+ self.suff_stats['E']['total'] -= self.suff_stats['E']['batch'][batch_idx]
+ self.suff_stats['E']['batch'][batch_idx] = new_E
+ self.suff_stats['E']['total'] += self.suff_stats['E']['batch'][batch_idx]
else:
- return self.parent().overdispersion_caller()
+ self.suff_stats['ent']['total'] = new_ent
+ self.suff_stats['mass']['total'] = new_mass
+ self.suff_stats['A']['total'] = new_A
+ self.suff_stats['B_g']['total'] = new_B
+ self.suff_stats['C']['total'] = new_C
+ self.suff_stats['D_g']['total'] = new_D
+ self.suff_stats['E']['total'] = new_E
- def unobserved_factors_damp_caller(self):
- if self.parent() is None:
- return self.unobserved_factors_damp
- else:
- return self.parent().unobserved_factors_damp_caller()
+ # ========= Functions to take samples from node. =========
+ def sample_observation(self, n):
+ state = self.params[0]
+ cnvs = self.cnvs
+ noise_factors = self.noise_factors_caller()[n]
+ cell_scales = self.cell_scales_caller()[n]
+ gene_scales = self.gene_scales_caller()
+ rng = np.random.default_rng(seed=self.seed+n)
+ s = rng.poisson(cell_scales*gene_scales*cnvs/2*jnp.exp(state + noise_factors))
+ return s
- def unobserved_factors_kernel_concentration_caller(self):
- if self.parent() is None:
- return self.unobserved_factors_kernel_concentration
- else:
- return self.parent().unobserved_factors_kernel_concentration_caller()
+ def sample_observations(self):
+ n_obs = len(self.data)
+ state = self.params[0]
+ cnvs = self.cnvs
+ noise_factors = self.noise_factors_caller()[np.array(list(self.data))]
+ cell_scales = self.cell_scales_caller()[np.array(list(self.data))]
+ gene_scales = self.gene_scales_caller()
+ rng = np.random.default_rng(seed=self.seed)
+ s = rng.poisson(cell_scales*gene_scales*cnvs/2*jnp.exp(state + noise_factors), size=[n_obs, self.n_genes])
+ return s
- def unobserved_factors_kernel_rate_caller(self):
+ # ========= Functions to access root's parameters. =========
+ def node_hyperparams_caller(self):
if self.parent() is None:
- return self.unobserved_factors_kernel_rate
+ return self.node_hyperparams
else:
- return self.parent().unobserved_factors_kernel_rate_caller()
+ return self.parent().node_hyperparams_caller()
- def unobserved_factors_root_kernel_caller(self):
- if self.parent() is None:
- return self.unobserved_factors_root_kernel
+ def noise_factors_caller(self):
+ return self.tssb.ntssb.root['node'].root['node'].noise_factors
+
+ def cell_scales_caller(self):
+ return self.tssb.ntssb.root['node'].root['node'].cell_scales
+
+ def get_cell_scales_sample(self):
+ return self.tssb.ntssb.root['node'].root['node'].cell_scales_sample
+
+ def gene_scales_caller(self):
+ return self.tssb.ntssb.root['node'].root['node'].gene_scales
+
+ def get_gene_scales_sample(self):
+ return self.tssb.ntssb.root['node'].root['node'].gene_scales_sample
+
+ def get_obs_weights_sample(self):
+ return self.tssb.ntssb.root['node'].root['node'].obs_weights_sample
+
+ def set_local_sample(self, sample, idx=None):
+ """
+ obs_weights, cell_scales
+ """
+ if idx is None:
+ idx = jnp.arange(self.tssb.ntssb.num_data)
+ self.tssb.ntssb.root['node'].root['node'].obs_weights_sample = self.tssb.ntssb.root['node'].root['node'].obs_weights_sample.at[:,idx].set(sample[0])
+ self.tssb.ntssb.root['node'].root['node'].cell_scales_sample = self.tssb.ntssb.root['node'].root['node'].cell_scales_sample.at[:,idx].set(sample[1])
+
+ def get_factor_weights_sample(self):
+ return self.tssb.ntssb.root['node'].root['node'].factor_weights_sample
+
+ def set_global_sample(self, sample):
+ """
+ factor_weights, gene_scales
+ """
+ self.tssb.ntssb.root['node'].root['node'].factor_weights_sample = jnp.array(sample[0])
+ self.tssb.ntssb.root['node'].root['node'].gene_scales_sample = jnp.array(sample[1])
+
+ def get_noise_sample(self, idx):
+ obs_weights = self.get_obs_weights_sample()[:,idx]
+ factor_weights = self.get_factor_weights_sample()
+ return jax.vmap(sample_prod, in_axes=(0,0))(obs_weights,factor_weights)
+
+ def get_direction_sample(self):
+ return self.samples[1]
+
+ def get_state_sample(self):
+ return self.samples[0]
+
+ # ======== Functions using the variational parameters. =========
+ def compute_loglikelihood(self, idx):
+ # Use stored samples for loc
+ node_mean_samples = self.samples[0]
+ cnvs = self.cnvs
+ obs_weights_samples = self.get_obs_weights_sample()[:,idx]
+ factor_weights_samples = self.get_factor_weights_sample()
+ cell_scales_samples = self.get_cell_scales_sample()[:,idx]
+ gene_scales_samples = self.get_gene_scales_sample()
+ # Average over samples for each observation
+ ll = jnp.mean(jax.vmap(_mc_obs_ll, in_axes=[None,0,None,0,0,0,0])(self.tssb.ntssb.data[idx],
+ node_mean_samples,
+ cnvs,
+ obs_weights_samples,
+ factor_weights_samples,
+ cell_scales_samples,
+ gene_scales_samples,
+ ), axis=0) # mean over MC samples
+ return ll
+
+ def compute_loglikelihood_suff(self):
+ state_samples = self.samples[0]
+ gene_scales_samples = self.get_gene_scales_sample()
+ cnv = self.cnvs
+ ll = jnp.mean(jax.vmap(ll_suffstats, in_axes=[0, None, 0, None, None, None, None, None])
+ (state_samples, cnv, gene_scales_samples, self.suff_stats['A']['total'], self.suff_stats['B_g']['total'],
+ self.suff_stats['C']['total'], self.suff_stats['D_g']['total'], self.suff_stats['E']['total'])
+ )
+ return ll
+
+ def sample_variational_distributions(self, n_samples=10):
+ if self.parent() is not None:
+ if self.parent().samples is not None:
+ n_samples = self.parent().samples[0].shape[0]
+ if self.parent() is None and self.tssb.parent() is None:
+ self.sample_locals(n_samples=n_samples, store=True)
+ self.sample_globals(n_samples=n_samples, store=True)
+ self.sample_kernel(n_samples=n_samples, store=True)
+
+ def sample_locals(self, n_samples, store=True):
+ key = jax.random.PRNGKey(self.seed)
+ key, sample_grad = self.local_sample_and_grad(jnp.arange(self.tssb.ntssb.num_data), key, n_samples=n_samples)
+ sample, _ = sample_grad
+ obs_weights_sample, cell_scales_sample = sample
+ if store:
+ self.obs_weights_sample = obs_weights_sample
+ self.cell_scales_sample = cell_scales_sample
else:
- return self.parent().unobserved_factors_root_kernel_caller()
-
- def unobserved_factors_kernel_caller(self):
- if self.parent() is None:
- return self.unobserved_factors_kernel
+ return obs_weights_sample, cell_scales_sample
+
+ def sample_globals(self, n_samples, store=True):
+ key = jax.random.PRNGKey(self.seed)
+ key, sample_grad = self.global_sample_and_grad(key, n_samples=n_samples)
+ sample, _ = sample_grad
+ factor_weights_sample, gene_scales_sample, factor_precisions_sample = sample
+ if store:
+ self.factor_weights_sample = factor_weights_sample
+ self.gene_scales_sample = gene_scales_sample
+ self.factor_precisions_sample = factor_precisions_sample
else:
- return self.parent().unobserved_factors_kernel_caller()
-
- def cell_global_noise_factors_weights_caller(self, variational=False):
- if self.parent() is None:
- if variational:
- return self.variational_parameters["globals"]["cell_noise_mean"]
- else:
- return self.cell_global_noise_factors_weights
+ return factor_weights_sample, gene_scales_sample, factor_precisions_sample
+
+ def sample_kernel(self, n_samples=10, store=True):
+ if self.parent() is None and self.tssb.parent() is None:
+ return self._sample_root_kernel(n_samples=n_samples, store=store)
+
+ key = jax.random.PRNGKey(self.seed)
+ key, sample_grad = self.direction_sample_and_grad(key, n_samples=n_samples)
+ sampled_angle, _ = sample_grad
+
+ key, sample_grad = self.state_sample_and_grad(key, n_samples=n_samples)
+ sampled_loc, _ = sample_grad
+
+ samples = [sampled_loc, sampled_angle]
+ if store:
+ self.samples = samples
else:
- return self.parent().cell_global_noise_factors_weights_caller(
- variational=variational
- )
-
- def global_noise_factors_caller(self, variational=False):
- if self.parent() is None:
- if variational:
- return self.variational_parameters["globals"]["noise_factors_mean"]
- else:
- return self.global_noise_factors
- else:
- return self.parent().global_noise_factors_caller(variational=variational)
-
- def batch_effects_factors_caller(self, variational=False):
- if self.parent() is None:
- if variational:
- return self.variational_parameters["globals"]["batch_effects_mean"]
- else:
- return self.batch_effects_factors
+ return samples
+
+ def _sample_root_kernel(self, n_samples=10, store=True):
+ # In this model the complete tree's root parameters are fixed and not learned, so just store n_samples copies of them to mimic a sample
+ sampled_direction = jnp.vstack(jnp.repeat(jnp.array([self.params[1]]), n_samples, axis=0))
+ sampled_state = jnp.vstack(jnp.repeat(jnp.array([self.params[0]]), n_samples, axis=0))
+ samples = [sampled_state, sampled_direction]
+ if store:
+ self.samples = samples
else:
- return self.parent().batch_effects_factors_caller(variational=variational)
-
- def set_event_string(
- self,
- var_names=None,
- estimated=True,
- unobs_threshold=1.0,
- kernel_threshold=1.0,
- max_len=5,
- event_fontsize=14,
- ):
- if var_names is None:
- var_names = np.arange(self.n_genes).astype(int).astype(str)
-
- unobserved_factors = self.unobserved_factors
- unobserved_factors_kernel = self.unobserved_factors_kernel
- if estimated:
- unobserved_factors = self.variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- unobserved_factors_kernel = np.exp(
- self.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
- )
-
- # Up-regulated
- up_color = "red"
- up_list = np.where(
- np.logical_and(
- unobserved_factors > unobs_threshold,
- unobserved_factors_kernel > kernel_threshold,
- )
- )[0]
- sorted_idx = np.argsort(unobserved_factors[up_list])[:max_len]
- up_list = up_list[sorted_idx]
- up_str = ""
- if len(up_list) > 0:
- up_str = (
- f'+'
- + ",".join(var_names[up_list])
- )
-
- # Down-regulated
- down_color = "blue"
- down_list = np.where(
- np.logical_and(
- unobserved_factors < -unobs_threshold,
- unobserved_factors_kernel > kernel_threshold,
- )
- )[0]
- sorted_idx = np.argsort(-unobserved_factors[down_list])[:max_len]
- down_list = down_list[sorted_idx]
- down_str = ""
- if len(down_list) > 0:
- down_str = (
- f'-'
- + ",".join(var_names[down_list])
- )
-
- self.event_str = up_str
- sep_str = ""
- if len(up_list) > 0 and len(down_list) > 0:
- sep_str = "
"
- self.event_str = self.event_str + sep_str + down_str
+ return samples
+
+ def compute_global_priors(self):
+ factor_weights_contrib = jnp.sum(jnp.mean(mc_factor_weights_logp_val_and_grad(self.factor_weights_sample, 0., self.factor_precisions_sample)[0], axis=0))
+ log_alpha = jnp.log(self.node_hyperparams['gene_scale_shape'])
+ log_beta = jnp.log(self.node_hyperparams['gene_scale_shape'] * self.gene_ratio)
+ gene_scales_contrib = jnp.sum(jnp.mean(mc_gene_scales_logp_val_and_grad(self.gene_scales_sample, log_alpha, log_beta)[0], axis=0))
+ log_alpha = jnp.log(self.node_hyperparams['factor_precision_shape'])
+ log_beta = jnp.log(1.)
+ factor_precisions_contrib = jnp.sum(jnp.mean(mc_factor_precisions_logp_val_and_grad(self.factor_precisions_sample, log_alpha, log_beta)[0], axis=0))
+ return factor_weights_contrib + gene_scales_contrib + factor_precisions_contrib
+
+ def compute_local_priors(self, batch_indices):
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_weight_variance']))
+ obs_weights_contrib = jnp.sum(jnp.mean(mc_obs_weights_logp_val_and_grad(self.obs_weights_sample[:,batch_indices], 0., log_std)[0], axis=0))
+ log_alpha = jnp.log(self.node_hyperparams['cell_scale_shape'])
+ log_beta = jnp.log(self.node_hyperparams['cell_scale_shape'] * self.lib_ratio)
+ cell_scales_contrib = jnp.sum(jnp.mean(mc_cell_scales_logp_val_and_grad(self.cell_scales_sample[batch_indices], log_alpha, log_beta)[0], axis=0))
+ return obs_weights_contrib + cell_scales_contrib
+
+ def compute_global_entropies(self):
+ mean = self.variational_parameters['global']['factor_weights']['mean']
+ log_std = self.variational_parameters['global']['factor_weights']['log_std']
+ factor_weights_contrib = jnp.sum(factor_weights_logq_val_and_grad(mean, log_std)[0])
+
+ log_alpha = self.variational_parameters['global']['gene_scales']['log_alpha']
+ log_beta = self.variational_parameters['global']['gene_scales']['log_beta']
+ gene_scales_contrib = jnp.sum(gene_scales_logq_val_and_grad(log_alpha, log_beta)[0])
+
+ log_alpha = self.variational_parameters['global']['factor_precisions']['log_alpha']
+ log_beta = self.variational_parameters['global']['factor_precisions']['log_beta']
+ factor_precisions_contrib = jnp.sum(factor_precisions_logq_val_and_grad(log_alpha, log_beta)[0])
+
+ return factor_weights_contrib + gene_scales_contrib + factor_precisions_contrib
+
+ def compute_local_entropies(self, batch_indices):
+ mean = self.variational_parameters['local']['obs_weights']['mean'][batch_indices]
+ log_std = self.variational_parameters['local']['obs_weights']['log_std'][batch_indices]
+ obs_weights_contrib = jnp.sum(obs_weights_logq_val_and_grad(mean, log_std)[0])
+
+ log_alpha = self.variational_parameters['local']['cell_scales']['log_alpha'][batch_indices]
+ log_beta = self.variational_parameters['local']['cell_scales']['log_beta'][batch_indices]
+ cell_scales_contrib = jnp.sum(cell_scales_logq_val_and_grad(log_alpha, log_beta)[0])
+ return obs_weights_contrib + cell_scales_contrib
+
+ def compute_kernel_prior(self):
+ parent = self.parent()
+ if parent is None:
+ return self.compute_root_prior()
+
+ parent_state = self.parent().get_state_sample()
+ direction_samples = self.get_direction_sample()
+
+ direction_shape = self.node_hyperparams['direction_shape']
+ inheritance_strength = self.node_hyperparams['inheritance_strength']
+
+ direction_logpdf = mc_direction_logp_val_and_grad(direction_samples, parent_state, direction_shape, inheritance_strength)[0]
+ direction_logpdf = jnp.sum(direction_logpdf,axis=1)
+
+ state_samples = self.get_state_sample()
+ state_logpdf = mc_state_logp_val_and_grad(state_samples, parent_state, direction_samples)[0]
+ state_logpdf = jnp.sum(state_logpdf, axis=1)
+
+ return jnp.mean(direction_logpdf + state_logpdf)
+
+ def compute_root_direction_prior(self, parent_state):
+ direction_samples = self.get_direction_sample()
+ direction_shape = self.node_hyperparams['direction_shape']
+ inheritance_strength = self.node_hyperparams['inheritance_strength']
+ return jnp.mean(jnp.sum(mc_direction_logp_val_and_grad(direction_samples, parent_state, direction_shape, inheritance_strength)[0], axis=1))
+
+ def compute_root_state_prior(self, parent_state):
+ direction_samples = jnp.sqrt(self.get_direction_sample())
+ state_samples = self.get_state_sample()
+ return jnp.mean(jnp.sum(mc_state_logp_val_and_grad(state_samples, parent_state, direction_samples)[0], axis=1))
+
+ def compute_root_kernel_prior(self, samples):
+ parent_state = samples[0]
+ logp = self.compute_root_direction_prior(parent_state)
+ logp += self.compute_root_state_prior(parent_state)
+ return logp
+
+ def compute_root_prior(self):
+ return 0.
+
+ def compute_kernel_entropy(self):
+ parent = self.parent()
+ if parent is None:
+ return self.compute_root_entropy()
+
+ # Direction
+ direction_logpdf = tfd.Gamma(np.exp(self.variational_parameters['kernel']['direction']['log_alpha']),
+ jnp.exp(self.variational_parameters['kernel']['direction']['log_beta'])
+ ).entropy()
+ direction_logpdf = jnp.sum(direction_logpdf)
+
+ # State
+ state_logpdf = tfd.Normal(self.variational_parameters['kernel']['state']['mean'],
+ jnp.exp(self.variational_parameters['kernel']['state']['log_std'])
+ ).entropy()
+ state_logpdf = jnp.sum(state_logpdf) # Sum across features
+
+ return direction_logpdf + state_logpdf
+
+ def compute_root_entropy(self):
+ # In this model the root nodes have no unknown parameters
+ return 0.
+
+ # ======== Functions for updating the variational parameters. =========
+ def local_sample_and_grad(self, idx, key, n_samples):
+ """Sample and take gradient of local parameters. Must be root"""
+ mean = self.variational_parameters['local']['obs_weights']['mean'][idx]
+ log_std = self.variational_parameters['local']['obs_weights']['log_std'][idx]
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ obs_weights_sample_grad = mc_sample_obs_weights_val_and_grad(jnp.array(sub_keys), mean, log_std)
+ obs_weights_sample = obs_weights_sample_grad[0]
+ # obs_weights_sample = obs_weights_sample.at[:,0,:].set(0.)
+
+ log_alpha = self.variational_parameters['local']['cell_scales']['log_alpha'][idx]
+ log_beta = self.variational_parameters['local']['cell_scales']['log_beta'][idx]
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ cell_scales_sample_grad = mc_sample_cell_scales_val_and_grad(jnp.array(sub_keys), log_alpha, log_beta)
+
+ sample = [obs_weights_sample, cell_scales_sample_grad[0]]
+ grad = [obs_weights_sample_grad[1], cell_scales_sample_grad[1]]
+
+ return key, (sample, grad)
+
+ def global_sample_and_grad(self, key, n_samples):
+ """Sample and take gradient of global parameters. Must be root"""
+ mean = self.variational_parameters['global']['factor_weights']['mean']
+ log_std = self.variational_parameters['global']['factor_weights']['log_std']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ factor_weights_sample_grad = mc_sample_factor_weights_val_and_grad(jnp.array(sub_keys), mean, log_std)
+ factor_weights_sample = factor_weights_sample_grad[0]
+ # factor_weights_sample = factor_weights_sample.at[:,:,0].set(0.)
+
+ log_alpha = self.variational_parameters['global']['gene_scales']['log_alpha']
+ log_beta = self.variational_parameters['global']['gene_scales']['log_beta']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ gene_scales_sample_grad = mc_sample_gene_scales_val_and_grad(jnp.array(sub_keys), log_alpha, log_beta)
+ # gene_scales_sample = gene_scales_sample_grad[0]
+ # gene_scales_sample = gene_scales_sample.at[jnp.arange(n_samples),0].set(1.)
+
+ log_alpha = self.variational_parameters['global']['factor_precisions']['log_alpha']
+ log_beta = self.variational_parameters['global']['factor_precisions']['log_beta']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ factor_precisions_sample_grad = mc_sample_factor_precisions_val_and_grad(jnp.array(sub_keys), log_alpha, log_beta)
+
+ sample = [factor_weights_sample, gene_scales_sample_grad[0], factor_precisions_sample_grad[0]]
+ grad = [factor_weights_sample_grad[1], gene_scales_sample_grad[1], factor_precisions_sample_grad[1]]
+
+ return key, (sample, grad)
+
+ def compute_locals_prior_grad(self, sample):
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_weight_variance']))
+ obs_weights_grad = mc_obs_weights_logp_val_and_grad(sample[0], 0., log_std)[1]
+ log_alpha = jnp.log(self.node_hyperparams['cell_scale_shape'])
+ log_beta = jnp.log(self.node_hyperparams['cell_scale_shape'] * self.lib_ratio)
+ cell_scales_grad = mc_cell_scales_logp_val_and_grad(sample[1], log_alpha, log_beta)[1]
+ return obs_weights_grad, cell_scales_grad
+
+ def compute_globals_prior_grad(self, sample):
+ factor_weights_grad = mc_factor_weights_logp_val_and_grad(sample[0], 0., sample[2])[1]
+ log_alpha = jnp.log(self.node_hyperparams['gene_scale_shape'])
+ log_beta = jnp.log(self.node_hyperparams['gene_scale_shape'] * self.gene_ratio)
+ gene_scales_grad = mc_gene_scales_logp_val_and_grad(sample[1], log_alpha, log_beta)[1]
+ log_alpha = jnp.log(self.node_hyperparams['factor_precision_shape'])
+ log_beta = jnp.log(1.)
+ factor_precisions_grad = mc_factor_precisions_logp_val_and_grad(sample[2], log_alpha, log_beta)[1]
+ factor_precisions_grad += mc_factor_weights_logp_val_and_grad_wrt_precisions(sample[0], 0., sample[2])[1]
+ return factor_weights_grad, gene_scales_grad, factor_precisions_grad
+
+ def compute_locals_entropy_grad(self, idx):
+ mean = self.variational_parameters['local']['obs_weights']['mean'][idx]
+ log_std = self.variational_parameters['local']['obs_weights']['log_std'][idx]
+ obs_weights_grad = obs_weights_logq_val_and_grad(mean, log_std)[1]
+
+ log_alpha = self.variational_parameters['local']['cell_scales']['log_alpha'][idx]
+ log_beta = self.variational_parameters['local']['cell_scales']['log_beta'][idx]
+ cell_scales_grad = cell_scales_logq_val_and_grad(log_alpha, log_beta)[1]
+
+ return obs_weights_grad, cell_scales_grad
+
+ def compute_globals_entropy_grad(self):
+ mean = self.variational_parameters['global']['factor_weights']['mean']
+ log_std = self.variational_parameters['global']['factor_weights']['log_std']
+ factor_weights_grad = factor_weights_logq_val_and_grad(mean, log_std)[1]
+
+ log_alpha = self.variational_parameters['global']['gene_scales']['log_alpha']
+ log_beta = self.variational_parameters['global']['gene_scales']['log_beta']
+ gene_scales_grad = gene_scales_logq_val_and_grad(log_alpha, log_beta)[1]
+
+ log_alpha = self.variational_parameters['global']['factor_precisions']['log_alpha']
+ log_beta = self.variational_parameters['global']['factor_precisions']['log_beta']
+ factor_precisions_grad = factor_precisions_logq_val_and_grad(log_alpha, log_beta)[1]
+
+ return factor_weights_grad, gene_scales_grad, factor_precisions_grad
+
+ def state_sample_and_grad(self, key, n_samples):
+ """Sample and take gradient of state"""
+ mu = self.variational_parameters['kernel']['state']['mean']
+ log_std = self.variational_parameters['kernel']['state']['log_std']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ return key, mc_sample_state_val_and_grad(jnp.array(sub_keys), mu, log_std)
+
+ def direction_sample_and_grad(self, key, n_samples):
+ """Sample and take gradient of direction"""
+ log_alpha = self.variational_parameters['kernel']['direction']['log_alpha']
+ log_beta = self.variational_parameters['kernel']['direction']['log_beta']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ return key, mc_sample_direction_val_and_grad(jnp.array(sub_keys), log_alpha, log_beta)
+
+ def state_sample_and_grad(self, key, n_samples):
+ """Sample and take gradient of state"""
+ mu = self.variational_parameters['kernel']['state']['mean']
+ log_std = self.variational_parameters['kernel']['state']['log_std']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ return key, mc_sample_state_val_and_grad(jnp.array(sub_keys), mu, log_std)
+
+ def compute_direction_prior_grad(self, direction, parent_direction, parent_state):
+ """Gradient of logp(direction|parent_state) wrt this direction"""
+ direction_shape = self.node_hyperparams['direction_shape']
+ inheritance_strength = self.node_hyperparams['inheritance_strength']
+ return mc_direction_logp_val_and_grad(direction, parent_state, direction_shape, inheritance_strength)[1]
+
+ def compute_direction_prior_child_grad_wrt_direction(self, child_direction, direction, state):
+ """Gradient of logp(child_alpha|alpha) wrt this direction"""
+ return 0. # no influence under this model
+
+ def compute_direction_prior_child_grad_wrt_state(self, child_direction, direction, state):
+ """Gradient of logp(child_alpha|alpha) wrt this direction"""
+ direction_shape = self.node_hyperparams['direction_shape']
+ inheritance_strength = self.node_hyperparams['inheritance_strength']
+ return mc_direction_logp_val_and_grad_wrt_parent(child_direction, state, direction_shape, inheritance_strength)[1]
+
+ def compute_state_prior_grad(self, state, parent_state, direction):
+ """Gradient of logp(state|parent_state,direction) wrt this psi"""
+ return mc_state_logp_val_and_grad(state, parent_state, direction)[1]
+
+ def compute_state_prior_child_grad(self, child_state, state, child_direction):
+ """Gradient of logp(child_state|state,child_direction) wrt this state"""
+ return mc_state_logp_val_and_grad_wrt_parent(child_state, state, child_direction)[1]
+
+ def compute_root_state_prior_child_grad(self, child_state, state, child_direction):
+ """Gradient of logp(child_psi|psi,child_alpha) wrt this psi"""
+ return self.compute_state_prior_child_grad(child_state, state, child_direction)
+
+ def compute_state_prior_grad_wrt_direction(self, state, parent_state, direction):
+ """Gradient of logp(state|parent_state,direction) wrt this direction"""
+ return mc_state_logp_val_and_grad_wrt_direction(state, parent_state, direction)[1]
+
+ def compute_direction_entropy_grad(self):
+ """Gradient of logq(alpha) wrt this alpha"""
+ log_alpha = self.variational_parameters['kernel']['direction']['log_alpha']
+ log_beta = self.variational_parameters['kernel']['direction']['log_beta']
+ return direction_logq_val_and_grad(log_alpha, log_beta)[1]
+
+ def compute_state_entropy_grad(self):
+ """Gradient of logq(psi) wrt this psi"""
+ mu = self.variational_parameters['kernel']['state']['mean']
+ log_std = self.variational_parameters['kernel']['state']['log_std']
+ return state_logq_val_and_grad(mu, log_std)[1]
+
+ def compute_ll_state_grad(self, x, weights, state):
+ """Gradient of logp(x|psi,noise) wrt this psi"""
+ obs_weights = self.get_obs_weights_sample()
+ factor_weights = self.get_factor_weights_sample()
+ gene_scales = self.get_gene_scales_sample()
+ cell_scales = self.get_cell_scales_sample()
+ cnvs = self.cnvs
+ return mc_ll_val_and_grad_state(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1]
+
+ def compute_ll_state_grad_suff(self, state):
+ """Gradient of logp(x|state,noise) wrt this state using suff stats"""
+ cnv = self.cnvs
+ gene_scales = self.get_gene_scales_sample()
+ return mc_ll_state_suff_val_and_grad(state, cnv, gene_scales,
+ self.suff_stats['B_g']['total'], self.suff_stats['D_g']['total'])[1]
+
+ def compute_ll_locals_grad(self, x, idx, weights):
+ """Gradient of logp(x|psi,locals,globals) wrt locals"""
+ state = self.get_state_sample()
+ obs_weights = self.get_obs_weights_sample()[:,idx]
+ factor_weights = self.get_factor_weights_sample()
+ gene_scales = self.get_gene_scales_sample()
+ cell_scales = self.get_cell_scales_sample()[:,idx]
+ cnvs = self.cnvs
+ obs_weights_grad = mc_ll_val_and_grad_obs_weights(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1]
+ cell_scales_grad = mc_ll_val_and_grad_cell_scales(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1]
+ return obs_weights_grad, cell_scales_grad
+
+ def compute_ll_globals_grad(self, x, idx, weights):
+ """Gradient of logp(x|psi,locals,globals) wrt globals"""
+ state = self.get_state_sample()
+ obs_weights = self.get_obs_weights_sample()[:,idx]
+ factor_weights = self.get_factor_weights_sample()
+ gene_scales = self.get_gene_scales_sample()
+ cell_scales = self.get_cell_scales_sample()[:,idx]
+ cnvs = self.cnvs
+ factor_weights_grad = mc_ll_val_and_grad_factor_weights(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1]
+ gene_scales_grad = mc_ll_val_and_grad_gene_scales(x, weights, state, cnvs, obs_weights, factor_weights, cell_scales, gene_scales)[1]
+ return factor_weights_grad, gene_scales_grad
+
+ def update_direction_params(self, direction_params_grad, direction_sample_grad, direction_params_entropy_grad, step_size=0.001):
+ mc_grad = jnp.mean(direction_params_grad[0] * direction_sample_grad, axis=0)
+ direction_log_alpha_grad = mc_grad + direction_params_entropy_grad[0]
+ self.variational_parameters['kernel']['direction']['log_alpha'] += direction_log_alpha_grad * step_size
+
+ mc_grad = jnp.mean(direction_params_grad[1] * direction_sample_grad, axis=0)
+ direction_log_beta_grad = mc_grad + direction_params_entropy_grad[1]
+ self.variational_parameters['kernel']['direction']['log_beta'] += direction_log_beta_grad * step_size
+
+ self.variational_parameters['kernel']['direction']['log_alpha'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_alpha'], minval=MIN_ALPHA)
+ self.variational_parameters['kernel']['direction']['log_beta'] = self.apply_clip(self.variational_parameters['kernel']['direction']['log_beta'], maxval=MAX_BETA)
+
+ def update_state_params(self, state_params_grad, state_sample_grad, state_params_entropy_grad, step_size=0.001):
+ mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0)
+ loc_mean_grad = mc_grad + state_params_entropy_grad[0]
+ self.variational_parameters['kernel']['state']['mean'] += loc_mean_grad * step_size
+
+ mc_grad = jnp.mean(state_params_grad[1] * state_sample_grad, axis=0)
+ loc_log_std_grad = mc_grad + state_params_entropy_grad[1]
+ self.variational_parameters['kernel']['state']['log_std'] += loc_log_std_grad * step_size
+
+ self.variational_parameters['kernel']['state']['mean'] = self.apply_clip(self.variational_parameters['kernel']['state']['mean'], minval=-5., maxval=5.)
+ self.variational_parameters['kernel']['state']['log_std'] = self.apply_clip(self.variational_parameters['kernel']['state']['log_std'], minval=-5., maxval=5.)
+
+ def update_cell_scales_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001):
+ mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0]
+ new_param = self.variational_parameters['local']['cell_scales']['log_alpha'][idx] + param_grad * step_size
+ self.variational_parameters['local']['cell_scales']['log_alpha'] = self.variational_parameters['local']['cell_scales']['log_alpha'].at[idx].set(new_param)
+
+ mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1]
+ new_param = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + param_grad * step_size
+ self.variational_parameters['local']['cell_scales']['log_beta'] = self.variational_parameters['local']['cell_scales']['log_beta'].at[idx].set(new_param)
+
+ self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=MIN_ALPHA)
+ self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=MAX_BETA)
+
+ def update_obs_weights_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001):
+ mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0]
+ new_param = self.variational_parameters['local']['obs_weights']['mean'][idx] + param_grad * step_size
+ self.variational_parameters['local']['obs_weights']['mean'] = self.variational_parameters['local']['obs_weights']['mean'].at[idx].set(new_param)
+
+ mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1]
+ new_param = self.variational_parameters['local']['obs_weights']['log_std'][idx] + param_grad * step_size
+ self.variational_parameters['local']['obs_weights']['log_std'] = self.variational_parameters['local']['obs_weights']['log_std'].at[idx].set(new_param)
+
+ def update_local_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=.001, param_names=["obs_weights", "cell_scales"], **kwargs):
+ if param_names is None:
+ param_names=["obs_weights", "cell_scales"]
+ if "obs_weights" in param_names:
+ self.update_obs_weights_params(idx, local_params_grad[0], local_sample_grad[0], local_params_entropy_grad[0], ent_anneal=ent_anneal, step_size=step_size)
+ if "cell_scales" in param_names:
+ self.update_cell_scales_params(idx, local_params_grad[1], local_sample_grad[1], local_params_entropy_grad[1], ent_anneal=ent_anneal, step_size=step_size)
+
+ def update_gene_scales_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001):
+ mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[0]
+ self.variational_parameters['global']['gene_scales']['log_alpha'] += param_grad * step_size
+
+ mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[1]
+ self.variational_parameters['global']['gene_scales']['log_beta'] += param_grad * step_size
+
+ self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=MIN_ALPHA)
+ self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=MAX_BETA)
+
+ def update_factor_weights_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001):
+ mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[0]
+ self.variational_parameters['global']['factor_weights']['mean'] += param_grad * step_size
+
+ mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[1]
+ self.variational_parameters['global']['factor_weights']['log_std'] += param_grad * step_size
+
+ def update_factor_precisions_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001):
+ mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[0]
+ self.variational_parameters['global']['factor_precisions']['log_alpha'] += param_grad * step_size
+
+ mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[1]
+ self.variational_parameters['global']['factor_precisions']['log_beta'] += param_grad * step_size
+
+ self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=MIN_ALPHA)
+ self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=MAX_BETA)
+
+ def update_global_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001,
+ param_names=["factor_weights", "gene_scales", "factor_precisions"], **kwargs):
+ if param_names is None:
+ param_names=["factor_weights", "gene_scales", "factor_precisions"]
+ if "factor_weights" in param_names:
+ self.update_factor_weights_params(global_params_grad[0], global_sample_grad[0], global_params_entropy_grad[0], step_size=step_size)
+ if "gene_scales" in param_names:
+ self.update_gene_scales_params(global_params_grad[1], global_sample_grad[1], global_params_entropy_grad[1], step_size=step_size)
+ if "factor_precisions" in param_names:
+ self.update_factor_precisions_params(global_params_grad[2], global_sample_grad[2], global_params_entropy_grad[2], step_size=step_size)
+
+ def initialize_global_opt_states(self, param_names=["factor_weights", "gene_scales", "factor_precisions"]):
+ states = dict()
+ if param_names is None:
+ param_names=["factor_weights", "gene_scales", "factor_precisions"]
+ if "factor_weights" in param_names:
+ factor_weights_states = self.initialize_factor_weights_states()
+ states["factor_weights"] = factor_weights_states
+ if "gene_scales" in param_names:
+ gene_scales_states = self.initialize_gene_scales_states()
+ states["gene_scales"] = gene_scales_states
+ if "factor_precisions" in param_names:
+ factor_precisions_states = self.initialize_factor_precisions_states()
+ states["factor_precisions"] = factor_precisions_states
+ return states
+
+ def initialize_factor_weights_states(self):
+ n_factors = self.node_hyperparams['n_factors']
+ m = jnp.zeros((n_factors,self.n_genes))
+ v = jnp.zeros((n_factors,self.n_genes))
+ state1 = (m,v)
+ m = jnp.zeros((n_factors,self.n_genes))
+ v = jnp.zeros((n_factors,self.n_genes))
+ state2 = (m,v)
+ states = (state1, state2)
+ return states
+
+ def initialize_gene_scales_states(self):
+ m = jnp.zeros((self.n_genes,))
+ v = jnp.zeros((self.n_genes,))
+ state1 = (m,v)
+ m = jnp.zeros((self.n_genes,))
+ v = jnp.zeros((self.n_genes,))
+ state2 = (m,v)
+ states = (state1, state2)
+ return states
+
+ def initialize_factor_precisions_states(self):
+ n_factors = self.node_hyperparams['n_factors']
+ m = jnp.zeros((n_factors,1))
+ v = jnp.zeros((n_factors,1))
+ state1 = (m,v)
+ m = jnp.zeros((n_factors,1))
+ v = jnp.zeros((n_factors,1))
+ state2 = (m,v)
+ states = (state1, state2)
+ return states
+
+ def update_factor_weights_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001):
+ mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[0]
+
+ m, v = states[0]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state1 = (m, v)
+ self.variational_parameters['global']['factor_weights']['mean'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[1]
+
+ m, v = states[1]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state2 = (m, v)
+ self.variational_parameters['global']['factor_weights']['log_std'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ states = (state1, state2)
+ return states
+
+ def update_gene_scales_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001):
+ mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[0]
+
+ m, v = states[0]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state1 = (m, v)
+ self.variational_parameters['global']['gene_scales']['log_alpha'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[1]
+
+ m, v = states[1]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state2 = (m, v)
+ self.variational_parameters['global']['gene_scales']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ self.variational_parameters['global']['gene_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_alpha'], minval=MIN_ALPHA)
+ self.variational_parameters['global']['gene_scales']['log_beta'] = self.apply_clip(self.variational_parameters['global']['gene_scales']['log_beta'], maxval=MAX_BETA)
+
+ states = (state1, state2)
+ return states
+
+ def update_factor_precisions_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001):
+ mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[0]
+
+ m, v = states[0]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state1 = (m, v)
+ self.variational_parameters['global']['factor_precisions']['log_alpha'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[1]
+
+ m, v = states[1]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state2 = (m, v)
+ self.variational_parameters['global']['factor_precisions']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ self.variational_parameters['global']['factor_precisions']['log_alpha'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_alpha'], minval=MIN_ALPHA)
+ self.variational_parameters['global']['factor_precisions']['log_beta'] = self.apply_clip(self.variational_parameters['global']['factor_precisions']['log_beta'], maxval=MAX_BETA)
+
+ states = (state1, state2)
+ return states
+
+ def update_global_params_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001, param_names=["factor_weights", "gene_scales", "factor_precisions"], **kwargs):
+ if param_names is None:
+ param_names=["factor_weights", "gene_scales", "factor_precisions"]
+ if "factor_weights" in param_names:
+ factor_weights_states = self.update_factor_weights_adaptive(global_params_grad[0], global_sample_grad[0], global_params_entropy_grad[0],
+ i=i, states=states["factor_weights"], b1=b1, b2=b2, eps=eps, step_size=step_size)
+ states["factor_weights"] = factor_weights_states
+ if "gene_scales" in param_names:
+ gene_scales_states = self.update_gene_scales_adaptive(global_params_grad[1], global_sample_grad[1], global_params_entropy_grad[1],
+ i=i, states=states["gene_scales"], b1=b1, b2=b2, eps=eps, step_size=step_size)
+ states["gene_scales"] = gene_scales_states
+ if "factor_precisions" in param_names:
+ factor_precisions_states = self.update_factor_precisions_adaptive(global_params_grad[2], global_sample_grad[2], global_params_entropy_grad[2],
+ i=i, states=states["factor_precisions"], b1=b1, b2=b2, eps=eps, step_size=step_size)
+ states["factor_precisions"] = factor_precisions_states
+ return states
+
+
+ def initialize_local_opt_states(self, param_names=["obs_weights", "cell_scales"]):
+ states = dict()
+ if param_names is None:
+ param_names=["obs_weights", "cell_scales"]
+ if "obs_weights" in param_names:
+ obs_weights_states = self.initialize_obs_weights_states()
+ states["obs_weights"] = obs_weights_states
+ if "cell_scales" in param_names:
+ cell_scales_states = self.initialize_cell_scales_states()
+ states["cell_scales"] = cell_scales_states
+ return states
+
+ def initialize_obs_weights_states(self):
+ n_obs = self.tssb.ntssb.num_data
+ n_factors = self.node_hyperparams['n_factors']
+ m = jnp.zeros((n_obs,n_factors))
+ v = jnp.zeros((n_obs,n_factors))
+ state1 = (m,v)
+ m = jnp.zeros((n_obs,n_factors))
+ v = jnp.zeros((n_obs,n_factors))
+ state2 = (m,v)
+ states = (state1, state2)
+ return states
+
+ def initialize_cell_scales_states(self):
+ n_obs = self.tssb.ntssb.num_data
+ m = jnp.zeros((n_obs,1))
+ v = jnp.zeros((n_obs,1))
+ state1 = (m,v)
+ m = jnp.zeros((n_obs,1))
+ v = jnp.zeros((n_obs,1))
+ state2 = (m,v)
+ states = (state1, state2)
+ return states
+
+ def update_obs_weights_adaptive(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, i, states, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001, ent_anneal=1.):
+ """
+ states are not indexed
+ """
+ mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0]
+
+ m, v = states[0]
+ new_m = (1 - b1) * param_grad + b1 * m[idx] # First moment estimate.
+ new_v = (1 - b2) * jnp.square(param_grad) + b2 * v[idx] # Second moment estimate.
+ m = m.at[idx].set(new_m)
+ v = v.at[idx].set(new_v)
+ mhat = new_m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = new_v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state1 = (m, v)
+ new_param = self.variational_parameters['local']['obs_weights']['mean'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps)
+ self.variational_parameters['local']['obs_weights']['mean'] = self.variational_parameters['local']['obs_weights']['mean'].at[idx].set(new_param)
+
+ mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1]
+
+ m, v = states[1]
+ new_m = (1 - b1) * param_grad + b1 * m[idx] # First moment estimate.
+ new_v = (1 - b2) * jnp.square(param_grad) + b2 * v[idx] # Second moment estimate.
+ m = m.at[idx].set(new_m)
+ v = v.at[idx].set(new_v)
+ mhat = new_m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = new_v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state2 = (m, v)
+ new_param = self.variational_parameters['local']['obs_weights']['log_std'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps)
+ self.variational_parameters['local']['obs_weights']['log_std'] = self.variational_parameters['local']['obs_weights']['log_std'].at[idx].set(new_param)
+
+ states = (state1, state2)
+ return states
+
+ def update_cell_scales_adaptive(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, i, states, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001, ent_anneal=1.):
+ """
+ states are already indexed, as are the gradients
+ """
+ mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0]
+
+ m, v = states[0]
+ new_m = (1 - b1) * param_grad + b1 * m[idx] # First moment estimate.
+ new_v = (1 - b2) * jnp.square(param_grad) + b2 * v[idx] # Second moment estimate.
+ m = m.at[idx].set(new_m)
+ v = v.at[idx].set(new_v)
+ mhat = new_m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = new_v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state1 = (m, v)
+ new_param = self.variational_parameters['local']['cell_scales']['log_alpha'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps)
+ self.variational_parameters['local']['cell_scales']['log_alpha'] = self.variational_parameters['local']['cell_scales']['log_alpha'].at[idx].set(new_param)
+
+ mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1]
+
+ m, v = states[1]
+ new_m = (1 - b1) * param_grad + b1 * m[idx] # First moment estimate.
+ new_v = (1 - b2) * jnp.square(param_grad) + b2 * v[idx] # Second moment estimate.
+ m = m.at[idx].set(new_m)
+ v = v.at[idx].set(new_v)
+ mhat = new_m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = new_v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state2 = (m, v)
+ new_param = self.variational_parameters['local']['cell_scales']['log_beta'][idx] + step_size * mhat / (jnp.sqrt(vhat) + eps)
+ self.variational_parameters['local']['cell_scales']['log_beta'] = self.variational_parameters['local']['cell_scales']['log_beta'].at[idx].set(new_param)
+
+ self.variational_parameters['local']['cell_scales']['log_alpha'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_alpha'], minval=MIN_ALPHA)
+ self.variational_parameters['local']['cell_scales']['log_beta'] = self.apply_clip(self.variational_parameters['local']['cell_scales']['log_beta'], maxval=MAX_BETA)
+
+ states = (state1, state2)
+ return states
+
+ def update_local_params_adaptive(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, i, states, ent_anneal=1., b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001, param_names=["obs_weights", "cell_scales"], **kwargs):
+ if param_names is None:
+ param_names = ["obs_weights", "cell_scales"]
+
+ if "obs_weights" in param_names:
+ obs_weights_states = self.update_obs_weights_adaptive(idx, local_params_grad[0], local_sample_grad[0], local_params_entropy_grad[0],
+ i=i, states=states["obs_weights"], b1=b1, b2=b2, eps=eps, step_size=step_size, ent_anneal=ent_anneal)
+ states["obs_weights"] = obs_weights_states
+ if "cell_scales" in param_names:
+ cell_scales_states = self.update_cell_scales_adaptive(idx, local_params_grad[1], local_sample_grad[1], local_params_entropy_grad[1],
+ i=i, states=states["cell_scales"], b1=b1, b2=b2, eps=eps, step_size=step_size, ent_anneal=ent_anneal)
+ states["cell_scales"] = cell_scales_states
+ return states
+
+ def initialize_state_states(self):
+ m = jnp.zeros((self.n_genes,))
+ v = jnp.zeros((self.n_genes,))
+ state1 = (m,v)
+ m = jnp.zeros((self.n_genes,))
+ v = jnp.zeros((self.n_genes,))
+ state2 = (m,v)
+ states = (state1, state2)
+ return states
+
+ def update_state_adaptive(self, state_params_grad, state_sample_grad, state_params_entropy_grad, i, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001):
+ states = self.state_states
+
+ mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0)
+ param_grad = mc_grad + state_params_entropy_grad[0]
+
+ m, v = states[0]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state1 = (m, v)
+ self.variational_parameters['kernel']['state']['mean'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ mc_grad = jnp.mean(state_params_grad[1] * state_sample_grad, axis=0)
+ param_grad = mc_grad + state_params_entropy_grad[1]
+
+ m, v = states[1]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state2 = (m, v)
+ self.variational_parameters['kernel']['state']['log_std'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ states = (state1, state2)
+ self.state_states = states
+
+ def initialize_direction_states(self):
+ m = jnp.zeros((self.n_genes,))
+ v = jnp.zeros((self.n_genes,))
+ state1 = (m,v)
+ m = jnp.zeros((self.n_genes,))
+ v = jnp.zeros((self.n_genes,))
+ state2 = (m,v)
+ states = (state1, state2)
+ return states
+
+ def update_direction_adaptive(self, direction_params_grad, direction_sample_grad, direction_params_entropy_grad, i, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001):
+ states = self.direction_states
+ mc_grad = jnp.mean(direction_params_grad[0] * direction_sample_grad, axis=0)
+ param_grad = mc_grad + direction_params_entropy_grad[0]
+
+ m, v = states[0]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state1 = (m, v)
+ self.variational_parameters['kernel']['direction']['log_alpha'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ mc_grad = jnp.mean(direction_params_grad[1] * direction_sample_grad, axis=0)
+ param_grad = mc_grad + direction_params_entropy_grad[1]
+
+ m, v = states[1]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state2 = (m, v)
+ self.variational_parameters['kernel']['direction']['log_beta'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ states = (state1, state2)
+ self.direction_states = states
\ No newline at end of file
diff --git a/scatrex/models/cna/node_opt.py b/scatrex/models/cna/node_opt.py
new file mode 100644
index 0000000..fac0fd3
--- /dev/null
+++ b/scatrex/models/cna/node_opt.py
@@ -0,0 +1,233 @@
+import jax
+import jax.numpy as jnp
+import tensorflow_probability.substrates.jax.distributions as tfd
+
+@jax.jit
+def sample_direction(key, log_alpha, log_beta): # univariate: one sample
+ print("haahahdirection")
+ return jnp.maximum(jnp.exp(tfd.ExpGamma(jnp.exp(log_alpha), log_rate=log_beta).sample(seed=key)), 1e-6)
+sample_direction_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(sample_direction, argnums=(1,2)), in_axes=(None, 0, 0))) # per-dimension val and grad
+mc_sample_direction_val_and_grad = jax.jit(jax.vmap(sample_direction_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def sample_state(key, mu, log_std): # univariate: one sample
+ print("haahahstate")
+ return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key)
+sample_state_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(sample_state, argnums=(1,2)), in_axes=(None, 0, 0))) # per-dimension val and grad
+mc_sample_state_val_and_grad = jax.jit(jax.vmap(sample_state_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def direction_logp(this_direction, parent_state, direction_shape, inheritance_strength): # single sample
+ return tfd.Gamma(direction_shape, log_rate=inheritance_strength*jnp.abs(parent_state)).log_prob(this_direction)
+univ_direction_logp_val_and_grad = jax.jit(jax.value_and_grad(direction_logp, argnums=0)) # Take grad wrt to this
+direction_logp_val_and_grad = jax.jit(jax.vmap(univ_direction_logp_val_and_grad, in_axes=(0,0,None,None))) # Take grad wrt to this
+mc_direction_logp_val_and_grad = jax.jit(jax.vmap(direction_logp_val_and_grad, in_axes=(0,0,None,None))) # Multiple sample value_and_grad
+
+univ_direction_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(direction_logp, argnums=1)) # Take grad wrt to parent
+direction_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(univ_direction_logp_val_and_grad_wrt_parent, in_axes=(0,0,None,None))) # Multiple sample value_and_grad
+mc_direction_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(direction_logp_val_and_grad, in_axes=(0,0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def direction_logq(log_alpha, log_beta):
+ return tfd.Gamma(jnp.exp(log_alpha), log_rate=log_beta).entropy()
+direction_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(direction_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters
+
+@jax.jit
+def state_logp(this_state, parent_state, this_direction): # single sample
+ return tfd.Normal(parent_state, this_direction).log_prob(this_state) # sum across dimensions
+state_logp_val = jax.jit(state_logp)
+mc_loc_logp_val = jax.jit(jax.vmap(state_logp_val, in_axes=(0,0,0))) # Multiple sample
+
+univ_state_logp_val_and_grad = jax.jit(jax.value_and_grad(state_logp, argnums=0)) # Take grad wrt to this
+state_logp_val_and_grad = jax.jit(jax.vmap(univ_state_logp_val_and_grad, in_axes=(0,0,0))) # Take grad wrt to this
+mc_state_logp_val_and_grad = jax.jit(jax.vmap(state_logp_val_and_grad, in_axes=(0,0,0))) # Multiple sample value_and_grad
+
+univ_state_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(state_logp, argnums=1)) # Take grad wrt to parent
+state_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(univ_state_logp_val_and_grad_wrt_parent, in_axes=(0,0,0))) # Take grad wrt to parent
+mc_state_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(state_logp_val_and_grad_wrt_parent, in_axes=(0,0,0))) # Multiple sample value_and_grad
+
+univ_state_logp_val_and_grad_wrt_direction = jax.jit(jax.value_and_grad(state_logp, argnums=2)) # Take grad wrt to angle
+state_logp_val_and_grad_wrt_direction = jax.jit(jax.vmap(univ_state_logp_val_and_grad_wrt_direction, in_axes=(0,0,0))) # Take grad wrt to angle
+mc_state_logp_val_and_grad_wrt_direction = jax.jit(jax.vmap(state_logp_val_and_grad_wrt_direction, in_axes=(0,0,0))) # Multiple sample value_and_grad
+
+@jax.jit
+def state_logq(mu, log_std):
+ return tfd.Normal(mu, jnp.exp(log_std)).entropy()
+state_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(state_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters
+
+
+# Noise
+
+@jax.jit
+def sample_obs_weights(key, mean, log_std): # NxK
+ return tfd.Normal(mean, jnp.exp(log_std)).sample(seed=key)
+sample_obs_weights_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_obs_weights, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad
+mc_sample_obs_weights_val_and_grad = jax.jit(jax.vmap(sample_obs_weights_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def sample_factor_weights(key, mu, log_std): # KxG
+ return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key)
+sample_factor_weights_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_factor_weights, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad
+mc_sample_factor_weights_val_and_grad = jax.jit(jax.vmap(sample_factor_weights_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+
+@jax.jit
+def obs_weights_logp(sample, mean, log_std): # single sample, NxK
+ return tfd.Normal(mean, jnp.exp(log_std)).log_prob(sample) # sum across obs and dimensions
+univ_obs_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(obs_weights_logp, argnums=0)) # Take grad wrt to sample (Nx1)
+obs_weights_logp_val_and_grad = jax.jit(jax.vmap(jax.vmap(univ_obs_weights_logp_val_and_grad, in_axes=(0, None, None)), in_axes=(0,None,None))) # Take grad wrt to sample (NxK)
+mc_obs_weights_logp_val_and_grad = jax.jit(jax.vmap(obs_weights_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxNxK
+
+@jax.jit
+def obs_weights_logq(mu, log_std):
+ return tfd.Normal(mu, jnp.exp(log_std)).entropy()
+obs_weights_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(obs_weights_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters
+
+
+@jax.jit
+def factor_weights_logp(sample, mean, precision): # single sample, KxG
+ return jnp.sum(tfd.Normal(mean, 1./jnp.sqrt(precision)).log_prob(sample)) # sum over 1
+univ_factor_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(factor_weights_logp, argnums=0))
+factor_weights_logp_val_and_grad = jax.jit(jax.vmap(jax.vmap(univ_factor_weights_logp_val_and_grad, in_axes=(0,None,None)), in_axes=(0,None,0))) # Take grad wrt to sample (KxG)
+mc_factor_weights_logp_val_and_grad = jax.jit(jax.vmap(factor_weights_logp_val_and_grad, in_axes=(0,None,0))) # Multiple sample value_and_grad: SxKxG
+
+@jax.jit
+def factor_weights_logp_summed(sample, mean, precision): # single sample, KxG
+ return jnp.sum(tfd.Normal(mean, 1./jnp.sqrt(precision)).log_prob(sample)) # sum over genes
+factor_weights_logp_val_and_grad_wrt_precisions = jax.jit(jax.value_and_grad(factor_weights_logp_summed, argnums=2))
+mc_factor_weights_logp_val_and_grad_wrt_precisions = jax.jit(jax.vmap(factor_weights_logp_val_and_grad_wrt_precisions, in_axes=(0,None,0))) # Multiple sample value_and_grad: SxKxG
+
+@jax.jit
+def factor_weights_logq(mu, log_std):
+ return tfd.Normal(mu, jnp.exp(log_std)).entropy()
+factor_weights_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(factor_weights_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters
+
+
+# Cell scales
+@jax.jit
+def sample_cell_scales(key, log_alpha, log_beta): # Nx1
+ return jnp.exp(tfd.ExpGamma(jnp.exp(log_alpha), jnp.exp(log_beta)).sample(seed=key))
+sample_cell_scales_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_cell_scales, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad
+mc_sample_cell_scales_val_and_grad = jax.jit(jax.vmap(sample_cell_scales_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def cell_scales_logp(sample, log_alpha, log_beta): # single sample, Nx1
+ return jnp.sum(tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).log_prob(sample)) # sum across obs and dimensions
+univ_cell_scales_logp_val_and_grad = jax.jit(jax.value_and_grad(cell_scales_logp, argnums=0)) # Take grad wrt to sample (Nx1)
+cell_scales_logp_val_and_grad = jax.jit(jax.vmap(univ_cell_scales_logp_val_and_grad, in_axes=(0,None,None))) # Take grad wrt to sample (Nx1)
+mc_cell_scales_logp_val_and_grad = jax.jit(jax.vmap(cell_scales_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxNx1
+
+@jax.jit
+def cell_scales_logq(log_alpha, log_beta):
+ return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).entropy()
+cell_scales_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(cell_scales_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters
+
+# Gene scales
+@jax.jit
+def sample_gene_scales(key, log_alpha, log_beta): # G
+ return jnp.exp(tfd.ExpGamma(jnp.exp(log_alpha), jnp.exp(log_beta)).sample(seed=key))
+sample_gene_scales_val_and_grad = jax.vmap(jax.value_and_grad(sample_gene_scales, argnums=(1,2)), in_axes=(None,0,0)) # per-dimension val and grad
+mc_sample_gene_scales_val_and_grad = jax.jit(jax.vmap(sample_gene_scales_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def gene_scales_logp(sample, log_alpha, log_beta): # single sample
+ return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).log_prob(sample) # sum across obs and dimensions
+univ_gene_scales_logp_val_and_grad = jax.jit(jax.value_and_grad(gene_scales_logp, argnums=0)) # Take grad wrt to sample (G,)
+gene_scales_logp_val_and_grad = jax.jit(jax.vmap(univ_gene_scales_logp_val_and_grad, in_axes=(0,None,None))) # Take grad wrt to sample (G,)
+mc_gene_scales_logp_val_and_grad = jax.jit(jax.vmap(gene_scales_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxG
+
+@jax.jit
+def gene_scales_logq(log_alpha, log_beta):
+ return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).entropy()
+gene_scales_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(gene_scales_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters
+
+
+# Factor variances
+@jax.jit
+def sample_factor_precisions(key, log_alpha, log_beta): # Kx1
+ return jnp.exp(tfd.ExpGamma(jnp.exp(log_alpha), jnp.exp(log_beta)).sample(seed=key))
+sample_factor_precisions_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_factor_precisions, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad
+mc_sample_factor_precisions_val_and_grad = jax.jit(jax.vmap(sample_factor_precisions_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def factor_precisions_logp(sample, log_alpha, log_beta): # single sample
+ return jnp.sum(tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).log_prob(sample)) # sum across obs and dimensions
+univ_factor_precisions_logp_val_and_grad = jax.jit(jax.value_and_grad(factor_precisions_logp, argnums=0)) # Take grad wrt to sample (G,)
+factor_precisions_logp_val_and_grad = jax.jit(jax.vmap(univ_factor_precisions_logp_val_and_grad, in_axes=(0,None,None))) # Take grad wrt to sample (G,)
+mc_factor_precisions_logp_val_and_grad = jax.jit(jax.vmap(factor_precisions_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxG
+
+@jax.jit
+def factor_precisions_logq(log_alpha, log_beta):
+ return tfd.Gamma(jnp.exp(log_alpha), jnp.exp(log_beta)).entropy()
+factor_precisions_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(factor_precisions_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters
+
+
+@jax.jit
+def _mc_obs_ll(obs, state, cnv, obs_weights, factor_weights, cell_scales, gene_scales): # For each MC sample: NxG
+ m = cell_scales * gene_scales * cnv/2 * jnp.exp(state + obs_weights.dot(factor_weights))
+ return jnp.sum(jax.vmap(jax.scipy.stats.poisson.logpmf, in_axes=[0, 0])(obs, m), axis=1) # sum over dimensions
+
+@jax.jit
+def ll(x, weights, state, cnv, obs_weights, factor_weights, cell_scales, gene_scales): # single sample
+ loc = cell_scales * gene_scales * cnv/2 * jnp.exp(state + obs_weights.dot(factor_weights))
+ return jnp.sum(jnp.sum(tfd.Poisson(loc).log_prob(x),axis=1) * weights)
+ll_val_and_grad_state = jax.jit(jax.value_and_grad(ll, argnums=2)) # Take grad wrt to psi
+mc_ll_val_and_grad_state = jax.jit(jax.vmap(ll_val_and_grad_state,
+ in_axes=(None,None,0,None,0,0,0,0)))
+
+ll_val_and_grad_factor_weights = jax.jit(jax.value_and_grad(ll, argnums=5)) # Take grad wrt to factor_weights
+mc_ll_val_and_grad_factor_weights = jax.jit(jax.vmap(ll_val_and_grad_factor_weights,
+ in_axes=(None,None,0,None,0,0,0,0)))
+
+ll_val_and_grad_cell_scales = jax.jit(jax.value_and_grad(ll, argnums=6)) # Take grad wrt to cell_scales
+mc_ll_val_and_grad_cell_scales = jax.jit(jax.vmap(ll_val_and_grad_cell_scales,
+ in_axes=(None,None,0,None,0,0,0,0)))
+
+ll_val_and_grad_gene_scales = jax.jit(jax.value_and_grad(ll, argnums=7)) # Take grad wrt to gene_scales
+mc_ll_val_and_grad_gene_scales = jax.jit(jax.vmap(ll_val_and_grad_gene_scales,
+ in_axes=(None,None,0,None,0,0,0,0)))
+
+
+@jax.jit
+def ll_obs(x, weight, state, cnv, obs_weights, factor_weights, cell_scales, gene_scales): # single obs
+ loc = cell_scales * gene_scales * cnv/2 * jnp.exp(state + obs_weights.dot(factor_weights))
+ return jnp.sum(tfd.Poisson(loc).log_prob(x)) * weight
+
+univ_ll_val_and_grad_obs_weights = jax.jit(jax.value_and_grad(ll_obs, argnums=4)) # Take grad wrt to obs_weights
+ll_val_and_grad_obs_weights = jax.jit(jax.vmap(univ_ll_val_and_grad_obs_weights, in_axes=(0,0, None, None, 0, None, 0, None)))
+mc_ll_val_and_grad_obs_weights = jax.jit(jax.vmap(ll_val_and_grad_obs_weights,
+ in_axes=(None,None,0,None,0,0,0,0)))
+
+univ_ll_val_and_grad_cell_scales = jax.jit(jax.value_and_grad(ll_obs, argnums=6)) # Take grad wrt to cell_scales
+ll_val_and_grad_cell_scales = jax.jit(jax.vmap(univ_ll_val_and_grad_cell_scales, in_axes=(0,0, None, None, 0, None, 0, None)))
+mc_ll_val_and_grad_cell_scales = jax.jit(jax.vmap(ll_val_and_grad_cell_scales,
+ in_axes=(None,None,0,None,0,0,0,0)))
+
+@jax.jit
+def ll_suffstats(state, cnv, gene_scales, A, B_g, C, D_g, E): # for a single state sample
+ """
+ A: \sum_n q(z_n = this node) * \sum_g x_ng * E[log\gamma_n]
+ B_g: \sum_n q(z_n = this node) * x_ng
+ C: \sum_n q(z_n = this node) * \sum_g x_ng * E[(s_nW_g)]
+ D_g: \sum_n q(z_n = this node) * E[\gamma_n] * E[exp(s_nW_g)]
+ E: \sum_n q(z_n = this node) * lgamma(x_ng+1)
+ """
+ ll = A + jnp.sum(B_g * (jnp.log(gene_scales) + jnp.log(cnv/2) + state)) + \
+ C - jnp.sum(gene_scales * cnv/2 * jnp.exp(state) * D_g) - E
+ return ll
+
+@jax.jit
+def ll_state_suff(state, cnv, gene_scales, B_g, D_g): # for a single state sample
+ """
+ B_g: \sum_n q(z_n = this node) * x_ng
+ D_g: \sum_n q(z_n = this node) * E[\gamma_n] * E[s_nW_g]
+ """
+ ll = jnp.sum(B_g * state) - jnp.sum(gene_scales * cnv/2 * jnp.exp(state) * D_g)
+ return ll
+
+ll_state_suff_val_and_grad = jax.jit(jax.value_and_grad(ll_state_suff, argnums=0)) # Take grad wrt to psi
+mc_ll_state_suff_val_and_grad = jax.jit(jax.vmap(ll_state_suff_val_and_grad,
+ in_axes=(0,None,0,None, None)))
+
+# To get noise sample
+sample_prod = jax.jit(lambda mat1, mat2: mat1.dot(mat2))
\ No newline at end of file
diff --git a/scatrex/models/cna/opt_funcs.py b/scatrex/models/cna/opt_funcs.py
deleted file mode 100644
index 64ca0ba..0000000
--- a/scatrex/models/cna/opt_funcs.py
+++ /dev/null
@@ -1,2111 +0,0 @@
-import jax
-from jax import jit, grad, vmap
-from jax import random
-from jax.example_libraries import optimizers
-import jax.numpy as jnp
-import jax.nn as jnn
-
-from scatrex.util import *
-from scatrex.callbacks import elbos_callback
-
-from jax.scipy.special import digamma, betaln, gammaln
-
-from functools import partial
-
-def loggaussian_ent(mean, std):
- return jnp.log(std) + mean + 0.5 + 0.5*jnp.log(2*jnp.pi)
-
-def diag_loggaussian_ent(mean, log_std, axis=None):
- return jnp.sum(vmap(loggaussian_ent)(mean, jnp.exp(log_std)), axis=axis)
-
-
-def complete_elbo(self, rng, mc_samples=3):
- rngs = random.split(rng, mc_samples)
-
- # Get data
- data = self.data
- lib_sizes = self.root["node"].root["node"].lib_sizes
- n_cells, n_genes = self.data.shape
-
- # Get global variational parameters
- log_baseline_mean = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"]
- log_baseline_log_std = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"]
- cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"]
- cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"]
- noise_factors_mean = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"]
- noise_factors_log_std = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"]
-
- # Sample global
- baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std)
- cell_noise_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std)
- noise_factor_samples = sample_noise_factors(rngs, noise_factors_mean, noise_factors_log_std)
-
- def sub_descend(root, depth=0):
- elbo = 0
- subtree_weight = 0
-
- if root["node"].parent() is None:
- parent_unobserved_samples = jnp.zeros((mc_samples, n_genes))
- unobserved_samples = jnp.zeros((mc_samples, n_genes))
- unobserved_kernel_samples = jnp.zeros((mc_samples, n_genes))
- else:
- # Sample parent
- parent_unobserved_means = root["node"].parent().variational_parameters["locals"]["unobserved_factors_mean"]
- parent_unobserved_log_stds = root["node"].parent().variational_parameters["locals"]["unobserved_factors_log_std"]
- parent_unobserved_samples = sample_unobserved(rngs, parent_unobserved_means, parent_unobserved_log_stds)
-
- # Sample node
- unobserved_means = root["node"].variational_parameters["locals"]["unobserved_factors_mean"]
- unobserved_log_stds = root["node"].variational_parameters["locals"]["unobserved_factors_log_std"]
- unobserved_samples = sample_unobserved(rngs, unobserved_means, unobserved_log_stds)
- unobserved_factors_kernel_log_mean = root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_mean"]
- unobserved_factors_kernel_log_std = root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_std"]
- unobserved_kernel_samples = sample_unobserved_kernel(rngs, unobserved_factors_kernel_log_mean, unobserved_factors_kernel_log_std)
-
- psi_not_prev_sum = 0
- for i, child in enumerate(root["children"]):
- nu_alpha = child["node"].variational_parameters["locals"]["nu_log_mean"]
- nu_beta = child["node"].variational_parameters["locals"]["nu_log_std"]
- psi_alpha = child["node"].variational_parameters["locals"]["psi_log_mean"]
- psi_beta = child["node"].variational_parameters["locals"]["psi_log_std"]
-
- # Compute local exact KL divergence term
- child["node"].stick_kl = beta_kl(nu_alpha, nu_beta, 1, (root["node"].tssb.alpha_decay**depth)*root["node"].tssb.dp_alpha) + beta_kl(psi_alpha, psi_beta, 1, root["node"].tssb.dp_gamma)
-
- # Compute expected local NTSSB weight term
- E_log_phi = E_q_log_beta(psi_alpha, psi_beta) + psi_not_prev_sum
- E_log_nu = E_q_log_beta(nu_alpha, nu_beta)
- E_log_1_nu = E_q_log_1_beta(nu_alpha, nu_beta)
-
- child["node"].weight_until_here = child["node"].data_weights + root["node"].weight_until_here
- child["node"].expected_weights = child["node"].data_weights*E_log_nu + root["node"].weight_until_here*E_log_1_nu + child["node"].weight_until_here*E_log_phi
-
- if i > 0:
- psi_not_prev_sum += E_q_log_1_beta(psi_alpha, psi_beta)
-
- # Go down in the tree
- elbo, subtree_weight = sub_descend(child, depth=depth+1)
-
-
- # Compute local approximate expected log likelihood term
- root["node"].ell = ell(root["node"].data_weights, baseline_samples, cell_noise_samples, noise_factor_samples, unobserved_samples, root["node"].cnvs, lib_sizes, data)
-
- if root["node"].parent() is None:
- root["node"].param_kl = 0.
- else:
- # Compute local approximate KL divergence term
- root["node"].param_kl = local_paramkl(parent_unobserved_samples, unobserved_samples, unobserved_kernel_samples,
- unobserved_means,
- unobserved_log_stds,
- unobserved_factors_kernel_log_mean,
- unobserved_factors_kernel_log_std,
- jnp.array([0.1, 1.]), jnp.array(0.))
-
- # If this node is the root of its subtree, compute expected weights here
- if depth == 0:
- nu_alpha = child["node"].variational_parameters["locals"]["nu_log_mean"]
- nu_beta = child["node"].variational_parameters["locals"]["nu_log_std"]
- psi_alpha = child["node"].variational_parameters["locals"]["psi_log_mean"]
- psi_beta = child["node"].variational_parameters["locals"]["psi_log_std"]
-
- E_log_nu = E_q_log_beta(nu_alpha, nu_beta)
- root["node"].expected_weights = root["node"].data_weights*E_log_nu
-
- # Compute local exact KL divergence term
- root["node"].stick_kl = beta_kl(nu_alpha, nu_beta, 1, (root["node"].tssb.alpha_decay**depth)*root["node"].tssb.dp_alpha) + beta_kl(psi_alpha, psi_beta, 1, root["node"].tssb.dp_gamma)
-
-
- # Compute local ELBO quantities
- root["node"].local_elbo = root["node"].ell + root["node"].expected_weights - root["node"].data_weights*np.log(root["node"].data_weights)
- root["node"].local_elbo = root["node"].local_elbo - root["node"].param_kl - root["node"].stick_kl
-
- # Remove root contribution
- elbo = elbo + root["node"].local_elbo
-
- # Track total weight in the subtree
- subtree_weight += root["node"].data_weights
-
- return elbo, subtree_weight
-
- def descend(super_root, elbo=0):
- subtree_elbo, subtree_weights = sub_descend(super_root["node"].root)
- elbo += np.sum(subtree_elbo + subtree_weights * super_root["node"].weight)
- for super_child in super_root["children"]:
- elbo += descend(super_child)
- return elbo
-
- elbo = descend(self.root)
-
- # Compute global KL
- global_kl = jnp.sum(baseline_kl(log_baseline_mean, log_baseline_log_std))
- global_kl += jnp.sum(noise_factors_kl(noise_factors_mean, noise_factors_log_std))
- global_kl += jnp.sum(cell_noise_kl(cell_noise_mean, cell_noise_log_std))
-
- # Add to ELBO
- elbo = elbo - global_kl - self.root["node"].root["node"].local_elbo
-
- return elbo
-
-@jax.jit
-def beta_kl(a1, b1, a2, b2):
- def logbeta_func(a,b):
- return gammaln(a) + gammaln(b) - gammaln(a+b)
-
- kl = logbeta_func(a1,b1) - logbeta_func(a2,b2)
- kl += (a1-a2)*digamma(a1)
- kl += (b1-b2)*digamma(b1)
- kl += (a2-a1 + b2-b1)*digamma(a1+b1)
- return kl
-
-@jax.jit
-def E_q_log_beta(
- alpha,
- beta,
-):
- return digamma(alpha) - digamma(alpha + beta)
-
-@jax.jit
-def E_q_log_1_beta(
- alpha,
- beta,
-):
- return digamma(beta) - digamma(alpha + beta)
-
-# This computes E_q[p(z_n=\epsilon | \nu, \psi)]
-@jax.jit
-def compute_expected_weight(
- nu_alpha,
- nu_beta,
- psi_alpha,
- psi_beta,
- prev_nu_sticks_sum,
- prev_psi_sticks_sum,
-):
- nu_digamma_sum = digamma(nu_alpha + nu_beta)
- E_log_nu = digamma(nu_alpha) - nu_digamma_sum
- nu_sticks_sum = digamma(nu_beta) - nu_digamma_sum + prev_nu_sticks_sum
- psi_sticks_sum = local_psi_sticks_sum(psi_alpha, psi_beta) + prev_psi_sticks_sum
- weight = E_log_nu + nu_sticks_sum + psi_sticks_sum
- return weight, nu_sticks_sum, psi_sticks_sum
-
-@jax.jit
-def sample_baseline(
- rngs,
- log_baseline_mean,
- log_baseline_log_std,
-):
- def _sample(rng, log_baseline_mean, log_baseline_log_std):
- return jnp.append(1, jnp.exp(diag_gaussian_sample(rng, log_baseline_mean, log_baseline_log_std)))
- vectorized_sample = vmap(_sample, in_axes=[0, None, None])
- return vectorized_sample(rngs, log_baseline_mean, log_baseline_log_std)
-
-@jax.jit
-def sample_cell_noise_factors(
- rngs,
- cell_noise_mean,
- cell_noise_log_std,
-):
- def _sample(rng, cell_noise_mean, cell_noise_log_std):
- return diag_gaussian_sample(rng, cell_noise_mean, cell_noise_log_std)
- vectorized_sample = vmap(_sample, in_axes=[0, None, None])
- return vectorized_sample(rngs, cell_noise_mean, cell_noise_log_std)
-
-@jax.jit
-def sample_noise_factors(
- rngs,
- noise_factors_mean,
- noise_factors_log_std,
-):
- def _sample(rng, noise_factors_mean, noise_factors_log_std):
- return diag_gaussian_sample(rng, noise_factors_mean, noise_factors_log_std)
- vectorized_sample = vmap(_sample, in_axes=[0, None, None])
- return vectorized_sample(rngs, noise_factors_mean, noise_factors_log_std)
-
-@jax.jit
-def sample_unobserved(
- rngs,
- unobserved_means,
- unobserved_log_stds,
-):
- def _sample(rng, unobserved_means, unobserved_log_stds):
- return diag_gaussian_sample(rng, unobserved_means, unobserved_log_stds)
- vectorized_sample = vmap(_sample, in_axes=[0, None, None])
- return vectorized_sample(rngs, unobserved_means, unobserved_log_stds)
-
-@jax.jit
-def sample_unobserved_kernel(
- rngs,
- log_unobserved_factors_kernel_means,
- log_unobserved_factors_kernel_log_stds,
-):
- def _sample(rng, log_unobserved_factors_kernel_means, log_unobserved_factors_kernel_log_stds):
- return jnp.exp(diag_gaussian_sample(rng, log_unobserved_factors_kernel_means, log_unobserved_factors_kernel_log_stds))
- vectorized_sample = vmap(_sample, in_axes=[0, None, None])
- return vectorized_sample(rngs, log_unobserved_factors_kernel_means, log_unobserved_factors_kernel_log_stds)
-
-@jax.jit
-def baseline_kl(
- log_baseline_mean,
- log_baseline_log_std,
-):
- std = jnp.exp(log_baseline_log_std)
- return -log_baseline_log_std + .5 * (std**2 + log_baseline_mean**2 - 1.)
-
-
-@jax.jit
-def cell_noise_kl(
- cell_noise_mean,
- cell_noise_log_std,
-):
- std = jnp.exp(cell_noise_log_std)
- return -cell_noise_log_std + .5 * (std**2 + cell_noise_mean**2 - 1.)
-
-
-@jax.jit
-def noise_factors_kl(
- noise_factors_mean,
- noise_factors_log_std,
-):
- std = jnp.exp(noise_factors_log_std)
- return -noise_factors_log_std + .5 * (std**2 + noise_factors_mean**2 - 1.)
-
-@jax.jit
-def ell(
- data_node_weights, # lambda_{nk} with only the node attachments
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- unobserved_samples, # N vector corresponding to node attachments
- cnv, # N vector corresponding to node attachments
- lib_sizes,
- data,
- mask,
-):
- def _ell(data_node_weights, baseline_sample, cell_noise_sample, noise_factor_sample, unobserved_sample):
- noise_sample = cell_noise_sample.dot(noise_factor_sample)
- node_mean = (
- baseline_sample * cnv/2 * jnp.exp(unobserved_sample + noise_sample)
- )
- sum = jnp.sum(node_mean, axis=1).reshape(-1, 1)
- node_mean = node_mean / sum
- node_mean = node_mean * lib_sizes
- pll = vmap(jax.scipy.stats.poisson.logpmf)(data, node_mean)
- ell = jnp.sum(pll, axis=1) # N-vector
- ell *= data_node_weights
- ell *= mask
- return ell
- vectorized = vmap(_ell, in_axes=[None, 0,0,0,0])
- return jnp.mean(vectorized(data_node_weights, baseline_samples, cell_noise_samples, noise_factor_samples, unobserved_samples), axis=0)
-
-
-@jax.jit
-def ll(
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- unobserved_samples,
- cnv,
- lib_sizes,
- data,
-):
- def _ll(baseline_sample, cell_noise_sample, noise_factor_sample, unobserved_sample):
- noise_sample = cell_noise_sample.dot(noise_factor_sample)
- node_mean = (
- baseline_sample * cnv/2 * jnp.exp(unobserved_sample + noise_sample)
- )
- sum = jnp.sum(node_mean, axis=1).reshape(-1, 1)
- node_mean = node_mean / sum
- node_mean = node_mean * lib_sizes
- pll = vmap(jax.scipy.stats.poisson.logpmf)(data, node_mean)
- ll = jnp.sum(pll, axis=1) # N-vector
- return ll
- vectorized = vmap(_ll, in_axes=[0,0,0,0])
- return jnp.mean(vectorized(baseline_samples, cell_noise_samples, noise_factor_samples, unobserved_samples), axis=0)
-
-
-
-@jax.jit
-def local_paramkl(
- parent_unobserved_samples,
- child_unobserved_samples,
- child_unobserved_kernel_samples,
- child_unobserved_means,
- child_unobserved_log_stds,
- child_log_unobserved_factors_kernel_means,
- child_log_unobserved_factors_kernel_log_stds,
- hyperparams, # [concentration, rate]
- child_in_root_subtree, # to make the prior prefer amplifications
-):
- def _local_paramkl(
- parent_unobserved_sample,
- child_unobserved_sample,
- child_unobserved_kernel_sample,
- child_unobserved_means,
- child_unobserved_log_stds,
- child_log_unobserved_factors_kernel_means,
- child_log_unobserved_factors_kernel_log_stds,
- hyperparams,
- ):
- kl = 0.0
- # Kernel
- pl = diag_gamma_logpdf(
- child_unobserved_kernel_sample,
- jnp.log(hyperparams[0] * jnp.ones((child_unobserved_kernel_sample.shape[0],))),
- (hyperparams[1]*jnp.abs(parent_unobserved_sample)),
- )
- ent = -diag_loggaussian_logpdf(
- child_unobserved_kernel_sample,
- child_log_unobserved_factors_kernel_means,
- child_log_unobserved_factors_kernel_log_stds,
- )
-# ent = diag_loggaussian_ent(child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds)
- kl += pl + ent
-
- # Cell state
- pl = diag_gaussian_logpdf(
- child_unobserved_sample,
- parent_unobserved_sample,
- jnp.log(0.001 + child_unobserved_kernel_sample),
- )
- ent = -diag_gaussian_logpdf(
- child_unobserved_sample, child_unobserved_means, child_unobserved_log_stds
- )
- kl += pl + ent
-
- return kl
- vectorized = vmap(_local_paramkl, in_axes=[0,0,0,None,None,None,None,None])
- return jnp.mean(vectorized(parent_unobserved_samples, child_unobserved_samples,
- child_unobserved_kernel_samples, child_unobserved_means, child_unobserved_log_stds,
- child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds, hyperparams))
-
-@jax.jit
-def update_local_parameters(
- rngs,
- child_unobserved_means,
- child_unobserved_log_stds,
- child_log_unobserved_factors_kernel_means,
- child_log_unobserved_factors_kernel_log_stds,
- data_node_weights,
- parent_unobserved_samples,
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- hyperparams, # [concentration, rate]
- child_in_root_subtree, # to make the prior prefer amplifications
- cnv,
- lib_sizes,
- data,
- mask,
- states,
- i,
- mb_scaling=1,
- lr=0.01
- ):
-
- def local_loss(
- params,
- parent_unobserved_samples,
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- hyperparams,
- child_in_root_subtree
- ):
- child_unobserved_means = params[0]
- child_unobserved_log_stds = params[1]
- child_log_unobserved_factors_kernel_means = params[2]
- child_log_unobserved_factors_kernel_log_stds = params[3]
-
- child_unobserved_kernel_samples = sample_unobserved_kernel(rngs, child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds)
- child_unobserved_samples = sample_unobserved(rngs, child_unobserved_means, child_unobserved_log_stds)
-
- child_unobserved_kernel_samples = jnp.clip(child_unobserved_kernel_samples, a_min=1e-8, a_max=5)
-
- loss = 0.
-
- loss = jnp.sum(ell(data_node_weights,
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- child_unobserved_samples,
- cnv,
- lib_sizes,
- data,
- mask,))
- kl = local_paramkl(parent_unobserved_samples,
- child_unobserved_samples,
- child_unobserved_kernel_samples,
- child_unobserved_means,
- child_unobserved_log_stds,
- child_log_unobserved_factors_kernel_means,
- child_log_unobserved_factors_kernel_log_stds,
- hyperparams, # [concentration, rate]
- child_in_root_subtree,)
- # Scale by minibatch
- loss = loss + kl*mb_scaling
- return loss
-
- child_unobserved_log_stds = jnp.clip(child_unobserved_log_stds, a_min=jnp.log(1e-8))
- child_log_unobserved_factors_kernel_means = jnp.clip(child_log_unobserved_factors_kernel_means, a_min=jnp.log(1e-8), a_max=0.)
- child_log_unobserved_factors_kernel_log_stds = jnp.clip(child_log_unobserved_factors_kernel_log_stds, a_min=jnp.log(1e-8), a_max=0.)
-
- params = jnp.array([child_unobserved_means, child_unobserved_log_stds,
- child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds])
- loss, grads = jax.value_and_grad(local_loss)(params, parent_unobserved_samples,
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- hyperparams,
- child_in_root_subtree)
-
- state1, state2, state3, state4 = states
-
-
-
- m3, v3 = state3
- b1=0.9
- b2=0.999
- eps=1e-8
- m3 = (1 - b1) * grads[2] + b1 * m3 # First moment estimate.
- v3 = (1 - b2) * jnp.square(grads[2]) + b2 * v3 # Second moment estimate.
- state3 = (m3, v3)
- mhat = m3 / (1 - jnp.asarray(b1, m3.dtype) ** (i + 1)) # Bias correction.
- vhat = v3 / (1 - jnp.asarray(b2, m3.dtype) ** (i + 1))
- child_log_unobserved_factors_kernel_means = child_log_unobserved_factors_kernel_means + lr * mhat / (jnp.sqrt(vhat) + eps)
-
-
- m4, v4 = state4
- b1=0.9
- b2=0.999
- eps=1e-8
- m4 = (1 - b1) * grads[3] + b1 * m4 # First moment estimate.
- v4 = (1 - b2) * jnp.square(grads[3]) + b2 * v4 # Second moment estimate.
- state4 = (m4, 4)
- mhat = m4 / (1 - jnp.asarray(b1, m4.dtype) ** (i + 1)) # Bias correction.
- vhat = v4 / (1 - jnp.asarray(b2, m4.dtype) ** (i + 1))
- child_log_unobserved_factors_kernel_log_stds = child_log_unobserved_factors_kernel_log_stds + lr * mhat / (jnp.sqrt(vhat) + eps)
-
- m1, v1 = state1
- b1=0.9
- b2=0.999
- eps=1e-8
- m1 = (1 - b1) * grads[0] + b1 * m1 # First moment estimate.
- v1 = (1 - b2) * jnp.square(grads[0]) + b2 * v1 # Second moment estimate.
- state1 = (m1, v1)
- mhat = m1 / (1 - jnp.asarray(b1, m1.dtype) ** (i + 1)) # Bias correction.
- vhat = v1 / (1 - jnp.asarray(b2, m1.dtype) ** (i + 1))
- child_unobserved_means = child_unobserved_means + lr * mhat / (jnp.sqrt(vhat) + eps)
-
- m2, v2 = state2
- m2 = (1 - b1) * grads[1] + b1 * m2 # First moment estimate.
- v2 = (1 - b2) * jnp.square(grads[1]) + b2 * v2 # Second moment estimate.
- state2 = (m2, v2)
- mhat = m2 / (1 - jnp.asarray(b1, m2.dtype) ** (i + 1)) # Bias correction.
- vhat = v2 / (1 - jnp.asarray(b2, m2.dtype) ** (i + 1))
- child_unobserved_log_stds = child_unobserved_log_stds + lr * mhat / (jnp.sqrt(vhat) + eps)
-
-
-
- states = (state1, state2, state3, state4)
-
-
- return loss, states, child_unobserved_means, child_unobserved_log_stds, child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds
-
-@jax.jit
-def baseline_node_grad(
- rngs,
- log_baseline_mean,
- log_baseline_log_std,
- data_node_weights,
- child_unobserved_samples,
- cell_noise_samples,
- noise_factor_samples,
- cnv,
- lib_sizes,
- data,
- mask,):
- def local_loss(
- params,
- child_unobserved_samples,
- cell_noise_samples,
- noise_factor_samples,
- ):
- log_baseline_mean, log_baseline_log_std = params[0], params[1]
- baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std)
-
- loss = jnp.sum(ell(data_node_weights,
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- child_unobserved_samples,
- cnv,
- lib_sizes,
- data,
- mask,))
- return loss
-
- params = jnp.array([log_baseline_mean, log_baseline_log_std])
- grads = jax.grad(local_loss)(params, child_unobserved_samples, cell_noise_samples, noise_factor_samples,)
- return grads
-
-@jax.jit
-def baseline_kl_grad(mean, log_std):
- def _kl(mean, log_std):
- return jnp.sum(baseline_kl(mean, log_std))
- return jnp.array(jax.grad(_kl, argnums=(0,1))(mean, log_std))
-
-@jax.jit
-def baseline_step(log_baseline_mean, log_baseline_log_std, grads, states, i, lr=0.01):
- state1, state2 = states
-
- m1, v1 = state1
- b1=0.9
- b2=0.999
- eps=1e-8
- m1 = (1 - b1) * grads[0] + b1 * m1 # First moment estimate.
- v1 = (1 - b2) * jnp.square(grads[0]) + b2 * v1 # Second moment estimate.
- state1 = (m1, v1)
- mhat = m1 / (1 - jnp.asarray(b1, m1.dtype) ** (i + 1)) # Bias correction.
- vhat = v1 / (1 - jnp.asarray(b2, m1.dtype) ** (i + 1))
- log_baseline_mean = log_baseline_mean + lr * mhat / (jnp.sqrt(vhat) + eps)
-
- m2, v2 = state2
- m2 = (1 - b1) * grads[1] + b1 * m2 # First moment estimate.
- v2 = (1 - b2) * jnp.square(grads[1]) + b2 * v2 # Second moment estimate.
- state2 = (m2, v2)
- mhat = m2 / (1 - jnp.asarray(b1, m2.dtype) ** (i + 1)) # Bias correction.
- vhat = v2 / (1 - jnp.asarray(b2, m2.dtype) ** (i + 1))
- log_baseline_log_std = log_baseline_log_std + lr * mhat / (jnp.sqrt(vhat) + eps)
-
- state = (state1, state2)
-
- return state, log_baseline_mean, log_baseline_log_std
-
-@jax.jit
-def noise_node_grad(
- rngs,
- noise_factors_mean,
- noise_factors_log_std,
- data_node_weights,
- child_unobserved_samples,
- cell_noise_samples,
- baseline_samples,
- cnv,
- lib_sizes,
- data,
- mask,):
- def local_loss(
- params,
- child_unobserved_samples,
- cell_noise_samples,
- baseline_samples,
- ):
- noise_factors_mean, noise_factors_log_std = params[0], params[1]
- noise_factor_samples = sample_noise_factors(rngs, noise_factors_mean, noise_factors_log_std)
-
- loss = jnp.sum(ell(data_node_weights,
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- child_unobserved_samples,
- cnv,
- lib_sizes,
- data,
- mask,))
- return loss
-
- params = jnp.array([noise_factors_mean, noise_factors_log_std])
- grads = jnp.array(jax.grad(local_loss)(params, child_unobserved_samples, cell_noise_samples, baseline_samples,))
- return grads
-
-@jax.jit
-def noise_kl_grad(mean, log_std):
- def _kl(mean, log_std):
- return jnp.sum(noise_factors_kl(mean, log_std))
- return jnp.array(jax.grad(_kl, argnums=(0,1))(mean, log_std))
-
-@jax.jit
-def noise_step(noise_factors_mean, noise_factors_log_std, grads, states, i, lr=0.01):
- state1, state2 = states
-
- m1, v1 = state1
- b1=0.9
- b2=0.999
- eps=1e-8
- m1 = (1 - b1) * grads[0] + b1 * m1 # First moment estimate.
- v1 = (1 - b2) * jnp.square(grads[0]) + b2 * v1 # Second moment estimate.
- state1 = (m1, v1)
- mhat = m1 / (1 - jnp.asarray(b1, m1.dtype) ** (i + 1)) # Bias correction.
- vhat = v1 / (1 - jnp.asarray(b2, m1.dtype) ** (i + 1))
- noise_factors_mean = noise_factors_mean + lr * mhat / (jnp.sqrt(vhat) + eps)
-
- m2, v2 = state2
- m2 = (1 - b1) * grads[1] + b1 * m2 # First moment estimate.
- v2 = (1 - b2) * jnp.square(grads[1]) + b2 * v2 # Second moment estimate.
- state2 = (m2, v2)
- mhat = m2 / (1 - jnp.asarray(b1, m2.dtype) ** (i + 1)) # Bias correction.
- vhat = v2 / (1 - jnp.asarray(b2, m2.dtype) ** (i + 1))
- noise_factors_log_std = noise_factors_log_std + lr * mhat / (jnp.sqrt(vhat) + eps)
-
- state = (state1, state2)
-
- return state, noise_factors_mean, noise_factors_log_std
-
-
-@jax.jit
-def cellnoise_node_grad(
- rngs,
- cell_noise_mean,
- cell_noise_log_std,
- data_node_weights,
- child_unobserved_samples,
- noise_factor_samples,
- baseline_samples,
- cnv,
- lib_sizes,
- data,
- mask,):
- def local_loss(
- params,
- child_unobserved_samples,
- noise_factor_samples,
- baseline_samples,
- ):
- cell_noise_mean, cell_noise_log_std = params[0], params[1]
- cell_noise_samples = sample_noise_factors(rngs, cell_noise_mean, cell_noise_log_std)
-
- loss = jnp.sum(ell(data_node_weights,
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- child_unobserved_samples,
- cnv,
- lib_sizes,
- data,
- mask,))
- return loss
-
- params = jnp.array([cell_noise_mean, cell_noise_log_std])
- grads = jnp.array(jax.grad(local_loss)(params, child_unobserved_samples, noise_factor_samples, baseline_samples,))
- return grads
-
-@jax.jit
-def cellnoise_kl_grad(mean, log_std):
- def _kl(mean, log_std):
- return jnp.sum(cell_noise_kl(mean, log_std))
- return jnp.array(jax.grad(_kl, argnums=(0,1))(mean, log_std))
-
-@jax.jit
-def cellnoise_step(cell_noise_mean, cell_noise_log_std, grads, states, i, lr=0.01):
- state1, state2 = states
-
- m1, v1 = state1
- b1=0.9
- b2=0.999
- eps=1e-8
- m1 = (1 - b1) * grads[0] + b1 * m1 # First moment estimate.
- v1 = (1 - b2) * jnp.square(grads[0]) + b2 * v1 # Second moment estimate.
- state1 = (m1, v1)
- mhat = m1 / (1 - jnp.asarray(b1, m1.dtype) ** (i + 1)) # Bias correction.
- vhat = v1 / (1 - jnp.asarray(b2, m1.dtype) ** (i + 1))
- cell_noise_mean = cell_noise_mean + lr * mhat / (jnp.sqrt(vhat) + eps)
-
- m2, v2 = state2
- m2 = (1 - b1) * grads[1] + b1 * m2 # First moment estimate.
- v2 = (1 - b2) * jnp.square(grads[1]) + b2 * v2 # Second moment estimate.
- state2 = (m2, v2)
- mhat = m2 / (1 - jnp.asarray(b1, m2.dtype) ** (i + 1)) # Bias correction.
- vhat = v2 / (1 - jnp.asarray(b2, m2.dtype) ** (i + 1))
- cell_noise_log_std = cell_noise_log_std + lr * mhat / (jnp.sqrt(vhat) + eps)
-
- state = (state1, state2)
-
- return state, cell_noise_mean, cell_noise_log_std
-
-
-#
-# @jax.jit
-# def update_global_parameters(
-# hyperparams, # [concentration, rate]
-# child_in_root_subtree, # to make the prior prefer amplifications
-# cnv,
-# lib_sizes,
-# data,
-# parent_unobserved_sample,
-# child_unobserved_sample,
-# child_unobserved_kernel_sample,
-# ):
-# def local_loss(local_params):
-# return ell(local_params) + baseline_kl(baseline)
-# return loss
-#
-#
-# return child_unobserved_means, child_unobserved_log_stds, child_log_unobserved_factors_kernel_means, child_log_unobserved_factors_kernel_log_stds
-
-#
-# def tree_traversal_compute_elbo(rng, mc_samples=3):
-# """
-# This function traverses the tree starting at the root and computes the
-# complete ELBO
-# """
-# rngs = random.split(rng, mc_samples)
-#
-# # Get data
-# data = self.data
-# lib_sizes = self.root["node"].root["node"].lib_sizes
-# n_cells, n_genes = self.data.shape
-#
-# # Get global variational parameters
-# log_baseline_mean = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"]
-# log_baseline_log_std = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"]
-# cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"]
-# cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"]
-# noise_factors_mean = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"]
-# noise_factors_log_std = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"]
-#
-# # Sample global
-# baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std)
-# cell_noise_factors_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std)
-# noise_factors_samples = sample_noise_factors(rng, noise_factors_mean, noise_factors_log_std)
-# noise_samples = cell_noise_factors_samples.dot(noise_factors_samples)
-#
-# # Traverse tree
-# def descend(root):
-# total_subtree_weight = 0
-# indices = list(range(len(root["children"])))
-# # indices = indices[::-1]
-#
-# parent_unobserved_means = root["node"].variational_parameters["locals"]["unobserved_factors_mean"]
-# parent_unobserved_log_stds = root["node"].variational_parameters["locals"]["unobserved_factors_log_std"]
-# # Sample parent
-# parent_unobserved_samples = sample_unobserved(rngs, parent_unobserved_means, parent_unobserved_log_stds)
-# psi_not_prev_sum = 0
-# for i in indices:
-# child = root["children"][i]
-# cnv = child.cnvs * jnp.ones((n_cells,n_genes))
-# data_node_weights = child.data_weights
-#
-# # Sample node
-# unobserved_samples = sample_unobserved(rngs, unobserved_means, unobserved_log_stds)
-# unobserved_kernel_samples = sample_unobserved_kernel(rngs, unobserved_factors_kernel_log_means, unobserved_factors_kernel_log_stds)
-#
-# # Compute node ll
-# unobserved_samples = unobserved_samples * jnp.ones((mc_samples, n_cells, n_genes)) # broadcast
-# unobserved_kernel_samples = unobserved_kernel_samples * jnp.ones((mc_samples, n_cells, n_genes)) # broadcast
-# ll = ell(data_node_weights, baseline_samples, noise_samples, unobserved_samples, cnv, lib_sizes, data)
-# child.ell = ll
-# child.param_kl = local_paramkl(unobserved_samples, unobserved_kernel_samples, parent_unobserved_samples)
-#
-# E_log_phi = E_q_log_beta(psi_alpha, psi_beta) + psi_not_prev_sum
-# E_log_nu = E_q_log_beta(nu_alpha, nu_beta)
-# E_log_1_nu = E_q_log_1_beta(nu_alpha, nu_beta)
-#
-# child.weight_until_here = child.data_weights + root.weight_until_here
-# child.expected_weights = child.data_weights*E_log_nu + root.weight_until_here*E_log_1_nu + child.weight_until_here*E_log_phi
-# total_weight += child.data_weights
-#
-# if i > 0:
-# psi_not_prev_sum += E_q_log_1_beta(psi_alpha, psi_beta)
-#
-# node.stick_kl = beta_kl(root["node"].variational_parameters["locals"]["nu_log_mean"], 1, root["node"].variational_parameters["locals"]["nu_log_mean"], root.tssb.dp_alpha)
-# node.stick_kl += beta_kl(root["node"].variational_parameters["locals"]["psi_log_mean"], 1, root["node"].variational_parameters["locals"]["psi_log_mean"], root.tssb.dp_gamma)
-#
-# expected_weight, nu_sticks_sum, psi_sticks_sum = compute_expected_weight(nu_alpha, nu_beta, psi_alpha, psi_beta, prev_nu_sticks_sum, prev_psi_sticks_sum)
-# node.ew = expected_weight
-#
-# return total_weight, lls, expected_weight, nodes
-#
-# _, lls, expected_w, nodes = descend(self.root)
-#
-# # Normalize data_node_weights and compute local ELBO contributions
-# data_weights = []
-# for node in nodes:
-# data_weights.append(node.data_weights)
-# data_weights = np.array(data_weights).reshape(-1,mb_size)/np.sum(data_weights)
-# elbo = 0
-# for i, node in nodes:
-# node.data_weights = data_weights[data_indices,i]
-# # Compute ELBO using normalized data_weights
-# node_data_elbo_contributions = node.data_weights*node.ell + node.ew - node.data_weights*np.log(node.data_weights))
-# node_kl = node.stick_kl + node.param_kl
-# elbo += node_data_elbo_contributions - node_kl
-#
-# return elbo
-#
-#
-# def tree_traversal_update(root, rng, update_global=True, n_inner_steps=10, mc_samples=3):
-# """
-# This function traverses the tree starting at `root` and updates the
-# variational parameters and ELBO contributions while doing it
-# """
-# rngs = random.split(rng, mc_samples)
-#
-# # Get data
-# data = self.data
-# lib_sizes = self.root["node"].root["node"].lib_sizes
-# n_cells, n_genes = self.data.shape
-#
-# # Get global variational parameters
-# log_baseline_mean = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"]
-# log_baseline_log_std = self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"]
-# cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"]
-# cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"]
-# noise_factors_mean = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"]
-# noise_factors_log_std = self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"]
-#
-# # Sample global
-# baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std)
-# cell_noise_factors_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std)
-# noise_factors_samples = sample_noise_factors(rng, noise_factors_mean, noise_factors_log_std)
-# noise_samples = cell_noise_factors_samples.dot(noise_factors_samples)
-#
-# # Traverse tree and update variational parameters
-# def descend(root, depth=0):
-# weight_down = 0
-# indices = list(range(len(root["children"])))
-# indices = indices[::-1]
-#
-# parent_unobserved_means = root["node"].variational_parameters["locals"]["unobserved_factors_mean"]
-# parent_unobserved_log_stds = root["node"].variational_parameters["locals"]["unobserved_factors_log_std"]
-# for i in indices:
-# child = root["children"][i]
-# cnv = child.cnvs * jnp.ones((n_cells,n_genes))
-# data_node_weights = child.data_weights
-#
-# # Sample parent
-# parent_unobserved_samples = sample_unobserved(rngs, parent_unobserved_means, parent_unobserved_log_stds)
-#
-# # Update local parameters
-# unobserved_means = root["node"].variational_parameters["locals"]["unobserved_factors_mean"]
-# unobserved_log_stds = root["node"].variational_parameters["locals"]["unobserved_factors_log_std"]
-# unobserved_factors_kernel_log_mean = root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_mean"]
-# unobserved_factors_kernel_log_std = root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_std"]
-# update_local_parameters(baseline_samples, noise_samples, parent_unobserved_samples, steps=n_inner_steps)
-#
-# # Sample updated node
-# unobserved_samples = sample_unobserved(rngs, unobserved_means, unobserved_log_stds)
-# unobserved_kernel_samples = sample_unobserved_kernel(rngs, unobserved_factors_kernel_log_means, unobserved_factors_kernel_log_stds)
-#
-# # Compute node ll
-# unobserved_samples = unobserved_samples * jnp.ones((mc_samples, n_cells, n_genes)) # broadcast
-# unobserved_kernel_samples = unobserved_kernel_samples * jnp.ones((mc_samples, n_cells, n_genes)) # broadcast
-# ll = ell(data_node_weights, baseline_samples, noise_samples, unobserved_samples, cnv, lib_sizes, data)
-# child.ell[data_indices] = ll
-# child.param_kl = local_paramkl(unobserved_samples, unobserved_kernel_samples, parent_unobserved_samples)
-#
-# child_weight, _, _, _ = descend(child, depth + 1)
-# post_alpha = 1.0 + child_weight
-# post_beta = self.dp_gamma + weight_down
-# child["node"].variational_parameters["locals"]["psi_log_mean"] = np.log(post_alpha)
-# child["node"].variational_parameters["locals"]["psi_log_std"] = np.log(post_beta)
-# weight_down += child_weight
-#
-# prev_psi_sticks_sum += local_psi_sticks_sum(psi_alpha, psi_beta)
-#
-# weight_here = np.sum(root["node"].data_weights)
-# total_weight = weight_here + weight_down
-# post_alpha = 1.0 + weight_here
-# post_beta = (self.alpha_decay**depth) * self.dp_alpha + weight_down
-# root["node"].variational_parameters["locals"]["nu_log_mean"] = np.log(post_alpha)
-# root["node"].variational_parameters["locals"]["nu_log_std"] = np.log(post_beta)
-#
-# node.stick_kl = beta_kl(root["node"].variational_parameters["locals"]["nu_log_mean"], root["node"].variational_parameters["locals"]["nu_log_mean"])
-# node.stick_kl += beta_kl(root["node"].variational_parameters["locals"]["psi_log_mean"], root["node"].variational_parameters["locals"]["psi_log_mean"])
-#
-# expected_weight, nu_sticks_sum, psi_sticks_sum = compute_expected_weight(nu_alpha, nu_beta, psi_alpha, psi_beta, prev_nu_sticks_sum, prev_psi_sticks_sum)
-# node.ew = expected_weight
-# node.data_weights[data_indices] = np.exp(node.ell + node.ew)
-#
-# return total_weight, lls, expected_weight, nodes
-#
-# _, lls, expected_w, nodes = descend(root)
-#
-# # Update global parameters
-# if root.parent() is None and update_global:
-# # Use samples from tree traversal to update global parameters
-# update_global_parameters(steps=n_inner_steps)
-#
-# # Compute global KL
-# global_kl = baseline_kl(log_baseline_mean, log_baseline_log_std)
-#
-# # Use lls from tree traversal to update data_node_weights
-# data_weights = []
-# for node in nodes:
-# data_weights.append(node.data_weights)
-# data_weights = np.array(data_weights).reshape(-1,mb_size)/np.sum(data_weights)
-# for i, node in nodes:
-# node.data_weights[data_indices] = data_weights[data_indices,i]
-# # Compute ELBO using normalized data_weights
-# node_data_elbo_contributions = node.data_weights*(node.ell + node.ew - np.log(node.data_weights))
-# node_elbo_contributions = node.stick_kl + node.param_kl
-#
-# # Compute total elbo: use decomposibility!
-# elbo =
-#
-# return elbo
-
-
-def compute_elbo(
- rng,
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- cnvs,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- nu_sticks_log_means,
- nu_sticks_log_stds,
- psi_sticks_log_means,
- psi_sticks_log_stds,
- unobserved_means,
- unobserved_log_stds,
- log_unobserved_factors_kernel_means,
- log_unobserved_factors_kernel_log_stds,
- log_baseline_mean,
- log_baseline_log_std,
- cell_noise_mean,
- cell_noise_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
-):
- elbo, ll, kl, node_kl = _compute_elbo(
- rng,
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- cnvs,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- nu_sticks_log_means,
- nu_sticks_log_stds,
- psi_sticks_log_means,
- psi_sticks_log_stds,
- unobserved_means,
- unobserved_log_stds,
- log_unobserved_factors_kernel_means,
- log_unobserved_factors_kernel_log_stds,
- log_baseline_mean,
- log_baseline_log_std,
- cell_noise_mean,
- cell_noise_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
- )
- return elbo
-
-
-def _compute_elbo(
- rng,
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- cnvs,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- nu_sticks_log_means,
- nu_sticks_log_stds,
- psi_sticks_log_means,
- psi_sticks_log_stds,
- unobserved_means,
- unobserved_log_stds,
- log_unobserved_factors_kernel_means,
- log_unobserved_factors_kernel_log_stds,
- log_baseline_mean,
- log_baseline_log_std,
- cell_noise_mean,
- cell_noise_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
-):
-
- # single-sample Monte Carlo estimate of the variational lower bound
- mb_size = len(indices)
-
- def stop_global(globals):
- (
- log_baseline_mean,
- log_baseline_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
- ) = (
- globals[0],
- globals[1],
- globals[2],
- globals[3],
- globals[4],
- globals[5],
- globals[6],
- globals[7],
- )
- log_baseline_mean = jax.lax.stop_gradient(log_baseline_mean)
- log_baseline_log_std = jax.lax.stop_gradient(log_baseline_log_std)
- noise_factors_mean = jax.lax.stop_gradient(noise_factors_mean)
- noise_factors_log_std = jax.lax.stop_gradient(noise_factors_log_std)
- factor_precision_log_means = jax.lax.stop_gradient(factor_precision_log_means)
- factor_precision_log_stds = jax.lax.stop_gradient(factor_precision_log_stds)
- batch_effects_mean = jax.lax.stop_gradient(batch_effects_mean)
- batch_effects_log_std = jax.lax.stop_gradient(batch_effects_log_std)
- return (
- log_baseline_mean,
- log_baseline_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
- )
-
- def alt_global(globals):
- (
- log_baseline_mean,
- log_baseline_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
- ) = (
- globals[0],
- globals[1],
- globals[2],
- globals[3],
- globals[4],
- globals[5],
- globals[6],
- globals[7],
- )
- return (
- log_baseline_mean,
- log_baseline_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
- )
-
- (
- log_baseline_mean,
- log_baseline_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
- ) = jax.lax.cond(
- do_global,
- alt_global,
- stop_global,
- (
- log_baseline_mean,
- log_baseline_log_std,
- noise_factors_mean,
- noise_factors_log_std,
- factor_precision_log_means,
- factor_precision_log_stds,
- batch_effects_mean,
- batch_effects_log_std,
- ),
- )
-
- def stop_node_grad(i):
- return (
- jax.lax.stop_gradient(nu_sticks_log_means[i]),
- jax.lax.stop_gradient(nu_sticks_log_stds[i]),
- jax.lax.stop_gradient(psi_sticks_log_means[i]),
- jax.lax.stop_gradient(psi_sticks_log_stds[i]),
- jax.lax.stop_gradient(log_unobserved_factors_kernel_log_stds[i]),
- jax.lax.stop_gradient(log_unobserved_factors_kernel_means[i]),
- jax.lax.stop_gradient(unobserved_log_stds[i]),
- jax.lax.stop_gradient(unobserved_means[i]),
- )
-
- def stop_node_grads(i):
- return jax.lax.cond(
- node_mask[i] != 1,
- stop_node_grad,
- lambda i: (
- nu_sticks_log_means[i],
- nu_sticks_log_stds[i],
- psi_sticks_log_means[i],
- psi_sticks_log_stds[i],
- log_unobserved_factors_kernel_log_stds[i],
- log_unobserved_factors_kernel_means[i],
- unobserved_log_stds[i],
- unobserved_means[i],
- ),
- i,
- ) # Sample all
-
- (
- nu_sticks_log_means,
- nu_sticks_log_stds,
- psi_sticks_log_means,
- psi_sticks_log_stds,
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- ) = vmap(stop_node_grads)(jnp.arange(len(cnvs)))
-
- def stop_node_params_grads(locals):
- return (
- jax.lax.stop_gradient(log_unobserved_factors_kernel_log_stds),
- jax.lax.stop_gradient(log_unobserved_factors_kernel_means),
- jax.lax.stop_gradient(unobserved_log_stds),
- jax.lax.stop_gradient(unobserved_means),
- )
-
- def alt_node_params(locals):
- return (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- )
-
- (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- ) = jax.lax.cond(
- sticks_only,
- stop_node_params_grads,
- alt_node_params,
- (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- ),
- )
-
- def stop_sticks(locals):
- return (
- jax.lax.stop_gradient(nu_sticks_log_means),
- jax.lax.stop_gradient(nu_sticks_log_stds),
- jax.lax.stop_gradient(psi_sticks_log_means),
- jax.lax.stop_gradient(psi_sticks_log_stds),
- )
-
- def keep_sticks(locals):
- return (
- nu_sticks_log_means,
- nu_sticks_log_stds,
- psi_sticks_log_means,
- psi_sticks_log_stds,
- )
-
- (
- nu_sticks_log_means,
- nu_sticks_log_stds,
- psi_sticks_log_means,
- psi_sticks_log_stds,
- ) = jax.lax.cond(
- do_sticks,
- keep_sticks,
- stop_sticks,
- (
- nu_sticks_log_means,
- nu_sticks_log_stds,
- psi_sticks_log_means,
- psi_sticks_log_stds,
- ),
- )
-
- def stop_non_global(locals):
- (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- psi_sticks_log_stds,
- psi_sticks_log_means,
- nu_sticks_log_stds,
- nu_sticks_log_means,
- ) = (
- locals[0],
- locals[1],
- locals[2],
- locals[3],
- locals[4],
- locals[5],
- locals[6],
- locals[7],
- )
- log_unobserved_factors_kernel_log_stds = jax.lax.stop_gradient(
- log_unobserved_factors_kernel_log_stds
- )
- log_unobserved_factors_kernel_means = jax.lax.stop_gradient(
- log_unobserved_factors_kernel_means
- )
- unobserved_log_stds = jax.lax.stop_gradient(unobserved_log_stds)
- unobserved_means = jax.lax.stop_gradient(unobserved_means)
- psi_sticks_log_stds = jax.lax.stop_gradient(psi_sticks_log_stds)
- psi_sticks_log_means = jax.lax.stop_gradient(psi_sticks_log_means)
- nu_sticks_log_stds = jax.lax.stop_gradient(nu_sticks_log_stds)
- nu_sticks_log_means = jax.lax.stop_gradient(nu_sticks_log_means)
- return (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- psi_sticks_log_stds,
- psi_sticks_log_means,
- nu_sticks_log_stds,
- nu_sticks_log_means,
- )
-
- def alt_non_global(locals):
- (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- psi_sticks_log_stds,
- psi_sticks_log_means,
- nu_sticks_log_stds,
- nu_sticks_log_means,
- ) = (
- locals[0],
- locals[1],
- locals[2],
- locals[3],
- locals[4],
- locals[5],
- locals[6],
- locals[7],
- )
- return (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- psi_sticks_log_stds,
- psi_sticks_log_means,
- nu_sticks_log_stds,
- nu_sticks_log_means,
- )
-
- (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- psi_sticks_log_stds,
- psi_sticks_log_means,
- nu_sticks_log_stds,
- nu_sticks_log_means,
- ) = jax.lax.cond(
- global_only,
- stop_non_global,
- alt_non_global,
- (
- log_unobserved_factors_kernel_log_stds,
- log_unobserved_factors_kernel_means,
- unobserved_log_stds,
- unobserved_means,
- psi_sticks_log_stds,
- psi_sticks_log_means,
- nu_sticks_log_stds,
- nu_sticks_log_means,
- ),
- )
-
- def keep_cell_grad(i):
- return cell_noise_mean[indices][i], cell_noise_log_std[indices][i]
-
- def stop_cell_grad(i):
- return jax.lax.stop_gradient(
- cell_noise_mean[indices][i]
- ), jax.lax.stop_gradient(cell_noise_log_std[indices][i])
-
- def stop_cell_grads(i):
- return jax.lax.cond(data_mask_subset[i] != 1, stop_cell_grad, keep_cell_grad, i)
-
- cell_noise_mean, cell_noise_log_std = vmap(stop_cell_grads)(jnp.arange(mb_size))
-
- def has_children(i):
- return jnp.any(ancestor_nodes_indices.ravel() == i)
-
- def has_next_branch(i):
- return jnp.any(previous_branches_indices.ravel() == i)
-
- ones_vec = jnp.ones(cnvs[0].shape)
- zeros_vec = jnp.log(ones_vec)
-
- log_baseline = diag_gaussian_sample(rng, log_baseline_mean, log_baseline_log_std)
-
- # noise
- factor_precision_log_means = jnp.clip(
- factor_precision_log_means, a_min=jnp.log(1e-3), a_max=jnp.log(1e3)
- )
- factor_precision_log_stds = jnp.clip(
- factor_precision_log_stds, a_min=jnp.log(1e-3), a_max=jnp.log(1e2)
- )
- log_factors_precisions = diag_gaussian_sample(
- rng, factor_precision_log_means, factor_precision_log_stds
- )
- noise_factors_mean = jnp.clip(noise_factors_mean, a_min=-10.0, a_max=10.0)
- noise_factors_log_std = jnp.clip(
- noise_factors_log_std, a_min=jnp.log(1e-3), a_max=jnp.log(1e2)
- )
- noise_factors = diag_gaussian_sample(rng, noise_factors_mean, noise_factors_log_std)
- cell_noise_mean = jnp.clip(cell_noise_mean, a_min=-10.0, a_max=10.0)
- cell_noise_log_std = jnp.clip(
- cell_noise_log_std, a_min=jnp.log(1e-3), a_max=jnp.log(1e2)
- )
- cell_noise = diag_gaussian_sample(rng, cell_noise_mean, cell_noise_log_std)
- noise = jnp.dot(cell_noise, noise_factors)
-
- # batch effects
- batch_effects_mean = jnp.clip(batch_effects_mean, a_min=-10.0, a_max=10.0)
- batch_effects_log_std = jnp.clip(
- batch_effects_log_std, a_min=jnp.log(1e-3), a_max=jnp.log(1e2)
- )
- batch_effects_factors = diag_gaussian_sample(
- rng, batch_effects_mean, batch_effects_log_std
- )
- cell_covariates = jnp.array(cell_covariates)[indices]
- batch_effects = jnp.dot(cell_covariates, batch_effects_factors)
-
- # unobserved factors
- log_unobserved_factors_kernel_means = jnp.clip(
- log_unobserved_factors_kernel_means, a_min=jnp.log(1e-6), a_max=jnp.log(1e2)
- )
- log_unobserved_factors_kernel_log_stds = jnp.clip(
- log_unobserved_factors_kernel_log_stds,
- a_min=jnp.log(1e-6),
- a_max=jnp.log(1e2),
- )
-
- def sample_unobs_kernel(i):
- return jnp.clip(
- diag_gaussian_sample(
- rng,
- log_unobserved_factors_kernel_means[i],
- log_unobserved_factors_kernel_log_stds[i],
- ),
- a_min=jnp.log(1e-6),
- )
-
- def sample_all_unobs_kernel(i):
- return jax.lax.cond(
- node_mask[i] >= 0, sample_unobs_kernel, lambda i: zeros_vec, i
- )
-
- nodes_log_unobserved_factors_kernels = vmap(sample_all_unobs_kernel)(
- jnp.arange(len(cnvs))
- )
-
- unobserved_means = jnp.clip(unobserved_means, a_min=jnp.log(1e-3), a_max=10)
- unobserved_log_stds = jnp.clip(
- unobserved_log_stds, a_min=jnp.log(1e-6), a_max=jnp.log(1e2)
- )
-
- def sample_unobs(i):
- return diag_gaussian_sample(rng, unobserved_means[i], unobserved_log_stds[i])
-
- def sample_all_unobs(i):
- return jax.lax.cond(node_mask[i] >= 0, sample_unobs, lambda i: zeros_vec, i)
-
- nodes_unobserved_factors = vmap(sample_all_unobs)(jnp.arange(len(cnvs)))
-
- nu_sticks_log_alphas = jnp.clip(nu_sticks_log_means, -3.0, 3.0)
- nu_sticks_log_betas = jnp.clip(nu_sticks_log_stds, -3.0, 3.0)
-
- # def sample_nu(i):
- # return jnp.clip(
- # diag_gaussian_sample(
- # rng, nu_sticks_log_means[i], nu_sticks_log_stds[i]
- # ),
- # -4.0,
- # 4.0,
- # )
- #
- # def sample_valid_nu(i):
- # return jax.lax.cond(
- # has_children(i), sample_nu, lambda i: jnp.array([logit(1 - 1e-6)]), i
- # ) # Sample all valid
- #
- # def sample_all_nus(i):
- # return jax.lax.cond(
- # cnvs[i][0] >= 0, sample_valid_nu, lambda i: jnp.array([logit(1e-6)]), i
- # ) # Sample all
- #
- # log_nu_sticks = vmap(sample_all_nus)(jnp.arange(len(cnvs)))
- #
- # def sigmoid_nus(i):
- # return jax.lax.cond(
- # cnvs[i][0] >= 0,
- # lambda i: jnn.sigmoid(log_nu_sticks[i]),
- # lambda i: jnp.array([1e-6]),
- # i,
- # ) # Sample all
- #
- # nu_sticks = jnp.clip(vmap(sigmoid_nus)(jnp.arange(len(cnvs))), 1e-6, 1 - 1e-6)
-
- def sample_nu(i):
- return jnp.clip(
- beta_sample(rng, nu_sticks_log_alphas[i], nu_sticks_log_betas[i]),
- 1e-4,
- 1 - 1e-4,
- )
-
- def sample_valid_nu(i):
- return jax.lax.cond(
- has_children(i), sample_nu, lambda i: jnp.array([1 - 1e-4]), i
- ) # Sample all valid
-
- def sample_all_nus(i):
- return jax.lax.cond(
- cnvs[i][0] >= 0, sample_valid_nu, lambda i: jnp.array([1e-4]), i
- ) # Sample all
-
- nu_sticks = vmap(sample_all_nus)(jnp.arange(len(cnvs)))
-
- psi_sticks_log_alphas = jnp.clip(psi_sticks_log_means, -3.0, 3.0)
- psi_sticks_log_betas = jnp.clip(psi_sticks_log_stds, -3.0, 3.0)
-
- # def sample_psi(i):
- # return jnp.clip(
- # diag_gaussian_sample(
- # rng, psi_sticks_log_means[i], psi_sticks_log_stds[i]
- # ),
- # -4.0,
- # 4.0,
- # )
- #
- # def sample_valid_psis(i):
- # return jax.lax.cond(
- # has_next_branch(i),
- # sample_psi,
- # lambda i: jnp.array([logit(1 - 1e-6)]),
- # i,
- # ) # Sample all valid
- #
- # def sample_all_psis(i):
- # return jax.lax.cond(
- # cnvs[i][0] >= 0,
- # sample_valid_psis,
- # lambda i: jnp.array([logit(1e-6)]),
- # i,
- # ) # Sample all
- #
- # log_psi_sticks = vmap(sample_all_psis)(jnp.arange(len(cnvs)))
- #
- # def sigmoid_psis(i):
- # return jax.lax.cond(
- # cnvs[i][0] >= 0,
- # lambda i: jnn.sigmoid(log_psi_sticks[i]),
- # lambda i: jnp.array([1e-6]),
- # i,
- # ) # Sample all
- #
- # psi_sticks = jnp.clip(vmap(sigmoid_psis)(jnp.arange(len(cnvs))), 1e-6, 1 - 1e-6)
-
- def sample_psi(i):
- return jnp.clip(
- beta_sample(rng, psi_sticks_log_alphas[i], psi_sticks_log_betas[i]),
- 1e-4,
- 1 - 1e-4,
- )
-
- def sample_valid_psis(i):
- return jax.lax.cond(
- has_next_branch(i),
- sample_psi,
- lambda i: jnp.array([1 - 1e-4]),
- i,
- ) # Sample all valid
-
- def sample_all_psis(i):
- return jax.lax.cond(
- cnvs[i][0] >= 0,
- sample_valid_psis,
- lambda i: jnp.array([1e-4]),
- i,
- ) # Sample all
-
- psi_sticks = vmap(sample_all_psis)(jnp.arange(len(cnvs)))
-
- lib_sizes = jnp.array(lib_sizes)[indices]
- data = jnp.array(data)[indices]
- baseline = jnp.exp(jnp.append(0, log_baseline))
-
- def compute_node_ll(i):
- unobserved_factors = nodes_unobserved_factors[i] * (parent_vector[i] != -1)
-
- node_mean = (
- baseline * cnvs[i] / 2 * jnp.exp(unobserved_factors + noise + batch_effects)
- )
- sum = jnp.sum(node_mean, axis=1).reshape(mb_size, 1)
- node_mean = node_mean / sum
- node_mean = node_mean * lib_sizes
- pll = vmap(jax.scipy.stats.poisson.logpmf)(data, node_mean)
- ll = jnp.sum(pll, axis=1) # N-vector
-
- # TSSB prior
- nu_stick = nu_sticks[i]
- psi_stick = psi_sticks[i]
-
- def prev_branches_psi(idx):
- return (idx != -1) * jnp.log(1.0 - psi_sticks[idx])
-
- def ancestors_nu(idx):
- _log_phi = jnp.log(psi_sticks[idx]) + jnp.sum(
- vmap(prev_branches_psi)(previous_branches_indices[idx])
- )
- _log_1_nu = jnp.log(1.0 - nu_sticks[idx])
- total = _log_phi + _log_1_nu
- return (idx != -1) * total
-
- # log_phi = jnp.log(psi_stick) + jnp.sum(
- # vmap(prev_branches_psi)(previous_branches_indices[i])
- # )
- # log_node_weight = (
- # jnp.log(nu_stick)
- # + log_phi
- # + jnp.sum(vmap(ancestors_nu)(ancestor_nodes_indices[i]))
- # )
- # log_node_weight = log_node_weight + jnp.log(tssb_weights[i])
- # ll = ll + log_node_weight # N-vector
-
- return ll
-
- small_ll = -1e30 * jnp.ones((mb_size))
-
- def get_node_ll(i):
- return jnp.where(
- node_mask[i] >= 0,
- compute_node_ll(jnp.where(node_mask[i] >= 0, i, 0)),
- small_ll,
- )
-
- out = jnp.array(vmap(get_node_ll)(jnp.arange(len(parent_vector))))
- l = jnp.sum(jnn.logsumexp(out, axis=0) * data_mask_subset)
-
- log_rate = jnp.log(unobserved_factors_kernel_rate)
- log_concentration = jnp.log(unobserved_factors_kernel_concentration)
- log_kernel = jnp.log(unobserved_factors_root_kernel)
- broadcasted_concentration = log_concentration * ones_vec
- broadcasted_rate = log_rate * ones_vec
-
- def compute_node_kl(i):
- kl = 0.0
- pl = diag_gamma_logpdf(
- jnp.clip(jnp.exp(nodes_log_unobserved_factors_kernels[i]), a_min=1e-6),
- broadcasted_concentration,
- (parent_vector[i] != -1)
- * (parent_vector[i] != 0)
- * unobserved_factors_kernel_rate
- * (jnp.abs(nodes_unobserved_factors[parent_vector[i]])),
- )
- ent = -diag_loggaussian_logpdf(
- jnp.clip(jnp.exp(nodes_log_unobserved_factors_kernels[i]), a_min=1e-6),
- log_unobserved_factors_kernel_means[i],
- log_unobserved_factors_kernel_log_stds[i],
- )
- kl += (parent_vector[i] != -1) * (pl + ent)
-
- # # Penalize copies in unobserved nodes
- # pl = diag_gamma_logpdf(1e-6 * jnp.ones(broadcasted_concentration.shape), broadcasted_concentration,
- # (parent_vector[i] != -1)*(log_rate + jnp.abs(nodes_unobserved_factors[parent_vector[i]])))
- # ent = - diag_gaussian_logpdf(jnp.log(1e-6 * jnp.ones(broadcasted_concentration.shape)), log_unobserved_factors_kernel_means[i], log_unobserved_factors_kernel_log_stds[i])
- # kl -= (parent_vector[i] != -1) * jnp.all(tssb_indices[i] == tssb_indices[parent_vector[i]]) * (pl + 0*ent)
-
- # unobserved_factors
- is_root_subtree = jnp.all(tssb_indices[i] == tssb_indices[0])
- pl = diag_gaussian_logpdf(
- nodes_unobserved_factors[i],
- (parent_vector[i] != -1)
- * (parent_vector[i] != 0)
- * nodes_unobserved_factors[parent_vector[i]]
- + (parent_vector[i] != -1)
- * is_root_subtree # promote overexpressing events near the root
- * (jnp.exp(nodes_log_unobserved_factors_kernels[i]) > 0.1)
- * 0.2,
- jnp.clip(nodes_log_unobserved_factors_kernels[i], a_min=jnp.log(1e-6))
- * (parent_vector[i] != -1),
- )
- ent = -diag_gaussian_logpdf(
- nodes_unobserved_factors[i], unobserved_means[i], unobserved_log_stds[i]
- )
- kl += (parent_vector[i] != -1) * (pl + ent)
-
- # # Penalize copied unobserved_factors
- # pl = diag_gaussian_logpdf(nodes_unobserved_factors[parent_vector[i]],
- # (parent_vector[i] != -1) * nodes_unobserved_factors[parent_vector[i]],
- # jnp.clip(nodes_log_unobserved_factors_kernels[i], a_min=jnp.log(1e-6))*(parent_vector[i] != -1) + log_kernel*(parent_vector[i] == -1))
- # ent = - diag_gaussian_logpdf(nodes_unobserved_factors[parent_vector[i]], unobserved_means[i], unobserved_log_stds[i])
- # kl -= (parent_vector[i] != -1) * jnp.all(tssb_indices[i] == tssb_indices[parent_vector[i]]) * (pl + 0*ent)
-
- # sticks
- nu_pl = has_children(i) * beta_logpdf(
- nu_sticks[i],
- jnp.log(jnp.array([1.0])),
- jnp.log(jnp.array([dp_alphas[i]])),
- )
- nu_ent = has_children(i) * -beta_logpdf(
- nu_sticks[i], nu_sticks_log_alphas[i], nu_sticks_log_betas[i]
- )
- kl += nu_pl + nu_ent
-
- psi_pl = has_next_branch(i) * beta_logpdf(
- psi_sticks[i],
- jnp.log(jnp.array([1.0])),
- jnp.log(jnp.array([dp_gammas[i]])),
- )
- psi_ent = has_next_branch(i) * -beta_logpdf(
- psi_sticks[i], psi_sticks_log_alphas[i], psi_sticks_log_betas[i]
- )
- kl += psi_pl + psi_ent
-
- return kl
-
- def get_node_kl(i):
- return jnp.where(
- node_mask[i] == 1,
- compute_node_kl(jnp.where(node_mask[i] == 1, i, 0)),
- 0.0,
- )
-
- node_kls = vmap(get_node_kl)(jnp.arange(len(parent_vector)))
- node_kl = jnp.sum(node_kls)
-
- # Global vars KL
- baseline_kl = diag_gaussian_logpdf(
- log_baseline, zeros_vec[1:], zeros_vec[1:]
- ) - diag_gaussian_logpdf(log_baseline, log_baseline_mean, log_baseline_log_std)
- ones_mat = jnp.ones(log_factors_precisions.shape)
- zeros_mat = jnp.zeros(log_factors_precisions.shape)
- factor_precision_kl = diag_gamma_logpdf(
- jnp.exp(log_factors_precisions),
- jnp.log(global_noise_factors_precisions_shape) * ones_mat,
- zeros_mat,
- ) - diag_loggaussian_logpdf(
- jnp.exp(log_factors_precisions),
- factor_precision_log_means,
- factor_precision_log_stds,
- )
- noise_factors_kl = diag_gaussian_logpdf(
- noise_factors,
- jnp.zeros(noise_factors.shape),
- jnp.log(jnp.sqrt(1.0 / jnp.exp(log_factors_precisions)).reshape(-1, 1))
- * jnp.ones(noise_factors.shape),
- ) - diag_gaussian_logpdf(noise_factors, noise_factors_mean, noise_factors_log_std)
- batch_effects_kl = diag_gaussian_logpdf(
- batch_effects_factors,
- jnp.zeros(batch_effects_factors.shape),
- jnp.zeros(batch_effects_factors.shape),
- ) - diag_gaussian_logpdf(
- batch_effects_factors, batch_effects_mean, batch_effects_log_std
- )
- total_kl = (
- node_kl
- + baseline_kl
- + factor_precision_kl
- + noise_factors_kl
- + batch_effects_kl
- )
-
- # Scale the KL by the data size
- total_kl = total_kl * jnp.sum(data_mask_subset != 0) / data.shape[0]
-
- zeros_mat = jnp.zeros(cell_noise.shape)
- cell_noise_kl = diag_gaussian_logpdf(
- cell_noise, zeros_mat, zeros_mat, axis=1
- ) - diag_gaussian_logpdf(cell_noise, cell_noise_mean, cell_noise_log_std, axis=1)
- cell_noise_kl = jnp.sum(cell_noise_kl * data_mask_subset)
- total_kl = total_kl + cell_noise_kl
-
- elbo_val = l + total_kl
-
- return elbo_val, l, total_kl, node_kls
-
-
-def batch_elbo(
- rng,
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- params,
- num_samples,
-):
- # Average over a batch of random samples from the var approx.
- rngs = random.split(rng, num_samples)
- init = [0]
- init.extend([None] * (23 + len(params)))
- vectorized_elbo = vmap(compute_elbo, in_axes=init)
- return jnp.mean(
- vectorized_elbo(
- rngs,
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- *params,
- )
- )
-
-
-@partial(jit, static_argnums=(3, 4, 5, 6, 23))
-def objective(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- num_samples,
- params,
- t,
-):
- rng = random.PRNGKey(t)
- return -batch_elbo(
- rng,
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- params,
- num_samples,
- )
-
-
-@partial(jit, static_argnums=(3, 4, 5, 6, 21, 22))
-def batch_objective(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- num_samples,
- num_data,
- params,
- t,
-):
- rng = random.PRNGKey(t)
- # Average over a batch of random samples from the var approx.
- rngs = random.split(rng, num_samples)
- init = [0]
- init.extend([None] * (23 + len(params)))
- vectorized_elbo = vmap(_compute_elbo, in_axes=init)
- elbos, lls, kls, node_kls = vectorized_elbo(
- rngs,
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- jnp.ones((num_data,)),
- jnp.arange(num_data),
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- *params,
- )
- elbo = jnp.mean(elbos)
- ll = jnp.mean(lls)
- kl = jnp.mean(kls)
- node_kl = node_kls
- return elbo, ll, kl, node_kl
-
-
-@partial(jit, static_argnums=(3, 4, 5, 6, 23))
-def do_grad(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- num_samples,
- params,
- i,
-):
- return jax.value_and_grad(objective, argnums=24)(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- num_samples,
- params,
- i,
- )
-
-
-def update(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- num_samples,
- i,
- opt_state,
- opt_update,
- get_params,
-):
- # print("Recompiling update!")
- params = get_params(opt_state)
- value, gradient = do_grad(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- node_mask,
- data_mask_subset,
- indices,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- num_samples,
- params,
- i,
- )
- opt_state = opt_update(i, gradient, opt_state)
- return opt_state, gradient, params, value
diff --git a/scatrex/models/cna/tree.py b/scatrex/models/cna/tree.py
index 233d87b..0e77309 100644
--- a/scatrex/models/cna/tree.py
+++ b/scatrex/models/cna/tree.py
@@ -1,11 +1,11 @@
import numpy as np
import matplotlib
-from ...util import *
-from ...ntssb.node import *
-from ...ntssb.tree import *
+from .node import CNANode
+from ...utils.math_utils import *
+from ...ntssb.observed_tree import *
from ...plotting import *
-
+from ...utils.tree_utils import dict_to_tree
def get_cnv_cmap(vmax=4, vmid=2):
# Extend amplification colors beyond 4
@@ -31,47 +31,126 @@ def get_cnv_cmap(vmax=4, vmid=2):
return cmap
-class ObservedTree(Tree):
+class CNATree(ObservedTree):
def __init__(self, **kwargs):
- super(ObservedTree, self).__init__(**kwargs)
+ super(CNATree, self).__init__(**kwargs)
+ self.node_constructor = CNANode
self.cmap = get_cnv_cmap()
self.sign_colors = {"-": "blue", "+": "red"}
- def add_node_params(
+ def sample_kernel(self, parent_params, min_nevents=1, max_nevents_frac=0.67, min_cn=0, seed=None, **kwargs):
+ n_genes = parent_params.shape[0]
+ i = 0
+ if seed is None:
+ seed = self.seed
+ while True: # Rejection sampling
+ cnvs = np.array(parent_params)
+
+ i += 1
+ # Sample number of regions to be affected
+ rng = np.random.default_rng(seed=seed + i)
+ n_r = rng.choice(
+ np.arange(
+ min_nevents, np.max([int(max_nevents_frac * self.n_regions), 2])
+ )
+ )
+
+ # Sample regions to be affected
+ rng = np.random.default_rng(seed=seed + i + 1)
+ affected_regions = rng.choice(
+ np.arange(0, self.n_regions), size=n_r, replace=False
+ )
+
+ all_affected_genes = []
+ for r in affected_regions:
+ # Apply event to region
+ if r > 0:
+ affected_genes = np.arange(self.region_stops[r - 1], self.region_stops[r])
+ elif r == 0:
+ affected_genes = np.arange(0, self.region_stops[r])
+
+ all_affected_genes.append(affected_genes)
+
+ if np.any(parent_params[affected_genes]) == 0:
+ continue
+
+ # Sample event sign
+ rng = np.random.default_rng(seed=seed + i + 2 + r)
+ s = rng.choice([-1, 1])
+
+ if np.any(parent_params[affected_genes] < 2):
+ s = -1
+ elif np.any(parent_params[affected_genes] > 2):
+ s = 1
+
+ # Sample event magnitude
+ rng = np.random.default_rng(seed=seed + i + 3 + r)
+ m = np.max([1, rng.poisson(0.5)])
+
+ # Record event
+ clone_cn_events_genes = np.zeros((n_genes,))
+ clone_cn_events_genes[affected_genes] = s * m
+
+ cnvs[affected_genes] = (
+ parent_params[affected_genes]
+ + clone_cn_events_genes[affected_genes]
+ )
+
+
+ all_affected_genes = np.concatenate(all_affected_genes)
+ if np.all(cnvs[all_affected_genes] >= min_cn):
+ break
+
+ return cnvs
+
+ def sample_root(self, n_genes=50, n_regions=5, **kwargs):
+ self.n_regions = np.max([n_regions, 3])
+ # Define regions
+ rng = np.random.default_rng(self.seed)
+ self.region_stops = np.sort(
+ rng.choice(np.arange(n_genes), size=n_regions, replace=False)
+ )
+ return 2*np.ones((n_genes,))
+
+ def _add_node_params(
self, n_genes=50, n_regions=5, min_nevents=1, max_nevents_frac=0.67, min_cn=0
):
C = len(self.tree_dict.keys())
n_regions = np.max([n_regions, 3])
# Define regions
+ rng = np.random.default_rng(self.seed)
region_stops = np.sort(
- np.random.choice(np.arange(n_genes), size=n_regions, replace=False)
+ rng.choice(np.arange(n_genes), size=n_regions, replace=False)
)
# Trasverse the tree and generate events for each node
for node in self.tree_dict:
if self.tree_dict[node]["parent"] == "-1":
- self.tree_dict[node]["params"] = np.ones((n_genes,)) * 2
+ self.tree_dict[node]["param"] = np.ones((n_genes,)) * 2
self.tree_dict[node]["params_label"] = ""
continue
parent_params = np.array(
- self.tree_dict[self.tree_dict[node]["parent"]]["params"]
+ self.tree_dict[self.tree_dict[node]["parent"]]["param"]
)
-
+ i = 0
while True:
- self.tree_dict[node]["params"] = np.array(
- self.tree_dict[self.tree_dict[node]["parent"]]["params"]
+ i += 1
+ self.tree_dict[node]["param"] = np.array(
+ self.tree_dict[self.tree_dict[node]["parent"]]["param"]
)
self.tree_dict[node]["params_label"] = ""
# Sample number of regions to be affected
- n_r = np.random.choice(
+ rng = np.random.default_rng(seed=self.seed + i)
+ n_r = rng.choice(
np.arange(
min_nevents, np.max([int(max_nevents_frac * n_regions), 2])
)
)
# Sample regions to be affected
- affected_regions = np.random.choice(
+ rng = np.random.default_rng(seed=self.seed + i + 1)
+ affected_regions = rng.choice(
np.arange(0, n_regions), size=n_r, replace=False
)
@@ -89,7 +168,8 @@ def add_node_params(
continue
# Sample event sign
- s = np.random.choice([-1, 1])
+ rng = np.random.default_rng(seed=self.seed + i + 2 + r)
+ s = rng.choice([-1, 1])
if np.any(parent_params[affected_genes] < 2):
s = -1
@@ -97,13 +177,14 @@ def add_node_params(
s = 1
# Sample event magnitude
- m = np.max([1, np.random.poisson(0.5)])
+ rng = np.random.default_rng(seed=self.seed + i + 3 + r)
+ m = np.max([1, rng.poisson(0.5)])
# Record event
clone_cn_events_genes = np.zeros((n_genes,))
clone_cn_events_genes[affected_genes] = s * m
- self.tree_dict[node]["params"][affected_genes] = (
+ self.tree_dict[node]["param"][affected_genes] = (
parent_params[affected_genes]
+ clone_cn_events_genes[affected_genes]
)
@@ -127,9 +208,11 @@ def add_node_params(
)
all_affected_genes = np.concatenate(all_affected_genes)
- if np.all(self.tree_dict[node]["params"][all_affected_genes] >= min_cn):
+ if np.all(self.tree_dict[node]["param"][all_affected_genes] >= min_cn):
break
+ self.tree = dict_to_tree(self.tree_dict)
+
self.create_adata()
def get_affected_genes(self):
@@ -138,7 +221,7 @@ def get_affected_genes(self):
def set_neutral_nodes(self, thres=0.95, neutral_level=2):
for node in self.tree_dict:
self.tree_dict[node]["is_neutral"] = False
- cnvs = self.tree_dict[node]["params"].ravel()
+ cnvs = self.tree_dict[node]["param"].ravel()
frac_neutral = np.sum(cnvs == neutral_level) / cnvs.size
if frac_neutral > thres:
self.tree_dict[node]["is_neutral"] = True
diff --git a/scatrex/models/trajectory/__init__.py b/scatrex/models/trajectory/__init__.py
new file mode 100644
index 0000000..ff72ac8
--- /dev/null
+++ b/scatrex/models/trajectory/__init__.py
@@ -0,0 +1,2 @@
+from .tree import TrajectoryTree
+from .node import TrajectoryNode
diff --git a/scatrex/models/trajectory/node.py b/scatrex/models/trajectory/node.py
new file mode 100644
index 0000000..10009bc
--- /dev/null
+++ b/scatrex/models/trajectory/node.py
@@ -0,0 +1,772 @@
+from numpy import *
+import numpy as np
+from numpy.random import *
+
+from functools import partial
+import jax.numpy as jnp
+import jax
+import tensorflow_probability.substrates.jax.distributions as tfd
+
+from .node_opt import * # node optimization functions
+from .node_opt import _mc_obs_ll
+from ...utils.math_utils import *
+from ...ntssb.node import *
+
+def update_params(params, params_gradient, step_size):
+ new_params = []
+ for i, param in enumerate(params):
+ new_params.append(param + step_size * params_gradient[i])
+ return new_params
+
+class TrajectoryNode(AbstractNode):
+ def __init__(
+ self,
+ observed_parameters, # subtree root location and angle
+ root_loc_mean=2.,
+ loc_mean=.5,
+ angle_concentration=10.,
+ loc_variance=.1,
+ obs_variance=.1,
+ n_factors=2,
+ obs_weight_variance=1.,
+ factor_variance=1.,
+ **kwargs,
+ ):
+ """
+ This model generates nodes in a 2D space by sampling an angle, which roughly follows
+ the angle at which the parent was generated, and a location around some radius.
+ TODO: Make the location prior a Gamma instead of a Normal, still centered around some radius
+ but with a peak at zero and a long tail, to allow for more or less close by nodes
+ """
+ super(TrajectoryNode, self).__init__(observed_parameters, **kwargs)
+
+ self.n_genes = self.observed_parameters[0].size
+
+ # Node hyperparameters
+ if self.parent() is None:
+ self.node_hyperparams = dict(
+ root_loc_mean=root_loc_mean,
+ angle_concentration=angle_concentration,
+ loc_mean=loc_mean,
+ loc_variance=loc_variance,
+ obs_variance=obs_variance,
+ n_factors=n_factors,
+ obs_weight_variance=obs_weight_variance,
+ factor_variance=factor_variance,
+ )
+ else:
+ self.node_hyperparams = self.node_hyperparams_caller()
+
+ self.reset_parameters(**self.node_hyperparams)
+
+ if self.tssb is not None:
+ self.reset_variational_parameters()
+ self.sample_variational_distributions()
+ self.reset_sufficient_statistics(self.tssb.ntssb.num_batches)
+
+ def combine_params(self):
+ return self.params[0] # get loc
+
+ def get_mean(self):
+ return self.combine_params()
+
+ def set_mean(self, node_mean=None):
+ if node_mean is not None:
+ self.node_mean = node_mean
+ else:
+ self.node_mean = self.get_mean()
+
+ def get_observed_parameters(self):
+ return self.observed_parameters[0] # get root loc
+
+ def get_params(self):
+ return self.get_mean()
+
+ def get_param(self, param='mean'):
+ if param == 'observed':
+ return self.get_observed_parameters()
+ elif param == 'mean':
+ return self.get_mean()
+ else:
+ raise ValueError(f"No param available for `{param}`")
+
+ def remove_noise(self, data):
+ """
+ Noise is additive in this model
+ """
+ return data - self.noise_factors_caller()
+
+ # ========= Functions to initialize node. =========
+ def set_node_hyperparams(self, **kwargs):
+ self.node_hyperparams.update(**kwargs)
+
+ def reset_parameters(
+ self,
+ root_loc_mean=2.,
+ loc_mean=.5,
+ angle_concentration=10.,
+ loc_variance=.1,
+ obs_variance=.1,
+ n_factors=2,
+ obs_weight_variance=1.,
+ factor_variance=1.,
+ ):
+ self.node_hyperparams = dict(
+ root_loc_mean=root_loc_mean,
+ angle_concentration=angle_concentration,
+ loc_mean=loc_mean,
+ loc_variance=loc_variance,
+ obs_variance=obs_variance,
+ n_factors=n_factors,
+ obs_weight_variance=obs_weight_variance,
+ factor_variance=factor_variance,
+ )
+
+ parent = self.parent()
+
+ if parent is None:
+ self.depth = 0.0
+ self.params = self.observed_parameters # loc and angle
+
+ n_factors = self.node_hyperparams['n_factors']
+ factor_variance = self.node_hyperparams['factor_variance']
+ rng = np.random.default_rng(seed=self.seed)
+ self.factor_weights = rng.normal(0., np.sqrt(factor_variance), size=(n_factors, 2)) * 1./3.
+
+ if n_factors > 0:
+ n_genes_per_factor = int(2/n_factors)
+ offset = 6.
+ perm = np.random.permutation(2)
+ for factor in range(n_factors):
+ gene_idx = perm[factor*n_genes_per_factor:(factor+1)*n_genes_per_factor]
+ self.factor_weights[factor,gene_idx] *= offset
+
+
+ # Set data-dependent parameters
+ if self.tssb is not None:
+ num_data = self.tssb.ntssb.num_data
+ if num_data is not None:
+ self.reset_data_parameters()
+
+ elif parent.tssb != self.tssb:
+ self.depth = 0.0
+ self.params = self.observed_parameters # loc and angle
+ else: # Non-root node: inherits everything from upstream node
+ self.depth = parent.depth + 1
+ loc_mean = self.node_hyperparams['loc_mean']
+ angle_concentration = self.node_hyperparams['angle_concentration'] * self.depth
+ rng = np.random.default_rng(seed=self.seed)
+ sampled_angle = rng.vonmises(parent.params[1], angle_concentration)
+ sampled_loc = rng.normal(
+ loc_mean,
+ self.node_hyperparams['loc_variance']
+ )
+ sampled_loc = parent.params[0] + np.array([np.cos(sampled_angle)*np.abs(sampled_loc), np.sin(sampled_angle)*np.abs(sampled_loc)])
+ self.params = [sampled_loc, sampled_angle]
+
+ self.set_mean()
+
+ # Generate structure on the factors
+ def reset_data_parameters(self):
+ num_data = self.tssb.ntssb.num_data
+ n_factors = self.node_hyperparams['n_factors']
+ rng = np.random.default_rng(seed=self.seed)
+ self.obs_weights = rng.normal(0., 1., size=(num_data, n_factors)) * 1./3.
+ if n_factors > 0:
+ n_obs_per_factor = int(num_data/n_factors)
+ offset = 6.
+ perm = np.random.permutation(num_data)
+ for factor in range(n_factors):
+ obs_idx = perm[factor*n_obs_per_factor:(factor+1)*n_obs_per_factor]
+ self.obs_weights[obs_idx,factor] *= offset
+
+ self.noise_factors = self.obs_weights.dot(self.factor_weights)
+
+ def reset_variational_parameters(self):
+ # Assignments
+ num_data = self.tssb.ntssb.num_data
+ if num_data is not None:
+ self.variational_parameters['q_z'] = jnp.ones(num_data,)
+
+ self.variational_parameters['sum_E_log_1_nu'] = 0.
+ self.variational_parameters['E_log_phi'] = 0.
+
+ # Sticks
+ self.variational_parameters["delta_1"] = 1.
+ self.variational_parameters["delta_2"] = 1.
+ self.variational_parameters["sigma_1"] = 1.
+ self.variational_parameters["sigma_2"] = 1.
+
+ # Pivots
+ self.variational_parameters["q_rho"] = np.ones(len(self.tssb.children_root_nodes),)
+
+ parent = self.parent()
+ if parent is None and self.tssb.parent() is None:
+ rng = np.random.default_rng(self.seed)
+ # root stores global parameters
+ n_factors = self.node_hyperparams['n_factors']
+ self.variational_parameters["global"] = {
+ 'factor_weights': {'mean': jnp.array(self.node_hyperparams['factor_variance']/10.*rng.normal(size=(n_factors, 2))),
+ 'log_std': -2. + jnp.zeros((n_factors, 2))}
+ }
+ if num_data is not None:
+ rng = np.random.default_rng(self.seed+1)
+ self.variational_parameters["local"] = {
+ 'obs_weights': {'mean': jnp.array(self.node_hyperparams['obs_weight_variance']/10.*rng.normal(size=(num_data, n_factors))),
+ 'log_std': -2. + jnp.zeros((num_data, n_factors))}
+ }
+ self.obs_weights = self.variational_parameters["local"]["obs_weights"]["mean"]
+ self.factor_weights = self.variational_parameters["global"]["factor_weights"]["mean"]
+ self.noise_factors = self.obs_weights.dot(self.factor_weights)
+ elif parent is None:
+ return # no variational parameters for root nodes of TSSBs in this model
+ else: # only the non-root nodes have variational parameters
+ # Kernel
+ radius = self.node_hyperparams['loc_mean']
+
+ if "angle" not in parent.variational_parameters["kernel"]:
+ mean_angle = jnp.array([parent.observed_parameters[1]])
+ parent_loc = jnp.array(parent.observed_parameters[0])
+ else:
+ mean_angle = parent.variational_parameters["kernel"]["angle"]["mean"]
+ parent_loc = parent.variational_parameters["kernel"]["loc"]["mean"]
+
+ rng = np.random.default_rng(self.seed+2)
+ mean_angle = rng.vonmises(mean_angle, self.node_hyperparams['angle_concentration'] * self.depth)
+ mean_loc = parent_loc + jnp.array([np.cos(mean_angle[0])*radius, jnp.sin(mean_angle[0])*radius])
+ rng = np.random.default_rng(self.seed+3)
+ mean_loc = rng.normal(mean_loc, self.node_hyperparams['loc_variance'])
+ self.variational_parameters["kernel"] = {
+ 'angle': {'mean': jnp.array(mean_angle), 'log_kappa': jnp.array([-1.])},
+ 'loc': {'mean': jnp.array(mean_loc), 'log_std': jnp.array([-1., -1.])}
+ }
+ self.params = [self.variational_parameters["kernel"]["loc"]["mean"],
+ self.variational_parameters["kernel"]["angle"]["mean"]]
+
+ def set_learned_parameters(self):
+ if self.parent() is None and self.tssb.parent() is None:
+ self.obs_weights = self.variational_parameters["local"]["obs_weights"]["mean"]
+ self.factor_weights = self.variational_parameters["global"]["factor_weights"]["mean"]
+ self.noise_factors = self.obs_weights.dot(self.factor_weights)
+ elif self.parent() is None:
+ self.params = self.observed_parameters
+ else:
+ self.params = [self.variational_parameters["kernel"]["loc"]["mean"],
+ self.variational_parameters["kernel"]["angle"]["mean"]]
+
+ def reset_sufficient_statistics(self, num_batches=1):
+ self.suff_stats = {
+ 'ent': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(c_n = this tree) q(z_n = this node) * log q(z_n = this node)
+ 'mass': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node)
+ 'A': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g x_ng ** 2
+ 'B_g': {'total': 0, 'batch': np.zeros((num_batches,2))}, # \sum_n q(z_n = this node) * x_ng
+ 'C': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g x_ng * E[s_nW_g]
+ 'D_g': {'total': 0, 'batch': np.zeros((num_batches,2))}, # \sum_n q(z_n = this node) * E[s_nW_g]
+ 'E': {'total': 0, 'batch': np.zeros((num_batches,))}, # \sum_n q(z_n = this node) * \sum_g E[(s_nW_g)**2]
+ }
+ if self.parent() is None and self.tssb.parent() is None:
+ self.local_suff_stats = {
+ 'locals_kl': {'total': 0., 'batch': np.zeros((num_batches,))},
+ }
+
+ def merge_suff_stats(self, suff_stats):
+ for stat in self.suff_stats:
+ self.suff_stats[stat]['total'] += suff_stats[stat]['total']
+ self.suff_stats[stat]['batch'] += suff_stats[stat]['batch']
+
+ def update_sufficient_statistics(self, batch_idx=None):
+ if batch_idx is not None:
+ idx = self.tssb.ntssb.batch_indices[batch_idx]
+ else:
+ idx = jnp.arange(self.tssb.ntssb.num_data)
+
+ if self.parent() is None and self.tssb.parent() is None:
+ locals_kl = self.compute_local_priors(idx) + self.compute_local_entropies(idx)
+ if batch_idx is not None:
+ self.local_suff_stats['locals_kl']['total'] -= self.local_suff_stats['locals_kl']['batch'][batch_idx]
+ self.local_suff_stats['locals_kl']['batch'][batch_idx] = locals_kl
+ self.local_suff_stats['locals_kl']['total'] += self.local_suff_stats['locals_kl']['batch'][batch_idx]
+ else:
+ self.local_suff_stats['locals_kl']['total'] = locals_kl
+
+ ent = assignment_entropies(self.variational_parameters['q_z'][idx])
+ ent *= self.tssb.variational_parameters['q_c'][idx]
+ E_ass = self.variational_parameters['q_z'][idx] * self.tssb.variational_parameters['q_c'][idx]
+ E_sw = jnp.mean(self.get_noise_sample(idx),axis=0)
+ E_sw2 = jnp.mean(self.get_noise_sample(idx)**2,axis=0)
+ x = self.tssb.ntssb.data[idx]
+
+ new_ent = jnp.sum(ent)
+ new_mass = jnp.sum(E_ass)
+ new_A = jnp.sum(E_ass * jnp.sum(x**2, axis=1))
+ new_B = jnp.sum(E_ass[:,None] * x, axis=0)
+ new_C = jnp.sum(E_ass * jnp.sum(x * E_sw, axis=1))
+ new_D = jnp.sum(E_ass[:,None] * E_sw, axis=0)
+ new_E = jnp.sum(E_ass * jnp.sum(E_sw2, axis=1))
+
+ if batch_idx is not None:
+ self.suff_stats['ent']['total'] -= self.suff_stats['ent']['batch'][batch_idx]
+ self.suff_stats['ent']['batch'][batch_idx] = new_ent
+ self.suff_stats['ent']['total'] += self.suff_stats['ent']['batch'][batch_idx]
+
+ self.suff_stats['mass']['total'] -= self.suff_stats['mass']['batch'][batch_idx]
+ self.suff_stats['mass']['batch'][batch_idx] = new_mass
+ self.suff_stats['mass']['total'] += self.suff_stats['mass']['batch'][batch_idx]
+
+ self.suff_stats['A']['total'] -= self.suff_stats['A']['batch'][batch_idx]
+ self.suff_stats['A']['batch'][batch_idx] = new_A
+ self.suff_stats['A']['total'] += self.suff_stats['A']['batch'][batch_idx]
+
+ self.suff_stats['B_g']['total'] -= self.suff_stats['B_g']['batch'][batch_idx]
+ self.suff_stats['B_g']['batch'][batch_idx] = new_B
+ self.suff_stats['B_g']['total'] += self.suff_stats['B_g']['batch'][batch_idx]
+
+ self.suff_stats['C']['total'] -= self.suff_stats['C']['batch'][batch_idx]
+ self.suff_stats['C']['batch'][batch_idx] = new_C
+ self.suff_stats['C']['total'] += self.suff_stats['C']['batch'][batch_idx]
+
+ self.suff_stats['D_g']['total'] -= self.suff_stats['D_g']['batch'][batch_idx]
+ self.suff_stats['D_g']['batch'][batch_idx] = new_D
+ self.suff_stats['D_g']['total'] += self.suff_stats['D_g']['batch'][batch_idx]
+
+ self.suff_stats['E']['total'] -= self.suff_stats['E']['batch'][batch_idx]
+ self.suff_stats['E']['batch'][batch_idx] = new_E
+ self.suff_stats['E']['total'] += self.suff_stats['E']['batch'][batch_idx]
+ else:
+ self.suff_stats['ent']['total'] = new_ent
+ self.suff_stats['mass']['total'] = new_mass
+ self.suff_stats['A']['total'] = new_A
+ self.suff_stats['B_g']['total'] = new_B
+ self.suff_stats['C']['total'] = new_C
+ self.suff_stats['D_g']['total'] = new_D
+ self.suff_stats['E']['total'] = new_E
+
+ # ========= Functions to take samples from node. =========
+ def sample_observation(self, n):
+ node_mean = self.get_mean()
+ noise_factors = self.noise_factors_caller()[n]
+ rng = np.random.default_rng(seed=self.seed+n)
+ s = rng.normal(node_mean + noise_factors, self.node_hyperparams['obs_variance'])
+ return s
+
+ def sample_observations(self):
+ n_obs = len(self.data)
+ node_mean = self.get_mean()
+ noise_factors = self.noise_factors_caller()[np.array(list(self.data))]
+ rng = np.random.default_rng(seed=self.seed)
+ s = rng.normal(node_mean + noise_factors, self.node_hyperparams['obs_variance'], size=[n_obs, self.n_genes])
+ return s
+
+ # ========= Functions to access root's parameters. =========
+ def node_hyperparams_caller(self):
+ if self.parent() is None:
+ return self.node_hyperparams
+ else:
+ return self.parent().node_hyperparams_caller()
+
+ def noise_factors_caller(self):
+ return self.tssb.ntssb.root['node'].root['node'].noise_factors
+
+ def get_obs_weights_sample(self):
+ return self.tssb.ntssb.root['node'].root['node'].obs_weights_sample
+
+ def set_local_sample(self, sample, idx=None):
+ if idx is None:
+ idx = jnp.arange(self.tssb.ntssb.num_data)
+ self.tssb.ntssb.root['node'].root['node'].obs_weights_sample = self.tssb.ntssb.root['node'].root['node'].obs_weights_sample.at[:,idx].set(sample)
+
+ def get_factor_weights_sample(self):
+ return self.tssb.ntssb.root['node'].root['node'].factor_weights_sample
+
+ def set_global_sample(self, sample):
+ self.tssb.ntssb.root['node'].root['node'].factor_weights_sample = jnp.array(sample)
+
+ def get_noise_sample(self, idx):
+ obs_weights = self.get_obs_weights_sample()[:,idx]
+ factor_weights = self.get_factor_weights_sample()
+ return jax.vmap(sample_prod, in_axes=(0,0))(obs_weights,factor_weights)
+
+ def get_direction_sample(self):
+ return self.samples[1]
+
+ def get_state_sample(self):
+ return self.samples[0]
+
+ def get_prior_angle_concentration(self, depth=None):
+ if depth is None:
+ depth = self.depth
+ return self.node_hyperparams['angle_concentration'] * jnp.maximum(depth, 1) # Prior hyperparameter
+
+ # ======== Functions using the variational parameters. =========
+ def compute_loglikelihood(self, idx):
+ # Use stored samples for loc
+ node_mean_samples = self.samples[0]
+ obs_weights_samples = self.get_obs_weights_sample()[:,idx]
+ factor_weights_samples = self.get_factor_weights_sample()
+ std = jnp.sqrt(self.node_hyperparams['obs_variance'])
+ # Average over samples for each observation
+ ll = jnp.mean(jax.vmap(_mc_obs_ll, in_axes=[None,0,0,0,None])(self.tssb.ntssb.data[idx],
+ node_mean_samples,
+ obs_weights_samples,
+ factor_weights_samples,
+ std), axis=0) # mean over MC samples
+ return ll
+
+ def compute_loglikelihood_suff(self):
+ node_mean_samples = self.samples[0]
+ std = jnp.sqrt(self.node_hyperparams['obs_variance'])
+ ll = jnp.mean(jax.vmap(ll_suffstats, in_axes=[0, None, None, None, None, None, None, None])
+ (node_mean_samples,self.suff_stats['mass']['total'],self.suff_stats['A']['total'],
+ self.suff_stats['B_g']['total'], self.suff_stats['C']['total'],
+ self.suff_stats['D_g']['total'],self.suff_stats['E']['total'],std)
+ )
+ return ll
+
+ def sample_variational_distributions(self, n_samples=10):
+ if self.parent() is not None:
+ if self.parent().samples is not None:
+ n_samples = self.parent().samples[0].shape[0]
+ if self.parent() is None and self.tssb.parent() is None:
+ self.sample_locals(n_samples=n_samples, store=True)
+ self.sample_globals(n_samples=n_samples, store=True)
+ self.sample_kernel(n_samples=n_samples, store=True)
+
+ def sample_locals(self, n_samples, store=True):
+ key = jax.random.PRNGKey(self.seed)
+ key, sample_grad = self.local_sample_and_grad(jnp.arange(self.tssb.ntssb.num_data), key, n_samples=n_samples)
+ sampled_obs_weights, _ = sample_grad
+ if store:
+ self.obs_weights_sample = sampled_obs_weights
+ else:
+ return sampled_obs_weights
+
+ def sample_globals(self, n_samples, store=True):
+ key = jax.random.PRNGKey(self.seed)
+ key, sample_grad = self.global_sample_and_grad(key, n_samples=n_samples)
+ sampled_factor_weights, _ = sample_grad
+ if store:
+ self.factor_weights_sample = sampled_factor_weights
+ else:
+ return sampled_factor_weights
+
+ def sample_kernel(self, n_samples=10, store=True):
+ parent = self.parent()
+ if parent is None:
+ return self._sample_root_kernel(n_samples=n_samples, store=store)
+
+ key = jax.random.PRNGKey(self.seed)
+ key, sample_grad = self.direction_sample_and_grad(key, n_samples=n_samples)
+ sampled_angle, _ = sample_grad
+
+ key, sample_grad = self.state_sample_and_grad(key, n_samples=n_samples)
+ sampled_loc, _ = sample_grad
+
+ samples = [sampled_loc, sampled_angle]
+ if store:
+ self.samples = samples
+ else:
+ return samples
+
+ def _sample_root_kernel(self, n_samples=10, store=True):
+ # In this model the root is just the known parameters, so just store n_samples copies of them to mimic a sample
+ observed_angle = jnp.array([self.observed_parameters[1]]) # Observed angle
+ observed_loc = jnp.array([self.observed_parameters[0]]) # Observed location
+ sampled_angle = jnp.vstack(jnp.repeat(observed_angle, n_samples, axis=0))
+ sampled_loc = jnp.vstack(jnp.repeat(observed_loc, n_samples, axis=0))
+ samples = [sampled_loc, sampled_angle]
+ if store:
+ self.samples = samples
+ else:
+ return samples
+
+ def compute_global_priors(self):
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['factor_variance']))
+ return jnp.mean(mc_factor_weights_logp_val_and_grad(self.factor_weights_sample, 0., log_std)[0])
+
+ def compute_local_priors(self, batch_indices):
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_weight_variance']))
+ return jnp.mean(mc_obs_weights_logp_val_and_grad(self.obs_weights_sample[:,batch_indices], 0., log_std)[0])
+
+ def compute_global_entropies(self):
+ mean = self.variational_parameters['global']['factor_weights']['mean']
+ log_std = self.variational_parameters['global']['factor_weights']['log_std']
+ return jnp.sum(factor_weights_logq_val_and_grad(mean, log_std)[0])
+
+ def compute_local_entropies(self, batch_indices):
+ mean = self.variational_parameters['local']['obs_weights']['mean'][batch_indices]
+ log_std = self.variational_parameters['local']['obs_weights']['log_std'][batch_indices]
+ return jnp.sum(obs_weights_logq_val_and_grad(mean, log_std)[0])
+
+ def compute_kernel_prior(self):
+ parent = self.parent()
+ if parent is None:
+ return self.compute_root_prior()
+
+ prior_mean_angle = self.parent().get_direction_sample()
+ prior_angle_concentration = self.get_prior_angle_concentration()
+ angle_samples = self.get_direction_sample()
+ angle_logpdf = mc_angle_logp_val_and_grad(angle_samples, prior_mean_angle, prior_angle_concentration)[0]
+
+ radius = self.node_hyperparams['loc_mean']
+ parent_loc = self.parent().get_state_sample()
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance']))
+ loc_samples = self.get_state_sample()
+ loc_logpdf = mc_loc_logp_val_and_grad(loc_samples, parent_loc, angle_samples, log_std, radius)[0]
+
+ return jnp.mean(angle_logpdf + loc_logpdf)
+
+ def compute_root_direction_prior(self, parent_alpha):
+ concentration = self.get_prior_angle_concentration()
+ alpha = self.get_direction_sample()
+ return jnp.mean(mc_angle_logp_val_and_grad(alpha, parent_alpha, concentration)[0])
+
+ def compute_root_state_prior(self, parent_psi):
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance']))
+ radius = self.node_hyperparams['root_loc_mean']
+ psi = self.get_state_sample()
+ alpha = self.get_direction_sample()
+ return jnp.mean(mc_loc_logp_val_and_grad(psi, parent_psi, alpha, log_std, radius)[0])
+
+ def compute_root_kernel_prior(self, samples):
+ parent_alpha = samples[0]
+ logp = self.compute_root_direction_prior(parent_alpha)
+ parent_psi = samples[1]
+ logp += self.compute_root_state_prior(parent_psi)
+ return logp
+
+ def compute_root_prior(self):
+ return 0.
+
+ def compute_kernel_entropy(self):
+ parent = self.parent()
+ if parent is None:
+ return self.compute_root_entropy()
+
+ # Angle
+ angle_logpdf = tfd.VonMises(np.exp(self.variational_parameters['kernel']['angle']['mean']),
+ jnp.exp(self.variational_parameters['kernel']['angle']['log_kappa'])
+ ).entropy()
+ angle_logpdf = jnp.sum(angle_logpdf)
+
+ # Location
+ loc_logpdf = tfd.Normal(self.variational_parameters['kernel']['loc']['mean'],
+ jnp.exp(self.variational_parameters['kernel']['loc']['log_std'])
+ ).entropy()
+ loc_logpdf = jnp.sum(loc_logpdf) # Sum across features
+
+ return angle_logpdf + loc_logpdf
+
+ def compute_root_entropy(self):
+ # In this model the root nodes have no unknown parameters
+ return 0.
+
+ # ======== Functions for updating the variational parameters. =========
+ def local_sample_and_grad(self, idx, key, n_samples):
+ """Sample and take gradient of local parameters. Must be root"""
+ mean = self.variational_parameters['local']['obs_weights']['mean'][idx]
+ log_std = self.variational_parameters['local']['obs_weights']['log_std'][idx]
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ return key, mc_sample_obs_weights_val_and_grad(jnp.array(sub_keys), mean, log_std)
+
+ def global_sample_and_grad(self, key, n_samples):
+ """Sample and take gradient of global parameters. Must be root"""
+ mean = self.variational_parameters['global']['factor_weights']['mean']
+ log_std = self.variational_parameters['global']['factor_weights']['log_std']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ return key, mc_sample_factor_weights_val_and_grad(jnp.array(sub_keys), mean, log_std)
+
+ def compute_locals_prior_grad(self, sample):
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_weight_variance']))
+ return mc_obs_weights_logp_val_and_grad(sample, 0., log_std)[1]
+
+ def compute_globals_prior_grad(self, sample):
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['factor_variance']))
+ return mc_factor_weights_logp_val_and_grad(sample, 0., log_std)[1]
+
+ def compute_locals_entropy_grad(self, idx):
+ mean = self.variational_parameters['local']['obs_weights']['mean'][idx]
+ log_std = self.variational_parameters['local']['obs_weights']['log_std'][idx]
+ return obs_weights_logq_val_and_grad(mean, log_std)[1]
+
+ def compute_globals_entropy_grad(self):
+ mean = self.variational_parameters['global']['factor_weights']['mean']
+ log_std = self.variational_parameters['global']['factor_weights']['log_std']
+ return factor_weights_logq_val_and_grad(mean, log_std)[1]
+
+ def state_sample_and_grad(self, key, n_samples):
+ """Sample and take gradient of state"""
+ mu = self.variational_parameters['kernel']['loc']['mean']
+ log_std = self.variational_parameters['kernel']['loc']['log_std']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ return key, mc_sample_loc_val_and_grad(jnp.array(sub_keys), mu, log_std)
+
+ def direction_sample_and_grad(self, key, n_samples):
+ """Sample and take gradient of direction"""
+ mu = self.variational_parameters['kernel']['angle']['mean']
+ log_kappa = self.variational_parameters['kernel']['angle']['log_kappa']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ return key, mc_sample_angle_val_and_grad(jnp.array(sub_keys), mu, log_kappa)
+
+ def state_sample_and_grad(self, key, n_samples):
+ """Sample and take gradient of state"""
+ mu = self.variational_parameters['kernel']['loc']['mean']
+ log_std = self.variational_parameters['kernel']['loc']['log_std']
+ key, *sub_keys = jax.random.split(key, n_samples+1)
+ return key, mc_sample_loc_val_and_grad(jnp.array(sub_keys), mu, log_std)
+
+ def compute_direction_prior_grad_wrt_direction(self, alpha, parent_alpha, parent_loc):
+ """Gradient of logp(alpha|parent_alpha) wrt this alpha"""
+ concentration = self.get_prior_angle_concentration()
+ return mc_angle_logp_val_and_grad(alpha, parent_alpha, concentration)[1]
+
+ def compute_direction_prior_grad_wrt_state(self, alpha, parent_alpha, parent_loc):
+ """Gradient of logp(alpha|parent_alpha) wrt this alpha"""
+ return 0.
+
+ def compute_direction_prior_child_grad(self, child_alpha, alpha):
+ """Gradient of logp(child_alpha|alpha) wrt this alpha"""
+ concentration = self.get_prior_angle_concentration(depth=self.depth+1)
+ return mc_angle_logp_val_and_grad_wrt_parent(child_alpha, alpha, concentration)[1]
+
+ def compute_state_prior_grad(self, psi, parent_psi, alpha):
+ """Gradient of logp(psi|parent_psi,new_alpha) wrt this psi"""
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance']))
+ radius = self.node_hyperparams['loc_mean']
+ return mc_loc_logp_val_and_grad(psi, parent_psi, alpha, log_std, radius)[1]
+
+ def compute_state_prior_child_grad(self, child_psi, psi, child_alpha):
+ """Gradient of logp(child_psi|psi,child_alpha) wrt this psi"""
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance']))
+ radius = self.node_hyperparams['loc_mean']
+ return mc_loc_logp_val_and_grad_wrt_parent(child_psi, psi, child_alpha, log_std, radius)[1]
+
+ def compute_root_state_prior_child_grad(self, child_psi, psi, child_alpha):
+ """Gradient of logp(child_psi|psi,child_alpha) wrt this psi"""
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance']))
+ radius = self.node_hyperparams['root_loc_mean']
+ return mc_loc_logp_val_and_grad_wrt_parent(child_psi, psi, child_alpha, log_std, radius)[1]
+
+ def compute_state_prior_grad_wrt_direction(self, psi, parent_psi, alpha):
+ """Gradient of logp(psi|parent_psi,alpha) wrt this alpha"""
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['loc_variance']))
+ radius = self.node_hyperparams['loc_mean']
+ return mc_loc_logp_val_and_grad_wrt_angle(psi, parent_psi, alpha, log_std, radius)[1]
+
+ def compute_direction_entropy_grad(self):
+ """Gradient of logq(alpha) wrt this alpha"""
+ mu = self.variational_parameters['kernel']['angle']['mean']
+ log_kappa = self.variational_parameters['kernel']['angle']['log_kappa']
+ return angle_logq_val_and_grad(mu, log_kappa)[1]
+
+ def compute_state_entropy_grad(self):
+ """Gradient of logq(psi) wrt this psi"""
+ mu = self.variational_parameters['kernel']['loc']['mean']
+ log_std = self.variational_parameters['kernel']['loc']['log_std']
+ return loc_logq_val_and_grad(mu, log_std)[1]
+
+ def compute_ll_state_grad(self, x, weights, psi):
+ """Gradient of logp(x|psi,noise) wrt this psi"""
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_variance']))
+ locals = self.get_obs_weights_sample()
+ globals = self.get_factor_weights_sample()
+ return mc_ll_val_and_grad_psi(x, weights, psi, locals, globals, log_std)[1]
+
+ def compute_ll_state_grad_suff(self, psi):
+ """Gradient of logp(x|psi,noise) wrt this psi using suff stats"""
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_variance']))
+ return mc_ll_node_mean_suff_val_and_grad(psi, self.suff_stats['mass']['total'],
+ self.suff_stats['B_g']['total'],
+ self.suff_stats['D_g']['total'], log_std)[1]
+
+ def compute_ll_locals_grad(self, x, idx, weights):
+ """Gradient of logp(x|psi,locals,globals) wrt locals"""
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_variance']))
+ psi = self.get_state_sample()
+ locals = self.get_obs_weights_sample()[:,idx]
+ globals = self.get_factor_weights_sample()
+ return mc_ll_val_and_grad_obs_weights(x, weights, psi, locals, globals, log_std)[1]
+
+ def compute_ll_globals_grad(self, x, idx, weights):
+ """Gradient of logp(x|psi,locals,globals) wrt globals"""
+ log_std = jnp.log(jnp.sqrt(self.node_hyperparams['obs_variance']))
+ psi = self.get_state_sample()
+ locals = self.get_obs_weights_sample()[:,idx]
+ globals = self.get_factor_weights_sample()
+ return mc_ll_val_and_grad_factor_weights(x, weights, psi, locals, globals, log_std)[1]
+
+ def update_direction_params(self, direction_params_grad, direction_sample_grad, direction_params_entropy_grad, step_size=0.001):
+ mc_grad = jnp.mean(direction_params_grad[0] * direction_sample_grad, axis=0)
+ angle_mean_grad = mc_grad + direction_params_entropy_grad[0]
+ self.variational_parameters['kernel']['angle']['mean'] += angle_mean_grad * step_size
+
+ mc_grad = jnp.mean(direction_params_grad[1] * direction_sample_grad, axis=0)
+ angle_log_kappa_grad = mc_grad + direction_params_entropy_grad[1]
+ self.variational_parameters['kernel']['angle']['log_kappa'] += angle_log_kappa_grad * step_size
+
+ def update_state_params(self, state_params_grad, state_sample_grad, state_params_entropy_grad, step_size=0.001):
+ mc_grad = jnp.mean(state_params_grad[0] * state_sample_grad, axis=0)
+ loc_mean_grad = mc_grad + state_params_entropy_grad[0]
+ self.variational_parameters['kernel']['loc']['mean'] += loc_mean_grad * step_size
+
+ mc_grad = jnp.mean(state_params_grad[1] * state_sample_grad, axis=0)
+ loc_log_std_grad = mc_grad + state_params_entropy_grad[1]
+ self.variational_parameters['kernel']['loc']['log_std'] += loc_log_std_grad * step_size
+
+ def update_local_params(self, idx, local_params_grad, local_sample_grad, local_params_entropy_grad, ent_anneal=1., step_size=0.001):
+ mc_grad = jnp.mean(local_params_grad[0] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[0]
+ new_param = self.variational_parameters['local']['obs_weights']['mean'][idx] + param_grad * step_size
+ self.variational_parameters['local']['obs_weights']['mean'] = self.variational_parameters['local']['obs_weights']['mean'].at[idx].set(new_param)
+
+ mc_grad = jnp.mean(local_params_grad[1] * local_sample_grad, axis=0)
+ param_grad = mc_grad + ent_anneal * local_params_entropy_grad[1]
+ new_param = self.variational_parameters['local']['obs_weights']['log_std'][idx] + param_grad * step_size
+ self.variational_parameters['local']['obs_weights']['log_std'] = self.variational_parameters['local']['obs_weights']['log_std'].at[idx].set(new_param)
+
+ def update_global_params(self, global_params_grad, global_sample_grad, global_params_entropy_grad, step_size=0.001):
+ mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[0]
+ self.variational_parameters['global']['factor_weights']['mean'] += param_grad * step_size
+
+ mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[1]
+ self.variational_parameters['global']['factor_weights']['log_std'] += param_grad * step_size
+
+ def initialize_global_opt_states(self):
+ n_factors = self.node_hyperparams['n_factors']
+ m = jnp.zeros((n_factors,self.n_genes))
+ v = jnp.zeros((n_factors,self.n_genes))
+ state1 = (m,v)
+ m = jnp.zeros((n_factors,self.n_genes))
+ v = jnp.zeros((n_factors,self.n_genes))
+ state2 = (m,v)
+ states = (state1, state2)
+ return states
+
+ def update_global_params_adaptive(self, global_params_grad, global_sample_grad, global_params_entropy_grad, i, states, b1=0.9,
+ b2=0.999, eps=1e-8, step_size=0.001):
+ mc_grad = jnp.mean(global_params_grad[0] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[0]
+
+ m, v = states[0]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state1 = (m, v)
+ self.variational_parameters['global']['factor_weights']['mean'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+
+ mc_grad = jnp.mean(global_params_grad[1] * global_sample_grad, axis=0)
+ param_grad = mc_grad + global_params_entropy_grad[1]
+
+ m, v = states[1]
+ m = (1 - b1) * param_grad + b1 * m # First moment estimate.
+ v = (1 - b2) * jnp.square(param_grad) + b2 * v # Second moment estimate.
+ mhat = m / (1 - jnp.asarray(b1, m.dtype) ** (i + 1)) # Bias correction.
+ vhat = v / (1 - jnp.asarray(b2, m.dtype) ** (i + 1))
+ state2 = (m, v)
+ self.variational_parameters['global']['factor_weights']['log_std'] += step_size * mhat / (jnp.sqrt(vhat) + eps)
+
+ states = (state1, state2)
+ return states
\ No newline at end of file
diff --git a/scatrex/models/trajectory/node_opt.py b/scatrex/models/trajectory/node_opt.py
new file mode 100644
index 0000000..756b431
--- /dev/null
+++ b/scatrex/models/trajectory/node_opt.py
@@ -0,0 +1,142 @@
+import jax
+import jax.numpy as jnp
+import tensorflow_probability.substrates.jax.distributions as tfd
+
+@jax.jit
+def sample_angle(key, mu, log_kappa): # univariate: one sample
+ return tfd.VonMises(mu, jnp.exp(log_kappa)).sample(seed=key)
+sample_angle_val_and_grad = jax.vmap(jax.value_and_grad(sample_angle, argnums=(1,2)), in_axes=(None, 0, 0)) # per-dimension val and grad
+mc_sample_angle_val_and_grad = jax.jit(jax.vmap(sample_angle_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def sample_loc(key, mu, log_std): # univariate: one sample
+ return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key)
+sample_loc_val_and_grad = jax.vmap(jax.value_and_grad(sample_loc, argnums=(1,2)), in_axes=(None, 0, 0)) # per-dimension val and grad
+mc_sample_loc_val_and_grad = jax.jit(jax.vmap(sample_loc_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def angle_logp(this_angle, parent_angle, concentration): # single sample
+ return jnp.sum(tfd.VonMises(parent_angle, concentration).log_prob(this_angle))
+angle_logp_val_and_grad = jax.jit(jax.value_and_grad(angle_logp, argnums=0)) # Take grad wrt to this
+mc_angle_logp_val_and_grad = jax.jit(jax.vmap(angle_logp_val_and_grad, in_axes=(0,0, None))) # Multiple sample value_and_grad
+
+angle_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(angle_logp, argnums=1)) # Take grad wrt to parent
+mc_angle_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(angle_logp_val_and_grad, in_axes=(0,0, None))) # Multiple sample value_and_grad
+
+@jax.jit
+def angle_logq(mu, log_kappa):
+ return jnp.sum(tfd.VonMises(mu, jnp.exp(log_kappa)).entropy())
+angle_logq_val_and_grad = jax.jit(jax.value_and_grad(angle_logq, argnums=(0,1))) # Take grad wrt to parameters
+
+@jax.jit
+def loc_logp(this_loc, parent_loc, this_angle, log_std, radius): # single sample
+ mean = parent_loc + jnp.hstack([jnp.cos(this_angle)*radius, jnp.sin(this_angle)*radius]) # Use samples from parent
+ return jnp.sum(tfd.Normal(mean, jnp.exp(log_std)).log_prob(this_loc)) # sum across dimensions
+loc_logp_val = jax.jit(loc_logp)
+mc_loc_logp_val = jax.jit(jax.vmap(loc_logp_val, in_axes=(0,0,0, None, None))) # Multiple sample
+
+loc_logp_val_and_grad = jax.jit(jax.value_and_grad(loc_logp, argnums=0)) # Take grad wrt to this
+mc_loc_logp_val_and_grad = jax.jit(jax.vmap(loc_logp_val_and_grad, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad
+
+loc_logp_val_and_grad_wrt_parent = jax.jit(jax.value_and_grad(loc_logp, argnums=1)) # Take grad wrt to parent
+mc_loc_logp_val_and_grad_wrt_parent = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_parent, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad
+
+loc_logp_val_and_grad_wrt_angle = jax.jit(jax.value_and_grad(loc_logp, argnums=2)) # Take grad wrt to angle
+mc_loc_logp_val_and_grad_wrt_angle = jax.jit(jax.vmap(loc_logp_val_and_grad_wrt_angle, in_axes=(0,0,0, None, None))) # Multiple sample value_and_grad
+
+@jax.jit
+def loc_logq(mu, log_std):
+ return tfd.Normal(mu, jnp.exp(log_std)).entropy()
+loc_logq_val_and_grad = jax.jit(jax.vmap(jax.value_and_grad(loc_logq, argnums=(0,1)), in_axes=(0,0))) # Take grad wrt to parameters
+
+
+
+
+@jax.jit
+def sample_obs_weights(key, mean, log_std): # NxK
+ return tfd.Normal(mean, jnp.exp(log_std)).sample(seed=key)
+sample_obs_weights_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_obs_weights, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad
+mc_sample_obs_weights_val_and_grad = jax.jit(jax.vmap(sample_obs_weights_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+@jax.jit
+def sample_factor_weights(key, mu, log_std): # KxG
+ return tfd.Normal(mu, jnp.exp(log_std)).sample(seed=key)
+sample_factor_weights_val_and_grad = jax.vmap(jax.vmap(jax.value_and_grad(sample_factor_weights, argnums=(1,2)), in_axes=(None,0,0)), in_axes=(None,0,0)) # per-dimension val and grad
+mc_sample_factor_weights_val_and_grad = jax.jit(jax.vmap(sample_factor_weights_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad
+
+
+@jax.jit
+def obs_weights_logp(sample, mean, log_std): # single sample, NxK
+ return jnp.sum(tfd.Normal(mean, jnp.exp(log_std)).log_prob(sample)) # sum across obs and dimensions
+obs_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(obs_weights_logp, argnums=0)) # Take grad wrt to sample (NxK)
+mc_obs_weights_logp_val_and_grad = jax.jit(jax.vmap(obs_weights_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxNxK
+
+@jax.jit
+def obs_weights_logq(mu, log_std):
+ return tfd.Normal(mu, jnp.exp(log_std)).entropy()
+obs_weights_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(obs_weights_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters
+
+
+@jax.jit
+def factor_weights_logp(sample, mean, log_std): # single sample, KxG
+ return jnp.sum(tfd.Normal(mean, jnp.exp(log_std)).log_prob(sample)) # sum across factors and genes
+factor_weights_logp_val_and_grad = jax.jit(jax.value_and_grad(factor_weights_logp, argnums=0)) # Take grad wrt to sample (KxG)
+mc_factor_weights_logp_val_and_grad = jax.jit(jax.vmap(factor_weights_logp_val_and_grad, in_axes=(0,None,None))) # Multiple sample value_and_grad: SxKxG
+
+@jax.jit
+def factor_weights_logq(mu, log_std):
+ return tfd.Normal(mu, jnp.exp(log_std)).entropy()
+factor_weights_logq_val_and_grad = jax.jit(jax.vmap(jax.vmap(jax.value_and_grad(factor_weights_logq, argnums=(0,1)), in_axes=(0,0)), in_axes=(0,0))) # Take grad wrt to parameters
+
+@jax.jit
+def _mc_obs_ll(obs, node_mean, obs_weights, factor_weights, std): # For each MC sample: Nx2
+ m = node_mean + obs_weights.dot(factor_weights)
+ return jnp.sum(jax.vmap(jax.scipy.stats.norm.logpdf, in_axes=[0, 0, None])(obs, m, std), axis=1) # sum over dimensions
+
+@jax.jit
+def ll(x, weights, psi, obs_weights, factor_weights, log_std): # single sample
+ loc = psi + obs_weights.dot(factor_weights)
+ return jnp.sum(jnp.sum(tfd.Normal(loc, jnp.exp(log_std)).log_prob(x),axis=1) * weights)
+ll_val_and_grad_psi = jax.jit(jax.value_and_grad(ll, argnums=2)) # Take grad wrt to psi
+mc_ll_val_and_grad_psi = jax.jit(jax.vmap(ll_val_and_grad_psi,
+ in_axes=(None,None,0,0,0,None)))
+
+ll_val_and_grad_obs_weights = jax.jit(jax.value_and_grad(ll, argnums=3)) # Take grad wrt to obs_weights
+mc_ll_val_and_grad_obs_weights = jax.jit(jax.vmap(ll_val_and_grad_obs_weights,
+ in_axes=(None,None,0,0,0,None)))
+
+ll_val_and_grad_factor_weights = jax.jit(jax.value_and_grad(ll, argnums=4)) # Take grad wrt to factor_weights
+mc_ll_val_and_grad_factor_weights = jax.jit(jax.vmap(ll_val_and_grad_factor_weights,
+ in_axes=(None,None,0,0,0,None)))
+
+@jax.jit
+def ll_suffstats(node_mean, mass, A, B_g, C, D_g, E, std): # for a single node_mean sample
+ """
+ mass: \sum_n q(z_n = this node)
+ A: \sum_n q(z_n = this node) * \sum_g x_ng ** 2
+ B_g: \sum_n q(z_n = this node) * x_ng
+ C: \sum_n q(z_n = this node) * \sum_g x_ng * E[s_nW_g]
+ D_g: \sum_n q(z_n = this node) * E[s_nW_g]
+ E: \sum_n q(z_n = this node) * \sum_g E[(s_nW_g)**2]
+ """
+ v = std**2
+ ll = -jnp.log(2*jnp.pi*v) * mass - A/(2*v) + jnp.sum(B_g*node_mean)/v + C/v - \
+ (mass*jnp.sum(node_mean**2))/(2*v) - jnp.sum(D_g*node_mean)/v - E/(2*v)
+ return ll
+
+@jax.jit
+def ll_node_mean_suff(node_mean, mass, B_g, D_g, log_std): # for a single node_mean sample
+ """
+ mass: \sum_n q(z_n = this node)
+ B_g: \sum_n q(z_n = this node) * x_ng
+ D_g: \sum_n q(z_n = this node) * E[s_nW_g]
+ """
+ v = jnp.exp(log_std)**2
+ ll = jnp.sum(B_g*node_mean)/v - (mass*jnp.sum(node_mean**2))/(2*v) - jnp.sum(D_g*node_mean)/v
+ return ll
+ll_node_mean_suff_val_and_grad = jax.jit(jax.value_and_grad(ll_node_mean_suff, argnums=0)) # Take grad wrt to psi
+mc_ll_node_mean_suff_val_and_grad = jax.jit(jax.vmap(ll_node_mean_suff_val_and_grad,
+ in_axes=(0,None, None, None, None)))
+
+# To get noise sample
+sample_prod = jax.jit(lambda mat1, mat2: mat1.dot(mat2))
\ No newline at end of file
diff --git a/scatrex/models/trajectory/tree.py b/scatrex/models/trajectory/tree.py
new file mode 100644
index 0000000..7ba9268
--- /dev/null
+++ b/scatrex/models/trajectory/tree.py
@@ -0,0 +1,67 @@
+import numpy as np
+from .node import TrajectoryNode
+from ...ntssb.observed_tree import *
+from anndata import AnnData
+
+class TrajectoryTree(ObservedTree):
+ def __init__(self, **kwargs):
+ super(TrajectoryTree, self).__init__(**kwargs)
+ self.node_constructor = TrajectoryNode
+
+ def sample_kernel(self, parent_params, mean_dist=1., angle_concentration=1., loc_variance=.1, seed=42, depth=1., **kwargs):
+ rng = np.random.default_rng(seed=seed)
+ parent_loc = parent_params[0]
+ parent_angle = parent_params[1]
+ angle_concentration = angle_concentration * depth
+ sampled_angle = rng.vonmises(parent_angle, angle_concentration)
+ sampled_loc = rng.normal(mean_dist, loc_variance)
+ sampled_loc = parent_loc + np.array([np.cos(sampled_angle)*np.abs(sampled_loc), np.sin(sampled_angle)*np.abs(sampled_loc)])
+ return [sampled_loc, sampled_angle]
+
+ def sample_root(self, **kwargs):
+ return [np.array([0., 0.]), 0.]
+
+ def get_param_size(self):
+ return self.tree["param"][0].size
+
+ def get_params(self):
+ params = []
+ for node in self.tree_dict:
+ params.append(self.tree_dict[node]["param"][0])
+ return np.array(params, dtype=np.float)
+
+ def param_distance(self, paramA, paramB):
+ return np.sqrt(np.sum((paramA[0]-paramB[0])**2))
+
+ def create_adata(self, var_names=None):
+ params = []
+ params_labels = []
+ for node in self.tree_dict:
+ if self.tree_dict[node]["size"] != 0:
+ params_labels.append(
+ [self.tree_dict[node]["label"]] * self.tree_dict[node]["size"]
+ )
+ params.append(
+ np.vstack(
+ [self.tree_dict[node]["param"][0]] * self.tree_dict[node]["size"]
+ )
+ )
+ params = pd.DataFrame(np.vstack(params))
+ params_labels = np.concatenate(params_labels).tolist()
+ if var_names is not None:
+ params.columns = var_names
+ self.adata = AnnData(params)
+ self.adata.obs["node"] = params_labels
+ self.adata.uns["node_colors"] = [
+ self.tree_dict[node]["color"]
+ for node in self.tree_dict
+ if self.tree_dict[node]["size"] != 0
+ ]
+ self.adata.uns["node_sizes"] = np.array(
+ [
+ self.tree_dict[node]["size"]
+ for node in self.tree_dict
+ if self.tree_dict[node]["size"] != 0
+ ]
+ )
+ self.adata.var["bulk"] = np.mean(self.adata.X, axis=0)
\ No newline at end of file
diff --git a/scatrex/ntssb/__init__.py b/scatrex/ntssb/__init__.py
index 2648fba..551928f 100644
--- a/scatrex/ntssb/__init__.py
+++ b/scatrex/ntssb/__init__.py
@@ -1,4 +1,4 @@
from .node import AbstractNode
from .ntssb import NTSSB
-from .tree import Tree
+from .observed_tree import ObservedTree
from .search import StructureSearch
diff --git a/scatrex/ntssb/node.py b/scatrex/ntssb/node.py
index bc0c94c..82d5a48 100644
--- a/scatrex/ntssb/node.py
+++ b/scatrex/ntssb/node.py
@@ -1,22 +1,33 @@
import numpy as np
from numpy.random import *
+import jax.numpy as jnp
+from abc import ABC, abstractmethod
-class AbstractNode(object):
+class AbstractNode(ABC):
def __init__(
- self, is_observed, observed_parameters, parent=None, tssb=None, label=""
+ self, observed_parameters, parent=None, tssb=None, label="", seed=42,
):
self.data = set([])
self._children = set([])
self.tssb = tssb
- self.is_observed = is_observed
- self.observed_parameters = observed_parameters
self.label = label
self.event_str = ""
+ self.seed = seed
self.params = dict()
+ self.observed_parameters = observed_parameters
self.variational_parameters = dict(globals=dict(), locals=dict())
-
+ self.variational_parameters = {'delta_1': 1., 'delta_2': 1., # nu stick
+ 'sigma_1': 1., 'sigma_2': 1., # psi stick
+ 'q_z': [], # prob of assigning each cell to this node
+ 'kernel' : dict(), # model-specific
+ 'q_rho': [], # prob of assigning the root node of each child TSSB to this node
+ 'E_log_phi': 0., # auxiliary quantity
+ 'sum_E_log_1_nu': 0., # auxiliary quantity
+ }
+ self.samples = None
+ self.pivot_prior_prob = 1. # prior prob of assigning the root node of each child TSSB to this node
self.data_weights = 0.
self.weight_until_here = 0.
@@ -38,6 +49,34 @@ def __init__(
else:
self._parent = None
+ @abstractmethod
+ def compute_loglikelihood(self):
+ return
+
+ @abstractmethod
+ def combine_params(self):
+ return
+
+ @abstractmethod
+ def sample_kernel(self):
+ return
+
+ @abstractmethod
+ def compute_kernel_prior(self):
+ return
+
+ @abstractmethod
+ def compute_root_kernel_prior(self):
+ return
+
+ @abstractmethod
+ def compute_kernel_entropy(self):
+ return
+
+ @abstractmethod
+ def remove_noise(self):
+ return
+
def set_parent(self, parent, reset=False):
if self._parent is not None and self._parent._children is not None:
self._parent._children.remove(self)
@@ -45,20 +84,11 @@ def set_parent(self, parent, reset=False):
parent.add_child(self)
self._parent = parent
- if not self.is_observed:
- # Make sure we use the right observed parameters
- self.observed_parameters = parent.observed_parameters
- self.inherit_parameters()
-
- self.unobserved_factors = parent.unobserved_factors + self.unobserved_factors
self.set_mean()
if reset:
self.reset_variational_parameters()
- def inherit_parameters(self):
- pass
-
def reset_parameters(self):
pass
@@ -69,11 +99,17 @@ def kill(self):
self._parent = None
self._children = None
- def spawn(self, is_observed, observed_parameters):
+ def spawn(self, observed_parameters, seed=42):
return self.__class__(
- is_observed, observed_parameters, parent=self, tssb=self.tssb
+ observed_parameters, parent=self, tssb=self.tssb, seed=seed,
)
+ def get_observed_parameters(self):
+ return self.observed_parameters
+
+ def get_params(self):
+ return self.params
+
def has_data(self):
if len(self.data):
return True
@@ -158,7 +194,7 @@ def global_param(self, key):
return self.parent().global_param(key)
def get_ancestors(self, all=True):
- if self._parent is None or (not all and self.is_observed):
+ if self._parent is None or (not all and self.tssb != self._parent.tssb):
return [self]
else:
ancestors = self._parent.get_ancestors(all=all)
@@ -179,11 +215,20 @@ def descend(node):
def parameter_log_prior(self):
return 0
- def tssb_caller(self):
- if self.parent() is None:
- return self.tssb
- else:
- return self.parent().tssb_caller()
-
def set_event_string(self):
pass
+
+ def get_top_obs(self, q=70, idx=None):
+ """
+ Get data which is very well explained by this node's parameter
+ """
+ if idx is None:
+ idx = jnp.arange(self.tssb.ntssb.num_data)
+ if len(idx) == 0:
+ return np.array([])
+ lls = self.compute_loglikelihood(idx)
+ top_obs = idx[np.where(lls > np.percentile(lls, q=q))[0]]
+ return top_obs
+
+ def reset_variational_state(self, **kwargs):
+ return
\ No newline at end of file
diff --git a/scatrex/ntssb/ntssb.py b/scatrex/ntssb/ntssb.py
index 225d5fd..2cba92b 100644
--- a/scatrex/ntssb/ntssb.py
+++ b/scatrex/ntssb/ntssb.py
@@ -8,6 +8,7 @@
from graphviz import Digraph
import matplotlib
import matplotlib.cm
+import matplotlib.pyplot as plt
import numpy as np
from numpy import *
@@ -19,13 +20,10 @@
import jax.numpy as jnp
import jax.nn as jnn
-from ..util import *
+from ..utils.math_utils import *
from ..callbacks import elbos_callback
from .tssb import TSSB
-from ..models import cna
-from ..models.cna.opt_funcs import *
-
-import time
+from ..plotting import tree_colors, plot_full_tree
import logging
@@ -62,7 +60,6 @@ class NTSSB(object):
def __init__(
self,
input_tree,
- node_constructor,
dp_alpha=1.0,
dp_gamma=1.0,
alpha_decay=1.0,
@@ -70,8 +67,11 @@ def __init__(
max_depth=15,
fixed_weights_pivot_sampling=True,
use_weights=True,
+ weights_concentration=10.,
+ min_weight=1e-6,
verbosity=logging.INFO,
node_hyperparams=dict(),
+ seed=42,
):
if input_tree is None:
raise Exception("Input tree must be specified.")
@@ -80,11 +80,14 @@ def __init__(
self.max_depth = max_depth
self.dp_alpha = dp_alpha # smaller dp_alpha => larger nu => less nodes
self.dp_gamma = dp_gamma # smaller dp_gamma => larger psi => less nodes
- self.alpha_decay = alpha_decay
+ self.alpha_decay = alpha_decay # smaller alpha_decay => larger decay with depth => less nodes
+
+ self.seed = seed
self.input_tree = input_tree
self.input_tree_dict = self.input_tree.tree_dict
- self.node_constructor = node_constructor
+ self.node_constructor = self.input_tree.node_constructor
+ self.node_hyperparams = node_hyperparams
self.fixed_weights_pivot_sampling = fixed_weights_pivot_sampling
@@ -99,11 +102,14 @@ def __init__(
self.data = None
self.num_data = None
self.covariates = None
+ self.num_batches = 1
+ self.batch_size = None
self.max_nodes = (
len(self.input_tree_dict.keys()) * 1
) # upper bound on number of nodes
self.n_nodes = len(self.input_tree_dict.keys())
+ self.n_total_nodes = self.n_nodes
self.obs_cmap = self.input_tree.cmap
self.exp_cmap = matplotlib.cm.viridis
@@ -111,199 +117,331 @@ def __init__(
logger.setLevel(verbosity)
- self.reset_tree(use_weights=use_weights, node_hyperparams=node_hyperparams)
+ self.reset_tree(use_weights=use_weights, weights_concentration=weights_concentration, min_weight=min_weight)
+
+ self.set_pivot_priors()
+
+ self.variational_parameters = {
+ 'LSE_c': [], # normalizing constant for cell-TSSB assignments
+ }
# ========= Functions to initialize tree. =========
- def reset_tree(self, use_weights=False, node_hyperparams=dict()):
+ def reset_tree(self, use_weights=False, weights_concentration=10., min_weight=1e-6):
if use_weights and "weight" not in self.input_tree_dict["A"].keys():
raise KeyError("No weights were specified in the input tree.")
# Clear tree
self.assignments = []
- input_tree_dict = self.input_tree_dict
-
- # Get MRCA node
- root_node = self.input_tree.mrca()
-
- obj = self.node_constructor(
- True,
- input_tree_dict[root_node]["params"],
- parent=None,
- label=root_node,
- **node_hyperparams,
- )
- input_tree_dict[root_node]["subtree"] = TSSB(
- obj,
- root_node,
- ntssb=self,
- dp_alpha=input_tree_dict[root_node]["dp_alpha_subtree"],
- alpha_decay=input_tree_dict[root_node]["alpha_decay_subtree"],
- dp_gamma=input_tree_dict[root_node]["dp_gamma_subtree"],
- color=input_tree_dict[root_node]["color"],
- )
-
- main = (
- boundbeta(1.0, self.dp_alpha) if self.min_depth == 0 else 0.0
- ) # if min_depth > 0, no data can be added to the root (main stick is nu)
- if use_weights:
- main = self.input_tree_dict[root_node]["weight"]
- input_tree_dict[root_node]["subtree"].weight = self.input_tree_dict[
- root_node
- ]["weight"]
-
- self.root = {
- "node": input_tree_dict[root_node]["subtree"],
- "main": main,
- "sticks": empty((0, 1)), # psi sticks
- "children": [],
- "label": root_node,
- "super_parent": None,
- "parent": None,
- }
-
- # Recursively construct tree of subtrees
- def descend(super_tree, label, depth=0):
- for i, child in enumerate(input_tree_dict[label]["children"]):
-
- stick = boundbeta(1, self.dp_gamma)
+ # Traverse tree in depth first and create TSSBs
+ def descend(input_root, idx=1, depth=0):
+ alpha_nu = 1.
+ beta_nu = (self.alpha_decay**depth) * self.dp_alpha
+
+ children_roots = []
+ sticks = []
+ psi_priors = []
+ for i, child in enumerate(input_root['children']):
+ child_root, idx = descend(child, idx, depth+1)
+ children_roots.append(child_root)
+
+ rng = np.random.default_rng(int(self.seed+idx*1e6))
+ stick = boundbeta(1, self.dp_gamma, rng)
+ psi_prior = {"alpha_psi": 1., "beta_psi": self.dp_gamma}
if use_weights:
- stick = self.input_tree.get_sum_weights_subtree(child)
- if i < len(input_tree_dict[label]["children"]) - 1:
+ stick = self.input_tree.get_sum_weights_subtree(child_root["label"])
+ if i < len(input_root["children"]) - 1:
sum = 0
- for j, c in enumerate(input_tree_dict[label]["children"][i:]):
- sum = sum + self.input_tree.get_sum_weights_subtree(c)
+ for j, c in enumerate(input_root["children"][i:]):
+ sum = sum + self.input_tree.get_sum_weights_subtree(c["label"])
stick = stick / sum
else:
stick = 1.0
-
- super_tree["sticks"] = vstack(
- [
- super_tree["sticks"],
- stick
- if i < len(input_tree_dict[label]["children"]) - 1
- else 1.0,
- ]
+ psi_prior["alpha_psi"] = stick * (weights_concentration - 2) + 1
+ psi_prior["beta_psi"] = (1-stick) * (weights_concentration -2) + 1
+ psi_priors.append(psi_prior)
+ sticks.append(stick)
+ if len(sticks) == 0:
+ sticks = empty((0, 1))
+ else:
+ sticks = vstack(sticks)
+
+ label = input_root["label"]
+
+ # Create node
+ local_seed = int(self.seed+idx*1e6)
+ node = self.node_constructor(
+ input_root["param"],
+ label=label,
+ seed=local_seed,
+ **self.node_hyperparams,
+ )
+
+ # Create TSSB with pointers to children root nodes
+ rng = np.random.default_rng(local_seed)
+
+ children_nodes = [c["node"].root["node"] for c in children_roots]
+ tssb = TSSB(
+ node,
+ label,
+ ntssb=self,
+ children_root_nodes=children_nodes,
+ dp_alpha=input_root["dp_alpha_subtree"],
+ alpha_decay=input_root["alpha_decay_subtree"],
+ dp_gamma=input_root["dp_gamma_subtree"],
+ eta=input_root["eta"],
+ color=input_root["color"],
+ seed=local_seed,
)
+ input_root["subtree"] = tssb
+
+ # Create root dict
+ if depth >= self.min_depth:
+ main = boundbeta(1.0, (self.alpha_decay ** depth) * self.dp_alpha, rng)
+ else: # if depth < min_depth, no data can be added to this node (main stick is nu)
+ main = 0.
+ if use_weights:
+ main = input_root["weight"]
+ subtree_weights_sum = self.input_tree.get_sum_weights_subtree(label)
+ main = main / subtree_weights_sum
+ input_root["subtree"].weight = input_root["weight"]
+ if len(input_root["children"]) < 1:
+ main = 1.0 # stop at leaf node
+
+ if use_weights:
+ alpha_nu = main * (weights_concentration - 2) + 1
+ beta_nu = (1-main) * (weights_concentration - 2) + 1
+
+ root_dict = {
+ "node": tssb,
+ "main": main,
+ "sticks": sticks,
+ "children": children_roots,
+ "label": input_root["label"],
+ "super_parent": None, # maybe remove
+ "pivot_node": None, # maybe remove
+ "pivot_tssb": None, # maybe remove
+ "color": input_root["color"],
+ "alpha_nu": alpha_nu,
+ "beta_nu": beta_nu,
+ "psi_priors": psi_priors,
+ }
+
+ return root_dict, idx+1
+
+ self.root, _ = descend(self.input_tree.tree)
- main = boundbeta(1.0, (self.alpha_decay ** (depth + 1)) * self.dp_alpha)
- if use_weights:
- main = self.input_tree_dict[child]["weight"]
- subtree_weights_sum = self.input_tree.get_sum_weights_subtree(child)
- main = main / subtree_weights_sum
-
- if len(input_tree_dict[child]["children"]) < 1:
- main = 1.0 # stop at leaf node
-
- pivot_tssb = input_tree_dict[label]["subtree"]
- # pivot_tssb.dp_alpha = input_tree_dict[child]['dp_alpha_parent_edge']
- # pivot_tssb.alpha_decay = input_tree_dict[child]['alpha_decay_parent_edge']
- # pivot_tssb.truncate()
-
- # pivot_tssb.eta = input_tree_dict[child]['eta']
- pivot_node = super_tree["node"].root["node"]
-
- obj = self.node_constructor(
- True,
- input_tree_dict[child]["params"],
- parent=pivot_node,
- label=child,
- )
+ def descend(root):
+ for child in root['children']:
+ child["node"]._parent = root["node"]
+ descend(child)
+
+ # Set parents
+ descend(self.root)
- input_tree_dict[child]["subtree"] = TSSB(
- obj,
- child,
- ntssb=self,
- dp_alpha=input_tree_dict[child]["dp_alpha_subtree"],
- alpha_decay=input_tree_dict[child]["alpha_decay_subtree"],
- dp_gamma=input_tree_dict[child]["dp_gamma_subtree"],
- color=input_tree_dict[child]["color"],
- )
- input_tree_dict[child]["subtree"].eta = input_tree_dict[child]["eta"]
+ # And add weights keys
+ self.set_weights()
- if use_weights:
- input_tree_dict[child]["subtree"].weight = self.input_tree_dict[
- child
- ]["weight"]
-
- super_tree["children"].append(
- {
- "node": input_tree_dict[child]["subtree"],
- "main": main if self.min_depth <= (depth + 1) else 0.0,
- "sticks": empty((0, 1)), # psi sticks
- "children": [],
- "label": child,
- "super_parent": super_tree["node"],
- "pivot_node": pivot_node,
- "pivot_tssb": pivot_tssb,
- }
- )
+ def set_tssb_params(self, dp_alpha=1., alpha_decay=1., dp_gamma=1.):
+ def descend(root):
+ root['node'].dp_alpha = dp_alpha
+ root['node'].alpha_decay = alpha_decay
+ root['node'].dp_gamma = dp_gamma
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
- descend(super_tree["children"][-1], child, depth + 1)
+ def set_node_hyperparams(self, **kwargs):
+ def descend(root):
+ root['node'].set_node_hyperparams(**kwargs)
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
- descend(self.root, root_node)
+ def sample_variational_distributions(self, **kwargs):
+ def descend(root):
+ root['node'].sample_variational_distributions(**kwargs)
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
- def reset_variational_parameters(self, **kwargs):
- # Reset node parameters
+ def set_learned_parameters(self):
+ def descend(root):
+ root['node'].set_learned_parameters()
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
+
+ def reset_sufficient_statistics(self):
def descend(super_tree):
- super_tree["node"].reset_node_variational_parameters(**kwargs)
+ super_tree["node"].reset_sufficient_statistics(num_batches=self.num_batches)
for child in super_tree["children"]:
descend(child)
+ descend(self.root)
+ def reset_variational_parameters(self, **kwargs):
+ # Reset node parameters
+ def descend(super_tree, alpha_psi=1., beta_psi=1.):
+ alpha_nu = super_tree['alpha_nu']
+ beta_nu = super_tree['beta_nu']
+ super_tree["node"].reset_variational_parameters(alpha_nu=alpha_nu, beta_nu=beta_nu,
+ alpha_psi=alpha_psi,beta_psi=beta_psi,
+ **kwargs)
+ c_norm = jnp.array(super_tree["node"].variational_parameters['q_c'])
+ for i, child in enumerate(super_tree["children"]):
+ alpha_psi = super_tree['psi_priors'][i]["alpha_psi"]
+ beta_psi = super_tree['psi_priors'][i]["beta_psi"]
+ c_norm += descend(child, alpha_psi=alpha_psi, beta_psi=beta_psi)
+ return c_norm
+
+ c_norm = descend(self.root)
+
+ # Apply normalization
+ def descend(root, alpha_psi=1., beta_psi=1.):
+ alpha_nu = root['alpha_nu']
+ beta_nu = root['beta_nu']
+ root["node"].reset_variational_parameters(alpha_nu=alpha_nu, beta_nu=beta_nu,
+ alpha_psi=alpha_psi,beta_psi=beta_psi,
+ **kwargs)
+ root["node"].variational_parameters['q_c'] = root["node"].variational_parameters['q_c'] / c_norm
+ for i, child in enumerate(root["children"]):
+ alpha_psi = root['psi_priors'][i]["alpha_psi"]
+ beta_psi = root['psi_priors'][i]["beta_psi"]
+ descend(child, alpha_psi=alpha_psi, beta_psi=beta_psi)
descend(self.root)
+ def init_root_kernels(self, **kwargs):
+ def descend(super_tree):
+ for child in super_tree["children"]:
+ child["node"].root["node"].init_kernel(**kwargs)
+ descend(child)
+ descend(self.root)
+
def reset_node_parameters(
- self, root_params=True, down_params=True, node_hyperparams=None
- ):
+ self, **node_hyperparams
+ ):
# Reset node parameters
def descend(super_tree):
- super_tree["node"].reset_node_variational_parameters()
- super_tree["node"].reset_node_parameters(
- root_params=root_params,
- down_params=down_params,
- node_hyperparams=node_hyperparams,
- )
+ super_tree["node"].reset_node_parameters(**node_hyperparams)
for child in super_tree["children"]:
descend(child)
descend(self.root)
+ def remake_observed_params(self):
+ def descend(super_tree):
+ self.input_tree.tree_dict[super_tree["label"]]["param"] = super_tree["node"].root["node"].params
+ super_tree["node"].root["node"].observed_parameters = self.input_tree.tree_dict[super_tree["label"]]["param"]
+ for child in super_tree["children"]:
+ descend(child)
+
+ descend(self.root)
+ self.input_tree.update_tree()
+
+ def set_radial_positions(self):
+ """
+ Create a radial layout from the full NTSSB and set the node means
+ as their positions in the layout.
+
+ Make sure the outer params correspond to longer branches than internal ones.
+ """
+ import networkx as nx
+ self.create_augmented_tree_dict() # create self.node_dict
+
+ G = nx.DiGraph()
+ for node in self.node_dict:
+ G.add_node(G, node)
+ if self.node_dict[node]['parent'] != '-1':
+ parent = self.node_dict[node]['parent']
+ G.add_edge(parent, node)
+ pos = nx.nx_pydot.graphviz_layout(G, prog="twopi")
+
+ self.set_node_means(pos) # to sample observations from nodes in these positions
+
+ def set_node_means(self, pos):
+ for node in self.node_dict:
+ self.node_dict[node].set_node_mean(pos[node])
+
def sync_subtrees(self):
subtrees = self.get_subtrees()
for subtree in subtrees:
subtree.ntssb = self
- def put_data_in_nodes(self, num_data, eta=0):
+ def get_node(self, u, key=None, uniform=False, include_leaves=True):
+ # See in which subtree it lands
+ subtree, _, u = self.find_node(u)
+
+ # See in which node it lands
+ if uniform:
+ _, _, root = subtree.find_node_uniform(key, include_leaves=include_leaves)
+ else:
+ _, _, root = subtree.find_node(u, include_leaves=include_leaves)
+
+ return root
+
+ def sample_assignments(self, num_data):
+ self.num_data = num_data
+
+ node_assignments = []
+ obs_node_assignments = []
+
self.assignments = []
+ self.subtree_assignments = []
subtrees = self.get_subtrees()
for tssb in subtrees:
tssb.assignments = []
+ tssb.remove_data()
+
+ # Draw sticks
+ rng = np.random.default_rng(self.seed)
+ u_vector = rng.random(size=num_data)
+ for n in range(num_data):
+ u = u_vector[n]
+ # See in which subtree it lands
+ subtree, _, u = self.find_node(u)
- self.num_data = num_data
+ # See in which node it lands
+ node, _, _ = subtree.find_node(u)
- # Get mixture weights
- nodes, weights = self.get_node_weights(eta=eta)
+ self.assignments.append(node)
+ self.subtree_assignments.append(subtree)
- for node in nodes:
- node.remove_data()
+ node_assignments.append(node.label)
+ obs_node_assignments.append(subtree.label)
- for n in range(self.num_data):
- node = np.random.choice(nodes, p=weights)
- node.tssb.assignments.append(node)
+ subtree.assignments.append(node)
+ subtree.add_datum(n)
node.add_datum(n)
- self.assignments.append(node)
+
+ return node_assignments, obs_node_assignments
+
+ def simulate_data(self):
+ self.data = np.zeros((self.num_data, self.input_tree.get_param_size()))
+
+ # Reset root node parameters to set data-dependent variables if applicable
+ self.root["node"].root["node"].reset_data_parameters()
+
+ # Sample observations
+ def super_descend(super_root):
+ descend(super_root['node'].root)
+ for super_child in super_root['children']:
+ super_descend(super_child)
- def normalize_data(self):
- if self.data is None:
- raise Exception("Need to `call add_data(self, data, to_root=False)` first.")
+ def descend(root):
+ attached_cells = np.array(list(root['node'].data))
+ if len(attached_cells) > 0:
+ self.data[attached_cells] = root['node'].sample_observations()
+ for child in root['children']:
+ descend(child)
+
+ super_descend(self.root)
+ self.data = jnp.array(self.data)
- self.normalized_data = np.log(
- 10000 * self.data / np.sum(self.data, axis=1).reshape(self.num_data, 1) + 1
- )
+ return self.data
- def add_data(self, data, covariates=None, to_root=False):
- self.data = data
- self.num_data = 0 if data is None else data.shape[0]
+ def add_data(self, data, covariates=None):
+ self.data = jnp.array(data)
+ self.num_data = data.shape[0]
if covariates is None:
self.covariates = np.zeros((self.num_data, 0))
else:
@@ -312,29 +450,15 @@ def add_data(self, data, covariates=None, to_root=False):
logger.debug(f"Adding data of shape {data.shape} to NTSSB")
- self.assignments = []
-
- for n in range(self.num_data):
- if to_root:
- subtree = self.root["node"]
- node = self.root["node"].root["node"]
- else:
- u = rand()
- subtree, _, u = self.find_node(u)
-
- # Now choose the node
- node, _, _ = subtree.find_node(u)
-
- subtree.assignments.append(node)
- node.add_datum(n)
- self.assignments.append(node)
-
try:
# Reset root node parameters to set data-dependent variables if applicable
self.root["node"].root["node"].reset_data_parameters()
except AttributeError:
pass
+ # Reset node variational parameters to use this data size
+ self.reset_variational_parameters()
+
def clear_data(self):
def descend(root):
for index, child in enumerate(root["children"]):
@@ -363,44 +487,47 @@ def descend(super_tree, label, depth=0):
descend(self.root, "A")
- def create_new_tree(self, n_extra_per_observed=1, num_data=None):
+ def create_new_tree(self, n_extra_per_observed=1):
# Clear current tree (including subtrees)
self.reset_tree(
- True, node_hyperparams=self.root["node"].root["node"].node_hyperparams
+ True
)
- self.reset_node_parameters(
- node_hyperparams=self.root["node"].root["node"].node_hyperparams
- )
- self.plot_tree(super_only=False) # update names
+ self.set_weights()
- # Add nodes to subtrees
- subtrees = self.get_subtrees()
- for subtree in subtrees:
- n_nodes = 0
- while n_nodes < n_extra_per_observed:
- _, nodes = subtree.get_mixture()
- # Uniformly choose a node from the subtree
- snode = np.random.choice(nodes)
- self.add_node_to(snode.label, optimal_init=False)
- self.plot_tree(super_only=False) # update names
- n_nodes = n_nodes + 1
-
- # Choose pivots
+ def get_distance(nodeA, nodeB):
+ return np.sqrt(np.sum((nodeA.get_mean() - nodeB.get_mean())**2))
+
+ # Add nodes and set pivots
def descend(super_tree):
- for child in super_tree["children"]:
+ if super_tree['weight'] != 0: # Add nodes only if it has some mass
+ n_nodes = 0
+ while n_nodes < n_extra_per_observed:
+ _, _, nodes_roots = super_tree['node'].get_mixture(get_roots=True)
+ # Uniformly choose a node from the subtree
+ rng = np.random.default_rng(super_tree['node'].seed + n_nodes)
+ snode = rng.choice(nodes_roots)
+ super_tree['node'].add_node(snode)
+ n_nodes = n_nodes + 1
+ super_tree['node'].reset_node_parameters(**self.node_hyperparams) # adjust parameters to avoid overlapping subnodes
+ for i, child in enumerate(super_tree["children"]):
weights, nodes = super_tree["node"].get_fixed_weights(
eta=child["node"].eta
)
- pivot_node = np.random.choice(nodes, p=weights)
+ weights = np.array([w/get_distance(child['node'].root['node'], n) for w, n in zip(weights, nodes)])
+ weights = weights / np.sum(weights)
+ # rng = np.random.default_rng(super_tree['node'].seed + i)
+ # pivot_node = rng.choice(nodes, p=weights)
+ pivot_node = nodes[np.argmax(weights)]
child["pivot_node"] = pivot_node
child["node"].root["node"].set_parent(pivot_node)
descend(child)
+ super_tree['node'].truncate()
+ super_tree['node'].set_weights()
+ super_tree['node'].set_pivot_priors()
descend(self.root)
-
- if num_data is not None:
- self.put_data_in_nodes(num_data, eta=0)
+ self.plot_tree(super_only=False) # update names
def sample_new_tree(self, num_data, use_weights=False):
self.num_data = num_data
@@ -408,7 +535,6 @@ def sample_new_tree(self, num_data, use_weights=False):
# Clear current tree (including subtrees)
self.reset_tree(
use_weights,
- node_hyperparams=self.root["node"].root["node"].node_hyperparams,
)
# Break sticks to assign data
@@ -449,6 +575,38 @@ def descend(super_tree):
descend(self.root)
+ def get_param_dict(self):
+ """
+ Go from a dictionary where each node is a TSSB to a dictionary where each node is a dictionary,
+ with `params` and `weight` keys
+ """
+
+ param_dict = {
+ "node": self.root['node'].get_param_dict(),
+ "weight": self.root['weight'],
+ "children": [],
+ "obs_param": self.root['node'].root['node'].get_observed_parameters(),
+ "label": self.root['label'],
+ "color": self.root['color'],
+ "size": len(self.root['node']._data),
+ }
+ def descend(root, root_new):
+ for child in root["children"]:
+ child_new = {
+ "node": child['node'].get_param_dict(),
+ "weight": child['weight'],
+ "children": [],
+ "obs_param": child['node'].root['node'].get_observed_parameters(),
+ "label": child['label'],
+ "color": child['color'],
+ "size": len(child['node']._data)
+ }
+ root_new['children'].append(child_new)
+ descend(child, root_new['children'][-1])
+
+ descend(self.root, param_dict)
+ return param_dict
+
# ========= Functions to sample tree parameters. =========
def sample_pivot_main_sticks(self):
def descend(super_tree):
@@ -577,6 +735,34 @@ def descend(root, mass):
return descend(self.root, 1.0)
+ def set_weights(self):
+ def descend(root, mass):
+ root['weight'] = mass * root["main"]
+ edges = sticks_to_edges(root["sticks"])
+ weights = diff(hstack([0.0, edges]))
+ for i, child in enumerate(root["children"]):
+ descend(child, mass * (1.0 - root["main"]) * weights[i])
+ return descend(self.root, 1.0)
+
+ def set_expected_weights(self):
+ def descend(root):
+ logprior = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ if root['node'].parent() is not None:
+ logprior += root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ logprior += root['node'].variational_parameters['E_log_phi']
+ root['weight'] = jnp.exp(logprior)
+ root['node'].set_expected_weights()
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
+
+ def set_pivot_priors(self):
+ def descend(root):
+ for child in root['children']:
+ root['node'].set_pivot_priors()
+ descend(child)
+ descend(self.root)
+
def get_node_data_sizes(self, normalized=False, super_only=False):
nodes, _ = self.get_node_mixture()
sizes = []
@@ -663,29 +849,20 @@ def descend(super_root):
return descend(self.root)
- def get_nodes(self, root_node=None, parent_vector=False):
- def descend(root, idx=0, prev_idx=-1):
- idx = idx + 1
- node = [root]
- parent_idx = [prev_idx]
- parent_idx = [prev_idx]
- prev_idx = idx
- children = list(root.children())
- labs = [c.label for c in children]
- children = np.array(children)[np.argsort(labs)]
- for i, child in enumerate(children):
- nodes, idx, parents_idx = descend(child, idx, prev_idx - 1)
- node.extend(nodes)
- parent_idx.extend(parents_idx)
- return node, idx, parent_idx
-
- if root_node is None:
- root_node = self.root["node"].root["node"]
- node_list, _, parent_list = descend(root_node)
- if parent_vector:
- return node_list, parent_list
- else:
- return node_list
+ def get_nodes(self):
+ def descend(root):
+ nodes = [root['node']]
+ for child in root['children']:
+ nodes.extend(descend(child))
+ return nodes
+
+ def super_descend(root):
+ nodes = descend(root['node'].root)
+ for child in root['children']:
+ nodes.extend(super_descend(child))
+ return nodes
+
+ return super_descend(self.root)
def get_width_distribution(self):
def descend(root, depth, width_vec):
@@ -812,136 +989,6 @@ def vbis_estimate(self, importance_samples=None, n_samples=1000):
# ========= Functions to update tree parameters given data. =========
- def get_node_mean(self, log_baseline, unobserved_factors, noise, cnvs):
- node_mean = jnp.exp(
- log_baseline + unobserved_factors + noise + jnp.log(cnvs / 2)
- )
- sum = jnp.sum(node_mean, axis=1).reshape(self.num_data, 1)
- node_mean = node_mean / sum
- return node_mean
-
- def get_tssb_indices(self, nodes, tssbs):
- # start = time.time()
- max_len = self.max_nodes
- tssb_indices = []
- for node in nodes:
- tssb_indices.append(
- np.array([i for i, tssb in enumerate(tssbs) if tssb == node.tssb.label])
- )
- # if len(tssb_indices[-1].shape[0]) > max_len:
- # max_len = len(tssb_indices[-1].shape[0])
-
- for i, c in enumerate(tssb_indices):
- l = c.shape[0]
- if l < max_len:
- c = np.concatenate([c, np.array([-1] * (max_len - l))])
- tssb_indices[i] = c
- tssb_indices = jnp.array(tssb_indices).astype(int)
- # end = time.time()
- # print(f"get_tssb_indices: {end-start}")
- return tssb_indices
-
- def get_below_root(self, root_idx, children_vector, tssbs=None):
- def descend(idx):
- below_root = [idx]
- for child_idx in children_vector[idx]:
- if child_idx > 0:
- if tssbs is not None:
- if tssbs[child_idx] == tssbs[root_idx]:
- aux = descend(child_idx)
- below_root.extend(aux)
- else:
- aux = descend(child_idx)
- below_root.extend(aux)
- return below_root
-
- return np.array(descend(root_idx))
-
- @partial(jax.jit, static_argnums=0)
- def get_children_vector(self, parent_vector):
- def f(i):
- return jnp.where(parent_vector == i, size=self.max_nodes, fill_value=-1)[0]
-
- return jax.vmap(f)(jnp.arange(self.max_nodes))
-
- def get_ancestor_indices(self, nodes, parent_vector, inclusive=False):
- # start = time.time()
- ancestor_indices = []
- max_len = self.max_nodes
- for i in range(len(nodes)):
- # get ancestor nodes in the same subtree
- p = i
- indices = []
- while p != -1 and nodes[p].tssb == nodes[i].tssb:
- if not (not inclusive and p == i):
- indices.append(p)
- p = parent_vector[p]
-
- indices = np.array(indices)
- ancestor_indices.append(indices)
- # if len(indices) > max_len:
- # max_len = len(indices)
-
- for i, c in enumerate(ancestor_indices):
- l = c.shape[0]
- if l < max_len:
- c = np.concatenate([c, np.array([-1] * (max_len - l))])
- ancestor_indices[i] = c
- ancestor_indices = jnp.array(ancestor_indices).astype(int)
- # end = time.time()
- # print(f"get_ancestor_indices: {end-start}")
- return ancestor_indices
-
- def get_previous_branches_indices(self, nodes, within_tssb=True):
- # start = time.time()
- previous_branches_indices = []
- max_len = self.max_nodes
- for node in nodes:
- indices = []
- if within_tssb and not node.is_observed:
- children = np.array(list(node.parent().children()))
- labs = [l.label for l in children]
- children = children[np.argsort(labs)]
- for j, prev_child in enumerate(children):
- if prev_child.is_observed:
- continue
- if prev_child == node:
- break
- # Locate prev_child in nodes list
- for idx, n_ in enumerate(nodes):
- if n_ == prev_child:
- indices.append(idx)
- break
- elif not within_tssb:
- if node.parent() is not None:
- children = np.array(list(node.parent().children()))
- if len(children) > 0:
- labs = [l.label for l in children]
- children = children[np.argsort(labs)]
- for j, prev_child in enumerate(children):
- if prev_child.is_observed:
- continue
- if prev_child == node:
- break
- # Locate prev_child in nodes list
- for idx, n_ in enumerate(nodes):
- if n_ == prev_child:
- indices.append(idx)
- break
- previous_branches_indices.append(np.array(indices))
- # if len(indices) > max_len:
- # max_len = len(indices)
-
- for i, c in enumerate(previous_branches_indices):
- l = c.shape[0]
- if l < max_len:
- c = np.concatenate([c, np.array([-1] * (max_len - l))])
- previous_branches_indices[i] = c
- previous_branches_indices = jnp.array(previous_branches_indices).astype(int)
- # end = time.time()
- # print(f"get_previous_branches_indices: {end-start}")
- return previous_branches_indices
-
def Eq_log_p_nu(self, dp_alpha, nu_sticks_alpha, nu_sticks_beta):
l = 0
aux = digamma(nu_sticks_beta) - digamma(nu_sticks_alpha + nu_sticks_beta)
@@ -999,1292 +1046,664 @@ def Eq_log_q_tau(self, tau_alpha, tau_beta):
)
return l
- def tssb_log_priors(self):
- nodes, parent_vector = self.get_nodes(root_node=None, parent_vector=True)
- tssb_weights = jnp.array([node.tssb.weight for node in nodes])
- init_nu_log_alphas = jnp.array([node.nu_log_alpha for node in nodes])
- init_nu_log_betas = jnp.array([node.nu_log_beta for node in nodes])
- init_psi_log_alphas = jnp.array([node.psi_log_alpha for node in nodes])
- init_psi_log_betas = jnp.array([node.psi_log_beta for node in nodes])
- ancestor_nodes_indices = self.get_ancestor_indices(nodes, parent_vector)
- previous_branches_indices = self.get_previous_branches_indices(nodes)
-
- rem = self.max_nodes - len(nodes)
- init_psi_log_betas = jnp.concatenate(
- [init_psi_log_betas, -1 * jnp.ones((rem,))]
- )
- init_psi_log_alphas = jnp.concatenate(
- [init_psi_log_alphas, -1 * jnp.ones((rem,))]
- )
- init_nu_log_betas = jnp.concatenate([init_nu_log_betas, -1 * jnp.ones((rem,))])
- init_nu_log_alphas = jnp.concatenate(
- [init_nu_log_alphas, -1 * jnp.ones((rem,))]
- )
- tssb_weights = jnp.concatenate([tssb_weights, 10 * jnp.ones((rem,))])
- previous_branches_indices = jnp.concatenate(
- [
- previous_branches_indices,
- -1 * jnp.ones((rem, previous_branches_indices.shape[1])),
- ],
- axis=0,
- ).astype(int)
- ancestor_nodes_indices = jnp.concatenate(
- [
- ancestor_nodes_indices,
- -1 * jnp.ones((rem, ancestor_nodes_indices.shape[1])),
- ],
- axis=0,
- ).astype(int)
-
- nu_sticks = jnp.exp(init_nu_log_alphas) / (
- jnp.exp(init_nu_log_alphas) + jnp.exp(init_nu_log_betas)
- )
- psi_sticks = jnp.exp(init_psi_log_alphas) / (
- jnp.exp(init_psi_log_alphas) + jnp.exp(init_psi_log_betas)
- )
+ def make_batches(self, batch_size=None, seed=42):
+ if batch_size is None:
+ batch_size = self.num_data
+ self.batch_size = batch_size
- logpis = []
- for i, node in enumerate(nodes):
- logpis.append(
- self.tssb_log_prior(
- i,
- nu_sticks,
- psi_sticks,
- previous_branches_indices,
- ancestor_nodes_indices,
- tssb_weights,
- )
- )
- ws = list(jnp.exp(np.array(logpis)))
- return list(np.array(logpis)), ws
+ rng = np.random.RandomState(seed)
+ perm = rng.permutation(self.num_data)
- def tssb_log_prior(
- self,
- i,
- nu_sticks,
- psi_sticks,
- previous_branches_indices,
- ancestor_nodes_indices,
- tssb_weights,
- ):
- # TSSB prior
- nu_stick = nu_sticks[i]
- psi_stick = psi_sticks[i]
-
- def prev_branches_psi(idx):
- return (idx != -1) * jnp.log(1.0 - psi_sticks[idx])
+ num_complete_batches, leftover = divmod(self.num_data, self.batch_size)
+ self.num_batches = num_complete_batches + bool(leftover)
- def ancestors_nu(idx):
- _log_phi = jnp.log(psi_sticks[idx]) + jnp.sum(
- vmap(prev_branches_psi)(previous_branches_indices[idx])
- )
- _log_1_nu = jnp.log(1.0 - nu_sticks[idx])
- total = _log_phi + _log_1_nu
- return (idx != -1) * total
+ self.batch_indices = []
+ for i in range(self.num_batches):
+ batch_idx = perm[i * self.batch_size : (i + 1) * self.batch_size]
+ self.batch_indices.append(batch_idx)
- log_phi = jnp.log(psi_stick) + jnp.sum(
- vmap(prev_branches_psi)(previous_branches_indices[i])
- )
- log_node_weight = (
- jnp.log(nu_stick)
- + log_phi
- + jnp.sum(vmap(ancestors_nu)(ancestor_nodes_indices[i]))
- )
- log_node_weight = log_node_weight + jnp.log(tssb_weights[i])
+ self.reset_sufficient_statistics()
- return log_node_weight
+ def get_top_node_obs(self, q=70):
+ """
+ Get data which is very well explained by the node they attach to
+ """
+ def sub_descend(root):
+ # Get cells attached to this node
+ idx = np.where(self.assignments == root['node'])[0]
+ top_obs = root['node'].get_top_obs(q=q, idx=idx)
+ for child in root['children']:
+ top_obs = np.concatenate([top_obs,sub_descend(child)])
+ return top_obs
- def optimize_elbo(
- self,
- update_all=False,
- global_only=False,
- sticks_only=False,
- n_iters=20,
- run=True,
- n_inner_steps=50,
- n_local_traverses=1,
- **update_kwargs,
- ):
- # Init
- if not run:
- self.update_elbo(update=False, root=self.root,
- sub_root=None, n_traverses=1, update_global=False, compute_global=True,
- n_inner_steps=0, mb_size=self.data.shape[0])
- self.assign_to_best()
- return [self.elbo]
-
- # Full: alternate between globals and locals for n_iters
- elbos = []
- if update_all:
- logger.debug("Updating global and node parameters")
- update = True
- if sticks_only:
- update = False
- logger.debug("Updating only sticks")
- for i in range(n_iters):
- # Globals
- self.update_elbo(update=False, n_traverses=3, update_global=True,
- compute_global=True, n_inner_steps=0, **update_kwargs)
- # Locals
- self.update_elbo(update=update, n_traverses=n_local_traverses, update_global=False,
- compute_global=False, n_inner_steps=n_inner_steps, **update_kwargs)
- elbos.append(self.elbo)
+ def descend(root):
+ top_obs = sub_descend(root['node'].root)
+ for child in root['children']:
+ top_obs = np.concatenate([top_obs, descend(child)])
+ return top_obs
+
+ top_obs = descend(self.root)
+ top_obs = np.unique(top_obs).astype(int)
+ return top_obs
+
+ def compute_elbo(self, memoized=True, batch_idx=None, **kwargs):
+ if memoized:
+ return self.compute_elbo_suff()
else:
- if update_kwargs["root"] is not None:
- update = True
- update_global = False
- compute_global = False
- logger.debug("Updating from root")
- if global_only:
- n_inner_steps = 0
- update = False
- update_global = True
- compute_global = True
- logger.debug("Updating only global parameters")
- elif sticks_only:
- n_inner_steps = 0
- logger.debug("Updating only stick parameters")
- elbos = self.update_elbo(update=update, n_traverses=20, update_global=update_global,
- compute_global=compute_global, n_inner_steps=n_inner_steps, **update_kwargs)
- # Local only
- elif update_kwargs["sub_root"] is not None:
- if sticks_only:
- n_inner_steps = 0
- nlabel = update_kwargs["sub_root"]["node"].label
- logger.debug(f"Updating parameters below {nlabel}")
- elbos = self.update_elbo(update=True, n_traverses=20, update_global=False,
- compute_global=False, n_inner_steps=n_inner_steps, **update_kwargs)
- self.assign_to_best()
- self.plot_tree(counts=True)
- return elbos
-
- def update_elbo(self, update=True, root=None, sub_root=None, go_down=True, compute_global=False, update_global=False, restricted=False, mb_size=128, n_traverses=1, n_inner_steps=10, mc_samples=3, lr=0.01):
- if root is None and sub_root is None:
- raise ValueError("`root` and `sub_root` can't both be None!")
-
- n_cells, n_genes = self.data.shape
- data_indices = np.arange(n_cells)
- res_data_indices = np.arange(n_cells)
- mask = np.ones((n_cells,))
- probs = np.ones((n_cells,))
- probs = probs / np.sum(probs)
-
- n_factors = self.root["node"].root["node"].num_global_noise_factors
- bs_grads = jnp.zeros((2,n_genes-1))
- noise_grads = jnp.zeros((2,n_factors,n_genes))
- cellnoise_grads = jnp.zeros((2,n_cells,n_factors))
-
- # Get global variational parameters
- log_baseline_mean = jnp.array(self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"])
- log_baseline_log_std = jnp.array(self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"])
- noise_factors_mean = jnp.array(self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"])
- noise_factors_log_std = jnp.array(self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"])
-
- def sub_update_params(root, data_indices, mask, depth=0):
- data = self.data[data_indices]
- lib_sizes = self.root["node"].root["node"].lib_sizes[data_indices]
- def _sub_update_params(root, depth=0):
- if root["node"].parent() is None:
- parent_unobserved_samples = jnp.zeros((mc_samples, n_genes))
- unobserved_samples = jnp.zeros((mc_samples, n_genes))
- unobserved_kernel_samples = jnp.zeros((mc_samples, n_genes))
- else:
- # Sample parent
- if root["node"].parent().parent() is None:
- parent_unobserved_samples = jnp.zeros((mc_samples, n_genes))
- else:
- parent_unobserved_means = root["node"].parent().variational_parameters["locals"]["unobserved_factors_mean"]
- parent_unobserved_log_stds = root["node"].parent().variational_parameters["locals"]["unobserved_factors_log_std"]
- parent_unobserved_samples = sample_unobserved(rngs, parent_unobserved_means, parent_unobserved_log_stds)
-
- unobserved_means = jnp.array(root["node"].variational_parameters["locals"]["unobserved_factors_mean"])
- unobserved_log_stds = jnp.array(root["node"].variational_parameters["locals"]["unobserved_factors_log_std"])
- unobserved_factors_kernel_log_mean = jnp.array(root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_mean"])
- unobserved_factors_kernel_log_std = jnp.array(root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_std"])
-
- m1 = jnp.zeros_like(unobserved_means)
- v1 = jnp.zeros_like(unobserved_means)
- state1 = (m1,v1)
-
- m2 = jnp.zeros_like(unobserved_means)
- v2 = jnp.zeros_like(unobserved_means)
- state2 = (m2,v2)
-
- m3 = jnp.zeros_like(unobserved_means)
- v3 = jnp.zeros_like(unobserved_means)
- state3 = (m3,v3)
-
- m4 = jnp.zeros_like(unobserved_means)
- v4 = jnp.zeros_like(unobserved_means)
- state4 = (m4,v4)
- states = (state1, state2, state3, state4)
- local_grad = obs_ll_grad + parent_dep
- for i in range(n_inner_steps):
- loss, states, unobserved_means, unobserved_log_stds, unobserved_factors_kernel_log_mean, unobserved_factors_kernel_log_std = update_local_parameters(rngs,
- unobserved_means,
- unobserved_log_stds,
- unobserved_factors_kernel_log_mean,
- unobserved_factors_kernel_log_std,
- root["node"].data_weights[data_indices] * root["node"].tssb.data_weights[data_indices],
- parent_unobserved_samples,
- baseline_samples,
- cell_noise_samples,
- noise_factor_samples,
- jnp.array([self.root["node"].root["node"].unobserved_factors_kernel_concentration, self.root["node"].root["node"].unobserved_factors_kernel_rate]), # [concentration, rate]
- jnp.array(0.), # to make the prior prefer amplifications
- root["node"].cnvs,
- lib_sizes,
- data,
- mask,
- states,
- i,
- mb_scaling=np.sum(mask)/n_cells,
- lr=lr,
- )
- root["node"].variational_parameters["locals"]["unobserved_factors_mean"] = np.array(unobserved_means)
- root["node"].variational_parameters["locals"]["unobserved_factors_log_std"] = np.array(unobserved_log_stds)
- root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_mean"] = np.array(unobserved_factors_kernel_log_mean)
- root["node"].variational_parameters["locals"]["unobserved_factors_kernel_log_std"] = np.array(unobserved_factors_kernel_log_std)
-
- root["node"].set_mean(variational=True)
-
- unobserved_samples = sample_unobserved(rngs, unobserved_means, unobserved_log_stds)
- unobserved_kernel_samples = sample_unobserved_kernel(rngs, unobserved_factors_kernel_log_mean, unobserved_factors_kernel_log_std)
+ return self.compute_elbo_batch(batch_idx=batch_idx)
+ def compute_elbo_batch(self, batch_idx=None):
+ """
+ Compute the ELBO of the model in a tree traversal, abstracting away the likelihood and kernel specific functions
+ for the model. The seed is used for MC sampling from the variational distributions for which Eq[logp] is generally not analytically
+ available (which is the likelihood and the kernel distribution).
- # Compute local approximate expected log likelihood term -- unweighted
- updated_indices = data_indices[np.where(mask)[0]]
- ll_res = ll(baseline_samples, cell_noise_samples, noise_factor_samples, unobserved_samples, root["node"].cnvs, lib_sizes, data)[np.where(mask)[0]]
- root["node"].ll[updated_indices] = ll_res
-
- # Compute local approximate KL divergence term
- if root["node"].parent() is None:
- root["node"].param_kl = 0.
- else:
- root["node"].param_kl = local_paramkl(parent_unobserved_samples, unobserved_samples, unobserved_kernel_samples,
- unobserved_means,
- unobserved_log_stds,
- unobserved_factors_kernel_log_mean,
- unobserved_factors_kernel_log_std,
- jnp.array([self.root["node"].root["node"].unobserved_factors_kernel_concentration, self.root["node"].root["node"].unobserved_factors_kernel_rate]), jnp.array(0.))
-
- bs_grads = 0.
- noise_grads = 0.
- cellnoise_grads = 0.
- if update_global:
- # Compute gradient of node ell wrt globals
- data_weights = root["node"].data_weights[data_indices] * root["node"].tssb.data_weights[data_indices]
- bs_grads = baseline_node_grad(rngs, log_baseline_mean, log_baseline_log_std, data_weights, unobserved_samples, cell_noise_samples, noise_factor_samples, root["node"].cnvs, lib_sizes, data, mask)
- noise_grads = noise_node_grad(rngs, noise_factors_mean, noise_factors_log_std, data_weights, unobserved_samples, cell_noise_samples, baseline_samples, root["node"].cnvs, lib_sizes, data, mask)
- cellnoise_grads = cellnoise_node_grad(rngs, cell_noise_mean, cell_noise_log_std, data_weights, unobserved_samples, noise_factor_samples, baseline_samples, root["node"].cnvs, lib_sizes, data, mask)
-
- weight_down = 0
- indices = list(range(len(root["children"])))
- indices = indices[::-1]
-
- for i in indices:
- child = root["children"][i]
- # Go down in the tree and get its weight
- child_weight, child_bs_grads, child_noise_grads, child_cellnoise_grads = _sub_update_params(child, depth + 1)
- if update:
- post_alpha = 1.0 + child_weight
- post_beta = self.dp_gamma + weight_down
- child["node"].variational_parameters["locals"]["psi_log_mean"] = np.log(post_alpha)
- child["node"].variational_parameters["locals"]["psi_log_std"] = np.log(post_beta)
- weight_down += child_weight
- bs_grads += child_bs_grads
- noise_grads += child_noise_grads
- cellnoise_grads += child_cellnoise_grads
-
- # Compute local exact KL divergence term
- child["node"].psi_stick_kl = -beta_kl(np.exp(child["node"].variational_parameters["locals"]["psi_log_mean"]), np.exp(child["node"].variational_parameters["locals"]["psi_log_mean"]), 1, root["node"].tssb.dp_gamma)
-
- weight_here = np.sum(root["node"].data_weights)
- total_weight = weight_here + weight_down
- if update:
- post_alpha = 1.0 + weight_here
- post_beta = (self.alpha_decay**depth) * self.dp_alpha + weight_down
- root["node"].variational_parameters["locals"]["nu_log_mean"] = np.log(post_alpha)
- root["node"].variational_parameters["locals"]["nu_log_std"] = np.log(post_beta)
- root["node"].total_weight = total_weight
- root["node"].weight_down = weight_down
- root["node"].weight_here = weight_here
- root["node"].nu_stick_kl = -beta_kl(np.exp(root["node"].variational_parameters["locals"]["nu_log_mean"]), np.exp(root["node"].variational_parameters["locals"]["nu_log_std"]), 1, (root["node"].tssb.alpha_decay**depth)*root["node"].tssb.dp_alpha)
-
- return total_weight, bs_grads, noise_grads, cellnoise_grads
- return _sub_update_params(root, depth=depth)
-
- def sub_update_weights(root):
- node = [root["node"]]
-
- nu_alpha = np.exp(root["node"].variational_parameters["locals"]["nu_log_mean"])
- nu_beta = np.exp(root["node"].variational_parameters["locals"]["nu_log_std"])
- psi_alpha = np.exp(root["node"].variational_parameters["locals"]["psi_log_mean"])
- psi_beta = np.exp(root["node"].variational_parameters["locals"]["psi_log_std"])
-
- # Compute expected local NTSSB weight term
- E_log_psi = E_q_log_beta(psi_alpha, psi_beta)
- E_log_nu = E_q_log_beta(nu_alpha, nu_beta)
- E_log_1_nu = E_q_log_1_beta(nu_alpha, nu_beta)
-
- if root["node"].is_observed:
- ancestors_E_log_1_nu = 0.
- ancestors_and_this_E_log_phi = 0.
- root["node"].ancestors_and_this_E_log_1_nu = E_log_1_nu
- root["node"].ancestors_and_this_E_log_phi = 0.
- root["node"].psi_not_prev_sum = 0.
+ If batch_idx is not None, return an estimate of the ELBO based on just the subset of data in batch_idx.
+ Otherwise, use sufficient statistics.
+ """
+ if batch_idx is None:
+ idx = jnp.arange(self.num_data)
+ else:
+ idx = self.batch_indices[batch_idx]
+ def descend(root, depth=0, local_contrib=0, global_contrib=0):
+ # Traverse inner TSSB
+ subtree_ll_contrib, subtree_ass_contrib, subtree_node_contrib = root['node'].compute_elbo(idx)
+ ll_contrib = subtree_ll_contrib * root['node'].variational_parameters['q_c'][idx]
+
+ # Assignments
+ ## E[log p(c|nu,psi)]
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ eq_logp_c = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ if root['node'].parent() is not None:
+ eq_logp_c += root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ eq_logp_c += root['node'].variational_parameters['E_log_phi']
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu']
else:
- ancestors_E_log_1_nu = root["node"].parent().ancestors_and_this_E_log_1_nu
- root["node"].ancestors_and_this_E_log_1_nu = ancestors_E_log_1_nu + E_log_nu
- root["node"].ancestors_and_this_E_log_phi = root["node"].parent().ancestors_and_this_E_log_phi + root["node"].psi_not_prev_sum + E_log_psi
- ancestors_and_this_E_log_phi = root["node"].ancestors_and_this_E_log_phi
-
- # root["node"].weight_until_here = root["node"].data_weights + until_here
- # root["node"].prior_weight = root["node"].data_weights*E_log_nu + until_here*E_log_1_nu + root["node"].weight_until_here*E_log_phi
- root["node"].prior_weight = E_log_nu + ancestors_E_log_1_nu + ancestors_and_this_E_log_phi
- root["node"].ew = root["node"].data_weights*E_log_nu + root["node"].weight_down*E_log_1_nu + root["node"].total_weight*E_log_psi
- # root["node"].data_weights = root["node"].ll + root["node"].prior_weight
- root["node"].unnormalized_data_weights = root["node"].ll + root["node"].prior_weight
-
- data_weight = [root["node"].unnormalized_data_weights]
- psi_not_prev_sum = 0
- for i, child in enumerate(root["children"]):
- nu_alpha = np.exp(child["node"].variational_parameters["locals"]["nu_log_mean"])
- nu_beta = np.exp(child["node"].variational_parameters["locals"]["nu_log_std"])
- psi_alpha = np.exp(child["node"].variational_parameters["locals"]["psi_log_mean"])
- psi_beta = np.exp(child["node"].variational_parameters["locals"]["psi_log_std"])
-
- if i > 0:
- psi_not_prev_sum += E_q_log_1_beta(psi_alpha, psi_beta)
- child["node"].psi_not_prev_sum = psi_not_prev_sum
-
- # Go down in the tree
- nodes, data_weights = sub_update_weights(child)
- node.extend(nodes)
- data_weight.extend(data_weights)
- return node, data_weight
-
- def sub_normalize_update_elbo(nodes):
- data_weights = np.vstack([node.unnormalized_data_weights for node in nodes]).T
- data_weights = np.exp(data_weights - jnn.logsumexp(data_weights,axis=1).reshape(-1,1))
- tree_ell = 0
- tree_ew = 0
- tree_kl = 0
- for i, node in enumerate(nodes):
- node.data_weights = data_weights[:,i]
- node.ell = node.data_weights*node.ll
- tree_ell += node.ell
- tree_ew += node.data_weights*node.prior_weight
- tree_kl += node.nu_stick_kl + node.psi_stick_kl + node.param_kl
- return tree_ell, tree_ew, tree_kl
-
- def tssb_normalize_update_elbo(tssbs):
- data_weights = np.vstack([tssb.unnormalized_data_weights for tssb in tssbs]).T
- data_weights = np.exp(data_weights - jnn.logsumexp(data_weights,axis=1).reshape(-1,1))
- total_elbo = 0
- total_ell = 0
- total_kl = 0
- for i, tssb in enumerate(tssbs):
- tssb.data_weights = data_weights[:,i]
- tssb.elbo = np.sum(tssb.data_weights * tssb.ell + tssb.data_weights * jnp.log(tssb.weight) + tssb.ew) + tssb.kl
- total_elbo += tssb.elbo
- total_ell += np.sum(tssb.data_weights * tssb.ell + tssb.data_weights * jnp.log(tssb.weight) + tssb.ew)
- total_kl += tssb.kl
- return total_elbo, total_ell, total_kl
-
- def descend(super_root, elbo=0):
- _, bs_grads, noise_grads, cellnoise_grads = sub_update_params(super_root["node"].root, data_indices, mask)#sub_update_params(super_root["node"].root)
-
- nodes, data_weights = sub_update_weights(super_root["node"].root)
- tree_ell, tree_ew, tree_kl = sub_normalize_update_elbo(nodes)
- super_root["node"].ell = tree_ell
- super_root["node"].ew = tree_ew
- super_root["node"].kl = tree_kl
-
- # Update TSSB assignment
- super_root["node"].unnormalized_data_weights = tree_ell + np.log(super_root["node"].weight)
-
- for super_child in super_root["children"]:
- child_bs_grads, child_noise_grads, child_cellnoise_grads = descend(super_child)
- bs_grads += child_bs_grads
- noise_grads += child_noise_grads
- cellnoise_grads += child_cellnoise_grads
-
- return bs_grads, noise_grads, cellnoise_grads
-
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu
+ ## E[log q(c)]
+ eq_logq_c = jax.lax.select(root['node'].variational_parameters['q_c'][idx] != 0,
+ root['node'].variational_parameters['q_c'][idx] * jnp.log(root['node'].variational_parameters['q_c'][idx]),
+ root['node'].variational_parameters['q_c'][idx])
+ ass_contrib = eq_logp_c*root['node'].variational_parameters['q_c'][idx] - eq_logq_c + subtree_ass_contrib * root['node'].variational_parameters['q_c'][idx]
+
+ # Sticks
+ E_log_nu = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ nu_kl = (self.dp_alpha * self.alpha_decay**depth - root['node'].variational_parameters['delta_2']) * E_log_1_nu
+ nu_kl -= (root['node'].variational_parameters['delta_1'] - 1) * E_log_nu
+ nu_kl += logbeta_func(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ nu_kl -= logbeta_func(1, self.dp_alpha * self.alpha_decay**depth)
+ psi_kl = 0.
+ if depth != 0:
+ E_log_psi = E_log_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ E_log_1_psi = E_log_1_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ psi_kl = (self.dp_gamma - root['node'].variational_parameters['sigma_2']) * E_log_1_psi
+ psi_kl -= (root['node'].variational_parameters['sigma_1'] - 1) * E_log_psi
+ psi_kl += logbeta_func(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ psi_kl -= logbeta_func(1, self.dp_gamma)
+ stick_contrib = nu_kl + psi_kl
+
+ self.n_total_nodes += root['node'].n_nodes
+ sum_E_log_1_psi = 0.
+ for child in root['children']:
+ # Auxiliary quantities
+ ## Branches
+ E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
+ E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ sum_E_log_1_psi += E_log_1_psi
+
+ # Go down
+ local_contrib, global_contrib = descend(child, depth=depth+1, local_contrib=local_contrib, global_contrib=global_contrib)
+
+ local_contrib += ll_contrib + ass_contrib
+ global_contrib += subtree_node_contrib + stick_contrib
+ return local_contrib, global_contrib
+
+ self.n_total_nodes = 0
+ local_contrib, global_contrib = descend(self.root)
+
+ # Add tree-independent contributions
+ global_contrib += self.num_data/len(idx) * (self.root['node'].root['node'].compute_local_priors(idx) + self.root['node'].root['node'].compute_local_entropies(idx))
+ global_contrib += self.root['node'].root['node'].compute_global_priors() + self.root['node'].root['node'].compute_global_entropies()
+
+ elbo = self.num_data/len(idx) * np.sum(local_contrib) + global_contrib
+ self.elbo = elbo
+ return elbo
+
+ def compute_elbo_suff(self):
+ def descend(root, depth=0, local_contrib=0, global_contrib=0):
+ # Traverse inner TSSB
+ ll_contrib, subtree_ass_contrib, subtree_node_contrib = root['node'].compute_elbo_suff()
+
+ # Assignments
+ ## E[log p(c|nu,psi)]
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ eq_logp_c = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ if root['node'].parent() is not None:
+ eq_logp_c += root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ eq_logp_c += root['node'].variational_parameters['E_log_phi']
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ else:
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu
+ ass_contrib = eq_logp_c*root['node'].suff_stats['mass']['total'] + root['node'].suff_stats['ent']['total'] + subtree_ass_contrib
+
+ # Sticks
+ E_log_nu = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ nu_kl = (self.dp_alpha * self.alpha_decay**depth - root['node'].variational_parameters['delta_2']) * E_log_1_nu
+ nu_kl -= (root['node'].variational_parameters['delta_1'] - 1) * E_log_nu
+ nu_kl += logbeta_func(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ nu_kl -= logbeta_func(1, self.dp_alpha * self.alpha_decay**depth)
+ psi_kl = 0.
+ if depth != 0:
+ E_log_psi = E_log_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ E_log_1_psi = E_log_1_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ psi_kl = (self.dp_gamma - root['node'].variational_parameters['sigma_2']) * E_log_1_psi
+ psi_kl -= (root['node'].variational_parameters['sigma_1'] - 1) * E_log_psi
+ psi_kl += logbeta_func(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ psi_kl -= logbeta_func(1, self.dp_gamma)
+ stick_contrib = nu_kl + psi_kl
+
+ self.n_total_nodes += root['node'].n_nodes
+ sum_E_log_1_psi = 0.
+ for child in root['children']:
+ # Auxiliary quantities
+ ## Branches
+ E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
+ E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ sum_E_log_1_psi += E_log_1_psi
+
+ # Go down
+ local_contrib, global_contrib = descend(child, depth=depth+1, local_contrib=local_contrib, global_contrib=global_contrib)
+
+ local_contrib += ll_contrib + ass_contrib
+ global_contrib += subtree_node_contrib + stick_contrib
+ return local_contrib, global_contrib
+
+ self.n_total_nodes = 0
+ local_contrib, global_contrib = descend(self.root)
+
+ # Add tree-independent contributions
+ global_contrib += self.root['node'].root['node'].local_suff_stats['locals_kl']['total']
+ global_contrib += self.root['node'].root['node'].compute_global_priors() + self.root['node'].root['node'].compute_global_entropies()
+
+ elbo = local_contrib + global_contrib
+ self.elbo = elbo
+ return elbo
+
+ def learn_model(self, n_epochs, seed=42, memoized=True, update_roots=True, update_globals=True, adaptive=True, return_trace=False,
+ locals_names=None, globals_names=None, **kwargs):
+ key = jax.random.PRNGKey(seed)
elbos = []
- ells = []
- kls = []
- if root:
- if update_global:
- m1 = jnp.zeros((n_genes-1,))
- v1 = jnp.zeros((n_genes-1,))
- state1 = (m1,v1)
- m2 = jnp.zeros((n_genes-1,))
- v2 = jnp.zeros((n_genes-1,))
- state2 = (m2,v2)
- bs_states = (state1, state2)
-
- m1 = jnp.zeros((n_factors,n_genes))
- v1 = jnp.zeros((n_factors,n_genes))
- state1 = (m1,v1)
- m2 = jnp.zeros((n_factors,n_genes))
- v2 = jnp.zeros((n_factors,n_genes))
- state2 = (m2,v2)
- noise_states = (state1, state2)
-
- m1 = jnp.zeros((n_cells,n_factors))
- v1 = jnp.zeros((n_cells,n_factors))
- state1 = (m1,v1)
- m2 = jnp.zeros((n_cells,n_factors))
- v2 = jnp.zeros((n_cells,n_factors))
- state2 = (m2,v2)
- cellnoise_states = (state1, state2)
- for i in range(n_traverses):
- rng = random.PRNGKey(i)
- rngs = random.split(rng, mc_samples)
-
- if update or update_global:
- # Setup minibatch optimization
- mask = np.ones((n_cells,))
- mask[res_data_indices] = 1
- data_indices = np.sort(random.choice(rng, n_cells, shape=(mb_size,), p=probs, replace=False))
- mask = mask[data_indices]
-
- # Get cell variational parameters
- cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"][data_indices]
- cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"][data_indices]
- cell_noise_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std)
-
- # Get data
- data = self.data[data_indices]
- lib_sizes = self.root["node"].root["node"].lib_sizes[data_indices]
-
- # Sample global
- baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std)
- noise_factor_samples = sample_noise_factors(rngs, noise_factors_mean, noise_factors_log_std)
-
- if compute_global:
- # Compute global KL
- self.global_kl = -jnp.sum(baseline_kl(log_baseline_mean, log_baseline_log_std))
- self.global_kl += -jnp.sum(noise_factors_kl(noise_factors_mean, noise_factors_log_std))
- self.cell_kl = -jnp.sum(cell_noise_kl(cell_noise_mean, cell_noise_log_std))
- self.global_kl += self.cell_kl
-
- # Traverse tree
- bs_grads, noise_grads, cellnoise_grads = descend(root)
-
- # Normalize across-tssbs and update global ELBOs
- self.elbo, ell, kl = tssb_normalize_update_elbo(self.get_subtrees())
- self.elbo += self.global_kl
- elbos.append(self.elbo)
-
- if update_global:
- # Update baseline
- bs_klgrad = baseline_kl_grad(log_baseline_mean, log_baseline_log_std)
- bs_klgrad *= len(data_indices)/n_cells # minibatch scaling
- bs_grads += bs_klgrad
- bs_states, log_baseline_mean, log_baseline_log_std = baseline_step(log_baseline_mean, log_baseline_log_std, bs_grads, bs_states, i, lr=lr)
- #
- # # Update noise factors
- # noise_klgrad = noise_kl_grad(noise_factors_mean, noise_factors_log_std)
- # noise_klgrad *= len(data_indices)/n_cells # minibatch scaling
- # noise_grads += noise_klgrad
- # noise_states, noise_factors_mean, noise_factors_log_std = noise_step(noise_factors_mean, noise_factors_log_std, noise_grads, noise_states, i, lr=lr)
- #
- # # Update cell noise
- # cellnoise_klgrad = cellnoise_kl_grad(cell_noise_mean, cell_noise_log_std)
- # cellnoise_grads += cellnoise_klgrad
- # local_state1, local_state2 = cellnoise_states
- # local_state1 = (local_state1[0][data_indices], local_state1[1][data_indices])
- # local_state2 = (local_state2[0][data_indices], local_state2[1][data_indices])
- # local_states = (local_state1, local_state2)
- # local_states, cell_noise_mean, cell_noise_log_std = cellnoise_step(cell_noise_mean, cell_noise_log_std, cellnoise_grads, local_states, i, lr=lr)
- # cellnoise_states[0][0].at[data_indices].set(local_states[0][0])
- # cellnoise_states[0][1].at[data_indices].set(local_states[0][1])
- # cellnoise_states[1][0].at[data_indices].set(local_states[1][0])
- # cellnoise_states[1][1].at[data_indices].set(local_states[1][1])
-
- self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_mean"] = log_baseline_mean
- self.root["node"].root["node"].variational_parameters["globals"]["log_baseline_log_std"] = log_baseline_log_std
-
- # self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_mean"] = noise_factors_mean
- # self.root["node"].root["node"].variational_parameters["globals"]["noise_factors_log_std"] = noise_factors_log_std
- #
- # self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"][data_indices] = cell_noise_mean
- # self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"][data_indices] = cell_noise_log_std
-
- if sub_root:
- rng = random.PRNGKey(0)
- rngs = random.split(rng, mc_samples)
-
- # Sample global
- baseline_samples = sample_baseline(rngs, log_baseline_mean, log_baseline_log_std)
- noise_factor_samples = sample_noise_factors(rngs, noise_factors_mean, noise_factors_log_std)
-
- if go_down:
- this_root = sub_root["node"].tssb.get_ntssb_root()
-
- if restricted:
- # Get data in nodes below current one
- res_data_indices = set()
- def data_below(node_obj):
- res_data_indices.update(node_obj.data)
- for child in node_obj.children():
- if child.tssb != sub_root["node"].tssb:
- if go_down:
- data_below(child)
- else:
- data_below(child)
- data_below(sub_root["node"])
- res_data_indices = np.sort(jnp.array(list(res_data_indices)))
- if len(res_data_indices) == 0:
- res_data_indices = np.arange(n_cells)
- mask[res_data_indices] = 1
- probs = np.ones((n_cells,))
- probs[res_data_indices] = 1e6
- probs = probs / np.sum(probs)
-
- for i in range(n_traverses):
- rng = random.PRNGKey(i)
- rngs = random.split(rng, mc_samples)
- mask = np.zeros((n_cells,))
- mask[res_data_indices] = 1
- data_indices = np.sort(random.choice(rng, n_cells, shape=(mb_size,), p=probs, replace=False))
- mask = mask[data_indices]
-
- # Get cell variational parameters
- cell_noise_mean = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_mean"][data_indices]
- cell_noise_log_std = self.root["node"].root["node"].variational_parameters["globals"]["cell_noise_log_std"][data_indices]
- cell_noise_samples = sample_cell_noise_factors(rngs, cell_noise_mean, cell_noise_log_std)
-
- # Get data
- data = self.data[data_indices]
- lib_sizes = self.root["node"].root["node"].lib_sizes[data_indices]
-
- # Update node parameters
- sub_update_params(sub_root, data_indices, mask, depth=len(sub_root["node"].label)-1)
-
- # Update within-tssb data assignment probabilities
- nodes, data_weights = sub_update_weights(sub_root)
-
- # Normalize within-tssb and update local ELBOs
- # This can be done approximately: no need to re-normalize across all the others if their relative contributions
- # didn't change and we end up with weights close to 0. This applies only to restricted ELBO updates on selected data
- tree_ell, tree_ew, tree_kl = sub_normalize_update_elbo(sub_root["node"].tssb.get_mixture()[1])
- sub_root["node"].tssb.ell = tree_ell
- sub_root["node"].tssb.ew = tree_ew
- sub_root["node"].tssb.kl = tree_kl
-
- # Update TSSB assignment
- sub_root["node"].tssb.unnormalized_data_weights = tree_ell + np.log(sub_root["node"].tssb.weight)
-
- # Maybe continue down to any children subtrees?
- if go_down:
- for sub_child in this_root["children"]:
- # Only continue to children subtrees in sub_root's path
- if sub_root["node"].label in sub_child["pivot_node"].label:
- logger.debug(f"Also updating parameters of {sub_child['node'].label} and down")
- descend(sub_child)
-
- # Normalize across-tssbs and update global ELBOs
- self.elbo, ell, kl = tssb_normalize_update_elbo(self.get_subtrees())
- self.elbo += self.global_kl
- elbos.append(self.elbo)
- ells.append(ell)
- kls.append(kl)
-
- return elbos, ells, kls
-
- def _optimize_elbo(
- self,
- root_node=None,
- local_node=None,
- global_only=False,
- sticks_only=False,
- unique_node=None,
- num_samples=10,
- n_iters=100,
- thin=10,
- step_size=0.05,
- debug=False,
- tol=1e-5,
- run=True,
- max_nodes=5,
- init=False,
- opt=None,
- opt_triplet=None,
- mb_size=100,
- callback=None,
- **callback_kwargs,
- ):
- # start = time.time()
- self.max_nodes = (
- len(self.input_tree_dict.keys()) * max_nodes
- ) # upper bound on number of nodes
- self.data = jnp.array(self.data, dtype="float32")
+ states = None
+ local_states = None
+ if adaptive:
+ local_states = self.root['node'].root['node'].initialize_local_opt_states(param_names=locals_names)
+ states = self.root['node'].root['node'].initialize_global_opt_states(param_names=globals_names)
+ it = 0
+ idx = None
+ for i in range(n_epochs):
+ for batch_idx in range(self.num_batches):
+ key, subkey = jax.random.split(key)
+ local_states = self.update_local_params(subkey, batch_idx=batch_idx, adaptive=adaptive, states=local_states, i=it,
+ param_names=locals_names, update_globals=update_globals, **kwargs)
+ if update_globals:
+ states = self.update_global_params(subkey, idx=idx, batch_idx=batch_idx, adaptive=adaptive, states=states, i=it, param_names=globals_names, **kwargs)
+ if memoized:
+ self.update_sufficient_statistics(batch_idx=batch_idx)
+ self.update_node_params(subkey, i=it, adaptive=adaptive, memoized=memoized, **kwargs)
+ if update_roots:
+ self.update_root_node_params(subkey, memoized=memoized, batch_idx=batch_idx, adaptive=adaptive, i=it, **kwargs)
+ self.update_pivot_probs()
+ it += 1
+ if return_trace:
+ elbos.append(self.compute_elbo(memoized=memoized, batch_idx=batch_idx))
+
+ if return_trace:
+ return elbos
- # Var params of nodes below root
- nodes, parent_vector = self.get_nodes(root_node=None, parent_vector=True)
+ def learn_roots(self, n_epochs, seed=42, adaptive=True, return_trace=False, **kwargs):
+ key = jax.random.PRNGKey(seed)
+ elbos = []
+ it = 0
+ for i in range(n_epochs):
+ for batch_idx in range(self.num_batches):
+ key, subkey = jax.random.split(key)
+ self.update_root_node_params(subkey, batch_idx=batch_idx, adaptive=adaptive, i=it, **kwargs)
+ if return_trace:
+ elbos.append(self.compute_elbo(batch_idx=batch_idx, **kwargs))
+ it += 1
+
+ if return_trace:
+ return elbos
- n_nodes = len(nodes)
- rem = self.max_nodes - n_nodes
+ def learn_globals(self, n_epochs, globals_names=None, locals_names=None, ass_anneal=1., ent_anneal=1., update_ass=True, update_locals=True, update_roots=False, subset=None, adaptive=True, seed=42, return_trace=False, **kwargs):
+ key = jax.random.PRNGKey(seed)
+ elbos = []
+ states = None
+ local_states = None
+ if adaptive:
+ local_states = self.root['node'].root['node'].initialize_local_opt_states(param_names=locals_names)
+ states = self.root['node'].root['node'].initialize_global_opt_states(param_names=globals_names)
+ it = 0
+ idx = None
+ for i in range(n_epochs):
+ for batch_idx in range(self.num_batches):
+ key, subkey = jax.random.split(key)
+ if subset is not None:
+ idx = jnp.array(list(set(self.batch_indices[batch_idx]).intersection(set(subset))))
+ local_states = self.update_local_params(subkey, idx=idx, batch_idx=batch_idx, ass_anneal=ass_anneal, ent_anneal=ent_anneal,
+ update_ass=update_ass, update_globals=update_locals, adaptive=adaptive, states=local_states, i=it,
+ param_names=locals_names, **kwargs)
+ states = self.update_global_params(subkey, idx=idx, batch_idx=batch_idx, adaptive=adaptive, states=states, i=it, param_names=globals_names, **kwargs)
+ if update_roots:
+ self.update_root_node_params(subkey, memoized=False, adaptive=adaptive, i=it, **kwargs)
+ it += 1
+ if return_trace:
+ elbos.append(self.compute_elbo_batch(batch_idx=batch_idx))
+
+ if return_trace:
+ return elbos
+
+ def learn_params(self, n_epochs, seed=42, memoized=True, adaptive=True, update_roots=False, return_trace=False, **kwargs):
+ key = jax.random.PRNGKey(seed)
+ elbos = []
+ it = 0
+ for i in range(n_epochs):
+ for batch_idx in range(self.num_batches):
+ key, subkey = jax.random.split(key)
+ self.update_local_params(subkey, batch_idx=batch_idx, update_globals=False, **kwargs)
+ if memoized:
+ self.update_sufficient_statistics(batch_idx=batch_idx)
+ if update_roots:
+ self.update_root_node_params(subkey, memoized=memoized, adaptive=adaptive, i=it, **kwargs)
+ self.update_node_params(subkey, i=it, adaptive=adaptive, memoized=memoized, **kwargs)
+ self.update_pivot_probs()
+ it += 1
+ if return_trace:
+ elbos.append(self.compute_elbo(memoized=memoized, batch_idx=batch_idx))
+
+ if return_trace:
+ return elbos
- # Root node label
- data_indices = np.array(list(range(self.num_data)))
- do_global = False
- if root_node is not None:
- root_label = root_node.label
- data_indices = set()
+ def update_sufficient_statistics(self, batch_idx=None):
+ """
+ Go to each node and update its sufficient statistics. Set the suff stats of batch_idx,
+ and update the total suff stats
+ """
+ def descend(root):
+ root['node'].update_sufficient_statistics(batch_idx=batch_idx)
+ for child in root['children']:
+ descend(child)
+
+ descend(self.root)
+ def update_local_params(self, key, ass_anneal=1., ent_anneal=1., idx=None, batch_idx=None, states=None, adaptive=False, i=0, step_size=0.0001, mc_samples=10, update_ass=True, update_outer_ass=True, update_globals=True, **kwargs):
+ """
+ This performs a tree traversal to update the sample to node attachment probabilities and other sample-specific parameters
+ """
+ if idx is None:
+ if batch_idx is None:
+ idx = jnp.arange(self.num_data)
+ else:
+ idx = self.batch_indices[batch_idx]
+
+ take_gradients = False
+ if update_globals:
+ take_gradients = True
+
+ def descend(root, local_grads=None):
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ logprior = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ if root['node'].parent() is not None:
+ logprior += root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ logprior += root['node'].variational_parameters['E_log_phi']
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ else:
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu
+
+ # Traverse inner TSSB
+ ll_node_sum, ent_node_sum, local_grads_down = root['node'].update_local_params(idx, update_ass=update_ass, take_gradients=take_gradients)
+ new_log_probs = ass_anneal*(ll_node_sum + logprior + ent_node_sum)
+ if update_ass and update_outer_ass:
+ root['node'].variational_parameters['q_c'] = root['node'].variational_parameters['q_c'].at[idx].set(new_log_probs)
+ if local_grads_down is not None:
+ if local_grads is None:
+ local_grads = list(local_grads_down)
+ else:
+ for ii, grads in enumerate(list(local_grads_down)):
+ local_grads[ii] += grads
+ logqs = [new_log_probs]
+ sum_E_log_1_psi = 0.
+ for child in root['children']:
+ E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
+ E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ sum_E_log_1_psi += E_log_1_psi
+
+ # Go down
+ child_log_probs, child_local_grads = descend(child, local_grads=local_grads)
+ logqs.extend(child_log_probs)
+ if child_local_grads is not None:
+ for ii, grads in enumerate(list(child_local_grads)):
+ local_grads[ii] += grads
+
+ return logqs, local_grads # batch-sized
+
+ if update_globals:
+ # Take MC sample and its gradient wrt parameters
+ key, sample_grad = self.root['node'].root['node'].local_sample_and_grad(idx, key, n_samples=mc_samples)
+ locals_curr_sample, locals_params_grad = sample_grad
+ self.root['node'].root['node'].set_local_sample(locals_curr_sample, idx=idx)
+
+ # Traverse tree
+ logqs, locals_sample_grad = descend(self.root)
+
+ if update_globals:
+ locals_prior_grads = self.root['node'].root['node'].compute_locals_prior_grad(locals_curr_sample)
+ for ii, grads in enumerate(locals_prior_grads):
+ locals_sample_grad[ii] += grads
+
+ # Gradient of entropy wrt parameters
+ locals_entropies_grad = self.root['node'].root['node'].compute_locals_entropy_grad(idx)
+
+ # Take gradient step
+ if adaptive:
+ states = self.root['node'].root['node'].update_local_params_adaptive(idx, locals_params_grad, locals_sample_grad, locals_entropies_grad, ent_anneal=ent_anneal, step_size=step_size,
+ states=states, i=i, **kwargs)
+ else:
+ self.root['node'].root['node'].update_local_params(idx, locals_params_grad, locals_sample_grad, locals_entropies_grad, ent_anneal=ent_anneal, step_size=step_size, **kwargs)
+
+ # Resample and store
+ key, sample_grad = self.root['node'].root['node'].local_sample_and_grad(idx, key, n_samples=mc_samples)
+ locals_curr_sample, _ = sample_grad
+ self.root['node'].root['node'].set_local_sample(locals_curr_sample, idx=idx)
+
+ if update_ass and update_outer_ass:
+ # Compute LSE
+ logqs = jnp.array(logqs).T
+ self.variational_parameters['LSE_c'] = jax.scipy.special.logsumexp(logqs, axis=1)
+ # Set probs
def descend(root):
- data_indices.update(root.data)
- for child in root.children():
+ newvals = jnp.exp(root['node'].variational_parameters['q_c'][idx] - self.variational_parameters['LSE_c'])
+ root['node'].variational_parameters['q_c'] = root['node'].variational_parameters['q_c'].at[idx].set(newvals)
+ for child in root['children']:
descend(child)
+ descend(self.root)
+
+ return states
- descend(root_node)
- data_indices = np.sort(list(data_indices))
- if len(data_indices) == 0 and root_node.parent() is not None:
- data_indices = np.sort(list(root_node.parent().data))
- # data_indices = list(root_node.data)
- else:
- do_global = True
- root_label = self.root["node"].label
- root_node = self.root["node"].root["node"]
-
- if local_node is not None:
- do_global = False
- root_node = local_node.parent()
- root_label = root_node.label
- init_ass_logits = np.array([node.data_ass_logits for node in nodes]).T
-
- @jax.jit
- def f(a):
- return jnn.softmax(a, axis=1)
-
- init_ass_probs = f(init_ass_logits)
- data_indices = list(
- np.where(
- init_ass_probs[:, np.where(root_node == np.array(nodes))[0][0]]
- > 1.0 / np.sqrt(len(nodes))
- )[0]
- )
- if len(data_indices) == 0:
- data_indices = np.arange(self.num_data)
-
- data_mask = np.zeros((self.num_data,))
- data_mask[data_indices] = 1.0
- data_mask = data_mask.astype(int)
-
- parent_vector = np.array(parent_vector)
- parent_vector = jnp.array(
- np.concatenate([parent_vector, -2 * np.ones((rem,))])
- ).astype(int)
-
- tssbs = [node.tssb.label for node in nodes]
- tssb_indices = self.get_tssb_indices(nodes, tssbs)
- # start3 = time.time()
- children_vector = self.get_children_vector(parent_vector)
- # end3 = time.time()
- # print(f"get_children_vector: {end3-start3}")
- ancestor_nodes_indices = self.get_ancestor_indices(nodes, parent_vector)
- previous_branches_indices = self.get_previous_branches_indices(nodes)
- node_idx = np.where(np.array(nodes) == root_node)[0][0]
- node_mask_idx = node_idx
- if local_node is None:
- node_mask_idx = self.get_below_root(node_idx, children_vector, tssbs=None)
+ def update_global_params(self, key, idx=None, batch_idx=None, mc_samples=10, adaptive=False, step_size=0.0001, states=None, i=0, **kwargs):
+ """
+ This performs a tree traversal to update the global parameters.
+ The global parameters are updated using stochastic mini-batch VI.
+ """
+ if idx is None:
+ if batch_idx is None:
+ idx = jnp.arange(self.num_data)
+ else:
+ idx = self.batch_indices[batch_idx]
+ def descend(root, globals_grads=None):
+ globals_grads_down = root['node'].get_global_grads(idx)
+ if globals_grads_down is not None:
+ if globals_grads is None:
+ globals_grads = list(globals_grads_down)
+ else:
+ for ii, grads in enumerate(list(globals_grads_down)):
+ globals_grads[ii] += grads
+ for child in root['children']:
+ child_globals_grads = descend(child, globals_grads=globals_grads)
+ if child_globals_grads is not None:
+ for ii, grads in enumerate(list(child_globals_grads)):
+ globals_grads[ii] += grads
+ return globals_grads
+
+ # Take MC sample and its gradient wrt parameters
+ key, sample_grad = self.root['node'].root['node'].global_sample_and_grad(key, n_samples=mc_samples)
+ globals_curr_sample, globals_params_grad = sample_grad
+ self.root['node'].root['node'].set_global_sample(globals_curr_sample)
+
+ # Get gradient of loss of data likelihood weighted by assignment probability to each node wrt current sample of global params
+ globals_sample_grad = descend(self.root)
+
+ # Scale gradient by batch size
+ for ii in range(len(globals_sample_grad)):
+ globals_sample_grad[ii] *= self.num_data/len(idx)
+
+ # Gradient of prior wrt sample
+ globals_prior_grads = self.root['node'].root['node'].compute_globals_prior_grad(globals_curr_sample)
+ len_globals_in_ll = len(globals_sample_grad)
+ # Add the priors
+ for ii in range(len_globals_in_ll):
+ globals_sample_grad[ii] += globals_prior_grads[ii]
+
+ # Extend to hierarchical parameters
+ for grads in globals_prior_grads[len_globals_in_ll:]:
+ globals_sample_grad.append(grads)
+
+ # Gradient of entropy wrt parameters
+ globals_entropies_grad = self.root['node'].root['node'].compute_globals_entropy_grad()
+
+ # Take gradient step
+ if adaptive:
+ states = self.root['node'].root['node'].update_global_params_adaptive(globals_params_grad, globals_sample_grad,
+ globals_entropies_grad, step_size=step_size, states=states, i=i, **kwargs)
else:
- local_node_idx = np.where(np.array(nodes) == local_node)[0][0]
- node_mask_idx = np.array([node_idx, local_node_idx])
-
- if unique_node is not None:
- node_idx = np.where(np.array(nodes) == unique_node)[0][0]
- node_mask_idx = node_idx
- do_global = False
- data_indices = list(unique_node.data)
- data_mask = np.zeros((self.num_data,))
- data_mask[data_indices] = 1.0
- data_mask = data_mask.astype(int)
- sticks_only = True
-
- # start2 = time.time()
- node_mask = np.zeros((len(nodes),))
- node_mask[node_mask_idx] = 1
-
- dp_alphas = np.array([node.tssb.dp_alpha for node in nodes])
- dp_gammas = np.array([node.tssb.dp_gamma for node in nodes])
- tssb_weights = np.array([node.tssb.weight for node in nodes])
-
- global_names = list(nodes[0].variational_parameters["globals"].keys())
- global_params = [
- jnp.array(nodes[0].variational_parameters["globals"][param])
- for param in global_names
- ]
+ states = self.root['node'].root['node'].update_global_params(globals_params_grad, globals_sample_grad,
+ globals_entropies_grad, step_size=step_size, **kwargs)
- local_names = list(nodes[0].variational_parameters["locals"].keys())
- local_params = []
- for node in nodes:
- local_params.append(
- [node.variational_parameters["locals"][param] for param in local_names]
- )
+ # Resample and store
+ key, sample_grad = self.root['node'].root['node'].global_sample_and_grad(key, n_samples=mc_samples)
+ globals_curr_sample, _ = sample_grad
+ self.root['node'].root['node'].set_global_sample(globals_curr_sample)
- obs_params = np.array([node.observed_parameters for node in nodes])
+ if adaptive:
+ return states
- tssb_weights = jnp.array(np.concatenate([tssb_weights, 10 * np.ones((rem,))]))
- dp_gammas = jnp.array(np.concatenate([dp_gammas, 1 * np.ones((rem,))]))
- dp_alphas = jnp.array(np.concatenate([dp_alphas, 1 * np.ones((rem,))]))
- previous_branches_indices = jnp.array(
- np.concatenate(
- [
- previous_branches_indices,
- -1 * np.ones((rem, previous_branches_indices.shape[1])),
- ],
- axis=0,
- )
- ).astype(int)
- ancestor_nodes_indices = jnp.array(
- np.concatenate(
- [
- ancestor_nodes_indices,
- -1 * np.ones((rem, ancestor_nodes_indices.shape[1])),
- ],
- axis=0,
- )
- ).astype(int)
- # children_vector = jnp.concatenate([children_vector, -1*jnp.ones((rem, children_vector.shape[1]))], axis=0).astype(int)
- tssb_indices = jnp.array(
- np.concatenate(
- [tssb_indices, -1 * np.ones((rem, tssb_indices.shape[1]))], axis=0
- )
- ).astype(int)
- obs_params = jnp.array(
- np.concatenate(
- [obs_params, np.zeros((rem, nodes[0].observed_parameters.size))], axis=0
- )
- )
- node_mask = jnp.array(np.concatenate([node_mask, -2 * np.ones((rem,))])).astype(
- int
- )
- all_nodes_mask = np.ones(len(node_mask)) * -2
- all_nodes_mask[np.where(node_mask >= 0)[0]] = 1
- all_nodes_mask = jnp.array(all_nodes_mask)
- local_params_list = []
- for param_idx in range(len(local_params[0])):
- l = []
- for node_idx in range(len(nodes)):
- l.append(local_params[node_idx][param_idx])
- l = np.vstack(l)
-
- # Add dummy nodes
- param_shape = l[0].shape[0]
- l = jnp.array(np.concatenate([l, np.zeros((rem, param_shape))], axis=0))
-
- local_params_list.append(l)
- # print([node.label for node in nodes])
- # print(parent_vector)
- # print(children_vector)
- # print([cnv[0] for cnv in cnvs])
-
- if not do_global:
- logger.debug("Won't take derivatives wrt global parameters")
- elif global_only:
- logger.debug("Won't take derivatives wrt local parameters")
-
- do_global = do_global * jnp.array(1.0)
- global_only = global_only * jnp.array(1.0)
- sticks_only = sticks_only * jnp.array(1.0)
-
- init_params = local_params_list + global_params
-
- # end2 = time.time()
- # print(f"Getting parameters: {end2-start2}")
-
- if opt_triplet is None:
- if opt is None:
- opt = optimizers.adam
- opt_init, opt_update, get_params = opt(step_size=step_size)
- get_params = jit(get_params)
- opt_update = jit(opt_update)
- opt_init = jit(opt_init)
- else:
- opt_init, opt_update, get_params = (
- opt_triplet[0],
- opt_triplet[1],
- opt_triplet[2],
- )
- opt_state = opt_init(init_params)
-
- # print(f"Time to prepare optimizer: {end-start} s")
- # n_nodes = jnp.array(n_nodes)
- self.n_nodes = n_nodes
- # print(n_nodes)
- if callback is None:
- callback = elbos_callback
- # print("Iteration {} lower bound {}".format(t, self.batch_objective(cnvs, parent_vector, children_vector, ancestor_nodes_indices, tssb_indices, previous_branches_indices, tssb_weights, dp_alphas, dp_gammas, params, t)))
-
- data = jnp.array(self.data)
- lib_sizes = jnp.array(self.root["node"].root["node"].lib_sizes)
- cell_covariates = jnp.array(self.covariates)
-
- unobserved_factors_kernel_rate = (
- self.root["node"].root["node"].unobserved_factors_kernel_rate
- )
- unobserved_factors_kernel_concentration = (
- self.root["node"].root["node"].unobserved_factors_kernel_concentration
- )
- unobserved_factors_root_kernel = (
- self.root["node"].root["node"].unobserved_factors_root_kernel
- )
- global_noise_factors_precisions_shape = (
- self.root["node"].root["node"].global_noise_factors_precisions_shape
- )
+ def update_node_params(self, key, memoized=True, **kwargs):
+ """
+ This performs a tree traversal to update the stick parameters and the kernel parameters.
+ The node parameters are updated using stochastic memoized VI.
+ """
+ def descend(root, depth=0):
+ mass_down = 0
+ for i, child in enumerate(root["children"][::-1]):
+ j = len(root['children'])-1-i
+ child_mass = descend(child, depth + 1)
+ child['node'].variational_parameters['sigma_1'] = root['psi_priors'][j]["alpha_psi"] + child_mass
+ child['node'].variational_parameters['sigma_2'] = root['psi_priors'][j]["beta_psi"] + mass_down
+ mass_down += child_mass
+
+ # Traverse inner TSSB
+ root['node'].update_node_params(key, memoized=memoized, **kwargs)
+
+ if memoized:
+ mass_here = root['node'].suff_stats['mass']['total']
+ else:
+ mass_here = jnp.sum(root['node'].variational_parameters['q_c'])
+ root['node'].variational_parameters['delta_1'] = root['alpha_nu'] + mass_here
+ root['node'].variational_parameters['delta_2'] = root['beta_nu'] + mass_down
- # print(all_nodes_mask)
- full_data_indices = jnp.array(np.arange(self.num_data))
- data_mask_subset = jnp.array(data_mask)
- sub_data_indices = np.where(data_mask)[0]
- current_elbo = self.elbo
-
- # Get max width in node_mask
- # Should not count below root?
- max_width = 1
- node_mask_idx_below_root = node_mask_idx[
- np.where(ancestor_nodes_indices[node_mask_idx, 0] != 0)[0]
- ]
- local_node_mask = jnp.array(node_mask)
- if len(node_mask_idx_below_root) > 0:
- # original previous_branches_indices only contains within tssb!
- previous_branches_indices2 = self.get_previous_branches_indices(
- nodes, within_tssb=False
- )
- masked_prevs = np.array(
- previous_branches_indices2[node_mask_idx_below_root]
- ).reshape(len(node_mask_idx_below_root), -1)
- max_width = np.max(np.sum(masked_prevs >= 0, axis=1)) + 1
-
- # Get max depth within TSSB
- masked_ancs = np.array(ancestor_nodes_indices)[node_mask_idx]
- max_depth = np.max(np.sum(masked_ancs >= 0, axis=1)) + 1
-
- # end = time.time()
- # print(f"before run: {end-start}")
- if run:
- # Main loop.
- current_elbo = self.elbo
- # if self.elbo == -np.inf:
- # current_elbo = -self.batch_objective(obs_params, parent_vector, children_vector, ancestor_nodes_indices, tssb_indices, previous_branches_indices, tssb_weights, dp_alphas, dp_gammas, all_nodes_mask, do_global, global_only, sticks_only, num_samples, init_params, 0)
- # print(f"Current ELBO: {current_elbo:.5f}")
- # print(f"Optimizing variational parameters from node {root_label}...")
- elbos = []
- means = []
- minibatch_probs = np.ones((self.num_data,))
- minibatch_probs[sub_data_indices] = 1e6
- minibatch_probs = minibatch_probs / np.sum(minibatch_probs)
- for t in range(n_iters):
- minibatch_idx = np.random.choice(
- self.num_data, p=minibatch_probs, size=mb_size, replace=False
- )
- minibatch_idx = jnp.array(np.sort(minibatch_idx)).ravel()
- data_mask_subset = jnp.array(data_mask)[minibatch_idx]
- # minibatch_idx = np.arange(self.num_data)
- # data_mask_subset = data_mask
- # start = time.time()
- # Update params
- # Iterate through nodes
- # This is very slow: should reduce number of iterations?! Or
- # randomly choose a node at each iteration?
- # If I reduce the number of iterations I don't account for different complexities
- # Only take gradient of one node at a time
- # If I do it randomly I'm probably not using momentum, right?
- # Only do it for big widths
- do_sticks = jnp.array(1.0)
- if max_width > 2 or (max_depth > 2 and np.sum(node_mask) > 2):
- node_idx = np.random.choice(node_mask_idx)
- local_node_mask = np.array(node_mask)
- off_idx = np.where(node_mask >= 0)[0]
- local_node_mask[off_idx] = 0
- local_node_mask[node_idx] = 1
- local_node_mask = jnp.array(local_node_mask)
- opt_state, g, params, elbo = cna.opt_funcs.update(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- local_node_mask,
- data_mask_subset,
- minibatch_idx,
- do_global,
- global_only,
- do_sticks,
- sticks_only,
- num_samples,
- t,
- opt_state,
- opt_update,
- get_params,
- )
- # end = time.time()
- # print(f"update: {end-start}")
- elbos.append(-elbo)
- try:
- callback(elbos, **callback_kwargs)
- except StopIteration as e:
- logger.debug(f"Stopped optimization at iteration {t}/{n_iters}")
- break
+ return mass_here + mass_down
+
+ descend(self.root)
- # Without node mask
- # start = time.time()
- ret = cna.opt_funcs.batch_objective(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- all_nodes_mask,
- jnp.array(1.0),
- jnp.array(0.0),
- jnp.array(1.0),
- jnp.array(0.0),
- num_samples,
- data.shape[0],
- get_params(opt_state),
- 0,
- )
- # end = time.time()
- # print(f"batch_objective: {end-start}")
- self.elbo = np.array(ret[0]).item()
- self.ll = np.array(ret[1]).item()
- self.kl = np.array(ret[2]).item()
- self.node_kl = np.array(ret[3])
-
- # Weigh by tree prior
- subtrees = self.get_mixture()[1][1:] # without the root
- for subtree in subtrees:
- pivot_node = subtree.root["node"].parent()
- parent_subtree = pivot_node.tssb
- prior_weights, subnodes = parent_subtree.get_fixed_weights()
- # Weight ELBO by chosen pivot's prior probability
- node_idx = np.where(pivot_node == np.array(subnodes))[0][0]
- self.elbo = self.elbo + np.log(prior_weights[node_idx])
-
- # Combinatorial penalization to avoid duplicates -- also avoids real clusters!
- # self.elbo = self.elbo + np.log(1/(2**len(data_indices)))
-
- new_elbo = self.elbo
- # print(f"Done. Speed: {avg_speed} s/it, Total: {total} s")
- # print(f"New ELBO: {new_elbo:.5f}")
- # print(f"New ELBO improvement: {(new_elbo - current_elbo)/np.abs(current_elbo) * 100:.3f}%\n")
-
- # start = time.time()
- self.set_node_means(
- get_params(opt_state),
- nodes,
- local_names,
- global_names,
- node_mask=node_mask,
- do_global=do_global,
- )
- self.update_ass_logits(
- nodes=nodes, variational=True
- )
- self.assign_to_best(nodes=nodes)
- # end = time.time()
- # print(f"last part: {end-start}")
- return elbos
- else:
- ret = cna.opt_funcs.batch_objective(
- data,
- cell_covariates,
- lib_sizes,
- unobserved_factors_kernel_rate,
- unobserved_factors_kernel_concentration,
- unobserved_factors_root_kernel,
- global_noise_factors_precisions_shape,
- obs_params,
- parent_vector,
- children_vector,
- ancestor_nodes_indices,
- tssb_indices,
- previous_branches_indices,
- tssb_weights,
- dp_alphas,
- dp_gammas,
- all_nodes_mask,
- jnp.array(1.0),
- jnp.array(0.0),
- jnp.array(1.0),
- jnp.array(0.0),
- num_samples,
- data.shape[0],
- get_params(opt_state),
- 0,
- )
- self.elbo = np.array(ret[0]).item()
- self.ll = np.array(ret[1]).item()
- self.kl = np.array(ret[2]).item()
- self.node_kl = np.array(ret[3])
-
- # Weigh by tree prior
- subtrees = self.get_mixture()[1][1:] # without the root
- for subtree in subtrees:
- pivot_node = subtree.root["node"].parent()
- parent_subtree = pivot_node.tssb
- prior_weights, subnodes = parent_subtree.get_fixed_weights()
- # Weight ELBO by chosen pivot's prior probability
- node_idx = np.where(pivot_node == np.array(subnodes))[0][0]
- self.elbo = self.elbo + np.log(prior_weights[node_idx])
- self.update_ass_logits(variational=True)
- self.assign_to_best(nodes=nodes)
- return None
-
- def set_node_means(
- self, params, nodes, local_names, global_names, node_mask=None, do_global=True
- ):
- # start = time.time()
- globals_start = len(local_names)
- params_idx = 0
- for i, global_param in enumerate(global_names):
- params_idx = globals_start + i
- if (
- do_global or "cell" in global_param
- ): # always update cell-specific parameters
- self.root["node"].root["node"].variational_parameters["globals"][
- global_param
- ] = np.array(params[params_idx])
-
- if node_mask is None:
- node_indices = np.arange(len(nodes))
- else:
- node_indices = np.where(node_mask == 1)[0]
- for node_idx in node_indices:
- for i, local_param in enumerate(local_names):
- nodes[node_idx].variational_parameters["locals"][
- local_param
- ] = np.array(params[i][node_idx])
- nodes[node_idx].set_mean(variational=True)
-
- def update_ass_logits(
- self, nodes=None, indices=None, variational=False, prior=True
- ):
- # start = time.time()
- if indices is None:
- indices = list(range(self.num_data))
-
- ns, weights = self.get_node_mixture()
- if nodes is None:
- nodes = ns
-
- for i, node in enumerate(ns):
- if node in nodes:
- node_lls = node.loglh(
- np.array(indices), variational=variational, axis=1
- )
- node_lls = node_lls + np.log(weights[i] + 1e-300) if prior else node_lls
- node.data_ass_logits[np.array(indices)] = node_lls
- # print(f"update_ass_logits: {time.time()-start}")
-
-
- def assign_to_best(self):
- total_weights = []
- node_list = []
- tssbs = self.get_subtrees()
- tssb_list = []
- for tssb in tssbs:
- tssb._data.clear()
- _, nodes = tssb.get_fixed_weights()
- for node in nodes:
- node_list.append(node)
- tssb_list.append(tssb)
- total_weights.append(tssb.data_weights * node.data_weights)
- total_weights = np.vstack(total_weights).T
- self.total_weights = total_weights
- self.node_list = node_list
- self.tssb_list = tssb_list
- assignments = np.argmax(total_weights,axis=1)
- for i, node in enumerate(node_list):
- node.remove_data()
- data = np.where(assignments == i)[0]
- node.add_data(np.where(assignments == i)[0])
- node.tssb._data.update(np.where(assignments == i)[0])
- self.assignments = list(np.array(node_list)[assignments])
+ def update_root_node_params(self, key, memoized=True, adaptive=True, i=0, **kwargs):
+ """
+ This performs a tree traversal to update the kernel parameters of root nodes
+ The node parameters are updated using stochastic memoized VI.
- def _assign_to_best(self, nodes=None):
- # start = time.time()
- if nodes is None:
- nodes = self.get_nodes()
+ Go inside TSSB1, compute the gradient of TSSB2's root wrt TSSB1's nodes, return their sum
+ Go inside TSSB2, compute the gradient of TSSB2's root wrt TSSB3's root, add to previous gradient and update
+ Repeat
+ """
+ def descend(root, depth=0):
+ ll_grads = []
+ locals_grads = []
+ children_grads = []
+ for child in root["children"]:
+ child_ll_grad, child_locals_grads, child_children_grads = descend(child, depth + 1)
+ ll_grads.append(child_ll_grad)
+ locals_grads.append(child_locals_grads)
+ children_grads.append(child_children_grads)
+
+ if len(root["children"]) > 0:
+ # Compute gradient of children roots wrt possible parents here
+ parent_grads = root['node'].compute_children_root_node_grads(**kwargs)
+
+ # Update parameters of each child root
+ for ii, child in enumerate(root["children"]):
+ ll_grad = ll_grads[ii]
+ local_grad = locals_grads[ii]
+ children_grad = children_grads[ii]
+ parent_grad = parent_grads[ii]
+ child['node'].update_root_node_params(key, ll_grad, local_grad, children_grad, parent_grad, adaptive=adaptive, i=i, **kwargs)
+
+ if depth > 0: # root of root TSSB has no parameters
+ # Sample root node and compute gradients wrt children (including child roots)
+ # direction_grads = [direction_params_grad, direction_params_entropy_grad, direction_sample_grad]
+ # state_grads = [state_params_grad, state_params_entropy_grad, state_sample_grad]
+ ll_grad, locals_grads, children_grads = root['node'].sample_grad_root_node(key, memoized=memoized, **kwargs)
+
+ return ll_grad, locals_grads, children_grads
+
+ descend(self.root)
+
+ def update_pivot_probs(self):
+ def descend(root):
+ if len(root["children"]) > 0:
+ root["node"].update_pivot_probs()
+ for child in root["children"]:
+ descend(child)
+ descend(self.root)
- assignment_logits = jnp.array([node.data_ass_logits for node in nodes]).T
+ def assign_samples(self):
+ def descend(root):
+ nodes_probs = [root['node'].variational_parameters['q_z'].ravel() * root['node'].tssb.variational_parameters['q_c'].ravel()]
+ nodes = [root['node']]
+ for child in root["children"]:
+ cnodes_probs, cnodes = descend(child)
+ nodes_probs.extend(cnodes_probs)
+ nodes.extend(cnodes)
+ return nodes_probs, nodes
+ def super_descend(super_root):
+ nodes_probs_lists = []
+ nodes_lists = []
+ nodes_probs, nodes = descend(super_root["node"].root)
+ nodes_probs_lists = [nodes_probs]
+ nodes_lists = [nodes]
+ for super_child in super_root["children"]:
+ children_nodes_probs, children_nodes = super_descend(super_child)
+ nodes_probs_lists.extend(children_nodes_probs)
+ nodes_lists.extend(children_nodes)
+ return nodes_probs_lists, nodes_lists
+
+ nodes_probs, nodes = super_descend(self.root)
+ nodes = [x for xs in nodes for x in xs]
+ nodes_probs = [x for xs in nodes_probs for x in xs]
+ nodes_probs = np.array(nodes_probs).T
+ self.assignments = np.array(nodes)[np.argmax(nodes_probs, axis=1)]
+ for node in nodes:
+ node.remove_data()
+ node.add_data(np.where(self.assignments == node)[0])
+ node.tssb._data.update(np.where(self.assignments == node)[0])
- @jit
- def get_assignments(assignment_logits):
- assignment_probs = jnp.array(jnn.softmax(assignment_logits, axis=1))
- return jax.vmap(jnp.argmax)(assignment_probs)
+ def assign_pivots(self):
+ def descend(root):
+ pivot_probs_nodes = [root['node'].get_pivot_probabilities(i) for i in range(len(root['children']))]
+ for i, child in enumerate(root['children']):
+ pivot_probs = [l for l in pivot_probs_nodes[i][0]]
+ pivot_nodes = [l for l in pivot_probs_nodes[i][1]]
+ child['pivot_node'] = pivot_nodes[np.argmax(pivot_probs)]
+ child['node'].root['node'].set_parent(child['pivot_node'])
+ descend(child)
+ descend(self.root)
- assignments = np.array(get_assignments(assignment_logits))
+ # ========= Functions to update tree structure. =========
- # Clear all
- for i, node in enumerate(nodes):
- node.remove_data()
- node.add_data(np.where(assignments == i)[0])
+ def prune_subtrees(self):
+ # Remove all nodes except observed ones
+ def descend(root):
+ def sub_descend(sub_root):
+ for sub_child in sub_root['children']:
+ sub_descend(sub_child)
+ root['node'].merge_nodes(sub_root, sub_child, sub_root)
+ sub_descend(root['node'].root)
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
- self.assignments = list(np.array(nodes)[assignments])
+ def birth(self, source, seed=None):
+ def sub_descend(root, target_root=None):
+ if source == root['node'].label: # node label
+ target_root = root
+ for child in root['children']:
+ if target_root is not None:
+ return target_root
+ else:
+ target_root = sub_descend(child, target_root=target_root)
+ return target_root
- # print(f"assign_to_best: {time.time()-start}")
+ def descend(root):
+ if root['node'].label in source: # TSSB label
+ target_root = sub_descend(root['node'].root)
+ root['node'].add_node(target_root, seed=seed)
+ for child in root['children']:
+ descend(child)
- # ========= Functions to update tree structure. =========
+ descend(self.root)
- def add_node_to(self, node, optimal_init=True, factor_idx=None, return_parent_root=True):
- if isinstance(node, dict):
- root = node
- else:
- nodes = self._get_nodes(get_roots=True)
- nodes_list = np.array([node[0] for node in nodes])
- roots_list = np.array([node[1] for node in nodes])
- if isinstance(node, str):
- self.plot_tree(super_only=False)
- nodes_list = np.array([node[0].label for node in nodes])
- node_idx = np.where(nodes_list == node)[0][0]
- root = roots_list[node_idx]
-
- # Create child
- stick_length = boundbeta(1, self.dp_gamma)
- root["sticks"] = np.vstack([root["sticks"], stick_length])
- root["children"].append(
- {
- "node": root["node"].spawn(False, root["node"].observed_parameters),
- "main": boundbeta(
- 1.0, (self.alpha_decay ** (root["node"].depth + 1)) * self.dp_alpha
- )
- if self.min_depth <= (root["node"].depth + 1)
- else 0.0,
- "sticks": np.empty((0, 1)),
- "children": [],
- }
- )
- root["children"][-1]["node"].reset_variational_parameters()
+ def merge(self, source, target):
+ def sub_descend(root, parent_root=None, source_root=None, target_root=None):
+ if target == root['node'].label: # node label
+ target_root = root
+ for child in root['children']:
+ if source == child['node'].label: # node label
+ source_root = child
+ if target == child['node'].label: # node label
+ target_root = child
+ if source_root is not None and target_root is not None and parent_root is None:
+ parent_root = root if source_root['node'].parent().label == root['node'].label else target_root
+ return parent_root, source_root, target_root
+ else:
+ parent_root, source_root, target_root = sub_descend(child, parent_root=parent_root, source_root=source_root, target_root=target_root)
+ return parent_root, source_root, target_root
- if optimal_init:
- # Remove some mass from the parent
- root["node"].variational_parameters["locals"]["nu_log_mean"] = np.array(0.0)
- root["node"].variational_parameters["locals"]["nu_log_std"] = np.array(0.0)
- root["children"][-1]["node"].data_ass_logits = -np.inf * np.ones(
- (self.num_data)
- )
- baseline = np.append(
- 1, np.exp(self.root["node"].root["node"].log_baseline_caller())
- )
+ def descend(root):
+ if root['node'].label in source: # TSSB label
+ parent_root, source_root, target_root = sub_descend(root['node'].root)
+ root['node'].merge_nodes(parent_root, source_root, target_root)
+ for child in root['children']:
+ descend(child)
- if factor_idx is not None:
- target_genes = np.argsort(
- np.abs(
- self.root["node"]
- .root["node"]
- .variational_parameters["globals"]["noise_factors_mean"][
- factor_idx
- ]
- )
- )[-5:]
- root["children"][-1]["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ][target_genes] = -1.0
- else:
- data_indices = np.where(np.array(self.assignments) == root["node"])[0]
- if len(data_indices) > 0:
- data_in_node = np.array(self.data)[data_indices]
- target_genes = np.argsort(np.var(np.log(data_in_node + 1), axis=0))[
- -10:
- ]
- root["children"][-1]["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ][target_genes] = -1.0
+ descend(self.root)
- if return_parent_root:
- return root
- else:
- return root["children"][-1]
- def perturb_node(self, node, target):
- # Perturb parameters of node to become closer to data explained by target
- if isinstance(node, str) and isinstance(target, str):
- self.plot_tree(super_only=False)
- nodes_list = np.array(self.get_nodes())
- node_labels = np.array([node.label for node in nodes_list])
- node = nodes_list[np.where(node_labels == node)[0][0]]
- target = nodes_list[np.where(node_labels == target)[0][0]]
-
- # data_indices = list(target.data.copy())
- data_indices = np.where(np.array(self.assignments) == target)[0]
-
- if len(data_indices) > 0:
- index = np.random.choice(np.array(data_indices))
- # worst_index = np.argmin(root['node'].data_ass_logits[data_indices])
- # worst_index = np.random.choice(np.array(data_indices)[np.array([np.argsort(target.data_ass_logits[data_indices])[:5]])].ravel())
- logger.debug(f"Setting node to explain datum {index}")
- worst_datum = self.data[index]
- baseline = np.append(
- 1, np.exp(self.root["node"].root["node"].log_baseline_caller())
- )
- noise = (
- self.root["node"]
- .root["node"]
- .variational_parameters["globals"]["cell_noise_mean"][index]
- .dot(
- self.root["node"]
- .root["node"]
- .variational_parameters["globals"]["noise_factors_mean"]
- )
- )
- total_rna = np.sum(
- baseline
- * node.cnvs
- / 2
- * np.exp(
- node.variational_parameters["locals"]["unobserved_factors_mean"]
- + noise
- )
- )
- node.variational_parameters["locals"]["unobserved_factors_mean"] = np.log(
- (worst_datum + 1)
- * total_rna
- / (
- self.root["node"].root["node"].lib_sizes[index]
- * baseline
- * node.cnvs
- / 2
- * np.exp(noise)
- )
- )
- node.set_mean(
- node.get_mean(
- unobserved_factors=node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ],
- baseline=baseline,
- )
- )
+ def swap_nodes(self, parent_node, child_node):
+ # Put parameters of child in parent and of parent in child
+ parent_params = deepcopy(parent_node.variational_parameters)
+ parent_suffstats = deepcopy(parent_node.suff_stats)
+ child_params = deepcopy(child_node.variational_parameters)
+ child_suffstats = deepcopy(child_node.suff_stats)
+ parent_node.variational_parameters = child_params
+ parent_node.suff_stats = child_suffstats
+ child_node.variational_parameters = parent_params
+ child_node.suff_stats = parent_suffstats
+
+ def swap_root(self, parent, child):
+ def descend(root, depth=0):
+ if depth > 0:
+ if root['node'].label == parent: # TSSB label
+ for ch in root['node'].root['children']:
+ print(ch['node'].label)
+ if ch['node'].label == child:
+ self.swap_nodes(root['node'].root['node'], ch['node'])
+ for ch in root['children']:
+ descend(ch, depth+1)
+
+ descend(self.root)
def remove_last_leaf_node(self, parent_label):
nodes = self._get_nodes(get_roots=True)
@@ -2342,40 +1761,6 @@ def prune_reattach(self, node, new_parent):
# node.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] = np.log(node.unobserved_factors_kernel_concentration_caller())*np.ones((n_genes,))
# node.variational_parameters['locals']['unobserved_factors_kernel_log_std'] += .5
- def pivot_reattach_to(self, subtree, pivot):
- nodes = self._get_nodes(get_roots=False)
- nodes = np.array(nodes)
- if isinstance(subtree, str) or isinstance(pivot, str):
- self.plot_tree(super_only=False)
- node_labels = np.array([node.label for node in nodes])
- subtree = nodes[np.where(node_labels == subtree)[0][0]]
- subtree = subtree.tssb
- pivot = nodes[np.where(node_labels == pivot)[0][0]]
-
- subtree_label = subtree.label
-
- root_node_idx = np.where(nodes == subtree.root["node"])[0][0]
- root_node = nodes[root_node_idx]
- pivot_node_idx = np.where(nodes == pivot)[0][0]
- pivot_node = nodes[pivot_node_idx]
-
- subtrees = self.get_subtrees(get_roots=True)
- subtree_objs = np.array([s[0] for s in subtrees])
- subtree_idx = np.where(subtree_objs == subtree)[0][0]
- subtrees[subtree_idx][1]["pivot_node"] = pivot_node
-
- # prev_unobserved_factors = root_node[0].unobserved_factors_mean
- # root_node.variational_parameters['locals']['unobserved_factors_kernel_log_std'] += .5
- root_node.set_parent(pivot_node, reset=False)
- root_node.set_mean(variational=True)
- # Reset the kernel posterior
- # root_node[0].unobserved_factors_kernel_log_mean = -1.*jnp.ones((root_node[0].n_genes,))
- # root_node[0].unobserved_factors_kernel_log_std = -1.*jnp.ones((root_node[0].n_genes,))
- # root_node[0].unobserved_factors_mean = prev_unobserved_factors
- # root_node[0].set_mean(variational=True)
- # self.update_ass_logits(variational=True)
- # self.assign_to_best()
-
def extract_pivot(self, node):
"""
extract_pivot(B):
@@ -2405,10 +1790,10 @@ def extract_pivot(self, node):
node.variational_parameters["locals"]["unobserved_factors_log_std"]
)
paramsB_k = np.array(
- node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"]
+ node.variational_parameters["locals"]["unobserved_factors_kernel_log_shape"]
)
paramsB_k_std = np.array(
- node.variational_parameters["locals"]["unobserved_factors_kernel_log_std"]
+ node.variational_parameters["locals"]["unobserved_factors_kernel_log_rate"]
)
# Set new node's parameters equal to the previous parameters of node
@@ -2419,10 +1804,10 @@ def extract_pivot(self, node):
"unobserved_factors_log_std"
] = np.array(paramsB_std)
new_node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
+ "unobserved_factors_kernel_log_shape"
] = np.array(paramsB_k)
new_node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
+ "unobserved_factors_kernel_log_rate"
] = np.array(paramsB_k_std)
new_node.set_mean(variational=True)
@@ -2447,7 +1832,7 @@ def extract_pivot(self, node):
np.abs(node.variational_parameters["locals"]["unobserved_factors_mean"])
> 0.5
)[0]
- node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"][
+ node.variational_parameters["locals"]["unobserved_factors_kernel_log_shape"][
affected_genes
] = np.log(node.unobserved_factors_kernel_concentration_caller())
@@ -2456,75 +1841,6 @@ def extract_pivot(self, node):
return new_node_root
- def push_subtree(self, node):
- """
- push_subtree(B):
- A-0 -> B -> B-0 to A-0 -> A-0-0 -> B -> B-0
- """
- if isinstance(node, str):
- self.plot_tree(super_only=False)
- nodes = self.get_nodes(None)
- node_labels = np.array([node.label for node in nodes])
- node = nodes[np.where(node_labels == node)[0][0]]
-
- if node.parent() is None:
- raise ValueError("Can't pull from root tree")
- if not node.is_observed:
- raise ValueError("Can't pull unobserved node")
-
- parent_node = node.parent()
- children = np.array(list(node.children()))
- idx = np.argsort([c.label for c in children])
- children = children[idx]
- if len(children) > 0:
- children = [n for n in children if not n.is_observed]
- child_node = None
- if len(children) > 0:
- child_node = children[0]
-
- # Add node below parent
- new_node_root = self.add_node_to(parent_node, return_parent_root=False)
- new_node = new_node_root["node"]
- self.pivot_reattach_to(node.tssb, new_node)
- paramsB = np.array(
- node.variational_parameters["locals"]["unobserved_factors_mean"]
- )
- paramsB_k = np.array(
- node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"]
- )
- dataB = node.data.copy()
- logitsB = np.array(node.data_ass_logits)
- if child_node:
- node.variational_parameters["locals"]["unobserved_factors_mean"] = np.array(
- child_node.variational_parameters["locals"]["unobserved_factors_mean"]
- )
- node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.array(
- child_node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
- )
- node.data = child_node.data.copy()
- node.data_ass_logits = np.array(child_node.data_ass_logits)
- # Merge B with child that it has become equal to
- self.merge_nodes(child_node, node)
- node.set_mean(variational=True)
-
- # Set new node's parameters equal to the previous parameters of node
- new_node.variational_parameters["locals"]["unobserved_factors_mean"] = np.array(
- paramsB
- )
- new_node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.array(paramsB_k)
- new_node.set_mean(variational=True)
- if child_node:
- new_node.data = dataB.copy()
- new_node.data_ass_logits = np.array(logitsB)
-
- return new_node_root
-
def path_to_node(self, node):
path = []
path.append(node)
@@ -2567,788 +1883,6 @@ def path_between_nodes(self, nodeA, nodeB):
path = path + list(pathB[np.where(pathB == mrca)[0][0] :])
return path
- def swap_nodes(self, nodeA, nodeB, update_pivots=True, use_top_kernel=True):
- self.plot_tree(super_only=False)
- if isinstance(nodeA, str) and isinstance(nodeB, str):
- nodes = self.get_nodes(None)
- node_labels = np.array([node.label for node in nodes])
- nodeA = nodes[np.where(node_labels == nodeA)[0][0]]
- nodeB = nodes[np.where(node_labels == nodeB)[0][0]]
-
- # If we are swapping the root node, need to change the baseline too
- root_node = None
- non_root_node = None
- child_unobserved_factors = None
- initial_log_baseline = None
- if nodeA.parent() is None:
- root_node = nodeA
- non_root_node = nodeB
- elif nodeB.parent() is None:
- root_node = nodeB
- non_root_node = nodeA
-
- def swap_params(nA, nB):
- params_names = list(nA.variational_parameters["locals"].keys())
- paramsA_list = [
- np.array(nA.variational_parameters["locals"][key])
- for key in params_names
- ]
- # paramsA = np.array(nA.variational_parameters['locals']['unobserved_factors_mean'])
- paramsA_k = np.array(
- nA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
- )
- # nu_sticksA = np.array(nA.variational_parameters['locals']['nu_log_mean'])
- # psi_sticksA = np.array(nA.variational_parameters['locals']['psi_log_mean'])
- paramsB_k = np.array(
- nB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
- )
- params_k = np.array([paramsA_k, paramsB_k])
- # We will initialize the kernels to be the one with the most events, just in case
- top_k_idx = np.argmax(np.array([np.var(paramsA_k), np.var(paramsB_k)]))
-
- # Relax kernel of intermediate nodes
- int_nodes = []
- if nA.label in nB.label:
- n = nB
- while True:
- n = n.parent()
- if n == nA:
- break
- else:
- int_nodes.append(n)
- elif nB.label in nA.label:
- n = nA
- while True:
- n = n.parent()
- if n == nB:
- break
- else:
- int_nodes.append(n)
- if len(int_nodes) > 0:
- if use_top_kernel:
- for node in int_nodes:
- node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.clip(params_k[top_k_idx], -3, 10)
-
- dataA = nA.data.copy()
- logitsA = np.array(nA.data_ass_logits)
- data_weights_A = np.array(nA.data_weights)
- for param in params_names:
- nA.variational_parameters["locals"][param] = np.array(
- nB.variational_parameters["locals"][param]
- )
- if use_top_kernel:
- nA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.array(params_k[top_k_idx])
- if nA == nB.parent():
- # If one is a child of the other, the child should not have high kernel where parent does
- parent_events = np.where(
- np.abs(
- nA.variational_parameters["locals"]["unobserved_factors_mean"]
- )
- > 0.1
- )[0]
- nA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.array(
- nA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
- )
- nA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ][parent_events] -= 1
- nA.data = nB.data.copy()
- nA.data_ass_logits = np.array(nB.data_ass_logits)
- nA.data_weights = np.array(nB.data_weights)
- nA.set_mean(variational=True)
- for i, param in enumerate(params_names):
- nB.variational_parameters["locals"][param] = np.array(paramsA_list[i])
- if use_top_kernel:
- nB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.array(params_k[top_k_idx])
- if nB == nA.parent():
- # If one is a child of the other, the child should not have high kernel where parent does
- parent_events = np.where(
- np.abs(
- nB.variational_parameters["locals"]["unobserved_factors_mean"]
- )
- > 0.1
- )[0]
- nB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.array(
- nB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
- )
- nB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ][parent_events] -= 1
- nB.data = dataA.copy()
- nB.data_ass_logits = np.array(logitsA)
- nB.data_weights = np.array(data_weights_A)
- nB.set_mean(variational=True)
-
- if not root_node:
- if nodeA.tssb == nodeB.tssb:
- swap_params(nodeA, nodeB)
- else:
- # e.g. A-0 with C
- if nodeA.parent() == nodeB or nodeB.parent() == nodeA:
- unobserved_node_idx = np.where(
- [not nodeA.is_observed, not nodeB.is_observed]
- )[0]
- if len(unobserved_node_idx) > 1:
- logger.debug(
- "Warning: both nodes to swap are unobserved but are part of different TSSBs:"
- )
- logger.debug(
- f"{nodeA.label}: {nodeA.tssb.label}, {nodeB.label}: {nodeB.tssb.label}"
- )
- logger.debug("Proceeding without swapping.")
- return
- elif len(unobserved_node_idx) == 1:
- # change A -> A-0 -> B to A-> B -> B-0: and put params of A-0 in B and of B in B-0
- unobserved_node_idx = unobserved_node_idx[0]
- unobserved_node = [nodeA, nodeB][unobserved_node_idx]
- observed_node = [nodeA, nodeB][1 - unobserved_node_idx]
- parent_unobserved = unobserved_node.parent()
- # if unobserved node is parent of more than one subtree, update pivot of the others to parent of unobserved_node
- unobserved_node_children = np.array(
- list(unobserved_node.children())
- )
- idx = np.argsort([n.label for n in unobserved_node_children])
- unobserved_node_children = unobserved_node_children[idx]
- if (
- np.sum(
- np.array(
- [
- child.is_observed
- for child in unobserved_node_children
- ]
- )
- )
- > 1
- ):
- for child in list(unobserved_node_children):
- if child.is_observed:
- self.pivot_reattach_to(
- child.tssb, parent_unobserved
- )
- # if not (np.sum(np.array([child.is_observed for child in unobserved_node_children])) > 1):
- # Update params
- # init_obs_params = observed_node.variational_parameters['locals']['unobserved_factors_mean']
- # observed_node.variational_parameters['locals']['unobserved_factors_mean'] = parent_unobserved.variational_parameters['locals']['unobserved_factors_mean']
- # unobserved_node.variational_parameters['locals']['unobserved_factors_mean'] = init_obs_params
- swap_params(nodeA, nodeB)
- # observed_node.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] += 2#np.log(observed_node.unobserved_factors_kernel_concentration_caller())*np.ones((observed_node.n_genes,))
- # unobserved_node.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] += 2#np.log(unobserved_node.unobserved_factors_kernel_concentration_caller())*np.ones((unobserved_node.n_genes,))
-
- # Put same-subtree children of unobserved node in its parent in order to not move a whole subtree
- nodes = self._get_nodes(get_roots=True)
- unobserved_node_tssb_root = [
- node[1] for node in nodes if node[0] == unobserved_node
- ][0]
- parent_unobserved_node_tssb_root = [
- node[1] for node in nodes if node[0] == parent_unobserved
- ][0]
- for i, unobs_child in enumerate(
- unobserved_node_tssb_root["children"]
- ):
- unobs_child["node"].set_parent(unobserved_node.parent())
- # Add children from unobserved to the parent dict
- parent_unobserved_node_tssb_root["children"].append(
- unobs_child
- )
- parent_unobserved_node_tssb_root["sticks"] = np.vstack(
- [
- parent_unobserved_node_tssb_root["sticks"],
- unobserved_node_tssb_root["sticks"][i],
- ]
- )
- if len(unobserved_node_tssb_root["children"]) > 0:
- # Remove children from unobserved
- unobserved_node_tssb_root["sticks"] = np.array([]).reshape(
- 0, 1
- )
- unobserved_node_tssb_root["children"] = []
-
- # Now move the unobserved node to below the observed one
- observed_node.set_parent(parent_unobserved)
- unobserved_node.set_parent(observed_node)
- unobserved_node.tssb = observed_node.tssb
- unobserved_node.cnvs = observed_node.cnvs
- unobserved_node.observed_parameters = (
- observed_node.observed_parameters
- )
- n_siblings = len(list(observed_node.children()))
- unobserved_node.label = (
- observed_node.label + "-" + str(n_siblings - 1)
- )
-
- nodes = self._get_nodes(get_roots=True)
- unobserved_node_tssb_root = [
- node[1] for node in nodes if node[0] == unobserved_node
- ][0]
- parent_unobserved_node_tssb_root = [
- node[1] for node in nodes if node[0] == parent_unobserved
- ][0]
-
- # Update dicts
- # Remove unobserved_node from its parent dict
- childnodes = np.array(
- [
- n["node"]
- for n in parent_unobserved_node_tssb_root["children"]
- ]
- )
- tokeep = (
- np.where(childnodes != unobserved_node_tssb_root["node"])[0]
- .astype(int)
- .ravel()
- )
- parent_unobserved_node_tssb_root[
- "sticks"
- ] = parent_unobserved_node_tssb_root["sticks"][tokeep]
- parent_unobserved_node_tssb_root["children"] = list(
- np.array(parent_unobserved_node_tssb_root["children"])[
- tokeep
- ]
- )
- # Update observed_node's pivot_node to unobserved_node's parent
- observed_node_ntssb_root = observed_node.tssb.get_ntssb_root()
- observed_node_ntssb_root["pivot_node"] = parent_unobserved
- # Add unobserved_node to observed_node's dict
- observed_node_tssb_root = observed_node_ntssb_root["node"].root
- observed_node_tssb_root["children"].append(
- unobserved_node_tssb_root
- )
- observed_node_tssb_root["sticks"] = np.vstack(
- [observed_node_tssb_root["sticks"], 1.0]
- )
- else: # random swap: change data and parameters
- swap_params(nodeA, nodeB)
- else: # e.g. A with A-0
- # init_baseline = np.mean(self.data / np.sum(self.data, axis=1).reshape(-1,1) * self.data.shape[1], axis=0)
- # init_baseline = init_baseline / init_baseline[0]
- # init_log_baseline = np.log(init_baseline[1:] + 1e-6)
- init_bs = np.array(
- root_node.variational_parameters["globals"]["log_baseline_mean"]
- - np.mean(
- root_node.variational_parameters["globals"]["log_baseline_mean"]
- )
- )
- root_node.variational_parameters["globals"]["log_baseline_mean"] = np.log(
- non_root_node.node_mean / non_root_node.node_mean[0]
- )[1:]
- root_node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] = np.zeros((self.data.shape[1],))
- root_node.variational_parameters["locals"]["unobserved_factors_log_std"] = (
- np.zeros((self.data.shape[1],)) - 2
- )
- root_node.set_mean(variational=True)
- non_root_node_init_psi = np.array(
- non_root_node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- )
- nodes = self.get_nodes()[1:]
- for node in nodes:
- node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] -= non_root_node_init_psi
- # node.variational_parameters['locals']['unobserved_factors_log_std'] += .5
- # node.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] += 1
- # node.variational_parameters['locals']['unobserved_factors_kernel_log_std'] += .5
- # non_root_node.variational_parameters['locals']['unobserved_factors_mean'] = np.clip(normal_sample(0., gamma_sample(root_node.unobserved_factors_kernel_concentration_caller(),
- # root_node.unobserved_factors_kernel_concentration_caller(), size=self.data.shape[1])), a_min=-5, a_max=5)
- data_indices = np.where(np.array(self.assignments) == root_node)[
- 0
- ] # list(root_node.data)
- if len(data_indices) > 0:
- # idx = np.random.choice(np.array(data_indices))
- # print(f'Setting new node to explain datum {idx}')
- # datum = self.data[idx]
- # baseline = np.append(1, np.exp(non_root_node.log_baseline_caller()))
- # total_rna = np.sum(baseline * non_root_node.cnvs/2 * np.exp(root_node.variational_parameters['locals']['unobserved_factors_mean']))
- # non_root_node.variational_parameters['locals']['unobserved_factors_mean'] = np.log((datum+1) * total_rna/(root_node.lib_sizes[idx]*baseline * root_node.cnvs/2))
- non_root_node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] = np.zeros((self.data.shape[1],))
- new_bs = np.array(
- root_node.variational_parameters["globals"]["log_baseline_mean"]
- - np.mean(
- root_node.variational_parameters["globals"]["log_baseline_mean"]
- )
- )
- non_root_node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ][1:] = np.array(init_bs - new_bs)
- non_root_node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.log(
- root_node.unobserved_factors_kernel_concentration_caller()
- ) * np.ones(
- (self.data.shape[1],)
- )
- data_in_node = np.array(self.data)[data_indices]
- target_genes_1 = np.argsort(np.var(np.log(data_in_node + 1), axis=0))[
- -5:
- ]
- target_genes_2 = np.where(
- np.abs(
- non_root_node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- )
- > 0.5
- )[0]
- target_genes = np.unique(
- np.concatenate([np.array(target_genes_1), np.array(target_genes_2)])
- )
- non_root_node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ][target_genes] = -1.0
-
- non_root_node.set_mean(variational=True)
- dataRoot = root_node.data.copy()
- logitsRoot = np.array(root_node.data_ass_logits)
- root_node.data = non_root_node.data.copy()
- root_node.data_ass_logits = np.array(non_root_node.data_ass_logits)
- non_root_node.data = dataRoot.copy()
- non_root_node.data_ass_logits = np.array(logitsRoot)
-
- nodeA_children = np.array(list(nodeA.children()))
- idx = np.argsort([n.label for n in nodeA_children])
- nodeA_children = nodeA_children[idx]
- nodeA_children = [
- node
- for node in nodeA_children
- if not node.is_observed and node != nodeB
- ]
-
- nodeB_children = np.array(list(nodeB.children()))
- idx = np.argsort([n.label for n in nodeB_children])
- nodeB_children = nodeB_children[idx]
- nodeB_children = [
- node
- for node in nodeB_children
- if not node.is_observed and node != nodeA
- ]
-
- if not non_root_node.is_observed:
- if non_root_node.parent() == root_node:
- # Go through children of nodeA and set them as children of nodeB and vice-versa
- for nodeA_child in nodeA_children:
- self.prune_reattach(nodeA_child, nodeB)
-
- for nodeB_child in nodeB_children:
- self.prune_reattach(nodeB_child, nodeA)
-
- if update_pivots:
- if nodeA.tssb == nodeB.tssb:
- root = nodeB.tssb.get_ntssb_root()
- # For each subtree, if pivot was swapped, update it
- for child in root["children"]:
- if child["pivot_node"] == nodeA:
- child["pivot_node"] = nodeB
- child["node"].root["node"].set_parent(nodeB, reset=False)
- child["node"].root["node"].set_mean(variational=True)
- elif child["pivot_node"] == nodeB:
- child["pivot_node"] = nodeA
- child["node"].root["node"].set_parent(nodeA, reset=False)
- child["node"].root["node"].set_mean(variational=True)
-
- # def swap_nodes(self, nodeA, nodeB):
- # if isinstance(nodeA, str) and isinstance(nodeB, str):
- # self.plot_tree(super_only=False)
- # nodes = self.get_nodes(None)
- # node_labels = np.array([node.label for node in nodes])
- # nodeA = nodes[np.where(node_labels == nodeA)[0][0]]
- # nodeB = nodes[np.where(node_labels == nodeB)[0][0]]
- #
- # # If we are swapping the root node, need to change the baseline too
- # root_node = None
- # non_root_node = None
- # child_unobserved_factors = None
- # initial_log_baseline = None
- # if nodeA.parent() is None:
- # root_node = nodeA
- # non_root_node = nodeB
- # elif nodeB.parent() is None:
- # root_node = nodeB
- # non_root_node = nodeA
- # if root_node and non_root_node:
- # child_unobserved_factors = non_root_node.unobserved_factors
- # child_unobserved_factors_k = non_root_node.unobserved_factors
- # initial_log_baseline = root_node.log_baseline_mean
- #
- # if not root_node:
- # paramsA = nodeA.variational_parameters['locals']['unobserved_factors_mean']
- # paramsA_k = nodeA.variational_parameters['locals']['unobserved_factors_kernel_log_mean']
- # # sticks_alphaA = nodeA.nu_log_alpha
- # # sticks_betaA = nodeA.nu_log_beta
- # dataA = nodeA.data
- # logitsA = nodeA.data_ass_logits
- #
- # nodeA.variational_parameters['locals']['unobserved_factors_mean'] = nodeB.variational_parameters['locals']['unobserved_factors_mean']
- # nodeA.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] = nodeB.variational_parameters['locals']['unobserved_factors_kernel_log_mean']
- # nodeA.data = nodeB.data
- # nodeA.data_ass_logits = nodeB.data_ass_logits
- # # nodeA.nu_log_alpha = nodeB.nu_log_alpha
- # # nodeA.nu_log_beta = nodeB.nu_log_beta
- # nodeA.set_mean(variational=True)
- #
- # nodeB.variational_parameters['locals']['unobserved_factors_mean'] = paramsA
- # nodeB.variational_parameters['locals']['unobserved_factors_kernel_log_mean'] = paramsA_k
- # nodeB.data = dataA
- # nodeB.data_ass_logits = logitsA
- # # nodeB.nu_log_alpha = sticks_alphaA
- # # nodeB.nu_log_beta = sticks_betaA
- # nodeB.set_mean(variational=True)
- #
- # if root_node and non_root_node:
- # root_node.variational_parameters['locals']['unobserved_factors_mean'] = child_unobserved_factors
- # root_node.variational_parameters['globals']['log_baseline_mean'] = non_root_node.variational_parameters['locals']['unobserved_factors_mean'][1:]
- # root_node.set_mean(variational=True)
- # non_root_node.variational_parameters['locals']['unobserved_factors_mean'] = np.append(0., initial_log_baseline)
- # non_root_node.set_mean(variational=True)
- #
- # if nodeA.tssb == nodeB.tssb:
- # # Go to subtrees
- # # For each subtree, if pivot was swapped, update it
- # root = nodeB.tssb.get_ntssb_root()
- # for child in root['children']:
- # if child['pivot_node'] == nodeA:
- # child['pivot_node'] = nodeB
- # child['node'].root['node'].set_parent(nodeB, reset=False)
- # child['node'].root['node'].set_mean(variational=True)
- # elif child['pivot_node'] == nodeB:
- # child['pivot_node'] = nodeA
- # child['node'].root['node'].set_parent(nodeA, reset=False)
- # child['node'].root['node'].set_mean(variational=True)
- # else:
- # if nodeA.parent() == nodeB or nodeB.parent() == nodeA:
- # unobserved_node_idx = np.where([not nodeA.is_observed, not nodeB.is_observed])[0]
- # if len(unobserved_node_idx) > 0:
- # # change A -> A-0 -> B to A-> B -> B-0:
- # unobserved_node_idx = unobserved_node_idx[0]
- # unobserved_node = [nodeA, nodeB][unobserved_node_idx]
- # # if unobserved node is parent of more than one subtree, don't proceed with full swap
- # unobserved_node_children = unobserved_node.children()
- # if not (np.sum(np.array([child.is_observed for child in unobserved_node_children])) > 1):
- # observed_node = [nodeA, nodeB][1 - unobserved_node_idx]
- # parent_unobserved = unobserved_node.parent()
- # observed_node.set_parent(parent_unobserved)
- # unobserved_node.set_parent(observed_node)
- #
- # nodes = self._get_nodes(get_roots=True)
- # unobserved_node_tssb_root = [node[1] for node in nodes if node[0] == unobserved_node][0]
- # parent_unobserved_node_tssb_root = [node[1] for node in nodes if node[0] == parent_unobserved][0]
- #
- # # Update dicts
- # # Remove unobserved_node from its parent dict
- # childnodes = np.array([n['node'] for n in parent_unobserved_node_tssb_root['children']])
- # tokeep = np.where(childnodes != unobserved_node_tssb_root['node'])[0].astype(int).ravel()
- # parent_unobserved_node_tssb_root['sticks'] = parent_unobserved_node_tssb_root['sticks'][tokeep]
- # parent_unobserved_node_tssb_root['children'] = list(np.array(parent_unobserved_node_tssb_root['children'])[tokeep])
- #
- # # Update observed_node's pivot_node to unobserved_node's parent
- # observed_node_ntssb_root = observed_node.tssb.get_ntssb_root()
- # observed_node_ntssb_root['pivot_node'] = parent_unobserved
- #
- # # Add unobserved_node to observed_node's dict
- # observed_node_tssb_root = observed_node_ntssb_root['node'].root
- # observed_node_tssb_root['children'].append(unobserved_node_tssb_root)
- # observed_node_tssb_root['sticks'] = np.vstack([observed_node_tssb_root['sticks'], 1.])
-
- def merge_nodes(self, nodeA, nodeB, optimal_params=True):
- if isinstance(nodeA, str) and isinstance(nodeB, str):
- self.plot_tree(super_only=False)
- nodes = self.get_nodes(None)
- node_labels = np.array([node.label for node in nodes])
- nodeA = nodes[np.where(node_labels == nodeA)[0][0]]
- nodeB = nodes[np.where(node_labels == nodeB)[0][0]]
-
- nodes = self._get_nodes(get_roots=True)
- nodes_list = np.array([node[0] for node in nodes])
- nodeA_idx = np.where(nodes_list == nodeA)[0][0]
- nodeB_idx = np.where(nodes_list == nodeB)[0][0]
- nodeA_root = nodes[nodeA_idx][1]
- nodeB_root = nodes[nodeB_idx][1]
- nodeA_parent_root = nodes[np.where(np.array(nodes) == nodeA.parent())[0][0]][1]
- if nodeB.parent() is not None:
- nodeB_parent_root = nodes[
- np.where(np.array(nodes) == nodeB.parent())[0][0]
- ][1]
-
- numDataA, numDataB = (len(nodeA.data), len(nodeB.data))
-
- nodeA_init_psi = np.array(
- nodeA.variational_parameters["locals"]["unobserved_factors_mean"]
- )
-
- if not nodeA.is_observed or not nodeB.is_observed:
- if nodeA.tssb == nodeB.tssb:
- # Move child nodes of nodeA to nodeB and remove nodeA
- # And make sure the children of nodeA keep their parameters
- n_childrenA = len(nodeA_root["children"])
- n_childrenB = len(nodeB_root["children"])
- for i, nodeA_child in enumerate(nodeA_root["children"]):
- nodeA_child["node"].set_parent(nodeB_root["node"], reset=False)
- nodeA_child["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.maximum(
- nodeA_child["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- nodeA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- )
- nodeA_child["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ] = np.minimum(
- nodeA_child["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ],
- nodeA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ],
- )
- nodeA_child["node"].set_mean(variational=True)
- nodeB_root["children"].append(nodeA_child)
- nodeB_root["sticks"] = np.vstack([nodeB_root["sticks"], 1.0])
- nodeA_root["children"] = []
-
- # If nodeA was the pivot of a downstream tree, update the pivot to nodeB
- nodeA_children = np.array(list(nodeA_root["node"].children().copy()))
- idx = np.argsort([n.label for n in nodeA_children])
- nodeA_children = nodeA_children[idx]
- for nodeA_child in nodeA_children:
- if nodeA_child.tssb != nodeA_root["node"].tssb:
- nodeA_child.set_parent(nodeB_root["node"], reset=False)
- nodeA_child.set_mean(variational=True)
- ntssb_root = nodeA_child.tssb.get_ntssb_root()
- ntssb_root["pivot_node"] = nodeB_root["node"]
- nodeA_root["node"].children().clear()
-
- nodeB.data_weights = np.array(nodeA.data_weights)
- nodeB.data.update(nodeA.data)
- else:
- # nodeB is parent (and pivot) of nodeA
- # Set parent of nodeB as pivot of nodeA's subtree
- ntssb_root = nodeA.tssb.get_ntssb_root()
- ntssb_root["pivot_node"] = nodeB_parent_root["node"]
- nodeA.set_parent(nodeB.parent(), reset=False)
-
- # Set all children of nodeB as children of nodeB's parent
- for i, nodeB_child in enumerate(nodeB_root["children"]):
- nodeB_child["node"].set_parent(
- nodeB_parent_root["node"], reset=False
- )
- nodeB_child["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.maximum(
- nodeB_child["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- nodeB_parent_root["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- )
- nodeB_child["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ] = np.minimum(
- nodeB_child["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ],
- nodeB_parent_root["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ],
- )
- nodeB_child["node"].set_mean(variational=True)
- nodeB_parent_root["children"].append(nodeB_child)
- nodeB_parent_root["sticks"] = np.vstack(
- [nodeB_parent_root["sticks"], 1.0]
- )
- nodeB_root["children"] = []
-
- # If nodeB was pivot of another tree, update the pivot to nodeB_parent
- nodeB_children = np.array(list(nodeB_root["node"].children().copy()))
- idx = np.argsort([n.label for n in nodeB_children])
- nodeB_children = nodeB_children[idx]
- for nodeB_child in nodeB_children:
- if (
- nodeB_child.tssb != nodeB_root["node"].tssb
- and nodeB_child.tssb != nodeA.tssb
- ):
- nodeB_child.set_parent(nodeB_parent_root["node"], reset=False)
- nodeB_child.set_mean(variational=True)
- ntssb_root = nodeB_child.tssb.get_ntssb_root()
- ntssb_root["pivot_node"] = nodeB_parent_root["node"]
- nodeB_root["node"].children().clear()
-
- nodeA.data_weights = np.array(nodeB.data_weights)
- nodeA.data.update(nodeB.data)
- else:
- nodeB.data_weights = np.array(nodeA.data_weights)
- nodeB.data.update(nodeA.data)
- nodeA.data.clear()
-
- # Keep node that explains the most data
- if optimal_params:
- if nodeA.tssb == nodeB.tssb:
- # Merge kernels
- nodeB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.maximum(
- nodeA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- nodeB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- )
- nodeB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ] = np.minimum(
- nodeA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ],
- nodeB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ],
- )
- # if numDataA > numDataB:
- if nodeB.parent() is not None:
- nodeB.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] = np.array(
- nodeA.variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- )
- nodeB.variational_parameters["locals"][
- "unobserved_factors_log_std"
- ] = np.array(
- nodeA.variational_parameters["locals"][
- "unobserved_factors_log_std"
- ]
- )
- nodeB.variational_parameters["locals"][
- "nu_log_mean"
- ] = np.array(
- nodeA.variational_parameters["locals"]["nu_log_mean"]
- )
- nodeB.variational_parameters["locals"]["nu_log_std"] = np.array(
- nodeA.variational_parameters["locals"]["nu_log_std"]
- )
- nodeB.variational_parameters["locals"][
- "psi_log_mean"
- ] = np.array(
- nodeA.variational_parameters["locals"]["psi_log_mean"]
- )
- nodeB.variational_parameters["locals"][
- "psi_log_std"
- ] = np.array(
- nodeA.variational_parameters["locals"]["psi_log_std"]
- )
- else: # We're trying to merge to root and root has no data, so adjust its baseline
- nodeB.variational_parameters["globals"][
- "log_baseline_mean"
- ] = np.log(nodeA.node_mean / nodeA.node_mean[0])[1:]
- nodeB.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] *= 0.0
- # Also adjust all unobserved factors by removing previous nodeA psi from all nodes below it: it is now present as the baseline
- def update_psi(node_obj):
- node_obj.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] -= nodeA_init_psi
- # node_obj.variational_parameters["locals"][
- # "unobserved_factors_kernel_log_mean"
- # ] += 0.5
- for child in node_obj.children():
- update_psi(child)
- update_psi(nodeA)
- #
- # nodes = self.get_nodes()[1:]
- # for node in nodes:
- # node.variational_parameters["locals"][
- # "unobserved_factors_mean"
- # ] -= nodeA_init_psi
- # node.variational_parameters["locals"][
- # "unobserved_factors_kernel_log_mean"
- # ] += 0.5
- nodeB.set_mean(variational=True)
- else:
- nodeA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.maximum(
- nodeA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- nodeB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- )
- nodeA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ] = np.minimum(
- nodeA.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ],
- nodeB.variational_parameters["locals"][
- "unobserved_factors_kernel_log_std"
- ],
- )
- if nodeB.parent() is None:
- nodeB.variational_parameters["globals"][
- "log_baseline_mean"
- ] = np.log(nodeA.node_mean / nodeA.node_mean[0])[1:]
- nodes = self.get_nodes()[1:]
- for node in nodes:
- node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] -= nodeA_init_psi
- node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] += 1.0
-
- if not nodeA.is_observed or not nodeB.is_observed:
- if nodeA.tssb == nodeB.tssb:
- # Remove nodeA from tssb root dict
- nodes = np.array([n["node"] for n in nodeA_parent_root["children"]])
- tokeep = np.where(nodes != nodeA)[0].astype(int).ravel()
- nodeA_root["node"].kill()
- del nodeA_root["node"]
-
- nodeA_parent_root["sticks"] = nodeA_parent_root["sticks"][tokeep]
- nodeA_parent_root["children"] = list(
- np.array(nodeA_parent_root["children"])[tokeep]
- )
- else:
- # Remove nodeB from tssb root dict
- nodes = np.array([n["node"] for n in nodeB_parent_root["children"]])
- tokeep = np.where(nodes != nodeB)[0].astype(int).ravel()
- nodeB_root["node"].kill()
- del nodeB_root["node"]
-
- nodeB_parent_root["sticks"] = nodeB_parent_root["sticks"][tokeep]
- nodeB_parent_root["children"] = list(
- np.array(nodeB_parent_root["children"])[tokeep]
- )
-
def subtree_reattach_to(self, node, target_clone, optimal_init=True):
# Get the node and its parent root
nodes = self._get_nodes(get_roots=True)
@@ -3448,7 +1982,7 @@ def descend(root):
)[0]
for node in nodes_below_nodeA:
node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
+ "unobserved_factors_kernel_log_shape"
][parent_affected_genes] = -1
if roots[nodeA_idx]["node"].parent().parent() is not None:
@@ -3464,12 +1998,12 @@ def descend(root):
genes = np.where(roots[nodeA_idx]["node"].cnvs != 2)[0]
roots[nodeA_idx]["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
+ "unobserved_factors_kernel_log_shape"
][genes] -= 1
genes = np.where(init_cnvs != 2)[0]
roots[nodeA_idx]["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
+ "unobserved_factors_kernel_log_shape"
][genes] -= 1
# Need to accomodate events this node has that were previously inherited
@@ -3482,7 +2016,7 @@ def descend(root):
> 0.5
)[0]
roots[nodeA_idx]["node"].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
+ "unobserved_factors_kernel_log_shape"
][affected_genes] = -1
# Reset variational parameters: all log_std and unobserved factors kernel
@@ -3690,7 +2224,15 @@ def label_nodes(self, counts=False, names=False):
elif not names or counts is True:
self.label_nodes_counts()
- def set_node_names(self, root_name="X"):
+ def set_node_names(self):
+ def descend(root):
+ root['node'].set_node_names(root_name=root['label'])
+ for child in root["children"]:
+ descend(child)
+
+ descend(self.root)
+
+ def set_tree_names(self, root_name="X"):
self.root["label"] = str(root_name)
self.root["node"].label = str(root_name)
@@ -3831,31 +2373,27 @@ def descend(root, g):
return g
def get_node_unobs(self):
- nodes = self.get_nodes(None)
+ nodes = self.get_nodes()
unobs = []
estimated = (
- np.var(nodes[1].variational_parameters["locals"]["unobserved_factors_mean"])
+ np.var(nodes[1].variational_parameters["kernel"]["state"]['mean'])
!= 0
)
if estimated:
logger.debug("Getting the learned unobserved factors.")
for node in nodes:
- unobs_factors = (
- node.unobserved_factors
- if not estimated
- else node.variational_parameters["locals"]["unobserved_factors_mean"]
- )
+ unobs_factors = node.params[0]
unobs.append(unobs_factors)
return nodes, unobs
def get_node_unobs_affected_genes(self):
- nodes = self.get_nodes(None)
+ nodes = self.get_nodes()
unobs = []
estimated = (
np.var(
- nodes[1].variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
+ np.exp(nodes[1].variational_parameters["kernel"][
+ "direction"
+ ]['log_alpha'])
)
!= 0
)
@@ -3870,22 +2408,24 @@ def get_node_unobs_affected_genes(self):
unobs_factors_kernel = (
node.unobserved_factors_kernel
if not estimated
- else node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
+ else np.exp(node.variational_parameters["kernel"]['direction'][
+ "log_alpha"
+ ] - node.variational_parameters["kernel"]['direction'][
+ "log_beta"
+ ])
)
unobs.append(unobs_factors)
return nodes, unobs
def get_node_obs(self):
- nodes = self.get_nodes(None)
+ nodes = self.get_nodes()
obs = []
for node in nodes:
obs.append(node.observed_parameters)
return nodes, obs
def get_avg_node_exp(self, norm=True):
- nodes = self.get_nodes(None)
+ nodes = self.get_nodes()
data = self.data
if norm:
try:
@@ -4190,87 +2730,68 @@ def descend(node):
descend(self.root["node"].root["node"])
- #
- # def path_to_node(self, node_id):
- # path = []
- # path.append(node_id)
- # parent_id = self.node_dict[node_id]['parent']
- # while parent_id != 'NULL':
- # path.append(parent_id)
- # parent_id = self.node_dict[parent_id]['parent']
- # return path[::-1][:]
- #
- # def path_between_nodes(self, nodeA, nodeB):
- # pathA = np.array(self.path_to_node(nodeA))
- # pathB = np.array(self.path_to_node(nodeB))
- # path = []
- # # Get MRCA
- # i = -1
- # for node in pathA:
- # if node in pathB:
- # i += 1
- # else:
- # break
- # mrca = pathA[i]
- # pathA = np.array(pathA[::-1])
- # # Get path from A to MRCA
- # path = path + list(pathA[:np.where(pathA == mrca)[0][0]])
- # # Get path from MRCA to B
- # path = path + list(pathB[np.where(pathB == mrca)[0][0]:])
- # return path
-
- # TODO: Should have a distance that counts the number of changed genes while going through the path
- def get_distance(self, id1, id2, distance="n_nodes"):
- path = self.path_between_nodes(id1, id2)
-
+ def get_distance(self, node1, node2):
+ path = self.path_between_nodes(node1, node2)
+ path_labels = [n.label for n in path]
+ node_dict = dict(zip(path_labels, path))
dist = 0
- if distance == "n_nodes":
- dist = len(path)
- else:
- prev_node = path[0]
- for node in path:
- if node != prev_node:
- if dist == "estimated":
- dist += np.sqrt(
- np.sum(
- (
- self.node_dict[node]["node"].variational_parameters[
- "locals"
- ]["unobserved_factors_mean"]
- - self.node_dict[prev_node][
- "node"
- ].variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- )
- ** 2
- )
- )
- else:
- dist += np.sqrt(
- np.sum(
- (
- self.node_dict[node].unobserved_factors
- - self.node_dict[prev_node][
- "node"
- ].unobserved_factors
- )
- ** 2
- )
- )
- prev_node = node
- dist += self.node_dict[node][distance]
+ prev_node = path_labels[0]
+ for node in path_labels:
+ if node != prev_node:
+ dist += np.sqrt(
+ np.sum(
+ (node_dict[node].get_mean() - node_dict[prev_node].get_mean())** 2
+ )
+ )
+ prev_node = node
- return dist
+ return dist
- def get_pairwise_cell_distances(self, distance="n_nodes"):
+ def get_pairwise_obs_distances(self):
n_cells = len(self.assignments)
mat = np.zeros((n_cells, n_cells))
-
- for i in range(1, n_cells):
- id1 = self.assignments[i].label
- for j in range(i):
- id2 = self.assignments[j].label
- mat[i][j] = self.get_distance(str(id1), str(id2), distance=distance)
+ nodes = self.get_nodes()
+
+ for node1 in nodes:
+ idx1 = np.where(self.assignments == node1)[0]
+ if len(idx1) == 0:
+ continue
+ for node2 in nodes:
+ idx2 = np.where(self.assignments == node2)[0]
+ if len(idx2) == 0:
+ continue
+ mat[np.meshgrid(idx1,idx2)] = self.get_distance(node1, node2)
return mat
+
+ def set_combined_params(self):
+ def sub_descend(subroot, obs_param):
+ for i, child in enumerate(subroot['children']):
+ child['combined_param'] = child['node'].combine_params(child['param'], obs_param)
+ sub_descend(child, obs_param)
+
+ def descend(root):
+ sub_descend(root['node'], root['obs_param'])
+ for i, child in enumerate(root['children']):
+ descend(child)
+
+ descend(self.root)
+
+ def set_ntssb_colors(self, **cmap_kwargs):
+ # Traverse tree to update dict
+ def descend(root):
+ tree_colors.make_tree_colormap(root['node'].root, root['color'], **cmap_kwargs)
+ for i, child in enumerate(root['children']):
+ descend(child)
+ descend(self.root)
+
+ def show_tree(self, **kwargs):
+ self.set_learned_parameters()
+ self.set_node_names()
+ self.set_expected_weights()
+ self.assign_samples()
+ self.set_ntssb_colors()
+ tree = self.get_param_dict()
+ plt.figure(figsize=(4,4))
+ ax = plt.gca()
+ plot_full_tree(tree, ax=ax, node_size=101, **kwargs)
\ No newline at end of file
diff --git a/scatrex/ntssb/tree.py b/scatrex/ntssb/observed_tree.py
similarity index 68%
rename from scatrex/ntssb/tree.py
rename to scatrex/ntssb/observed_tree.py
index 51f7d67..4e3247e 100644
--- a/scatrex/ntssb/tree.py
+++ b/scatrex/ntssb/observed_tree.py
@@ -1,5 +1,6 @@
import string
from abc import ABC, abstractmethod
+from copy import deepcopy
import numpy as np
import pandas as pd
@@ -8,23 +9,30 @@
import matplotlib.pyplot as plt
from graphviz import Digraph
+from .node import AbstractNode
from ..plotting import constants
+from ..utils.tree_utils import tree_to_dict, dict_to_tree, subsample_tree, condense_tree
-class Tree(ABC):
+class ObservedTree(ABC):
def __init__(
self,
n_nodes=3,
dp_alpha_subtree=1.0,
- alpha_decay_subtree=0.9,
+ alpha_decay_subtree=1.0,
dp_gamma_subtree=1.0,
dp_alpha_parent_edge=1.0,
- alpha_decay_parent_edge=0.9,
- eta=0.0,
+ alpha_decay_parent_edge=1.0,
+ eta=1.0,
node_weights=None,
+ seed=42,
+ add_root=True,
+ **kwargs,
):
-
- self.tree_dict = dict()
+ self.node_constructor = AbstractNode
+ self.seed = seed
+ self.tree_dict = dict() # flat
+ self.tree = dict() # recursive
self.n_nodes = n_nodes
self.dp_alpha_subtree = dp_alpha_subtree
self.alpha_decay_subtree = alpha_decay_subtree
@@ -35,16 +43,71 @@ def __init__(
self.adata = None
self.cmap = None
self.node_weights = node_weights
+ self.add_root = add_root
if self.node_weights is None:
- self.node_weights = np.random.dirichlet([10.0] * self.n_nodes)
+ rng = np.random.default_rng(seed=self.seed)
+ self.node_weights = rng.dirichlet([100.0] * self.n_nodes)
+
+ def condense(self, min_weight=.1, inplace=False):
+ """
+ Traverse the tree from the bottom up and merge nodes until all nodes have at least min_weight
+ """
+ if inplace:
+ condense_tree(self.tree, min_weight=min_weight)
+ self.tree_dict = tree_to_dict(self.tree)
+ self.tree = dict_to_tree(self.tree_dict, root_name=self.tree['label'])
+ self.n_nodes = len(self.tree_dict.keys())
+ else:
+ new_tree = deepcopy(self.tree)
+ condense_tree(new_tree, min_weight=min_weight)
+ new_tree_dict = tree_to_dict(new_tree)
+ new_tree = dict_to_tree(new_tree_dict, root_name=new_tree['label'])
+ n_nodes = len(new_tree_dict.keys())
+ new_obj = self.__class__(n_nodes=n_nodes, seed=self.seed)
+ new_obj.tree = new_tree
+ new_obj.tree_dict = new_tree_dict
+ return new_obj
+
+ def subsample(self, keep_prob=0.5, force=True, inplace=False):
+ """
+ Randomly choose a fraction of nodes to keep in the tree
+ """
+ init_n_nodes = self.n_nodes
+ seed = self.seed
+ while True:
+ new_tree = deepcopy(self.tree)
+ subsample_tree(new_tree, keep_prob=keep_prob, seed=seed)
+ new_tree_dict = tree_to_dict(new_tree)
+ new_tree = dict_to_tree(new_tree_dict, root_name=new_tree['label'])
+ n_nodes = len(new_tree_dict.keys())
+ if force:
+ if n_nodes == int(init_n_nodes*keep_prob):
+ break
+ else:
+ seed += 1
+ else:
+ break
+
+ if inplace:
+ self.tree = new_tree
+ self.tree_dict = new_tree_dict
+ self.n_nodes = n_nodes
+ else:
+ new_obj = self.__class__(n_nodes=n_nodes, seed=self.seed)
+ new_obj.tree = new_tree
+ new_obj.tree_dict = new_tree_dict
+ return new_obj
def get_size(self):
return len(self.tree_dict.keys())
+ def get_param_size(self):
+ return self.tree["param"].size
+
def get_params(self):
params = []
for node in self.tree_dict:
- params.append(self.tree_dict[node]["params"])
+ params.append(self.tree_dict[node]["param"])
return np.array(params, dtype=np.float)
def change_names(self, keep="root"):
@@ -62,7 +125,8 @@ def change_names(self, keep="root"):
self.tree_dict[node]["children"] = []
self.tree_dict[alphabet[i]] = self.tree_dict[node]
self.tree_dict[alphabet[i]]["label"] = alphabet[i]
- del self.tree_dict[node]
+ if new_names[node] != node:
+ del self.tree_dict[node]
for i in self.tree_dict:
self.tree_dict[i]["children"] = []
@@ -72,7 +136,15 @@ def change_names(self, keep="root"):
if self.tree_dict[j]["parent"] == i:
self.tree_dict[i]["children"].append(j)
- def set_colors(self, root_node=None):
+ root_name = keep
+ for node in self.tree_dict:
+ if self.tree_dict[node]['parent'] == '-1':
+ root_name = node
+
+ self.tree = dict_to_tree(self.tree_dict, root_name=root_name)
+ self.tree_dict = tree_to_dict(self.tree)
+
+ def set_colors(self, root_node='root'):
idx = 0
if root_node in self.tree_dict:
self.tree_dict[root_node]["color"] = "lightgray"
@@ -82,6 +154,14 @@ def set_colors(self, root_node=None):
self.tree_dict[node]["color"] = constants.LABEL_COLORS_DICT[node]
except:
self.tree_dict[node]["color"] = constants.CLONES_PAL[i]
+
+ root_name = root_node
+ for node in self.tree_dict:
+ if self.tree_dict[node]['parent'] == '-1':
+ root_name = node
+
+ self.tree = dict_to_tree(self.tree_dict, root_name=root_name)
+ self.tree_dict = tree_to_dict(self.tree)
def add_tree_parameters(self, change_name=True):
@@ -125,28 +205,32 @@ def add_tree_parameters(self, change_name=True):
def generate_tree(self):
alphabet = list(string.ascii_uppercase)
# Add healthy node
- self.tree_dict = dict(
- root=dict(
- parent="-1",
- children=[],
- params=None,
- dp_alpha_subtree=self.dp_alpha_subtree,
- alpha_decay_subtree=self.alpha_decay_subtree,
- dp_gamma_subtree=self.dp_gamma_subtree,
- dp_alpha_parent_edge=self.dp_alpha_parent_edge,
- alpha_decay_parent_edge=self.alpha_decay_parent_edge,
- eta=self.eta,
- weight=0,
- size=int(0),
- color="lightgray",
- label="root",
+ if self.add_root:
+ self.tree_dict = dict(
+ root=dict(
+ parent="-1",
+ children=[],
+ param=None,
+ dp_alpha_subtree=self.dp_alpha_subtree,
+ alpha_decay_subtree=self.alpha_decay_subtree,
+ dp_gamma_subtree=self.dp_gamma_subtree,
+ dp_alpha_parent_edge=self.dp_alpha_parent_edge,
+ alpha_decay_parent_edge=self.alpha_decay_parent_edge,
+ eta=self.eta,
+ weight=0,
+ size=int(0),
+ color="lightgray",
+ label="root",
+ )
)
- )
+ mrca_parent = "root"
+ else:
+ mrca_parent = "-1"
# Add MRCA node
self.tree_dict["A"] = dict(
- parent="root",
+ parent=mrca_parent,
children=[],
- params=None,
+ param=None,
dp_alpha_subtree=self.dp_alpha_subtree,
alpha_decay_subtree=self.alpha_decay_subtree,
dp_gamma_subtree=self.dp_gamma_subtree,
@@ -159,11 +243,12 @@ def generate_tree(self):
label="A",
)
for c in range(1, self.n_nodes):
- parent = alphabet[np.random.choice(np.arange(0, c))]
+ rng = np.random.default_rng(seed=self.seed+c)
+ parent = alphabet[rng.choice(np.arange(0, c))]
self.tree_dict[alphabet[c]] = dict(
parent=parent,
children=[],
- params=None,
+ param=None,
dp_alpha_subtree=self.dp_alpha_subtree,
alpha_decay_subtree=self.alpha_decay_subtree,
dp_gamma_subtree=self.dp_gamma_subtree,
@@ -181,6 +266,12 @@ def generate_tree(self):
if self.tree_dict[j]["parent"] == i:
self.tree_dict[i]["children"].append(j)
+ for node in self.tree_dict:
+ if self.tree_dict[node]['parent'] == '-1':
+ root_name = node
+
+ self.tree = dict_to_tree(self.tree_dict, root_name=root_name)
+
def get_sum_weights_subtree(self, label):
if "weight" not in self.tree_dict["A"].keys():
raise KeyError("No weights were specified in the input tree.")
@@ -288,7 +379,7 @@ def create_adata(self, var_names=None):
)
params.append(
np.vstack(
- [self.tree_dict[node]["params"]] * self.tree_dict[node]["size"]
+ [self.tree_dict[node]["param"]] * self.tree_dict[node]["size"]
)
)
params = pd.DataFrame(np.vstack(params))
@@ -335,7 +426,7 @@ def plot_heatmap(self, var_names=None, cmap=None, **kwds):
def read_tree_from_dict(
self,
tree_dict,
- input_params_key="params",
+ input_params_key="param",
input_label_key="label",
input_parent_key="parent",
input_sizes_key="size",
@@ -417,6 +508,12 @@ def read_tree_from_dict(
for j in self.tree_dict:
if self.tree_dict[j]["parent"] == i:
self.tree_dict[i]["children"].append(j)
+
+ for node in self.tree_dict:
+ if self.tree_dict[node]['parent'] == '-1':
+ root_name = node
+
+ self.tree = dict_to_tree(self.tree_dict, root_name=root_name)
def root(self):
nodes = list(self.tree_dict.keys())
@@ -447,12 +544,62 @@ def update_weights(self, uniform=False):
def subset_genes(self, gene_list):
for node in self.tree_dict:
- self.tree_dict[node]["params"] = pd.DataFrame(
- self.tree_dict[node]["params"][:, np.newaxis].T,
+ self.tree_dict[node]["param"] = pd.DataFrame(
+ self.tree_dict[node]["param"][:, np.newaxis].T,
columns=self.adata.var_names,
)[gene_list].values.ravel()
self.adata = self.adata[:, gene_list]
@abstractmethod
- def add_node_params(self):
+ def sample_root(self):
return
+
+ @abstractmethod
+ def sample_kernel(self):
+ return
+
+ def update_tree(self):
+ for node in self.tree_dict:
+ if self.tree_dict[node]['parent'] == '-1':
+ root_name = node
+ self.tree = dict_to_tree(self.tree_dict, root_name=root_name)
+
+ def update_dict(self):
+ self.tree_dict = tree_to_dict(self.tree)
+
+ def param_distance(self, paramA, paramB):
+ return np.sqrt(np.sum((paramA-paramB)**2))
+
+ def add_node_params(
+ self, n_genes=2, min_dist=0.2, **params
+ ):
+ def descend(root, idx=0, depth=1):
+ for i, child in enumerate(root['children']):
+ seed = self.seed+idx
+ accepted = False
+ while not accepted:
+ child["param"] = self.sample_kernel(root["param"], seed=seed, depth=depth, **params)
+ dist_to_parent = self.param_distance(root["param"], child["param"])
+ # Reject sample if too close to any other child
+ dists = []
+ for j, child2 in enumerate(root['children']):
+ if j < i:
+ dists.append(self.param_distance(child["param"], child2["param"]))
+ if np.all(np.array(dists) >= min_dist*dist_to_parent):
+ accepted = True
+ else:
+ seed += 1
+
+ idx = descend(child, idx+1, depth=depth+1)
+ return idx
+
+ # Set root param
+ self.tree["param"] = self.sample_root(n_genes=n_genes, seed=self.seed, **params)
+
+ # Set node params recursively
+ descend(self.tree)
+
+ # Update tree_dict too
+ self.tree_dict = tree_to_dict(self.tree)
+
+ self.create_adata()
diff --git a/scatrex/ntssb/search.py b/scatrex/ntssb/search.py
index ef621e3..0f5cb0a 100644
--- a/scatrex/ntssb/search.py
+++ b/scatrex/ntssb/search.py
@@ -1,67 +1,272 @@
import numpy as np
from copy import deepcopy
-from tqdm.auto import tqdm
+from tqdm.auto import trange
from time import time
-from ..util import *
-from jax import jit
-from jax.example_libraries.optimizers import adam
import matplotlib.pyplot as plt
-import logging
-
-logger = logging.getLogger(__name__)
+from ..utils.math_utils import *
+import jax
+import jax.numpy as jnp
-def search_callback(inf):
- return
-
+import logging
-MOVE_WEIGHTS = {
- "add": 1,
- "merge": 1.0,
- "prune_reattach": 1.0,
- "pivot_reattach": 1.0,
- "swap": 1.0,
- "add_reattach_pivot": 1.0,
- "subtree_reattach": 0.5,
- "push_subtree": 1.0,
- "extract_pivot": 1.0,
- "subtree_pivot_reattach": 0.5,
- "perturb_node": .1,
- "perturb_globals": .1,
- "optimize_node": 1,
- "transfer_factor": 0.5,
- "transfer_unobserved": 0.5,
- "clean_node": 0.
-}
+logger = logging.getLogger(__name__)
class StructureSearch(object):
def __init__(self, ntssb):
-
+ # Keep a pointer to the current tree
self.tree = deepcopy(ntssb)
+ # And a pointer to the changed tree
+ self.proposed_tree = deepcopy(self.tree)
+
self.traces = dict()
self.traces["tree"] = []
self.traces["elbo"] = []
- self.traces["score"] = []
- self.traces["move"] = []
- self.traces["temperature"] = []
- self.traces["accepted"] = []
- self.traces["times"] = []
self.traces["n_nodes"] = []
- self.traces["gamma"] = []
self.traces["elbos"] = []
- self.best_elbo = self.tree.elbo
self.best_tree = deepcopy(self.tree)
- self.opt_triplet = None
- def init_optimizer(self, lr=0.01, opt=adam):
- opt_init, opt_update, get_params = opt(lr=lr)
- get_params = jit(get_params)
- opt_update = jit(opt_update)
- opt_init = jit(opt_init)
- self.opt_triplet = (opt_init, opt_update, get_params)
+ def run_search(self, n_iters=10, n_epochs=10, mc_samples=10, step_size=0.01, moves_per_tssb=1, global_freq=0, memoized=True, update_roots=True, seed=42, swap_freq=0, update_outer_ass=True):
+ """
+ Start with learning the parameters of the model for the non-augmented tree,
+ which are just the assignments of cells to TSSBs, the outer stick parameters,
+ and the global model parameters
+ """
+ # Generate PRNG key
+ key = jax.random.PRNGKey(seed)
+
+ self.proposed_tree = deepcopy(self.tree)
+
+ # Run the structure search
+ t = trange(n_iters, desc='Finding NTSSB', leave=True)
+ for i in t:
+ key, subkey = jax.random.split(key)
+
+ update_globals = False
+ if global_freq != 0:
+ if i % global_freq == 0:
+ # Do birth-merge step in which other local params are updated
+ update_globals = True
+
+ # Birth: traverse the tree and spawn a bunch of nodes (quick and helps escape local optima)
+ self.birth(subkey, moves_per_tssb=moves_per_tssb)
+
+ # Update parameters in n_epochs passes through the data, interleaving node updates with local batch updates
+ self.tree.learn_params(n_epochs, update_roots=update_roots, mc_samples=mc_samples,
+ step_size=step_size, memoized=memoized)
+ self.tree.compute_elbo(memoized=memoized)
+ self.proposed_tree = deepcopy(self.tree)
+
+ # Merge: traverse the tree and propose merges and accept/reject based on their summary statistics (reliable)
+ self.merge(subkey, moves_per_tssb=moves_per_tssb, memoized=memoized, update_globals=update_globals,
+ n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size)
+
+ # Prune and reattach: traverse the tree and propose pruning nodes and reattaching somewhere else inside their TSSB
+ # self.prune_reattach(subkey, moves_per_tssb=moves_per_tssb)
+
+ # Swap roots: traverse the tree and propose swapping roots of TSSBs with their immediate children
+ # self.swap_roots(subkey, moves_per_tssb=moves_per_tssb)
+
+ # Keep best
+ if self.tree.elbo > self.best_tree.elbo:
+ self.best_tree = deepcopy(self.tree)
+
+ # Compute ELBO
+ self.traces['elbo'].append(self.tree.elbo)
+
+ t.set_description(f"Finding NTSSB ({self.tree.n_total_nodes} nodes, elbo: {self.tree.elbo})" )
+ t.refresh() # to show immediately the update
+
+ def swap(self, key, memoized=True, n_epochs=10, **learn_kwargs):
+ """
+ Propose changing the root of a subtree. Preferably if
+ """
+ if self.proposed_tree.n_total_nodes == self.proposed_tree.n_nodes:
+ self.logger("Nothing to swap.")
+ return
+
+ n_children = 0
+ while n_children == 0:
+ key, subkey = jax.random.split(key)
+ u = jax.random.uniform(subkey)
+ # See in which subtree it lands
+ subtree, _, u = self.proposed_tree.find_node(u)
+ # Choose either couple of children to merge or a child to merge with parent
+ parent = subtree.root
+ n_children = len(parent['children'])
+
+ # Swap parent-child
+ source_idx = jax.random.choice(subkey, n_children)
+ source = parent['children'][source_idx]
+ target = parent
+
+ self.proposed_tree.swap_root(source, target)
+ self.proposed_tree.compute_elbo(memoized=memoized)
+
+ if self.proposed_tree.elbo > self.tree.elbo:
+ # print(f"Merged {source_label} to {target_label}")
+ self.tree = deepcopy(self.proposed_tree)
+ else:
+ self.proposed_tree.learn_params(n_epochs, update_globals=False, memoized=memoized, **learn_kwargs)
+ self.proposed_tree.compute_elbo(memoized=memoized)
+ if self.proposed_tree.elbo > self.tree.elbo:
+ self.tree = deepcopy(self.proposed_tree)
+ else:
+ self.proposed_tree = deepcopy(self.tree)
+
+ def birth(self, key, moves_per_tssb=1):
+ """
+ Spawn `moves_per_tssb=1` nodes
+ """
+ n_births = self.proposed_tree.n_nodes * moves_per_tssb
+ for _ in range(n_births):
+ key, subkey = jax.random.split(key)
+ u = jax.random.uniform(subkey)
+ target = self.proposed_tree.get_node(u, key=subkey, uniform=True)
+ new_node = target['node'].tssb.add_node(target, seed=int(subkey[1]))
+ if jax.random.bernoulli(subkey):
+ new_node.init_new_node_kernel()
+
+ # Always accept
+ self.tree = deepcopy(self.proposed_tree)
+
+ def merge_root(self, key, memoized=True, n_epochs=10, **learn_kwargs):
+ """
+ Propose merging a root's child to the root and keeping the child's parameters. Optimize and accept if ELBO improves
+ This is done by swapping the parameters of root and child and then merging child to root as usual
+ """
+ if self.proposed_tree.n_total_nodes == self.proposed_tree.n_nodes:
+ self.logger("Nothing to swap.")
+ return
+
+ n_children = 0
+ while n_children == 0:
+ key, subkey = jax.random.split(key)
+ u = jax.random.uniform(subkey)
+ # See in which subtree it lands
+ subtree, _, u = self.proposed_tree.find_node(u)
+ # Choose either couple of children to merge or a child to merge with parent
+ parent = subtree.root
+ n_children = len(parent['children'])
+
+ # Merge parent-child
+ source_idx = jax.random.choice(subkey, n_children)
+ source = parent['children'][source_idx]
+ target = parent
+
+ slab = source['node'].label
+ tlab = target['node'].label
+ # self.proposed_tree.swap_nodes(source['node'], target['node'])
+ self.proposed_tree.swap_root(source['node'].label, target['node'].label)
+ subtree.merge_nodes(target, source, target)
+ self.proposed_tree.compute_elbo(memoized=memoized)
+
+ if self.proposed_tree.elbo > self.tree.elbo: # if ELBO improves even before optimizing
+ self.tree = deepcopy(self.proposed_tree)
+ else:
+ # Update node parameters
+ self.proposed_tree.learn_model(n_epochs, update_globals=False, update_roots=True, memoized=memoized, **learn_kwargs)
+ self.proposed_tree.compute_elbo(memoized=memoized)
+ if self.proposed_tree.elbo > self.tree.elbo:
+ self.tree = deepcopy(self.proposed_tree)
+ else:
+ self.proposed_tree = deepcopy(self.tree)
+
+ def merge(self, key, moves_per_tssb=1, memoized=True, update_globals=False, n_epochs=10, **learn_kwargs):
+ """
+ Traverse the trees and propose and accept/reject merges as we go using local suff stats
+ """
+ n_merges = int(0.7 * self.proposed_tree.n_total_nodes * moves_per_tssb * 2)
+ if update_globals:
+ n_merges = int(0.7 * self.proposed_tree.n_total_nodes * moves_per_tssb)
+ for _ in range(n_merges):
+ key, subkey = jax.random.split(key)
+ u = jax.random.uniform(subkey)
+ parent = self.proposed_tree.get_node(u, key=subkey, uniform=True, include_leaves=False) # get non-leaf node, without accounting for weights
+ tssb = parent['node'].tssb
+ # Choose either couple of children to merge or a child to merge with parent
+ n_children = len(parent['children'])
+ if n_children == 0:
+ continue
+ if n_children > 1:
+ # Choose
+ if jax.random.bernoulli(subkey, 0.5) == 1:
+ # Merge sibling-sibling
+ # Choose a child
+ source_idx, target_idx = jax.random.choice(subkey, n_children, shape=(2,), replace=False)
+ # Choose most similar sibling
+ source = parent['children'][source_idx]
+ target = parent['children'][target_idx]
+ else:
+ # Merge parent-child
+ # Choose most similar child
+ source_idx = jax.random.choice(subkey, n_children)
+ source = parent['children'][source_idx]
+ target = parent
+ else:
+ # Merge parent-child
+ source = parent['children'][0]
+ target = parent
+
+ source_label = source['node'].label
+ target_label = target['node'].label
+ # print(f"Will merge {source_label} to {target_label}")
+ # Merge, updating suff stats
+ tssb.merge_nodes(parent, source, target)
+ # Update node sticks
+ tssb.update_stick_params(parent)
+ # Update pivot probs
+ tssb.update_pivot_probs()
+ # Compute ELBO of new tree
+ self.proposed_tree.compute_elbo(memoized=memoized)
+ # print(f"{self.tree.elbo} -> {self.proposed_tree.elbo}")
+ # Update if ELBO improved
+ if self.proposed_tree.elbo > self.tree.elbo:
+ # print(f"Merged {source_label} to {target_label}")
+ self.tree = deepcopy(self.proposed_tree)
+ else:
+ # Maybe update other locals
+ if update_globals:
+ # print("Inference")
+ # print(self.tree.elbo)
+ self.proposed_tree.learn_params(n_epochs, update_globals=True, memoized=memoized, **learn_kwargs)
+ self.proposed_tree.compute_elbo(memoized=memoized)
+ # print(self.proposed_tree.elbo)
+ if self.proposed_tree.elbo > self.tree.elbo:
+ self.tree = deepcopy(self.proposed_tree)
+ else:
+ self.proposed_tree = deepcopy(self.tree)
+ else:
+ self.proposed_tree = deepcopy(self.tree)
+
+
+ def prune_reattach(self, moves_per_tssb=1):
+ """
+ Prune subtree and reattach somewhere else within the same TSSB
+ """
+ n_prs = self.proposed_tree.n_nodes * moves_per_tssb
+ for _ in range(n_prs):
+ key, subkey = jax.random.split(key)
+ u = jax.random.uniform(subkey)
+ parent, source, target = self.proposed_tree.get_nodes(u, n_nodes=2) # find two nodes in the same TSSB
+ tssb = parent['node'].tssb
+ tssb.prune_reattach(parent, source, target)
+ # Update node stick parameters to account for changed mass distribution
+ tssb.update_stick_params(parent)
+ tssb.update_stick_params(target)
+
+ # Optimize kernel parameters of root of moved subtree
+ tssb.update_node_params(source)
+
+ # Compute ELBO of new tree
+ self.proposed_tree.compute_elbo()
+ # Update if ELBO improved
+ if self.proposed_tree.elbo > self.tree.elbo:
+ self.tree = self.proposed_tree
+
def plot_traces(
self,
@@ -112,2251 +317,3 @@ def plot_traces(
color="black",
)
plt.show()
-
- def run_search(
- self,
- n_iters=500,
- n_iters_elbo=20,
- n_iters_elbo_init=20,
- factor_delay=0,
- posterior_delay=0,
- global_delay=0,
- joint_init=True,
- thin=10,
- local=False,
- mc_samples=1,
- lr=0.001,
- verbosity=logging.INFO,
- tol=1e-6,
- mb_size=256,
- max_nodes=5,
- debug=False,
- callback=None,
- alpha=0.0,
- Tmax=10,
- anneal=False,
- restart_step=10,
- move_weights=None,
- weighted=True,
- merge_n_tries=5,
- opt=adam,
- search_callback=None,
- add_rule="accept",
- add_rule_thres=1.0,
- seed=1,
- rescore_best=False,
- **callback_kwargs,
- ):
-
- logger.setLevel(verbosity)
-
- np.random.seed(seed)
-
- elbos = []
- gamma = 1.0
-
- if move_weights is None:
- move_weights = MOVE_WEIGHTS
-
- self.tree.max_nodes = (
- len(self.tree.input_tree_dict.keys()) * max_nodes
- ) # upper bound on number of nodes
-
- mb_size = min(mb_size, self.tree.data.shape[0])
-
- score_type = "elbo"
- # if posterior_delay > 0:
- # score_type = 'll'
-
- main_lr = lr
- T = Tmax
- if not anneal:
- T = 1.0
-
- if not local and global_delay > 0:
- local = True
-
- n_factors = self.tree.root["node"].root["node"].num_global_noise_factors
-
- transfer_factor_weight = 0.0
- if "transfer_factor" in move_weights:
- if n_factors == 0:
- move_weights["transfer_factor"] = 0.0
- move_weights["transfer_unobserved"] = 0.0
- transfer_factor_weight = move_weights["transfer_factor"]
-
- init_baseline = np.mean(self.tree.data, axis=0)
- init_baseline = init_baseline / np.median(
- self.tree.input_tree.adata.X / 2, axis=0
- )
- init_baseline = init_baseline / np.std(init_baseline)
- init_baseline = init_baseline / init_baseline[0]
- init_log_baseline = np.log(init_baseline[1:] + 1e-6)
- init_log_baseline = np.clip(init_log_baseline, -2, 2)
-
- if len(self.traces["score"]) == 0:
- if n_factors > 0 and factor_delay > 0:
- self.tree.root["node"].root["node"].num_global_noise_factors = 0
- if "transfer_factor" in move_weights:
- move_weights["transfer_factor"] = 0.0
- move_weights["transfer_unobserved"] = 0.0
-
- # Compute score of initial tree -- should we really optimize the baseline to the max before doing it with the unobs factors?
- self.tree.reset_variational_parameters()
- self.tree.root["node"].root["node"].variational_parameters["globals"][
- "log_baseline_mean"
- ] = init_log_baseline
- self.tree.optimize_elbo(
- root=self.tree.root,
- sub_root=None,
- restricted=False,
- update_all=True,
- sticks_only=True,
- mc_samples=mc_samples,
- n_inner_steps=n_iters_elbo_init,
- n_local_traverses=5,
- lr=lr,
- mb_size=mb_size,
- )
-
- # full update -- maybe without globals?
- root_node_init = self.tree.root["node"].root["node"]
- update_all = False
- if joint_init:
- update_all = True
- root_node_init = None
- self.tree.root["node"].root["node"].reset_variational_parameters(
- means=False
- )
-
- self.tree.optimize_elbo(
- root=self.tree.root,
- sub_root=None,
- restricted=False,
- update_all=update_all,
- mc_samples=mc_samples,
- n_inner_steps=n_iters_elbo,
- n_local_traverses=5,
- lr=lr,
- mb_size=mb_size,
- )
-
- self.tree.plot_tree(super_only=False)
- self.tree.update_ass_logits(variational=True)
- self.tree.assign_to_best()
- self.best_elbo = self.tree.elbo
- self.best_tree = deepcopy(self.tree)
- else:
- self.tree.root = deepcopy(self.best_tree.root)
- self.best_elbo = self.best_tree.elbo
- self.tree.elbo = self.best_tree.elbo
- gamma = self.traces["gamma"][-1]
-
- moves = list(move_weights.keys())
- moves_weights = list(move_weights.values())
-
- logger.debug(
- f"Will search for the maximum marginal likelihood tree with the following moves: {moves}\n"
- )
-
- init_score = self.tree.elbo if score_type == "elbo" else self.tree.ll
- init_root = deepcopy(self.tree.root)
- move_id = "full"
- n_merge = 0
- # Search tree
- p = np.array(moves_weights)
- p = p / np.sum(p)
- for i in tqdm(range(n_iters)):
-
- try:
- if np.mod(i, 5) == 0:
- nodes, mixture = self.tree.get_node_mixture()
- n_nodes = len(nodes)
- if (
- np.sum(mixture < 0.1 * 1.0 / n_nodes) > np.ceil(n_nodes / 3)
- or n_nodes > self.tree.max_nodes * add_rule_thres
- ):
- # Reduce probability of adding, and only add if score improves
- # p = np.array(move_weights)
- # p[np.where(np.array(moves)=='add')[0][0]] = 0.25 * 1/len(moves)
- add_rule = "improve"
- else:
- # Keep uniform move probability and always accept adds
- p = np.array(moves_weights)
- add_rule = "accept"
- except:
- pass
-
- p = p / np.sum(p)
- # if move_id == 'add':
- # move_id = 'merge' # always try to merge after adding # not -- must give a chance for the noise factors to be updated too
- # else:
- move_id = np.random.choice(moves, p=p)
-
- if (
- i == factor_delay
- and n_factors > 0
- and self.tree.root["node"].root["node"].num_global_noise_factors == 0
- ):
- self.tree = deepcopy(self.best_tree)
- self.tree.root["node"].root["node"].num_global_noise_factors = n_factors
- self.tree.root["node"].root["node"].init_noise_factors()
- move_weights["transfer_factor"] = transfer_factor_weight
- move_weights["transfer_unobserved"] = transfer_factor_weight
- moves = list(move_weights.keys())
- moves_weights = list(move_weights.values())
- move_id = "full"
-
- if lr < main_lr:
- move_id = "reset_globals"
-
- # nits_check = int(np.max([50, .1*n_iters]))
- # if i > 50 and score_type == 'elbo' and np.sum(np.array(self.traces['accepted'][-nits_check:]) == False) == nits_check:
- # logger.debug(f"No moves accepted in {nits_check} iterations. Using ll for 10 iterations.")
- # posterior_delay = i + 10
- #
- # if i < posterior_delay:
- # score_type = 'll'
- # elif i == posterior_delay:
- # # Go back to best in terms of ELBO
- # # self.tree.root = deepcopy(self.best_tree.root)
- # # self.tree.elbo = self.best_elbo
- # score_type = 'elbo'
-
- if global_delay > 0 and i > global_delay:
- local = False
-
- init_root = deepcopy(self.tree.root)
- init_elbo = self.tree.elbo
- init_score = self.tree.elbo if score_type == "elbo" else self.tree.ll
- success = True
-
- # nodes = self.tree.get_nodes()
- # nodes = self.tree._get_nodes(get_roots=True)
- nodes = self.tree.get_node_roots()
- self.tree.n_nodes = len(nodes)
- start = time()
- if move_id == "add" and self.tree.n_nodes < self.tree.max_nodes - 1:
- success, elbos = self.add_node(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- mb_size=mb_size,
- max_nodes=max_nodes,
- debug=debug,
- opt=opt,
- weighted=weighted,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "merge":
- success, elbos = self.merge_nodes(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- mb_size=mb_size,
- max_nodes=max_nodes,
- debug=debug,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "prune_reattach":
- success, elbos = self.prune_reattach(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- mb_size=mb_size,
- max_nodes=max_nodes,
- debug=debug,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "pivot_reattach":
- success, elbos = self.pivot_reattach(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- mb_size=mb_size,
- max_nodes=max_nodes,
- debug=debug,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif (
- move_id == "add_reattach_pivot"
- and self.tree.n_nodes < self.tree.max_nodes - 1
- ):
- success, elbos = self.add_reattach_pivot(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- mb_size=mb_size,
- max_nodes=max_nodes,
- debug=debug,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "swap":
- success, elbos = self.swap_nodes(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- debug=debug,
- mb_size=mb_size,
- max_nodes=max_nodes,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "subtree_reattach":
- success, elbos = self.subtree_reattach(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- debug=debug,
- mb_size=mb_size,
- max_nodes=max_nodes,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "subtree_pivot_reattach":
- success, elbos = self.subtree_pivot_reattach(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- debug=debug,
- mb_size=mb_size,
- max_nodes=max_nodes,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif (
- move_id == "push_subtree"
- and self.tree.n_nodes < self.tree.max_nodes - 1
- ):
- success, elbos = self.push_subtree(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- debug=debug,
- mb_size=mb_size,
- max_nodes=max_nodes,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif (
- move_id == "extract_pivot"
- and self.tree.n_nodes < self.tree.max_nodes - 1
- ):
- success, elbos = self.extract_pivot(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- debug=debug,
- mb_size=mb_size,
- max_nodes=max_nodes,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "optimize_node":
- # Randomly choose a node
- node_root = np.random.choice(nodes[1:])
- node = node_root["node"]
- logger.debug(f"Optimizing {node.label}...")
- elbos = self.tree.optimize_elbo(
- root=None,
- sub_root=node_root,
- restricted=True,
- mc_samples=mc_samples,
- n_inner_steps=n_iters_elbo,
- lr=lr,
- mb_size=mb_size,
- )
- elif move_id == "perturb_node":
- # Randomly choose a node
- node_root = np.random.choice(nodes[1:])
- node = node_root["node"]
- logger.debug(f"Perturbing {node.label}...")
- # Perturb a bit and avoid clash between kernel and effect
- perturbation = np.random.normal(0, 0.5, size=init_log_baseline.size + 1)
- node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] += np.abs(perturbation)
- node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ] += perturbation
- # self.tree.root['node'].root['node'].variational_parameters['globals']['log_baseline_log_std'] += 2. # increase std
- elbos = self.tree.optimize_elbo(
- root=None,
- sub_root=node_root,
- restricted=False,
- mc_samples=mc_samples,
- n_inner_steps=n_iters_elbo,
- lr=lr,
- mb_size=mb_size,
- )
- # init_root, init_elbo, success, elbos = self.perturb_node(local=local, mc_samples=mc_samples, n_iters=n_iters_elbo, thin=thin, lr=lr, tol=tol, debug=debug, mb_size=mb_size, max_nodes=max_nodes, opt=opt, callback=callback, **callback_kwargs)
- elif move_id == "clean_node":
- success, elbos = self.clean_node(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- debug=debug,
- mb_size=mb_size,
- max_nodes=max_nodes,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "transfer_factor":
- success, elbos = self.transfer_factor(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- debug=debug,
- mb_size=mb_size,
- max_nodes=max_nodes,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
-
- elif move_id == "transfer_unobserved":
- success, elbos = self.transfer_unobserved(
- local=local,
- mc_samples=mc_samples,
- n_iters=n_iters_elbo,
- thin=thin,
- lr=lr,
- tol=tol,
- debug=debug,
- mb_size=mb_size,
- max_nodes=max_nodes,
- opt=opt,
- callback=callback,
- **callback_kwargs,
- )
- elif move_id == "globals":
- elbos = self.tree.optimize_elbo(
- root=self.tree.root,
- sub_root=None,
- restricted=False,
- globals_only=True,
- mc_samples=mc_samples,
- n_inner_steps=n_iters_elbo,
- lr=lr,
- mb_size=mb_size,
- )
- elif move_id == "perturb_globals":
- logger.debug("Perturbing globals...")
- # Perturb a bit and optimize all
- perturbation = np.random.normal(0, 0.5, size=init_log_baseline.size)
- self.tree.root["node"].root["node"].variational_parameters["globals"][
- "log_baseline_mean"
- ] += perturbation
- perturbation = np.random.normal(
- 0,
- 0.1,
- size=(
- self.tree.root["node"].root["node"].num_global_noise_factors,
- init_log_baseline.size + 1,
- ),
- )
- self.tree.root["node"].root["node"].variational_parameters["globals"][
- "noise_factors_mean"
- ] += perturbation
- self.tree.root["node"].root["node"].variational_parameters["locals"][
- "unobserved_factors_mean"
- ] *= 0
- elbos = self.tree.optimize_elbo(
- root=self.tree.root,
- sub_root=None,
- restricted=False,
- update_all=True,
- mc_samples=mc_samples,
- n_inner_steps=n_iters_elbo,
- lr=lr,
- mb_size=mb_size,
- )
- elif move_id == "full":
- logger.debug(f"Full update...")
- elbos = self.tree.optimize_elbo(
- root=self.tree.root,
- sub_root=None,
- restricted=False,
- update_all=True,
- mc_samples=mc_samples,
- n_inner_steps=n_iters_elbo,
- lr=lr,
- mb_size=mb_size,
- )
- self.traces["times"].append(time() - start)
-
- if np.isnan(self.tree.elbo):
- logger.debug("Got NaN!")
- self.tree.root = deepcopy(init_root)
- self.tree.elbo = init_elbo
- logger.debug(
- "Proceeding with previous tree, reducing step size and doing `reset_globals`."
- )
- lr = lr * 0.1
- if lr < 1e-4:
- logger.debug(
- "Step size is becoming small. Fetching best tree with noise factors."
- )
- self.tree.root = deepcopy(self.best_tree.root)
- self.tree.elbo = self.best_elbo
- if (
- self.tree.root["node"].root["node"].num_global_noise_factors
- == 0
- and n_factors > 0
- and i > factor_delay
- ):
- self.tree.root["node"].root[
- "node"
- ].num_global_noise_factors = n_factors
- self.tree.root["node"].root["node"].init_noise_factors()
- if lr < 1e-6:
- raise ValueError("Step size became too small due to too many NaNs!")
- # self.init_optimizer(lr=lr)
- continue
- else:
- if lr != main_lr:
- lr = main_lr
- # self.init_optimizer(lr=lr)
-
- # if anneal:
- # if i/thin >= 0 and np.mod(i, thin) == 0:
- # idx = int(i/thin)
- # if Tmax != 1:
- # T = Tmax - alpha*idx
- # T = T * (1 + (self.tree.elbo - self.best_elbo)/self.tree.elbo)
-
- new_score = self.tree.elbo if score_type == "elbo" else self.tree.ll
-
- accepted = True
-
- if move_id == "full" or success == False:
- accepted = False
-
- logger.debug(f"ELBO change: {init_elbo} -> {self.tree.elbo}")
- if self.tree.n_nodes >= self.tree.max_nodes:
- self.tree.root = deepcopy(init_root)
- self.tree.elbo = init_elbo
- accepted = False
- elif move_id == "add" or move_id == "add_reattach_pivot":
- if (
- add_rule == "accept" and score_type == "elbo"
- ): # only accept immediatly if using ELBO to score
- logger.debug(
- f"*Move ({move_id}) accepted. ({init_elbo} -> {self.tree.elbo})*"
- )
- if self.tree.elbo > self.best_elbo:
- self.best_elbo = self.tree.elbo
- self.best_tree = deepcopy(self.tree)
- logger.debug(f"New best! {self.best_elbo}")
- else:
- rate = -(init_score - new_score) / T
- rate = rate * gamma
- if rate < np.log(np.random.rand()):
- self.tree.root = deepcopy(init_root)
- self.tree.elbo = init_elbo
- accepted = False
- gamma = gamma * np.exp((0.0 - alpha) * alpha)
- else:
- logger.debug(
- f"*Move ({move_id}) accepted. ({init_elbo} -> {self.tree.elbo})*"
- )
- gamma = gamma * np.exp((1.0 - alpha) * alpha)
- if self.tree.elbo > self.best_elbo:
- self.best_elbo = self.tree.elbo
- self.best_tree = deepcopy(self.tree)
- logger.debug(f"New best! {self.best_elbo}")
- elif move_id != "full" and success == True:
- rate = -(init_score - new_score) / T
- rate = rate * gamma
- rate_thres = np.log(np.random.rand())
- logger.debug(f"{rate}, {rate_thres}")
- if rate <= rate_thres:
- self.tree.root = deepcopy(init_root)
- self.tree.elbo = init_elbo
- accepted = False
- gamma = gamma * np.exp((0.0 - alpha) * alpha)
- else: # Accepted
- logger.debug(
- f"*Move ({move_id}) accepted. ({init_elbo} -> {self.tree.elbo})*"
- )
- gamma = gamma * np.exp((1.0 - alpha) * alpha)
- if self.tree.elbo > self.best_elbo:
- self.best_elbo = self.tree.elbo
- self.best_tree = deepcopy(self.tree)
- logger.debug(f"New best! {self.best_elbo}")
-
- if self.tree.elbo > self.best_elbo:
- self.best_elbo = self.tree.elbo
- self.best_tree = deepcopy(self.tree)
- logger.debug(f"New best! {self.best_elbo}")
-
- if i == factor_delay and factor_delay > 0 and n_factors > 0:
- logger.debug(
- "Setting current tree with complete number of factors as the best."
- )
- self.best_elbo = self.tree.elbo
- self.best_tree = deepcopy(self.tree)
- logger.debug(f"New best! {self.best_elbo}")
-
- gamma = np.max([gamma, 1e-5])
- score = self.tree.elbo if score_type == "elbo" else self.tree.ll
- if rescore_best:
- self.traces["tree"].append(deepcopy(self.tree))
- else:
- self.traces["tree"].append(self.tree.plot_tree(counts=True))
- self.traces["elbo"].append(self.tree.elbo)
- self.traces["score"].append(score)
- self.traces["move"].append(move_id)
- self.traces["n_nodes"].append(self.tree.n_nodes)
- self.traces["temperature"].append(T)
- self.traces["accepted"].append(accepted)
-
- self.traces["gamma"].append(gamma)
- self.traces["elbos"].append(elbos)
-
- if search_callback is not None:
- search_callback(self)
-
- if anneal:
- if i / restart_step > 0 and np.mod(i, restart_step) == 0:
- self.tree.root = deepcopy(self.best_tree.root)
- self.tree.elbo = self.best_elbo
-
- if T == 0:
- break
-
- if rescore_best:
- logger.debug("Re-scoring top 10% trees")
- logger.debug(f"Current best one is tree {np.argmax(self.traces['elbo'])}")
- # Get top 10 unique-scoring trees
- elbos, indices = np.unique(self.traces["elbo"], return_index=True)
- top_scoring = indices[np.where(elbos > np.quantile(elbos, q=0.90))[0]]
- logger.debug(
- f"Re-scoring top 10% trees: got {len(top_scoring)} trees to score"
- )
- # Re-score them
- new_elbos = []
- for tree_idx in top_scoring:
- elbos = self.traces["tree"][tree_idx].optimize_elbo(
- local_node=None,
- root_node=None,
- mc_samples=5,
- n_iters=10000,
- thin=thin,
- tol=tol,
- lr=lr,
- mb_size=mb_size,
- max_nodes=max_nodes,
- init=True,
- debug=False,
- opt=opt,
- opt_triplet=None,
- callback=callback,
- **callback_kwargs,
- )
- logger.debug(
- f"Tree {tree_idx}: {self.traces['elbo'][tree_idx]} -> {self.traces['tree'][tree_idx].elbo}"
- )
- new_elbos.append(self.traces["tree"][tree_idx].elbo)
- # Get new best
- new_best_idx = top_scoring[np.argmax(new_elbos)]
- logger.debug(f"New best is tree {new_best_idx}")
- new_best = self.traces["tree"][new_best_idx]
- self.best_tree = deepcopy(new_best)
- self.best_elbo = new_best.elbo
-
- self.tree.plot_tree(super_only=False)
- self.best_tree.plot_tree(super_only=False)
- self.best_tree.update_ass_logits(variational=True)
- self.best_tree.assign_to_best()
- return self.best_tree
-
- def add_node(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- weighted=True,
- callback=None,
- **callback_kwargs,
- ):
- success = True
- elbos = []
-
- nodes, target_probs = self.tree.get_node_data_sizes(normalized=True)
- if not weighted:
- target_probs = np.array([1.0] * len(nodes))
- target_probs /= np.sum(target_probs)
- node_idx = np.random.choice(range(len(nodes)), p=np.array(target_probs))
- node = nodes[node_idx]
-
- logger.debug(f"Trying to add node below {node.label}")
-
- # Decide wether to initialize from a factor
- from_factor = False
- factor_idx = None
- if self.tree.root["node"].root["node"].num_global_noise_factors > 0:
- if len(nodes) < len(self.tree.input_tree_dict.keys()) * 2:
- p = 0.8
- else:
- p = 0.5
- from_factor = bool(np.random.binomial(1, p))
-
- if from_factor:
- logger.debug(f"Initializing new node from noise factor")
- # Choose factor that the data in the node like
- cells_in_node = np.where(np.array(self.tree.assignments) == node)
- factor_idx = np.argmax(
- np.mean(
- np.abs(
- self.tree.root["node"]
- .root["node"]
- .variational_parameters["globals"]["cell_noise_mean"][
- cells_in_node
- ]
- ),
- axis=0,
- )
- )
-
- # new_node = self.tree.add_node_to(node, optimal_init=True, factor_idx=factor_idx)
- parent_root = self.tree.add_node_to(node, optimal_init=True, factor_idx=factor_idx)
- new_node = parent_root["children"][-1]["node"]
-
- root = self.tree.root
- sub_root = None
- restricted = False
- go_down = False
- if local:
- root = None
- sub_root = parent_root
- restricted = True
-
- children = np.array(list(node.children()))
- idx = np.argsort([n.label for n in children])
- children = children[idx]
- unobs_children = [
- child for child in children if not child.is_observed and child != new_node
- ]
- if len(unobs_children) > 0:
- pr = np.random.binomial(1, 0.5)
- if pr:
- child = np.random.choice(unobs_children)
- logger.debug(f"Also attaching {child.label} to new node")
- self.tree.prune_reattach(child, new_node)
- go_down = True
-
- self.tree.plot_tree()
- # Ensure constant node order
- if from_factor:
- # Remove factor
- self.tree.root["node"].root["node"].variational_parameters["globals"][
- "noise_factors_mean"
- ][factor_idx] *= 0.0
- self.tree.root["node"].root["node"].variational_parameters["globals"][
- "cell_noise_mean"
- ][:, factor_idx] = 0.0
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=None,
- update_all=True,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,# * 5,
- lr=lr,
- mb_size=mb_size,
- )
- else:
- if node.parent() is None: # if root, need to update also global factors
- elbos = self.tree.optimize_elbo(
- root=self.tree.root,
- sub_root=None,
- update_all=True,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,# * 5,
- lr=lr,
- mb_size=mb_size,
- )
- else:
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- go_down=go_down,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def merge_nodes(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- # merge_paris = self.tree.get_merge_pairs()
- # for pair in merge_pairs:
- success = False
- elbos = []
- # self.tree.merge_nodes(pair[0], pair[1])
- # self.tree.optimize_elbo(unique_node=None, root_node=pair[1], run=True, mc_samples=mc_samples, n_iters=n_iters, thin=thin, tol=tol, lr=lr, mb_size=mb_size, max_nodes=max_nodes, init=False, debug=debug, opt=opt, opt_triplet=self.opt_triplet, callback=callback, **callback_kwargs)
-
- # Choose a subtree
- # _, subtrees = self.tree.get_mixture()
- subtrees = self.tree.get_tree_roots()
-
- n_nodes = []
- nodes = []
- for subtree in subtrees:
- nodes.append(subtree["node"].get_node_roots())
- n_nodes.append(len(nodes[-1]))
- n_nodes = np.array(n_nodes)
- # Only proceed if there is at least one subtree with mergeable nodes
- if np.any(n_nodes > 1):
- success = True
- probs = n_nodes - 1
- probs = probs / np.sum(probs)
-
- # Choose a subtree with more than 1 node
- idx = np.random.choice(range(len(subtrees)), p=probs)
- subtree = subtrees[idx]
- nodes = nodes[idx]
-
- # Choose a first node A (which can't be the root), biased by size
- inv_sizes = np.array([1/(len(n["node"].data) + 1.) for n in nodes[1:]])
- node_idx = np.random.choice(
- range(len(nodes[1:])), p=inv_sizes/np.sum(inv_sizes)
- )
- nodeA_root = nodes[1:][node_idx]
- nodeA = nodeA_root["node"]
-
- # Get parent and siblings in the same subtree
- parent = nodeA.parent()
- nodes = np.array(list(parent.children()))
- idx = np.argsort([n.label for n in nodes])
- nodes = nodes[idx]
- nodes = [s for s in nodes if s != nodeA and nodeA.tssb == s.tssb]
- nodes.append(parent)
-
- # If nodeA is pivot node, it's also possible to merge the child with it
- n_pivots = 0
- nodeA_children = np.array(list(nodeA.children()))
- idx = np.argsort([n.label for n in nodeA_children])
- nodeA_children = nodeA_children[idx]
- for nodeA_child in nodeA_children:
- if nodeA_child.tssb != nodeA.tssb:
- n_pivots += 1
- nodes.append(nodeA_child)
-
- sims = [
- 1.0 / (np.mean(np.abs(nodeA.node_mean - node.node_mean)) + 1e-8)
- for node in nodes
- ]
-
- # Choose nodeB proportionally to similarities
- nodeB = np.random.choice(nodes, p=sims / np.sum(sims))
- nodeB_root = nodeB.get_tssb_root()
-
- # Choose initialization
- optimal_params = True
- if nodeB.parent() is None:
- optimal_params = bool(np.random.choice(2, p=[0.4, 0.6]))
-
- local_node = None
- restricted = False
- root = self.tree.root
- sub_root = None
- update_all = True
- # If a pivot was chosen, choose merge root of subtree with it
- if nodeB.tssb != nodeA.tssb:
- logger.debug(f"Trying to merge {nodeB.label} to {nodeA.label}...")
- self.tree.merge_nodes(nodeB, nodeA, optimal_params=optimal_params)
- if local:
- local_node = nodeB
- root = None
- restricted = True
- sub_root = nodeB_root
- update_all = False
- else:
- logger.debug(f"Trying to merge {nodeA.label} to {nodeB.label}...")
- self.tree.merge_nodes(nodeA, nodeB, optimal_params=optimal_params)
- if local:
- local_node = nodeB.parent()
- root = None
- restricted = True
- sub_root = nodeB_root
- update_all = False
-
- # Account for merging to baseline
- update_all_n_iters = 40
- if nodeB.parent() is None:
- if optimal_params:
- logger.debug(
- f"Global update since the baseline has been changed..."
- )
- local_node = None
- root = self.tree.root
- sub_root = None
- restricted = False
- update_all = True
- # n_iters = np.max([n_iters, 50])
-
- self.tree.plot_tree()
- # Ensure constant node order
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- update_all=update_all,
- restricted=restricted,
- run=True,
- n_iters=update_all_n_iters,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def prune_reattach(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- success = False
- elbos = []
-
- # Choose a subtree
- # _, subtrees = self.tree.get_mixture()
- subtrees = self.tree.get_tree_roots()
-
- n_nodes = []
- nodes = []
- for subtree in subtrees:
- nodes.append(subtree["node"].get_mixture()[1])
- n_nodes.append(len(nodes[-1]))
- n_nodes = np.array(n_nodes)
- # Only proceed if there is at least one subtree with pruneable nodes -- i.e., it needs to have at least 2 extra nodes
- if np.any(n_nodes > 2):
- success = True
- probs = np.array(n_nodes)
- probs[probs <= 2] = 0.0
- probs = probs / np.sum(probs)
-
- # Choose a subtree with more than 1 node
- idx = np.random.choice(range(len(subtrees)), p=probs)
- subtree = subtrees[idx]
- nodes = nodes[idx]
-
- # Get the nodes which can be reattached, use labels
- self.tree.plot_tree()
- possible_nodes = dict()
- node_label_dict = dict()
- for node in nodes:
- possible_nodes[node.label] = [
- n.label
- for n in nodes
- if node.label not in n.label and n != node.parent()
- ]
- node_label_dict[node.label] = node
-
- # Choose a first node
- possible_nodesA = [
- node for node in possible_nodes if len(possible_nodes[node]) > 0
- ]
- nodeA_label = np.random.choice(
- possible_nodesA, p=[1.0 / len(possible_nodesA)] * len(possible_nodesA)
- )
- nodeA = node_label_dict[nodeA_label]
-
- # Get nodes not below node A: use labels
- sims = [
- 1.0
- / (
- np.mean(
- np.abs(nodeA.node_mean - node_label_dict[node_label].node_mean)
- )
- + 1e-8
- )
- for node_label in possible_nodes[nodeA_label]
- ]
-
- # Choose nodeB proportionally to similarities
- nodeB_label = np.random.choice(
- possible_nodes[nodeA_label], p=sims / np.sum(sims)
- )
- nodeB = node_label_dict[nodeB_label]
-
- logger.debug(f"Trying to reattach {nodeA_label} to {nodeB_label}...")
-
- self.tree.prune_reattach(nodeA, nodeB)
- self.tree.plot_tree()
- # Ensure constant node order
- local_node = None
- root = self.tree.root
- sub_root = None
- restricted = False
- if local:
- root = None
- local_node = nodeA
- sub_root = subtree["node"].root
- restricted = True
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def pivot_reattach(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- success = False
- elbos = []
-
- # Uniformly pick a subtree
- subtrees = self.tree.get_mixture()[1][1:] # without the root
- subtree = np.random.choice(subtrees, p=[1.0 / len(subtrees)] * len(subtrees))
- init_pivot_node = subtree.root["node"].parent()
- init_pivot = init_pivot_node.label
- init_pivot_node_parent = init_pivot_node.parent()
-
- # Choose a pivot node from the parent subtree that isn't the current one
- weights, nodes = init_pivot_node.tssb.get_fixed_weights()
- # Only proceed if parent subtree has more than 1 node
- if len(nodes) > 1:
- success = True
- # weights = [weight for i, weight in enumerate(weights) if nodes[i] != init_pivot_node]
- # weights = np.array(weights) / np.sum(weights)
- # Also use the similarity of the parent subtree's nodes' unobserved factors with the subtree root
- sims = [
- 1.0
- / (
- np.mean(
- np.abs(
- subtree.root["node"].variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- - node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- )
- )
- + 1e-8
- )
- for node in nodes
- ]
- log_weights = [
- np.log(weights[i]) + np.log(sims[i])
- for i, node in enumerate(nodes)
- if node != init_pivot_node
- ]
- weights = np.exp(np.array(log_weights))
- weights = weights / np.sum(weights)
- nodes = [
- node for node in nodes if node != init_pivot_node
- ] # remove the current pivot
- node_idx = np.random.choice(range(len(nodes)), p=weights)
- node = nodes[node_idx]
-
- # Update pivot
- logger.debug(f"Trying to set {node.label} as pivot of {subtree.label}")
-
- self.tree.pivot_reattach_to(subtree, node)
-
- removed_pivot = False
- if (
- len(init_pivot_node.data) == 0
- and init_pivot_node_parent is not None
- and len(init_pivot_node.children()) == 0
- ):
- logger.debug(
- f"Also removing initial pivot ({init_pivot_node.label}) from tree"
- )
- self.tree.merge_nodes(init_pivot_node, init_pivot_node_parent)
- # self.tree.optimize_elbo(sticks_only=True, root_node=init_pivot_node_parent, mc_samples=mc_samples, n_iters=n_iters, thin=thin, tol=tol, lr=lr, mb_size=mb_size, max_nodes=max_nodes, init=False, debug=debug, opt=opt, opt_triplet=self.opt_triplet, callback=callback, **callback_kwargs)
- removed_pivot = True
-
- self.tree.plot_tree()
- # Ensure constant node order
- root = self.tree.root
- sub_root = None
- restricted = False
- if local:
- root = None
- sub_root = subtree.root
- restricted = True
- # if root_node.parent() is not None:
- # root_node = root_node.parent() # more robust
- # if removed_pivot:
- # root_node = init_pivot_node_parent
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def add_reattach_pivot(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- success = True
- elbos = []
-
- # Add a node below a subtree with children subtrees
- subtrees = self.tree.get_subtrees(get_roots=True)
- nonleaf_subtrees = [
- subtree for subtree in subtrees if len(subtree[1]["children"]) > 0
- ]
-
- # Don't waste too much time with subtrees which are not expected to have any complexity
- subtree_weights = [subtree[0].weight for subtree in nonleaf_subtrees]
-
- # If every subtree has a parent with no weight, don't proceed
- if np.sum(subtree_weights) < 1e-10:
- return False, elbos
-
- subtree_weights = np.array(subtree_weights) / np.sum(subtree_weights)
-
- # Pick a subtree
- parent_subtree = nonleaf_subtrees[
- np.random.choice(
- len(nonleaf_subtrees),
- p=subtree_weights,
- )
- ]
-
- # Pick a node in the parent subtree
- nodes, target_probs = self.tree.get_node_data_sizes(normalized=True)
- target_probs = [
- prob + 1e-8
- for i, prob in enumerate(target_probs)
- if nodes[i].tssb == parent_subtree[0]
- ]
- nodes = [node for node in nodes if node.tssb == parent_subtree[0]]
- target_probs /= np.sum(target_probs)
- node = np.random.choice(nodes, p=np.array(target_probs))
- pivot_node_parent_root = self.tree.add_node_to(node, optimal_init=True, return_parent_root=True)
-
- # Pick one of the children subtrees
- subtrees = [subtree for subtree in parent_subtree[1]["children"]]
- subtree = np.random.choice(subtrees, p=[1.0 / len(subtrees)] * len(subtrees))
- init_pivot = subtree["node"].root["node"].parent()
-
- # Update pivot
- self.tree.pivot_reattach_to(subtree["node"], pivot_node_parent_root["children"][-1]["node"])
-
- logger.debug(
- f"Trying to add node {pivot_node_parent_root['children'][-1]['node'].label} and setting it as pivot of {subtree['node'].label}"
- )
-
- self.tree.plot_tree()
- # Ensure constant node order
- root_node = pivot_node_parent_root
- n_iters_elbo = n_iters
- root = self.tree.root
- sub_root = None
- restricted = False
- if pivot_node_parent_root["children"][-1]["node"].parent().parent() is None:
- n_iters_elbo = n_iters #* 5
- if local:
- root_node = pivot_node_parent_root
- root = None
- sub_root = pivot_node_parent_root
- restricted = True
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- mc_samples=mc_samples,
- go_down=True,
- n_inner_steps=n_iters_elbo,
- lr=lr,
- mb_size=mb_size,
- )
-
-
- return success, elbos
-
- def push_subtree(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- success = True
- elbos = []
-
- # Uniformly pick a subtree
- subtrees = self.tree.get_tree_roots()[1:] # without the root
- subtree = np.random.choice(subtrees, p=[1.0 / len(subtrees)] * len(subtrees))
-
- # Push subtree down
- new_pivot_root = self.tree.push_subtree(subtree["node"].root["node"])
-
- logger.debug(f"Trying to push {subtree['node'].label} down")
-
- root_node = None
- root = self.tree.root
- sub_root = None
- restricted = False
- if local:
- root = None
- sub_root = new_pivot_root
- restricted = True
- go_down=True
- # root_node = subtree.root["node"].parent()
- self.tree.plot_tree()
- # Ensure constant node order
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- go_down=go_down,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def extract_pivot(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- success = True
- elbos = []
-
- # Pick a subtree biased by the weight of the parent TSSB
- subtrees = self.tree.get_mixture()[1][1:] # without the root
- parent_subtree_weights = [
- subtree.root["node"].parent().tssb.weight for subtree in subtrees
- ]
-
- # If every subtree has a parent with no weight, don't proceed
- if np.sum(parent_subtree_weights) < 1e-10:
- return False, elbos
-
- parent_subtree_weights = np.array(parent_subtree_weights) / np.sum(
- parent_subtree_weights
- )
- subtree = np.random.choice(subtrees, p=parent_subtree_weights)
-
- # Push subtree down
- new_node = self.tree.extract_pivot(subtree.root["node"])
-
- logger.debug(f"Trying to extract pivot from {subtree.label}")
-
- root_node = None
- root = self.tree.root
- sub_root = None
- restricted = False
- if local:
- root_node = new_node
- root = None
- sub_root = new_node
- restricted = True
- self.tree.plot_tree()
- # Ensure constant node order
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,
- go_down=True,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def subtree_reattach(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- """
- Move a subtree to a different clone
- """
- success = False
- elbos = []
-
- subtrees = self.tree.get_subtrees(get_roots=True)
-
- # Pick a subtree with more than 1 node
- in_subtrees = [
- subtree[1] for subtree in subtrees if len(subtree[0].root["children"]) > 0
- ]
-
- # If there is any subtree with more than 1 node, proceed
- if len(in_subtrees) > 0:
- success = True
- # Choose one subtree
- subtreeA = np.random.choice(
- in_subtrees, p=[1.0 / len(in_subtrees)] * len(in_subtrees)
- )
-
- # Choose one of its nodes uniformly which is not the root
- node_weights, nodes, roots = subtreeA["node"].get_mixture(get_roots=True)
- nodeA_idx = (
- np.random.choice(
- len(roots[1:]), p=[1.0 / len(roots[1:])] * len(roots[1:])
- )
- + 1
- )
- nodeA_parent_idx = np.where(np.array(nodes) == nodes[nodeA_idx].parent())[
- 0
- ][0]
-
- # Choose another subtree that's similar to the subtree's top node
- rem_subtrees = [s[1] for s in subtrees if s[1]["node"] != subtreeA["node"]]
- sims = [
- 1.0
- / (
- np.mean(
- np.abs(
- roots[nodeA_idx]["node"].node_mean
- - s["node"].root["node"].node_mean
- )
- )
- + 1e-8
- )
- for s in rem_subtrees
- ]
- new_subtree = np.random.choice(rem_subtrees, p=sims / np.sum(sims))
-
- logger.debug(
- f"Trying to set {roots[nodeA_idx]['node'].label} below {new_subtree['node'].label}"
- )
-
- # Move subtree
- optimal_init = bool(np.random.binomial(1, 0.5))
- pivot_changed = self.tree.subtree_reattach_to(
- roots[nodeA_idx]["node"], new_subtree["node"], optimal_init=optimal_init
- )
-
- # Also swap to make the moved subtree root be the new root?
- # if len(list(roots[nodeA_idx]["node"].children())) == 0:
- # if np.random.binomial(1, 0.5):
- # self.tree.swap_nodes(
- # roots[nodeA_idx]["node"], new_subtree["node"].root["node"]
- # )
-
- # self.tree.reset_variational_parameters(variances_only=True)
- # init_baseline = jnp.mean(self.tree.data, axis=0)
- # init_log_baseline = jnp.log(init_baseline / init_baseline[0])[1:]
- # self.tree.root['node'].root['node'].log_baseline_mean = init_log_baseline + np.random.normal(0, .5, size=self.tree.data.shape[1]-1)
- self.tree.plot_tree()
- # Ensure constant node order
- root_node = None
- root = self.tree.root
- sub_root = None
- restricted = False
- update_all = True
- if local:
- root_node = roots[nodeA_idx]["node"]
- root = None
- sub_root = roots[nodeA_idx]
- restricted = True
- update_all = False
- if pivot_changed:
- n_iters = n_iters #* 5
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- update_all=update_all,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def subtree_pivot_reattach(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- """
- Reattach subtree of node A to clone B and use it as pivot of A
- """
- success = False
- elbos = []
-
- subtrees = self.tree.get_subtrees(get_roots=True)
-
- # Pick a non-root subtree with more than 1 node
- in_subtrees = [
- subtree[1]
- for subtree in subtrees[1:]
- if len(subtree[0].root["children"]) > 0
- ]
-
- # If there is any subtree with more than 1 node, proceed
- if len(in_subtrees) > 0:
- success = True
- # Choose one subtree
- subtreeA = np.random.choice(
- in_subtrees, p=[1.0 / len(in_subtrees)] * len(in_subtrees)
- )
-
- # Choose one of its nodes uniformly which is not the root
- node_weights, nodes, roots = subtreeA["node"].get_mixture(get_roots=True)
- nodeA_idx = (
- np.random.choice(
- len(roots[1:]), p=[1.0 / len(roots[1:])] * len(roots[1:])
- )
- + 1
- )
- nodeA_parent_idx = np.where(np.array(nodes) == nodes[nodeA_idx].parent())[
- 0
- ][0]
-
- logger.debug(
- f"Trying to set {roots[nodeA_idx]['node'].label} below {subtreeA['super_parent'].label} and use it as pivot of {subtreeA['node'].label}"
- )
-
- # Move subtree to parent
- self.tree.subtree_reattach_to(
- roots[nodeA_idx]["node"], subtreeA["super_parent"].label
- ) # Use label to avoid bugs with references
-
- # Set root of moved subtree as new pivot. TODO: Choose one node from the leaves instead
- self.tree.pivot_reattach_to(subtreeA["node"], roots[nodeA_idx]["node"])
-
- # And choose a leaf node of that subtree as the pivot of old subtree
- # self.tree.reset_variational_parameters(variances_only=True)
- # init_baseline = jnp.mean(self.tree.data, axis=0)
- # init_log_baseline = jnp.log(init_baseline / init_baseline[0])[1:]
- # self.tree.root['node'].root['node'].log_baseline_mean = init_log_baseline + np.random.normal(0, .5, size=self.tree.data.shape[1]-1)
- self.tree.plot_tree()
- # Ensure constant node order
- root_node = None
- root = self.tree.root
- sub_root = None
- restricted = False
- update_all = True
- if local:
- root_node = roots[nodeA_idx]["node"]
- root = None
- sub_root = roots[nodeA_idx]
- restricted = True
- update_all = False
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- update_all=update_all,
- go_down=True,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,# * 5,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def swap_nodes(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- success = True
- elbos = []
-
- # Randomly decide whether to update pivots
- update_pivots = np.random.binomial(1, 0.5)
-
- def tssb_swap(tssb, children_trees, ntree, n_iters, update_pivots=True):
- weights, nodes = tssb.get_mixture()
- nodes = tssb.get_node_roots()
-
- empty_root = False
- if len(nodes) > 1:
- nodeA_root, nodeB_root = np.random.choice(nodes, replace=False, size=2)
- nodeA = nodeA_root["node"]
- nodeB = nodeB_root["node"]
- # Root can't be empty in the original TSSB
- if len(nodes[0]["node"].data) == 0 and nodes[0]["node"].tssb.weight > 1e-6 and nodes[0]["node"].parent() is not None:
- logger.debug("Swapping root")
- empty_root = True
- nodeA_root = nodes[0]
- nodeA = nodeA_root["node"]
- nodeB_root = np.random.choice(nodes[1:])
- nodeB = nodeB_root["node"]
-
- logger.debug(f"Trying to swap {nodeA.label} with {nodeB.label}...")
- self.tree.swap_nodes(nodeA, nodeB, update_pivots=update_pivots)
-
- if empty_root:
- # self.tree = deepcopy(ntree)
- logger.debug(f"Swapped {nodeA.label} with {nodeB.label}")
- # Go through all nodes below root and reset their unobserved_factors_kernel_log_std
- self.tree.plot_tree()
- # Ensure constant node order
- ntree.optimize_elbo(
- root = self.tree.root,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,# * 10,
- lr=lr,
- mb_size=mb_size,
- )
- else:
- # ntree = self.compute_expected_score(ntree, n_burnin=n_burnin, n_samples=n_samples, thin=thin, global_params=global_params, compound=compound)
- # ntree.reset_variational_parameters(variances_only=True)
- # init_baseline = jnp.mean(ntree.data, axis=0)
- # init_log_baseline = jnp.log(init_baseline / init_baseline[0])[1:]
- # ntree.root['node'].root['node'].log_baseline_mean = init_log_baseline + np.random.normal(0, .5, size=ntree.data.shape[1]-1)
- root_node = nodes[0]
- root = None
- sub_root = None
- restricted = False
- update_all = False
- go_down = False
- n_all_iters = 40
- if nodeB == nodeA.parent():
- root_node = nodeB_root
- root = None
- sub_root = nodeB_root
- restricted = True
- go_down = True
- update_all = False
- elif nodeA == nodeB.parent():
- root_node = nodeA_root
- root = None
- sub_root = nodeA_root
- restricted = True
- go_down = True
- update_all = False
- if not local:
- root_node = None
- root = self.tree.root
- sub_root = None
- restricted = False
- update_all = True
- # n_iters *= 10 # Big change, so give time to converge
- if root_node:
- if root_node["node"].parent() is None:
- root_node = None # Update everything!
- # n_iters *= 2 # Big change, so give time to converge
- root = self.tree.root
- sub_root = None
- restricted = False
- update_all = True
- self.tree.plot_tree()
- # Ensure constant node order
- logger.debug(f"Using {root_node} as root within in-tssb swap...")
- elbos = ntree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- update_all=update_all,
- n_iters=n_all_iters,
- go_down=go_down,
- mc_samples=mc_samples,
- n_inner_steps=n_iters,
- lr=lr,
- mb_size=mb_size,
- )
-
- def descend(root, subtree, ntree, done):
- if not done:
- if root["node"] == subtree:
- tssb_swap(
- subtree,
- root["children"],
- ntree,
- n_iters,
- update_pivots=update_pivots,
- )
- return True
- else:
- for index, child in enumerate(root["children"]):
- done = descend(child, subtree, ntree, done)
- if done:
- break
-
- # Randomly decide between within TSSB swap or unrestricted in ntssb
- within_tssb = np.random.binomial(1, 0.5)
-
- if within_tssb:
- # Uniformly pick a subtree with more than 1 node
- subtrees = self.tree.get_mixture()[1]
- subtrees = [
- subtree for subtree in subtrees if len(subtree.root["children"]) > 0
- ]
- if len(subtrees) > 0:
- subtree = np.random.choice(
- subtrees, p=[1.0 / len(subtrees)] * len(subtrees)
- )
-
- descend(self.tree.root, subtree, self.tree, False)
- else:
- nodes = self.tree.get_node_roots()
- nodes = nodes[1:] # without root
-
- # Randomly decide between parent-child and unrestricted
- unrestricted = np.random.binomial(1, 0.5)
- if unrestricted:
- nodeA_root, nodeB_root = np.random.choice(nodes, replace=False, size=2)
- nodeA = nodeA_root["node"]
- nodeB = nodeB_root["node"]
- else:
- nodeA_root = np.random.choice(nodes)
- nodeA = nodeA_root["node"]
- nodeB = nodeA.parent()
- nodeB_root = nodeB.get_tssb_root()
-
- if nodeB is not None:
- logger.debug(f"Trying to swap {nodeA.label} with {nodeB.label}...")
- self.tree.swap_nodes(nodeA, nodeB, update_pivots=update_pivots)
- root_node = nodeB
- root = None
- sub_root=nodeB_root
- restricted=False
- update_all=False
- go_down=True
- if unrestricted:
- if nodeA == nodeB.parent():
- root_node = nodeA
- root=None
- sub_root=nodeA_root
- restricted=True
- update_all=False
- go_down=True
- elif nodeB == nodeA.parent():
- root_node = nodeB
- root=None
- sub_root=nodeB_root
- restricted=True
- update_all=False
- go_down=True
- else:
- mrca = self.tree.get_mrca(nodeA, nodeB)
- mrca_root = mrca.get_tssb_root()
- root=None
- sub_root=mrca_root
- restricted=True
- update_all=False
- go_down=True
- for child in root_node.children():
- child.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.clip(
- root_node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ],
- -10,
- 10,
- )
- # root_node = self.tree.root['node'].root['node']
- new_n_iters = n_iters
- if root_node.parent() is None:
- new_n_iters = n_iters
- if not local:
- root_node = None
- root = self.tree.root
- sub_root = None
- restricted = False
- update_all = True
- go_down = False
- self.tree.plot_tree()
- # Ensure constant node order
- elbos = self.tree.optimize_elbo(
- root=root,
- sub_root=sub_root,
- restricted=restricted,
- update_all=update_all,
- go_down=go_down,
- mc_samples=mc_samples,
- n_inner_steps=new_n_iters,
- lr=lr,
- mb_size=mb_size,
- )
-
- return success, elbos
-
- def transfer_factor(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- # Take events from a factor and put them in the node that contains
- # cells that use that factor the most
-
- success = True
- elbos = []
-
- # Randomly choose a factor
- factor_idx = np.random.randint(
- self.tree.root["node"].root["node"].num_global_noise_factors
- )
-
- # Get genes in the factor
- target_genes = np.argsort(
- np.abs(
- self.tree.root["node"]
- .root["node"]
- .variational_parameters["globals"]["noise_factors_mean"][factor_idx]
- )
- )[-10:]
-
- # Get cells that give large weight to it
- thres = np.quantile(
- np.abs(
- self.tree.root["node"]
- .root["node"]
- .variational_parameters["globals"]["cell_noise_mean"][:, factor_idx]
- ),
- 0.75,
- )
- target_cells = np.where(
- np.abs(
- self.tree.root["node"]
- .root["node"]
- .variational_parameters["globals"]["cell_noise_mean"][:, factor_idx]
- )
- > thres
- )[0]
-
- # Get node that most of them attach to
- target_node = max(
- list(np.array(self.tree.assignments)[target_cells]),
- key=list(np.array(self.tree.assignments)[target_cells]).count,
- )
-
- # Increase kernel on the genes that are affected by that factor
- target_node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ][target_genes] = -1.0
-
- # Remove these genes from the factor
- self.tree.root["node"].root["node"].variational_parameters["globals"][
- "noise_factors_mean"
- ][factor_idx, target_genes] *= 0.0
-
- # Remove weight of this factor from the target cells
- self.tree.root["node"].root["node"].variational_parameters["globals"][
- "cell_noise_mean"
- ][target_cells, factor_idx] *= 0.0
-
- logger.debug(
- f"Trying to move factor {factor_idx} to node {target_node.label}..."
- )
-
- # This move has to be global because we mess with a noise factor
- root_node = None
- elbos = self.tree.optimize_elbo(
- root_node=root_node,
- mc_samples=mc_samples,
- n_iters=n_iters, #* 5,
- thin=thin,
- tol=tol,
- lr=lr,
- mb_size=mb_size,
- max_nodes=max_nodes,
- init=False,
- debug=debug,
- opt=opt,
- opt_triplet=self.opt_triplet,
- callback=callback,
- **callback_kwargs,
- )
-
- return success, elbos
-
- def transfer_unobserved(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- # Take events from a node and put them in the factor that cells below
- # that node most use
-
- success = True
- elbos = []
-
- # Randomly choose a node, biased towards nodes with most events
- nodes = self.tree.get_nodes()
- n_events = [
- np.sum(
- np.exp(
- node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
- )
- )
- for node in nodes[1:]
- ]
- node = np.random.choice(nodes[1:], p=n_events / np.sum(n_events))
-
- # Get all descendants of node
- nodes = node.get_descendants()
-
- # Get cells in those nodes
- data = np.concatenate([list(n.data) for n in nodes]).astype(int)
- if len(data) != 0:
-
- # Get a factor (biased towards the ones those cells like the most)
- factor_counts = np.bincount(
- np.argmax(
- np.abs(
- self.tree.root["node"]
- .root["node"]
- .variational_parameters["globals"]["cell_noise_mean"][data]
- ),
- axis=1,
- )
- )
- target_factor = np.random.choice(
- np.arange(len(factor_counts)), p=factor_counts / np.sum(factor_counts)
- )
- else:
- target_factor = np.random.choice(
- np.arange(self.tree.root["node"].root["node"].num_global_noise_factors)
- )
-
- # Get events from node and put them in factor
- node_event_locs = np.where(
- node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"]
- > -1
- )[0]
- node_event_vec = node.variational_parameters["locals"][
- "unobserved_factors_mean"
- ]
- self.tree.root["node"].root["node"].variational_parameters["globals"][
- "noise_factors_mean"
- ][target_factor, node_event_locs] = node_event_vec[node_event_locs]
-
- # Remove those events from node and its descendants
- node.variational_parameters["locals"]["unobserved_factors_kernel_log_mean"][
- node_event_locs
- ] = -2.0
- for desc in nodes:
- desc.variational_parameters["locals"]["unobserved_factors_mean"][
- node_event_locs
- ] = 0.0
-
- logger.debug(
- f"Trying to move events in node {node.label} to factor {target_factor}..."
- )
-
- # This move has to be global because we mess with a noise factor
- root_node = None
- elbos = self.tree.optimize_elbo(
- root_node=root_node,
- mc_samples=mc_samples,
- n_iters=n_iters,# * 5,
- thin=thin,
- tol=tol,
- lr=lr,
- mb_size=mb_size,
- max_nodes=max_nodes,
- init=False,
- debug=debug,
- opt=opt,
- opt_triplet=self.opt_triplet,
- callback=callback,
- **callback_kwargs,
- )
-
- return success, elbos
-
- def clean_factors(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- # If a factor encodes a CNV profile, remove it and move cells that used
- # it to the actual clone with that CNV profile
- pass
-
- def perturb_node(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- # Move node towards the data in its neighborhood that is currently explained by nodes in its neighborhood
- success = True
- elbos = []
-
- nodes = self.tree.get_nodes()
- node = np.random.choice(nodes)
-
- # Decide wether to move closer to parent, sibling or child
- parent = node.parent()
- if parent is not None:
- siblings = np.array([n for n in list(parent.children()) if n != node])
- idx = np.argsort([n.label for n in siblings])
- siblings = siblings[idx]
- parent = np.array([parent])
- else:
- parent = np.array([])
- siblings = np.array([])
- children = np.array(list(node.children()))
- idx = np.argsort([n.label for n in children])
- children = children[idx]
- if len(children) == 0:
- children = np.array([])
- possibilities = np.concatenate([parent, siblings, children])
- probs = np.array(
- [1 + node.num_local_data() for node in possibilities]
- ) # the more data they have, the more likely it is that we decide to move towards them
- probs = probs / np.sum(probs)
-
- target = np.random.choice(possibilities, p=probs)
-
- logger.debug(f"Trying to move {node.label} close to {target.label}...")
-
- self.tree.perturb_node(node, target)
- root_node = node
- if not local:
- root_node = None
- elbos = self.tree.optimize_elbo(
- root_node=root_node,
- mc_samples=mc_samples,
- n_iters=n_iters,
- thin=thin,
- tol=tol,
- lr=lr,
- mb_size=mb_size,
- max_nodes=max_nodes,
- init=False,
- debug=debug,
- opt=opt,
- opt_triplet=self.opt_triplet,
- callback=callback,
- **callback_kwargs,
- )
-
- return success, elbos
-
- def clean_node(
- self,
- local=False,
- mc_samples=1,
- n_iters=100,
- thin=10,
- tol=1e-7,
- lr=0.05,
- mb_size=100,
- max_nodes=5,
- debug=False,
- opt=None,
- callback=None,
- **callback_kwargs,
- ):
- # Get node with bad kernel and clean it up
- success = True
- elbos = []
-
- nodes = self.tree.get_nodes()
-
- n_genes = nodes[0].observed_parameters.shape[0]
-
- frac_events = np.array(
- [
- np.sum(
- np.exp(
- node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ]
- )
- > 0.2
- )
- / n_genes
- for node in nodes
- ]
- )
-
- probs = (
- 1e-6 + frac_events
- ) # the more complex the kernel, the more likely it is to clean it
- probs = probs / np.sum(probs)
-
- node = np.random.choice(nodes, p=probs)
-
- # Get the number of nodes with too many events
- n_bad_nodes = np.sum(frac_events > 1 / 3)
- frac_bad_nodes = n_bad_nodes / len(nodes)
- if frac_bad_nodes > 1 / 3 or n_bad_nodes > 3:
- # Reset all unobserved_factors
- logger.debug(f"Trying to clean all nodes...")
- for node in nodes:
- node.variational_parameters["locals"]["unobserved_factors_mean"] *= 0.0
- # node.variational_parameters["locals"][
- # "unobserved_factors_log_std"
- # ] = -2 * np.ones((n_genes,))
- node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.log(
- node.unobserved_factors_kernel_concentration_caller()
- ) * np.ones(
- (n_genes,)
- )
- # node.variational_parameters["locals"][
- # "unobserved_factors_kernel_log_std"
- # ] = -2 * np.ones((n_genes,))
-
- root_node = nodes[0]
- if not local:
- root_node = None
- elbos = self.tree.optimize_elbo(
- root_node=root_node,
- mc_samples=mc_samples,
- n_iters=n_iters,# * 5,
- thin=thin,
- tol=tol,
- lr=lr,
- mb_size=mb_size,
- max_nodes=max_nodes,
- init=False,
- debug=debug,
- opt=opt,
- opt_triplet=self.opt_triplet,
- callback=callback,
- **callback_kwargs,
- )
- else:
- logger.debug(f"Trying to clean {node.label}...")
-
- node.variational_parameters["locals"]["unobserved_factors_mean"] *= 0.0
- # node.variational_parameters["locals"][
- # "unobserved_factors_log_std"
- # ] = -2 * np.ones((n_genes,))
- node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
- ] = np.log(node.unobserved_factors_kernel_concentration_caller()) * np.ones(
- (n_genes,)
- )
- # node.variational_parameters["locals"][
- # "unobserved_factors_kernel_log_std"
- # ] = -2 * np.ones((n_genes,))
- root_node = node
- if not local:
- root_node = None
- elbos = self.tree.optimize_elbo(
- root_node=root_node,
- mc_samples=mc_samples,
- n_iters=n_iters,
- thin=thin,
- tol=tol,
- lr=lr,
- mb_size=mb_size,
- max_nodes=max_nodes,
- init=False,
- debug=debug,
- opt=opt,
- opt_triplet=self.opt_triplet,
- callback=callback,
- **callback_kwargs,
- )
-
- return success, elbos
diff --git a/scatrex/ntssb/tssb.py b/scatrex/ntssb/tssb.py
index f82978b..46faf40 100644
--- a/scatrex/ntssb/tssb.py
+++ b/scatrex/ntssb/tssb.py
@@ -15,7 +15,7 @@
import numpy as np
import matplotlib
-from ..util import *
+from ..utils.math_utils import *
import logging
@@ -35,13 +35,16 @@ def __init__(
root_node,
label,
ntssb=None,
+ parent=None,
dp_alpha=1.0,
dp_gamma=1.0,
min_depth=0,
max_depth=15,
- alpha_decay=0.5,
- eta=0.0,
+ alpha_decay=1.,
+ eta=1.0,
+ children_root_nodes=[], # Root nodes of any children TSSBs
color="black",
+ seed=42,
):
if root_node is None:
raise Exception("Root node must be specified.")
@@ -52,7 +55,10 @@ def __init__(
self.dp_gamma = dp_gamma # smaller dp_gamma => larger psi => less nodes
self.alpha_decay = alpha_decay
self.weight = 1.0
+ self.children_root_nodes = children_root_nodes
self.color = color
+ self.seed = seed
+ self._parent = parent
self.eta = eta # wether to put more weight on top or bottom of tree in case we want fixed weights
@@ -62,7 +68,7 @@ def __init__(
self.root = {
"node": root_node,
- "main": boundbeta(1.0, dp_alpha)
+ "main": boundbeta(1.0, dp_alpha, np.random.default_rng(self.seed))
if self.min_depth == 0
else 0.0, # if min_depth > 0, no data can be added to the root (main stick is nu)
"sticks": empty((0, 1)), # psi sticks
@@ -82,8 +88,56 @@ def __init__(
self.kl = -1e6
self._data = set()
- def add_data(self, l):
- self._data.update(l)
+ self.n_nodes = 1
+
+ self.variational_parameters = {'delta_1': 1., 'delta_2': 1., # nu stick
+ 'sigma_1': 1., 'sigma_2': 1., # psi stick
+ 'q_c': [], # prob of assigning each cell to this TSSB
+ 'LSE_z': [], # normalizing constant for prob of assigning each cell to nodes
+ 'LSE_rho': [], # normalizing constant for each child TSSB's probs of assigning each node to nodes in this TSSB (this is a list)
+ 'll': [], # auxiliary quantity
+ 'sum_E_log_1_nu': 0., # auxiliary quantity
+ 'E_log_phi': 0., # auxiliary quantity
+ }
+
+ def parent(self):
+ return self._parent
+
+ def get_param_dict(self):
+ """
+ Go from a dictionary where each node is a TSSB to a dictionary where each node is a dictionary,
+ with `params` and `weight` keys
+ """
+ param_dict = {
+ "param": self.root['node'].get_params(),
+ "mean": self.root['node'].get_mean(),
+ "weight": self.root['weight'],
+ "children": [],
+ "label": self.root['label'],
+ "color": self.root['color'],
+ "size": len(self.root['node'].data),
+ "pivot_probs": self.root['node'].variational_parameters['q_rho'],
+ }
+ def descend(root, root_new):
+ for child in root["children"]:
+ child_new = {
+ "param": child['node'].get_params(),
+ "mean": child['node'].get_mean(),
+ "weight": child['weight'],
+ "children": [],
+ "label": child['label'],
+ "color": child['color'],
+ "size": len(child['node'].data),
+ "pivot_probs": child['node'].variational_parameters['q_rho'],
+ }
+ root_new['children'].append(child_new)
+ descend(child, root_new['children'][-1])
+
+ descend(self.root, param_dict)
+ return param_dict
+
+ def add_datum(self, id):
+ self._data.add(id)
def remove_data(self):
self._data.clear()
@@ -106,6 +160,7 @@ def descend(root):
self.root["node"].remove_data()
descend(self.root)
self.assignments = []
+ self.remove_data()
def reset_tree(self):
# Clear tree
@@ -122,27 +177,167 @@ def reset_tree(self):
}
def reset_node_parameters(
- self, root_params=True, down_params=True, node_hyperparams=None
+ self, min_dist=0.7, **node_hyperparams
):
+ def get_distance(nodeA, nodeB):
+ return np.sqrt(np.sum((nodeA.get_mean() - nodeB.get_mean())**2))
+
# Reset node parameters
def descend(root):
- root["node"].reset_parameters(
- root_params=root_params, down_params=down_params, **node_hyperparams
+ for i, child in enumerate(root["children"]):
+ accepted = False
+ while not accepted:
+ child["node"].reset_parameters(
+ **node_hyperparams
+ )
+ dist_to_parent = get_distance(root["node"], child["node"])
+ # Reject sample if too close to any other child
+ dists = []
+ for j, child2 in enumerate(root["children"]):
+ if j < i:
+ dists.append(get_distance(child["node"], child2["node"]))
+ if np.all(np.array(dists) >= min_dist*dist_to_parent):
+ accepted = True
+ else:
+ child["node"].seed += 1
+ descend(child)
+
+ self.root["node"].reset_parameters(**node_hyperparams)
+ descend(self.root)
+
+ def set_node_hyperparams(self, **kwargs):
+ def descend(root):
+ root['node'].set_node_hyperparams(**kwargs)
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
+
+ def set_weights(self, node_weights_dict):
+ def descend(root):
+ root['weight'] = node_weights_dict[root['label']]
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
+
+ def set_sticks_from_weights(self):
+ def descend(root):
+ stick = self.input_tree.get_sum_weights_subtree(child)
+ if i < len(input_tree_dict[label]["children"]) - 1:
+ sum = 0
+ for j, c in enumerate(input_tree_dict[label]["children"][i:]):
+ sum = sum + self.input_tree.get_sum_weights_subtree(c)
+ stick = stick / sum
+ else:
+ stick = 1.0
+
+ sticks = vstack(
+ [
+ root['sticks'],
+ stick
+ ]
)
- for child in root["children"]:
+ main = root["weight"]
+ subtree_weights_sum = self.input_tree.get_sum_weights_subtree(child)
+ main = main / subtree_weights_sum
+ root['main'] = main
+ root['sticks'] = sticks
+ for child in root['children']:
descend(child)
+ def set_pivot_priors(self):
+ def descend(root, depth=0):
+ root["node"].pivot_prior_prob = self.eta ** depth
+ prior_norm = root["node"].pivot_prior_prob
+ for child in root['children']:
+ child_prior_norm = descend(child, depth=depth+1)
+ prior_norm += child_prior_norm
+ return prior_norm
+
+ prior_norm = descend(self.root)
+
+ def descend(root, depth=0):
+ root["node"].pivot_prior_prob = root["node"].pivot_prior_prob/prior_norm
+ for child in root['children']:
+ descend(child, depth=depth+1)
+
+ # Normalize
+ descend(self.root)
+
+ def sample_variational_distributions(self, **kwargs):
+ def descend(root):
+ root['node'].sample_variational_distributions(**kwargs)
+ for child in root['children']:
+ descend(child)
descend(self.root)
- def reset_node_variational_parameters(self, **kwargs):
+ def set_learned_parameters(self):
+ def descend(root):
+ root['node'].set_learned_parameters()
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
+
+ def get_pivot_probabilities(self, i):
+ def descend(root):
+ probs = [root['node'].variational_parameters['q_rho'][i]]
+ nodes = [root['node']]
+ for child in root['children']:
+ child_probs, child_nodes = descend(child)
+ nodes.extend(child_nodes)
+ probs.extend(child_probs)
+ return probs, nodes
+ return descend(self.root)
+
+ def reset_sufficient_statistics(self, num_batches=1):
+ def descend(root):
+ root["node"].reset_sufficient_statistics(num_batches=num_batches)
+ for child in root["children"]:
+ descend(child)
+ descend(self.root)
+
+ self.suff_stats = {
+ 'mass': {'total': 0, 'batch': [0.] * num_batches}, # \sum_n q(c_n = this tree)
+ 'ent': {'total': 0, 'batch': [0.] * num_batches}, # \sum_n q(c_n = this tree) log q(c_n = this tree)
+ }
+
+ def reset_variational_parameters(self, alpha_nu=1., beta_nu=1., alpha_psi=1., beta_psi=1., **kwargs):
+ # Reset NTSSB parameters relative to this TSSB
+ self.variational_parameters = {'delta_1': alpha_nu, 'delta_2': beta_nu, # nu stick
+ 'sigma_1': alpha_psi, 'sigma_2': beta_psi, # psi stick
+ 'q_c': jnp.ones(self.ntssb.num_data,) * alpha_nu/(alpha_nu+beta_nu), # prob of assigning each cell to this TSSB
+ 'LSE_z': jnp.ones(self.ntssb.num_data,), # normalizing constant for prob of assigning each cell to nodes
+ 'LSE_rho': [1.] * len(self.children_root_nodes), # normalizing constant for each child TSSB's probs of assigning each node to nodes in this TSSB (this is a list)
+ 'll': jnp.ones(self.ntssb.num_data,), # auxiliary quantity
+ 'sum_E_log_1_nu': 0., # auxiliary quantity
+ 'E_log_phi': 0., # auxiliary quantity
+ }
+
# Reset node parameters
def descend(root):
root["node"].reset_variational_parameters(**kwargs)
+ z_norm = jnp.array(root["node"].variational_parameters['q_z'])
+ rho_norm = jnp.array(root["node"].variational_parameters['q_rho'])
for child in root["children"]:
- descend(child)
+ child_z_norm, child_rho_norm = descend(child)
+ z_norm += child_z_norm
+ rho_norm += child_rho_norm
+ return z_norm, rho_norm
+
+ self.root["node"]._parent = None
+
+ # Get normalizing constants
+ z_norm, rho_norm = descend(self.root)
+ # Apply normalization
+ def descend(root):
+ root["node"].reset_variational_parameters(**kwargs)
+ root["node"].variational_parameters['q_z'] = root["node"].variational_parameters['q_z'] / z_norm
+ root["node"].variational_parameters['q_rho'] = list(root["node"].variational_parameters['q_rho'] / rho_norm)
+ for child in root["children"]:
+ descend(child)
descend(self.root)
+
def sample_new_tree(self, num_data, cull=True):
# Clear current tree
self.reset_tree()
@@ -234,336 +429,694 @@ def add_data(self, data, initialize_assignments=False):
node.add_datum(n)
self.assignments.append(node)
- def resample_node_params(
- self,
- iters=1,
- independent_subtrees=False,
- top_node=None,
- data_indices=None,
- compound=True,
- ):
- """
- Go through all nodes in the tree, starting at the bottom, and resample their parameters iteratively.
- """
- if top_node is None:
- top_node = self.root
+ def add_node(self, target_root, seed=None):
+ if seed is None:
+ seed = self.seed + target_root["node"].seed + len(target_root["children"])
- for iter in range(iters):
+ node = target_root["node"].spawn(target_root["node"].observed_parameters, seed=seed)
- def descend(root):
- for index, child in enumerate(root["children"]):
- descend(child)
- root["node"].resample_params(
- independent_subtrees=independent_subtrees,
- data_indices=data_indices,
- compound=compound,
+ rng = np.random.default_rng(seed)
+ stick_length = boundbeta(1, self.dp_gamma, rng)
+ target_root["sticks"] = np.vstack([target_root["sticks"], stick_length])
+ target_root["children"].append(
+ {
+ "node": node,
+ "main": boundbeta(
+ 1.0, (self.alpha_decay ** (target_root["node"].depth + 1)) * self.dp_alpha, np.random.default_rng(seed+1)
)
+ if self.min_depth <= (target_root["node"].depth + 1)
+ else 0.0,
+ "sticks": np.empty((0, 1)),
+ "children": [],
+ "label": node.label,
+ }
+ )
+
+ # Update pivot prior
+ self.set_pivot_priors()
+
+ self.n_nodes += 1
+
+ return node
+
+ def merge_nodes(self, parent_root, source_root, target_root):
+ # Add mass of source to mass of target
+ target_root['node'].variational_parameters['q_z'] += source_root['node'].variational_parameters['q_z']
+ # Only need to update totals, because after the merges we go back to iterate
+ target_root['node'].merge_suff_stats(source_root['node'].suff_stats)
+
+ # Update pivot probs
+ for i in range(len(self.children_root_nodes)):
+ target_root['node'].variational_parameters['q_rho'][i] += source_root['node'].variational_parameters['q_rho'][i]
+
+ # Set children of source as children of target
+ for i, child in enumerate(source_root['children']):
+ target_root['children'].append(child)
+ target_root["sticks"] = np.vstack([target_root["sticks"], source_root['sticks'][i]])
+ child['node'].set_parent(target_root['node'])
+
+ # Remove source from its parent's dict
+ nodes = np.array([n["node"] for n in parent_root["children"]])
+ tokeep = np.where(nodes != source_root['node'])[0].astype(int).ravel()
+ parent_root["sticks"] = parent_root["sticks"][tokeep]
+ parent_root["children"] = list(np.array(parent_root["children"])[tokeep])
+
+ # Delete source node object and dict
+ source_root["node"].kill()
+ del source_root["node"]
+
+ # Update names
+ self.set_node_names(root=parent_root, root_name=parent_root['node'].label)
+
+ # Update pivot prior
+ self.set_pivot_priors()
- descend(top_node)
-
- def resample_assignment(self, current_node, n, obs):
- def path_lt(path1, path2):
- if len(path1) == 0 and len(path2) == 0:
- return 0
- elif len(path1) == 0:
- return 1
- elif len(path2) == 0:
- return -1
- s1 = "".join(map(lambda i: "%03d" % (i), path1))
- s2 = "".join(map(lambda i: "%03d" % (i), path2))
-
- return cmp(s2, s1)
-
- epsilon = finfo(float64).eps
- lengths = []
- reassign = 0
- better = 0
-
- # Get an initial uniform variate.
- ancestors = current_node.get_ancestors()
- current = self.root
- indices = []
- for anc in ancestors[1:]:
- index = list(map(lambda c: c["node"], current["children"])).index(anc)
- current = current["children"][index]
- indices.append(index)
-
- max_u = 1.0
- min_u = 0.0
- llh_s = log(rand()) + current_node.logprob(obs)
- # llh_s = self.assignments[n].logprob(self.data[n:n+1]) - 0.0000001
- while True:
- new_u = (max_u - min_u) * rand() + min_u
- (new_node, new_path, _) = self.find_node(new_u)
- new_llh = new_node.logprob(obs)
- if new_llh > llh_s:
- if new_node != current_node:
- if new_llh > current_node.logprob(obs):
- better += 1
- current_node.remove_datum(n)
- new_node.add_datum(n)
- current_node = new_node
- reassign += 1
- break
- elif abs(max_u - min_u) < epsilon:
- logger.debug("Slice sampler shrank down. Keep current state.")
- break
+ self.n_nodes -= 1
+
+ def compute_elbo(self, idx):
+ """
+ Compute the ELBO of the model in a tree traversal, abstracting away the likelihood and kernel specific functions
+ for the model. The seed is used for MC sampling from the variational distributions for which Eq[logp] is not analytically
+ available (which is the likelihood and the kernel distribution).
+ """
+ def descend(root, depth=0, ll_contrib=0, ass_contrib=0, global_contrib=0):
+ self.n_nodes += 1
+ # Assignments
+ ## E[log p(z|nu,psi))]
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ eq_logp_z = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ if root['node'].parent() is not None:
+ eq_logp_z += root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ eq_logp_z += root['node'].variational_parameters['E_log_phi']
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu']
else:
- path_comp = path_lt(indices, new_path)
- if path_comp < 0:
- min_u = new_u
- # if we are at a leaf in a fixed tree, move on, because we can't create more nodes
- if len(new_node.children()) == 0:
- break
- elif path_comp > 0:
- max_u = new_u
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu
+ ## E[log q(z)]
+ eq_logq_z = jax.lax.select(root['node'].variational_parameters['q_z'][idx] != 0,
+ root['node'].variational_parameters['q_z'][idx] * jnp.log(root['node'].variational_parameters['q_z'][idx]),
+ root['node'].variational_parameters['q_z'][idx])
+ ass_contrib += eq_logp_z*root['node'].variational_parameters['q_z'][idx] - eq_logq_z
+
+ # Sticks
+ E_log_nu = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ nu_kl = (self.dp_alpha * self.alpha_decay**depth - root['node'].variational_parameters['delta_2']) * E_log_1_nu
+ nu_kl -= (root['node'].variational_parameters['delta_1'] - 1) * E_log_nu
+ nu_kl += logbeta_func(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ nu_kl -= logbeta_func(1, self.dp_alpha * self.alpha_decay**depth)
+ psi_kl = 0.
+ if depth != 0:
+ E_log_psi = E_log_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ E_log_1_psi = E_log_1_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ psi_kl = (self.dp_gamma - root['node'].variational_parameters['sigma_2']) * E_log_1_psi
+ psi_kl -= (root['node'].variational_parameters['sigma_1'] - 1) * E_log_psi
+ psi_kl += logbeta_func(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ psi_kl -= logbeta_func(1, self.dp_gamma)
+ global_contrib += nu_kl + psi_kl
+
+ # Kernel
+ if root['node'].parent() is None and self.parent() is None: # is the root of the root TSSB
+ ## E[log p(kernel)]
+ eq_logp_kernel = root['node'].compute_root_prior()
+ ## -E[log q(kernel)]
+ negeq_logq_kernel = root['node'].compute_root_entropy()
+ else:
+ if root['node'].parent() is not None: # Not a TSSB root. The root param probs are computed in the parent TSSBs, weighted by the pivots
+ ## E[log p(kernel)]
+ eq_logp_kernel = root['node'].compute_kernel_prior()
else:
- raise Exception("Slice sampler weirdness.")
-
- return current_node
-
- def resample_assignments(self):
- def path_lt(path1, path2):
- if len(path1) == 0 and len(path2) == 0:
- return 0
- elif len(path1) == 0:
- return 1
- elif len(path2) == 0:
- return -1
- s1 = "".join(map(lambda i: "%03d" % (i), path1))
- s2 = "".join(map(lambda i: "%03d" % (i), path2))
-
- return cmp(s2, s1)
-
- epsilon = finfo(float64).eps
- lengths = []
- reassign = 0
- better = 0
- for n in range(self.num_data):
+ eq_logp_kernel = 0.
+ ## -E[log q(kernel)]
+ negeq_logq_kernel = root['node'].compute_kernel_entropy()
+ global_contrib += eq_logp_kernel + negeq_logq_kernel
+
+ # Pivots
+ eq_logp_rootkernel = 0.
+ eq_logp_rho = 0.
+ eq_logq_rho = 0.
+ for i, next_tssb_root_node in enumerate(self.children_root_nodes):
+ ## E[log p(root kernel | rho kernel)]
+ eq_logp_rootkernel += root['node'].variational_parameters['q_rho'][i] * next_tssb_root_node.compute_root_kernel_prior(root['node'].samples)
+ ## E[log p(rho))]
+ eq_logp_rho += root['node'].pivot_prior_prob
+ ## E[log q(rho))]
+ eq_logq_rho = jax.lax.select(root['node'].variational_parameters['q_rho'][i] != 0,
+ root['node'].variational_parameters['q_rho'][i] * jnp.log(root['node'].variational_parameters['q_rho'][i]),
+ root['node'].variational_parameters['q_rho'][i])
+ global_contrib += eq_logp_rootkernel + eq_logp_rho-eq_logq_rho
+
+ # Likelihood
+ # Use node's kernel sample to evaluate likelihood
+ ll_contrib += root['node'].compute_loglikelihood(idx) * root['node'].variational_parameters['q_z'][idx]
+
+ sum_E_log_1_psi = 0.
+ for child in root['children']:
+ # Auxiliary quantities
+ E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
+ E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ sum_E_log_1_psi += E_log_1_psi
+
+ # Go down
+ ll_contrib, ass_contrib, global_contrib = descend(child, depth=depth+1,
+ ll_contrib=ll_contrib,
+ ass_contrib=ass_contrib,
+ global_contrib=global_contrib)
+
+ return ll_contrib, ass_contrib, global_contrib
+
+ self.n_nodes = 0
+ return descend(self.root)
- # Get an initial uniform variate.
- ancestors = self.assignments[n].get_ancestors()
- current = self.root
- indices = []
- for anc in ancestors[1:]:
- index = list(map(lambda c: c["node"], current["children"])).index(anc)
- current = current["children"][index]
- indices.append(index)
-
- max_u = 1.0
- min_u = 0.0
- llh_s = log(rand()) + self.assignments[n].logprob(self.data[n : n + 1])
- # llh_s = self.assignments[n].logprob(self.data[n:n+1]) - 0.0000001
- while True:
- new_u = (max_u - min_u) * rand() + min_u
- (new_node, new_path, _) = self.find_node(new_u)
- new_llh = new_node.logprob(self.data[n : n + 1])
- if new_llh > llh_s:
- if new_node != self.assignments[n]:
- if new_llh > self.assignments[n].logprob(self.data[n : n + 1]):
- better += 1
- self.assignments[n].remove_datum(n)
- new_node.add_datum(n)
- self.assignments[n] = new_node
- reassign += 1
- break
- elif abs(max_u - min_u) < epsilon:
- logger.debug("Slice sampler shrank down. Keep current state.")
- break
+ def compute_elbo_suff(self):
+ """
+ Compute the ELBO of the model in a tree traversal, abstracting away the likelihood and kernel specific functions
+ for the model. The seed is used for MC sampling from the variational distributions for which Eq[logp] is not analytically
+ available (which is the likelihood and the kernel distribution).
+ """
+ def descend(root, depth=0, ll_contrib=0, ass_contrib=0, global_contrib=0):
+ self.n_nodes += 1
+ # Assignments
+ ## E[log p(z|nu,psi))]
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ eq_logp_z = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ if root['node'].parent() is not None:
+ eq_logp_z += root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ eq_logp_z += root['node'].variational_parameters['E_log_phi']
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ else:
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu
+ ## E[log q(z)]
+ ass_contrib += eq_logp_z*root['node'].suff_stats['mass']['total'] + root['node'].suff_stats['ent']['total']
+
+ # Sticks
+ E_log_nu = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ nu_kl = (self.dp_alpha * self.alpha_decay**depth - root['node'].variational_parameters['delta_2']) * E_log_1_nu
+ nu_kl -= (root['node'].variational_parameters['delta_1'] - 1) * E_log_nu
+ nu_kl += logbeta_func(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ nu_kl -= logbeta_func(1, self.dp_alpha * self.alpha_decay**depth)
+ psi_kl = 0.
+ if depth != 0:
+ E_log_psi = E_log_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ E_log_1_psi = E_log_1_beta(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ psi_kl = (self.dp_gamma - root['node'].variational_parameters['sigma_2']) * E_log_1_psi
+ psi_kl -= (root['node'].variational_parameters['sigma_1'] - 1) * E_log_psi
+ psi_kl += logbeta_func(root['node'].variational_parameters['sigma_1'], root['node'].variational_parameters['sigma_2'])
+ psi_kl -= logbeta_func(1, self.dp_gamma)
+ global_contrib += nu_kl + psi_kl
+
+ # Kernel
+ if root['node'].parent() is None and self.parent() is None: # is the root of the root TSSB
+ ## E[log p(kernel)]
+ eq_logp_kernel = root['node'].compute_root_prior()
+ ## -E[log q(kernel)]
+ negeq_logq_kernel = root['node'].compute_root_entropy()
+ else:
+ if root['node'].parent() is not None: # Not a TSSB root. The root param probs are computed in the parent TSSBs, weighted by the pivots
+ ## E[log p(kernel)]
+ eq_logp_kernel = root['node'].compute_kernel_prior()
else:
- path_comp = path_lt(indices, new_path)
- if path_comp < 0:
- min_u = new_u
- # if we are at a leaf in a fixed tree, move on, because we can't create more nodes
- if len(new_node.children()) == 0:
- break
- elif path_comp > 0:
- max_u = new_u
- else:
- raise Exception("Slice sampler weirdness.")
- lengths.append(len(new_path))
- lengths = array(lengths)
- # logger.debug "reassign: "+str(reassign)+" better: "+str(better)
-
- # def resample_birth(self):
- # # Break sticks to choose node
- # u = rand()
- # node, root = self.find_node(u)
- #
- # # Create child
- # stick_length = boundbeta(1, self.dp_gamma)
- # root['sticks'] = vstack([ root['sticks'], stick_length ])
- # root['children'].append({ 'node' : root['node'].spawn(False, self.root_node.observed_parameters),
- # 'main' : boundbeta(1.0, (self.alpha_decay**(depth+1))*self.dp_alpha) if self.min_depth <= (depth+1) else 0.0,
- # 'sticks' : empty((0,1)),
- # 'children' : [] })
- #
- # # Update parameters of data in parent and child node until convergence:
- # # assignments (starting from parent)
- # # stick lengths
- # # parameters
- #
- # def resample_merge(self):
- # # Choose any node a at random
- #
- # # Compute similarity of parameters of leaf sibilings and parent
- #
- # # Choose closest node b
- #
- #
- #
- # # If the merge is accepted, the child nodes of a are transferred to node b.
-
- def resample_tree_topology(self, children_trees, independent_subtrees=False):
- # x = self.complete_data_log_likelihood_nomix()
- post = self.ntssb.unnormalized_posterior()
- weights, nodes = self.get_mixture()
+ eq_logp_kernel = 0.
+ ## -E[log q(kernel)]
+ negeq_logq_kernel = root['node'].compute_kernel_entropy()
+ global_contrib += eq_logp_kernel + negeq_logq_kernel
+
+ # Pivots
+ eq_logp_rootkernel = 0.
+ eq_logp_rho = 0.
+ eq_logq_rho = 0.
+ for i, next_tssb_root_node in enumerate(self.children_root_nodes):
+ ## E[log p(root kernel | rho kernel)]
+ eq_logp_rootkernel += root['node'].variational_parameters['q_rho'][i] * next_tssb_root_node.compute_root_kernel_prior(root['node'].samples)
+ ## E[log p(rho))]
+ eq_logp_rho += root['node'].pivot_prior_prob
+ ## E[log q(rho))]
+ eq_logq_rho = jax.lax.select(root['node'].variational_parameters['q_rho'][i] != 0,
+ root['node'].variational_parameters['q_rho'][i] * jnp.log(root['node'].variational_parameters['q_rho'][i]),
+ root['node'].variational_parameters['q_rho'][i])
+ global_contrib += eq_logp_rootkernel + eq_logp_rho-eq_logq_rho
+
+ # Likelihood
+ # Use node's kernel sample to evaluate likelihood
+ ll_contrib += root['node'].compute_loglikelihood_suff()
+
+ sum_E_log_1_psi = 0.
+ for child in root['children']:
+ # Auxiliary quantities
+ E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
+ E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ sum_E_log_1_psi += E_log_1_psi
+
+ # Go down
+ ll_contrib, ass_contrib, global_contrib = descend(child, depth=depth+1,
+ ll_contrib=ll_contrib,
+ ass_contrib=ass_contrib,
+ global_contrib=global_contrib)
+
+ return ll_contrib, ass_contrib, global_contrib
+
+ self.n_nodes = 0
+ return descend(self.root)
+
+
+ def update_sufficient_statistics(self, batch_idx=None):
+ def descend(root):
+ root['node'].update_sufficient_statistics(batch_idx=batch_idx)
+ for child in root['children']:
+ descend(child)
+
+ descend(self.root)
+
+ if batch_idx is not None:
+ idx = self.ntssb.batch_indices[batch_idx]
+ else:
+ idx = jnp.arange(self.ntssb.num_data)
+
+ ent = assignment_entropies(self.variational_parameters['q_c'][idx])
+ E_ass = self.variational_parameters['q_c'][idx]
+
+ new_ent = jnp.sum(ent)
+ new_mass = jnp.sum(E_ass)
+
+ if batch_idx is not None:
+ self.suff_stats['ent']['total'] -= self.suff_stats['ent']['batch'][batch_idx]
+ self.suff_stats['ent']['batch'][batch_idx] = new_ent
+ self.suff_stats['ent']['total'] += self.suff_stats['ent']['batch'][batch_idx]
+
+ self.suff_stats['mass']['total'] -= self.suff_stats['mass']['batch'][batch_idx]
+ self.suff_stats['mass']['batch'][batch_idx] = new_mass
+ self.suff_stats['mass']['total'] += self.suff_stats['mass']['batch'][batch_idx]
+ else:
+ self.suff_stats['ent']['total'] = new_ent
+ self.suff_stats['mass']['total'] = new_mass
- empty_root = False
- if len(nodes) > 1:
- if len(nodes[0].data) == 0:
- logger.debug("Swapping root")
- empty_root = True
- nodeAnum = 0
+ def update_local_params(self, idx, update_ass=True, take_gradients=False, **kwargs):
+ """
+ This performs a tree traversal to update the cell to node attachment probabilities.
+ Returns \sum_node Eq[logp(y_n|psi_node)] * q(z_n=node) with the updated q(z_n=node)
+ """
+ def descend(root, local_grads=None):
+ E_log_1_nu = E_log_1_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ logprior = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ if root['node'].parent() is not None:
+ logprior += root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ logprior += root['node'].variational_parameters['E_log_phi']
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu + root['node'].parent().variational_parameters['sum_E_log_1_nu']
else:
- nodeAnum = randint(0, len(nodes))
- nodeBnum = randint(0, len(nodes))
- while nodeAnum == nodeBnum:
- nodeBnum = randint(0, len(nodes))
-
- def swap_nodes(nodeAnum, nodeBnum, verbose=False):
- def findNodes(root, nodeNum, nodeA=False, nodeB=False):
- node = root
- if nodeNum == nodeAnum:
- nodeA = node
- if nodeNum == nodeBnum:
- nodeB = node
- for i, child in enumerate(root["children"]):
- nodeNum = nodeNum + 1
- (nodeA, nodeB, nodeNum) = findNodes(
- child, nodeNum, nodeA, nodeB
- )
- return (nodeA, nodeB, nodeNum)
+ root['node'].variational_parameters['sum_E_log_1_nu'] = E_log_1_nu
+
+ # Take gradient of locals
+ weights = root['node'].variational_parameters['q_z'][idx] * self.variational_parameters['q_c'][idx]
+ if take_gradients:
+ local_grads_down = root["node"].compute_ll_locals_grad(self.ntssb.data[idx], idx, weights) # returns a tuple of grads for each cell-specific param
+ if local_grads is None:
+ local_grads = list(local_grads_down)
+ else:
+ for i, grads in enumerate(list(local_grads_down)):
+ local_grads[i] += grads
+ ll = root['node'].compute_loglikelihood(idx)
+ root['node'].variational_parameters['ll'] = ll
+ root['node'].variational_parameters['logprior'] = logprior
+ new_log_prob = ll + logprior
+ if update_ass:
+ root['node'].variational_parameters['q_z'] = root['node'].variational_parameters['q_z'].at[idx].set(new_log_prob)
+
+ logqs = [new_log_prob]
+ sum_E_log_1_psi = 0.
+ for child in root['children']:
+ E_log_psi = E_log_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ child['node'].variational_parameters['E_log_phi'] = E_log_psi + sum_E_log_1_psi
+ E_log_1_psi = E_log_1_beta(child['node'].variational_parameters['sigma_1'], child['node'].variational_parameters['sigma_2'])
+ sum_E_log_1_psi += E_log_1_psi
+
+ # Go down
+ child_log_probs, child_local_param_grads = descend(child, local_grads=local_grads)
+ logqs.extend(child_log_probs)
+ if child_local_param_grads is not None:
+ for i, grads in enumerate(list(child_local_param_grads)):
+ local_grads[i] += grads
+ # local_grads += child_local_param_grads
+
+ return logqs, local_grads
+
+ logqs, local_grads = descend(self.root)
+
+ # Compute LSE
+ logqs = jnp.array(logqs).T # make it obs by nodes
+ self.variational_parameters['LSE_z'] = jax.scipy.special.logsumexp(logqs, axis=1)
+ # Set probs and return sum of weighted likelihoods
+ def descend(root):
+ if update_ass:
+ newvals = jnp.exp(root['node'].variational_parameters['q_z'][idx] - self.variational_parameters['LSE_z'])
+ root['node'].variational_parameters['q_z'] = root['node'].variational_parameters['q_z'].at[idx].set(newvals)
+ ell = (root['node'].variational_parameters['ll'] + root['node'].variational_parameters['logprior']) * root['node'].variational_parameters['q_z'][idx]
+ ent = -jax.lax.select(root['node'].variational_parameters['q_z'][idx] != 0,
+ root['node'].variational_parameters['q_z'][idx] * jnp.log(root['node'].variational_parameters['q_z'][idx]),
+ root['node'].variational_parameters['q_z'][idx])
+
+ for child in root['children']:
+ ell_, ent_ = descend(child)
+ ell += ell_
+ ent += ent_
+ return ell, ent
+
+ ell, ent = descend(self.root)
+ return ell, ent, local_grads
+
+ def get_global_grads(self, idx):
+ """
+ This performs a tree traversal to update the global parameters
+ """
+ def descend(root, globals_grads=None):
+ weights = root['node'].variational_parameters['q_z'][idx] * self.variational_parameters['q_c'][idx]
+ globals_grads_down = root["node"].compute_ll_globals_grad(self.ntssb.data[idx], idx, weights)
+ if globals_grads is None:
+ globals_grads = list(globals_grads_down)
+ else:
+ for i, grads in enumerate(list(globals_grads_down)):
+ globals_grads[i] += grads
+ for child in root['children']:
+ child_globals_grads = descend(child, globals_grads=globals_grads)
+ for i, grads in enumerate(list(child_globals_grads)):
+ globals_grads[i] += grads
+ return globals_grads
+
+ # Get gradient of loss of data likelihood weighted by assignment probability to each node wrt current sample of global params
+ return descend(self.root)
- (nodeA, nodeB, nodeNum) = findNodes(self.root, nodeNum=0)
+ def update_stick_params(self, root=None, memoized=True):
+ def descend(root, depth=0):
+ mass_down = 0
+ for child in root['children'][::-1]:
+ child_mass = descend(child, depth=depth+1)
- paramsA = nodeA["node"].unobserved_factors
- dataA = set(nodeA["node"].data)
- mainA = nodeA["main"]
+ child['node'].variational_parameters['sigma_1'] = 1.0 + child_mass
+ child['node'].variational_parameters['sigma_2'] = self.dp_gamma + mass_down
+ mass_down += child_mass
- nodeA["node"].unobserved_factors = nodeB["node"].unobserved_factors
- nodeA["node"].node_mean = (
- nodeA["node"].baseline_caller()
- * nodeA["node"].cnvs
- / 2
- * np.exp(nodeA["node"].unobserved_factors)
- )
+ if memoized:
+ mass_here = root['node'].suff_stats['mass']['total']
+ else:
+ mass_here = jnp.sum(root['node'].variational_parameters['q_z'] * self.variational_parameters['q_c'])
+ root['node'].variational_parameters['delta_1'] = 1.0 + mass_here
+ root['node'].variational_parameters['delta_2'] = (self.alpha_decay**depth) * self.dp_alpha + mass_down
- for dataid in list(dataA):
- nodeA["node"].remove_datum(dataid)
- for dataid in nodeB["node"].data:
- nodeA["node"].add_datum(dataid)
- self.ntssb.assignments[dataid]["node"] = nodeA["node"]
- nodeA["main"] = nodeB["main"]
-
- nodeB["node"].unobserved_factors = paramsA
- nodeB["node"].node_mean = (
- nodeB["node"].baseline_caller()
- * nodeB["node"].cnvs
- / 2
- * np.exp(nodeB["node"].unobserved_factors)
- )
+ return mass_here + mass_down
- dataB = set(nodeB["node"].data)
-
- for dataid in list(dataB):
- nodeB["node"].remove_datum(dataid)
- for dataid in dataA:
- nodeB["node"].add_datum(dataid)
- self.ntssb.assignments[dataid]["node"] = nodeB["node"]
- nodeB["main"] = mainA
-
- if not independent_subtrees:
- # Go to subtrees
- # For each subtree, if pivot was swapped, update it
- for child in children_trees:
- if child["pivot_node"] == nodeA["node"]:
- child["pivot_node"] = nodeB["node"]
- child["node"].root["node"].set_parent(
- nodeB["node"], reset=False
- )
- elif child["pivot_node"] == nodeB["node"]:
- child["pivot_node"] = nodeA["node"]
- child["node"].root["node"].set_parent(
- nodeA["node"], reset=False
- )
+ # Update sticks
+ if root is None:
+ root = self.root
+ descend(root)
- logger.debug(
- f"Swapped {nodeA['node'].label} with {nodeB['node'].label}"
- )
+ def update_node_params(self, key, root=None, memoized=True, step_size=0.0001, mc_samples=10, i=0, adaptive=True, **kwargs):
+ """
+ Update variational parameters for kernels, sticks and pivots
- if empty_root:
- logger.debug("checking alternative root")
- nodenum = []
- for ii, nn in enumerate(nodes):
- if len(nodes[ii].data) > 0:
- nodenum.append(ii)
- post_temp = zeros(len(nodenum))
- for idx, nodeBnum in enumerate(nodenum):
- logger.debug(f"nodeBnum: {nodeBnum}")
- logger.debug(f"nodeAnum: {nodeAnum}")
- swap_nodes(nodeAnum, nodeBnum)
- for i in range(5):
- self.resample_sticks()
- self.ntssb.root["node"].root["node"].resample_cell_params()
- post_new = self.ntssb.unnormalized_posterior()
- post_temp[idx] = post_new
-
- accept_prob = np.exp(np.min([0.0, post_new - post]))
-
- if rand() > accept_prob:
- swap_nodes(nodeAnum, nodeBnum)
- for i in range(5):
- self.resample_sticks()
- self.ntssb.root["node"].root["node"].resample_cell_params()
-
- if nodeBnum == len(nodes) - 1:
- logger.debug("forced swapping")
- nodeBnum = post_temp.argmax() + 1
- swap_nodes(nodeAnum, nodeBnum)
- for i in range(5):
- self.resample_sticks()
- self.ntssb.root["node"].root[
- "node"
- ].resample_cell_params()
-
- self.resample_node_params()
- self.resample_stick_orders()
- else:
- logger.debug("Successful swap!")
- self.resample_node_params()
- self.resample_stick_orders()
- break
- # else:
- # swap_nodes(nodeAnum,nodeBnum, verbose=True)
- # for i in range(5):
- # self.resample_sticks()
- # self.ntssb.root['node'].root['node'].resample_cell_params()
- #
- # post_new = self.ntssb.unnormalized_posterior()
- # accept_prob = np.exp(np.min([0., post_new - post]))
- # if (rand() > accept_prob):
- # logger.debug("Unsuccessful swap.")
- # swap_nodes(nodeAnum,nodeBnum) # swap back
- # for i in range(5):
- # self.resample_sticks()
- # self.ntssb.root['node'].root['node'].resample_cell_params()
- #
- # else:
- # logger.debug("Successful swap!")
- # self.resample_node_params()
- # self.resample_stick_orders()
+ Each node must have two parameters for the kernel: a direction and a state.
+ We assume the tree kernel, regardless of the model, is always defined as
+ P(direction|parent_direction) and P(state|direction,parent_state).
+ For each node, we first update the direction and then the state, taking one gradient step for each
+ parameter and then moving on to the next nodes in the tree traversal
+
+ PSEUDOCODE:
+ def descend(root):
+ alpha, alpha_grad = sample_grad_alpha
+ psi, psi_grad = sample_grad_psi
+
+ alpha_grad += Gradient of logp(alpha|parent_alpha) wrt this alpha
+ alpha_grad += Gradient of logp(psi|parent_psi,alpha) wrt this alpha
+
+ alpha_grad += Gradient of logq(alpha) wrt this alpha
+ psi_grad += Gradient of logq(psi) wrt this psi
+
+ psi_grad += Gradient of logp(x|psi) wrt this psi
+
+ for each child:
+ child_alpha, child_psi = descend(child)
+ alpha_grad += Gradient of logp(child_alpha|alpha) wrt this alpha
+ psi_grad += Gradient of logp(child_psi|psi,child_alpha) wrt this psi
+
+ for each child_root:
+ alpha_grad += Gradient of logp(child_root_alpha|alpha) wrt this alpha
+ psi_grad += Gradient of logp(child_root_psi|psi,child_root_alpha) wrt this psi
+
+ new_alpha_params = alpha_params + alpha_grad * step_size
+ new_alpha = sample_alpha
+
+ psi_grad += Gradient of logp(psi|parent_psi,new_alpha) wrt this psi
+ new_psi_params = psi_params + psi_grad * step_size
+ new_psi = sample_psi
+
+ return new_alpha, new_psi
+
+ """
+ def descend(root, key, depth=0):
+ direction_sample_grad = 0.
+ state_sample_grad = 0.
+
+ if depth != 0:
+ key, sample_grad = root['node'].direction_sample_and_grad(key, n_samples=mc_samples)
+ direction_curr_sample, direction_params_grad = sample_grad
+ key, sample_grad = root['node'].state_sample_and_grad(key, n_samples=mc_samples)
+ state_curr_sample, state_params_grad = sample_grad
+ else:
+ root['node'].sample_kernel(n_samples=mc_samples)
+ direction_curr_sample = root['node'].get_direction_sample()
+ state_curr_sample = root['node'].get_state_sample()
+
+ if depth != 0:
+ direction_parent_sample = root["node"].parent().get_direction_sample()
+ state_parent_sample = root["node"].parent().get_state_sample()
+
+ direction_sample_grad += root["node"].compute_direction_prior_grad(direction_curr_sample, direction_parent_sample, state_parent_sample)
+ direction_sample_grad += root["node"].compute_state_prior_grad_wrt_direction(state_curr_sample, state_parent_sample, direction_curr_sample)
+
+ direction_params_entropy_grad = root["node"].compute_direction_entropy_grad()
+ state_params_entropy_grad = root["node"].compute_state_entropy_grad()
+
+ if memoized:
+ state_sample_grad += root["node"].compute_ll_state_grad_suff(state_curr_sample)
+ else:
+ weights = root['node'].variational_parameters['q_z'] * self.variational_parameters['q_c']
+ state_sample_grad += root["node"].compute_ll_state_grad(self.ntssb.data, weights, state_curr_sample)
+
+ mass_down = 0
+ for child in root['children'][::-1]:
+ child_mass, direction_child_sample, state_child_sample = descend(child, key, depth=depth+1)
+
+ child['node'].variational_parameters['sigma_1'] = 1.0 + child_mass
+ child['node'].variational_parameters['sigma_2'] = self.dp_gamma + mass_down
+ mass_down += child_mass
+
+ if depth != 0:
+ direction_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_direction(direction_child_sample, direction_curr_sample, state_curr_sample)
+ state_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_state(direction_child_sample, direction_curr_sample, state_curr_sample)
+ state_sample_grad += root["node"].compute_state_prior_child_grad(state_child_sample, state_curr_sample, direction_child_sample)
+
+ if depth != 0:
+ for ii, child_root in enumerate(self.children_root_nodes):
+ direction_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_direction(child_root.get_direction_sample(), direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][ii]
+ state_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_state(child_root.get_direction_sample(), direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][ii]
+ state_sample_grad += root["node"].compute_root_state_prior_child_grad(child_root.get_state_sample(), state_curr_sample, child_root.get_direction_sample()) * root['node'].variational_parameters['q_rho'][ii]
+
+ if depth != 0:
+ if adaptive and i == 0:
+ root['node'].reset_opt()
+
+ # Combine gradients of functions wrt sample with gradient of sample wrt var params
+ if adaptive:
+ root['node'].update_direction_adaptive(direction_params_grad, direction_sample_grad, direction_params_entropy_grad,
+ step_size=step_size, i=i)
+ else:
+ root['node'].update_direction_params(direction_params_grad, direction_sample_grad, direction_params_entropy_grad,
+ step_size=step_size)
+ key, sample_grad = root['node'].direction_sample_and_grad(key, n_samples=mc_samples)
+ direction_curr_sample, _ = sample_grad
+
+ state_sample_grad += root["node"].compute_state_prior_grad(state_curr_sample, state_parent_sample, direction_curr_sample)
+
+ if adaptive:
+ root['node'].update_state_adaptive(state_params_grad, state_sample_grad, state_params_entropy_grad,
+ step_size=step_size, i=i)
+ else:
+ root['node'].update_state_params(state_params_grad, state_sample_grad, state_params_entropy_grad,
+ step_size=step_size)
+ key, sample_grad = root['node'].state_sample_and_grad(key, n_samples=mc_samples)
+ state_curr_sample, _ = sample_grad
+ root['node'].samples[0] = state_curr_sample
+ root['node'].samples[1] = direction_curr_sample
+
+ if memoized:
+ mass_here = root['node'].suff_stats['mass']['total']
+ else:
+ mass_here = jnp.sum(root['node'].variational_parameters['q_z'] * self.variational_parameters['q_c'])
+ root['node'].variational_parameters['delta_1'] = 1.0 + mass_here
+ root['node'].variational_parameters['delta_2'] = (self.alpha_decay**depth) * self.dp_alpha + mass_down
+
+ return mass_here + mass_down, direction_curr_sample, state_curr_sample
+
+ # Update kernels and sticks
+ if root is None:
+ root = self.root
+ descend(root, key)
+
+
+ def sample_grad_root_node(self, key, memoized=True, mc_samples=10, **kwargs):
+ """
+ B->B0 --> C
+ Compute gradient of p(stateB0|dirB0,stateB) wrt stateB, p(dirB0|dirB,stateB) wrt dirB, stateB
+ and gradient p(stateC|dirC,stateB) wrt stateB, p(dirC|dirB,stateB) wrt dirB, stateB
+ """
+ # Sample root params and compute initial gradients wrt children
+ root = self.root
+ direction_sample_grad = 0.
+ state_sample_grad = 0.
+
+ key, sample_grad = root['node'].direction_sample_and_grad(key, n_samples=mc_samples)
+ direction_curr_sample, direction_params_grad = sample_grad
+ key, sample_grad = root['node'].state_sample_and_grad(key, n_samples=mc_samples)
+ state_curr_sample, state_params_grad = sample_grad
+
+ # Gradient of entropy
+ direction_params_entropy_grad = root["node"].compute_direction_entropy_grad()
+ state_params_entropy_grad = root["node"].compute_state_entropy_grad()
+
+ # Gradient of likelihood
+ if memoized:
+ ll_grad = root["node"].compute_ll_state_grad_suff(state_curr_sample)
+ else:
+ weights = root['node'].variational_parameters['q_z'] * self.variational_parameters['q_c']
+ ll_grad = root["node"].compute_ll_state_grad(self.ntssb.data, weights, state_curr_sample)
+
+ # Gradient of children in TSSB
+ for child in root['children'][::-1]:
+ direction_child_sample = child['node'].get_direction_sample()
+ state_child_sample = child['node'].get_state_sample()
+ direction_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_direction(direction_child_sample, direction_curr_sample, state_curr_sample)
+ state_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_state(direction_child_sample, direction_curr_sample, state_curr_sample)
+ state_sample_grad += root["node"].compute_state_prior_child_grad(state_child_sample, state_curr_sample, direction_child_sample)
+
+ # Gradient of roots of children TSSB
+ for i, child_root in enumerate(self.children_root_nodes):
+ direction_child_sample = child_root.get_direction_sample()
+ state_child_sample = child_root.get_state_sample()
+ # Gradient of the root nodes of children TSSBs wrt to their parameters using this TSSB root as parent
+ direction_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_direction(direction_child_sample, direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][i]
+ state_sample_grad += root["node"].compute_direction_prior_child_grad_wrt_state(direction_child_sample, direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][i]
+ state_sample_grad += root["node"].compute_state_prior_child_grad(state_child_sample, state_curr_sample, direction_child_sample) * root['node'].variational_parameters['q_rho'][i]
+
+ direction_locals_grads = [direction_params_grad, direction_params_entropy_grad]
+ state_locals_grads = [state_params_grad, state_params_entropy_grad]
+
+ return ll_grad, [direction_locals_grads, state_locals_grads], [direction_sample_grad, state_sample_grad]
+
+ def compute_children_root_node_grads(self, **kwargs):
+ """
+ A -> B
+ Compute gradient of p(dirB|dirA,stateA) wrt dirB and p(stateB|dirB,stateA) wrt stateB, dirB
+ """
+ def descend(root, children_grads=None):
+ direction_curr_sample = root['node'].samples[1]
+ state_curr_sample = root['node'].samples[0]
+
+ # Compute gradient of children roots wrt their params
+ if children_grads is None:
+ children_grads = [[0., 0.]] * len(self.children_root_nodes)
+ for i, child_root in enumerate(self.children_root_nodes):
+ # Gradient of the root nodes of children TSSBs wrt to their parameters using this
+ direction_child_sample = child_root.get_direction_sample()
+ state_child_sample = child_root.get_state_sample()
+ direction_sample_grad = root["node"].compute_direction_prior_grad(direction_child_sample, direction_curr_sample, state_curr_sample) * root['node'].variational_parameters['q_rho'][i]
+ direction_sample_grad += root["node"].compute_state_prior_grad_wrt_direction(state_child_sample, state_curr_sample, direction_child_sample) * root['node'].variational_parameters['q_rho'][i]
+ state_sample_grad = root["node"].compute_state_prior_grad(state_child_sample, state_curr_sample, direction_child_sample) * root['node'].variational_parameters['q_rho'][i]
+ children_grads[i][0] += direction_sample_grad
+ children_grads[i][1] += state_sample_grad
+
+ for child in root['children']:
+ descend(child, children_grads=children_grads)
+
+ return children_grads
+
+ # Return list of children, for each child a tuple with direction grad sum and sample grad sum
+ return descend(self.root)
+
+
+ def update_root_node_params(self, key, ll_grad, local_grads, children_grads, parent_grads, i=0, adaptive=True, step_size=0.0001, mc_samples=10, **kwargs):
+ """
+ ll_grads: ll_grad_state_D
+ local_grads: params_grad_D, params_entropy_grad_D
+ children_grads: grad_D p(D0|D) + grad_D p(D|D)
+ parent_grads: grad_D p(D|B) + grad_D p(D|B0)
+ """
+ if adaptive and i == 0:
+ self.root['node'].reset_opt()
+
+ direction_params_grad, direction_params_entropy_grad = local_grads[0]
+ direction_sample_grad = children_grads[0] + parent_grads[0]
+
+ # Combine gradients of functions wrt sample with gradient of sample wrt var params
+ if adaptive:
+ self.root['node'].update_direction_adaptive(direction_params_grad, direction_sample_grad, direction_params_entropy_grad,
+ step_size=step_size, i=i)
+ else:
+ self.root['node'].update_direction_params(direction_params_grad, direction_sample_grad, direction_params_entropy_grad,
+ step_size=step_size)
+ key, sample_grad = self.root['node'].direction_sample_and_grad(key, n_samples=mc_samples)
+ direction_curr_sample, _ = sample_grad
+ self.root['node'].samples[1] = direction_curr_sample
+
+ state_params_grad, state_params_entropy_grad = local_grads[1]
+ state_sample_grad = children_grads[1] + parent_grads[1]
+ state_sample_grad += ll_grad
+
+ if adaptive:
+ self.root['node'].update_state_adaptive(state_params_grad, state_sample_grad, state_params_entropy_grad,
+ step_size=step_size, i=i)
+ else:
+ self.root['node'].update_state_params(state_params_grad, state_sample_grad, state_params_entropy_grad,
+ step_size=step_size)
+ key, sample_grad = self.root['node'].state_sample_and_grad(key, n_samples=mc_samples)
+ state_curr_sample, _ = sample_grad
+ self.root['node'].samples[0] = state_curr_sample
+
+
+ def update_pivot_probs(self, **kwargs):
+ def descend(root):
+ # Update pivot assignment probabilities
+ new_log_probs = []
+ root['node'].variational_parameters['q_rho'] = [0.] * len(self.children_root_nodes)
+ for i, child_root in enumerate(self.children_root_nodes):
+ pivot_direction_score = child_root.compute_root_direction_prior(root['node'].get_direction_sample())
+ pivot_state_score = child_root.compute_root_state_prior(root['node'].get_state_sample())
+ new_log_prob = pivot_direction_score + pivot_state_score + jnp.log(root['node'].pivot_prior_prob)
+ root['node'].variational_parameters['q_rho'][i] = new_log_prob
+ new_log_probs.append([new_log_prob])
+
+ for child in root['children']:
+ children_log_probs = descend(child)
+ for i, child_root in enumerate(self.children_root_nodes):
+ new_log_probs[i].extend(children_log_probs[i])
+
+ return new_log_probs
+
+ logqs = descend(self.root)
+ # Compute LSE for pivot probs
+ for i in range(len(self.children_root_nodes)):
+ this_logqs = jnp.array(logqs[i])
+ self.variational_parameters['LSE_rho'][i] = jax.scipy.special.logsumexp(this_logqs)
+
+ # Set probs and return sum of weighted likelihoods
+ def descend(root):
+ for i, child_root in enumerate(self.children_root_nodes):
+ root['node'].variational_parameters['q_rho'][i] = jnp.exp(root['node'].variational_parameters['q_rho'][i] - self.variational_parameters['LSE_rho'][i])
+ for child in root['children']:
+ descend(child)
+ # Normalize pivot probs
+ descend(self.root)
+
def cull_tree(self, verbose=False, resample_sticks=True):
"""
If a leaf node has no data assigned to it, remove it
@@ -989,10 +1542,9 @@ def descend(root):
return pivot_node, pivot_tssb
- def find_node(self, u, truncated=False):
+ def find_node(self, u, include_leaves=True):
def descend(root, u, depth=0):
if depth >= self.max_depth:
- # logger.debug >>sys.stderr, "WARNING: Reached maximum depth."
return (root["node"], [], root)
elif u < root["main"]:
return (root["node"], [], root)
@@ -1000,36 +1552,19 @@ def descend(root, u, depth=0):
# Rescale the uniform variate to the remaining interval.
u = (u - root["main"]) / (1.0 - root["main"])
- if not truncated:
- # Perhaps break sticks out appropriately.
- while (
- not root["children"] or (1.0 - prod(1.0 - root["sticks"])) < u
- ):
- stick_length = boundbeta(1, self.dp_gamma)
- root["sticks"] = vstack([root["sticks"], stick_length])
- root["children"].append(
- {
- "node": root["node"].spawn(
- False, self.root_node.observed_parameters
- ),
- "main": boundbeta(
- 1.0,
- (self.alpha_decay ** (depth + 1)) * self.dp_alpha,
- )
- if self.min_depth <= (depth + 1)
- else 0.0,
- "sticks": empty((0, 1)),
- "children": [],
- }
- )
- else:
- root["sticks"][-1] = 1.0
+ # Don't break sticks
edges = 1.0 - cumprod(1.0 - root["sticks"])
index = sum(u > edges)
+ if index >= len(root['sticks']):
+ return (root["node"], [], root)
edges = hstack([0.0, edges])
u = (u - edges[index]) / (edges[index + 1] - edges[index])
+ # Perhaps stop before continuing to a leaf
+ if not include_leaves and len(root["children"][index]["children"]) == 0:
+ return (root["node"], [], root)
+
(node, path, root) = descend(root["children"][index], u, depth + 1)
path.insert(0, index)
@@ -1038,6 +1573,33 @@ def descend(root, u, depth=0):
return descend(self.root, u)
+ def find_node_uniform(self, key, include_leaves=True):
+ def descend(root, key, depth=0):
+ if depth >= self.max_depth:
+ return (root["node"], [], root)
+ elif len(root["children"]) == 0:
+ return (root["node"], [], root)
+ else:
+ key, subkey = jax.random.split(key)
+ n_children = len(root["children"])
+ if jax.random.bernoulli(subkey, p=1./(n_children+1)):
+ return (root["node"], [], root)
+ else:
+ key, subkey = jax.random.split(key)
+ index = jax.random.choice(subkey, len(root["children"]))
+
+ # Perhaps stop before continuing to a leaf
+ if not include_leaves and len(root["children"][index]["children"]) == 0:
+ return (root["node"], [], root)
+
+ (node, path, root) = descend(root["children"][index], key, depth + 1)
+
+ path.insert(0, index)
+
+ return (node, path, root)
+
+ return descend(self.root, key)
+
def get_expected_mixture(self, reset_names=False):
"""
Computes the expected weight for each node.
@@ -1085,6 +1647,26 @@ def descend(root):
return sr
return descend(self.root)
+ def set_weights(self):
+ def descend(root, mass):
+ root['weight'] = mass * root["main"]
+ edges = sticks_to_edges(root["sticks"])
+ weights = diff(hstack([0.0, edges]))
+ for i, child in enumerate(root["children"]):
+ descend(child, mass * (1.0 - root["main"]) * weights[i])
+ return descend(self.root, 1.0)
+
+ def set_expected_weights(self):
+ def descend(root):
+ logprior = E_log_beta(root['node'].variational_parameters['delta_1'], root['node'].variational_parameters['delta_2'])
+ if root['node'].parent() is not None:
+ logprior += root['node'].parent().variational_parameters['sum_E_log_1_nu']
+ logprior += root['node'].variational_parameters['E_log_phi']
+ root['weight'] = jnp.exp(logprior)
+ for child in root['children']:
+ descend(child)
+ descend(self.root)
+
def get_mixture(
self, reset_names=False, get_roots=False, get_depths=False, truncate=False
):
@@ -1481,20 +2063,21 @@ def label_nodes(self, counts=False, names=False):
elif not names or counts is True:
self.label_nodes_counts()
- def set_node_names(self, root_name="X"):
- self.root["label"] = str(root_name)
- self.root["node"].label = str(root_name)
+ def set_node_names(self, root=None, root_name="X"):
+ if root is None:
+ root = self.root
+
+ root["label"] = str(root_name)
+ root["node"].label = str(root_name)
def descend(root, name):
for i, child in enumerate(root["children"]):
- child_name = "%s-%d" % (name, i)
-
+ child_name = f"{name}-{i}"
root["children"][i]["label"] = child_name
root["children"][i]["node"].label = child_name
-
descend(child, child_name)
- descend(self.root, root_name)
+ descend(root, root_name)
def set_subcluster_node_names(self):
# Assumes the other fixed nodes have already been named, and ignores the root
diff --git a/scatrex/plotting/__init__.py b/scatrex/plotting/__init__.py
index ee3f060..5da138d 100644
--- a/scatrex/plotting/__init__.py
+++ b/scatrex/plotting/__init__.py
@@ -1 +1,2 @@
from .constants import *
+from .scatterplot import *
\ No newline at end of file
diff --git a/scatrex/plotting/scatterplot.py b/scatrex/plotting/scatterplot.py
index 9f8e71f..8aab3d4 100644
--- a/scatrex/plotting/scatterplot.py
+++ b/scatrex/plotting/scatterplot.py
@@ -3,7 +3,101 @@
"""
import matplotlib
import matplotlib.pyplot as plt
+import networkx as nx
import numpy as np
+from ..utils.tree_utils import tree_to_dict
+
+
+def plot_full_tree(tree, ax=None, figsize=(6,6), **kwargs):
+ if ax is None:
+ plt.figure(figsize=figsize)
+ ax = plt.gca()
+ def descend(root, graph, pos={}):
+ pos_out = plot_tree(root['node'], G=graph, ax=ax, alpha=1., draw=False, **kwargs) # Draw subtree
+ pos.update(pos_out)
+ for child in root['children']:
+ descend(child, graph, pos)
+
+ def sub_descend(sub_root, graph):
+ parent = sub_root['label']
+ for i, super_child in enumerate(root['children']):
+ child = super_child['label']
+ graph.add_edge(parent, child, alpha=sub_root['pivot_probs'][i], ls='--')
+ nx.draw_networkx_edges(graph, pos, edgelist=[(parent, child)], edge_color=sub_root['color'], alpha=sub_root['pivot_probs'][i], style='--')
+ for child in sub_root['children']:
+ sub_descend(child, graph)
+
+ if len(root['children']) > 0:
+ sub_descend(root['node'], graph) # Draw pivot edges
+
+ G = nx.DiGraph()
+ descend(tree, G)
+
+ ax.margins(0.20) # Set margins for the axes so that nodes aren't clipped
+ ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
+ ax.spines['right'].set_visible(False)
+ ax.spines['top'].set_visible(False)
+
+
+def plot_tree(tree, G = None, param_key='param', data=None, labels=True, alpha=0.5, font_size=12, node_size=1500, edge_width=1., arrows=True, draw=True, ax=None):
+ tree_dict = tree_to_dict(tree, param_key=param_key)
+
+ # Get all positions
+ pos = {}
+ pos[tree['label']] = tree[param_key]
+ for node in tree_dict:
+ if tree_dict[node]['parent'] != '-1':
+ pos[node] = tree_dict[node]['param']
+
+ # Draw graph
+ node_options = {'alpha': alpha,
+ 'node_size': node_size,}
+ edge_options = {'alpha': alpha,
+ 'width': edge_width,
+ 'node_size':node_size,
+ 'arrows': arrows}
+ label_options = {'alpha': alpha,
+ 'font_size': font_size,}
+
+ if ax is None:
+ fig = plt.figure(figsize=(6,6))
+
+ if G is None:
+ G = nx.DiGraph()
+ for node in tree_dict:
+ nx.draw_networkx_nodes(G, pos, nodelist=[node], node_color=tree_dict[node]['color'],
+ **node_options)
+ if tree_dict[node]['parent'] != '-1':
+ parent = tree_dict[node]['parent']
+ G.add_edge(parent, node)
+ nx.draw_networkx_edges(G, pos, edgelist=[(parent, node)], edge_color=tree_dict[parent]['color'],**edge_options)
+ if labels:
+ labs = dict(zip(list(tree_dict.keys()), list(tree_dict.keys())))
+ nx.draw_networkx_labels(G, pos, labs, **label_options)
+
+ ax = plt.gca()
+ ax.margins(0.20) # Set margins for the axes so that nodes aren't clipped
+ ax.tick_params(left=True, bottom=True, labelleft=True, labelbottom=True)
+ ax.spines['right'].set_visible(False)
+ ax.spines['top'].set_visible(False)
+ if draw:
+ plt.show()
+ else:
+ return pos
+
+def plot_nested_tree(tree, top=True, param_key='param', out_alpha=0.4, in_alpha=1., large_node_size=5000, small_node_size=500, draw=True, ax=None, **kwargs):
+ tree_dict = tree_to_dict(tree, param_key=param_key)
+
+ if top:
+ # Plot main tree with transparency, large nodes and without labels
+ ax = plot_tree(tree, param_key=param_key, labels=False, node_size=large_node_size, alpha=out_alpha, draw=False, **kwargs)
+
+ for subtree in tree_dict: # Do this in a tree traversal so that we add the pivots
+ # Plot each subtree
+ plot_tree(tree_dict[subtree]['node'], param_key='mean', labels=False, node_size=small_node_size, alpha=in_alpha, ax=ax, draw=False, **kwargs)
+
+ if draw:
+ plt.show()
def plot_tree_proj(
diff --git a/scatrex/plotting/tree_colors.py b/scatrex/plotting/tree_colors.py
new file mode 100644
index 0000000..89e1bde
--- /dev/null
+++ b/scatrex/plotting/tree_colors.py
@@ -0,0 +1,38 @@
+# Functions to make a colormap out of a tree
+import matplotlib
+import seaborn as sns
+from colorsys import rgb_to_hls
+
+def adjust_color(rgba, lightness_scale, saturation_scale):
+ # scale the lightness (The values should be between 0 and 1)
+ rgb = rgba[:-1]
+ a = rgba[-1]
+ hls = rgb_to_hls(*rgb)
+ lightness = max(0,min(1, hls[1] * lightness_scale))
+ saturation = max(0,min(1,hls[2] * saturation_scale))
+ rgb = sns.set_hls_values(color = rgb, h = None, l = lightness, s = saturation)
+ rgba = rgb + (a,)
+ hex = matplotlib.colors.to_hex(rgba, keep_alpha=True)
+ return hex
+
+
+def make_tree_colormap(tree, base_color, brightness_mult=0.7, saturation_mult=1.3):
+ """
+ Updates the tree dictionary with colors defined from node depth and breadth
+ tree: nested dictionary containing {node: children} with children a list of also dictionaries {child: children}
+ base_color: HEX code for base color
+ saturation_mult: how much to change saturation for each step in depth, centered at 1
+ brightness_mult: how much to change brightness for each step in breadth, centered at 1
+ """
+ base_color_rgba = matplotlib.colors.ColorConverter.to_rgba(base_color)
+
+ tree['color'] = base_color
+
+ # Traverse tree to update out_dict
+ def descend(root, depth=1, breadth=1):
+ for i, child in enumerate(root['children']):
+ breadth += i
+ color = adjust_color(base_color_rgba, brightness_mult**depth, saturation_mult**breadth)
+ child['color'] = color
+ descend(child, depth=depth+1, breadth=breadth)
+ descend(tree)
diff --git a/scatrex/scatrex.py b/scatrex/scatrex.py
index c1cf3cc..5082a33 100644
--- a/scatrex/scatrex.py
+++ b/scatrex/scatrex.py
@@ -1,7 +1,11 @@
-from .models import *
+from .models import CNATree, TrajectoryTree
from .ntssb import NTSSB
from .ntssb import StructureSearch
from .plotting import scatterplot, constants
+from .utils import tree_utils
+
+import jax
+import jax.numpy as jnp
import numpy as np
from sklearn.decomposition import PCA
@@ -25,13 +29,13 @@
class SCATrEx(object):
def __init__(
self,
- model=cna,
+ model=CNATree,
model_args=dict(),
verbosity=logging.INFO,
temppath="./temppath",
+ seed=42,
):
-
- self.model = cna
+ self.model = model
self.model_args = model_args
self.observed_tree = None
self.ntssb = None
@@ -42,6 +46,7 @@ def __init__(
self.temppath = temppath
if not os.path.exists(temppath):
os.makedirs(temppath, exist_ok=True)
+ self.seed = seed
logger.setLevel(verbosity)
@@ -93,14 +98,11 @@ def simulate_tree(
observed_tree=None,
observed_tree_args=dict(),
observed_tree_params=dict(),
- model_args=None,
n_genes=50,
n_extra_per_observed=1,
- seed=42,
copy=False,
+ **cmap_kwargs,
):
- np.random.seed(seed)
-
self.observed_tree = observed_tree
if not self.observed_tree:
@@ -111,7 +113,7 @@ def simulate_tree(
for arg in observed_tree_args:
logger.info(f"{arg}: {observed_tree_args[arg]}")
- self.observed_tree = self.model.ObservedTree(**observed_tree_args)
+ self.observed_tree = self.model(**observed_tree_args)
self.observed_tree.generate_tree()
self.observed_tree.add_node_params(n_genes=n_genes, **observed_tree_params)
@@ -123,58 +125,281 @@ def simulate_tree(
logger.info(f"{arg}: {self.model_args[arg]}")
self.ntssb = NTSSB(
- self.observed_tree, self.model.Node, node_hyperparams=self.model_args
+ self.observed_tree, node_hyperparams=self.model_args, seed=self.seed
)
self.ntssb.create_new_tree(n_extra_per_observed=n_extra_per_observed)
+ self.ntssb.set_ntssb_colors(**cmap_kwargs)
+
logger.info("Tree is stored in `self.observed_tree` and `self.ntssb`")
return self.ntssb if copy else None
- def simulate_data(self, n_cells=100, seed=42, copy=False):
- np.random.seed(seed)
- self.ntssb.put_data_in_nodes(n_cells)
- self.ntssb.root["node"].root["node"].generate_data_params()
-
- # Sample observations
- observations = []
- assignments = []
- assignments_labels = []
- assignments_obs_labels = []
- for obs in range(len(self.ntssb.assignments)):
- sample = self.ntssb.assignments[obs].sample_observation(obs).reshape(1, -1)
- observations.append(sample)
- assignments.append(self.ntssb.assignments[obs])
- assignments_labels.append(self.ntssb.assignments[obs].label)
- assignments_obs_labels.append(self.ntssb.assignments[obs].tssb.label)
- assignments = np.array(assignments)
- assignments_labels = np.array(assignments_labels)
- assignments_obs_labels = np.array(assignments_obs_labels)
- observations = np.concatenate(observations)
-
- self.ntssb.data = observations
- self.ntssb.num_data = observations.shape[0]
+ def simulate_data(self, n_cells=100, copy=False):
+ node_assignments, obs_node_assignments = self.ntssb.sample_assignments(n_cells)
+ data = np.array(self.ntssb.simulate_data())
+ noiseless_data = self.ntssb.root['node'].root['node'].remove_noise(data)
- if self.ntssb.root["node"].root["node"].num_batches > 1:
- self.ntssb.covariates = self.ntssb.root["node"].root["node"].cell_covariates
+ self.adata = AnnData(data)
+ self.adata.layers['corrected'] = noiseless_data
+ self.adata.obs["obs_node"] = obs_node_assignments
+ self.adata.uns["obs_node_colors"] = [
+ self.observed_tree.tree_dict[node]["color"]
+ for node in self.observed_tree.tree_dict
+ ]
+ tree = self.ntssb.get_param_dict()
+ tree_dict = tree_utils.tree_to_dict(tree)
+ node_colors = []
+ for node in tree_dict:
+ d = tree_utils.tree_to_dict(tree_dict[node]['node'])
+ for n in d:
+ if d[n]['size'] > 0:
+ node_colors.append(d[n]['color'])
+ self.adata.obs["node"] = node_assignments
+ self.adata.uns["node_colors"] = node_colors[:] # to remove the root
+ self.adata.raw = self.adata
logger.info("Labeled data are stored in `self.adata`")
- self.adata = AnnData(observations)
- self.adata.obs["node"] = assignments_labels
- self.adata.obs["obs_node"] = assignments_obs_labels
- if self.ntssb.root["node"].root["node"].num_batches > 1:
- self.adata.obs["batch"] = np.argmax(self.ntssb.covariates, axis=1)
+ return self.adata if copy else None
- self.adata.uns["obs_node_colors"] = [
+ def learn_scales(self, n_epochs=100, mc_samples=10, step_size=0.01):
+ logger.info("Learning cell and gene scales")
+ n_cells = self.ntssb.data.shape[0]
+ n_genes = self.ntssb.data.shape[1]
+ gs = np.sqrt(np.median(self.ntssb.data))
+
+ root = self.ntssb.root['node'].root['node']
+ gene_scales_alpha_init = 10. * jnp.ones((n_genes,)) #* jnp.exp(np.random.normal(size=self.n_genes))
+ gene_scales_beta_init = 10. * jnp.ones((n_genes,)) * jnp.exp(gs + 0. * np.random.normal(size=n_genes))
+ root.variational_parameters['global']['gene_scales']['log_alpha'] = jnp.log(gene_scales_alpha_init)
+ root.variational_parameters['global']['gene_scales']['log_beta'] = jnp.log(gene_scales_beta_init)
+
+ cell_scales_alpha_init = 10. * jnp.ones((n_cells,1)) #* jnp.exp(np.random.normal(size=[500,1]))
+ cell_scales_beta_init = 10. * jnp.ones((n_cells,1)) * jnp.exp(gs + 0. * np.random.normal(size=[n_cells,1]))
+ root.variational_parameters['local']['cell_scales']['log_alpha'] = jnp.log(cell_scales_alpha_init)
+ root.variational_parameters['local']['cell_scales']['log_beta'] = jnp.log(cell_scales_beta_init)
+
+ # Initialize MC samples
+ self.ntssb.sample_variational_distributions(n_samples=mc_samples)
+ self.ntssb.update_sufficient_statistics()
+
+ # Update cell, gene scales, assignments
+ self.ntssb.learn_globals(n_epochs=n_epochs, step_size=step_size,mc_samples=mc_samples,
+ update_ass=True, update_locals=True, update_roots=False,
+ locals_names=['cell_scales'],
+ globals_names=['gene_scales'])
+
+ # Add noise factors to learn better scales
+ # root.variational_parameters['global']['factor_weights']['mean'] = jnp.array(0.01 * np.random.normal(size=(2, n_genes)))
+ # root.variational_parameters['local']['obs_weights']['mean'] = jnp.array(0.01 * np.random.normal(size=(500, 2)))
+ self.ntssb.sample_variational_distributions(n_samples=mc_samples)
+ self.ntssb.learn_globals(n_epochs=n_epochs, step_size=step_size, mc_samples=mc_samples,
+ update_ass=True, update_locals=True, update_roots=False)
+
+ def learn_roots_and_noise(self, n_iters=10, n_epochs=100, n_merges=10, n_swaps=10, memoized=True, mc_samples=10, step_size=0.01, seed=42):
+ logger.info("Learning roots and noise")
+ # Remove noise and learn roots
+ self.ntssb.set_node_hyperparams(n_factors=0)
+ self.ntssb.root['node'].root['node'].reset_variational_noise_factors()
+ self.ntssb.sample_variational_distributions(n_samples=mc_samples)
+ self.ntssb.update_sufficient_statistics()
+ self.ntssb.learn_roots(n_epochs, memoized=memoized, mc_samples=mc_samples, step_size=step_size, return_trace=False)
+
+ # Update assignments
+ self.ntssb.update_local_params(jax.random.PRNGKey(seed), update_ass=True, update_globals=False)
+
+ # Learn a tree with root updates on noiseless data (over-cluster)
+ searcher = StructureSearch(self.ntssb)
+ searcher.tree.set_tssb_params(dp_alpha=1., dp_gamma=1.,)
+ searcher.tree.sample_variational_distributions(n_samples=10)
+ searcher.tree.update_sufficient_statistics()
+ searcher.tree.compute_elbo(memoized=memoized)
+ searcher.proposed_tree = deepcopy(searcher.tree)
+ searcher.run_search(n_iters=n_iters, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size,
+ memoized=memoized, seed=seed, update_roots=True)
+
+ self.ntssb = deepcopy(searcher.tree)
+ self.ntssb.set_node_hyperparams(n_factors=self.model_args['n_factors'])
+ self.ntssb.root['node'].root['node'].reset_variational_noise_factors()
+ self.ntssb.sample_variational_distributions(n_samples=mc_samples)
+ self.ntssb.update_sufficient_statistics()
+ self.ntssb.compute_elbo(memoized=memoized)
+ # Cleanup parameters: learn noise and update all parameters (including roots) except scales, no memoization needed
+ self.ntssb.learn_model(n_epochs=n_epochs, update_ass=True, update_globals=True,
+ locals_names=['obs_weights'],
+ globals_names=['factor_weights', 'factor_precisions'],
+ update_roots=True, step_size=step_size, mc_samples=mc_samples,
+ memoized=False)
+
+ self.ntssb.update_sufficient_statistics()
+ self.ntssb.compute_elbo(memoized=memoized)
+
+ # Propose merges to account for new noise
+ searcher = StructureSearch(self.ntssb)
+ searcher.tree.sample_variational_distributions(n_samples=mc_samples)
+ searcher.tree.update_sufficient_statistics()
+ searcher.tree.compute_elbo(memoized=memoized)
+ searcher.proposed_tree = deepcopy(searcher.tree)
+ key = jax.random.PRNGKey(seed)
+ for i in range(10):
+ key, subkey = jax.random.split(key)
+ searcher.merge(subkey, moves_per_tssb=1, memoized=memoized)
+
+ searcher.tree.update_sufficient_statistics()
+ searcher.tree.compute_elbo(memoized=memoized)
+ searcher.proposed_tree = deepcopy(searcher.tree)
+ # Propose root merges with opt after these merges
+ for i in range(n_merges):
+ key, subkey = jax.random.split(key)
+ searcher.merge_root(subkey, memoized=memoized, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size)
+
+ searcher.tree.update_sufficient_statistics()
+ searcher.tree.compute_elbo(memoized=memoized)
+ searcher.proposed_tree = deepcopy(searcher.tree)
+ # Propose root swaps after these merges
+ for i in range(n_swaps):
+ key, subkey = jax.random.split(key)
+ searcher.swap(subkey, memoized=memoized, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size,
+ update_ass=False) # fixed assignments
+
+ # Re-learn noise and parameters except scales
+ self.ntssb = deepcopy(searcher.tree)
+ self.ntssb.sample_variational_distributions(n_samples=mc_samples)
+ self.ntssb.update_sufficient_statistics()
+ self.ntssb.compute_elbo(memoized=memoized)
+ self.ntssb.learn_model(n_epochs=n_epochs, update_ass=True, update_globals=True,
+ locals_names=['obs_weights'],
+ globals_names=['factor_weights', 'factor_precisions'],
+ update_roots=True, step_size=step_size, mc_samples=mc_samples,
+ memoized=False) # If I do this with memoization the noise ends up explaining too much and messing up the scales! Best not to mix
+
+ def learn_tree(self, n_iters=10, n_epochs=100, memoized=True, mc_samples=10, step_size=0.01, dp_alpha=.1, dp_gamma=.1, prune=True, seed=42):
+ logger.info("Learning augmented tree")
+ # Re-learn tree (optionally from scratch) with fixed noise and roots
+ if prune:
+ self.ntssb.prune_subtrees()
+ searcher = StructureSearch(self.ntssb)
+ searcher.tree.set_tssb_params(dp_alpha=dp_alpha, dp_gamma=dp_gamma)
+ searcher.tree.sample_variational_distributions(n_samples=mc_samples)
+ searcher.tree.update_sufficient_statistics()
+ searcher.tree.compute_elbo(memoized=memoized)
+ searcher.proposed_tree = deepcopy(searcher.tree)
+ searcher.run_search(n_iters=n_iters, n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size,
+ memoized=memoized, seed=seed,
+ update_roots=False)
+ self.ntssb = deepcopy(searcher.tree)
+
+ def update_anndata(self, adata):
+ """
+ Use learned NTSSB to add annotations to input AnnData object.
+ Cells and genes must be in same order as the data in NTSSB!
+ """
+ node_assignments = []
+ for i in range(adata.shape[0]):
+ node_assignments.append(self.ntssb.assignments[i].label)
+
+ obs_node_assignments = []
+ for i in range(adata.shape[0]):
+ obs_node_assignments.append(self.ntssb.assignments[i].tssb.label)
+
+ adata.obs["scatrex_node"] = node_assignments
+ adata.obs["scatrex_obs_node"] = obs_node_assignments
+
+ adata.uns["scatrex_node_colors"] = [
+ self.observed_tree.tree_dict[node.split("-")[0]]["color"]
+ for node in np.unique(adata.obs["scatrex_node"])
+ ]
+ adata.uns["scatrex_obs_node_colors"] = [
self.observed_tree.tree_dict[node]["color"]
- for node in self.observed_tree.tree_dict
+ for node in np.unique(adata.obs["scatrex_obs_node"])
]
- self.adata.raw = self.adata
- return (observations, assignments_labels) if copy else None
+ labels = list(self.observed_tree.tree_dict.keys())
+ sizes = [
+ np.count_nonzero(adata.obs["scatrex_obs_node"] == label)
+ for label in labels
+ ]
+ adata.uns["scatrex_estimated_frequencies"] = dict(zip(labels, sizes))
- def learn_tree(
+ adata.layers["scatrex_noise"] = (
+ self.ntssb.root["node"]
+ .root["node"]
+ .variational_parameters["local"]["obs_weights"]['mean']
+ .dot(
+ self.ntssb.root["node"]
+ .root["node"]
+ .variational_parameters["global"]["factor_weights"]['mean']
+ )
+ )
+
+ genes_pos = np.arange(adata.var_names.size)
+
+ xi_mat = np.zeros(adata.shape)
+ om_mat = np.zeros(adata.shape)
+ cnv_mat = np.zeros(adata.shape)
+ mean_mat = np.zeros(adata.shape)
+ nodes = np.array(self.ntssb.get_nodes())
+ nodes_labels = np.array([node.label for node in nodes])
+ for node_id in np.unique(adata.obs["scatrex_node"]):
+ cells = np.where(adata.obs["scatrex_node"] == node_id)[0]
+ node = nodes[np.where(node_id == nodes_labels)[0][0]]
+ pos = np.meshgrid(cells, genes_pos)
+ xi_mat[tuple(pos)] = (
+ np.array(
+ node.params[0]
+ ).reshape(-1, 1)
+ * np.ones((len(cells), len(genes_pos))).T
+ )
+ om_mat[tuple(pos)] = (
+ np.array(
+ node.params[1]
+ ).reshape(-1, 1)
+ * np.ones((len(cells), len(genes_pos))).T
+ )
+ cnv_mat[tuple(pos)] = (
+ np.array(node.cnvs).reshape(-1, 1)
+ * np.ones((len(cells), len(genes_pos))).T
+ )
+ mean_mat[tuple(pos)] = (
+ np.array(node.get_mean()).reshape(-1, 1)
+ * np.ones((len(cells), len(genes_pos))).T
+ )
+ adata.layers["scatrex_cell_states"] = xi_mat
+ adata.layers["scatrex_cell_state_events"] = om_mat
+ adata.layers["scatrex_cnvs"] = cnv_mat
+ adata.layers["scatrex_mean"] = mean_mat
+
+ def learn(self, adata, observed_tree=None, counts_layer='counts', allow_subtrees=True, allow_root_subtrees=False, root_cells=None,
+ batch_size=None, seed=42,
+ n_epochs=100, mc_samples=10, step_size=0.01, n_iters=10, n_merges=10, n_swaps=10, memoized=True, dp_alpha=.1, dp_gamma=.1):
+ """
+ Complete NTSSB learning procedure.
+ """
+ if observed_tree is not None:
+ self.set_observed_tree(observed_tree)
+
+ # Setup NTSSB
+ self.ntssb = NTSSB(self.observed_tree,
+ node_hyperparams=self.model_args,
+ seed=seed,)
+ self.ntssb.add_data(np.array(adata.layers[counts_layer]))
+ self.ntssb.make_batches(batch_size, seed)
+ self.ntssb.reset_variational_parameters()
+
+ # Learn
+ self.learn_scales(n_epochs=n_epochs, mc_samples=mc_samples, step_size=step_size)
+ self.learn_roots_and_noise(n_iters=n_iters, n_epochs=n_epochs, n_merges=n_merges, n_swaps=n_swaps, memoized=memoized, mc_samples=mc_samples, step_size=step_size, seed=seed)
+ self.learn_tree(n_iters=n_iters, n_epochs=n_epochs, memoized=memoized, mc_samples=mc_samples, step_size=step_size, dp_alpha=dp_alpha, dp_gamma=dp_gamma, prune=True, seed=seed)
+
+ # Create outputs for analysis
+ self.ntssb.set_learned_parameters()
+ self.ntssb.assign_samples()
+ self.ntssb.create_augmented_tree_dict()
+ self.ntssb.initialize_gene_node_colormaps()
+ self.update_anndata(adata)
+
+ def _learn_tree(
self,
observed_tree=None,
reset=True,
@@ -363,7 +588,9 @@ def learn_tree(
np.array(
np.exp(
node.variational_parameters["locals"][
- "unobserved_factors_kernel_log_mean"
+ "unobserved_factors_kernel_log_shape"
+ ] - node.variational_parameters["locals"][
+ "unobserved_factors_kernel_log_rate"
]
)
).reshape(-1, 1)
@@ -783,7 +1010,7 @@ def assign_new_data(self, data):
return
def set_node_event_strings(self, **kwargs):
- self.ntssb.set_node_event_strings(var_names=sim_sca.adata.var_names, **kwargs)
+ self.ntssb.set_node_event_strings(var_names=self.adata.var_names, **kwargs)
def plot_tree(
self,
@@ -871,6 +1098,63 @@ def plot_tree(
return g
+ def plot_data(self, level=0, draw=True, layer=None, color=None, remove_noise=False, **kwargs):
+ # Plot data projection using node colors
+ data = self.adata.X
+ if layer is not None:
+ data = self.adata.layers[layer]
+
+ if remove_noise:
+ # Remove noise according to noise model
+ data = self.ntssb.root['node'].root['node'].remove_noise(data)
+
+ def super_descend(super_root, go_down=True):
+ if go_down:
+ descend(super_root['node'].root)
+ else:
+ # Plot data
+ attached_cells = np.array(list(super_root['node']._data))
+ if len(attached_cells > 0):
+ color = super_root['color']
+ if color is not None:
+ cell_color = color
+ plt.scatter(data[attached_cells,0], data[attached_cells,1],
+ color=cell_color, label=super_root['label'], **kwargs)
+
+ for super_child in super_root['children']:
+ super_descend(super_child, go_down=go_down)
+
+ def descend(root):
+ # Plot data
+ attached_cells = np.array(list(root['node'].data))
+ if len(attached_cells > 0):
+ cell_color = root['color']
+ if color is not None:
+ cell_color = color
+ plt.scatter(data[attached_cells,0], data[attached_cells,1],
+ color=cell_color, label=root['label'], **kwargs)
+ for child in root['children']:
+ descend(child)
+
+ if level == 0: # Unobs
+ super_descend(self.ntssb.root, go_down=True)
+ elif level == 1: # Obs
+ super_descend(self.ntssb.root, go_down=False)
+
+ if draw:
+ plt.show()
+
+ def plot_tree_projection(self, level=None, ax=None, title="", **kwargs):
+
+ tree = self.ntssb.get_param_dict()
+ if level is None: # Both levels
+ ax = scatterplot.plot_nested_tree(tree, param_key='obs_param', top=True, ax=ax, **kwargs)
+ elif level == 1: # Only obs
+ ax = scatterplot.plot_tree(tree, param_key='obs_param', ax=ax, **kwargs)
+ elif level == 0: # Only unobs
+ ax = scatterplot.plot_nested_tree(tree, top=False, ax=ax, **kwargs)
+ return ax
+
def plot_tree_proj(
self,
project=True,
@@ -981,6 +1265,8 @@ def plot_unobserved_parameters(
if gene_names is not None:
# Transform gene names into gene indices
genes = np.array([self.adata.var_names.get_loc(g) for g in gene_names])
+ else:
+ genes = np.arange(len(self.adata.var_names))
if self.search is not None:
if len(self.search.traces["elbo"]) > 0:
@@ -1038,14 +1324,17 @@ def plot_unobserved_parameters(
ls=ls,
)
else:
- plt.plot(
- unobs[genes].ravel() - step * i,
+ plt.bar(
+ np.arange(len(genes)),
+ unobs[genes].ravel(),# - step * i,
+ bottom=- step * i,
label=node.label,
color=node.tssb.color,
lw=lw,
alpha=alpha,
ls=ls,
)
+ plt.plot(np.zeros((len(genes),)) - step * i, ls='--', color='gray', alpha=alpha)
if gene_names is not None and show_names:
plt.xticks(np.arange(len(gene_names)), labels=gene_names)
else:
@@ -1059,6 +1348,77 @@ def plot_unobserved_parameters(
if ax is None:
plt.show()
+ def plot_estimated_unobserved_kernels(
+ self,
+ node_names=None,
+ gene=None,
+ gene_names=None,
+ ax=None,
+ figsize=(4, 4),
+ lw=4,
+ alpha=0.7,
+ title="",
+ fontsize=18,
+ step=4,
+ x_max=1,
+ show_names=False,
+ save=None,
+ ):
+ nodes, _ = self.ntssb.get_node_mixture()
+
+ if node_names is not None:
+ nodes = [node for node in nodes if node.label in node_names]
+
+ genes = None
+ if gene_names is not None:
+ # Transform gene names into gene indices
+ genes = np.array([self.adata.var_names.get_loc(g) for g in gene_names])
+ else:
+ genes = np.arange(len(self.adata.var_names))
+
+ if self.search is not None:
+ if len(self.search.traces["elbo"]) > 0:
+ estimated = True
+
+ if ax is None:
+ plt.figure(figsize=figsize)
+ else:
+ plt.gca(ax)
+ ticklabs = []
+ tickpos = []
+ for i, node in enumerate(nodes):
+ if node.parent() is not None:
+ ls = "-"
+ shape = np.exp(node.variational_parameters["locals"]["unobserved_factors_kernel_log_shape"])
+ rate = np.exp(node.variational_parameters["locals"]["unobserved_factors_kernel_log_rate"])
+ unobs = shape/rate
+ std = np.sqrt(
+ shape/rate**2
+ )
+ plt.bar(
+ np.arange(len(genes)),
+ unobs[genes].ravel(),
+ bottom= - step * i,
+ label=node.label,
+ color=node.tssb.color,
+ lw=lw,
+ alpha=alpha,
+ ls=ls,
+ )
+ plt.plot(np.zeros((len(genes),)) - step * i, ls='--', color='gray', alpha=alpha)
+ if gene_names is not None and show_names:
+ plt.xticks(np.arange(len(gene_names)), labels=gene_names)
+ else:
+ plt.xticks([])
+ tickpos.append(-step * i)
+ ticklabs.append(rf"{node.label.replace('-', '')}")
+ plt.yticks(tickpos, labels=ticklabs, fontsize=fontsize)
+ plt.title(title, fontsize=fontsize)
+ if save is not None:
+ plt.savefig(save, bbox_inches="tight")
+ if ax is None:
+ plt.show()
+
def bulkify(self):
self.adata.var["raw_bulk"] = np.mean(self.adata.X, axis=0)
try:
diff --git a/scatrex/utils/annotate_utils.py b/scatrex/utils/annotate_utils.py
new file mode 100644
index 0000000..e495ab6
--- /dev/null
+++ b/scatrex/utils/annotate_utils.py
@@ -0,0 +1,105 @@
+from pybiomart import Server
+import pandas as pd
+
+
+def convert_tidy_to_matrix(tidy_df, rows="single_cell_id", columns="copy_number"):
+ # Takes a tidy dataframe specifying the CNVs of cells along genomic bins
+ # and converts it to a cell by bin matrix
+ cell_df = tidy_df.loc[tidy_df[rows] == tidy_df[rows][0]]
+ bins_df = cell_df.drop(columns=[columns, rows], inplace=False)
+ tidy_df["bin_id"] = np.tile(bins_df.index, tidy_df[rows].unique().size)
+ matrix = tidy_df[[columns, rows, "bin_id"]].pivot_table(
+ values=columns, index=rows, columns="bin_id"
+ )
+
+ return matrix, bins_df
+
+
+def annotate_bins(bins_df):
+ # Takes a dataframe of genomic regions and returns an ordered list of full genes in each region
+ server = Server("www.ensembl.org", use_cache=False)
+ dataset = server.marts["ENSEMBL_MART_ENSEMBL"].datasets["hsapiens_gene_ensembl"]
+ gene_coordinates = dataset.query(
+ attributes=[
+ "chromosome_name",
+ "start_position",
+ "end_position",
+ "external_gene_name",
+ ],
+ filters={
+ "chromosome_name": [
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 20,
+ 21,
+ 22,
+ "X",
+ "Y",
+ ]
+ },
+ use_attr_names=True,
+ )
+ # Drop duplicate genes
+ gene_coordinates.drop_duplicates(subset="external_gene_name", ignore_index=True)
+
+ annotated_bins = bins_df.copy()
+ annotated_bins["genes"] = [list() for _ in range(annotated_bins.shape[0])]
+
+ bin_size = bins_df["end"][0] - bins_df["start"][0]
+ for index, row in gene_coordinates.iterrows():
+ gene = row["external_gene_name"]
+ if pd.isna(gene):
+ continue
+ start_bin_in_chr = int(row["start_position"] / bin_size)
+ stop_bin_in_chr = int(row["end_position"] / bin_size)
+ chromosome = str(row["chromosome_name"])
+ chr_start = np.where(bins_df["chr"] == chromosome)[0][0]
+ start_bin = start_bin_in_chr + chr_start
+ stop_bin = stop_bin_in_chr + chr_start
+
+ if stop_bin < annotated_bins.shape[0]:
+ if np.all(annotated_bins.iloc[start_bin:stop_bin].chr == chromosome):
+ for bin in range(start_bin, stop_bin + 1):
+ annotated_bins.loc[bin, "genes"].append(gene)
+
+ return annotated_bins
+
+
+def annotate_matrix(matrix, annotated_bins):
+ # Takes a dataframe of cells by bins and a dataframe with gene lists for each bin
+ # and returns a dataframe of cells by genes
+ df_list = []
+ chrs = []
+ for bin, row in annotated_bins.iterrows():
+ genes = row["genes"]
+ chr = row["chr"]
+ if len(genes) > 0:
+ df_list.append(
+ pd.concat([matrix[bin]] * len(genes), axis=1, ignore_index=True).rename(
+ columns=dict(zip(range(len(genes)), genes))
+ )
+ )
+ chrs.append([chr] * df_list[-1].shape[1])
+ chrs = np.concatenate(chrs)
+ df = pd.concat(df_list, axis=1)
+ chrs = chrs[np.where(~df.columns.duplicated())[0]]
+ df = df.loc[:, ~df.columns.duplicated()]
+ return df, chrs
+
diff --git a/scatrex/util.py b/scatrex/utils/math_utils.py
similarity index 71%
rename from scatrex/util.py
rename to scatrex/utils/math_utils.py
index 90eb396..e8544f8 100644
--- a/scatrex/util.py
+++ b/scatrex/utils/math_utils.py
@@ -5,6 +5,7 @@
from functools import partial
import numpy as np
+import jax
from jax import vmap
from jax import random
import jax.numpy as jnp
@@ -12,8 +13,6 @@
from jax.scipy.stats import norm, gamma, laplace, beta, dirichlet, poisson
from jax.scipy.special import digamma, betaln, gammaln
-from pybiomart import Server
-import pandas as pd
def relative_difference(current, prev, eps=1e-6):
@@ -25,7 +24,7 @@ def absolute_difference(current, prev):
def diag_gamma_sample(rng, log_alpha, log_beta):
- return jnp.exp(-log_beta) * random.gamma(rng, jnp.exp(log_alpha))
+ return jnp.clip(jnp.exp(-log_beta) * random.gamma(rng, jnp.exp(log_alpha)), a_min=1e-10, a_max=1e30)
def diag_gamma_logpdf(x, log_alpha, log_beta):
@@ -384,9 +383,9 @@ def betapdfln(x, a, b):
)
-def boundbeta(a, b):
+def boundbeta(a, b, rng):
return (1.0 - numpy.finfo(numpy.float64).eps) * (
- numpy.random.beta(a, b) - 0.5
+ rng.beta(a, b) - 0.5
) + 0.5
# return numpy.random.beta(a,b)
@@ -411,109 +410,70 @@ def logsumexp(X, axis=None):
return numpy.log(numpy.sum(numpy.exp(X - maxes), axis=axis)) + maxes
-def convert_tidy_to_matrix(tidy_df, rows="single_cell_id", columns="copy_number"):
- # Takes a tidy dataframe specifying the CNVs of cells along genomic bins
- # and converts it to a cell by bin matrix
- cell_df = tidy_df.loc[tidy_df[rows] == tidy_df[rows][0]]
- bins_df = cell_df.drop(columns=[columns, rows], inplace=False)
- tidy_df["bin_id"] = np.tile(bins_df.index, tidy_df[rows].unique().size)
- matrix = tidy_df[[columns, rows, "bin_id"]].pivot_table(
- values=columns, index=rows, columns="bin_id"
- )
+# JAX-optimized functions
- return matrix, bins_df
-
-
-def annotate_bins(bins_df):
- # Takes a dataframe of genomic regions and returns an ordered list of full genes in each region
- server = Server("www.ensembl.org", use_cache=False)
- dataset = server.marts["ENSEMBL_MART_ENSEMBL"].datasets["hsapiens_gene_ensembl"]
- gene_coordinates = dataset.query(
- attributes=[
- "chromosome_name",
- "start_position",
- "end_position",
- "external_gene_name",
- ],
- filters={
- "chromosome_name": [
- 1,
- 2,
- 3,
- 4,
- 5,
- 6,
- 7,
- 8,
- 9,
- 10,
- 11,
- 12,
- 13,
- 14,
- 15,
- 16,
- 17,
- 18,
- 19,
- 20,
- 21,
- 22,
- "X",
- "Y",
- ]
- },
- use_attr_names=True,
- )
- # Drop duplicate genes
- gene_coordinates.drop_duplicates(subset="external_gene_name", ignore_index=True)
-
- annotated_bins = bins_df.copy()
- annotated_bins["genes"] = [list() for _ in range(annotated_bins.shape[0])]
-
- bin_size = bins_df["end"][0] - bins_df["start"][0]
- for index, row in gene_coordinates.iterrows():
- gene = row["external_gene_name"]
- if pd.isna(gene):
- continue
- start_bin_in_chr = int(row["start_position"] / bin_size)
- stop_bin_in_chr = int(row["end_position"] / bin_size)
- chromosome = str(row["chromosome_name"])
- chr_start = np.where(bins_df["chr"] == chromosome)[0][0]
- start_bin = start_bin_in_chr + chr_start
- stop_bin = stop_bin_in_chr + chr_start
-
- if stop_bin < annotated_bins.shape[0]:
- if np.all(annotated_bins.iloc[start_bin:stop_bin].chr == chromosome):
- for bin in range(start_bin, stop_bin + 1):
- annotated_bins.loc[bin, "genes"].append(gene)
-
- return annotated_bins
-
-
-def annotate_matrix(matrix, annotated_bins):
- # Takes a dataframe of cells by bins and a dataframe with gene lists for each bin
- # and returns a dataframe of cells by genes
- df_list = []
- chrs = []
- for bin, row in annotated_bins.iterrows():
- genes = row["genes"]
- chr = row["chr"]
- if len(genes) > 0:
- df_list.append(
- pd.concat([matrix[bin]] * len(genes), axis=1, ignore_index=True).rename(
- columns=dict(zip(range(len(genes)), genes))
- )
- )
- chrs.append([chr] * df_list[-1].shape[1])
- chrs = np.concatenate(chrs)
- df = pd.concat(df_list, axis=1)
- chrs = chrs[np.where(~df.columns.duplicated())[0]]
- df = df.loc[:, ~df.columns.duplicated()]
- return df, chrs
-
-
-def convert_phylogeny_to_clonal_tree(threshold):
- # Converts a phylogenetic tree to a clonal tree by choosing the main clades
- # according to some threshold
- raise NotImplementedError
+@jax.jit
+def logbeta_func(a,b):
+ return gammaln(a) + gammaln(b) - gammaln(a+b)
+
+@jax.jit
+def beta_kl(a1, b1, a2, b2):
+ """
+ logp(a1,b1) - logp(a2,b2)
+ """
+ kl = logbeta_func(a1,b1) - logbeta_func(a2,b2)
+ kl += (a1-a2)*digamma(a1)
+ kl += (b1-b2)*digamma(b1)
+ kl += (a2-a1 + b2-b1)*digamma(a1+b1)
+ return kl
+
+@jax.jit
+def E_log_beta(
+ alpha,
+ beta,
+):
+ return digamma(alpha) - digamma(alpha + beta)
+
+@jax.jit
+def E_log_1_beta(
+ alpha,
+ beta,
+):
+ return digamma(beta) - digamma(alpha + beta)
+
+@jax.jit
+def E_q_log_beta(
+ alpha1,
+ beta1,
+ alpha2,
+ beta2,
+):
+ """
+ Expected value of log p(y) with x~Beta(alpha1,beta1) wrt to q(y) with y~Beta(alpha2,beta2)
+ """
+ return (alpha1-1)*E_log_beta(alpha2,beta2) + (beta1-1)*E_log_1_beta(alpha2,beta2) - betaln(alpha1,beta1)
+
+
+# This computes E_q[p(z_n=\epsilon | \nu, \psi)]
+@jax.jit
+def compute_expected_weight(
+ nu_alpha,
+ nu_beta,
+ psi_alpha,
+ psi_beta,
+ prev_nu_sticks_sum,
+ prev_psi_sticks_sum,
+):
+ nu_digamma_sum = digamma(nu_alpha + nu_beta)
+ E_log_nu = digamma(nu_alpha) - nu_digamma_sum
+ nu_sticks_sum = digamma(nu_beta) - nu_digamma_sum + prev_nu_sticks_sum
+ psi_sticks_sum = local_psi_sticks_sum(psi_alpha, psi_beta) + prev_psi_sticks_sum
+ weight = E_log_nu + nu_sticks_sum + psi_sticks_sum
+ return weight, nu_sticks_sum, psi_sticks_sum
+
+
+@jax.jit
+def assignment_entropies(probs):
+ return -jax.lax.select(probs != 0,
+ probs * jnp.log(probs),
+ probs)
diff --git a/scatrex/utils/tree_utils.py b/scatrex/utils/tree_utils.py
new file mode 100644
index 0000000..0b3948c
--- /dev/null
+++ b/scatrex/utils/tree_utils.py
@@ -0,0 +1,164 @@
+import numpy as np
+import jax
+
+def tree_to_dict(tree, param_key='param', root_par='-1'):
+ """Converts a tree in a recursive dictionary to a flat dictionary with parent keys
+ """
+ tree_dict = dict()
+
+ def descend(node, par_id):
+ label = node['label']
+ tree_dict[label] = dict()
+ tree_dict[label]["parent"] = par_id
+ for key in node:
+ if key == param_key:
+ key = 'param'
+ tree_dict[label]['param'] = node[param_key]
+ elif key == 'children':
+ tree_dict[label][key] = [c['label'] for c in node[key]]
+ elif key != 'parent':
+ tree_dict[label][key] = node[key]
+ for child in node['children']:
+ descend(child, node['label'])
+
+ descend(tree, root_par)
+
+ return tree_dict
+
+def dict_to_tree(tree_dict, root_name, param_key='param'):
+ """Converts a tree in a flat dictionary with parent keys to a recursive dictionary
+ """
+ # Make children in the flat dict
+ for i in tree_dict:
+ tree_dict[i]["children"] = []
+ for i in tree_dict:
+ for j in tree_dict:
+ if tree_dict[j]["parent"] == i:
+ tree_dict[i]["children"].append(j)
+
+ # Make tree
+ root = {}
+ root['label'] = root_name
+ for key in tree_dict[root_name]:
+ # if isinstance(tree_dict[root_name][key], list):
+ # if isinstance(tree_dict[root_name][key][0], dict):
+ # continue
+ root[key] = tree_dict[root_name][key]
+ root['children'] = []
+
+ # Recursively construct tree
+ def descend(super_tree, label):
+ for i, child in enumerate(tree_dict[label]["children"]):
+ d = {}
+ for key in tree_dict[child]:
+ # if isinstance(tree_dict[child][key], list):
+ # if isinstance(tree_dict[child][key][0], dict):
+ # continue
+ d[key] = tree_dict[child][key]
+ d['children'] = []
+ super_tree["children"].append(d)
+ descend(super_tree["children"][-1], child)
+
+ descend(root, root_name)
+ return root
+
+def condense_tree(tree, min_weight=0.1):
+ """
+ Traverse and choose with some prob whether to keep each node in the tree
+ """
+ def descend(root):
+ to_keep = []
+ for child in root['children']:
+ descend(child)
+ to_keep.append(int(child['weight'] > min_weight))
+ if len(to_keep) > 0:
+ to_remove_idx = [i for i in range(len(to_keep)) if to_keep[i] == 0]
+ if 'weight' in root:
+ root['weight'] += np.sum([r['weight'] for i, r in enumerate(root['children']) if to_keep[i]==0])
+ if 'size' in root:
+ root['size'] += int(np.sum([r['size'] for i, r in enumerate(root['children']) if to_keep[i]==0]))
+ # Set children of source as children of target
+ for child_to_remove in list(np.array(root['children'])[to_remove_idx]):
+ for child_of_to_remove in child_to_remove['children']:
+ to_keep.append(1)
+ root['children'].append(child_of_to_remove)
+ # Remove child
+ root['children'] = list(np.array(root['children'])[np.where(np.array(to_keep))[0]])
+ descend(tree)
+
+def subsample_tree(tree, keep_prob=0.5, seed=42):
+ """
+ Traverse and choose with some prob whether to keep each node in the tree
+ """
+ def descend(root, key):
+ to_keep = []
+ for child in root['children']:
+ key, subkey = jax.random.split(key)
+ descend(child, key)
+ to_keep.append(int(jax.random.bernoulli(subkey, keep_prob)))
+ if len(to_keep) > 0:
+ to_remove_idx = [i for i in range(len(to_keep)) if to_keep[i] == 0]
+ if 'weight' in root:
+ root['weight'] += np.sum([r['weight'] for i, r in enumerate(root['children']) if to_keep[i]==0])
+ if 'size' in root:
+ root['size'] += int(np.sum([r['size'] for i, r in enumerate(root['children']) if to_keep[i]==0]))
+ # Set children of source as children of target
+ for child_to_remove in list(np.array(root['children'])[to_remove_idx]):
+ for child_of_to_remove in child_to_remove['children']:
+ to_keep.append(1)
+ root['children'].append(child_of_to_remove)
+ # Remove child
+ root['children'] = list(np.array(root['children'])[np.where(np.array(to_keep))[0]])
+ key = jax.random.PRNGKey(seed)
+ descend(tree, key)
+
+def convert_phylogeny_to_clonal_tree(threshold):
+ # Converts a phylogenetic tree to a clonal tree by choosing the main clades
+ # according to some threshold
+ raise NotImplementedError
+
+def obs_rmse(ntssb1, ntssb2, param='observed'):
+ ntssb1_obs = np.zeros(ntssb1.data.shape)
+ ntssb2_obs = np.zeros(ntssb2.data.shape)
+
+ ntssb1_nodes = ntssb1.get_nodes()
+ for node in ntssb1_nodes:
+ idx = np.where(ntssb1.assignments == node)
+ ntssb1_obs[idx] = node.get_param(param)
+
+ ntssb2_nodes = ntssb2.get_nodes()
+ for node in ntssb2_nodes:
+ idx = np.where(ntssb2.assignments == node)
+ ntssb2_obs[idx] = node.get_param(param)
+
+ return np.mean(np.sqrt(np.mean((ntssb1_obs - ntssb2_obs)**2, axis=1)))
+
+def ntssb_distance(ntssb1, ntssb2):
+ pdist1 = ntssb1.get_pairwise_obs_distances()
+ pdist2 = ntssb2.get_pairwise_obs_distances()
+ n_obs = pdist1.shape[0]
+ return np.sqrt(2./(n_obs*(n_obs-1)) * np.sum((pdist1-pdist2)**2))
+
+
+def subtree_to_tree_distance(subtree, tree):
+ """subtree is a TSSB in the NTSSB
+ tree is a dictionary containing the true tree that should be there
+ To check that the substructure we find is close to the real structure and not just something
+ completely different
+ """
+ subtree_nodes = []
+ tree_nodes = []
+ dists = np.zeros((len(subtree_nodes), len(tree_nodes)))
+ for i, subtree_node in enumerate(subtree_nodes):
+ for j, tree_node in enumerate(tree_nodes):
+ dists[i,j] = compute_distance(subtree_node, tree_node)
+ return np.sqrt(np.mean(dists**2))
+
+def print_tree(tree, tab=' '):
+ def descend(root, depth=0):
+ tabs = [tab] * depth
+ tabs = ''.join(tabs)
+ print(f"{tabs}{root['label']}")
+ for child in root['children']:
+ descend(child, depth=depth+1)
+ descend(tree)
\ No newline at end of file