Skip to content

Commit

Permalink
Speed up clonotype distance calculation (#470)
Browse files Browse the repository at this point in the history
* compute_distances new version added

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* data type for csr max computation set to int16 to allow negative values

* data type also for csr_min set to int16 to allow negative values

* raise AssertainError instead of assert False

* matrix shape changed in lookup function

* array size changed for v-gene and column matching

* naming conventions

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use self._chain_count2 for asymmetric matrix for filter_chain_count

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* reverse table assignment adjusted in lookup function

* changed AssertionError to NotImplementedError

* removed unnecessary for loop

* refactored lookup function

* refactored lookup function and added docstring

* adapted docsting of lookup function

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* j_gene matching implemented

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fixes typo in j_gene parameter

* bind loop variables in function match_gene_segment

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* set datatypes of csr matrix arrays explicitely

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* docstring and data type checks added for  csr_min and csr_max

* Changed data types of indices and indptr arrays for the csr matrices in _dist_for_clonotype and lookup to int64 instead of int32.

* Changed data type checks to max value checks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* documentation adaptions

* update pre-commit config

* reformat

* deleted print statement

* Implemented allowing graph partitioning method "fastgreedy"

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* test case for j_gene added

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* refactored filter_chain_count and added documentation

* documentation and refactoring

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* documentation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Rerun tutorial

* Update CHANGELOG

* Attempt to fix conda CI

* update conda CI

* Fix python version conda ci

* override python version

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Gregor Sturm <[email protected]>
  • Loading branch information
3 people authored Sep 8, 2024
1 parent 83e0961 commit cbc01d1
Show file tree
Hide file tree
Showing 8 changed files with 402 additions and 237 deletions.
13 changes: 6 additions & 7 deletions .github/workflows/conda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,26 @@ jobs:
matrix:
include:
- os: ubuntu-latest
python: "3.12"
python: "3.11"

env:
OS: ${{ matrix.os }}
PYTHON: ${{ matrix.python }}

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
miniforge-variant: Mambaforge
miniforge-version: latest
mamba-version: "*"
channels: conda-forge,bioconda
channel-priority: strict
python-version: ${{ matrix.python-version }}
python-version: ${{ matrix.python }}

- name: install conda build
run: |
mamba install -y boa conda-verify
mamba install -y boa conda-verify python=${{ matrix.python }}
shell: bash

- name: build and test package
Expand Down
10 changes: 6 additions & 4 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@ and this project adheres to [Semantic Versioning][].
[keep a changelog]: https://keepachangelog.com/en/1.0.0/
[semantic versioning]: https://semver.org/spec/v2.0.0.html

## [Unreleased]
## v0.18.0

### Additions

- Isotypically included B cells are now labelled as `receptor_subtype="IGH+IGK/L"` instead of `ambiguous` in `tl.chain_qc` ([#537](https://github.com/scverse/scirpy/pull/537)).
- Added the `normalized_hamming` metric to `pp.ir_dist` that accounts for differences in CDR3 sequence length ([#512](https://github.com/scverse/scirpy/pull/512)).
- `tl.define_clonotype_clusters` now has an option to require J genes to match (`same_j_gene=True`) in addition to `same_v_gene`. ([#470](https://github.com/scverse/scirpy/pull/470)).

### Performance improvements

- The hamming distance was reimplemented with numba, achieving a significant speedup ([#512](https://github.com/scverse/scirpy/pull/512)).
- The hamming distance has been reimplemented with numba, achieving a significant speedup ([#512](https://github.com/scverse/scirpy/pull/512)).
- Clonotype clustering has been accelerated leveraging sparse matrix operations ([#470](https://github.com/scverse/scirpy/pull/470)).

### Fixes

- Fix that pl.clonotype_network couldn't use non-standard obsm key ([#545](https://github.com/scverse/scirpy/pull/545)).
- Fix that `pl.clonotype_network` couldn't use non-standard obsm key ([#545](https://github.com/scverse/scirpy/pull/545)).

### Other changes

Expand Down Expand Up @@ -54,7 +56,7 @@ and this project adheres to [Semantic Versioning][].

### Fixes

- Fix issue with detecting the number of available CPUs on MacOD ([#518](https://github.com/scverse/scirpy/pull/502))
- Fix issue with detecting the number of available CPUs on MacOS ([#518](https://github.com/scverse/scirpy/pull/502))

## v0.16.1

Expand Down
108 changes: 17 additions & 91 deletions docs/tutorials/tutorial_3k_tcr.ipynb

Large diffs are not rendered by default.

369 changes: 273 additions & 96 deletions src/scirpy/ir_dist/_clonotype_neighbors.py

Large diffs are not rendered by default.

101 changes: 65 additions & 36 deletions src/scirpy/ir_dist/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,61 +233,90 @@ def n_cols(self):

def lookup(
self,
object_id: int,
forward_lookup_table: str,
reverse_lookup_table: str | None = None,
) -> coo_matrix | np.ndarray:
"""Get ids of neighboring objects from a lookup table.
Performs the following lookup:
object_ids: np.ndarray[int],
forward_lookup_table_name: str,
reverse_lookup_table_name: str | None = None,
) -> sp.csr_matrix:
"""
Creates a distance matrix between objects with the given ids based on a feature distance matrix.
object_id -> dist_mat -> neighboring features -> neighboring objects.
To get the distance between two objects we need to look up the features of the two objects.
The distance between those two features is then the distance between the two objects.
where an object is a clonotype in our case (but could be used for something else)
To do so, we first use the `object_ids` together with the `forward_lookup_table` to look up
the indices of the objects in the feature `distance_matrix`. Afterwards we pick the according row for each object
out of the `distance_matrix` and construct a `rows` matrix (n_object_ids x n_features).
"nan"s are not looked up via the distance matrix, they return a row of zeros
"nan"s (index = -1) are not looked up in the feature `distance_matrix`, they return a row of zeros
instead.
Then we use the entries of the `reverse_lookup_table` to construct a `reverse_lookup_matrix` (n_features x n_object_ids).
By multiplying the `rows` matrix with the `reverse_lookup_matrix` we get the final `object_distance_matrix` that shows
the distances between the objects with the given `object_ids` regarding a certain feature column.
It might not be obvious at the first sight that the matrix multiplication between `rows` and `reverse_lookup_matrix` gives
us the desired result. But this trick allows us to use the built-in sparse matrix multiplication of `scipy.sparse`
for enhanced performance.
Parameters
----------
object_id
The row index of the feature_table.
forward_lookup_table
object_ids
The row indices of the feature_table.
forward_lookup_table_name
The unique identifier of a lookup table previously added via
`add_lookup_table`.
reverse_lookup_table
reverse_lookup_table_name
The unique identifier of the lookup table used for the reverse lookup.
If not provided will use the same lookup table for forward and reverse
lookup. This is useful to calculate distances across features from
different columns of the feature table (e.g. primary and secondary VJ chains).
Returns
-------
object_distance_matrix
A CSR matrix containing the pairwise distances between objects with the
given `object_ids` regarding a certain feature column.
"""
distance_matrix_name, forward, reverse = self.lookups[forward_lookup_table]
distance_matrix_name, forward_lookup_table, reverse_lookup_table = self.lookups[forward_lookup_table_name]

if reverse_lookup_table is not None:
distance_matrix_name_reverse, _, reverse = self.lookups[reverse_lookup_table]
if reverse_lookup_table_name is not None:
distance_matrix_name_reverse, _, reverse_lookup_table = self.lookups[reverse_lookup_table_name]
if distance_matrix_name != distance_matrix_name_reverse:
raise ValueError("Forward and reverse lookup tablese must be defined " "on the same distance matrices.")

distance_matrix = self.distance_matrices[distance_matrix_name]
idx_in_dist_mat = forward[object_id]
if idx_in_dist_mat == -1: # nan
return reverse.empty()
else:
# get distances from the distance matrix...
row = distance_matrix[idx_in_dist_mat, :]

if reverse.is_boolean:
assert (
len(row.indices) == 1 # type: ignore
), "Boolean reverse lookup only works for identity distance matrices."
return reverse[row.indices[0]] # type: ignore
else:
# ... and get column indices directly from sparse row
# sum concatenates coo matrices
return merge_coo_matrices(
(reverse[i] * multiplier for i, multiplier in zip(row.indices, row.data, strict=False)), # type: ignore
shape=(1, reverse.size),
)

if np.max(distance_matrix.data) > np.iinfo(np.uint8).max:
raise OverflowError(
"The data values in the distance scipy.sparse.csr_matrix exceed the maximum value for uint8 (255)"
)

indices_in_dist_mat = forward_lookup_table[object_ids]
indptr = np.empty(distance_matrix.indptr.shape[0] + 1, dtype=np.int64)
indptr[:-1] = distance_matrix.indptr
indptr[-1] = indptr[-2]
distance_matrix_extended = sp.csr_matrix(
(distance_matrix.data.astype(np.uint8), distance_matrix.indices, indptr),
shape=(distance_matrix.shape[0] + 1, distance_matrix.shape[1]),
)
rows = distance_matrix_extended[indices_in_dist_mat, :]

reverse_matrix_data = [np.array([], dtype=np.uint8)] * rows.shape[1]
reverse_matrix_col = [np.array([], dtype=np.int64)] * rows.shape[1]
nnz_array = np.zeros(rows.shape[1], dtype=np.int64)

for key, value in reverse_lookup_table.lookup.items():
reverse_matrix_data[key] = value.data
reverse_matrix_col[key] = value.col
nnz_array[key] = value.nnz

data = np.concatenate(reverse_matrix_data)
col = np.concatenate(reverse_matrix_col)
indptr = np.concatenate([np.array([0], dtype=np.int64), np.cumsum(nnz_array)])

reverse_matrix = sp.csr_matrix((data, col, indptr), shape=(rows.shape[1], reverse_lookup_table.size))
object_distance_matrix = rows * reverse_matrix
return object_distance_matrix

def add_distance_matrix(
self,
Expand Down
Binary file not shown.
21 changes: 21 additions & 0 deletions src/scirpy/tests/test_clonotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import sys
from typing import cast

import anndata as ad
import numpy as np
import numpy.testing as npt
import pandas as pd
Expand Down Expand Up @@ -339,3 +340,23 @@ def test_clonotype_convergence(adata_clonotype):
categories=["convergent", "not convergent"],
),
)


def test_j_gene_matching():
from . import TESTDATA

data = ad.read_h5ad(TESTDATA / "clonotypes_test_data/j_gene_test_data.h5ad")

ir.tl.define_clonotype_clusters(
data,
sequence="nt",
metric="normalized_hamming",
receptor_arms="all",
dual_ir="any",
same_j_gene=True,
key_added="test_j_gene",
)

clustering = data.obs["test_j_gene"].tolist()
expected = ["0", "0", "0", "0", "0", "1", "1", "1", "1", "1", "1", "1", "1", "2", "2", "2", "2", "2"]
assert np.array_equal(clustering, expected)
17 changes: 14 additions & 3 deletions src/scirpy/tl/_clonotypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,10 @@ def define_clonotype_clusters(
receptor_arms: Literal["VJ", "VDJ", "all", "any"] = "all",
dual_ir: Literal["primary_only", "all", "any"] = "any",
same_v_gene: bool = False,
same_j_gene: bool = False,
within_group: Sequence[str] | str | None = "receptor_type",
key_added: str | None = None,
partitions: Literal["connected", "leiden"] = "connected",
partitions: Literal["connected", "leiden", "fastgreedy"] = "connected",
resolution: float = 1,
n_iterations: int = 5,
distance_key: str | None = None,
Expand Down Expand Up @@ -249,12 +250,19 @@ def define_clonotype_clusters(
partitions
How to find graph partitions that define a clonotype.
Possible values are `leiden`, for using the "Leiden" algorithm and
Possible values are `leiden`, for using the "Leiden" algorithm,
`fastgreedy` for using the "Fastgreedy" algorithm and
`connected` to find fully connected sub-graphs.
The difference is that the Leiden algorithm further divides
The difference is that the Leiden and Fastgreedy algorithms further divide
fully connected subgraphs into highly-connected modules.
"Leiden" finds the community structure of the graph using the
Leiden algorithm of Traag, van Eck & Waltman.
"Fastgreedy" finds the community structure of the graph according to the
algorithm of Clauset et al based on the greedy optimization of modularity.
resolution
`resolution` parameter for the leiden algorithm.
n_iterations
Expand Down Expand Up @@ -289,6 +297,7 @@ def define_clonotype_clusters(
receptor_arms=receptor_arms, # type: ignore
dual_ir=dual_ir, # type: ignore
same_v_gene=same_v_gene,
same_j_gene=same_j_gene,
match_columns=within_group,
distance_key=distance_key,
sequence_key="junction_aa" if sequence == "aa" else "junction",
Expand All @@ -304,6 +313,8 @@ def define_clonotype_clusters(
resolution_parameter=resolution,
n_iterations=n_iterations,
)
elif partitions == "fastgreedy":
part = g.community_fastgreedy().as_clustering()
else:
part = g.clusters(mode="weak")

Expand Down

0 comments on commit cbc01d1

Please sign in to comment.