diff --git a/efaar_benchmarking/benchmarking.py b/efaar_benchmarking/benchmarking.py index 6d3f7ab..6aec832 100644 --- a/efaar_benchmarking/benchmarking.py +++ b/efaar_benchmarking/benchmarking.py @@ -44,7 +44,7 @@ def pert_stats( def benchmark( map_data: Bunch, benchmark_sources: list = cst.BENCHMARK_SOURCES, - pert_label_col: str = cst.PERT_LABEL_COL, + pert_label_col: str = cst.REPLOGLE_PERT_LABEL_COL, recall_thr_pairs: list = cst.RECALL_PERC_THRS, filter_on_pert_prints: bool = False, pert_pval_thr: float = cst.PERT_SIG_PVAL_THR, @@ -78,15 +78,15 @@ def benchmark( """ if not len(benchmark_sources) > 0 and all([src in benchmark_data_dir for src in benchmark_sources]): - AssertionError("Invalid benchmark source(s) provided.") + ValueError("Invalid benchmark source(s) provided.") md = map_data.metadata idx = (md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr) if filter_on_pert_prints else [True] * len(md) features = map_data.features[idx].set_index(md[idx][pert_label_col]).rename_axis(index=None) del map_data if not len(features) == len(set(features.index)): - AssertionError("Duplicate perturbation labels in the map.") + ValueError("Duplicate perturbation labels in the map.") if not len(features) >= min_req_entity_cnt: - AssertionError("Not enough entities in the map for benchmarking.") + ValueError("Not enough entities in the map for benchmarking.") print(len(features), "perturbations exist in the map.") metrics_lst = [] @@ -98,7 +98,9 @@ def benchmark( random_seed_str = f"{rs1}_{rs2}" null_cossim = generate_null_cossims(features, n_null_samples, rs1, rs2) for s in benchmark_sources: - query_cossim = generate_query_cossims(features, get_benchmark_relationships(benchmark_data_dir, s)) + rels = get_benchmark_relationships(benchmark_data_dir, s) + print(len(rels), "relationships exist in the benchmark source.") + query_cossim = generate_query_cossims(features, rels) if len(query_cossim) > 0: metrics_lst.append( convert_metrics_to_df( diff --git a/efaar_benchmarking/constants.py b/efaar_benchmarking/constants.py index 454606f..951bf09 100644 --- a/efaar_benchmarking/constants.py +++ b/efaar_benchmarking/constants.py @@ -2,12 +2,16 @@ BENCHMARK_DATA_DIR = str(resources.files("efaar_benchmarking").joinpath("benchmark_annotations")) BENCHMARK_SOURCES = ["Reactome", "HuMAP", "CORUM", "SIGNOR", "StringDB"] -PERT_LABEL_COL = "gene" -CONTROL_PERT_LABEL = "non-targeting" -PERT_SIG_PVAL_COL = "gene_pvalue" -PERT_SIG_PVAL_THR = 0.01 RECALL_PERC_THRS = [(0.05, 0.95), (0.1, 0.9)] RANDOM_SEED = 42 RANDOM_COUNT = 3 N_NULL_SAMPLES = 5000 MIN_REQ_ENT_CNT = 20 +PERT_SIG_PVAL_COL = "gene_pvalue" +PERT_SIG_PVAL_THR = 0.01 +CONTROL_PERT_LABEL = "non-targeting" +REPLOGLE_PERT_LABEL_COL = "gene" +REPLOGLE_BATCH_COL = "gem_group" +JUMP_PERT_LABEL_COL = "Metadata_Symbol" +JUMP_PLATE_COL = "Metadata_Plate" +JUMP_BATCH_COL = "Metadata_Batch" diff --git a/efaar_benchmarking/data_loading.py b/efaar_benchmarking/data_loading.py index ee9d9e9..542a0f4 100644 --- a/efaar_benchmarking/data_loading.py +++ b/efaar_benchmarking/data_loading.py @@ -1,17 +1,95 @@ import os +import shutil +import tempfile +import zipfile +from concurrent.futures import ThreadPoolExecutor, as_completed import numpy as np +import pandas as pd import scanpy as sc import wget -def load_replogle(gene_type, data_type, data_path="data/"): +def load_cpg16_crispr(data_path: str = "data/") -> tuple[pd.DataFrame, pd.DataFrame]: """ - Load Replogle et al. 2020 single-cell RNA-seq data for K562 cells. + Load and return the JUMP-CP (cpg0016) CRISPR dataset. + The metadata is downloaded from here: + https://zenodo.org/records/7661296/files/jump-cellpainting/metadata-v0.5.0.zip?download=1 + The cellprofiler features are downloaded from here: + https://cellpainting-gallery.s3.amazonaws.com/index.html#cpg0016-jump/ + We read the metadata first, filter it to CRISPR plates, and download the features for these plates only. + + Parameters: + data_path (str): Path to the directory containing the dataset files. + + Returns: + tuple[pd.DataFrame, pd.DataFrame]: A tuple containing two DataFrames: + - features: A DataFrame containing the CRISPR dataset features. + - metadata: A DataFrame containing the CRISPR dataset metadata. + """ + metadata_source_path = "https://zenodo.org/records/7661296/files/jump-cellpainting/datasets-v0.5.0.zip?download=1" + plate_file_name = "plate.csv.gz" + well_file_name = "well.csv.gz" + crispr_file_name = "crispr.csv.gz" + plate_file_path = os.path.join(data_path, plate_file_name) + well_file_path = os.path.join(data_path, well_file_name) + crispr_file_path = os.path.join(data_path, crispr_file_name) + if not (os.path.exists(plate_file_path) and os.path.exists(well_file_path) and os.path.exists(crispr_file_path)): + path_to_zip_file = data_path + "tmp.zip" + wget.download(metadata_source_path, path_to_zip_file) + with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref: + for name in zip_ref.namelist(): + if name.endswith(plate_file_name) or name.endswith(well_file_name) or name.endswith(crispr_file_name): + with tempfile.TemporaryDirectory() as temp_dir: + zip_ref.extract(name, temp_dir) + shutil.move(os.path.join(temp_dir, name), os.path.join(data_path, os.path.basename(name))) + os.remove(path_to_zip_file) + + plates = pd.read_csv(plate_file_path) + crispr_plates = plates.query("Metadata_PlateType=='CRISPR'") + wells = pd.read_csv(well_file_path) + well_plate = wells.merge(crispr_plates, on=["Metadata_Source", "Metadata_Plate"]) + crispr = pd.read_csv(crispr_file_path) + metadata = well_plate.merge(crispr, on="Metadata_JCP2022") + + cp_feature_source_formatter = ( + "s3://cellpainting-gallery/cpg0016-jump/" + "{Metadata_Source}/workspace/profiles/" + "{Metadata_Batch}/{Metadata_Plate}/{Metadata_Plate}.parquet" + ) + + features_file_path = os.path.join(data_path, "cpg_features.parquet") + if not os.path.exists(features_file_path): + cripsr_plates = metadata[ + ["Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_PlateType"] + ].drop_duplicates() + features = [] + with ThreadPoolExecutor(max_workers=10) as executer: + future_to_plate = { + executer.submit( + lambda path: pd.read_parquet(path, storage_options={"anon": True}), + cp_feature_source_formatter.format(**row.to_dict()), + ): cp_feature_source_formatter.format(**row.to_dict()) + for _, row in cripsr_plates.iterrows() + } + for future in as_completed(future_to_plate): + features.append(future.result()) + pd.concat(features).to_parquet(features_file_path) + features = pd.read_parquet(features_file_path).dropna(axis=1) + return features, metadata + + +def load_replogle(gene_type: str, data_type: str, data_path: str = "data/") -> sc.AnnData: + """ + Load Replogle et al. 2022 single-cell RNA-seq data for K562 cells published here: + https://pubmed.ncbi.nlm.nih.gov/35688146/ + Four types of K562 data and downloaded using the links at: + plus.figshare.com/articles/dataset/_Mapping_information-rich_genotype-phenotype_landscapes_with_genome-scale_Perturb-seq_Replogle_et_al_2022_processed_Perturb-seq_datasets/20029387 Parameters: gene_type (str): Type of genes to load. Must be either 'essential' or 'genome_wide'. data_type (str): Type of data to load. Must be either 'raw' or 'normalized'. + Normalized means Z-normalized by gemgroup. data_path (str): Path to the directory where the data will be downloaded and saved. Returns: diff --git a/efaar_benchmarking/efaar.py b/efaar_benchmarking/efaar.py index 75bbb3d..9f40ba2 100644 --- a/efaar_benchmarking/efaar.py +++ b/efaar_benchmarking/efaar.py @@ -1,75 +1,139 @@ +from typing import Optional + import numpy as np import pandas as pd import scanpy as sc from scvi.model import SCVI +from sklearn.covariance import EllipticEnvelope +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler from sklearn.utils import Bunch import efaar_benchmarking.constants as cst -def embed_by_scvi(adata, batch_key="gem_group", n_latent=128, n_hidden=256): +def embed_by_scvi_anndata(adata, batch_col=cst.REPLOGLE_BATCH_COL, n_latent=128, n_hidden=256) -> np.ndarray: """ - Embeds the input AnnData object using scVI. + Embed the input AnnData object using scVI. - Parameters: - adata (anndata.AnnData): The AnnData object to be embedded. - batch_key (str): The batch key in the AnnData object. Default is "gem_group". - n_latent (int): The number of latent dimensions. Default is 128. - n_hidden (int): The number of hidden dimensions. Default is 256. + Args: + adata (anndata.AnnData): The AnnData object to be embedded. + batch_col (str): The batch key in the AnnData object. Default is "gem_group". + n_latent (int): The number of latent dimensions. Default is 128. + n_hidden (int): The number of hidden dimensions. Default is 256. Returns: - None + numpy.ndarray: Embedding of the input data using scVI. """ - SCVI.setup_anndata(adata, batch_key=batch_key) + SCVI.setup_anndata(adata, batch_key=batch_col) vae = SCVI(adata, n_hidden=n_hidden, n_latent=n_latent) vae.train(use_gpu=True) return vae.get_latent_representation() -def embed_by_pca(adata, n_latent=100): +def embed_by_pca_anndata(adata, n_latent=100) -> np.ndarray: """ - Embeds the input data using PCA. + Embed the input data using principal component analysis (PCA). + Note that the data is centered by the `pca` function prior to PCA transformation. - Parameters: - adata (AnnData): Annotated data matrix. - n_latent (int): Number of principal components to use. + Args: + adata (AnnData): Annotated data matrix. + n_latent (int): Number of principal components to use. Returns: - numpy.ndarray: Embedding of the input data using PCA. + numpy.ndarray: Embedding of the input data using PCA. """ sc.pp.pca(adata, n_comps=n_latent) return adata.obsm["X_pca"] -def align_by_centering(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL, pert_col=cst.PERT_LABEL_COL): +def embed_align_by_pca( + features: np.ndarray, + metadata: pd.DataFrame = None, + variance_or_ncomp=100, + plate_col: Optional[str] = None, +) -> np.ndarray: + """ + Embed the input data using principal component analysis (PCA). + Note that we explicitly center & scale the data by plate before and after calling `PCA`. + Centering and scaling is done by plate if `plate_col` is not None, and on the whole data otherwise. + Note that `PCA` transformer also does mean-centering on the whole data prior to the PCA operation. + Args: + features (np.ndarray): Features to transform + metadata (pd.DataFrame): Metadata. Defaults to None. + variance_or_ncomp (float, optional): Variance or number of components to keep after PCA. + Defaults to 100 (n_components). If between 0 and 1, select the number of components such that + the amount of variance that needs to be explained is greater than the percentage specified. + plate_col (str, optional): Column name for plate metadata. Defaults to None. + Returns: + np.ndarray: Transformed data using PCA. + """ + + def centerscale(features, metadata, plate_col): + if plate_col is None: + features = StandardScaler().fit_transform(features) + else: + if metadata is None: + raise ValueError("metadata must be provided if plate_col is not None") + unq_plates = metadata[plate_col].unique() + for plate in unq_plates: + ind = metadata[plate_col] == plate + features[ind, :] = StandardScaler().fit_transform(features[ind, :]) + return features + + features = centerscale(features, metadata, plate_col) + features = PCA(variance_or_ncomp).fit_transform(features) + features = centerscale(features, metadata, plate_col) + + return features + + +def align_on_controls( + embeddings: np.ndarray, + metadata: pd.DataFrame, + scale: bool = True, + pert_col: str = cst.REPLOGLE_PERT_LABEL_COL, + control_key: str = cst.CONTROL_PERT_LABEL, +) -> np.ndarray: """ - Applies the centerscale method to align embeddings based on the centering perturbations in the metadata. + Center the embeddings by the control perturbation units in the metadata. Args: embeddings (numpy.ndarray): The embeddings to be aligned. metadata (pandas.DataFrame): The metadata containing information about the embeddings. + scale (bool): Whether to scale the embeddings besides centering. Defaults to True. + pert_col (str, optional): The column in the metadata containing perturbation information. + Defaults to cst.REPLOGLE_PERT_LABEL_COL. control_key (str, optional): The key for non-targeting controls in the metadata. Defaults to cst.CONTROL_PERT_LABEL. - pert_col (str, optional): The column in the metadata containing perturbation information. - Defaults to cst.PERT_LABEL_COL. Returns: numpy.ndarray: The aligned embeddings. """ - ntc_idxs = np.where(metadata[pert_col].values == control_key)[0] - ntc_center = embeddings[ntc_idxs].mean(0) - return embeddings - ntc_center - - -def aggregate_by_mean(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL, pert_col=cst.PERT_LABEL_COL): + ss = StandardScaler() if scale else StandardScaler(with_std=False) + ss.fit(embeddings[metadata[pert_col].values == control_key]) + return ss.transform(embeddings) + + +def aggregate( + embeddings: np.ndarray, + metadata: pd.DataFrame, + pert_col: str = cst.REPLOGLE_PERT_LABEL_COL, + control_key: str = cst.CONTROL_PERT_LABEL, + method="mean", +) -> Bunch[pd.DataFrame, pd.DataFrame]: """ - Applies the mean aggregation to aggregate replicate embeddings for each perturbation. + Apply the mean or median aggregation to replicate embeddings for each perturbation. Args: embeddings (numpy.ndarray): The embeddings to be aggregated. metadata (pandas.DataFrame): The metadata containing information about the embeddings. - control_key (str, optional): The key for non-targeting controls in the metadata. Defaults to "non-targeting". - PERT_COL (str, optional): The column in the metadata containing perturbation information. Defaults to "gene". + pert_col (str, optional): The column in the metadata containing perturbation information. + Defaults to cst.REPLOGLE_PERT_LABEL_COL. + control_key (str, optional): The key for non-targeting controls in the metadata. + Defaults to cst.CONTROL_PERT_LABEL. + method (str, optional): The aggregation method to use. Must be either "mean" or "median". + Defaults to "mean". Returns: Bunch: A named tuple containing two pandas DataFrames: @@ -81,5 +145,50 @@ def aggregate_by_mean(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL, final_embeddings = np.zeros((len(unique_perts), embeddings.shape[1])) for i, pert in enumerate(unique_perts): idxs = np.where(metadata[pert_col].values == pert)[0] - final_embeddings[i, :] = embeddings[idxs, :].mean(0) + if method == "mean": + final_embeddings[i, :] = np.mean(embeddings[idxs, :], axis=0) + elif method == "median": + final_embeddings[i, :] = np.median(embeddings[idxs, :], axis=0) + else: + raise ValueError(f"Invalid aggregation method: {method}") return Bunch(features=pd.DataFrame(final_embeddings), metadata=pd.DataFrame.from_dict({pert_col: unique_perts})) + + +def filter_cpg16_crispr( + features: pd.DataFrame, + metadata: pd.DataFrame, + filter_by_intensity: bool = True, + filter_by_cell_count: bool = True, + drop_image_cols: bool = True, +) -> tuple[pd.DataFrame, pd.DataFrame]: + """ + Filter the given features and metadata dataframes based on various criteria. + + Args: + features (pd.DataFrame): A dataframe containing the features to filter. + metadata (pd.DataFrame): A dataframe containing the metadata to filter. + filter_by_intensity (bool, optional): Whether to filter by intensity. Defaults to True. + filter_by_cell_count (bool, optional): Whether to filter by cell count. Defaults to True. + drop_image_cols (bool, optional): Whether to drop image columns. Defaults to True. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the filtered features and metadata dataframes. + """ + if filter_by_intensity: + features = features.loc[ + EllipticEnvelope(contamination=0.01, random_state=42).fit_predict( + features[[col for col in features.columns if "ImageQuality_MeanIntensity" in col]] + ) + == 1 + ] + if filter_by_cell_count: + mask = np.full(len(features), True) + for colname in ["Cytoplasm_Number_Object_Number", "Nuclei_Number_Object_Number"]: + mask = mask & (features[colname] >= 50) & (features[colname] <= 350) + features = features.loc[mask] + if drop_image_cols: + features = features.drop(columns=[col for col in features.columns if col.startswith("Image_")]) + + metadata_cols = metadata.columns + merged_data = metadata.merge(features, on=["Metadata_Source", "Metadata_Plate", "Metadata_Well"]) + return merged_data.drop(columns=metadata_cols), merged_data[metadata_cols] diff --git a/efaar_benchmarking/plotting.py b/efaar_benchmarking/plotting.py index ee04776..ea52167 100644 --- a/efaar_benchmarking/plotting.py +++ b/efaar_benchmarking/plotting.py @@ -49,7 +49,6 @@ def plot_recall(df): curr_ax.set_ylabel("Recall Value") curr_ax.set_xticks(range(len(x_values_orig))) curr_ax.set_xticklabels(x_values_orig, rotation=45) - curr_ax.set_ylim(0, 0.8) plt.tight_layout() plt.show() diff --git a/notebooks/data/crispr.csv.gz b/notebooks/data/crispr.csv.gz deleted file mode 100644 index ee54baa..0000000 Binary files a/notebooks/data/crispr.csv.gz and /dev/null differ diff --git a/notebooks/data/plate.csv.gz b/notebooks/data/plate.csv.gz deleted file mode 100644 index 0239e67..0000000 Binary files a/notebooks/data/plate.csv.gz and /dev/null differ diff --git a/notebooks/data/well.csv.gz b/notebooks/data/well.csv.gz deleted file mode 100644 index c46e091..0000000 Binary files a/notebooks/data/well.csv.gz and /dev/null differ diff --git a/notebooks/jump_map_building.ipynb b/notebooks/jump_map_building.ipynb index 66a179b..a11b450 100644 --- a/notebooks/jump_map_building.ipynb +++ b/notebooks/jump_map_building.ipynb @@ -19,1098 +19,6 @@ "limitations under the License." ] }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Building and Benchmarking a Biological Map with JUMP-Cell Painting Consortium Cell Profiler Features\n", - "\n", - "In this notebook, we create and benchmark maps based on the JUMP-CP imaging dataset featurized with Cell Profiler. The focus will be on creating simple EFAAR pipelines and benchmarking using the framework to see how changes in the pipeline effect the metrics.\n", - "\n", - "The primary components of an EFAAR Pipeline are \n", - "\n", - "- **Embed:** A high dimensional featurization of a biological perturbation\n", - "- **Filter:** A preprocessing step where embeddings are discarded based on statistical, machine learning thresholds or external metadata or features \n", - "- **Align:** Process to combine the embeddings into the same high dimensional space while increasing the the signal and decreasing noise in the data \n", - "- **Aggregate:** Combining any replicates of perturbations into a single representation for comparison \n", - "- **Relate:** Any metric that allows pairwise comparison between perturbations\n", - "\n", - "Additionally, we will examine the presence of chromosomal proximity bias in the data and alter the pipeline to decrease its effect on relationships. " - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Imports" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Uncomment the following lines to run this notebook in google Colab\n", - "\n", - "# !git clone https://github.com/recursionpharma/EFAAR_benchmarking.git\n", - "# import sys\n", - "# sys.path.insert(0, \"/content/EFAAR_benchmarking\")\n", - "# import os\n", - "# os.chdir(\"EFAAR_benchmarking/notebooks\")\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_0.parquet data/cpg_0016_crispr_0.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_1.parquet data/cpg_0016_crispr_1.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_2.parquet data/cpg_0016_crispr_2.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_3.parquet data/cpg_0016_crispr_3.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_4.parquet data/cpg_0016_crispr_4.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_5.parquet data/cpg_0016_crispr_5.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_6.parquet data/cpg_0016_crispr_6.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_7.parquet data/cpg_0016_crispr_7.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_8.parquet data/cpg_0016_crispr_8.parquet\n", - "# !gsutil cp gs://rxrx-cytodata2023-public/cpg_0016_crispr_9.parquet data/cpg_0016_crispr_9.parquet" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "import concurrent.futures\n", - "from concurrent.futures import ThreadPoolExecutor\n", - "from typing import List, Tuple, Dict, Any, Optional\n", - "\n", - "import pandas as pd\n", - "pd.options.mode.chained_assignment = None\n", - "import numpy as np\n", - "import scipy as sp\n", - "import sklearn\n", - "from sklearn.preprocessing import StandardScaler\n", - "from sklearn.decomposition import PCA\n", - "from sklearn.covariance import EllipticEnvelope\n", - "import matplotlib.pyplot as plt\n", - "from scipy.cluster import hierarchy\n", - "from scipy.spatial.distance import pdist\n", - "from skimage.measure import block_reduce\n", - "\n", - "import seaborn as sns\n", - "import plotly.graph_objects as go\n", - "from sklearn.utils import Bunch\n", - "import scipy.linalg as linalg\n", - "\n", - "from efaar_benchmarking.benchmarking import benchmark as bm\n", - "from efaar_benchmarking.utils import get_benchmark_metrics\n", - "\n", - "import matplotlib as mpl\n", - "import matplotlib.pyplot as plt\n", - "\n", - "sns.set_theme()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Loading\n", - "Load the cpg0016 CRISPR data. The metadata is not available via an API so it has been copied from the Cell Painting gallery for ease of use. But the features can be easily pulled from the S3 bucket provided by the JUMP consortium. We pickle the raw data after loading to avoid repeatedly pulling from the S3 bucket." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "CP_FEATURE_FORMATTER = (\n", - " \"s3://cellpainting-gallery/cpg0016-jump/\"\n", - " \"{Metadata_Source}/workspace/profiles/\"\n", - " \"{Metadata_Batch}/{Metadata_Plate}/{Metadata_Plate}.parquet\"\n", - ")\n", - "\n", - "CPG_METADATA_COLS = [\n", - " \"Metadata_Source\",\n", - " \"Metadata_Plate\",\n", - " \"Metadata_Well\",\n", - " \"Metadata_JCP2022\",\n", - " \"Metadata_Batch\",\n", - " \"Metadata_PlateType\",\n", - " \"Metadata_NCBI_Gene_ID\",\n", - " \"Metadata_Symbol\",\n", - "]\n", - "\n", - "\n", - "def load_cpg_crispr_well_metadata():\n", - " \"\"\"Load well metadata for CRISPR plates from Cell Painting Gallery.\"\"\"\n", - " plates = pd.read_csv(\"data/plate.csv.gz\")\n", - " crispr_plates = plates.query(\"Metadata_PlateType=='CRISPR'\")\n", - " wells = pd.read_csv(\"data/well.csv.gz\")\n", - " crispr = pd.read_csv(\"data/crispr.csv.gz\")\n", - "\n", - " well_plate = wells.merge(crispr_plates, on=[\"Metadata_Source\", \"Metadata_Plate\"])\n", - " crispr_well_metadata = well_plate.merge(crispr, on=\"Metadata_JCP2022\")\n", - " return crispr_well_metadata\n", - "\n", - "\n", - "def _load_plate_features(path: str):\n", - " try:\n", - " df = pd.read_parquet(\"data/\" + path.split(\"/\")[-1])\n", - " except FileNotFoundError:\n", - " df = pd.read_parquet(path, storage_options={\"anon\": True})\n", - " df.to_parquet(\"data/\" + path.split(\"/\")[-1])\n", - " return df\n", - "\n", - "\n", - "def load_feature_data(metadata_df: pd.DataFrame, max_workers=4) -> pd.DataFrame:\n", - " \"\"\"Load feature data from Cell Painting Gallery from metadata dataframe.\n", - "\n", - " Parameters\n", - " ----------\n", - " metadata_df : pd.DataFrame\n", - " Well metadata dataframe\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " Well features\n", - " \"\"\"\n", - " cripsr_plates = metadata_df[\n", - " [\"Metadata_Source\", \"Metadata_Batch\", \"Metadata_Plate\", \"Metadata_PlateType\"]\n", - " ].drop_duplicates()\n", - " data = []\n", - " with ThreadPoolExecutor(max_workers=max_workers) as executer:\n", - " future_to_plate = {\n", - " executer.submit(\n", - " _load_plate_features, CP_FEATURE_FORMATTER.format(**row.to_dict())\n", - " ): CP_FEATURE_FORMATTER.format(**row.to_dict())\n", - " for _, row in cripsr_plates.iterrows()\n", - " }\n", - " for future in concurrent.futures.as_completed(future_to_plate):\n", - " data.append(future.result())\n", - " return pd.concat(data)\n", - "\n", - "\n", - "def build_combined_data(metadata: pd.DataFrame, features: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"Join well metadata and well features\n", - "\n", - " Parameters\n", - " ----------\n", - " metadata : pd.DataFrame\n", - " Well metadata\n", - " features : pd.DataFrame\n", - " Well features\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " Combined dataframe\n", - " \"\"\"\n", - " return metadata.merge(features, on=[\"Metadata_Source\", \"Metadata_Plate\", \"Metadata_Well\"])\n", - "\n", - "\n", - "def read_parquets_from_gcs():\n", - " data = []\n", - " for i in range(10):\n", - " try:\n", - " data.append(pd.read_parquet(f\"data/cpg_0016_crispr_{i}.parquet\"))\n", - " except FileNotFoundError:\n", - " subset = pd.read_parquet(f\"gs://rxrx-cytodata2023-public/cpg_0016_crispr_{i}.parquet\")\n", - " subset.to_parquet(f\"data/cpg_0016_crispr_{i}.parquet\")\n", - " data.append(subset)\n", - " print(f\"Read in parquet {i+1} of 10\")\n", - " return pd.concat(data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "raw_well_data = read_parquets_from_gcs()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Creating and benchmarking an initial map\n", - "Let's create and benchmark an initial map with a simple EFAAR pipeline to serve as a baseline\n", - "\n", - "- **Embed:** Cell Profiler Embeddings from JUMP cpg0016\n", - "- **Filter:** Filtering out rows with Nan values\n", - "- **Align:** Centering and scaling the individual features followed by PCA keeping a fraction of the variance \n", - "- **Aggregate:** The mean over the well replicates of each gene KO\n", - "- **Relate:** Cosine similarity between the perturbations\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def preprocess_data(\n", - " data: pd.DataFrame, metadata_cols: list = CPG_METADATA_COLS, drop_image_cols: bool = True\n", - ") -> pd.DataFrame:\n", - " \"\"\"Preprocess data by dropping feature columns with nan values,\n", - " and optionaly dropping the Image columns\n", - "\n", - " Parameters\n", - " ----------\n", - " data : pd.DataFrame\n", - " Data to preprocess\n", - " metadata_cols : List[str], optional\n", - " Metadata columns, by default CPG_METADATA_COLS\n", - " drop_image_cols : bool, optional\n", - " Whether to drop Image columns, by default True\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " Processed dataframe\n", - " \"\"\"\n", - " metadata = data[metadata_cols]\n", - " features = data[[col for col in data.columns if col not in metadata_cols]]\n", - " features = features.dropna(axis=1)\n", - " if drop_image_cols:\n", - " image_cols = [col for col in features.columns if col.startswith(\"Image_\")]\n", - " features = features.drop(columns=image_cols)\n", - " return metadata.join(features)\n", - "\n", - "\n", - "def pca_transform_data(data, \n", - " metadata_cols: list = CPG_METADATA_COLS,\n", - " variance=0.98,\n", - " pca_fit_fraction = 1.0,) -> pd.DataFrame:\n", - " \"\"\"Align data by centerscaling data and then transforming by PCA\n", - "\n", - " Parameters\n", - " ----------\n", - " data : pd.DataFrame\n", - " Data to preprocess\n", - " metadata_cols : List[str], optional\n", - " Metadata columns, by default CPG_METADATA_COLS\n", - " pca_fit_fraction : bool, optional\n", - " Fraction of wells to sample for PCA, by default 1.0\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " Processed dataframe\n", - " \"\"\"\n", - "\n", - " metadata = data[metadata_cols]\n", - " features = data[[col for col in data.columns if col not in metadata_cols]]\n", - "\n", - " scaler = StandardScaler()\n", - " features.loc[:,:] = scaler.fit_transform(\n", - " features\n", - " )\n", - " \n", - " pca = PCA(variance)\n", - " pca.fit(features.sample(frac=pca_fit_fraction, random_state=42))\n", - " features = pd.DataFrame(pca.transform(features), index=metadata.index)\n", - "\n", - " return metadata.join(features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "transformed_well_data = pca_transform_data(preprocess_data(raw_well_data), pca_fit_fraction=0.05)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "features = transformed_well_data[[col for col in transformed_well_data.columns if col not in CPG_METADATA_COLS]]\n", - "aggregated_data = features.groupby(transformed_well_data[\"Metadata_Symbol\"], as_index=True).mean()\n", - "aggregated_data.index.name = \"gene\"" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Relationships\n", - "Visualizing relationships of some well known gene sets that we expect to cluster together" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def make_pairwise_cos(\n", - " df: pd.DataFrame,\n", - " convert: bool = True,\n", - " dtype: type = np.float16,\n", - ") -> pd.DataFrame:\n", - " \"\"\"\n", - " Converts a dataframe of samples X features into a square dataframe of samples X samples\n", - " of cosine similarities between rows.\n", - "\n", - " Inputs\n", - " ------\n", - " - df = pd.DataFrame\n", - " - convert = bool. Whether to convert the results to a smaller data type\n", - " - dtype = type. Data type to convert to\n", - " \"\"\"\n", - " mat = (1 - sp.spatial.distance.cdist(df.values, df.values, metric=\"cosine\")).clip(-1, 1)\n", - " if convert:\n", - " mat = mat.astype(dtype)\n", - " return pd.DataFrame(mat, index=df.index, columns=df.index)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def cluster_perturbations(data: pd.DataFrame) -> pd.DataFrame:\n", - " \"\"\"\n", - " Takes a square data frame, and uses euclidian distance to reorder the rows and columns\n", - " into clusters. The index and column values for the data frame must be identical to eachother.\n", - " \"\"\"\n", - " if list(data.index) != list(data.columns):\n", - " raise ValueError(\"index and columns for pldata must be equal\")\n", - "\n", - " order = hierarchy.dendrogram(\n", - " hierarchy.linkage(\n", - " pdist(data)\n", - " ),\n", - " no_plot=True,\n", - " )['ivl']\n", - " data = data.iloc[order]\n", - " data = data[data.index]\n", - " return data\n", - "\n", - "\n", - "def plot_cosine_rectangle_data(pldata: pd.DataFrame, title: str='plot', h=600, w=600):\n", - " \"\"\"\n", - " pldata is expected to be a rectangular dataframe, where all the values are to be plotted\n", - " \"\"\"\n", - " pldata = cluster_perturbations(data=pldata)\n", - "\n", - " fig = go.Figure(data=go.Heatmap(\n", - " z=pldata.values,\n", - " x=pldata.columns,\n", - " y=pldata.index,\n", - " colorscale='RdBu_r',\n", - " zmin=-1,\n", - " zmax=1,\n", - " ))\n", - " fig.update_layout(height=h, width=w, title_text=f\"{title}\")\n", - " fig.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gene_sets = {\n", - " 'PROTEASOME': 'PSMB2,PSMB7,PSMB4,PSMA7,PSMA4,PSMB6,PSMA5,PSMB3,PSMA6,PSMA1,PSMB1,PSMA3'.split(','),\n", - " 'EXOSOME': 'DIS3,EXOSC4,EXOSC8,EXOSC7,EXOSC9,EXOSC5'.split(','),\n", - " 'VATPASE': 'ATP6V1B2,ATP6V1H,ATP6V1D,ATP6V1A,ATP6V1F,ATP6V1E1'.split(','),\n", - " 'AUTOPHAGY': 'ATG12,ATG5'.split(','),\n", - " 'REPLICATION_FACTOR_COMPLEX': 'RFC3,RFC4,RFC2,RFC5'.split(','), \n", - " 'DYNEIN': 'DYNC1I2,DYNC1H1,DYNC1LI1,DYNC1LI2'.split(','),\n", - " 'RNA': 'POLA2,POLA1,POLR2L,POLR2B,POLR2I,POLR2G,POLR2C'.split(','), \n", - " 'EGFR': 'PRKCE,BRAF,HRAS,SHC1,RAF1,EGFR,MAPK1'.split(','),\n", - " 'TGFB/ACTIVIN': 'ACVR1B,TGFBR2,TGFBR1'.split(','),\n", - " }\n", - "\n", - "cluster_genes = [g for gene_set in gene_sets.values() for g in gene_set]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_cosine_rectangle_data(make_pairwise_cos(aggregated_data.loc[cluster_genes]), title = \"JUMP Gene Sets - Initial EFAAR Pipeline\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Benchmarking\n", - "Plots provide a nice visualization but we can quantify how well the map is recapitulating biology with these metrics" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Multivariate Benchmarks\n", - "The multivariate benchmarks compare annotated sets of known relationships against a distribution of randomly selected embeddings and computes recall based on the percentiles of the random distribution. In this notebook we will use the 5th and 95th percentil. Hence, embeddings that don't have biological signal are expected to have around " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = aggregated_data.reset_index()[['gene']]\n", - "features = aggregated_data.reset_index(drop=True)\n", - "results_dict = bm(Bunch(metadata=metadata, features=features), pert_label_col='gene', run_count=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "results_dict['Seeds_1608637542_3421126067']['HuMAP']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "benchmark_results_df = get_benchmark_metrics(results_dict)\n", - "benchmark_results_df['map_version'] = 'JUMP_initial_EFAAR'\n", - "sns.barplot(data=benchmark_results_df, x='source' ,y='recall')\n", - "plt.axhline(y=0.1, color='red', linestyle='--') \n", - "plt.title('Recall at 5th and 95th Percentiles')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Univariate Benchmarks\n", - "For the univariate benchmarks we can compare a metric that measures the consistency of the well level replicates of a perturbation against an empirical null where the metrics is computed on random sets of replicates. This computation can be very sensitive to the null distribution. Incorporating the experiment design in the null is import but can become computationally expensive very quickly. The following is an example of the idea but a production implementation needs more refinement and comptue. We are not matching plates or well address for example which can make replicates look more similar to each other. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def avg_angle(df):\n", - " cosine_sim = sklearn.metrics.pairwise.cosine_similarity(df.values)\n", - " if cosine_sim.shape[0]==1:\n", - " print(\"Whatt???\")\n", - " return np.arccos(cosine_sim[np.tril_indices(cosine_sim.shape[0], k=-1)]).mean()\n", - "\n", - "def generate_null_distribution(feature_data, n_samples = 10000, cardinality = 5):\n", - " \"\"\"Generate null distribution for univariate avg cosine sim metric\"\"\"\n", - " null_metrics = []\n", - " for i in range(n_samples):\n", - " null_metrics.append(avg_angle(feature_data.sample(n=cardinality)))\n", - " return null_metrics\n", - "\n", - "\n", - "def compute_cosine_sim_metric(data: pd.DataFrame, metadata_cols=CPG_METADATA_COLS, cardinality=5):\n", - " \"\"\"Compute univariate cosine similarity metrics on well level data\n", - "\n", - " Parameters\n", - " ----------\n", - " data : pd.DataFrame\n", - " Data to process\n", - " metadata_cols : List[str], optional\n", - " Metadata columns, by default CPG_METADATA_COLS\n", - " cardinality : bool, optional\n", - " Cardinatity of the wells to match when generating the null distribution\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " Processed dataframe\n", - " \"\"\"\n", - " data = data.query('Metadata_Symbol != \"no-guide\" and Metadata_Symbol != \"non-targeting\" and Metadata_Symbol != \"PLK1\"')\n", - " gene_count = data.groupby('Metadata_Symbol').Metadata_Well.count()\n", - " genes = list(gene_count[gene_count==cardinality].index)\n", - " data = data.query('Metadata_Symbol.isin(@genes)')\n", - " \n", - " perts = data[['Metadata_Symbol']]\n", - " features = data[[col for col in data.columns if col not in metadata_cols]]\n", - "\n", - " df = perts.join(features)\n", - " query_metrics = df.drop(columns = \"Metadata_Symbol\").groupby(df['Metadata_Symbol']).apply(avg_angle)\n", - " query_metrics.name = \"avg_cossim\"\n", - " query_metrics = query_metrics.reset_index()\n", - "\n", - " null = []\n", - " for i in range(5):\n", - " df['Metadata_Symbol'] = df['Metadata_Symbol'].sample(frac=1).values\n", - " null.extend(df.drop(columns = \"Metadata_Symbol\").groupby(df['Metadata_Symbol']).apply(avg_angle).values)\n", - " sorted_null = np.sort(null)\n", - " query_metrics['avg_cossim_pval'] = np.searchsorted(sorted_null, query_metrics.avg_cossim)/len(sorted_null)\n", - " return query_metrics\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "initial_univariate_metrics = compute_cosine_sim_metric(transformed_well_data)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "num01genes = len(initial_univariate_metrics.query(\"avg_cossim_pval<=0.01\"))\n", - "num05genes = len(initial_univariate_metrics.query(\"avg_cossim_pval<=0.05\"))\n", - "print(f\"{num01genes} significant genes at a 0.01 significance threshold\")\n", - "print(f\"{num05genes} significant genes at a 0.05 significance threshold\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Improving the Map\n", - "Let's iterate on the EFAAR pipeline to try to improve our both the univariate and multivariate benchmarks" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Filtering\n", - "With CellProfiler features, our features are identifiable so we can filter directly based on them. With other types of embeddings, this could be based on statistical or machine learning models predicing whether a well should be removed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cols = raw_well_data.columns" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "[col for col in cols if \"ImageQuality_MeanIntensity\" in col]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "[col for col in cols if \"Object_Number\" in col]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "raw_well_data['Image_ImageQuality_MeanIntensity_OrigDNA'].plot(kind='kde')\n", - "plt.title('Mean Intensity')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "raw_well_data['Cytoplasm_Number_Object_Number'].plot(kind='kde')\n", - "plt.title('Object Number')\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def filter_by_image_intensity(df,contamination=0.01):\n", - " \"\"\"Filter dataframe by image intensity threshold.\"\"\"\n", - "\n", - " intensity_cols = [col for col in cols if \"ImageQuality_MeanIntensity\" in col]\n", - " image_data = df[intensity_cols]\n", - " envelope = EllipticEnvelope(contamination=contamination, random_state=42)\n", - " lables = envelope.fit_predict(image_data)\n", - " return df.loc[lables == 1]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def filter_by_cell_count(df, lower_threshold, upper_threshold):\n", - " \"\"\"Filter dataframe by cell count\"\"\"\n", - " mask = (df['Cytoplasm_Number_Object_Number'] >= lower_threshold) & (df['Nuclei_Number_Object_Number'] >= lower_threshold)\n", - " mask = mask & (df['Cytoplasm_Number_Object_Number'] <= upper_threshold) & (df['Nuclei_Number_Object_Number'] <= upper_threshold)\n", - " return df.loc[mask]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def filter_and_preprocess_data(\n", - " data: pd.DataFrame, metadata_cols: list = CPG_METADATA_COLS, drop_image_cols: bool = True\n", - ") -> pd.DataFrame:\n", - " \"\"\"Preprocess data by dropping feature columns with nan values, dropping the\n", - " Number_Object columns, and optionaly dropping the Image columns\n", - "\n", - " Parameters\n", - " ----------\n", - " data : pd.DataFrame\n", - " Data to preprocess\n", - " metadata_cols : List[str], optional\n", - " Metadata columns, by default CPG_METADATA_COLS\n", - " drop_image_cols : bool, optional\n", - " Whether to drop Image columns, by default True\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " Processed dataframe\n", - " \"\"\"\n", - "\n", - " metadata = data[metadata_cols]\n", - " features = data[[col for col in data.columns if col not in metadata_cols]]\n", - " features = features.dropna(axis=1)\n", - " data = metadata.join(features)\n", - " data = filter_by_image_intensity(data)\n", - " data = filter_by_cell_count(data, lower_threshold=50, upper_threshold=350) \n", - "\n", - " data = data.drop(columns=[col for col in data.columns if col.endswith(\"Object_Number\")])\n", - "\n", - " if drop_image_cols:\n", - " image_cols = [col for col in data.columns if col.startswith(\"Image_\")]\n", - " data = data.drop(columns=image_cols)\n", - " return data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Run preprocessing on the raw data\n", - "print(len(raw_well_data), \"wells before filtering\")\n", - "processed_data = filter_and_preprocess_data(raw_well_data)\n", - "print(len(processed_data), \"wells after filterng\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Alignment" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def _compute_modified_cov(data: pd.DataFrame):\n", - " return np.cov(data, rowvar=False, ddof=1) + 0.5 * np.eye(\n", - " np.shape(data)[1], dtype=np.float32\n", - " )\n", - "\n", - "\n", - "def tvn_transform(\n", - " data: pd.DataFrame,\n", - " metadata_cols: list = CPG_METADATA_COLS,\n", - " variance=0.98,\n", - " pca_fit_fraction = 1.0\n", - ") -> pd.DataFrame:\n", - " \"\"\"Transform data by scaling, applying PCA followed by by coral by batch. Data is scaled by plate\n", - " before and after PCA is applied. \n", - "\n", - " Parameters\n", - " ----------\n", - " data : pd.DataFrame\n", - " Data to transform\n", - " metadata_cols : list, optional\n", - " Metadata columns, by default CPG_METADATA_COLS\n", - " variance : float, optional\n", - " Variance to keep after PCA, by default 0.98\n", - " pca_fit_fraction : bool, optional\n", - " Fraction of wells to sample for PCA, by default 1.0\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " Transformed data\n", - " \"\"\"\n", - " metadata = data[metadata_cols]\n", - " features = data[[col for col in data.columns if col not in metadata_cols]]\n", - "\n", - " for plate in metadata.Metadata_Plate.unique():\n", - " scaler = StandardScaler()\n", - " features.loc[metadata.Metadata_Plate == plate, :] = scaler.fit_transform(\n", - " features.loc[metadata.Metadata_Plate == plate, :]\n", - " )\n", - "\n", - " pca = PCA(variance)\n", - " pca.fit(features.sample(frac=pca_fit_fraction, random_state=42))\n", - " features = pd.DataFrame(pca.transform(features), index= metadata.index)\n", - "\n", - " for plate in metadata.Metadata_Plate.unique():\n", - " scaler = StandardScaler()\n", - " features.loc[metadata.Metadata_Plate == plate, :] = scaler.fit_transform(\n", - " features.loc[metadata.Metadata_Plate == plate, :]\n", - " )\n", - "\n", - " for batch in metadata.Metadata_Batch.unique():\n", - " source_cov = _compute_modified_cov(features.loc[metadata.query(f\"Metadata_Batch == '{batch}'\").index])\n", - " source_cov_half_inv = linalg.fractional_matrix_power(source_cov, -0.5)\n", - " \n", - " features.loc[metadata.Metadata_Batch == batch, :] = np.matmul(\n", - " features.loc[metadata.Metadata_Batch == batch, :],\n", - " source_cov_half_inv,\n", - " ) \n", - " return metadata.join(features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "# Run alignment on the processed data\n", - "# This takes a while, so we again only use a fraction of the data to fit the PCA\n", - "tvn_transformed_well_data = tvn_transform(processed_data, pca_fit_fraction= .1)" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Aggregation\n", - "Combine representations by taking the median across the replicates" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "features = tvn_transformed_well_data[[col for col in tvn_transformed_well_data.columns if col not in CPG_METADATA_COLS]]\n", - "aggregated_tvn_data = features.groupby(tvn_transformed_well_data[\"Metadata_Symbol\"], as_index=True).median()\n", - "aggregated_tvn_data.index.name = \"gene\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# inspect the data\n", - "aggregated_tvn_data.head()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Relationships" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_cosine_rectangle_data(make_pairwise_cos(aggregated_tvn_data.loc[list(set(cluster_genes).intersection(aggregated_tvn_data.index))]), \n", - " title = \"JUMP Gene Sets - Improved EFAAR Pipeline\")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Multivariate Benchmarks" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data_dict={}" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = aggregated_tvn_data.reset_index()[['gene']]\n", - "features = aggregated_tvn_data.reset_index(drop=True)\n", - "data_dict['improved_EFAAR'] = Bunch(metadata=metadata, features=features)\n", - "results = bm(data_dict['improved_EFAAR'], pert_label_col='gene', run_count=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "new_results_df = get_benchmark_metrics(results)\n", - "new_results_df['map_version'] = 'JUMP_improved_EFAAR'\n", - "benchmark_results_df = pd.concat([benchmark_results_df, new_results_df])\n", - "sns.barplot(data=benchmark_results_df, x='source' ,y='recall', hue='map_version')\n", - "plt.axhline(y=0.1, color='red', linestyle='--') \n", - "plt.title('Recall at 5th and 95th Percentiles')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Proximity Bias in CRISPR Knockout Data" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Visualizing Proximity Bias in JUMP Data\n", - "We create a genome wide heatmap ordered by gene chromosome position" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def _get_data_path(filename):\n", - " return \"data/\"+filename\n", - "\n", - "\n", - "def _chr_to_int(chr):\n", - " if chr == \"x\":\n", - " return 24\n", - " elif chr == \"y\":\n", - " return 25\n", - " return int(chr)\n", - "\n", - "\n", - "def _chrom_int(chrom):\n", - " return chrom.copy().str.split(\"chr\").str[1].str.lower().map(_chr_to_int)\n", - "\n", - "VALID_CHROMS = [f\"chr{i}\" for i in range(1, 23)] + [\"chrX\", \"chrY\"]\n", - "\n", - "def _load_centromeres() -> pd.DataFrame:\n", - " centros = pd.read_csv(_get_data_path(\"centromeres_hg38.tsv\"), sep=\"\\t\", usecols=[\"chrom\", \"chromStart\", \"chromEnd\"])\n", - " centros[[\"chromStart\", \"chromEnd\"]] = centros[[\"chromStart\", \"chromEnd\"]].astype(int)\n", - " centros = centros.groupby(\"chrom\", as_index=False).agg(\n", - " centromere_start=(\"chromStart\", \"min\"),\n", - " centromere_end=(\"chromEnd\", \"max\"),\n", - " )\n", - " return centros\n", - "\n", - "\n", - "def _load_chromosomes(centromeres: Optional[pd.DataFrame] = None) -> pd.DataFrame:\n", - " chroms = pd.read_csv(_get_data_path(\"hg38_scaffolds.tsv\"), sep=\"\\t\", usecols=[\"chrom\", \"chromStart\", \"chromEnd\"])\n", - " chroms[[\"chromStart\", \"chromEnd\"]] = chroms[[\"chromStart\", \"chromEnd\"]].astype(int)\n", - " chroms = chroms.loc[chroms.chrom.isin(VALID_CHROMS)].rename(columns={\"chromStart\": \"start\", \"chromEnd\": \"end\"})\n", - " chroms[\"chrom_int\"] = _chrom_int(chroms.chrom)\n", - "\n", - " # Merge in centromere data if available\n", - " if isinstance(centromeres, pd.DataFrame) and not centromeres.empty:\n", - " chroms = chroms.merge(centromeres, on=\"chrom\", how=\"left\")\n", - "\n", - " chroms = chroms.set_index(\"chrom\").sort_values(\"chrom_int\", ascending=True)\n", - " return chroms\n", - "\n", - "\n", - "def _load_bands() -> pd.DataFrame:\n", - " bands = pd.read_csv(\n", - " _get_data_path(\"hg38_cytoband.tsv.gz\"), sep=\"\\t\", usecols=[\"name\", \"#chrom\", \"chromStart\", \"chromEnd\"]\n", - " )\n", - " bands = bands.rename(columns={\"#chrom\": \"chrom\"})\n", - " bands = bands.groupby([\"chrom\", \"name\"], as_index=False).agg(\n", - " band_start=(\"chromStart\", \"min\"),\n", - " band_end=(\"chromEnd\", \"max\"),\n", - " band_chrom_arm=(\"name\", lambda x: x.str[:1].min()),\n", - " )\n", - " bands[\"chrom_int\"] = _chrom_int(bands.chrom)\n", - " return bands\n", - "\n", - "\n", - "def _load_genes(chromosomes: Optional[pd.DataFrame] = None) -> pd.DataFrame:\n", - " genes = pd.read_csv(\n", - " _get_data_path(\"ncbirefseq_hg38.tsv.gz\"), sep=\"\\t\", usecols=[\"name2\", \"chrom\", \"txStart\", \"txEnd\"]\n", - " ).rename(columns={\"name2\": \"gene\"})\n", - "\n", - " genes = genes.loc[genes.chrom.isin(VALID_CHROMS)]\n", - " genes[\"chrom_int\"] = _chrom_int(genes.chrom)\n", - " genes = genes.groupby(\"gene\", as_index=False).agg(\n", - " start=(\"txStart\", \"min\"),\n", - " end=(\"txEnd\", \"max\"),\n", - " chrom_int=(\"chrom_int\", \"min\"),\n", - " chrom_count=(\"chrom\", \"nunique\"),\n", - " chrom=(\"chrom\", \"first\"),\n", - " )\n", - "\n", - " # Filter out (psuedo-)genes of unknown function and genes that show up on multiple chromosomes\n", - " genes = genes.loc[~genes.gene.str.contains(\"^LOC*\", regex=True)]\n", - " genes = genes.loc[genes.chrom_count == 1].drop(columns=\"chrom_count\").set_index(\"gene\")\n", - " genes = genes.sort_values([\"chrom_int\", \"start\", \"end\"], ascending=True)\n", - "\n", - " # Use the middle of the centromere as the way to determine if a gene is on the 0/1 chromosome\n", - " if isinstance(chromosomes, pd.DataFrame) and not chromosomes.empty:\n", - " chroms_centromere_mid = chromosomes.copy().set_index(\"chrom_int\")\n", - " chrom_centromere_mid = (\n", - " (chroms_centromere_mid.centromere_start + chroms_centromere_mid.centromere_end) / 2\n", - " ).to_dict()\n", - " genes[\"chrom_arm_int\"] = genes.apply(lambda x: x.end > chrom_centromere_mid[x.chrom_int], axis=1).astype(int)\n", - "\n", - " # NOTE: Assumes that p is the first chromosome\n", - " genes[\"chrom_arm\"] = genes[\"chrom_arm_int\"].apply(lambda x: \"p\" if x == 0 else \"q\")\n", - " genes[\"chrom_arm_name\"] = genes[\"chrom\"] + genes[\"chrom_arm\"]\n", - " return genes\n", - "\n", - "\n", - "def get_chromosome_info_as_dfs() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:\n", - " \"\"\"\n", - " Get structured information about the chromosomes that genes lie on as three dataframes:\n", - " - Genes, including start and end and which chromosome arm they are on\n", - " - Chromosomes, including the centromere start and end genomic coordinates\n", - " - Cytogenic bands, including the name and start and end genomic\n", - "\n", - " Returns\n", - " -------\n", - " Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]\n", - " Genes, chromosomes, cytogenic bands\n", - " \"\"\"\n", - "\n", - " # TODO: refactor into more composable functions\n", - " bands = _load_bands()\n", - " chroms = _load_chromosomes(centromeres=_load_centromeres())\n", - " genes = _load_genes(chromosomes=chroms)\n", - "\n", - " return genes, chroms, bands\n", - "\n", - "def get_chromosome_info_as_dicts(\n", - " legacy_bands: bool = False,\n", - ") -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:\n", - " \"\"\"\n", - " Convert the output of `get_chromosome_info_as_dfs` to dictionary form for compatibility\n", - " with legacy notebooks.\n", - "\n", - " Returns\n", - " -------\n", - " Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]\n", - " Dict corresponding to genes, chromosomes, and cytogenic bands\n", - " \"\"\"\n", - " gene_df, chrom_df, band_df = get_chromosome_info_as_dfs()\n", - "\n", - " # Extra composite key for convenience\n", - " gene_df[\"arm\"] = gene_df.chrom + gene_df.chrom_arm\n", - " gene_dict = gene_df.to_dict(orient=\"index\")\n", - "\n", - " chrom_dict = chrom_df.to_dict(orient=\"index\")\n", - "\n", - " # Extra composite key for convenience\n", - " band_df[\"region\"] = band_df.chrom + band_df.name\n", - " if legacy_bands:\n", - " band_df = band_df[[\"region\", \"chrom\", \"band_start\", \"band_end\"]].set_index(\"region\")\n", - " band_dict = {str(k): tuple(v) for k, v in band_df.iterrows()}\n", - " else:\n", - " band_dict = band_df.to_dict(orient=\"index\")\n", - " return gene_dict, chrom_dict, band_dict\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "gene_dict, chrom_dict, band_dict = get_chromosome_info_as_dicts()\n", - "genes = list(gene_dict.keys())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "n_feats = aggregated_tvn_data.shape[1]\n", - "aggregated_tvn_data =aggregated_tvn_data.reset_index()\n", - "idx = aggregated_tvn_data.query(f\"gene.isin({genes})\").index\n", - "print(f'Full data has {aggregated_tvn_data.shape[0]} genes, {len(idx)} of which are in hg38 annotations')\n", - "data_t = aggregated_tvn_data.rename({'Metadata_Symbol': 'gene'}, axis=1)\n", - "data_t = data_t.loc[idx].reset_index(drop=True)\n", - "\n", - "# Add in chromomsome information\n", - "data_t['chromosome'] = data_t.gene.apply(lambda x: gene_dict[x]['chrom'] if x in gene_dict else \"no info\" )\n", - "data_t['chr_idx'] = data_t.gene.apply(lambda x: gene_dict[x]['chrom_int'] if x in gene_dict else \"no info\" )\n", - "data_t['chromosome_arm'] = data_t.gene.apply(lambda x: gene_dict[x]['arm'] if x in gene_dict else \"no info\" )\n", - "data_t['gene_bp'] = data_t.gene.apply(lambda x: gene_dict[x][\"start\"] if x in gene_dict else \"no info\" )" - ] - }, { "cell_type": "code", "execution_count": null, @@ -1119,436 +27,23 @@ }, "outputs": [], "source": [ - "data_t.head()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cols = ['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp',] + list(range(0,n_feats))\n", - "data_t = data_t.loc[:, cols]\n", - "data_t = data_t.sort_values(['chr_idx', 'gene_bp']).reset_index(drop=True)\n", - "data_t = data_t.set_index(['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp',])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def crunch_square_df(\n", - " sims: pd.DataFrame,\n", - " crunch_factor: int,\n", - ") -> pd.DataFrame:\n", - " \"\"\"\n", - " Compress `sims` dataframe by `crunch_factor` to make visualizations reasonable. This takes averages of squares of\n", - " size `crunch_factor` X `crunch_factor`. Indices are replaced by the first value in the crunch block.\n", - "\n", - " Inputs:\n", - " -------\n", - " - sims: pd.DataFrame() with matching row and column indices\n", - " - crunch_fctor: int to compress the data by.\n", - " \"\"\"\n", - "\n", - " idx = [(i % crunch_factor == 0) for i, x in enumerate(sims.index)]\n", - " new_index = sims.index[idx] # type: ignore\n", - " crunched = block_reduce(sims.values, (crunch_factor, crunch_factor), np.mean)\n", - "\n", - " return pd.DataFrame(crunched, index=new_index, columns=new_index)\n", - "\n", - "\n", - "def plot_heatmap(\n", - " sims: pd.DataFrame,\n", - " f_name: Optional[str] = None,\n", - " format: str = \"png\",\n", - " crunch_factor: int = 1,\n", - " show_chr_lines: bool = True,\n", - " show_cent_lines: bool = True,\n", - " show_chroms: bool = True,\n", - " show_chrom_arms: bool = False,\n", - " figsize: tuple = (20, 20),\n", - " title: Optional[str] = None,\n", - " label_locy: Optional[float] = None,\n", - " lab_s: int = 12,\n", - " drop_chry: bool = True,\n", - " lw: float = 0.5,\n", - " lab_rot: int = 0,\n", - "):\n", - " \"\"\"\n", - " Plotting function for heatmaps (full-genome or subsets) can be sorted/clustered or ordered by chromosome\n", - "\n", - " Inputs:\n", - " - sims: Square dataframe with matching row and column indices and similarity values. This can be \"split\" along\n", - " the diagonal to show different datasets. Index should include `chromosome` and `chromosome_arm` if\n", - " ordering by genomic position. Each row/column should represent one gene ordered by genomic position\n", - " - f_name: file name\n", - " - format: file format for saving a file\n", - " - crunch_factor: if > 1 will apply a average smoothing to reduce the size of the output file\n", - " - show_chr_lines: whether to show lines at chromosome boundaries\n", - " - show_cent_lines: whether to show lines at centromeres\n", - " - show_chroms: Whether to label chromosomes on top and right\n", - " - show_chrom_arms: whether to label chromosome arms on top and right\n", - " - figsize: size of the resulting figure\n", - " - title: plot title\n", - " - label_locy: location of labels in y\n", - " - lab_s: Font size of labels\n", - " - drop_chry: Whether to remove Chromosome Y values\n", - " - lw: line width\n", - " - lab_rot: rotation of labels\n", - " \"\"\"\n", - " color_norm = mpl.colors.Normalize(vmin=-1, vmax=1)\n", - " cmap = mpl.colormaps[\"RdBu_r\"]\n", - " # This sets nan values to white\n", - " # cmap.set_bad(color='white')\n", - "\n", - " if crunch_factor > 1:\n", - " # Downsample the data to make the file size less crazy\n", - " # Every `sample_factor`th row/column will be kept\n", - " sims = crunch_square_df(sims, crunch_factor=crunch_factor)\n", - "\n", - " image_data = cmap(color_norm(sims.values))\n", - "\n", - " plt.figure(figsize=figsize)\n", - " plt.imshow(image_data)\n", - "\n", - " if drop_chry and \"chromosome\" in sims.index.names:\n", - " noy_idx = sims.index.get_level_values(\"chromosome\") != \"chrY\"\n", - " sims = sims.copy().loc[noy_idx, noy_idx] # type: ignore\n", - "\n", - " if show_chr_lines or show_chroms or show_cent_lines or show_chrom_arms:\n", - " # Get the position of all the chromosomes and centromeres\n", - " index_df = sims.index.to_frame(index=False).reset_index().rename({\"index\": \"pos\"}, axis=1)\n", - " chr_pos = index_df.groupby(\"chromosome\").pos.max().sort_values()\n", - " cent_pos = index_df.groupby(\"chromosome_arm\").pos.max().sort_values()\n", - " # Filter to just p-arms\n", - " cent_pos_p = cent_pos[[x[-1] == \"p\" for x in cent_pos.index]]\n", - " # Get midpoints for annotations\n", - " chr_mids = pd.DataFrame(\n", - " (np.insert(chr_pos.values[:-1], 0, 0) + chr_pos.values) / 2, index=chr_pos.index # type: ignore\n", - " ).to_dict()[\n", - " 0 # type: ignore\n", - " ]\n", - " cent_mids = pd.DataFrame(\n", - " (np.insert(cent_pos.values[:-1], 0, 0) + cent_pos.values) / 2, index=cent_pos.index # type: ignore\n", - " ).to_dict()[\n", - " 0 # type: ignore\n", - " ]\n", - "\n", - " # Hide X and Y axes label marks\n", - " ax = plt.gca()\n", - " ax.xaxis.set_tick_params(labelbottom=False)\n", - " ax.yaxis.set_tick_params(labelleft=False)\n", - " # Hide X and Y axes tick marks\n", - " ax.set_xticks([])\n", - " ax.set_yticks([])\n", - "\n", - " xm, xM = ax.get_xlim()\n", - " ym, yM = ax.get_ylim()\n", - "\n", - " if show_chr_lines:\n", - " for x in chr_pos.values:\n", - " plt.plot([x + 0.5, x + 0.5], [ym, yM], color=\"k\", ls=\"-\", lw=lw)\n", - " plt.plot([xm, xM], [x + 0.5, x + 0.5], color=\"k\", ls=\"-\", lw=lw)\n", - " if show_cent_lines:\n", - " for x in cent_pos_p.values:\n", - " plt.plot([x + 0.5, x + 0.5], [ym, yM], color=\"k\", ls=\":\", lw=lw)\n", - " plt.plot([xm, xM], [x + 0.5, x + 0.5], color=\"k\", ls=\":\", lw=lw)\n", - "\n", - " if show_chroms:\n", - " # Label chromosomes on top/right to not clash with coords\n", - " ax = plt.gca()\n", - " s = sims.shape[0]\n", - " for ch in chr_mids:\n", - " # Labels across the top\n", - " if label_locy is None:\n", - " label_locy = -0.008 * s\n", - " # Labels on top\n", - " ax.text(\n", - " chr_mids[ch], label_locy, ch.replace(\"chr\", \"\"), ha=\"center\", va=\"bottom\", rotation=lab_rot, size=lab_s\n", - " )\n", - " # Labels on the right\n", - " ax.text(sims.shape[0] + 0.008 * s, chr_mids[ch], ch.replace(\"chr\", \"\"), ha=\"left\", va=\"center\", size=lab_s)\n", - "\n", - " if show_chrom_arms:\n", - " # Label chromosome arms on top/right\n", - " ax = plt.gca()\n", - " s = sims.shape[0]\n", - " for cent in cent_mids:\n", - " # Labels across the top\n", - " if label_locy is None:\n", - " label_locy = -0.008 * s\n", - " ax.text(cent_mids[cent], label_locy, cent, ha=\"left\", rotation=lab_rot, size=lab_s)\n", - " # Labels on the right\n", - " ax.text(sims.shape[0] + 0.001 * s, cent_mids[cent], cent, ha=\"left\", size=lab_s)\n", - "\n", - " plt.title(title, size=\"xx-large\", y=1.1)\n", - " plt.gcf().set_facecolor(\"white\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_heatmap(make_pairwise_cos(data_t), f_name=('data/cpg0016_split_prenorm.svg'),\n", - " format='svg', crunch_factor=10, title='JUMP Whole Genome Heatmap')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Using Arm Subtraction to Correct for Proximity Bias\n", - "We add an additional alignment step to our pipeline where we subtract off the mean of the corresponding chromosome arm from each transformed embedding." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Load U2OS expression data\n", - "u2os_exp = pd.read_csv('data/u2os.csv')\n", - "u2os_exp = u2os_exp.groupby('gene',as_index=False).zfpkm.median()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def build_arm_centering_df(\n", - " data: pd.DataFrame,\n", - " metadata_cols: List[str],\n", - " arm_column=\"chromosome_arm\",\n", - " subset_query=\"zfpkm <-3\",\n", - " min_num_gene=20,\n", - ") -> pd.DataFrame:\n", - " \"\"\"Build a dataframe with the mean feature values for each chromosome arm\n", - "\n", - " Parameters\n", - " ----------\n", - " data : pd.DataFrame\n", - " DataFrame with metadata and features\n", - " metadata_cols : List[str]\n", - " Metadata columns\n", - " arm_column : str, optional\n", - " Metadata column with arm identifier, by default \"chromosome_arm\"\n", - " subset_query : str, optional\n", - " Query to subset genes, by default \"zfpkm <-3\"\n", - " min_num_gene : int, optional\n", - " Minimum number of genes required. Dataframe returned will only\n", - " include arms that meet this threshold, by default 20\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " DataFrame with mean feature values for each chromosome arm\n", - " \"\"\"\n", - " subset = data.query(subset_query)\n", - " if arm_column not in metadata_cols:\n", - " metadata_cols = metadata_cols + [arm_column]\n", - " features = subset.drop(metadata_cols, axis=\"columns\")\n", - " return features.groupby(subset[arm_column]).mean()[\n", - " subset.groupby(arm_column)[metadata_cols[0]].size() > min_num_gene\n", - " ]\n", - "\n", - "\n", - "def perform_arm_centering(\n", - " data: pd.DataFrame,\n", - " metadata_cols: List[str],\n", - " arm_centering_df: pd.DataFrame,\n", - " arm_column: str = \"chromosome_arm\",\n", - ") -> pd.DataFrame:\n", - " \"\"\"Apply arm centering to data\n", - "\n", - " Parameters\n", - " ----------\n", - " data : pd.DataFrame\n", - " Data DataFrame\n", - " metadata_cols : List[str]\n", - " List of metadata columns\n", - " arm_centering_df : pd.DataFrame\n", - " Arm centering dataframe\n", - " arm_column : str, optional\n", - " Column that identifies chromosome arm, by default \"chromosome_arm\"\n", - "\n", - " Returns\n", - " -------\n", - " pd.DataFrame\n", - " Arm centered data\n", - " \"\"\"\n", - " metadata = data[metadata_cols]\n", - " features = data.drop(metadata_cols, axis=\"columns\")\n", - " for chromosome_arm in arm_centering_df.index:\n", - " arm_features = features[metadata[arm_column] == chromosome_arm]\n", - " arm_features = arm_features - arm_centering_df.loc[chromosome_arm]\n", - " features[metadata[arm_column] == chromosome_arm] = arm_features\n", - " return metadata.join(features)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "data = data_t.reset_index()\n", - "data= data.merge(u2os_exp, how='left', left_on='gene', right_on = 'gene')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "arm_centering_df = build_arm_centering_df(data, metadata_cols=['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm']) " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "arm_centered_data = perform_arm_centering(data, metadata_cols=['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm'], \n", - " arm_centering_df=arm_centering_df)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plot_heatmap(make_pairwise_cos(arm_centered_data.set_index(['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm'])), \n", - " f_name=('data/cpg0016_arm_centered.svg'), format='svg', crunch_factor=10, title='JUMP Whole Genome Arm Centered Heatmap')" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Recomputing Metrics" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "metadata = arm_centered_data[['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm']]\n", - "features = arm_centered_data.drop(columns=['gene', 'chromosome', 'chr_idx', 'chromosome_arm', 'gene_bp', 'zfpkm'])\n", - "data_dict['arm_centered_EFAAR'] = Bunch(metadata=metadata, features=features)\n", - "arm_centered_results = bm(data_dict['arm_centered_EFAAR'], pert_label_col='gene', run_count=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "arm_results_df = get_benchmark_metrics(arm_centered_results)\n", - "arm_results_df['map_version'] = 'JUMP_arm_centered_EFAAR'\n", - "benchmark_results_df = pd.concat([benchmark_results_df, arm_results_df])\n", - "sns.barplot(data=benchmark_results_df, x='source' ,y='recall', hue='map_version')\n", - "plt.axhline(y=0.1, color='red', linestyle='--') \n", - "plt.title('Recall at 5th and 95th Percentiles')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Arm Stratified Metrics\n", - "We also compute the multiperturbation metrics stratified by chromosome arm. This is because there may have been some relationships where they were only recalled because the genes happened to be on the same chromosome arm. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from efaar_benchmarking.constants import RANDOM_SEED, BENCHMARK_SOURCES, N_NULL_SAMPLES\n", - "from efaar_benchmarking.utils import (\n", - " generate_null_cossims,\n", - " generate_query_cossims,\n", - " get_benchmark_data,\n", - " get_feats_w_indices,\n", - ")\n", - "\n", - "def compute_within_cross_arm_pairwise_metrics(\n", - " data: Bunch,\n", - " pert_label_col: str = \"gene\",\n", - " pct_thresholds: list = [0.05, 0.95],\n", - ") -> tuple:\n", - " \"\"\"Compute known biology benchmarks stratified by whether the pairs of genes\n", - " are on the same chromosome arm or not.\n", - "\n", - " Parameters\n", - " ----------\n", - " data : Bunch\n", - " Metadata-features bunch\n", - " pert_label_col : str, optional\n", - " Column in the metadata that defines the perturbation, by default \"gene\"\n", - " pct_thresholds : list, optional\n", - " Percentile thresholds for the recall computation, by default [0.05, 0.95]\n", - "\n", - " Returns\n", - " -------\n", - " tuple\n", - " Results for within-arm and cross-arm pairs, respectively.\n", - " \"\"\"\n", - " np.random.seed(RANDOM_SEED)\n", - " within = {}\n", - " between = {}\n", - " for source in BENCHMARK_SOURCES:\n", - " random_seed_pair = np.random.randint(2**32, size=2)\n", - " gt_data = get_benchmark_data(source)\n", - " gene_dict, _, _ = get_chromosome_info_as_dicts()\n", - "\n", - " feats = get_feats_w_indices(data, pert_label_col)\n", + "from efaar_benchmarking.data_loading import load_cpg16_crispr\n", + "from efaar_benchmarking.efaar import *\n", + "from efaar_benchmarking.constants import *\n", + "from efaar_benchmarking.benchmarking import benchmark\n", + "from efaar_benchmarking.plotting import plot_recall\n", "\n", - " gt_data[\"entity1_chrom\"] = gt_data.entity1.apply(lambda x: gene_dict[x][\"arm\"] if x in gene_dict else \"no info\")\n", - " gt_data[\"entity2_chrom\"] = gt_data.entity2.apply(lambda x: gene_dict[x][\"arm\"] if x in gene_dict else \"no info\")\n", - " gt_data = gt_data.query(\"entity1_chrom != 'no info' and entity2_chrom != 'no info'\")\n", - " df_gg_null = generate_null_cossims(\n", - " feats,\n", - " feats,\n", - " rseed_entity1=random_seed_pair[0],\n", - " rseed_entity2=random_seed_pair[1],\n", - " n_entity1=N_NULL_SAMPLES,\n", - " n_entity2=N_NULL_SAMPLES,\n", - " )\n", + "recall_threshold_pairs = []\n", + "start = 0.01\n", + "end = 0.99\n", + "step = 0.01\n", "\n", - " within_gt_subset = gt_data.query(\"entity1_chrom == entity2_chrom\")\n", - " between_gt_subset = gt_data.query(\"entity1_chrom != entity2_chrom\")\n", + "while start <= .105 and end >= .895:\n", + " recall_threshold_pairs.append((round(start,2), round(end,2)))\n", + " start += step\n", + " end -= step\n", "\n", - " df_gg_within = generate_query_cossims(feats, feats, within_gt_subset)\n", - " df_gg_between = generate_query_cossims(feats, feats, between_gt_subset)\n", - "\n", - " within[source] = _compute_recall(df_gg_null, df_gg_within, pct_thresholds)\n", - "\n", - " between[source] = _compute_recall(df_gg_null, df_gg_between, pct_thresholds)\n", - " return within, between\n", - "\n", - "def _compute_recall(null_cossims, query_cossims, pct_thresholds) -> dict:\n", - " null_sorted = np.sort(null_cossims)\n", - " percentiles = np.searchsorted(null_sorted, query_cossims) / len(null_sorted)\n", - " return sum((percentiles <= np.min(pct_thresholds)) | (percentiles >= np.max(pct_thresholds))) / len(percentiles)" + "print(recall_threshold_pairs)" ] }, { @@ -1557,31 +52,14 @@ "metadata": {}, "outputs": [], "source": [ - "arm_stratified_results = {}\n", - "for k, v in data_dict.items():\n", - " arm_stratified_results[k] = compute_within_cross_arm_pairwise_metrics(v)\n", - "\n", - "\n", - "result_records = []\n", - "for map_label, v in arm_stratified_results.items():\n", - " for name, result in zip((\"within arm\", \"between arms\"), v):\n", - " for source, recall in result.items():\n", - " result_records.append((map_label, name, source, recall))\n", - "\n", - "stratified_results_df = pd.DataFrame.from_records(result_records, columns=[\"Map Name\", \"arms\", \"source\", \"recall\"])\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "sns.catplot(data = stratified_results_df, x = \"arms\", y=\"recall\", hue = \"Map Name\", col = \"source\", kind = \"bar\")\n", - "plt.title(\"Arm Stratified Benchmarks\")\n", - "plt.show()" + "features, metadata = load_cpg16_crispr() # loading may take some time if the files are not cached yet, depending on the speed of your internet connection\n", + "features, metadata = filter_cpg16_crispr(features, metadata)\n", + "embeddings = embed_align_by_pca(features.values, metadata, variance_or_ncomp=.98, plate_col=JUMP_PLATE_COL)\n", + "embeddings = align_on_controls(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL)\n", + "map_data = aggregate(embeddings, metadata, pert_col=JUMP_PERT_LABEL_COL)\n", + "metrics = benchmark(map_data, recall_thr_pairs=recall_threshold_pairs, pert_label_col=JUMP_PERT_LABEL_COL)\n", + "plot_recall(metrics)\n", + "metrics.groupby('source')['recall_0.05_0.95'].mean()" ] } ], @@ -1601,7 +79,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.10" + "version": "3.11.5" }, "vscode": { "interpreter": { diff --git a/notebooks/replogle_map_building.ipynb b/notebooks/replogle_map_building.ipynb index fc2c5eb..2777aca 100644 --- a/notebooks/replogle_map_building.ipynb +++ b/notebooks/replogle_map_building.ipynb @@ -28,7 +28,7 @@ "outputs": [], "source": [ "from efaar_benchmarking.data_loading import load_replogle\n", - "from efaar_benchmarking.efaar import embed_by_scvi, embed_by_pca, align_by_centering, aggregate_by_mean\n", + "from efaar_benchmarking.efaar import *\n", "from efaar_benchmarking.benchmarking import benchmark\n", "from efaar_benchmarking.plotting import plot_recall\n", "\n", @@ -60,18 +60,20 @@ "metadata": {}, "outputs": [], "source": [ - "adata = load_replogle(\"essential\", \"normalized\")\n", + "adata = load_replogle(\"genome_wide\", \"normalized\")\n", "metadata = adata.obs\n", - "embeddings_pca = embed_by_pca(adata)\n", - "embeddings_aligned = align_by_centering(embeddings_pca, metadata)\n", - "map_data = aggregate_by_mean(embeddings_aligned, metadata)\n", + "embeddings = embed_by_pca_anndata(adata)\n", + "del adata\n", + "embeddings = align_on_controls(embeddings, metadata)\n", + "map_data = aggregate(embeddings, metadata)\n", + "del embeddings, metadata\n", "metrics = benchmark(map_data, recall_thr_pairs=recall_threshold_pairs)\n", - "plot_recall(metrics)" + "plot_recall(metrics)\n", + "metrics.groupby('source')['recall_0.05_0.95'].mean()" ] }, { "cell_type": "markdown", - "id": "1eaf464a", "metadata": {}, "source": [ "## scVI Embeddings" @@ -86,9 +88,9 @@ "source": [ "adata = load_replogle(\"essential\", \"raw\")\n", "metadata = adata.obs\n", - "embeddings_scvi = embed_by_scvi(adata)\n", - "embeddings_aligned = align_by_centering(embeddings_scvi, metadata)\n", - "map_data = aggregate_by_mean(embeddings_aligned, metadata)\n", + "embeddings = embed_by_scvi_anndata(adata)\n", + "embeddings = align_on_controls(embeddings, metadata)\n", + "map_data = aggregate(embeddings, metadata)\n", "metrics = benchmark(map_data, recall_thr_pairs=recall_threshold_pairs)\n", "plot_recall(metrics)" ] diff --git a/requirements/dev_3.10.txt b/requirements/dev_3.10.txt index 30eed72..603c85b 100644 --- a/requirements/dev_3.10.txt +++ b/requirements/dev_3.10.txt @@ -566,7 +566,7 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data -pyarrow==13.0.0 +pyarrow==14.0.1 # via efaar-benchmarking (pyproject.toml) pyasn1==0.5.0 # via diff --git a/requirements/dev_3.11.txt b/requirements/dev_3.11.txt index 61a0e1b..21d13ab 100644 --- a/requirements/dev_3.11.txt +++ b/requirements/dev_3.11.txt @@ -560,7 +560,7 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data -pyarrow==13.0.0 +pyarrow==14.0.1 # via efaar-benchmarking (pyproject.toml) pyasn1==0.5.0 # via diff --git a/requirements/dev_3.9.txt b/requirements/dev_3.9.txt index 05429de..7559aa8 100644 --- a/requirements/dev_3.9.txt +++ b/requirements/dev_3.9.txt @@ -576,7 +576,7 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data -pyarrow==13.0.0 +pyarrow==14.0.1 # via efaar-benchmarking (pyproject.toml) pyasn1==0.5.0 # via diff --git a/requirements/main_3.10.txt b/requirements/main_3.10.txt index b15014b..7d5d601 100644 --- a/requirements/main_3.10.txt +++ b/requirements/main_3.10.txt @@ -344,7 +344,7 @@ protobuf==4.24.4 # orbax-checkpoint psutil==5.9.6 # via lightning -pyarrow==13.0.0 +pyarrow==14.0.1 # via efaar-benchmarking (pyproject.toml) pyasn1==0.5.0 # via diff --git a/requirements/main_3.11.txt b/requirements/main_3.11.txt index 585b680..d842965 100644 --- a/requirements/main_3.11.txt +++ b/requirements/main_3.11.txt @@ -340,7 +340,7 @@ protobuf==4.24.4 # orbax-checkpoint psutil==5.9.6 # via lightning -pyarrow==13.0.0 +pyarrow==14.0.1 # via efaar-benchmarking (pyproject.toml) pyasn1==0.5.0 # via diff --git a/requirements/main_3.9.txt b/requirements/main_3.9.txt index 64f6d93..94c217a 100644 --- a/requirements/main_3.9.txt +++ b/requirements/main_3.9.txt @@ -348,7 +348,7 @@ protobuf==4.24.4 # orbax-checkpoint psutil==5.9.6 # via lightning -pyarrow==13.0.0 +pyarrow==14.0.1 # via efaar-benchmarking (pyproject.toml) pyasn1==0.5.0 # via