diff --git a/torchmdnet/datasets/spice.py b/torchmdnet/datasets/spice.py index de385d63e..3bda70c20 100644 --- a/torchmdnet/datasets/spice.py +++ b/torchmdnet/datasets/spice.py @@ -1,8 +1,9 @@ +import hashlib import h5py import numpy as np import os import torch as pt -from torch_geometric.data import Data, Dataset +from torch_geometric.data import Data, Dataset, download_url from tqdm import tqdm @@ -10,22 +11,55 @@ class SPICE(Dataset): """ SPICE dataset (https://github.com/openmm/spice-dataset) + + The dataset consists of several subsets (https://github.com/openmm/spice-dataset/blob/main/downloader/config.yaml). + The subsets can be selected with `subsets`. By default, all the subsets are loaded. + + For example, this loads just two subsets: + >>> ds = SPICE(".", subsets=["SPICE PubChem Set 1 Single Points Dataset v1.2", "SPICE PubChem Set 2 Single Points Dataset v1.2"]) + + The loader can filter conformations with large gradients. The maximum gradient norm threshold + can be set with `max_gradient`. By default, the filter is not applied. + + For examples, the filter the threshold is set to 100 eV/A: + >>> ds = SPICE(".", max_gradient=100) """ HARTREE_TO_EV = 27.211386246 BORH_TO_ANGSTROM = 0.529177 + @property + def raw_file_names(self): + return "SPICE.hdf5" + + @property + def raw_url(self): + return f"https://github.com/openmm/spice-dataset/releases/download/1.0/{self.raw_file_names}" + + @property + def processed_file_names(self): + return [ + f"{self.name}.idx.mmap", + f"{self.name}.z.mmap", + f"{self.name}.pos.mmap", + f"{self.name}.y.mmap", + f"{self.name}.dy.mmap", + ] + def __init__( self, root=None, transform=None, pre_transform=None, pre_filter=None, - paths=None, + subsets=None, + max_gradient=None, ): - - self.name = self.__class__.__name__ - self.paths = str(paths) + arg_hash = f"{subsets}{max_gradient}" + arg_hash = hashlib.md5(arg_hash.encode()).hexdigest() + self.name = f"{self.__class__.__name__}-{arg_hash}" + self.subsets = subsets + self.max_gradient = max_gradient super().__init__(root, transform, pre_transform, pre_filter) idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths @@ -43,13 +77,15 @@ def __init__( assert self.idx_mm[-1] == len(self.z_mm) assert len(self.idx_mm) == len(self.y_mm) + 1 - @property - def raw_paths(self): - return self.paths - def sample_iter(self): - for mol in tqdm(h5py.File(self.raw_paths).values(), desc="Molecules"): + assert len(self.raw_paths) == 1 + + for mol in tqdm(h5py.File(self.raw_paths[0]).values(), desc="Molecules"): + + if self.subsets: + if mol["subset"][0].decode() not in list(self.subsets): + continue z = pt.tensor(mol["atomic_numbers"], dtype=pt.long) all_pos = ( @@ -77,8 +113,9 @@ def sample_iter(self): for pos, y, dy in zip(all_pos, all_y, all_dy): # Skip samples with large forces - if dy.norm(dim=1).max() > 100: # eV/A - continue + if self.max_gradient: + if dy.norm(dim=1).max() > float(self.max_gradient): + continue data = Data(z=z, pos=pos, y=y.view(1, 1), dy=dy) @@ -90,18 +127,15 @@ def sample_iter(self): yield data - @property - def processed_file_names(self): - return [ - f"{self.name}.idx.mmap", - f"{self.name}.z.mmap", - f"{self.name}.pos.mmap", - f"{self.name}.y.mmap", - f"{self.name}.dy.mmap", - ] + def download(self): + download_url(self.raw_url, self.raw_dir) def process(self): + print("Arguments") + print(f" subsets: {self.subsets}") + print(f" max_gradient: {self.max_gradient} eV/A\n") + print("Gathering statistics...") num_all_confs = 0 num_all_atoms = 0