Skip to content

Commit

Permalink
Merge pull request #113 from raimis/ds_spice_2
Browse files Browse the repository at this point in the history
Improved SPICE loader
  • Loading branch information
Raimondas Galvelis authored Aug 24, 2022
2 parents e3014b9 + 5f5d660 commit db72e12
Showing 1 changed file with 55 additions and 21 deletions.
76 changes: 55 additions & 21 deletions torchmdnet/datasets/spice.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,65 @@
import hashlib
import h5py
import numpy as np
import os
import torch as pt
from torch_geometric.data import Data, Dataset
from torch_geometric.data import Data, Dataset, download_url
from tqdm import tqdm


class SPICE(Dataset):

"""
SPICE dataset (https://github.com/openmm/spice-dataset)
The dataset consists of several subsets (https://github.com/openmm/spice-dataset/blob/main/downloader/config.yaml).
The subsets can be selected with `subsets`. By default, all the subsets are loaded.
For example, this loads just two subsets:
>>> ds = SPICE(".", subsets=["SPICE PubChem Set 1 Single Points Dataset v1.2", "SPICE PubChem Set 2 Single Points Dataset v1.2"])
The loader can filter conformations with large gradients. The maximum gradient norm threshold
can be set with `max_gradient`. By default, the filter is not applied.
For examples, the filter the threshold is set to 100 eV/A:
>>> ds = SPICE(".", max_gradient=100)
"""

HARTREE_TO_EV = 27.211386246
BORH_TO_ANGSTROM = 0.529177

@property
def raw_file_names(self):
return "SPICE.hdf5"

@property
def raw_url(self):
return f"https://github.com/openmm/spice-dataset/releases/download/1.0/{self.raw_file_names}"

@property
def processed_file_names(self):
return [
f"{self.name}.idx.mmap",
f"{self.name}.z.mmap",
f"{self.name}.pos.mmap",
f"{self.name}.y.mmap",
f"{self.name}.dy.mmap",
]

def __init__(
self,
root=None,
transform=None,
pre_transform=None,
pre_filter=None,
paths=None,
subsets=None,
max_gradient=None,
):

self.name = self.__class__.__name__
self.paths = str(paths)
arg_hash = f"{subsets}{max_gradient}"
arg_hash = hashlib.md5(arg_hash.encode()).hexdigest()
self.name = f"{self.__class__.__name__}-{arg_hash}"
self.subsets = subsets
self.max_gradient = max_gradient
super().__init__(root, transform, pre_transform, pre_filter)

idx_name, z_name, pos_name, y_name, dy_name = self.processed_paths
Expand All @@ -43,13 +77,15 @@ def __init__(
assert self.idx_mm[-1] == len(self.z_mm)
assert len(self.idx_mm) == len(self.y_mm) + 1

@property
def raw_paths(self):
return self.paths

def sample_iter(self):

for mol in tqdm(h5py.File(self.raw_paths).values(), desc="Molecules"):
assert len(self.raw_paths) == 1

for mol in tqdm(h5py.File(self.raw_paths[0]).values(), desc="Molecules"):

if self.subsets:
if mol["subset"][0].decode() not in list(self.subsets):
continue

z = pt.tensor(mol["atomic_numbers"], dtype=pt.long)
all_pos = (
Expand Down Expand Up @@ -77,8 +113,9 @@ def sample_iter(self):
for pos, y, dy in zip(all_pos, all_y, all_dy):

# Skip samples with large forces
if dy.norm(dim=1).max() > 100: # eV/A
continue
if self.max_gradient:
if dy.norm(dim=1).max() > float(self.max_gradient):
continue

data = Data(z=z, pos=pos, y=y.view(1, 1), dy=dy)

Expand All @@ -90,18 +127,15 @@ def sample_iter(self):

yield data

@property
def processed_file_names(self):
return [
f"{self.name}.idx.mmap",
f"{self.name}.z.mmap",
f"{self.name}.pos.mmap",
f"{self.name}.y.mmap",
f"{self.name}.dy.mmap",
]
def download(self):
download_url(self.raw_url, self.raw_dir)

def process(self):

print("Arguments")
print(f" subsets: {self.subsets}")
print(f" max_gradient: {self.max_gradient} eV/A\n")

print("Gathering statistics...")
num_all_confs = 0
num_all_atoms = 0
Expand Down

0 comments on commit db72e12

Please sign in to comment.