Skip to content

Commit

Permalink
flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
safiyecelik committed Oct 25, 2023
1 parent 920d3cb commit 459af8a
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 13 deletions.
12 changes: 6 additions & 6 deletions efaar_benchmarking/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@

def pert_stats(
map_data: Bunch,
pert_print_pvalue_thr: float = cst.PERT_SIG_PVAL_THR,
pert_pval_thr: float = cst.PERT_SIG_PVAL_THR,
):
"""
Calculate perturbation statistics based on the provided map data.
Args:
map_data (Bunch): Map data containing metadata.
pert_print_pvalue_thr (float): Perturbation significance threshold. Defaults to cst.PERT_SIG_PVAL_THR.
pert_pval_thr (float): Perturbation significance threshold. Defaults to cst.PERT_SIG_PVAL_THR.
Returns:
dict: Dictionary containing perturbation statistics:
Expand All @@ -31,7 +31,7 @@ def pert_stats(

md = map_data.metadata
idx = [True] * len(md)
pidx = md[cst.PERT_SIG_PVAL_COL] <= pert_print_pvalue_thr
pidx = md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr
return {
"all_pert_count": sum(idx),
"pp_pert_count": sum(idx & pidx),
Expand All @@ -45,7 +45,7 @@ def benchmark(
pert_label_col: str = cst.PERT_LABEL_COL,
recall_thr_pairs: list = cst.RECALL_PERC_THRS,
filter_on_pert_prints: bool = False,
pert_print_pvalue_thr: float = cst.PERT_SIG_PVAL_THR,
pert_pval_thr: float = cst.PERT_SIG_PVAL_THR,
n_null_samples: int = cst.N_NULL_SAMPLES,
random_seed: int = cst.RANDOM_SEED,
n_iterations: int = cst.RANDOM_COUNT,
Expand All @@ -58,7 +58,7 @@ def benchmark(
pert_label_col (str, optional): Column name for perturbation labels. Defaults to cst.PERT_LABEL_COL.
recall_thr_pairs (list, optional): List of recall percentage threshold pairs. Defaults to cst.RECALL_PERC_THRS.
filter_on_pert_prints (bool, optional): Flag to filter map data based on perturbation prints. Defaults to False.
pert_print_pvalue_thr (float, optional): P-value threshold for perturbation filtering. Defaults to cst.PERT_SIG_PVAL_THR.
pert_pval_thr (float, optional): pvalue threshold for perturbation filtering. Defaults to cst.PERT_SIG_PVAL_THR.
n_null_samples (int, optional): Number of null samples to generate. Defaults to cst.N_NULL_SAMPLES.
random_seed (int, optional): Random seed to use for generating null samples. Defaults to cst.RANDOM_SEED.
n_iterations (int, optional): Number of random seed pairs to use. Defaults to cst.RANDOM_COUNT.
Expand All @@ -74,7 +74,7 @@ def benchmark(
[src in cst.BENCHMARK_SOURCES for src in benchmark_sources]
), "Invalid benchmark source(s) provided."
md = map_data.metadata
idx = (md[cst.PERT_SIG_PVAL_COL] <= pert_print_pvalue_thr) if filter_on_pert_prints else [True] * len(md)
idx = (md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr) if filter_on_pert_prints else [True] * len(md)
features = map_data.features[idx].set_index(md[idx][pert_label_col]).rename_axis(index=None)
del map_data
assert len(features) == len(set(features.index)), "Duplicate perturbation labels in the map."
Expand Down
1 change: 0 additions & 1 deletion efaar_benchmarking/plotting.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np


def plot_recall(df):
Expand Down
10 changes: 4 additions & 6 deletions efaar_benchmarking/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Optional

import numpy as np
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.utils import Bunch

import efaar_benchmarking.constants as cst


Expand Down Expand Up @@ -51,7 +48,8 @@ def generate_null_cossims(
rseed_entity2 (int): Random seed for sampling subset from entity2_feats.
Returns:
np.ndarray: A NumPy array containing the null cosine similarity values between the randomly sampled subsets of entities.
np.ndarray: A NumPy array containing the null cosine similarity values between the randomly sampled subsets
of entities.
"""

np.random.seed(rseed_entity1)
Expand Down Expand Up @@ -109,8 +107,8 @@ def compute_recall(
Parameters:
null_distribution (np.ndarray): The null distribution to compare against
query_distribution (np.ndarray): The query distribution
recall_threshold_pairs (list) A list of pairs of floats (left, right) that represent different recall threshold pairs, where
left and right are floats between 0 and 1.
recall_threshold_pairs (list) A list of pairs of floats (left, right) that represent different recall threshold
pairs, where left and right are floats between 0 and 1.
Returns:
dict: A dictionary of metrics with the following keys:
Expand Down

0 comments on commit 459af8a

Please sign in to comment.