diff --git a/unimol/unimol/losses/docking_pose.py b/unimol/unimol/losses/docking_pose.py index e9b87c0..b541f70 100644 --- a/unimol/unimol/losses/docking_pose.py +++ b/unimol/unimol/losses/docking_pose.py @@ -1,7 +1,7 @@ # Copyright (c) DP Technology. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. - +import torch import torch.nn.functional as F from unicore import metrics from unicore.losses import UnicoreLoss, register_loss @@ -11,6 +11,9 @@ class DockingPossLoss(UnicoreLoss): def __init__(self, task): super().__init__(task) + self.eos_idx = task.dictionary.eos() + self.bos_idx = task.dictionary.bos() + self.padding_idx = task.dictionary.pad() def forward(self, model, sample, reduce=True): """Compute the loss for the given sample. @@ -24,7 +27,8 @@ def forward(self, model, sample, reduce=True): cross_distance_predict, holo_distance_predict = net_outputs[0], net_outputs[1] ### distance loss - distance_mask = sample["target"]["distance_target"].ne(0) # 0 is padding + distance_mask = sample["target"]["distance_target"].ne(0) # 0 for padding, BOS and EOS + # 0 is impossible in the cross distance matrix, all the relevant cross distances are kept if self.args.dist_threshold > 0: distance_mask &= ( sample["target"]["distance_target"] < self.args.dist_threshold @@ -36,9 +40,10 @@ def forward(self, model, sample, reduce=True): ) ### holo distance loss - holo_distance_mask = sample["target"]["holo_distance_target"].ne( - 0 - ) # 0 is padding + token_mask = sample["net_input"]["mol_src_tokens"].ne(self.padding_idx) & \ + sample["net_input"]["mol_src_tokens"].ne(self.eos_idx) & \ + sample["net_input"]["mol_src_tokens"].ne(self.bos_idx) + holo_distance_mask = token_mask.unsqueeze(-1) & token_mask.unsqueeze(1) holo_distance_predict_train = holo_distance_predict[holo_distance_mask] holo_distance_target = sample["target"]["holo_distance_target"][ holo_distance_mask diff --git a/unimol/unimol/losses/unimol.py b/unimol/unimol/losses/unimol.py index f377540..ef58333 100644 --- a/unimol/unimol/losses/unimol.py +++ b/unimol/unimol/losses/unimol.py @@ -162,14 +162,18 @@ def cal_dist_loss(self, sample, dist, masked_tokens, target_key, normalize=False masked_distance_target = sample[target_key]["distance_target"][ dist_masked_tokens ] - non_pad_pos = masked_distance_target > 0 + # padding distance + nb_masked_tokens = dist_masked_tokens.sum(dim=-1) + masked_src_tokens = sample["net_input"]["src_tokens"].ne(self.padding_idx) + masked_src_tokens_expanded = torch.repeat_interleave(masked_src_tokens, nb_masked_tokens, dim=0) + # if normalize: masked_distance_target = ( masked_distance_target.float() - self.dist_mean ) / self.dist_std masked_dist_loss = F.smooth_l1_loss( - masked_distance[non_pad_pos].view(-1).float(), - masked_distance_target[non_pad_pos].view(-1), + masked_distance[masked_src_tokens_expanded].view(-1).float(), + masked_distance_target[masked_src_tokens_expanded].view(-1), reduction="mean", beta=1.0, ) diff --git a/unimol/unimol/models/unimol.py b/unimol/unimol/models/unimol.py index aa9644b..ae8329c 100644 --- a/unimol/unimol/models/unimol.py +++ b/unimol/unimol/models/unimol.py @@ -229,14 +229,18 @@ def get_dist_features(dist, et): if self.args.masked_coord_loss > 0: coords_emb = src_coord if padding_mask is not None: - atom_num = (torch.sum(1 - padding_mask.type_as(x), dim=1) - 1).view( + atom_num = torch.sum(1 - padding_mask.type_as(x), dim=1).view( -1, 1, 1, 1 - ) + ) # consider BOS and EOS as part of the object else: - atom_num = src_coord.shape[1] - 1 + atom_num = src_coord.shape[1] delta_pos = coords_emb.unsqueeze(1) - coords_emb.unsqueeze(2) attn_probs = self.pair2coord_proj(delta_encoder_pair_rep) coord_update = delta_pos / atom_num * attn_probs + # Mask padding + pair_coords_mask = (1 - padding_mask.float()).unsqueeze(-1) * (1 - padding_mask.float()).unsqueeze(1) + coord_update = coord_update * pair_coords_mask.unsqueeze(-1) + # coord_update = torch.sum(coord_update, dim=2) encoder_coord = coords_emb + coord_update if self.args.masked_dist_loss > 0: diff --git a/unimol/unimol/utils/docking_utils.py b/unimol/unimol/utils/docking_utils.py index cc63a8c..af3b31b 100644 --- a/unimol/unimol/utils/docking_utils.py +++ b/unimol/unimol/utils/docking_utils.py @@ -180,6 +180,9 @@ def docking_data_pre(raw_data_path, predict_path): pocket_coords = pocket_coords.numpy().astype(np.float32) distance_predict = distance_predict.numpy().astype(np.float32) holo_distance_predict = holo_distance_predict.numpy().astype(np.float32) + # Fill diagonal with 0, issue with the model not learning to predict 0 distance + holo_distance_predict = np.fill_diagonal(holo_distance_predict, 0) + # holo_coords = holo_coordinates.numpy().astype(np.float32) pocket_coords_list.append(pocket_coords)