Skip to content

Commit

Permalink
Filter non used sample values in QM9 (#316)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaulPPelaez authored Apr 18, 2024
1 parent 72d6e8e commit 0ed2e7c
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions torchmdnet/datasets/qm9.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch_geometric.transforms import Compose
from torch_geometric.datasets import QM9 as QM9_geometric
from torch_geometric.nn.models.schnet import qm9_target_dict
from torch_geometric.data import Data


class QM9(QM9_geometric):
Expand All @@ -25,17 +26,27 @@ def __init__(self, root, transform=None, label=None):
else:
transform = Compose([transform, self._filter_label])

super(QM9, self).__init__(root, transform=transform)
# Keep only pos, z and y in each sample
def pre_transform(x):
return Data(
pos=x.pos,
z=x.z,
y=x.y,
)

super(QM9, self).__init__(
root, transform=transform, pre_transform=pre_transform
)

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.
Args:
max_z (int): Maximum atomic number
Args:
max_z (int): Maximum atomic number
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
atomref = self.atomref(self.label_idx)
if atomref is None:
return None
Expand Down

0 comments on commit 0ed2e7c

Please sign in to comment.