diff --git a/efaar_benchmarking/benchmarking.py b/efaar_benchmarking/benchmarking.py index 6aec832..be7b03d 100644 --- a/efaar_benchmarking/benchmarking.py +++ b/efaar_benchmarking/benchmarking.py @@ -1,6 +1,8 @@ import random +import numpy as np import pandas as pd +from sklearn.metrics.pairwise import cosine_similarity from sklearn.utils import Bunch import efaar_benchmarking.constants as cst @@ -13,38 +15,70 @@ ) -def pert_stats( - map_data: Bunch, - pert_pval_thr: float = cst.PERT_SIG_PVAL_THR, -): +def univariate_cons_metric(arr: np.ndarray, null=None) -> tuple[float, float]: """ - Calculate perturbation statistics based on the provided map data. + Calculate the univariate consistency metric, i.e. average cosine angle and associated p-value, for a given array. Args: - map_data (Bunch): Map data containing metadata. - pert_pval_thr (float): Perturbation significance threshold. Defaults to cst.PERT_SIG_PVAL_THR. + arr (numpy.ndarray): The input array. + null (numpy.ndarray, optional): Null distribution of the metric. Default is None. Returns: - dict: Dictionary containing perturbation statistics: - - "all_pert_count": Total count of perturbations. - - "pp_pert_count": Count of perturbations that meet the significance threshold. - - "pp_pert_percent": Percentage of perturbations that meet the significance threshold. + tuple: A tuple containing the average angle (avg_angle) and p-value (pval) of the metric. + If the length of the input array is less than 3, returns (None, None). + If null is None, returns (avg_angle, None). """ + if len(arr) < 3: + return None, None + cosine_sim = cosine_similarity(arr) + avg_angle = np.arccos(cosine_sim[np.tril_indices(cosine_sim.shape[0], k=-1)]).mean() + if null is None: + return avg_angle, None + else: + sorted_null = np.sort(null[len(arr)]) + pval = np.searchsorted(sorted_null, avg_angle) / len(sorted_null) + return avg_angle, pval - md = map_data.metadata - idx = [True] * len(md) - pidx = md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr - return { - "all_pert_count": sum(idx), - "pp_pert_count": sum(idx & pidx), - "pp_pert_percent": sum(idx & pidx) / sum(idx), + +def univariate_cons_benchmark( + features: np.ndarray, metadata: pd.DataFrame, pert_col: str, keys_to_drop: str, n_samples: int = 5000 +) -> pd.DataFrame: + """ + Perform univariate benchmarking on the given features and metadata. + + Args: + features (np.ndarray): The array of features. + metadata (pd.DataFrame): The metadata dataframe. + pert_col (str): The column name in the metadata dataframe representing the perturbations. + keys_to_drop (str): The perturbation keys to be dropped from the analysis. + n_samples (int, optional): The number of samples to generate for null distribution. Defaults to 5000. + + Returns: + pd.DataFrame: The dataframe containing the query metrics. + """ + indices = ~metadata[pert_col].isin(keys_to_drop) + features = features[indices] + metadata = metadata[indices] + + unique_cardinalities = metadata.groupby(pert_col).count().iloc[:, 0].unique() + print(unique_cardinalities) + null = { + x: [univariate_cons_metric(np.random.default_rng().choice(features, x, False))[0] for i in range(n_samples)] + for x in unique_cardinalities } + features_df = pd.DataFrame(features, index=metadata[pert_col]) + query_metrics = features_df.groupby(features_df.index).apply(lambda x: univariate_cons_metric(x.values, null)[1]) + query_metrics.name = "avg_cossim_pval" + query_metrics = query_metrics.reset_index() + + return query_metrics + def benchmark( map_data: Bunch, benchmark_sources: list = cst.BENCHMARK_SOURCES, - pert_label_col: str = cst.REPLOGLE_PERT_LABEL_COL, + pert_col: str = cst.REPLOGLE_PERT_LABEL_COL, recall_thr_pairs: list = cst.RECALL_PERC_THRS, filter_on_pert_prints: bool = False, pert_pval_thr: float = cst.PERT_SIG_PVAL_THR, @@ -59,7 +93,7 @@ def benchmark( Args: map_data (Bunch): The map data containing `features` and `metadata` attributes. benchmark_sources (list, optional): List of benchmark sources. Defaults to cst.BENCHMARK_SOURCES. - pert_label_col (str, optional): Column name for perturbation labels. Defaults to cst.PERT_LABEL_COL. + pert_col (str, optional): Column name for perturbation labels. Defaults to cst.PERT_LABEL_COL. recall_thr_pairs (list, optional): List of recall percentage threshold pairs. Defaults to cst.RECALL_PERC_THRS. filter_on_pert_prints (bool, optional): Flag to filter map data based on perturbation prints. Defaults to False. pert_pval_thr (float, optional): pvalue threshold for perturbation filtering. Defaults to cst.PERT_SIG_PVAL_THR. @@ -81,7 +115,7 @@ def benchmark( ValueError("Invalid benchmark source(s) provided.") md = map_data.metadata idx = (md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr) if filter_on_pert_prints else [True] * len(md) - features = map_data.features[idx].set_index(md[idx][pert_label_col]).rename_axis(index=None) + features = map_data.features[idx].set_index(md[idx][pert_col]).rename_axis(index=None) del map_data if not len(features) == len(set(features.index)): ValueError("Duplicate perturbation labels in the map.")