From 7b92f4d14eba6638935575562c3c5d46e78930b6 Mon Sep 17 00:00:00 2001 From: Safiye Celik Date: Thu, 21 Dec 2023 10:42:38 -0500 Subject: [PATCH] add tvn alignment --- efaar_benchmarking/efaar.py | 66 +++++++++++++++++++++++++------------ 1 file changed, 45 insertions(+), 21 deletions(-) diff --git a/efaar_benchmarking/efaar.py b/efaar_benchmarking/efaar.py index 2a64c73..c4273d9 100644 --- a/efaar_benchmarking/efaar.py +++ b/efaar_benchmarking/efaar.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import scanpy as sc +import scipy.linalg as linalg from scvi.model import SCVI from sklearn.covariance import EllipticEnvelope from sklearn.decomposition import PCA @@ -50,9 +51,11 @@ def embed_by_pca_anndata(adata, n_latent: int = 100) -> np.ndarray: return adata.obsm["X_pca"] -def centerscale(features: np.ndarray, metadata: pd.DataFrame = None, plate_col: Optional[str] = None) -> np.ndarray: +def centerscale_on_plate( + features: np.ndarray, metadata: pd.DataFrame = None, plate_col: Optional[str] = None +) -> np.ndarray: """ - Center and scale the input features. + Center and scale the input features based on the plate information. Args: features (np.ndarray): Input features to be centered and scaled. @@ -74,14 +77,14 @@ def centerscale(features: np.ndarray, metadata: pd.DataFrame = None, plate_col: return features -def embed_align_by_pca( +def embed_by_pca( features: np.ndarray, metadata: pd.DataFrame = None, variance_or_ncomp=100, plate_col: Optional[str] = None, ) -> np.ndarray: """ - Embed the input data using principal component analysis (PCA). + Embed the whole input data using principal component analysis (PCA). Note that we explicitly center & scale the data by plate before and after calling `PCA`. Centering and scaling is done by plate if `plate_col` is not None, and on the whole data otherwise. Note that `PCA` transformer also does mean-centering on the whole data prior to the PCA operation. @@ -98,45 +101,69 @@ def embed_align_by_pca( np.ndarray: Transformed data using PCA. """ - features = centerscale(features, metadata, plate_col) + features = centerscale_on_plate(features, metadata, plate_col) features = PCA(variance_or_ncomp).fit_transform(features) - features = centerscale(features, metadata, plate_col) + features = centerscale_on_plate(features, metadata, plate_col) return features -def align_on_controls( +def centerscale_on_controls( embeddings: np.ndarray, metadata: pd.DataFrame, + pert_col: str, + control_key: str, scale: bool = True, - pert_col: str = cst.REPLOGLE_PERT_LABEL_COL, - control_key: str = cst.REPLOGLE_CONTROL_PERT_LABEL, ) -> np.ndarray: """ - Center the embeddings by the control perturbation units in the metadata. + Center and scale the embeddings by the control perturbation units in the metadata. Args: embeddings (numpy.ndarray): The embeddings to be aligned. metadata (pandas.DataFrame): The metadata containing information about the embeddings. - scale (bool): Whether to scale the embeddings besides centering. Defaults to True. pert_col (str, optional): The column in the metadata containing perturbation information. - Defaults to cst.REPLOGLE_PERT_LABEL_COL. control_key (str, optional): The key for non-targeting controls in the metadata. - Defaults to cst.CONTROL_PERT_LABEL. + scale (bool): Whether to scale the embeddings besides centering. Defaults to True. Returns: numpy.ndarray: The aligned embeddings. """ + control_ind = metadata[pert_col] == control_key ss = StandardScaler() if scale else StandardScaler(with_std=False) - ss.fit(embeddings[metadata[pert_col].values == control_key]) - return ss.transform(embeddings) + return ss.fit(embeddings[control_ind]).transform(embeddings) + + +def tvn_on_controls( + embeddings: np.ndarray, + metadata: pd.DataFrame, + pert_col: str, + control_key: str, +) -> np.ndarray: + """ + Apply TVN (Typical Variation Normalization) to the data based on the control perturbation units. + + Args: + embeddings (np.ndarray): The embeddings to be normalized. + metadata (pd.DataFrame): The metadata containing information about the samples. + pert_col (str, optional): The column name in the metadata DataFrame that represents the perturbation labels. + control_key (str, optional): The control perturbation label. + + Returns: + np.ndarray: The normalized embeddings. + """ + ctrl_ind = metadata[pert_col] == control_key + embeddings = PCA().fit(embeddings[ctrl_ind]).transform(embeddings) + embeddings = centerscale_on_controls(embeddings, metadata, pert_col, control_key) + source_cov = np.cov(embeddings[ctrl_ind], rowvar=False, ddof=1) + 0.5 * np.eye(embeddings.shape[1]) + source_cov_half_inv = linalg.fractional_matrix_power(source_cov, -0.5) + return np.matmul(embeddings, source_cov_half_inv) def aggregate( embeddings: np.ndarray, metadata: pd.DataFrame, - pert_col: str = cst.REPLOGLE_PERT_LABEL_COL, - control_key: str = cst.REPLOGLE_CONTROL_PERT_LABEL, + pert_col: str, + control_key: str, method="mean", ) -> Bunch[pd.DataFrame, pd.DataFrame]: """ @@ -146,9 +173,7 @@ def aggregate( embeddings (numpy.ndarray): The embeddings to be aggregated. metadata (pandas.DataFrame): The metadata containing information about the embeddings. pert_col (str, optional): The column in the metadata containing perturbation information. - Defaults to cst.REPLOGLE_PERT_LABEL_COL. control_key (str, optional): The key for non-targeting controls in the metadata. - Defaults to cst.CONTROL_PERT_LABEL. method (str, optional): The aggregation method to use. Must be either "mean" or "median". Defaults to "mean". @@ -172,7 +197,7 @@ def aggregate( def filter_to_perturbations( - features: pd.DataFrame, metadata: pd.DataFrame, perts: list[str], pert_col: str = cst.REPLOGLE_PERT_LABEL_COL + features: pd.DataFrame, metadata: pd.DataFrame, perts: list[str], pert_col: str ) -> tuple[pd.DataFrame, pd.DataFrame]: """ Filters the features and metadata dataframes based on a list of perturbations. @@ -182,7 +207,6 @@ def filter_to_perturbations( metadata (pd.DataFrame): The metadata dataframe. perts (list[str]): A list of perturbations to filter. pert_col (str, optional): The column name in the metadata dataframe that contains the perturbation labels. - Defaults to cst.REPLOGLE_PERT_LABEL_COL. Returns: tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the filtered features and metadata dataframes.