Skip to content

Commit

Permalink
pad fixing for unimol and docking pose (#211)
Browse files Browse the repository at this point in the history
* pad fixing for unimol and docking pose

* address comments for pad fixing

* adding a fix to fill diagonal with 0 for docking_utils.py, because the model does not learn to predict 0 distance

* remove -1 for atom_num in else

---------

Co-authored-by: nicolas.brosse <[email protected]>
  • Loading branch information
nbrosse and nicolas.brosse authored Apr 26, 2024
1 parent 37c6ddc commit 7394647
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 11 deletions.
15 changes: 10 additions & 5 deletions unimol/unimol/losses/docking_pose.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) DP Technology.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn.functional as F
from unicore import metrics
from unicore.losses import UnicoreLoss, register_loss
Expand All @@ -11,6 +11,9 @@
class DockingPossLoss(UnicoreLoss):
def __init__(self, task):
super().__init__(task)
self.eos_idx = task.dictionary.eos()
self.bos_idx = task.dictionary.bos()
self.padding_idx = task.dictionary.pad()

def forward(self, model, sample, reduce=True):
"""Compute the loss for the given sample.
Expand All @@ -24,7 +27,8 @@ def forward(self, model, sample, reduce=True):
cross_distance_predict, holo_distance_predict = net_outputs[0], net_outputs[1]

### distance loss
distance_mask = sample["target"]["distance_target"].ne(0) # 0 is padding
distance_mask = sample["target"]["distance_target"].ne(0) # 0 for padding, BOS and EOS
# 0 is impossible in the cross distance matrix, all the relevant cross distances are kept
if self.args.dist_threshold > 0:
distance_mask &= (
sample["target"]["distance_target"] < self.args.dist_threshold
Expand All @@ -36,9 +40,10 @@ def forward(self, model, sample, reduce=True):
)

### holo distance loss
holo_distance_mask = sample["target"]["holo_distance_target"].ne(
0
) # 0 is padding
token_mask = sample["net_input"]["mol_src_tokens"].ne(self.padding_idx) & \
sample["net_input"]["mol_src_tokens"].ne(self.eos_idx) & \
sample["net_input"]["mol_src_tokens"].ne(self.bos_idx)
holo_distance_mask = token_mask.unsqueeze(-1) & token_mask.unsqueeze(1)
holo_distance_predict_train = holo_distance_predict[holo_distance_mask]
holo_distance_target = sample["target"]["holo_distance_target"][
holo_distance_mask
Expand Down
10 changes: 7 additions & 3 deletions unimol/unimol/losses/unimol.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,18 @@ def cal_dist_loss(self, sample, dist, masked_tokens, target_key, normalize=False
masked_distance_target = sample[target_key]["distance_target"][
dist_masked_tokens
]
non_pad_pos = masked_distance_target > 0
# padding distance
nb_masked_tokens = dist_masked_tokens.sum(dim=-1)
masked_src_tokens = sample["net_input"]["src_tokens"].ne(self.padding_idx)
masked_src_tokens_expanded = torch.repeat_interleave(masked_src_tokens, nb_masked_tokens, dim=0)
#
if normalize:
masked_distance_target = (
masked_distance_target.float() - self.dist_mean
) / self.dist_std
masked_dist_loss = F.smooth_l1_loss(
masked_distance[non_pad_pos].view(-1).float(),
masked_distance_target[non_pad_pos].view(-1),
masked_distance[masked_src_tokens_expanded].view(-1).float(),
masked_distance_target[masked_src_tokens_expanded].view(-1),
reduction="mean",
beta=1.0,
)
Expand Down
10 changes: 7 additions & 3 deletions unimol/unimol/models/unimol.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,18 @@ def get_dist_features(dist, et):
if self.args.masked_coord_loss > 0:
coords_emb = src_coord
if padding_mask is not None:
atom_num = (torch.sum(1 - padding_mask.type_as(x), dim=1) - 1).view(
atom_num = torch.sum(1 - padding_mask.type_as(x), dim=1).view(
-1, 1, 1, 1
)
) # consider BOS and EOS as part of the object
else:
atom_num = src_coord.shape[1] - 1
atom_num = src_coord.shape[1]
delta_pos = coords_emb.unsqueeze(1) - coords_emb.unsqueeze(2)
attn_probs = self.pair2coord_proj(delta_encoder_pair_rep)
coord_update = delta_pos / atom_num * attn_probs
# Mask padding
pair_coords_mask = (1 - padding_mask.float()).unsqueeze(-1) * (1 - padding_mask.float()).unsqueeze(1)
coord_update = coord_update * pair_coords_mask.unsqueeze(-1)
#
coord_update = torch.sum(coord_update, dim=2)
encoder_coord = coords_emb + coord_update
if self.args.masked_dist_loss > 0:
Expand Down
3 changes: 3 additions & 0 deletions unimol/unimol/utils/docking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,9 @@ def docking_data_pre(raw_data_path, predict_path):
pocket_coords = pocket_coords.numpy().astype(np.float32)
distance_predict = distance_predict.numpy().astype(np.float32)
holo_distance_predict = holo_distance_predict.numpy().astype(np.float32)
# Fill diagonal with 0, issue with the model not learning to predict 0 distance
holo_distance_predict = np.fill_diagonal(holo_distance_predict, 0)
#
holo_coords = holo_coordinates.numpy().astype(np.float32)

pocket_coords_list.append(pocket_coords)
Expand Down

0 comments on commit 7394647

Please sign in to comment.