Skip to content

Commit

Permalink
Add independent model
Browse files Browse the repository at this point in the history
  • Loading branch information
pawel-czyz committed Apr 3, 2024
1 parent 54db486 commit 936a64b
Showing 1 changed file with 195 additions and 13 deletions.
208 changes: 195 additions & 13 deletions workflows/presentation.smk
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ from lifelines.calibration import survival_probability_calibration
from jnotype.pyramids import TwoLayerPyramidSampler, TwoLayerPyramidSamplerNonparametric
from jnotype.sampling import ListDataset

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive




matplotlib.use("agg")

Expand All @@ -40,17 +48,18 @@ class ModelSettings:
true_loc: tuple[int, int]

MODELS = {
"one_parameter_model": ModelSettings(true_loc=(1, 2))
"one_parameter_model": ModelSettings(true_loc=(1, 2)),
"independent": ModelSettings(true_loc=(-1, 1)),
}


rule all:
input: "analysis/COAD/everything.done"

rule analysis_all:
rule run_all:
input:
basic_info = "analysis/{analysis}/basic_info/summary.json",
one_parameter_model = "analysis/{analysis}/one_parameter_model/done.done",
independent_model = "analysis/{analysis}/independent/done.done",
output: touch("analysis/{analysis}/everything.done")


Expand Down Expand Up @@ -99,6 +108,7 @@ rule basic_statistics:


# === One-parameter model ===
# All genes share a single theta parameter, with the mutation frequency

rule one_parameter_all:
input:
Expand Down Expand Up @@ -175,9 +185,101 @@ rule one_parameter_sample_posterior_predictive:
joblib.dump(samples, str(output))


# === Independent-probabilities model ===
# Each gene has a separate mutation probability theta[g]

rule independent_all:
input:
posterior_predictive_matrices_raw = "analysis/{analysis}/independent/posterior_predictive_matrices.pdf",
posterior_predictive_matrices_ordered = "analysis/{analysis}/independent/posterior_predictive_matrices_ordered.pdf",
posterior_predictive_occurrences = "analysis/{analysis}/independent/posterior_predictive_occurrence.pdf",
posterior_predictive_histograms_many_panels = "analysis/{analysis}/independent/histogram_number_of_mutations_many_panels.pdf",
posterior_predictive_histograms_single_panel = "analysis/{analysis}/independent/histogram_number_of_mutations_single_panel.pdf",
output: touch("analysis/{analysis}/independent/done.done")


def independent_model(Y = None, N = None, G = None):
if N is None or G is None:
N, G = Y.shape

# Priors for the probability vector theta
alpha = np.ones(G) # Adjust these based on your knowledge
beta = np.ones(G)
with numpyro.plate('features', G):
theta = numpyro.sample('theta', dist.Beta(alpha, beta))

with numpyro.plate('data', N, dim=-2):
with numpyro.plate("features", G, dim=-1):
numpyro.sample('obs', dist.Bernoulli(theta[None, :]), obs=Y)


rule mcmc_independent_model:
input: "data/preprocessed/{analysis}/mutation-matrix.csv"
output: "analysis/{analysis}/independent/posterior_samples.joblib"
run:
Y = pd.read_csv(str(input), index_col=0).values

nuts_kernel = NUTS(independent_model)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(jax.random.PRNGKey(0), Y=Y)

samples = mcmc.get_samples()
joblib.dump(samples, str(output))


rule posterior_predictive_independent_model:
input:
samples = "analysis/{analysis}/independent/posterior_samples.joblib",
mutations = "data/preprocessed/{analysis}/mutation-matrix.csv"
output: "analysis/{analysis}/independent/posterior_predictive.joblib"
run:
Y = pd.read_csv(input.mutations, index_col=0).values
posterior_samples = joblib.load(input.samples)
predictive = Predictive(independent_model, posterior_samples=posterior_samples)
predictive_samples = predictive(jax.random.PRNGKey(12), N=Y.shape[0], G=Y.shape[1])["obs"]
joblib.dump(predictive_samples, str(output))


# === Model-independent-rules ===

def optionally_order_matrix(data: np.ndarray, ordered: bool = True) -> np.ndarray:
id0 = np.argsort(np.sum(data, axis=1))
id1 = np.argsort(np.sum(data, axis=0))
if ordered:
return data[id0, :][:, id1]
else:
return data


def plot_posterior_predictive_matrices(
axs,
samples,
mutations,
model: str,
ordered: bool,
seed: int = 42,
) -> None:
rng = np.random.default_rng(seed)
indices = rng.choice(samples.shape[0], size=len(axs.ravel()), replace=False)

suffix = "" if not ordered else " (ordered)"

for ax, index in zip(axs.ravel(), indices):
data = samples[index, ...]

sns.heatmap(optionally_order_matrix(data, ordered), ax=ax, cmap="Blues", xticklabels=False, yticklabels=False, cbar=False, square=False)
ax.set_xlabel("Genes" + suffix)
ax.set_ylabel("Patients" + suffix)

x_true, y_true = MODELS[model].true_loc
ax = axs[x_true, y_true]
ax.clear()

sns.heatmap(optionally_order_matrix(mutations, ordered), ax=ax, cmap="Blues", xticklabels=False, yticklabels=False, cbar=False, square=False)
ax.set_xlabel("Genes" + suffix)
ax.set_ylabel("Patients" + suffix)


rule plot_posterior_predictive_matrices:
input:
mutations = "data/preprocessed/{analysis}/mutation-matrix.csv",
Expand All @@ -189,21 +291,101 @@ rule plot_posterior_predictive_matrices:
samples = joblib.load(input.posterior_predictive)

fig, axs = plt.subplots(3, 4, sharex=True, sharey=True, dpi=250, figsize=(4*2, 3*2))
for ax, data in zip(axs.ravel(), samples[:len(axs.ravel()), ...]):
sns.heatmap(data, ax=ax, cmap="Blues", xticklabels=False, yticklabels=False, cbar=False, square=False)
ax.set_xlabel("Patients")
ax.set_ylabel("Genes")
plot_posterior_predictive_matrices(axs, samples, mutations=mutations, model=wildcards.model, ordered=False)
fig.tight_layout()
fig.savefig(output.matrices)


rule plot_posterior_predictive_matrices_ordered:
input:
mutations = "data/preprocessed/{analysis}/mutation-matrix.csv",
posterior_predictive = "analysis/{analysis}/{model}/posterior_predictive.joblib"
output:
matrices = "analysis/{analysis}/{model}/posterior_predictive_matrices_ordered.pdf",
run:
mutations = pd.read_csv(input.mutations, index_col=0).values
samples = joblib.load(input.posterior_predictive)

fig, axs = plt.subplots(3, 4, sharex=True, sharey=True, dpi=250, figsize=(4*2, 3*2))
plot_posterior_predictive_matrices(axs, samples, mutations=mutations, model=wildcards.model, ordered=True)
fig.tight_layout()
fig.savefig(output.matrices)


rule plot_posterior_number_of_mutations_histograms_many_panels:
input:
mutations = "data/preprocessed/{analysis}/mutation-matrix.csv",
posterior_predictive = "analysis/{analysis}/{model}/posterior_predictive.joblib"
output: "analysis/{analysis}/{model}/histogram_number_of_mutations_many_panels.pdf"
run:
mutations = pd.read_csv(input.mutations, index_col=0).values
samples = joblib.load(input.posterior_predictive)

fig, axs = plt.subplots(3, 4, sharex=True, sharey=True, dpi=250, figsize=(4*2, 3*2))

seed = 42
rng = np.random.default_rng(seed)
indices = rng.choice(samples.shape[0], size=len(axs.ravel()), replace=False)

max_mut = max(samples.sum(axis=-1).max(), mutations.sum(axis=-1).max())
bins = np.arange(-0.5, max_mut + 1.5)

for ax, index in zip(axs.ravel(), indices):
data = samples[index, ...]
n_mutations = data.sum(axis=-1)
ax.hist(n_mutations, bins=bins, rasterized=True, color="darkblue")

ax.spines[["top", "right"]].set_visible(False)

x_true, y_true = MODELS[wildcards.model].true_loc
ax = axs[x_true, y_true]
ax.clear()
n_mutations = mutations.sum(axis=-1)
ax.hist(n_mutations, bins=bins, rasterized=True, color="goldenrod")


sns.heatmap(mutations, ax=ax, cmap="Blues", xticklabels=False, yticklabels=False, cbar=False, square=False)
ax.set_xlabel("Patients")
ax.set_ylabel("Genes")
for ax in axs.ravel():
ax.set_xlabel("Num. of mutations")
ax.set_xticks(np.arange(0, max_mut + 1, 20))

for ax in axs[:, 0]:
ax.set_ylabel("Num. of patients")

fig.tight_layout()
fig.savefig(output.matrices)
fig.savefig(str(output))


rule plot_posterior_number_of_mutations_histograms_single_panel:
input:
mutations = "data/preprocessed/{analysis}/mutation-matrix.csv",
posterior_predictive = "analysis/{analysis}/{model}/posterior_predictive.joblib"
output: "analysis/{analysis}/{model}/histogram_number_of_mutations_single_panel.pdf"
run:
mutations = pd.read_csv(input.mutations, index_col=0).values
samples = joblib.load(input.posterior_predictive)

fig, ax = plt.subplots(figsize=FIGSIZE, dpi=DPI)
ax.spines[["top", "right"]].set_visible(False)
ax.set_xlabel("Number of mutations")
ax.set_ylabel("Number of patients")

seed = 42
rng = np.random.default_rng(seed)
indices = rng.choice(samples.shape[0], size=min(samples.shape[0], 150), replace=False)

max_mut = max(samples.sum(axis=-1).max(), mutations.sum(axis=-1).max())
bins = np.arange(-0.5, max_mut + 1.5)

for index in indices:
data = samples[index, ...]
n_mutations = data.sum(axis=-1)
ax.hist(n_mutations, bins=bins, rasterized=True, color="darkblue", linewidth=0.2, alpha=0.15, histtype="step")

n_mutations = mutations.sum(axis=-1)
ax.hist(n_mutations, bins=bins, rasterized=True, color="goldenrod", linewidth=1, histtype="step")

fig.tight_layout()
fig.savefig(str(output))


rule plot_posterior_predictive_gene_occurrence:
Expand All @@ -217,15 +399,15 @@ rule plot_posterior_predictive_gene_occurrence:

fig, axs = plt.subplots(3, 4, sharex=True, sharey=True, dpi=250, figsize=(4*2, 3*2))
for ax, data in zip(axs.ravel(), samples[:len(axs.ravel()), ...]):
ax.scatter(np.arange(data.shape[1]), np.sort(data.sum(axis=0)), s=2, c="darkblue")
ax.scatter(np.arange(data.shape[1]), np.sort(data.mean(axis=0)), s=2, c="darkblue")

# Add the ground-truth values
x_true, y_true = MODELS[wildcards.model].true_loc
ax = axs[x_true, y_true]
data = pd.read_csv(input.mutations, index_col=0).values
ax.clear()

ax.scatter(np.arange(data.shape[1]), np.sort(data.sum(axis=0)), s=2, c="darkblue")
ax.scatter(np.arange(data.shape[1]), np.sort(data.mean(axis=0)), s=2, c="darkblue")

for ax in axs.ravel():
ax.set_xticks([])
Expand Down

0 comments on commit 936a64b

Please sign in to comment.