Skip to content

Commit

Permalink
Merge branch 'trunk' into add_tvn
Browse files Browse the repository at this point in the history
  • Loading branch information
safiyecelik committed Dec 21, 2023
2 parents 7b92f4d + 61e3ad8 commit af3260b
Show file tree
Hide file tree
Showing 3 changed files with 184 additions and 21 deletions.
22 changes: 16 additions & 6 deletions efaar_benchmarking/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def univariate_consistency_benchmark(
metadata: pd.DataFrame,
pert_col: str,
keys_to_drop: str,
n_samples: int = 10000,
n_samples: int = cst.N_NULL_SAMPLES,
random_seed: int = cst.RANDOM_SEED,
) -> pd.DataFrame:
"""
Expand All @@ -51,7 +51,8 @@ def univariate_consistency_benchmark(
metadata (pd.DataFrame): The metadata dataframe.
pert_col (str): The column name in the metadata dataframe representing the perturbations.
keys_to_drop (str): The perturbation keys to be dropped from the analysis.
n_samples (int, optional): The number of samples to generate for null distribution. Defaults to 5000.
n_samples (int, optional): The number of samples to generate for null distribution.
Defaults to cst.N_NULL_SAMPLES.
random_seed (int, optional): The random seed to use for generating null distribution.
Defaults to cst.RANDOM_SEED.
Expand Down Expand Up @@ -197,6 +198,7 @@ def compute_recall(
null_distribution: np.ndarray,
query_distribution: np.ndarray,
recall_threshold_pairs: list,
right_sided: bool = False,
) -> dict:
"""Compute recall at given percentage thresholds for a query distribution with respect to a null distribution.
Each recall threshold is a pair of floats (left, right) where left and right are floats between 0 and 1.
Expand All @@ -206,6 +208,8 @@ def compute_recall(
query_distribution (np.ndarray): The query distribution
recall_threshold_pairs (list) A list of pairs of floats (left, right) that represent different recall threshold
pairs, where left and right are floats between 0 and 1.
right_sided (bool, optional): Whether to consider only right tail of the distribution or both tails when
computing recall Defaults to False (i.e, both tails).
Returns:
dict: A dictionary of metrics with the following keys:
Expand All @@ -227,9 +231,14 @@ def compute_recall(
)
for threshold_pair in recall_threshold_pairs:
left_threshold, right_threshold = np.min(threshold_pair), np.max(threshold_pair)
metrics[f"recall_{left_threshold}_{right_threshold}"] = sum(
(query_percentage_ranks_right <= left_threshold) | (query_percentage_ranks_left >= right_threshold)
) / len(query_distribution)
if right_sided:
metrics[f"recall_{left_threshold}_{right_threshold}"] = sum(
query_percentage_ranks_left >= right_threshold
) / len(query_distribution)
else:
metrics[f"recall_{left_threshold}_{right_threshold}"] = sum(
(query_percentage_ranks_right <= left_threshold) | (query_percentage_ranks_left >= right_threshold)
) / len(query_distribution)
return metrics


Expand Down Expand Up @@ -270,6 +279,7 @@ def multivariate_benchmark(
n_iterations: int = cst.RANDOM_COUNT,
min_req_entity_cnt: int = cst.MIN_REQ_ENT_CNT,
benchmark_data_dir: str = cst.BENCHMARK_DATA_DIR,
right_sided: bool = False,
) -> pd.DataFrame:
"""Perform benchmarking on map data.
Expand Down Expand Up @@ -321,7 +331,7 @@ def multivariate_benchmark(
if len(query_cossim) > 0:
metrics_lst.append(
convert_metrics_to_df(
metrics=compute_recall(null_cossim, query_cossim, recall_thr_pairs),
metrics=compute_recall(null_cossim, query_cossim, recall_thr_pairs, right_sided),
source=s,
random_seed_str=random_seed_str,
filter_on_pert_prints=filter_on_pert_prints,
Expand Down
12 changes: 10 additions & 2 deletions efaar_benchmarking/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import seaborn as sns


def plot_recall(metric_dfs: dict):
def plot_recall(metric_dfs: dict, right_sided: bool = False, title=""):
"""
Plots line plots of recall values for several threshold pairs for each benchmark source and each map.
Expand All @@ -14,14 +14,19 @@ def plot_recall(metric_dfs: dict):
Each dataframe needs to have a "source" column and "recall_X_Y" columns for several [X, Y] pairs.
Each metric dataframe in the dictionary corresponds to a different map. All the dataframes
need to have the exact same structure (ie same column names and same source values).
right_sided (bool): Whether to plot the right-sided recall values. Default is False.
title (str): Title for the entire plot. Default is an empty string.
Returns:
None
"""
df_template = list(metric_dfs.values())[0] # this is used as the template for the structure and labels of the plots
recall_thr_pairs = [col.split("_")[1:] for col in df_template.columns if col.startswith("recall_")]
x_values = [f"{x}, {y}" for x, y in recall_thr_pairs]
random_recall_values = [float(x) + 1 - float(y) for x, y in recall_thr_pairs]
if right_sided:
random_recall_values = [1 - float(y) for x, y in recall_thr_pairs]
else:
random_recall_values = [float(x) + 1 - float(y) for x, y in recall_thr_pairs]

col_cnt = 5
sns.set_style("whitegrid")
Expand All @@ -30,6 +35,9 @@ def plot_recall(metric_dfs: dict):
)
palette = dict(zip(metric_dfs.keys(), sns.color_palette("tab10", len(metric_dfs))))

# Set the title for the entire plot
fig.suptitle(title)

# Plot each source as a separate subplot
for i, source in enumerate(df_template["source"].unique()):
for m in metric_dfs.keys():
Expand Down
171 changes: 158 additions & 13 deletions notebooks/periscope_map_building.ipynb

Large diffs are not rendered by default.

0 comments on commit af3260b

Please sign in to comment.