From 2de296a95f50db047f391524e2559dce3d70e3b4 Mon Sep 17 00:00:00 2001 From: RaulPPealez Date: Thu, 11 Apr 2024 13:29:58 +0200 Subject: [PATCH] Filter non used sample values in QM9 --- torchmdnet/datasets/qm9.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/torchmdnet/datasets/qm9.py b/torchmdnet/datasets/qm9.py index 6b97787da..92c165110 100644 --- a/torchmdnet/datasets/qm9.py +++ b/torchmdnet/datasets/qm9.py @@ -6,6 +6,7 @@ from torch_geometric.transforms import Compose from torch_geometric.datasets import QM9 as QM9_geometric from torch_geometric.nn.models.schnet import qm9_target_dict +from torch_geometric.data import Data class QM9(QM9_geometric): @@ -25,17 +26,27 @@ def __init__(self, root, transform=None, label=None): else: transform = Compose([transform, self._filter_label]) - super(QM9, self).__init__(root, transform=transform) + # Keep only pos, z and y in each sample + def pre_transform(x): + return Data( + pos=x.pos, + z=x.z, + y=x.y, + ) + + super(QM9, self).__init__( + root, transform=transform, pre_transform=pre_transform + ) def get_atomref(self, max_z=100): """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. - Args: - max_z (int): Maximum atomic number + Args: + max_z (int): Maximum atomic number - Returns: - torch.Tensor: Atomic energy reference values for each element in the dataset. - """ + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ atomref = self.atomref(self.label_idx) if atomref is None: return None