Skip to content

Commit

Permalink
Merge pull request #22 from choderalab/add-gat-model
Browse files Browse the repository at this point in the history
Add GAT model
  • Loading branch information
kaminow authored Apr 21, 2023
2 parents c9dc6e4 + d5adef2 commit 42893ce
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 17 deletions.
4 changes: 4 additions & 0 deletions devtools/conda-envs/test_env.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: test
channels:
- conda-forge
- dglteam
dependencies:
- pytorch
- pytorch_geometric
Expand All @@ -10,6 +11,9 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
- rdkit
# testing dependencies
- pytest
- pytest-cov
Expand Down
3 changes: 3 additions & 0 deletions environment-gpu.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: mtenn-gpu
channels:
- conda-forge
- dglteam
dependencies:
- pytorch
- pytorch-gpu
Expand All @@ -11,3 +12,5 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
3 changes: 3 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
name: mtenn
channels:
- conda-forge
- dglteam
dependencies:
- pytorch
- pytorch_geometric
Expand All @@ -10,3 +11,5 @@ dependencies:
- numpy
- h5py
- e3nn
- dgllife
- dgl
2 changes: 1 addition & 1 deletion mtenn/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.1.0"
__version__ = "0.2.0"
5 changes: 3 additions & 2 deletions mtenn/conversion_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .schnet import SchNet
from .e3nn import E3NN
from .gat import GAT
from .schnet import SchNet

__all__ = ["SchNet", "E3NN"]
__all__ = ["E3NN", "GAT", "SchNet"]
155 changes: 155 additions & 0 deletions mtenn/conversion_utils/gat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
Representation and strategy for GAT model.
"""
from copy import deepcopy
import torch
from dgllife.model import GAT as GAT_dgl
from dgllife.model import WeightedSumAndMax

from ..model import (
BoltzmannCombination,
ConcatStrategy,
DeltaStrategy,
GroupedModel,
MeanCombination,
LigandOnlyModel,
PIC50Readout,
)


class GAT(torch.nn.Module):
def __init__(self, *args, model=None, **kwargs):
## If no model is passed, construct model based on passed args, otherwise copy
## all parameters and weights over
if model is None:
super().__init__()
self.gnn = GAT_dgl(*args, **kwargs)
else:
# Parameters that are conveniently accessible from the top level
in_feats = model.gnn_layers[0].gat_conv.fc.in_features
hidden_feats = model.hidden_feats
num_heads = model.num_heads
agg_modes = model.agg_modes
# Parameters that can only be adcessed layer-wise
layer_params = [
(
l.gat_conv.feat_drop.p,
l.gat_conv.attn_drop.p,
l.gat_conv.leaky_relu.negative_slope,
bool(l.gat_conv.res_fc),
l.gat_conv.activation,
bool(l.gat_conv.bias),
)
for l in model.gnn_layers
]
(
feat_drops,
attn_drops,
alphas,
residuals,
activations,
residuals,
biases,
) = zip(*layer_params)
self.gnn = GAT_dgl(
in_feats=in_feats,
hidden_feats=hidden_feats,
num_heads=num_heads,
feat_drops=feat_drops,
attn_drops=attn_drops,
alphas=alphas,
residuals=residuals,
agg_modes=agg_modes,
activations=activations,
biases=biases,
)
self.gnn.load_state_dict(model.state_dict())

# Copied from GATPredictor class, figure out how many features the last
# layer of the GNN will have
if self.gnn.agg_modes[-1] == "flatten":
gnn_out_feats = self.gnn.hidden_feats[-1] * self.gnn.num_heads[-1]
else:
gnn_out_feats = self.gnn.hidden_feats[-1]
self.readout = WeightedSumAndMax(gnn_out_feats)

# Use given hidden feats if supplied, otherwise use 1/2 gnn_out_feats
if "predictor_hidden_feats" in kwargs:
predictor_hidden_feats = kwargs["predictor_hidden_feats"]
else:
predictor_hidden_feats = gnn_out_feats // 2

# 2 layer MLP with ReLU activation (borrowed from GATPredictor)
self.predict = torch.nn.Sequential(
torch.nn.Linear(2 * gnn_out_feats, predictor_hidden_feats),
torch.nn.ReLU(),
torch.nn.Linear(predictor_hidden_feats, 1),
)

def forward(self, data):
g = data["g"]
node_feats = self.gnn(g, g.ndata["h"])
graph_feats = self.readout(g, node_feats)
return self.predict(graph_feats)

def _get_representation(self):
"""
Input model, remove last layer.
Returns
-------
GAT
Copied GAT model with the last layer replaced by an Identity module
"""

## Copy model so initial model isn't affected
model_copy = deepcopy(self.gnn)

return model_copy

def _get_energy_func(self):
"""
Return last two layer of the model.
Returns
-------
torch.nn.Sequential
Sequential module calling copy of `model`'s last two layers
"""

return torch.nn.Sequential(deepcopy(self.readout), deepcopy(self.predict))

@staticmethod
def get_model(
*args,
model=None,
fix_device=False,
pred_readout=None,
**kwargs,
):
"""
Exposed function to build a LigandOnlyModel object from a GAT object
(or args/kwargs).
Parameters
----------
model: GAT, optional
GAT model to use to build the LigandOnlyModel object. If left as none, a
default model will be initialized and used
fix_device: bool, default=False
If True, make sure the input is on the same device as the model,
copying over as necessary.
pred_readout : Readout
Readout object for the energy predictions. If `grouped` is `False`,
this option will still be used in the construction of the `LigandOnlyModel`
object.
Returns
-------
LigandOnlyModel
LigandOnlyModel object containing the desired Representation and Strategy
"""
if model is None:
model = GAT(*args, **kwargs)

return LigandOnlyModel(model=model, readout=pred_readout, fix_device=fix_device)
52 changes: 38 additions & 14 deletions mtenn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,7 @@ class Model(torch.nn.Module):
representations, and convert to a final scalar value.
"""

def __init__(
self, representation, strategy, readout=None, fix_device=False
):
def __init__(self, representation, strategy, readout=None, fix_device=False):
"""
Parameters
----------
Expand Down Expand Up @@ -51,9 +49,7 @@ def forward(self, comp, *parts):

if len(parts) == 0:
parts = Model._split_parts(tmp_comp)
parts_rep = [
self.get_representation(self._fix_device(p)) for p in parts
]
parts_rep = [self.get_representation(self._fix_device(p)) for p in parts]

energy_val = self.strategy(complex_rep, *parts_rep)
if self.readout:
Expand Down Expand Up @@ -205,6 +201,40 @@ def forward(self, input_list):
return comb_pred


class LigandOnlyModel(Model):
"""
A ligand-only version of the Model. In this case, the `representation` block will
hold the entire model, while the `strategy` block will simply be set as an Identity
module.
"""

def __init__(self, model, readout=None, fix_device=False):
"""
Parameters
----------
fix_device: bool, default=False
If True, make sure the input is on the same device as the model,
copying over as necessary.
"""
super(LigandOnlyModel, self).__init__(
representation=model,
strategy=torch.nn.Identity(),
readout=readout,
fix_device=fix_device,
)

def forward(self, rep):
## This implementation of the forward function assumes the
## get_representation function takes a single data object
tmp_rep = self._fix_device(rep)
pred = self.get_representation(tmp_rep)

if self.readout:
return self.readout(pred)
else:
return pred


class Representation(torch.nn.Module):
pass

Expand Down Expand Up @@ -234,9 +264,7 @@ def __init__(self, energy_func, pic50=True):

def forward(self, comp, *parts):
## Calculat delta G
return self.energy_func(comp) - sum(
[self.energy_func(p) for p in parts]
)
return self.energy_func(comp) - sum([self.energy_func(p) for p in parts])


class ConcatStrategy(Strategy):
Expand Down Expand Up @@ -367,8 +395,4 @@ def forward(self, delta_g):
## IC50 value = exp(dG/kT) => pic50 = -log10(exp(dg/kT))
## Rearrange a bit more to avoid disappearing floats:
## pic50 = -dg/kT / ln(10)
return (
-delta_g
/ self.kT
/ torch.log(torch.tensor(10, dtype=delta_g.dtype))
)
return -delta_g / self.kT / torch.log(torch.tensor(10, dtype=delta_g.dtype))

0 comments on commit 42893ce

Please sign in to comment.