diff --git a/README.md b/README.md index 9a43ebf..4e3c473 100644 --- a/README.md +++ b/README.md @@ -81,3 +81,13 @@ _Gilson,M.K., Liu,T., Baitaluk,M., Nicola,G., Hwang, L. and Chong,J. BindingDB i **Guide to Pharmacology:** _Harding SD, Armstrong JF, Faccenda E, Southan C, Alexander SPH, Davenport AP, Spedding M, Davies JA. (2023) The IUPHAR/BPS Guide to PHARMACOLOGY in 2024. Nucl. Acids Res. 2024; 52(D1):D1438-D1449. doi:10.1093/nar/gkad944. [Full text]. PMID: 37897341._ + + +## Gene-compound relationship benchmark +In `notebooks/rxrx3_core_benchmarks_openphenom.ipynb` we leverage a specialized benchmark to measure compound activity against a gene. + +This benchmark evaluates the zero-shot prediction of compound-gene activity using cosine similarities between model embeddings. Specifically, for each compound, we assess whether the cosine similarities correctly rank the compound's known target genes higher than a randomly sampled set of other genes from the ground truth dataset. + +To achieve this, we compute the cosine similarity between each compound and gene across all available concentrations and take the maximum similarity score for each pair. This approach captures the strongest potential interaction regardless of concentration, even if negatives come from different concentrations than positives. + +We then treat the absolute value of the cosine similarity as a confidence measure—similar to a classifier's probability score—and compute the AUC (Area Under the ROC Curve) and average precision for each compound. The final results report the median AUC and average precision across all compounds, compared against a random baseline. diff --git a/efaar_benchmarking/efaar.py b/efaar_benchmarking/efaar.py index eed8788..74eff3c 100644 --- a/efaar_benchmarking/efaar.py +++ b/efaar_benchmarking/efaar.py @@ -139,6 +139,35 @@ def centerscale_on_controls( return StandardScaler().fit(embeddings[control_ind]).transform(embeddings) +def pca_centerscale_on_controls( + embeddings: np.ndarray, + metadata: pd.DataFrame, + pert_col: str, + control_key: str, + batch_col: str | None = None, +) -> np.ndarray: + """ + Fit PCA on controls then center and scale the embeddings on the control perturbation units. + + Args: + embeddings (numpy.ndarray): The embeddings to be aligned. + metadata (pandas.DataFrame): The metadata containing information about the embeddings. + pert_col (str, optional): The column in the metadata containing perturbation information. + control_key (str, optional): The key for non-targeting controls in the metadata. + batch_col (str, optional): Column name in the metadata representing the batch labels. + Defaults to None. + Returns: + numpy.ndarray: The aligned embeddings. + """ + X = embeddings.copy() + X_controls = X[metadata[pert_col] == control_key] + if not len(X_controls): + raise ValueError(f"No control samples found for {control_key}") + pca = PCA().fit(X_controls) + X_pca = pca.transform(X) + return centerscale_on_controls(X_pca, metadata, pert_col, control_key, batch_col=batch_col) + + def tvn_on_controls( embeddings: np.ndarray, metadata: pd.DataFrame, diff --git a/notebooks/rxrx3_core_benchmarks_openphenom.ipynb b/notebooks/rxrx3_core_benchmarks_openphenom.ipynb index 67c3001..6e91862 100644 --- a/notebooks/rxrx3_core_benchmarks_openphenom.ipynb +++ b/notebooks/rxrx3_core_benchmarks_openphenom.ipynb @@ -13,10 +13,9 @@ "import pandas as pd\n", "\n", "from sklearn.utils import Bunch\n", - "from sklearn.decomposition import PCA\n", - "from sklearn.preprocessing import StandardScaler\n", "\n", - "from efaar_benchmarking.efaar import centerscale_on_controls\n", + "from efaar_benchmarking.constants import COMPOUND_CONCENTRATIONS\n", + "from efaar_benchmarking.efaar import pca_centerscale_on_controls\n", "from efaar_benchmarking.benchmarking import known_relationship_benchmark\n", "from efaar_benchmarking.benchmarking import compound_gene_benchmark, BenchmarkConfig" ] @@ -30,7 +29,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "/var/folders/rh/2pzrb9394871v5556j113wwh0000gr/T/ipykernel_9537/1030283650.py:2: DtypeWarning: Columns (6) have mixed types. Specify dtype option on import or set low_memory=False.\n", + "/var/folders/rh/2pzrb9394871v5556j113wwh0000gr/T/ipykernel_12082/1030283650.py:2: DtypeWarning: Columns (6) have mixed types. Specify dtype option on import or set low_memory=False.\n", " rxrx3_metadata = pd.read_csv(\"data/metadata_rxrx3_core.csv\") # visit https://rxrx3.rxrx.ai/downloads\n" ] } @@ -714,36 +713,6 @@ "cell_type": "code", "execution_count": 8, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['well_id',\n", - " 'experiment_name',\n", - " 'plate',\n", - " 'address',\n", - " 'gene',\n", - " 'treatment',\n", - " 'SMILES',\n", - " 'concentration',\n", - " 'perturbation_type',\n", - " 'cell_type',\n", - " 'perturbation']" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "metadata_columns" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, "outputs": [ { "name": "stdout", @@ -756,17 +725,10 @@ "source": [ "print(\"fitting aligner...\")\n", "X = embeddings_mrged[feature_columns].astype(float).values\n", - "pca = PCA().fit(X[embeddings_mrged[pert_colname] == control_key])\n", - "X_pca = pca.transform(X)\n", - "embeddings_pcacs = centerscale_on_controls(X_pca, embeddings_mrged[metadata_columns], pert_col=pert_colname, batch_col=experiment_colname, control_key=control_key)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ + "embeddings_pcacs = pca_centerscale_on_controls(\n", + " X, embeddings_mrged[metadata_columns], pert_col=pert_colname, batch_col=experiment_colname, control_key=control_key\n", + ")\n", + "\n", "assert embeddings_mrged[metadata_columns].shape[0] == embeddings_pcacs.shape[0]\n", "\n", "new_metadata = embeddings_mrged[metadata_columns].copy().reset_index()\n", @@ -776,12 +738,12 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "assert aligned_embeddings.feature_0.isna().sum() == 0\n", - " \n", + "\n", "# remove controls from henceforth analysis\n", "merged = aligned_embeddings[\n", " ~(\n", @@ -790,10 +752,7 @@ " )\n", "]\n", "\n", - "assert merged.feature_0.isna().sum() == 0\n", - "\n", - "COMPOUND_CONCENTRATIONS = [0.0025, 0.01, 0.025, 0.1, 0.25, 1.0, 2.5, 10.0]\n", - "\n", + "# aggregate to perturbation-level\n", "agg_func = {col: \"mean\" for col in merged.columns if col.startswith(\"feature_\")}\n", "map_data = (\n", " merged.groupby([\"perturbation_type\", pert_colname, \"concentration\"], dropna=False)\n", @@ -809,7 +768,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -848,7 +807,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -866,7 +825,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -889,7 +848,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -902,15 +861,15 @@ { "error_y": { "array": [ - 0.005427170341253574, - 0.003968513891837956, - 0.006861122433321745, - 0.006448842852500994, - 0.007338013865538547, - 0.007403709711191922, - 0.007201013100503151, - 0.010578876163379507, - 0.0122309780032886 + 0.010025315663722797, + 0.012151388474140257, + 0.01137420598220815, + 0.011715270828890607, + 0.010711572746344678, + 0.009954499433996455, + 0.010518603812478829, + 0.012086549375033832, + 0.013847861813751746 ], "type": "data" }, @@ -929,15 +888,15 @@ "max" ], "y": [ - 0.2392185099845927, - 0.23287536370433584, - 0.2519754153504984, - 0.2639562672085235, - 0.256128562599406, - 0.25830869301236825, - 0.2751461825793374, - 0.25763465957518206, - 0.2709983702686366 + 0.2390593257318816, + 0.23586980534623878, + 0.2520869334858552, + 0.262519263998251, + 0.2571483621603731, + 0.2590521375426392, + 0.2752186265277348, + 0.2583475495464243, + 0.266096852004128 ] }, { @@ -1837,15 +1796,15 @@ { "error_y": { "array": [ - 0.008681647108043332, - 0.006493802234714557, - 0.008257830374794215, - 0.005426473836768565, - 0.006857108878281468, - 0.008300603784993879, - 0.0069804181707088945, - 0.010583634500496602, - 0.011247044944917134 + 0.014643989170915733, + 0.016543434232447352, + 0.015399697256825986, + 0.015921518156114742, + 0.01767328027273109, + 0.015542969791387078, + 0.017807323580870175, + 0.015586358353045004, + 0.021054205333749153 ], "type": "data" }, @@ -1864,15 +1823,15 @@ "max" ], "y": [ - 0.512941720916961, - 0.512415769393954, - 0.5406047372244298, - 0.5427598690938227, - 0.5268041483840833, - 0.5334087431891422, - 0.5358461121840867, - 0.535814581524895, - 0.5388132765203012 + 0.516564315515448, + 0.513473719498011, + 0.5431615538882418, + 0.5439217673129079, + 0.5282081366334797, + 0.5357906399857973, + 0.5395486843475453, + 0.5374198480954333, + 0.5408863882491511 ] }, { @@ -2819,13 +2778,6 @@ " fig.show()" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "code", "execution_count": null,