Skip to content

Commit

Permalink
move pca-cs to efaar, update readme, and re-run notebook with feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
kian-kd committed Nov 11, 2024
1 parent 0775226 commit 66bf5af
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 98 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,13 @@ _Gilson,M.K., Liu,T., Baitaluk,M., Nicola,G., Hwang, L. and Chong,J. BindingDB i
**Guide to Pharmacology:**

_Harding SD, Armstrong JF, Faccenda E, Southan C, Alexander SPH, Davenport AP, Spedding M, Davies JA. (2023) The IUPHAR/BPS Guide to PHARMACOLOGY in 2024. Nucl. Acids Res. 2024; 52(D1):D1438-D1449. doi:10.1093/nar/gkad944. [Full text]. PMID: 37897341._


## Gene-compound relationship benchmark
In `notebooks/rxrx3_core_benchmarks_openphenom.ipynb` we leverage a specialized benchmark to measure compound activity against a gene.

This benchmark evaluates the zero-shot prediction of compound-gene activity using cosine similarities between model embeddings. Specifically, for each compound, we assess whether the cosine similarities correctly rank the compound's known target genes higher than a randomly sampled set of other genes from the ground truth dataset.

To achieve this, we compute the cosine similarity between each compound and gene across all available concentrations and take the maximum similarity score for each pair. This approach captures the strongest potential interaction regardless of concentration, even if negatives come from different concentrations than positives.

We then treat the absolute value of the cosine similarity as a confidence measure—similar to a classifier's probability score—and compute the AUC (Area Under the ROC Curve) and average precision for each compound. The final results report the median AUC and average precision across all compounds, compared against a random baseline.
29 changes: 29 additions & 0 deletions efaar_benchmarking/efaar.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,35 @@ def centerscale_on_controls(
return StandardScaler().fit(embeddings[control_ind]).transform(embeddings)


def pca_centerscale_on_controls(
embeddings: np.ndarray,
metadata: pd.DataFrame,
pert_col: str,
control_key: str,
batch_col: str | None = None,
) -> np.ndarray:
"""
Fit PCA on controls then center and scale the embeddings on the control perturbation units.
Args:
embeddings (numpy.ndarray): The embeddings to be aligned.
metadata (pandas.DataFrame): The metadata containing information about the embeddings.
pert_col (str, optional): The column in the metadata containing perturbation information.
control_key (str, optional): The key for non-targeting controls in the metadata.
batch_col (str, optional): Column name in the metadata representing the batch labels.
Defaults to None.
Returns:
numpy.ndarray: The aligned embeddings.
"""
X = embeddings.copy()
X_controls = X[metadata[pert_col] == control_key]
if not len(X_controls):
raise ValueError(f"No control samples found for {control_key}")
pca = PCA().fit(X_controls)
X_pca = pca.transform(X)
return centerscale_on_controls(X_pca, metadata, pert_col, control_key, batch_col=batch_col)


def tvn_on_controls(
embeddings: np.ndarray,
metadata: pd.DataFrame,
Expand Down
148 changes: 50 additions & 98 deletions notebooks/rxrx3_core_benchmarks_openphenom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,9 @@
"import pandas as pd\n",
"\n",
"from sklearn.utils import Bunch\n",
"from sklearn.decomposition import PCA\n",
"from sklearn.preprocessing import StandardScaler\n",
"\n",
"from efaar_benchmarking.efaar import centerscale_on_controls\n",
"from efaar_benchmarking.constants import COMPOUND_CONCENTRATIONS\n",
"from efaar_benchmarking.efaar import pca_centerscale_on_controls\n",
"from efaar_benchmarking.benchmarking import known_relationship_benchmark\n",
"from efaar_benchmarking.benchmarking import compound_gene_benchmark, BenchmarkConfig"
]
Expand All @@ -30,7 +29,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
"/var/folders/rh/2pzrb9394871v5556j113wwh0000gr/T/ipykernel_9537/1030283650.py:2: DtypeWarning: Columns (6) have mixed types. Specify dtype option on import or set low_memory=False.\n",
"/var/folders/rh/2pzrb9394871v5556j113wwh0000gr/T/ipykernel_12082/1030283650.py:2: DtypeWarning: Columns (6) have mixed types. Specify dtype option on import or set low_memory=False.\n",
" rxrx3_metadata = pd.read_csv(\"data/metadata_rxrx3_core.csv\") # visit https://rxrx3.rxrx.ai/downloads\n"
]
}
Expand Down Expand Up @@ -714,36 +713,6 @@
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['well_id',\n",
" 'experiment_name',\n",
" 'plate',\n",
" 'address',\n",
" 'gene',\n",
" 'treatment',\n",
" 'SMILES',\n",
" 'concentration',\n",
" 'perturbation_type',\n",
" 'cell_type',\n",
" 'perturbation']"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"metadata_columns"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
Expand All @@ -756,17 +725,10 @@
"source": [
"print(\"fitting aligner...\")\n",
"X = embeddings_mrged[feature_columns].astype(float).values\n",
"pca = PCA().fit(X[embeddings_mrged[pert_colname] == control_key])\n",
"X_pca = pca.transform(X)\n",
"embeddings_pcacs = centerscale_on_controls(X_pca, embeddings_mrged[metadata_columns], pert_col=pert_colname, batch_col=experiment_colname, control_key=control_key)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"embeddings_pcacs = pca_centerscale_on_controls(\n",
" X, embeddings_mrged[metadata_columns], pert_col=pert_colname, batch_col=experiment_colname, control_key=control_key\n",
")\n",
"\n",
"assert embeddings_mrged[metadata_columns].shape[0] == embeddings_pcacs.shape[0]\n",
"\n",
"new_metadata = embeddings_mrged[metadata_columns].copy().reset_index()\n",
Expand All @@ -776,12 +738,12 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"assert aligned_embeddings.feature_0.isna().sum() == 0\n",
" \n",
"\n",
"# remove controls from henceforth analysis\n",
"merged = aligned_embeddings[\n",
" ~(\n",
Expand All @@ -790,10 +752,7 @@
" )\n",
"]\n",
"\n",
"assert merged.feature_0.isna().sum() == 0\n",
"\n",
"COMPOUND_CONCENTRATIONS = [0.0025, 0.01, 0.025, 0.1, 0.25, 1.0, 2.5, 10.0]\n",
"\n",
"# aggregate to perturbation-level\n",
"agg_func = {col: \"mean\" for col in merged.columns if col.startswith(\"feature_\")}\n",
"map_data = (\n",
" merged.groupby([\"perturbation_type\", pert_colname, \"concentration\"], dropna=False)\n",
Expand All @@ -809,7 +768,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -848,7 +807,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -866,7 +825,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -889,7 +848,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand All @@ -902,15 +861,15 @@
{
"error_y": {
"array": [
0.005427170341253574,
0.003968513891837956,
0.006861122433321745,
0.006448842852500994,
0.007338013865538547,
0.007403709711191922,
0.007201013100503151,
0.010578876163379507,
0.0122309780032886
0.010025315663722797,
0.012151388474140257,
0.01137420598220815,
0.011715270828890607,
0.010711572746344678,
0.009954499433996455,
0.010518603812478829,
0.012086549375033832,
0.013847861813751746
],
"type": "data"
},
Expand All @@ -929,15 +888,15 @@
"max"
],
"y": [
0.2392185099845927,
0.23287536370433584,
0.2519754153504984,
0.2639562672085235,
0.256128562599406,
0.25830869301236825,
0.2751461825793374,
0.25763465957518206,
0.2709983702686366
0.2390593257318816,
0.23586980534623878,
0.2520869334858552,
0.262519263998251,
0.2571483621603731,
0.2590521375426392,
0.2752186265277348,
0.2583475495464243,
0.266096852004128
]
},
{
Expand Down Expand Up @@ -1837,15 +1796,15 @@
{
"error_y": {
"array": [
0.008681647108043332,
0.006493802234714557,
0.008257830374794215,
0.005426473836768565,
0.006857108878281468,
0.008300603784993879,
0.0069804181707088945,
0.010583634500496602,
0.011247044944917134
0.014643989170915733,
0.016543434232447352,
0.015399697256825986,
0.015921518156114742,
0.01767328027273109,
0.015542969791387078,
0.017807323580870175,
0.015586358353045004,
0.021054205333749153
],
"type": "data"
},
Expand All @@ -1864,15 +1823,15 @@
"max"
],
"y": [
0.512941720916961,
0.512415769393954,
0.5406047372244298,
0.5427598690938227,
0.5268041483840833,
0.5334087431891422,
0.5358461121840867,
0.535814581524895,
0.5388132765203012
0.516564315515448,
0.513473719498011,
0.5431615538882418,
0.5439217673129079,
0.5282081366334797,
0.5357906399857973,
0.5395486843475453,
0.5374198480954333,
0.5408863882491511
]
},
{
Expand Down Expand Up @@ -2819,13 +2778,6 @@
" fig.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
Expand Down

0 comments on commit 66bf5af

Please sign in to comment.