Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add right-sided benchmarks for periscope #32

Merged
merged 1 commit into from
Dec 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions efaar_benchmarking/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def univariate_consistency_benchmark(
metadata: pd.DataFrame,
pert_col: str,
keys_to_drop: str,
n_samples: int = 10000,
n_samples: int = cst.N_NULL_SAMPLES,
random_seed: int = cst.RANDOM_SEED,
) -> pd.DataFrame:
"""
Expand All @@ -51,7 +51,8 @@ def univariate_consistency_benchmark(
metadata (pd.DataFrame): The metadata dataframe.
pert_col (str): The column name in the metadata dataframe representing the perturbations.
keys_to_drop (str): The perturbation keys to be dropped from the analysis.
n_samples (int, optional): The number of samples to generate for null distribution. Defaults to 5000.
n_samples (int, optional): The number of samples to generate for null distribution.
Defaults to cst.N_NULL_SAMPLES.
random_seed (int, optional): The random seed to use for generating null distribution.
Defaults to cst.RANDOM_SEED.

Expand Down Expand Up @@ -197,6 +198,7 @@ def compute_recall(
null_distribution: np.ndarray,
query_distribution: np.ndarray,
recall_threshold_pairs: list,
right_sided: bool = False,
) -> dict:
"""Compute recall at given percentage thresholds for a query distribution with respect to a null distribution.
Each recall threshold is a pair of floats (left, right) where left and right are floats between 0 and 1.
Expand All @@ -206,6 +208,8 @@ def compute_recall(
query_distribution (np.ndarray): The query distribution
recall_threshold_pairs (list) A list of pairs of floats (left, right) that represent different recall threshold
pairs, where left and right are floats between 0 and 1.
right_sided (bool, optional): Whether to consider only right tail of the distribution or both tails when
computing recall Defaults to False (i.e, both tails).

Returns:
dict: A dictionary of metrics with the following keys:
Expand All @@ -227,9 +231,14 @@ def compute_recall(
)
for threshold_pair in recall_threshold_pairs:
left_threshold, right_threshold = np.min(threshold_pair), np.max(threshold_pair)
metrics[f"recall_{left_threshold}_{right_threshold}"] = sum(
(query_percentage_ranks_right <= left_threshold) | (query_percentage_ranks_left >= right_threshold)
) / len(query_distribution)
if right_sided:
metrics[f"recall_{left_threshold}_{right_threshold}"] = sum(
query_percentage_ranks_left >= right_threshold
) / len(query_distribution)
else:
metrics[f"recall_{left_threshold}_{right_threshold}"] = sum(
(query_percentage_ranks_right <= left_threshold) | (query_percentage_ranks_left >= right_threshold)
) / len(query_distribution)
return metrics


Expand Down Expand Up @@ -270,6 +279,7 @@ def multivariate_benchmark(
n_iterations: int = cst.RANDOM_COUNT,
min_req_entity_cnt: int = cst.MIN_REQ_ENT_CNT,
benchmark_data_dir: str = cst.BENCHMARK_DATA_DIR,
right_sided: bool = False,
) -> pd.DataFrame:
"""Perform benchmarking on map data.

Expand Down Expand Up @@ -321,7 +331,7 @@ def multivariate_benchmark(
if len(query_cossim) > 0:
metrics_lst.append(
convert_metrics_to_df(
metrics=compute_recall(null_cossim, query_cossim, recall_thr_pairs),
metrics=compute_recall(null_cossim, query_cossim, recall_thr_pairs, right_sided),
source=s,
random_seed_str=random_seed_str,
filter_on_pert_prints=filter_on_pert_prints,
Expand Down
12 changes: 10 additions & 2 deletions efaar_benchmarking/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import seaborn as sns


def plot_recall(metric_dfs: dict):
def plot_recall(metric_dfs: dict, right_sided: bool = False, title=""):
"""
Plots line plots of recall values for several threshold pairs for each benchmark source and each map.

Expand All @@ -14,14 +14,19 @@ def plot_recall(metric_dfs: dict):
Each dataframe needs to have a "source" column and "recall_X_Y" columns for several [X, Y] pairs.
Each metric dataframe in the dictionary corresponds to a different map. All the dataframes
need to have the exact same structure (ie same column names and same source values).
right_sided (bool): Whether to plot the right-sided recall values. Default is False.
title (str): Title for the entire plot. Default is an empty string.

Returns:
None
"""
df_template = list(metric_dfs.values())[0] # this is used as the template for the structure and labels of the plots
recall_thr_pairs = [col.split("_")[1:] for col in df_template.columns if col.startswith("recall_")]
x_values = [f"{x}, {y}" for x, y in recall_thr_pairs]
random_recall_values = [float(x) + 1 - float(y) for x, y in recall_thr_pairs]
if right_sided:
random_recall_values = [1 - float(y) for x, y in recall_thr_pairs]
else:
random_recall_values = [float(x) + 1 - float(y) for x, y in recall_thr_pairs]

col_cnt = 5
sns.set_style("whitegrid")
Expand All @@ -30,6 +35,9 @@ def plot_recall(metric_dfs: dict):
)
palette = dict(zip(metric_dfs.keys(), sns.color_palette("tab10", len(metric_dfs))))

# Set the title for the entire plot
fig.suptitle(title)

# Plot each source as a separate subplot
for i, source in enumerate(df_template["source"].unique()):
for m in metric_dfs.keys():
Expand Down
171 changes: 158 additions & 13 deletions notebooks/periscope_map_building.ipynb

Large diffs are not rendered by default.