Skip to content

Commit

Permalink
add p-value computation
Browse files Browse the repository at this point in the history
  • Loading branch information
safiyecelik committed Dec 4, 2023
1 parent d0901f1 commit 54f034c
Showing 1 changed file with 55 additions and 21 deletions.
76 changes: 55 additions & 21 deletions efaar_benchmarking/benchmarking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import random

import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.utils import Bunch

import efaar_benchmarking.constants as cst
Expand All @@ -13,38 +15,70 @@
)


def pert_stats(
map_data: Bunch,
pert_pval_thr: float = cst.PERT_SIG_PVAL_THR,
):
def univariate_cons_metric(arr: np.ndarray, null=None) -> tuple[float, float]:
"""
Calculate perturbation statistics based on the provided map data.
Calculate the univariate consistency metric, i.e. average cosine angle and associated p-value, for a given array.
Args:
map_data (Bunch): Map data containing metadata.
pert_pval_thr (float): Perturbation significance threshold. Defaults to cst.PERT_SIG_PVAL_THR.
arr (numpy.ndarray): The input array.
null (numpy.ndarray, optional): Null distribution of the metric. Default is None.
Returns:
dict: Dictionary containing perturbation statistics:
- "all_pert_count": Total count of perturbations.
- "pp_pert_count": Count of perturbations that meet the significance threshold.
- "pp_pert_percent": Percentage of perturbations that meet the significance threshold.
tuple: A tuple containing the average angle (avg_angle) and p-value (pval) of the metric.
If the length of the input array is less than 3, returns (None, None).
If null is None, returns (avg_angle, None).
"""
if len(arr) < 3:
return None, None
cosine_sim = cosine_similarity(arr)
avg_angle = np.arccos(cosine_sim[np.tril_indices(cosine_sim.shape[0], k=-1)]).mean()
if null is None:
return avg_angle, None
else:
sorted_null = np.sort(null[len(arr)])
pval = np.searchsorted(sorted_null, avg_angle) / len(sorted_null)
return avg_angle, pval

md = map_data.metadata
idx = [True] * len(md)
pidx = md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr
return {
"all_pert_count": sum(idx),
"pp_pert_count": sum(idx & pidx),
"pp_pert_percent": sum(idx & pidx) / sum(idx),

def univariate_cons_benchmark(
features: np.ndarray, metadata: pd.DataFrame, pert_col: str, keys_to_drop: str, n_samples: int = 5000
) -> pd.DataFrame:
"""
Perform univariate benchmarking on the given features and metadata.
Args:
features (np.ndarray): The array of features.
metadata (pd.DataFrame): The metadata dataframe.
pert_col (str): The column name in the metadata dataframe representing the perturbations.
keys_to_drop (str): The perturbation keys to be dropped from the analysis.
n_samples (int, optional): The number of samples to generate for null distribution. Defaults to 5000.
Returns:
pd.DataFrame: The dataframe containing the query metrics.
"""
indices = ~metadata[pert_col].isin(keys_to_drop)
features = features[indices]
metadata = metadata[indices]

unique_cardinalities = metadata.groupby(pert_col).count().iloc[:, 0].unique()
print(unique_cardinalities)
null = {
x: [univariate_cons_metric(np.random.default_rng().choice(features, x, False))[0] for i in range(n_samples)]
for x in unique_cardinalities
}

features_df = pd.DataFrame(features, index=metadata[pert_col])
query_metrics = features_df.groupby(features_df.index).apply(lambda x: univariate_cons_metric(x.values, null)[1])
query_metrics.name = "avg_cossim_pval"
query_metrics = query_metrics.reset_index()

return query_metrics


def benchmark(
map_data: Bunch,
benchmark_sources: list = cst.BENCHMARK_SOURCES,
pert_label_col: str = cst.REPLOGLE_PERT_LABEL_COL,
pert_col: str = cst.REPLOGLE_PERT_LABEL_COL,
recall_thr_pairs: list = cst.RECALL_PERC_THRS,
filter_on_pert_prints: bool = False,
pert_pval_thr: float = cst.PERT_SIG_PVAL_THR,
Expand All @@ -59,7 +93,7 @@ def benchmark(
Args:
map_data (Bunch): The map data containing `features` and `metadata` attributes.
benchmark_sources (list, optional): List of benchmark sources. Defaults to cst.BENCHMARK_SOURCES.
pert_label_col (str, optional): Column name for perturbation labels. Defaults to cst.PERT_LABEL_COL.
pert_col (str, optional): Column name for perturbation labels. Defaults to cst.PERT_LABEL_COL.
recall_thr_pairs (list, optional): List of recall percentage threshold pairs. Defaults to cst.RECALL_PERC_THRS.
filter_on_pert_prints (bool, optional): Flag to filter map data based on perturbation prints. Defaults to False.
pert_pval_thr (float, optional): pvalue threshold for perturbation filtering. Defaults to cst.PERT_SIG_PVAL_THR.
Expand All @@ -81,7 +115,7 @@ def benchmark(
ValueError("Invalid benchmark source(s) provided.")
md = map_data.metadata
idx = (md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr) if filter_on_pert_prints else [True] * len(md)
features = map_data.features[idx].set_index(md[idx][pert_label_col]).rename_axis(index=None)
features = map_data.features[idx].set_index(md[idx][pert_col]).rename_axis(index=None)
del map_data
if not len(features) == len(set(features.index)):
ValueError("Duplicate perturbation labels in the map.")
Expand Down

0 comments on commit 54f034c

Please sign in to comment.