Skip to content

Commit

Permalink
add tvn alignment
Browse files Browse the repository at this point in the history
  • Loading branch information
safiyecelik committed Dec 21, 2023
1 parent af78cee commit 7b92f4d
Showing 1 changed file with 45 additions and 21 deletions.
66 changes: 45 additions & 21 deletions efaar_benchmarking/efaar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
import scanpy as sc
import scipy.linalg as linalg
from scvi.model import SCVI
from sklearn.covariance import EllipticEnvelope
from sklearn.decomposition import PCA
Expand Down Expand Up @@ -50,9 +51,11 @@ def embed_by_pca_anndata(adata, n_latent: int = 100) -> np.ndarray:
return adata.obsm["X_pca"]


def centerscale(features: np.ndarray, metadata: pd.DataFrame = None, plate_col: Optional[str] = None) -> np.ndarray:
def centerscale_on_plate(
features: np.ndarray, metadata: pd.DataFrame = None, plate_col: Optional[str] = None
) -> np.ndarray:
"""
Center and scale the input features.
Center and scale the input features based on the plate information.
Args:
features (np.ndarray): Input features to be centered and scaled.
Expand All @@ -74,14 +77,14 @@ def centerscale(features: np.ndarray, metadata: pd.DataFrame = None, plate_col:
return features


def embed_align_by_pca(
def embed_by_pca(
features: np.ndarray,
metadata: pd.DataFrame = None,
variance_or_ncomp=100,
plate_col: Optional[str] = None,
) -> np.ndarray:
"""
Embed the input data using principal component analysis (PCA).
Embed the whole input data using principal component analysis (PCA).
Note that we explicitly center & scale the data by plate before and after calling `PCA`.
Centering and scaling is done by plate if `plate_col` is not None, and on the whole data otherwise.
Note that `PCA` transformer also does mean-centering on the whole data prior to the PCA operation.
Expand All @@ -98,45 +101,69 @@ def embed_align_by_pca(
np.ndarray: Transformed data using PCA.
"""

features = centerscale(features, metadata, plate_col)
features = centerscale_on_plate(features, metadata, plate_col)
features = PCA(variance_or_ncomp).fit_transform(features)
features = centerscale(features, metadata, plate_col)
features = centerscale_on_plate(features, metadata, plate_col)

return features


def align_on_controls(
def centerscale_on_controls(
embeddings: np.ndarray,
metadata: pd.DataFrame,
pert_col: str,
control_key: str,
scale: bool = True,
pert_col: str = cst.REPLOGLE_PERT_LABEL_COL,
control_key: str = cst.REPLOGLE_CONTROL_PERT_LABEL,
) -> np.ndarray:
"""
Center the embeddings by the control perturbation units in the metadata.
Center and scale the embeddings by the control perturbation units in the metadata.
Args:
embeddings (numpy.ndarray): The embeddings to be aligned.
metadata (pandas.DataFrame): The metadata containing information about the embeddings.
scale (bool): Whether to scale the embeddings besides centering. Defaults to True.
pert_col (str, optional): The column in the metadata containing perturbation information.
Defaults to cst.REPLOGLE_PERT_LABEL_COL.
control_key (str, optional): The key for non-targeting controls in the metadata.
Defaults to cst.CONTROL_PERT_LABEL.
scale (bool): Whether to scale the embeddings besides centering. Defaults to True.
Returns:
numpy.ndarray: The aligned embeddings.
"""
control_ind = metadata[pert_col] == control_key
ss = StandardScaler() if scale else StandardScaler(with_std=False)
ss.fit(embeddings[metadata[pert_col].values == control_key])
return ss.transform(embeddings)
return ss.fit(embeddings[control_ind]).transform(embeddings)


def tvn_on_controls(
embeddings: np.ndarray,
metadata: pd.DataFrame,
pert_col: str,
control_key: str,
) -> np.ndarray:
"""
Apply TVN (Typical Variation Normalization) to the data based on the control perturbation units.
Args:
embeddings (np.ndarray): The embeddings to be normalized.
metadata (pd.DataFrame): The metadata containing information about the samples.
pert_col (str, optional): The column name in the metadata DataFrame that represents the perturbation labels.
control_key (str, optional): The control perturbation label.
Returns:
np.ndarray: The normalized embeddings.
"""
ctrl_ind = metadata[pert_col] == control_key
embeddings = PCA().fit(embeddings[ctrl_ind]).transform(embeddings)
embeddings = centerscale_on_controls(embeddings, metadata, pert_col, control_key)
source_cov = np.cov(embeddings[ctrl_ind], rowvar=False, ddof=1) + 0.5 * np.eye(embeddings.shape[1])
source_cov_half_inv = linalg.fractional_matrix_power(source_cov, -0.5)
return np.matmul(embeddings, source_cov_half_inv)


def aggregate(
embeddings: np.ndarray,
metadata: pd.DataFrame,
pert_col: str = cst.REPLOGLE_PERT_LABEL_COL,
control_key: str = cst.REPLOGLE_CONTROL_PERT_LABEL,
pert_col: str,
control_key: str,
method="mean",
) -> Bunch[pd.DataFrame, pd.DataFrame]:
"""
Expand All @@ -146,9 +173,7 @@ def aggregate(
embeddings (numpy.ndarray): The embeddings to be aggregated.
metadata (pandas.DataFrame): The metadata containing information about the embeddings.
pert_col (str, optional): The column in the metadata containing perturbation information.
Defaults to cst.REPLOGLE_PERT_LABEL_COL.
control_key (str, optional): The key for non-targeting controls in the metadata.
Defaults to cst.CONTROL_PERT_LABEL.
method (str, optional): The aggregation method to use. Must be either "mean" or "median".
Defaults to "mean".
Expand All @@ -172,7 +197,7 @@ def aggregate(


def filter_to_perturbations(
features: pd.DataFrame, metadata: pd.DataFrame, perts: list[str], pert_col: str = cst.REPLOGLE_PERT_LABEL_COL
features: pd.DataFrame, metadata: pd.DataFrame, perts: list[str], pert_col: str
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""
Filters the features and metadata dataframes based on a list of perturbations.
Expand All @@ -182,7 +207,6 @@ def filter_to_perturbations(
metadata (pd.DataFrame): The metadata dataframe.
perts (list[str]): A list of perturbations to filter.
pert_col (str, optional): The column name in the metadata dataframe that contains the perturbation labels.
Defaults to cst.REPLOGLE_PERT_LABEL_COL.
Returns:
tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the filtered features and metadata dataframes.
Expand Down

0 comments on commit 7b92f4d

Please sign in to comment.