Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add plotting #14

Merged
merged 8 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 57 additions & 52 deletions efaar_benchmarking/benchmarking.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
from collections import defaultdict

import numpy as np
from sklearn.utils import Bunch

from efaar_benchmarking.utils import (
generate_null_cossims,
generate_query_cossims,
get_benchmark_data,
compute_recall,
convert_metrics_to_df,
)
import efaar_benchmarking.constants as cst
from efaar_benchmarking.utils import compute_pairwise_metrics, get_feats_w_indices
from sklearn.utils import Bunch
import pandas as pd
import random


def pert_stats(
Expand Down Expand Up @@ -45,62 +49,63 @@ def pert_stats(

def benchmark(
map_data: Bunch,
pert_label_col: str = cst.PERT_LABEL_COL,
benchmark_sources: list = cst.BENCHMARK_SOURCES,
filter_on_pert_type: bool = False,
filter_on_well_type: bool = False,
pert_label_col: str = cst.PERT_LABEL_COL,
recall_thr_pairs: list = cst.RECALL_PERC_THRS,
filter_on_pert_prints: bool = False,
run_count: int = cst.RANDOM_COUNT,
) -> dict:
pert_print_pvalue_thr: float = cst.PERT_SIG_PVAL_THR,
n_null_samples: int = cst.N_NULL_SAMPLES,
random_seed: int = cst.RANDOM_SEED,
n_iterations: int = cst.RANDOM_COUNT,
) -> pd.DataFrame:
"""Perform benchmarking on map data.

Args:
map_data (Bunch): The map data containing features and metadata.
pert_label_col (str, optional): Column name for perturbation labels. Defaults to cst.PERT_LABEL_COL.
map_data (Bunch): The map data containing `features` and `metadata` attributes.
benchmark_sources (list, optional): List of benchmark sources. Defaults to cst.BENCHMARK_SOURCES.
filter_on_pert_type (bool, optional): Flag to filter map data based on perturbation type. Defaults to False.
filter_on_well_type (bool, optional): Flag to filter map data based on well type. Defaults to False.
pert_label_col (str, optional): Column name for perturbation labels. Defaults to cst.PERT_LABEL_COL.
recall_thr_pairs (list, optional): List of recall percentage threshold pairs. Defaults to cst.RECALL_PERC_THRS.
filter_on_pert_prints (bool, optional): Flag to filter map data based on perturbation prints. Defaults to False.
run_count (int, optional): Number of random seed pairs to use. Defaults to cst.RANDOM_COUNT.
pert_print_pvalue_thr (float, optional): P-value threshold for perturbation filtering. Defaults to cst.PERT_SIG_PVAL_THR.
n_null_samples (int, optional): Number of null samples to generate. Defaults to cst.N_NULL_SAMPLES.
random_seed (int, optional): Random seed to use for generating null samples. Defaults to cst.RANDOM_SEED.
n_iterations (int, optional): Number of random seed pairs to use. Defaults to cst.RANDOM_COUNT.

Returns:
dict: A dictionary containing the benchmark results for each seed pair and benchmark source.
pd.DataFrame: a dataframe with benchmarking results. The columns are:
"source": benchmark source name
"random_seed": random seed string from random seeds 1 and 2
"recall_{low}_{high}": recall at requested thresholds
"""

assert len(benchmark_sources) > 0 and all(
[src in cst.BENCHMARK_SOURCES for src in benchmark_sources]
), "Invalid benchmark source(s) provided."
md = map_data.metadata
idx = [True] * len(md)
if filter_on_pert_type:
idx = idx & (md[cst.PERT_TYPE_COL] == cst.PERT_TYPE)
if filter_on_well_type:
idx = idx & (md[cst.WELL_TYPE_COL] == cst.WELL_TYPE)
if filter_on_pert_prints:
pval_thresh = cst.PERT_SIG_PVAL_THR if filter_on_pert_prints else 1
idx = idx & (md[cst.PERT_SIG_PVAL_COL] <= pval_thresh)
print(sum(idx), "gene perturbations in the map.")
map_data = Bunch(features=map_data.features[idx], metadata=md[idx])
res = defaultdict(dict) # type: ignore
feats_w_indices = get_feats_w_indices(map_data, pert_label_col)
if len(set(feats_w_indices.index)) >= cst.MIN_REQ_ENT_CNT:
np.random.seed(cst.RANDOM_SEED)
# numpy requires seeds to be between 0 and 2 ** 32 - 1
random_seed_pairs = np.random.randint(2**32, size=run_count * 2).reshape(run_count, 2)
for rs1, rs2 in random_seed_pairs:
res_seed = res[f"Seeds_{rs1}_{rs2}"]
for src in benchmark_sources:
if src not in res_seed:
res_curr = compute_pairwise_metrics(
feats_w_indices,
src,
cst.RECALL_PERC_THR_PAIR,
rs1,
rs2,
cst.N_NULL_SAMPLES,
cst.N_NULL_SAMPLES,
)
if res_curr is not None:
res_seed[src] = res_curr
idx = (md[cst.PERT_SIG_PVAL_COL] <= pert_print_pvalue_thr) if filter_on_pert_prints else [True] * len(md)
features = map_data.features[idx].set_index(md[idx][pert_label_col]).rename_axis(index=None)
del map_data
assert len(features) == len(set(features.index)), "Duplicate perturbation labels in the map."
assert len(features) >= cst.MIN_REQ_ENT_CNT, "Not enough entities in the map for benchmarking."
print(len(features), "perturbations in the map.")

res[f"Seeds_{rs1}_{rs2}"] = res_seed
else:
print("Not enough entities in the map for benchmarking")
return res
metrics_lst = []
random.seed(random_seed)
random_seed_pairs = [
(random.randint(0, 2**31 - 1), random.randint(0, 2**31 - 1)) for _ in range(n_iterations)
] # numpy requires seeds to be between 0 and 2 ** 32 - 1
for rs1, rs2 in random_seed_pairs:
random_seed_str = f"{rs1}_{rs2}"
null_cossim = generate_null_cossims(features, n_null_samples, rs1, rs2)
for s in benchmark_sources:
query_cossim = generate_query_cossims(features, get_benchmark_data(s))
single_seed_result = compute_recall(null_cossim, query_cossim, recall_thr_pairs)
metrics_lst.append(
convert_metrics_to_df(
metrics=single_seed_result,
source=s,
random_seed_str=random_seed_str,
filter_on_pert_prints=filter_on_pert_prints,
)
)
return pd.concat(metrics_lst, ignore_index=True)
9 changes: 3 additions & 6 deletions efaar_benchmarking/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,11 @@
"benchmark_annotations"
)
BENCHMARK_SOURCES = ["Reactome", "HuMAP", "CORUM", "SIGNOR", "StringDB"]
PERT_LABEL_COL = "gene"
PERT_LABEL_COL = "perturbation"
CONTROL_PERT_LABEL = "non-targeting"
PERT_SIG_PVAL_COL = "perturbation_pvalue"
PERT_SIG_PVAL_THR = 0.01
PERT_TYPE_COL = "perturbation_type"
PERT_TYPE = "GENE"
WELL_TYPE_COL = "well_type"
WELL_TYPE = "query_guides"
RECALL_PERC_THR_PAIR = (0.05, 0.95)
RECALL_PERC_THRS = [(0.05, 0.95), (0.1, 0.9)]
RANDOM_SEED = 42
RANDOM_COUNT = 3
N_NULL_SAMPLES = 5000
Expand Down
9 changes: 5 additions & 4 deletions efaar_benchmarking/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os

import scanpy as sc
import wget

import scanpy as sc
import numpy as np

def load_replogle(gene_type, data_type, data_path="data/"):
"""
Expand Down Expand Up @@ -42,4 +41,6 @@ def load_replogle(gene_type, data_type, data_path="data/"):
if not os.path.exists(data_path + filename):
wget.download(src, data_path + filename)

return sc.read_h5ad(data_path + filename)
adata = sc.read_h5ad(data_path + filename)
adata.X = np.nan_to_num(adata.X)
return adata
1 change: 0 additions & 1 deletion efaar_benchmarking/efaar.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def embed_by_pca(adata, N_LATENT=100):

Parameters:
adata (AnnData): Annotated data matrix.
BATCH_KEY (str): Key for batch information in adata.obs.
N_LATENT (int): Number of principal components to use.

Returns:
Expand Down
39 changes: 39 additions & 0 deletions efaar_benchmarking/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

def plot_recall(df):
"""
Plots a line plot of recall values for several thereshold pairs for each source.

Parameters:
df (pandas.DataFrame): A dataframe with "source" column and "recall_X_Y" columns for several [X, Y] pairs.

Returns:
None
"""
xy_pairs = [col.split("_")[1:] for col in df.columns if col.startswith("recall_")]

col_cnt = 5
sns.set_style("whitegrid")
fig, axs = plt.subplots(nrows=round(len(df['source'].unique())/col_cnt), ncols=col_cnt, figsize=(15, 3), squeeze=False)

color = "r"
# Plot each source as a separate subplot
for i, source in enumerate(df['source'].unique()):
source_df = df[df['source'] == source]
x_values_orig = [f"{x}, {y}" for x, y in xy_pairs]
y_values = [list(source_df[f"recall_{x}_{y}"]) for x, y in xy_pairs]
x_values = [i for i, sublist in zip(x_values_orig, y_values) for _ in sublist]
tmp_data = pd.DataFrame({'x': x_values, 'y': sum(y_values, [])})
curr_ax = axs[i//col_cnt, i%col_cnt]
sns.lineplot(ax=curr_ax, x="x", y="y", data=tmp_data, color=color, marker="o", markersize=8, markerfacecolor=color, markeredgewidth=2, errorbar=('ci', 99))
curr_ax.set_title(source)
curr_ax.set_xlabel("Recall Thresholds")
curr_ax.set_ylabel("Recall Value")
curr_ax.set_xticks(range(len(x_values_orig)))
curr_ax.set_xticklabels(x_values_orig, rotation=45)

plt.tight_layout()
plt.show()
Loading
Loading