Skip to content

Commit

Permalink
refactor to new recommended jaxtyping/beartype syntax
Browse files Browse the repository at this point in the history
  • Loading branch information
Jamasb committed Dec 26, 2023
1 parent 595b5ac commit 381040c
Show file tree
Hide file tree
Showing 25 changed files with 154 additions and 208 deletions.
5 changes: 2 additions & 3 deletions docs/source/tutorials.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Evaluating a pre-trained Model
from graphein.protein.tensor.data import ProteinBatch
from proteinworkshop.models.utils import get_aggregation
from jaxtyping import jaxtyped
from beartype import beartype
from beartype import beartype as typechecker
class IdentityModel(nn.Module):
Expand All @@ -58,8 +58,7 @@ Evaluating a pre-trained Model
"""This property describes the required attributes of the input batch."""
return {"x", "batch"}
@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def forward(self, batch: Union[Batch, ProteinBatch]) -> Dict[str, torch.Tensor]:
"""
This method does the forward pass of the model.
Expand Down
2 changes: 1 addition & 1 deletion notebooks/adding_new_task_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
" def required_attributes(self) -> Set[str]:\n",
" return {\"residue_type\"}\n",
"\n",
" @beartype\n",
" @jaxtyped(typechecker=typechecker)\n",
" def __call__(self, x: Union[Data, Protein]) -> Union[Data, Protein]:\n",
" x.residue_type_uncorrupted = copy.deepcopy(x.residue_type)\n",
" # Get indices of residues to corrupt\n",
Expand Down
10 changes: 5 additions & 5 deletions proteinworkshop/datasets/atom3d_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import atom3d.datasets.datasets as da
import lightning as L
import torch
from beartype import beartype
from beartype import beartype as typechecker
from graphein.protein.tensor.dataloader import ProteinDataLoader
from loguru import logger as log
from torch.utils.data import Dataset
Expand All @@ -34,7 +34,7 @@ def set_worker_sharing_strategy(worker_id: int):
torch.multiprocessing.set_sharing_strategy(SHARING_STRATEGY)


@beartype
@typechecker
def get_data_path(
dataset: str,
lba_split: int = 30,
Expand All @@ -57,7 +57,7 @@ def get_data_path(
return data_paths[dataset]


@beartype
@typechecker
def get_test_data_path(
dataset: str,
lba_split: int = 30,
Expand Down Expand Up @@ -85,7 +85,7 @@ def get_test_data_path(
return data_paths[dataset]


@beartype
@typechecker
def get_task_split(
task: str,
lba_split: int = 30,
Expand Down Expand Up @@ -317,7 +317,7 @@ def setup(self, stage: Optional[str] = None):
self.data_test,
) = self.get_datasets()

@beartype
@typechecker
def get_dataloader(
self,
dataset: Union[da.LMDBDataset, PPIDataset, RESDataset],
Expand Down
2 changes: 1 addition & 1 deletion proteinworkshop/datasets/components/atom3d_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
NUM_ATOM_TYPES = len(_atom_types_dict)


@beartype
@typechecker
def _element_mapping(x: str) -> int:
return _atom_types_dict.get(x, 8)

Expand Down
16 changes: 8 additions & 8 deletions proteinworkshop/datasets/components/ppi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scipy
import torch
from atom3d.datasets import LMDBDataset
from beartype import beartype
from beartype import beartype as typechecker
from torch.utils.data import IterableDataset
from torch_geometric.data import Data

Expand All @@ -23,14 +23,14 @@
]


@beartype
@typechecker
def get_res(df: pd.DataFrame) -> pd.DataFrame:
"""Get all residues."""
# Adapted from: https://github.com/drorlab/atom3d/blob/master/examples/ppi/dataset/neighbors.py
return df[PPI_DF_INDEX_COLUMNS].drop_duplicates()


@beartype
@typechecker
def _get_idx_to_res_mapping(
df: pd.DataFrame,
) -> Tuple[pd.DataFrame, pd.Series]:
Expand All @@ -43,7 +43,7 @@ def _get_idx_to_res_mapping(
return idx_to_res, res_to_idx


@beartype
@typechecker
def get_subunits(
ensemble: pd.DataFrame,
) -> Tuple[
Expand Down Expand Up @@ -79,7 +79,7 @@ def get_subunits(
return names, (bdf0, bdf1, udf0, udf1)


@beartype
@typechecker
def get_negatives(
neighbors, df0: pd.DataFrame, df1: pd.DataFrame
) -> pd.DataFrame:
Expand Down Expand Up @@ -170,7 +170,7 @@ def __iter__(self):
)
return gen

@beartype
@typechecker
def _df_to_graph(
self, struct_df: pd.DataFrame, chain_res: Iterable, label: float
) -> Optional[Data]:
Expand Down Expand Up @@ -208,7 +208,7 @@ def _df_to_graph(

return data

@beartype
@typechecker
def _dataset_generator(
self, indices: List[int], shuffle: bool = True
) -> Generator[Tuple[Data, Data], None, None]:
Expand Down Expand Up @@ -247,7 +247,7 @@ def _dataset_generator(
continue
yield graph1, graph2

@beartype
@typechecker
def _create_labels(
self,
positives: pd.DataFrame,
Expand Down
14 changes: 5 additions & 9 deletions proteinworkshop/features/edge_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np
import torch
from beartype import beartype
from beartype import beartype as typechecker
from graphein.protein.tensor.types import CoordTensor, EdgeTensor
from jaxtyping import jaxtyped
from omegaconf import ListConfig
Expand All @@ -20,8 +20,7 @@
"""List of edge features that can be computed."""


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def compute_scalar_edge_features(
x: Union[Data, Batch], features: Union[List[str], ListConfig]
) -> torch.Tensor:
Expand Down Expand Up @@ -55,8 +54,7 @@ def compute_scalar_edge_features(
return torch.cat(feats, dim=1)


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def compute_vector_edge_features(
x: Union[Data, Batch], features: Union[List[str], ListConfig]
) -> Union[Data, Batch]:
Expand All @@ -71,8 +69,7 @@ def compute_vector_edge_features(
return x


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def compute_edge_distance(
pos: CoordTensor, edge_index: EdgeTensor
) -> torch.Tensor:
Expand All @@ -91,8 +88,7 @@ def compute_edge_distance(
)


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def pos_emb(edge_index: EdgeTensor, num_pos_emb: int = 16):
# From https://github.com/jingraham/neurips19-graph-protein-design
d = edge_index[0] - edge_index[1]
Expand Down
6 changes: 3 additions & 3 deletions proteinworkshop/features/edges.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

import graphein.protein.tensor.edges as gp
import torch
from beartype import beartype
from beartype import beartype as typechecker
from graphein.protein.tensor.data import Protein, ProteinBatch
from omegaconf import ListConfig
from torch_geometric.data import Batch, Data


@beartype
@typechecker
def compute_edges(
x: Union[Data, Batch, Protein, ProteinBatch],
edge_types: Union[ListConfig, List[str]],
Expand Down Expand Up @@ -81,7 +81,7 @@ def compute_edges(
return edges, indxs


@beartype
@typechecker
def sequence_edges(
b: Union[Data, Batch, Protein, ProteinBatch],
chains: Optional[torch.Tensor] = None,
Expand Down
5 changes: 2 additions & 3 deletions proteinworkshop/features/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import torch.nn as nn
from beartype import beartype
from beartype import beartype as typechecker
from graphein.protein.tensor.data import ProteinBatch, get_random_batch
from jaxtyping import jaxtyped
from loguru import logger
Expand Down Expand Up @@ -75,8 +75,7 @@ def __init__(
if "sequence_positional_encoding" in self.scalar_node_features:
self.positional_encoding = PositionalEncoding(16)

@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def forward(
self, batch: Union[Batch, ProteinBatch]
) -> Union[Batch, ProteinBatch]:
Expand Down
14 changes: 5 additions & 9 deletions proteinworkshop/features/node_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
import torch.nn.functional as F
from beartype import beartype
from beartype import beartype as typechecker
from graphein.protein.tensor.angles import (
alpha,
dihedrals,
Expand All @@ -25,8 +25,7 @@
from .utils import _normalize


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def compute_scalar_node_features(
x: Union[Batch, Data, Protein, ProteinBatch],
node_features: Union[ListConfig, List[ScalarNodeFeature]],
Expand Down Expand Up @@ -86,8 +85,7 @@ def compute_scalar_node_features(
return torch.cat(feats, dim=1) if feats else x.x


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def compute_vector_node_features(
x: Union[Batch, Data, Protein, ProteinBatch],
vector_features: Union[ListConfig, List[str]],
Expand All @@ -114,8 +112,7 @@ def compute_vector_node_features(
return x


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def compute_surface_feat(
coords: Union[CoordTensor, AtomTensor], k: int, sigma: List[float]
):
Expand Down Expand Up @@ -150,8 +147,7 @@ def compute_surface_feat(
return torch.cat(feat, dim=1)


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def orientations(
X: Union[CoordTensor, AtomTensor], ca_idx: int = 1
) -> OrientationTensor:
Expand Down
22 changes: 10 additions & 12 deletions proteinworkshop/features/representation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Literal, Tuple

import torch
from beartype import beartype
from beartype import beartype as typechecker
from graphein.protein.tensor.types import AtomTensor, CoordTensor
from jaxtyping import jaxtyped
from torch_geometric.data import Batch, Data
Expand All @@ -11,8 +11,7 @@
from proteinworkshop.configs.config import ExperimentConfigurationError


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def get_full_atom_coords(
atom_tensor: AtomTensor, fill_value: float = 1e-5
) -> Tuple[CoordTensor, torch.Tensor, torch.Tensor]:
Expand Down Expand Up @@ -40,8 +39,7 @@ def get_full_atom_coords(
return coords, residue_index, atom_type


@jaxtyped
@beartype
@jaxtyped(typechecker=typechecker)
def transform_representation(
x: Batch, representation_type: Literal["CA", "BB", "FA", "BB_SC", "CA_SC"]
) -> Batch:
Expand Down Expand Up @@ -87,7 +85,7 @@ def transform_representation(
)


@beartype
@typechecker
def _ca_to_fa_repr(x: Data) -> Data:
"""Converts CA representation to full atom representation."""
coords, residue_index, atom_type = get_full_atom_coords(x.coords)
Expand All @@ -101,7 +99,7 @@ def _ca_to_fa_repr(x: Data) -> Data:
return x


@beartype
@typechecker
def _ca_to_bb_repr(x: Data) -> Data:
"""Converts CA representation to backbone representation."""
x.pos = x.coords[:, :4, :].reshape(-1, 3)
Expand All @@ -117,7 +115,7 @@ def _ca_to_bb_repr(x: Data) -> Data:
return x


@beartype
@typechecker
def ca_to_bb_repr(batch: Batch) -> Batch: # sourcery skip: assign-if-exp
"""
Converts a batch of CA representations to backbone representations. I.e.
Expand Down Expand Up @@ -180,7 +178,7 @@ def ca_to_bb_repr(batch: Batch) -> Batch: # sourcery skip: assign-if-exp
return batch


@beartype
@typechecker
def ca_to_bb_sc_repr(batch: Batch) -> Batch:
"""Converts a batch of CA representations to backbone + sidechain representations."""
# Get centroids
Expand All @@ -190,7 +188,7 @@ def ca_to_bb_sc_repr(batch: Batch) -> Batch:
return ca_to_fa_repr(batch)


@beartype
@typechecker
def ca_to_ca_sc_repr(batch: Batch) -> Batch:
"""Converts a batch of CA representations to C + sidechain representations."""
# Get centroids
Expand All @@ -200,7 +198,7 @@ def ca_to_ca_sc_repr(batch: Batch) -> Batch:
return batch


@beartype
@typechecker
def coarsen_sidechain(x: Data, aggr: str = "mean") -> CoordTensor:
"""Returns tensor of sidechain centroids: L x 3"""
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
Expand All @@ -216,7 +214,7 @@ def coarsen_sidechain(x: Data, aggr: str = "mean") -> CoordTensor:
return sc_points


@beartype
@typechecker
def ca_to_fa_repr(batch: Batch) -> Batch: # sourcery skip: assign-if-exp
"""Converts a batch of CA representations to full atom representations."""
if "sidechain_torsion" in batch.keys:
Expand Down
Loading

0 comments on commit 381040c

Please sign in to comment.