Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean jump notebook #21

Merged
merged 10 commits into from
Nov 10, 2023
12 changes: 7 additions & 5 deletions efaar_benchmarking/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def pert_stats(
def benchmark(
map_data: Bunch,
benchmark_sources: list = cst.BENCHMARK_SOURCES,
pert_label_col: str = cst.PERT_LABEL_COL,
pert_label_col: str = cst.REPLOGLE_PERT_LABEL_COL,
recall_thr_pairs: list = cst.RECALL_PERC_THRS,
filter_on_pert_prints: bool = False,
pert_pval_thr: float = cst.PERT_SIG_PVAL_THR,
Expand Down Expand Up @@ -78,15 +78,15 @@ def benchmark(
"""

if not len(benchmark_sources) > 0 and all([src in benchmark_data_dir for src in benchmark_sources]):
AssertionError("Invalid benchmark source(s) provided.")
ValueError("Invalid benchmark source(s) provided.")
md = map_data.metadata
idx = (md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr) if filter_on_pert_prints else [True] * len(md)
features = map_data.features[idx].set_index(md[idx][pert_label_col]).rename_axis(index=None)
del map_data
if not len(features) == len(set(features.index)):
AssertionError("Duplicate perturbation labels in the map.")
ValueError("Duplicate perturbation labels in the map.")
if not len(features) >= min_req_entity_cnt:
AssertionError("Not enough entities in the map for benchmarking.")
ValueError("Not enough entities in the map for benchmarking.")
print(len(features), "perturbations exist in the map.")

metrics_lst = []
Expand All @@ -98,7 +98,9 @@ def benchmark(
random_seed_str = f"{rs1}_{rs2}"
null_cossim = generate_null_cossims(features, n_null_samples, rs1, rs2)
for s in benchmark_sources:
query_cossim = generate_query_cossims(features, get_benchmark_relationships(benchmark_data_dir, s))
rels = get_benchmark_relationships(benchmark_data_dir, s)
print(len(rels), "relationships exist in the benchmark source.")
query_cossim = generate_query_cossims(features, rels)
if len(query_cossim) > 0:
metrics_lst.append(
convert_metrics_to_df(
Expand Down
12 changes: 8 additions & 4 deletions efaar_benchmarking/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

BENCHMARK_DATA_DIR = str(resources.files("efaar_benchmarking").joinpath("benchmark_annotations"))
BENCHMARK_SOURCES = ["Reactome", "HuMAP", "CORUM", "SIGNOR", "StringDB"]
PERT_LABEL_COL = "gene"
CONTROL_PERT_LABEL = "non-targeting"
PERT_SIG_PVAL_COL = "gene_pvalue"
PERT_SIG_PVAL_THR = 0.01
RECALL_PERC_THRS = [(0.05, 0.95), (0.1, 0.9)]
RANDOM_SEED = 42
RANDOM_COUNT = 3
N_NULL_SAMPLES = 5000
MIN_REQ_ENT_CNT = 20
PERT_SIG_PVAL_COL = "gene_pvalue"
PERT_SIG_PVAL_THR = 0.01
CONTROL_PERT_LABEL = "non-targeting"
REPLOGLE_PERT_LABEL_COL = "gene"
REPLOGLE_BATCH_COL = "gem_group"
JUMP_PERT_LABEL_COL = "Metadata_Symbol"
JUMP_PLATE_COL = "Metadata_Plate"
JUMP_BATCH_COL = "Metadata_Batch"
83 changes: 82 additions & 1 deletion efaar_benchmarking/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,98 @@
import os
import shutil
import tempfile
import zipfile
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import pandas as pd
import scanpy as sc
import wget
from sklearn.covariance import EllipticEnvelope

CP_FEATURE_FORMATTER = (
"s3://cellpainting-gallery/cpg0016-jump/"
"{Metadata_Source}/workspace/profiles/"
"{Metadata_Batch}/{Metadata_Plate}/{Metadata_Plate}.parquet"
)


def load_cpg16_crispr(data_path="data/", filter_by_intensity=True, filter_by_cell_count=True, drop_image_cols=True):
plate_file_name = "plate.csv.gz"
well_file_name = "well.csv.gz"
crispr_file_name = "crispr.csv.gz"
plate_file_path = os.path.join(data_path, plate_file_name)
well_file_path = os.path.join(data_path, well_file_name)
crispr_file_path = os.path.join(data_path, crispr_file_name)
if not (os.path.exists(plate_file_path) and os.path.exists(well_file_path) and os.path.exists(crispr_file_path)):
source_path = "https://zenodo.org/records/7661296/files/jump-cellpainting/datasets-v0.5.0.zip?download=1"
safiyecelik marked this conversation as resolved.
Show resolved Hide resolved
path_to_zip_file = data_path + "tmp.zip"
wget.download(source_path, path_to_zip_file)
with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
for name in zip_ref.namelist():
if name.endswith(plate_file_name) or name.endswith(well_file_name) or name.endswith(crispr_file_name):
with tempfile.TemporaryDirectory() as temp_dir:
zip_ref.extract(name, temp_dir)
shutil.move(os.path.join(temp_dir, name), os.path.join(data_path, os.path.basename(name)))
os.remove(path_to_zip_file)

plates = pd.read_csv(plate_file_path)
crispr_plates = plates.query("Metadata_PlateType=='CRISPR'")
wells = pd.read_csv(well_file_path)
well_plate = wells.merge(crispr_plates, on=["Metadata_Source", "Metadata_Plate"])
crispr = pd.read_csv(crispr_file_path)
metadata = well_plate.merge(crispr, on="Metadata_JCP2022")

features_file_path = os.path.join(data_path, "cpg_features.parquet")
if not os.path.exists(features_file_path):

def load_plate_features(path: str):
return pd.read_parquet(path, storage_options={"anon": True})

cripsr_plates = metadata[
["Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_PlateType"]
].drop_duplicates()
features = []
with ThreadPoolExecutor(max_workers=10) as executer:
future_to_plate = {
executer.submit(
load_plate_features, CP_FEATURE_FORMATTER.format(**row.to_dict())
): CP_FEATURE_FORMATTER.format(**row.to_dict())
for _, row in cripsr_plates.iterrows()
}
for future in as_completed(future_to_plate):
features.append(future.result())
pd.concat(features).to_parquet(features_file_path)
features = pd.read_parquet(features_file_path).dropna(axis=1)
if filter_by_intensity:
safiyecelik marked this conversation as resolved.
Show resolved Hide resolved
features = features.loc[
EllipticEnvelope(contamination=0.01, random_state=42).fit_predict(
features[[col for col in features.columns if "ImageQuality_MeanIntensity" in col]]
)
== 1
]
if filter_by_cell_count:
mask = np.full(len(features), True)
for colname in ["Cytoplasm_Number_Object_Number", "Nuclei_Number_Object_Number"]:
mask = mask & (features[colname] >= 50) & (features[colname] <= 350)
features = features.loc[mask]
if drop_image_cols:
features = features.drop(columns=[col for col in features.columns if col.startswith("Image_")])

metadata_cols = metadata.columns
merged_data = metadata.merge(features, on=["Metadata_Source", "Metadata_Plate", "Metadata_Well"])
return merged_data.drop(columns=metadata_cols), merged_data[metadata_cols]


def load_replogle(gene_type, data_type, data_path="data/"):
"""
Load Replogle et al. 2020 single-cell RNA-seq data for K562 cells.
Load Replogle et al. 2020 single-cell RNA-seq data for K562 cells from
plus.figshare.com/articles/dataset/_Mapping_information-rich_genotype-phenotype_landscapes_with_genome-scale_Perturb-seq_Replogle_et_al_2022_processed_Perturb-seq_datasets/20029387

Parameters:
gene_type (str): Type of genes to load. Must be either 'essential' or 'genome_wide'.
data_type (str): Type of data to load. Must be either 'raw' or 'normalized'.
Normalized means Z-normalized by gemgroup.
data_path (str): Path to the directory where the data will be downloaded and saved.

Returns:
Expand Down
112 changes: 85 additions & 27 deletions efaar_benchmarking/efaar.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,128 @@
from typing import Optional

import numpy as np
import pandas as pd
import scanpy as sc
from scvi.model import SCVI
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.utils import Bunch

import efaar_benchmarking.constants as cst


def embed_by_scvi(adata, batch_key="gem_group", n_latent=128, n_hidden=256):
def embed_by_scvi_anndata(adata, batch_col=cst.REPLOGLE_BATCH_COL, n_latent=128, n_hidden=256) -> np.ndarray:
"""
Embeds the input AnnData object using scVI.
Embed the input AnnData object using scVI.

Parameters:
adata (anndata.AnnData): The AnnData object to be embedded.
batch_key (str): The batch key in the AnnData object. Default is "gem_group".
n_latent (int): The number of latent dimensions. Default is 128.
n_hidden (int): The number of hidden dimensions. Default is 256.
Args:
adata (anndata.AnnData): The AnnData object to be embedded.
batch_col (str): The batch key in the AnnData object. Default is "gem_group".
n_latent (int): The number of latent dimensions. Default is 128.
n_hidden (int): The number of hidden dimensions. Default is 256.

Returns:
None
numpy.ndarray: Embedding of the input data using scVI.
"""
SCVI.setup_anndata(adata, batch_key=batch_key)
SCVI.setup_anndata(adata, batch_key=batch_col)
vae = SCVI(adata, n_hidden=n_hidden, n_latent=n_latent)
vae.train(use_gpu=True)
return vae.get_latent_representation()


def embed_by_pca(adata, n_latent=100):
def embed_by_pca_anndata(adata, n_latent=100) -> np.ndarray:
"""
Embeds the input data using PCA.
Embed the input data using principal component analysis (PCA).
Note that the data is centered by the `pca` function prior to PCA transformation.

Parameters:
adata (AnnData): Annotated data matrix.
n_latent (int): Number of principal components to use.
Args:
adata (AnnData): Annotated data matrix.
n_latent (int): Number of principal components to use.

Returns:
numpy.ndarray: Embedding of the input data using PCA.
numpy.ndarray: Embedding of the input data using PCA.
"""
sc.pp.pca(adata, n_comps=n_latent)
return adata.obsm["X_pca"]


def align_by_centering(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL, pert_col=cst.PERT_LABEL_COL):
def embed_align_by_pca(
features: np.ndarray,
metadata: pd.DataFrame = None,
variance_or_ncomp=100,
plate_col: Optional[str] = None,
) -> np.ndarray:
"""
Applies the centerscale method to align embeddings based on the centering perturbations in the metadata.
Embed the input data using principal component analysis (PCA).
Note that we explicitly center & scale the data by plate before and after calling `PCA`.
Centering and scaling is done by plate if `plate_col` is not None, and on the whole data otherwise.
Note that `PCA` transformer also does mean-centering on the whole data prior to the PCA operation.
Args:
features (np.ndarray): Features to transform
metadata (pd.DataFrame): Metadata. Defaults to None.
variance_or_ncomp (float, optional): Variance or number of components to keep after PCA.
Defaults to 100 (n_components). If between 0 and 1, select the number of components such that
the amount of variance that needs to be explained is greater than the percentage specified.
plate_col (str, optional): Column name for plate metadata. Defaults to None.
Returns:
np.ndarray: Transformed data using PCA.
"""

def centerscale(features, metadata, plate_col):
if plate_col is None:
features = StandardScaler().fit_transform(features)
else:
if metadata is None:
raise ValueError("metadata must be provided if plate_col is not None")
unq_plates = metadata[plate_col].unique()
for plate in unq_plates:
ind = metadata[plate_col] == plate
features[ind, :] = StandardScaler().fit_transform(features[ind, :])
return features

features = centerscale(features, metadata, plate_col)
features = PCA(variance_or_ncomp).fit_transform(features)
features = centerscale(features, metadata, plate_col)

return features


def align_on_controls(embeddings, metadata, pert_col=cst.REPLOGLE_PERT_LABEL_COL, control_key=cst.CONTROL_PERT_LABEL):
"""
Center the embeddings by the control perturbation units in the metadata.

Args:
embeddings (numpy.ndarray): The embeddings to be aligned.
metadata (pandas.DataFrame): The metadata containing information about the embeddings.
pert_col (str, optional): The column in the metadata containing perturbation information.
Defaults to cst.REPLOGLE_PERT_LABEL_COL.
control_key (str, optional): The key for non-targeting controls in the metadata.
Defaults to cst.CONTROL_PERT_LABEL.
pert_col (str, optional): The column in the metadata containing perturbation information.
Defaults to cst.PERT_LABEL_COL.

Returns:
numpy.ndarray: The aligned embeddings.
"""
ntc_idxs = np.where(metadata[pert_col].values == control_key)[0]
ntc_center = embeddings[ntc_idxs].mean(0)
return embeddings - ntc_center
# return embeddings - embeddings[metadata[pert_col].values == control_key].mean(0)
ss = StandardScaler()
ss.fit(embeddings[metadata[pert_col].values == control_key])
return ss.transform(embeddings)


def aggregate_by_mean(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL, pert_col=cst.PERT_LABEL_COL):
def aggregate(
embeddings, metadata, pert_col=cst.REPLOGLE_PERT_LABEL_COL, control_key=cst.CONTROL_PERT_LABEL, method="mean"
):
"""
Applies the mean aggregation to aggregate replicate embeddings for each perturbation.
Apply the mean or median aggregation to replicate embeddings for each perturbation.

Args:
embeddings (numpy.ndarray): The embeddings to be aggregated.
metadata (pandas.DataFrame): The metadata containing information about the embeddings.
control_key (str, optional): The key for non-targeting controls in the metadata. Defaults to "non-targeting".
PERT_COL (str, optional): The column in the metadata containing perturbation information. Defaults to "gene".
pert_col (str, optional): The column in the metadata containing perturbation information.
Defaults to cst.REPLOGLE_PERT_LABEL_COL.
control_key (str, optional): The key for non-targeting controls in the metadata.
Defaults to cst.CONTROL_PERT_LABEL.
method (str, optional): The aggregation method to use. Must be either "mean" or "median".
Defaults to "mean".

Returns:
Bunch: A named tuple containing two pandas DataFrames:
Expand All @@ -81,5 +134,10 @@ def aggregate_by_mean(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL,
final_embeddings = np.zeros((len(unique_perts), embeddings.shape[1]))
for i, pert in enumerate(unique_perts):
idxs = np.where(metadata[pert_col].values == pert)[0]
final_embeddings[i, :] = embeddings[idxs, :].mean(0)
if method == "mean":
final_embeddings[i, :] = embeddings[idxs, :].mean(0)
elif method == "median":
final_embeddings[i, :] = embeddings[idxs, :].median(0)
else:
raise ValueError(f"Invalid aggregation method: {method}")
return Bunch(features=pd.DataFrame(final_embeddings), metadata=pd.DataFrame.from_dict({pert_col: unique_perts}))
1 change: 0 additions & 1 deletion efaar_benchmarking/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def plot_recall(df):
curr_ax.set_ylabel("Recall Value")
curr_ax.set_xticks(range(len(x_values_orig)))
curr_ax.set_xticklabels(x_values_orig, rotation=45)
curr_ax.set_ylim(0, 0.8)

plt.tight_layout()
plt.show()
Loading