Skip to content

Commit

Permalink
Fix and enhance geometric node features (#305)
Browse files Browse the repository at this point in the history
* Remove obsolete `remove_insertions`

* Fix case when beta-carbon is missing in PDB

* Fix case when side-chain atoms are missing in PDB

* Test handling of missing beta-carbons

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Merge

* Remove `main`

* Fix data types.

Without this, `vec`s are of type object, which for example breks the smooth conversion to torch tensors.

* Implement `add_virtual_beta_carbon_vector`

* Enhance visalization of PyG data

* Test `test_add_virtual_beta_carbon_vector`

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update changelog

* Update changelog

* Generalize vector visualization for PyG format

* Do not convert Tensor to Tensor

* Fix k-NN graph on df with dropped residues

* Update changelog

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix outdated test

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Arian Jamasb <[email protected]>
  • Loading branch information
3 people authored Jul 31, 2023
1 parent 44a32a3 commit 7c99e57
Show file tree
Hide file tree
Showing 10 changed files with 242 additions and 27 deletions.
15 changes: 12 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
### 1.7.1 - UNRELEASED

#### New Features
* [Feature] - [#305](https://github.com/a-r-j/graphein/pull/305) Adds the `add_virtual_beta_carbon_vector` function inspired by [RFdiffusion](https://github.com/RosettaCommons/RFdiffusion/blob/main/rfdiffusion/coords6d.py#L37) and ProteinMPNN.

#### API Changes
* Chain selections are now specified with either `"all"` or a list of strings (e.g. `["A", "B"]`) rather than a single selection string (e.g. `"AB"`). This is a necessary chain due to MMTF support which can have multicharacter chain identifiers. [#307](https://github.com/a-r-j/graphein/pull/307)

#### Improvements
* [Bugfix] - [#305](https://github.com/a-r-j/graphein/pull/305) Fixes `add_k_nn_edges` for the case when some residues were dropped before (e.g. when some alt_locs are removed).
* [Bugfix] - [#305](https://github.com/a-r-j/graphein/pull/305) Removes obsolete `remove_insertions` in [`rgroup_df` construction](https://github.com/a-r-j/graphein/blob/649a490505740a266b26976807e7f303c2a32ff0/graphein/protein/graphs.py#L540).
* [Bugfix] - [#305](https://github.com/a-r-j/graphein/pull/305) Fixes the construction of geometric features when beta-carbons or side chains are missing in non-glycine residues (for example in `H:CYS:104` in 3SE8).
* [Bugfix] - [#305](https://github.com/a-r-j/graphein/pull/305) Fixes data types of geometric feature vectors: `object` -> `float`.
* [Bugfix] - [#301](https://github.com/a-r-j/graphein/pull/301) Fixes the conversion of undirected NetworkX graph to directed PyG data.

#### Bugfixes
* Adds missing `stage` parameter to `graphein.ml.datasets.foldcomp_data.FoldCompDataModule.setup()`. [#310](https://github.com/a-r-j/graphein/pull/310)
Expand All @@ -15,17 +24,19 @@
* Fixes incorrect start padding in pNeRF output [#321](https://github.com/a-r-j/graphein/pull/321)

#### Other Changes
* Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312)
* Adds entry point for biopandas dataframes in `graphein.protein.tensor.io.protein_to_pyg`. [#310](https://github.com/a-r-j/graphein/pull/310)
* Adds support for `.ent` files to `graphein.protein.graphs.read_pdb_to_dataframe`. [#310](https://github.com/a-r-j/graphein/pull/310)
* Obsolete residues with no replacement are now returned by `graphein.protein.utils.get_obsolete_mapping`. [#310](https://github.com/a-r-j/graphein/pull/310)
* Adds the ability to store a dictionary of HETATM positions in `Data`/`Protein` objects created in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307)
* Improved handling of non-standard residues in the `graphein.protein.tensor` module. [#307](https://github.com/a-r-j/graphein/pull/307)
* Insertions retained by default in the `graphein.protein.tensor` module. I.e. `insertions=True` is now the default behaviour.[#307](https://github.com/a-r-j/graphein/pull/307)
* `plot_pyg_data` now also plots some geometric features if present. [#305](https://github.com/a-r-j/graphein/pull/305)
* Adds transform composition to FoldComp Dataset [#312](https://github.com/a-r-j/graphein/pull/312)
* Improve FoldComp dataloading performance and include B factors (pLDDT) in output. [#313](https://github.com/a-r-j/graphein/pull/313) [#315](https://github.com/a-r-j/graphein/pull/315)
* Add new helper functions to PDBManager [#322](https://github.com/a-r-j/graphein/pull/322) (@amorehead)

### 1.7.0 - UNRELEASED
### 1.7.0 - 10 /04/2023

#### New Features

Expand All @@ -34,8 +45,6 @@
* [ESM] - [#284](https://github.com/a-r-j/graphein/pull/284) - Wrapper for ESMFold batch folding & embedding.
* [Downloads] MMTF downloading now supported in download utilities. [#272](https://github.com/a-r-j/graphein/pull/272)

#### Improvements
* [Bugfix] - [#301](https://github.com/a-r-j/graphein/pull/301) Fixes the conversion of undirected NetworkX graph to directed PyG data.

#### API Changes
* The `pdb_path` argument to many functions (e.g. `graphein.protein.graphs.construct_graph`) has been renamed to `path` as this can now accept MMTF files in addition to PDB files.
Expand Down
3 changes: 2 additions & 1 deletion graphein/ml/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,8 @@ def convert_nx_to_pyg(self, G: nx.Graph) -> Data:
# Convert everything possible to torch.Tensors
for key, val in data.items():
try:
data[key] = torch.tensor(np.array(val))
if not isinstance(val, torch.Tensor):
data[key] = torch.tensor(np.array(val))
except Exception as e:
log.warning(e)
pass
Expand Down
33 changes: 31 additions & 2 deletions graphein/ml/visualisation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
"""Visualisation utils for ML."""
from __future__ import annotations

from typing import Optional, Tuple
import itertools
from typing import Iterable, Optional, Tuple

import matplotlib.pyplot as plt
import networkx as nx
import plotly.graph_objects as go

from graphein.protein.features.nodes.geometry import VECTOR_FEATURE_NAMES
from graphein.protein.visualisation import add_vector_to_plot
from graphein.utils.dependencies import import_message

from ..protein.visualisation import plotly_protein_structure_graph
Expand Down Expand Up @@ -48,6 +51,8 @@ def plot_pyg_data(
edge_colour_map=plt.cm.plasma,
colour_nodes_by: str = "residue_name",
colour_edges_by: Optional[str] = None,
node_vector_features: Iterable[str] = tuple(VECTOR_FEATURE_NAMES),
node_vector_feature_colours: Iterable[str] = ("red", "green", "blue"),
) -> go.Figure:
"""
Plots protein structure graph from ``torch_geometric.data.Data``
Expand Down Expand Up @@ -94,6 +99,11 @@ def plot_pyg_data(
:param colour_edges_by: Specifies how to colour edges. Currently only
``"kind"`` or ``None`` are supported.
:type colour_edges_by: Optional[str]
:param node_vector_features: Specifies node vector features to visualize. By default all
present features are plotted.
:type node_vector_features: Interable[str]
:param node_vector_feature_colours: Specifies colors of vectors.
:type node_vector_feature_colours: Interable[str]
:returns: Plotly Graph Objects plot
:rtype: go.Figure
"""
Expand All @@ -115,13 +125,18 @@ def plot_pyg_data(
d["coords"] = x.coords[i]
if node_colour_tensor is not None:
d["colour"] = float(node_colour_tensor[i])
for node_vec in node_vector_features:
if hasattr(x, node_vec):
d[node_vec] = getattr(x, node_vec)[i]

# Preprocess edge colours
if edge_colour_tensor is not None:
# TODO add edge types
for i, (_, _, d) in enumerate(nx_graph.edges(data=True)):
d["colour"] = float(edge_colour_tensor[i])

return plotly_protein_structure_graph(
# Plot nx graph
fig = plotly_protein_structure_graph(
nx_graph,
plot_title,
figsize,
Expand All @@ -135,3 +150,17 @@ def plot_pyg_data(
colour_nodes_by if node_colour_tensor is None else "colour",
colour_edges_by if edge_colour_tensor is None else "colour",
)

# Add vectors to visualize
node_vector_feature_colours = itertools.cycle(node_vector_feature_colours)
for node_vec in node_vector_features:
if hasattr(x, node_vec):
fig = add_vector_to_plot(
nx_graph,
fig,
node_vec,
colour=next(node_vector_feature_colours),
scale=1.5,
)

return fig
15 changes: 8 additions & 7 deletions graphein/protein/edges/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -1097,6 +1097,7 @@ def add_k_nn_edges(
or pdb_df["z_coord"].isna().sum()
):
raise ValueError("Coordinates contain a NaN value.")
pdb_df = pdb_df.reset_index(drop=True)

# Construct distance matrix
dist_mat = compute_distmat(pdb_df)
Expand Down Expand Up @@ -1124,7 +1125,7 @@ def add_k_nn_edges(
nn = neigh.kneighbors_graph()

# Create iterable of node indices
outgoing = np.repeat(np.array(range(len(G.graph["pdb_df"]))), k)
outgoing = np.repeat(np.array(range(len(pdb_df))), k)
incoming = nn.indices
interacting_nodes = list(zip(outgoing, incoming))
log.info(f"Found: {len(interacting_nodes)} KNN edges")
Expand All @@ -1133,16 +1134,16 @@ def add_k_nn_edges(
continue

# Get nodes IDs from indices
n1 = G.graph["pdb_df"].loc[a1, "node_id"]
n2 = G.graph["pdb_df"].loc[a2, "node_id"]
n1 = pdb_df.loc[a1, "node_id"]
n2 = pdb_df.loc[a2, "node_id"]

# Get chains
n1_chain = G.graph["pdb_df"].loc[a1, "chain_id"]
n2_chain = G.graph["pdb_df"].loc[a2, "chain_id"]
n1_chain = pdb_df.loc[a1, "chain_id"]
n2_chain = pdb_df.loc[a2, "chain_id"]

# Get sequence position
n1_position = G.graph["pdb_df"].loc[a1, "residue_number"]
n2_position = G.graph["pdb_df"].loc[a2, "residue_number"]
n1_position = pdb_df.loc[a1, "residue_number"]
n2_position = pdb_df.loc[a2, "residue_number"]

# Check residues are not on same chain
condition_1 = n1_chain != n2_chain
Expand Down
94 changes: 86 additions & 8 deletions graphein/protein/features/nodes/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,23 @@
# Project Website: https://github.com/a-r-j/graphein
# Code Repository: https://github.com/a-r-j/graphein

from typing import List

import networkx as nx
import numpy as np
from loguru import logger as log

from graphein.protein.utils import compute_rgroup_dataframe, filter_dataframe

VECTOR_FEATURE_NAMES: List[str] = [
"sidechain_vector",
"c_beta_vector",
"sequence_neighbour_vector_n_to_c",
"sequence_neighbour_vector_c_to_n",
"virtual_c_beta_vector",
]
"""Names of all vector features from the module."""


def add_sidechain_vector(
g: nx.Graph, scale: bool = True, reverse: bool = False
Expand Down Expand Up @@ -41,16 +52,23 @@ def add_sidechain_vector(
for n, d in g.nodes(data=True):
if d["residue_name"] == "GLY":
# If GLY, set vector to 0
vec = np.array([0, 0, 0])
vec = np.array([0.0, 0.0, 0.0])
elif n not in sc_centroid.index:
vec = np.array([0.0, 0.0, 0.0])
log.warning(
f"Non-glycine residue {n} does not have side-chain atoms."
)
else:
if reverse:
vec = d["coords"] - np.array(
sc_centroid.loc[n][["x_coord", "y_coord", "z_coord"]]
sc_centroid.loc[n][["x_coord", "y_coord", "z_coord"]],
dtype=float,
)
else:
vec = (
np.array(
sc_centroid.loc[n][["x_coord", "y_coord", "z_coord"]]
sc_centroid.loc[n][["x_coord", "y_coord", "z_coord"]],
dtype=float,
)
- d["coords"]
)
Expand All @@ -68,7 +86,7 @@ def add_beta_carbon_vector(
carbon.
Glycine does not have a beta carbon, so we set it to
``np.array([0, 0, 0])``. We extract the position of the beta carbon from the
``np.array([0., 0., 0.])``. We extract the position of the beta carbon from the
unprocessed atomic PDB dataframe. For this we use the ``raw_pdb_df``
DataFrame. If ``scale``, we scale the vector to the unit vector. If
``reverse`` is ``True``, we reverse the vector (``C beta - node``).
Expand All @@ -93,16 +111,25 @@ def add_beta_carbon_vector(
# Iterate over nodes and compute vector
for n, d in g.nodes(data=True):
if d["residue_name"] == "GLY":
vec = np.array([0, 0, 0])
vec = np.array([0.0, 0.0, 0.0])
elif n not in c_beta_coords.index:
vec = np.array([0.0, 0.0, 0.0])
log.warning(
f"Non-glycine residue {n} does not have a beta-carbon."
)
else:
if reverse:
vec = d["coords"] - np.array(
c_beta_coords.loc[n][["x_coord", "y_coord", "z_coord"]]
c_beta_coords.loc[n][["x_coord", "y_coord", "z_coord"]],
dtype=float,
)
else:
vec = (
np.array(
c_beta_coords.loc[n][["x_coord", "y_coord", "z_coord"]]
c_beta_coords.loc[n][
["x_coord", "y_coord", "z_coord"]
],
dtype=float,
)
- d["coords"]
)
Expand Down Expand Up @@ -148,7 +175,7 @@ def add_sequence_neighbour_vector(
# Checks not at chain terminus - is this versatile enough?
if i == len(chain_residues) - 1:
residue[1][f"sequence_neighbour_vector_{suffix}"] = np.array(
[0, 0, 0]
[0.0, 0.0, 0.0]
)
continue
# Asserts residues are on the same chain
Expand All @@ -174,3 +201,54 @@ def add_sequence_neighbour_vector(
vec = vec / np.linalg.norm(vec)

residue[1][f"sequence_neighbour_vector_{suffix}"] = vec


def add_virtual_beta_carbon_vector(
g: nx.Graph, scale: bool = False, reverse: bool = False
):
"""For each node adds a vector from alpha carbon to virtual beta carbon.
:param g: Graph to add vector to.
:type g: nx.Graph
:param scale: Scale vector to unit vector. Defaults to ``False``.
:type scale: bool
:param reverse: Reverse vector. Defaults to ``False``.
:type reverse: bool
"""
# Get coords of backbone atoms
coord_dfs = {}
for atom_type in ["N", "CA", "C"]:
df = filter_dataframe(
g.graph["raw_pdb_df"], "atom_name", [atom_type], boolean=True
)
df.index = df["node_id"]
coord_dfs[atom_type] = df

# Iterate over nodes and compute vector
for n, d in g.nodes(data=True):
if any([n not in df.index for df in coord_dfs.values()]):
vec = np.array([0, 0, 0], dtype=float)
log.warning(f"Missing backbone atom in residue {n}.")
else:
N = np.array(
coord_dfs["N"].loc[n][["x_coord", "y_coord", "z_coord"]],
dtype=float,
)
Ca = np.array(
coord_dfs["CA"].loc[n][["x_coord", "y_coord", "z_coord"]],
dtype=float,
)
C = np.array(
coord_dfs["C"].loc[n][["x_coord", "y_coord", "z_coord"]],
dtype=float,
)
b = Ca - N
c = C - Ca
a = np.cross(b, c)
Cb = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + Ca
vec = Cb - Ca

if reverse:
vec = -vec
if scale:
vec = vec / np.linalg.norm(vec)
d["virtual_c_beta_vector"] = vec
2 changes: 1 addition & 1 deletion graphein/protein/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,7 +534,7 @@ def initialise_graph_with_metadata(
chain_ids=list(protein_df["chain_id"].unique()),
pdb_df=protein_df,
raw_pdb_df=raw_pdb_df,
rgroup_df=compute_rgroup_dataframe(remove_insertions(raw_pdb_df)),
rgroup_df=compute_rgroup_dataframe(raw_pdb_df),
coords=np.asarray(protein_df[["x_coord", "y_coord", "z_coord"]]),
)

Expand Down
9 changes: 9 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import pytest
from loguru import logger


@pytest.fixture
def caplog(caplog):
handler_id = logger.add(caplog.handler, format="{message}")
yield caplog
logger.remove(handler_id)
2 changes: 1 addition & 1 deletion tests/ppi/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_construct_graph():

# Check nodes and edges
assert len(g.nodes()) == 8
assert len(g.edges()) == 23
assert len(g.edges()) == 21

# Check edge types are from string/biogrid
# Check nodes are in our list
Expand Down
25 changes: 25 additions & 0 deletions tests/protein/edges/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,3 +428,28 @@ def test_add_k_nn_edges():
add_k_nn_edges(g, **args)
edges_real = list(g.edges())
assert set(edges_real) == set(edges_expected)


def test_insertion_codes_in_add_k_nn_edges():
pdb_df = pd.DataFrame(
{
"residue_number": [1, 2, 3],
"node_id": ["A:HIS:1", "A:TYR:2", "B:ALA:4:altB"],
"chain_id": ["A", "A", "B"],
"x_coord": [1.0, 2.0, 3.0],
"y_coord": [4.0, 5.0, 6.0],
"z_coord": [7.0, 8.0, 9.0],
},
index=[0, 1, 3], # simulating dropped "B:ALA:4:altA"
)
g = nx.empty_graph(pdb_df["node_id"])
g.graph["pdb_df"] = pdb_df
add_k_nn_edges(g, k=3)

edges_expected = [
("A:HIS:1", "A:TYR:2"),
("A:HIS:1", "B:ALA:4:altB"),
("A:TYR:2", "B:ALA:4:altB"),
]

assert set(g.edges()) == set(edges_expected)
Loading

0 comments on commit 7c99e57

Please sign in to comment.