Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge trunk #25

Merged
merged 2 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions efaar_benchmarking/benchmarking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def pert_stats(
def benchmark(
map_data: Bunch,
benchmark_sources: list = cst.BENCHMARK_SOURCES,
pert_label_col: str = cst.PERT_LABEL_COL,
pert_label_col: str = cst.REPLOGLE_PERT_LABEL_COL,
recall_thr_pairs: list = cst.RECALL_PERC_THRS,
filter_on_pert_prints: bool = False,
pert_pval_thr: float = cst.PERT_SIG_PVAL_THR,
Expand Down Expand Up @@ -78,15 +78,15 @@ def benchmark(
"""

if not len(benchmark_sources) > 0 and all([src in benchmark_data_dir for src in benchmark_sources]):
AssertionError("Invalid benchmark source(s) provided.")
ValueError("Invalid benchmark source(s) provided.")
md = map_data.metadata
idx = (md[cst.PERT_SIG_PVAL_COL] <= pert_pval_thr) if filter_on_pert_prints else [True] * len(md)
features = map_data.features[idx].set_index(md[idx][pert_label_col]).rename_axis(index=None)
del map_data
if not len(features) == len(set(features.index)):
AssertionError("Duplicate perturbation labels in the map.")
ValueError("Duplicate perturbation labels in the map.")
if not len(features) >= min_req_entity_cnt:
AssertionError("Not enough entities in the map for benchmarking.")
ValueError("Not enough entities in the map for benchmarking.")
print(len(features), "perturbations exist in the map.")

metrics_lst = []
Expand All @@ -98,7 +98,9 @@ def benchmark(
random_seed_str = f"{rs1}_{rs2}"
null_cossim = generate_null_cossims(features, n_null_samples, rs1, rs2)
for s in benchmark_sources:
query_cossim = generate_query_cossims(features, get_benchmark_relationships(benchmark_data_dir, s))
rels = get_benchmark_relationships(benchmark_data_dir, s)
print(len(rels), "relationships exist in the benchmark source.")
query_cossim = generate_query_cossims(features, rels)
if len(query_cossim) > 0:
metrics_lst.append(
convert_metrics_to_df(
Expand Down
12 changes: 8 additions & 4 deletions efaar_benchmarking/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@

BENCHMARK_DATA_DIR = str(resources.files("efaar_benchmarking").joinpath("benchmark_annotations"))
BENCHMARK_SOURCES = ["Reactome", "HuMAP", "CORUM", "SIGNOR", "StringDB"]
PERT_LABEL_COL = "gene"
CONTROL_PERT_LABEL = "non-targeting"
PERT_SIG_PVAL_COL = "gene_pvalue"
PERT_SIG_PVAL_THR = 0.01
RECALL_PERC_THRS = [(0.05, 0.95), (0.1, 0.9)]
RANDOM_SEED = 42
RANDOM_COUNT = 3
N_NULL_SAMPLES = 5000
MIN_REQ_ENT_CNT = 20
PERT_SIG_PVAL_COL = "gene_pvalue"
PERT_SIG_PVAL_THR = 0.01
CONTROL_PERT_LABEL = "non-targeting"
REPLOGLE_PERT_LABEL_COL = "gene"
REPLOGLE_BATCH_COL = "gem_group"
JUMP_PERT_LABEL_COL = "Metadata_Symbol"
JUMP_PLATE_COL = "Metadata_Plate"
JUMP_BATCH_COL = "Metadata_Batch"
82 changes: 80 additions & 2 deletions efaar_benchmarking/data_loading.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,95 @@
import os
import shutil
import tempfile
import zipfile
from concurrent.futures import ThreadPoolExecutor, as_completed

import numpy as np
import pandas as pd
import scanpy as sc
import wget


def load_replogle(gene_type, data_type, data_path="data/"):
def load_cpg16_crispr(data_path: str = "data/") -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Load Replogle et al. 2020 single-cell RNA-seq data for K562 cells.
Load and return the JUMP-CP (cpg0016) CRISPR dataset.
The metadata is downloaded from here:
https://zenodo.org/records/7661296/files/jump-cellpainting/metadata-v0.5.0.zip?download=1
The cellprofiler features are downloaded from here:
https://cellpainting-gallery.s3.amazonaws.com/index.html#cpg0016-jump/
We read the metadata first, filter it to CRISPR plates, and download the features for these plates only.

Parameters:
data_path (str): Path to the directory containing the dataset files.

Returns:
tuple[pd.DataFrame, pd.DataFrame]: A tuple containing two DataFrames:
- features: A DataFrame containing the CRISPR dataset features.
- metadata: A DataFrame containing the CRISPR dataset metadata.
"""
metadata_source_path = "https://zenodo.org/records/7661296/files/jump-cellpainting/datasets-v0.5.0.zip?download=1"
plate_file_name = "plate.csv.gz"
well_file_name = "well.csv.gz"
crispr_file_name = "crispr.csv.gz"
plate_file_path = os.path.join(data_path, plate_file_name)
well_file_path = os.path.join(data_path, well_file_name)
crispr_file_path = os.path.join(data_path, crispr_file_name)
if not (os.path.exists(plate_file_path) and os.path.exists(well_file_path) and os.path.exists(crispr_file_path)):
path_to_zip_file = data_path + "tmp.zip"
wget.download(metadata_source_path, path_to_zip_file)
with zipfile.ZipFile(path_to_zip_file, "r") as zip_ref:
for name in zip_ref.namelist():
if name.endswith(plate_file_name) or name.endswith(well_file_name) or name.endswith(crispr_file_name):
with tempfile.TemporaryDirectory() as temp_dir:
zip_ref.extract(name, temp_dir)
shutil.move(os.path.join(temp_dir, name), os.path.join(data_path, os.path.basename(name)))
os.remove(path_to_zip_file)

plates = pd.read_csv(plate_file_path)
crispr_plates = plates.query("Metadata_PlateType=='CRISPR'")
wells = pd.read_csv(well_file_path)
well_plate = wells.merge(crispr_plates, on=["Metadata_Source", "Metadata_Plate"])
crispr = pd.read_csv(crispr_file_path)
metadata = well_plate.merge(crispr, on="Metadata_JCP2022")

cp_feature_source_formatter = (
"s3://cellpainting-gallery/cpg0016-jump/"
"{Metadata_Source}/workspace/profiles/"
"{Metadata_Batch}/{Metadata_Plate}/{Metadata_Plate}.parquet"
)

features_file_path = os.path.join(data_path, "cpg_features.parquet")
if not os.path.exists(features_file_path):
cripsr_plates = metadata[
["Metadata_Source", "Metadata_Batch", "Metadata_Plate", "Metadata_PlateType"]
].drop_duplicates()
features = []
with ThreadPoolExecutor(max_workers=10) as executer:
future_to_plate = {
executer.submit(
lambda path: pd.read_parquet(path, storage_options={"anon": True}),
cp_feature_source_formatter.format(**row.to_dict()),
): cp_feature_source_formatter.format(**row.to_dict())
for _, row in cripsr_plates.iterrows()
}
for future in as_completed(future_to_plate):
features.append(future.result())
pd.concat(features).to_parquet(features_file_path)
features = pd.read_parquet(features_file_path).dropna(axis=1)
return features, metadata


def load_replogle(gene_type: str, data_type: str, data_path: str = "data/") -> sc.AnnData:
"""
Load Replogle et al. 2022 single-cell RNA-seq data for K562 cells published here:
https://pubmed.ncbi.nlm.nih.gov/35688146/
Four types of K562 data and downloaded using the links at:
plus.figshare.com/articles/dataset/_Mapping_information-rich_genotype-phenotype_landscapes_with_genome-scale_Perturb-seq_Replogle_et_al_2022_processed_Perturb-seq_datasets/20029387

Parameters:
gene_type (str): Type of genes to load. Must be either 'essential' or 'genome_wide'.
data_type (str): Type of data to load. Must be either 'raw' or 'normalized'.
Normalized means Z-normalized by gemgroup.
data_path (str): Path to the directory where the data will be downloaded and saved.

Returns:
Expand Down
167 changes: 138 additions & 29 deletions efaar_benchmarking/efaar.py
Original file line number Diff line number Diff line change
@@ -1,75 +1,139 @@
from typing import Optional

import numpy as np
import pandas as pd
import scanpy as sc
from scvi.model import SCVI
from sklearn.covariance import EllipticEnvelope
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.utils import Bunch

import efaar_benchmarking.constants as cst


def embed_by_scvi(adata, batch_key="gem_group", n_latent=128, n_hidden=256):
def embed_by_scvi_anndata(adata, batch_col=cst.REPLOGLE_BATCH_COL, n_latent=128, n_hidden=256) -> np.ndarray:
"""
Embeds the input AnnData object using scVI.
Embed the input AnnData object using scVI.

Parameters:
adata (anndata.AnnData): The AnnData object to be embedded.
batch_key (str): The batch key in the AnnData object. Default is "gem_group".
n_latent (int): The number of latent dimensions. Default is 128.
n_hidden (int): The number of hidden dimensions. Default is 256.
Args:
adata (anndata.AnnData): The AnnData object to be embedded.
batch_col (str): The batch key in the AnnData object. Default is "gem_group".
n_latent (int): The number of latent dimensions. Default is 128.
n_hidden (int): The number of hidden dimensions. Default is 256.

Returns:
None
numpy.ndarray: Embedding of the input data using scVI.
"""
SCVI.setup_anndata(adata, batch_key=batch_key)
SCVI.setup_anndata(adata, batch_key=batch_col)
vae = SCVI(adata, n_hidden=n_hidden, n_latent=n_latent)
vae.train(use_gpu=True)
return vae.get_latent_representation()


def embed_by_pca(adata, n_latent=100):
def embed_by_pca_anndata(adata, n_latent=100) -> np.ndarray:
"""
Embeds the input data using PCA.
Embed the input data using principal component analysis (PCA).
Note that the data is centered by the `pca` function prior to PCA transformation.

Parameters:
adata (AnnData): Annotated data matrix.
n_latent (int): Number of principal components to use.
Args:
adata (AnnData): Annotated data matrix.
n_latent (int): Number of principal components to use.

Returns:
numpy.ndarray: Embedding of the input data using PCA.
numpy.ndarray: Embedding of the input data using PCA.
"""
sc.pp.pca(adata, n_comps=n_latent)
return adata.obsm["X_pca"]


def align_by_centering(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL, pert_col=cst.PERT_LABEL_COL):
def embed_align_by_pca(
features: np.ndarray,
metadata: pd.DataFrame = None,
variance_or_ncomp=100,
plate_col: Optional[str] = None,
) -> np.ndarray:
"""
Embed the input data using principal component analysis (PCA).
Note that we explicitly center & scale the data by plate before and after calling `PCA`.
Centering and scaling is done by plate if `plate_col` is not None, and on the whole data otherwise.
Note that `PCA` transformer also does mean-centering on the whole data prior to the PCA operation.
Args:
features (np.ndarray): Features to transform
metadata (pd.DataFrame): Metadata. Defaults to None.
variance_or_ncomp (float, optional): Variance or number of components to keep after PCA.
Defaults to 100 (n_components). If between 0 and 1, select the number of components such that
the amount of variance that needs to be explained is greater than the percentage specified.
plate_col (str, optional): Column name for plate metadata. Defaults to None.
Returns:
np.ndarray: Transformed data using PCA.
"""

def centerscale(features, metadata, plate_col):
if plate_col is None:
features = StandardScaler().fit_transform(features)
else:
if metadata is None:
raise ValueError("metadata must be provided if plate_col is not None")
unq_plates = metadata[plate_col].unique()
for plate in unq_plates:
ind = metadata[plate_col] == plate
features[ind, :] = StandardScaler().fit_transform(features[ind, :])
return features

features = centerscale(features, metadata, plate_col)
features = PCA(variance_or_ncomp).fit_transform(features)
features = centerscale(features, metadata, plate_col)

return features


def align_on_controls(
embeddings: np.ndarray,
metadata: pd.DataFrame,
scale: bool = True,
pert_col: str = cst.REPLOGLE_PERT_LABEL_COL,
control_key: str = cst.CONTROL_PERT_LABEL,
) -> np.ndarray:
"""
Applies the centerscale method to align embeddings based on the centering perturbations in the metadata.
Center the embeddings by the control perturbation units in the metadata.

Args:
embeddings (numpy.ndarray): The embeddings to be aligned.
metadata (pandas.DataFrame): The metadata containing information about the embeddings.
scale (bool): Whether to scale the embeddings besides centering. Defaults to True.
pert_col (str, optional): The column in the metadata containing perturbation information.
Defaults to cst.REPLOGLE_PERT_LABEL_COL.
control_key (str, optional): The key for non-targeting controls in the metadata.
Defaults to cst.CONTROL_PERT_LABEL.
pert_col (str, optional): The column in the metadata containing perturbation information.
Defaults to cst.PERT_LABEL_COL.

Returns:
numpy.ndarray: The aligned embeddings.
"""
ntc_idxs = np.where(metadata[pert_col].values == control_key)[0]
ntc_center = embeddings[ntc_idxs].mean(0)
return embeddings - ntc_center


def aggregate_by_mean(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL, pert_col=cst.PERT_LABEL_COL):
ss = StandardScaler() if scale else StandardScaler(with_std=False)
ss.fit(embeddings[metadata[pert_col].values == control_key])
return ss.transform(embeddings)


def aggregate(
embeddings: np.ndarray,
metadata: pd.DataFrame,
pert_col: str = cst.REPLOGLE_PERT_LABEL_COL,
control_key: str = cst.CONTROL_PERT_LABEL,
method="mean",
) -> Bunch[pd.DataFrame, pd.DataFrame]:
"""
Applies the mean aggregation to aggregate replicate embeddings for each perturbation.
Apply the mean or median aggregation to replicate embeddings for each perturbation.

Args:
embeddings (numpy.ndarray): The embeddings to be aggregated.
metadata (pandas.DataFrame): The metadata containing information about the embeddings.
control_key (str, optional): The key for non-targeting controls in the metadata. Defaults to "non-targeting".
PERT_COL (str, optional): The column in the metadata containing perturbation information. Defaults to "gene".
pert_col (str, optional): The column in the metadata containing perturbation information.
Defaults to cst.REPLOGLE_PERT_LABEL_COL.
control_key (str, optional): The key for non-targeting controls in the metadata.
Defaults to cst.CONTROL_PERT_LABEL.
method (str, optional): The aggregation method to use. Must be either "mean" or "median".
Defaults to "mean".

Returns:
Bunch: A named tuple containing two pandas DataFrames:
Expand All @@ -81,5 +145,50 @@ def aggregate_by_mean(embeddings, metadata, control_key=cst.CONTROL_PERT_LABEL,
final_embeddings = np.zeros((len(unique_perts), embeddings.shape[1]))
for i, pert in enumerate(unique_perts):
idxs = np.where(metadata[pert_col].values == pert)[0]
final_embeddings[i, :] = embeddings[idxs, :].mean(0)
if method == "mean":
final_embeddings[i, :] = np.mean(embeddings[idxs, :], axis=0)
elif method == "median":
final_embeddings[i, :] = np.median(embeddings[idxs, :], axis=0)
else:
raise ValueError(f"Invalid aggregation method: {method}")
return Bunch(features=pd.DataFrame(final_embeddings), metadata=pd.DataFrame.from_dict({pert_col: unique_perts}))


def filter_cpg16_crispr(
features: pd.DataFrame,
metadata: pd.DataFrame,
filter_by_intensity: bool = True,
filter_by_cell_count: bool = True,
drop_image_cols: bool = True,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Filter the given features and metadata dataframes based on various criteria.

Args:
features (pd.DataFrame): A dataframe containing the features to filter.
metadata (pd.DataFrame): A dataframe containing the metadata to filter.
filter_by_intensity (bool, optional): Whether to filter by intensity. Defaults to True.
filter_by_cell_count (bool, optional): Whether to filter by cell count. Defaults to True.
drop_image_cols (bool, optional): Whether to drop image columns. Defaults to True.

Returns:
Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the filtered features and metadata dataframes.
"""
if filter_by_intensity:
features = features.loc[
EllipticEnvelope(contamination=0.01, random_state=42).fit_predict(
features[[col for col in features.columns if "ImageQuality_MeanIntensity" in col]]
)
== 1
]
if filter_by_cell_count:
mask = np.full(len(features), True)
for colname in ["Cytoplasm_Number_Object_Number", "Nuclei_Number_Object_Number"]:
mask = mask & (features[colname] >= 50) & (features[colname] <= 350)
features = features.loc[mask]
if drop_image_cols:
features = features.drop(columns=[col for col in features.columns if col.startswith("Image_")])

metadata_cols = metadata.columns
merged_data = metadata.merge(features, on=["Metadata_Source", "Metadata_Plate", "Metadata_Well"])
return merged_data.drop(columns=metadata_cols), merged_data[metadata_cols]
1 change: 0 additions & 1 deletion efaar_benchmarking/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def plot_recall(df):
curr_ax.set_ylabel("Recall Value")
curr_ax.set_xticks(range(len(x_values_orig)))
curr_ax.set_xticklabels(x_values_orig, rotation=45)
curr_ax.set_ylim(0, 0.8)

plt.tight_layout()
plt.show()
Binary file removed notebooks/data/crispr.csv.gz
Binary file not shown.
Binary file removed notebooks/data/plate.csv.gz
Binary file not shown.
Binary file removed notebooks/data/well.csv.gz
Binary file not shown.
Loading