Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 9, 2024
1 parent 2e9cde0 commit 8c9da93
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 45 deletions.
2 changes: 1 addition & 1 deletion src/scirpy/pl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ._clonotypes import COLORMAP_EDGES, clonotype_network
from ._diversity import alpha_diversity
from ._group_abundance import group_abundance
from ._logoplots import logoplot_cdr3_motif
from ._repertoire_overlap import repertoire_overlap
from ._spectratype import spectratype
from ._vdj_usage import vdj_usage
from ._logoplots import logoplot_cdr3_motif
from .base import embedding
102 changes: 58 additions & 44 deletions src/scirpy/pl/_logoplots.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from scirpy.util import DataHandler
from typing import Callable, Literal, Union
from collections.abc import Sequence
from scirpy.get import obs_context
from scirpy.get import airr as get_airr

import numpy as np
import pandas as pd
from typing import Literal, Union

import palmotif as palm
import pandas as pd
from IPython.display import SVG

from scirpy.get import airr as get_airr
from scirpy.util import DataHandler


@DataHandler.inject_param_docs()
def logoplot_cdr3_motif(
adata: DataHandler.TYPE,
chains: Union[
Literal["VJ_1", "VDJ_1", "VJ_2", "VDJ_2"],
Sequence[Literal["VJ_1", "VDJ_1", "VJ_2", "VDJ_2"]],
Literal["VJ_1", "VDJ_1", "VJ_2", "VDJ_2"],
Sequence[Literal["VJ_1", "VDJ_1", "VJ_2", "VDJ_2"]],
] = "VDJ_1",
airr_mod="airr",
airr_key="airr",
Expand All @@ -29,8 +28,9 @@ def logoplot_cdr3_motif(
clonotype_key: Union[None, str] = None,
cdr_len: int,
plot: bool = True,
color_scheme: Sequence[Literal["nucleotide", "base_pairing", "hydrophobicity", "chemistry", "charge", "taylor",
"logojs", "shapely"]] = "taylor"
color_scheme: Sequence[
Literal["nucleotide", "base_pairing", "hydrophobicity", "chemistry", "charge", "taylor", "logojs", "shapely"]
] = "taylor",
):
"""
A user friendly wrapper function for the palmotif python package.
Expand Down Expand Up @@ -68,38 +68,42 @@ def logoplot_cdr3_motif(
set to false to retrieve the raw sequence motif for customised use
color_scheme
different color schemes used by palmotif. see https://github.com/agartland/palmotif/blob/master/palmotif/aacolors.py for more details
Returns
----------
-------
Depending on `plot` either returns a SVG object or the calculated sequence motif as a pd.DataFrame
"""
params = DataHandler(adata, airr_mod, airr_key, chain_idx_key)

if by is "length":

if by == "length":
airr_df = get_airr(params, [cdr3_col], chains)
if type(chains) is list:
if len(chains) > 2:
raise Exception("Only two different chains are allowed e.g. VDJ_1 and VDJ_2")

else:
cdr3_list = airr_df[airr_df[chains[0] + "_" + cdr3_col].str.len() == cdr_len][chains[0] + "_" + cdr3_col].to_list()
cdr3_list += airr_df[airr_df[chains[1] + "_" + cdr3_col].str.len() == cdr_len][chains[1] + "_" + cdr3_col].to_list()
cdr3_list = airr_df[airr_df[chains[0] + "_" + cdr3_col].str.len() == cdr_len][
chains[0] + "_" + cdr3_col
].to_list()
cdr3_list += airr_df[airr_df[chains[1] + "_" + cdr3_col].str.len() == cdr_len][
chains[1] + "_" + cdr3_col
].to_list()
motif = palm.compute_motif(cdr3_list)
else:
motif = palm.compute_motif(
airr_df[airr_df[chains + "_" + cdr3_col].str.len() == cdr_len][chains + "_" + cdr3_col].to_list()
)
)

if plot:
return SVG(palm.svg_logo(motif, return_str = False, color_scheme=color_scheme))
return SVG(palm.svg_logo(motif, return_str=False, color_scheme=color_scheme))
else:
return motif


if by is "gene_segment":
if by == "gene_segment":
if target_col is None or gene_annotation is None:
raise Exception("Please specify where the gene information is stored (`target_col`) and which genes to include (`gene_annotation`) as a list")
raise Exception(
"Please specify where the gene information is stored (`target_col`) and which genes to include (`gene_annotation`) as a list"
)
if type(gene_annotation) is not list:
gene_annotation = list(gene_annotation.split(" "))

Expand All @@ -108,27 +112,35 @@ def logoplot_cdr3_motif(
if len(chains) > 2:
raise Exception("Only two different chains are allowed e.g. VDJ_1 and VDJ_2")

cdr3_list = airr_df[(airr_df[chains[0] + "_" + target_col].isin(gene_annotation)) &
(airr_df[chains[0] + "_" + cdr3_col].str.len() == cdr_len)][chains[0] + "_" + cdr3_col].to_list()
cdr3_list += airr_df[(airr_df[chains[1] + "_" + target_col].isin(gene_annotation)) &
(airr_df[chains[1] + "_" + cdr3_col].str.len() == cdr_len)][chains[1] + "_" + cdr3_col].to_list()
cdr3_list = airr_df[
(airr_df[chains[0] + "_" + target_col].isin(gene_annotation))
& (airr_df[chains[0] + "_" + cdr3_col].str.len() == cdr_len)
][chains[0] + "_" + cdr3_col].to_list()
cdr3_list += airr_df[
(airr_df[chains[1] + "_" + target_col].isin(gene_annotation))
& (airr_df[chains[1] + "_" + cdr3_col].str.len() == cdr_len)
][chains[1] + "_" + cdr3_col].to_list()
motif = palm.compute_motif(cdr3_list)

else:
motif = palm.compute_motif(
airr_df[(airr_df[chains + "_" + target_col].isin(gene_annotation)) &
(airr_df[chains + "_" + cdr3_col].str.len() == cdr_len)][chains + "_" + cdr3_col].to_list()
)
airr_df[
(airr_df[chains + "_" + target_col].isin(gene_annotation))
& (airr_df[chains + "_" + cdr3_col].str.len() == cdr_len)
][chains + "_" + cdr3_col].to_list()
)

if plot:
return SVG(palm.svg_logo(motif, return_str = False, color_scheme=color_scheme))
return SVG(palm.svg_logo(motif, return_str=False, color_scheme=color_scheme))
else:
return motif
if by is "clonotype":

if by == "clonotype":
if clonotype_id is None or clonotype_key is None:
raise Exception("Please select desired clonotype cluster and the name of the column where this information is stored!")

raise Exception(
"Please select desired clonotype cluster and the name of the column where this information is stored!"
)

if type(clonotype_id) is not list:
clonotype_id = list(clonotype_id.split(" "))

Expand All @@ -139,25 +151,27 @@ def logoplot_cdr3_motif(
airr_df = pd.concat([airr_df, params.get_obs(clonotype_key)])
airr_df = airr_df.loc[params.get_obs(clonotype_key).isin(clonotype_id)]



if type(chains) is list:
if len(chains) > 2:
raise Exception("Only two different chains are allowed e.g. VDJ_1 and VDJ_2")

else:
cdr3_list = airr_df[airr_df[chains[0] + "_" + cdr3_col].str.len() == cdr_len][chains[0] + "_" + cdr3_col].to_list()
cdr3_list += airr_df[airr_df[chains[1] + "_" + cdr3_col].str.len() == cdr_len][chains[1] + "_" + cdr3_col].to_list()
cdr3_list = airr_df[airr_df[chains[0] + "_" + cdr3_col].str.len() == cdr_len][
chains[0] + "_" + cdr3_col
].to_list()
cdr3_list += airr_df[airr_df[chains[1] + "_" + cdr3_col].str.len() == cdr_len][
chains[1] + "_" + cdr3_col
].to_list()
motif = palm.compute_motif(cdr3_list)
else:
motif = palm.compute_motif(
airr_df[airr_df[chains + "_" + cdr3_col].str.len() == cdr_len][chains + "_" + cdr3_col].to_list()
)
)

if plot:
return SVG(palm.svg_logo(motif, return_str = False, color_scheme=color_scheme))
return SVG(palm.svg_logo(motif, return_str=False, color_scheme=color_scheme))
else:
return motif

else:
raise Exception("Invalid input for parameter `by`!")
raise Exception("Invalid input for parameter `by`!")

0 comments on commit 8c9da93

Please sign in to comment.