diff --git a/torchmdnet/datasets/qm9.py b/torchmdnet/datasets/qm9.py index 6b97787d..92c16511 100644 --- a/torchmdnet/datasets/qm9.py +++ b/torchmdnet/datasets/qm9.py @@ -6,6 +6,7 @@ from torch_geometric.transforms import Compose from torch_geometric.datasets import QM9 as QM9_geometric from torch_geometric.nn.models.schnet import qm9_target_dict +from torch_geometric.data import Data class QM9(QM9_geometric): @@ -25,17 +26,27 @@ def __init__(self, root, transform=None, label=None): else: transform = Compose([transform, self._filter_label]) - super(QM9, self).__init__(root, transform=transform) + # Keep only pos, z and y in each sample + def pre_transform(x): + return Data( + pos=x.pos, + z=x.z, + y=x.y, + ) + + super(QM9, self).__init__( + root, transform=transform, pre_transform=pre_transform + ) def get_atomref(self, max_z=100): """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. - Args: - max_z (int): Maximum atomic number + Args: + max_z (int): Maximum atomic number - Returns: - torch.Tensor: Atomic energy reference values for each element in the dataset. - """ + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ atomref = self.atomref(self.label_idx) if atomref is None: return None