diff --git a/torchmdnet/datasets/__init__.py b/torchmdnet/datasets/__init__.py index b57cd95a..dff30016 100644 --- a/torchmdnet/datasets/__init__.py +++ b/torchmdnet/datasets/__init__.py @@ -23,6 +23,7 @@ from .qm9q import QM9q from .spice import SPICE from .genentech import GenentechTorsions +from .artificial_short_range import ShortRange __all__ = [ "Ace", @@ -47,4 +48,5 @@ "SPICE", "Tripeptides", "WaterBox", + "ShortRange", ] diff --git a/torchmdnet/datasets/artificial_short_range.py b/torchmdnet/datasets/artificial_short_range.py new file mode 100644 index 00000000..38e964f0 --- /dev/null +++ b/torchmdnet/datasets/artificial_short_range.py @@ -0,0 +1,117 @@ +import glob +import numpy as np +import torch +from torch_geometric.data import Dataset, Data + + +def random_vectors_in_sphere_box_muller(radius, count): + # Generate uniformly distributed random numbers for Box-Muller + u1 = np.random.uniform(low=0.0, high=1.0, size=count) + u2 = np.random.uniform(low=0.0, high=1.0, size=count) + u3 = np.random.uniform(low=0.0, high=1.0, size=count) + + # Box-Muller transform for normal distribution + normal1 = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2) + normal2 = np.sqrt(-2.0 * np.log(u1)) * np.sin(2.0 * np.pi * u2) + normal3 = np.sqrt(-2.0 * np.log(u3)) * np.cos( + 2.0 * np.pi * u2 + ) # Using u2 again for the third component + + # Stack the normals + vectors = np.column_stack((normal1, normal2, normal3)) + + # Normalize each vector to have magnitude 1 + norms = np.linalg.norm(vectors, axis=1) + vectors_normalized = vectors / norms[:, np.newaxis] + + # Scale vectors by random radii up to 'radius' + scale = np.random.uniform(0, radius**3, count) ** ( + 1 / 3 + ) # Cube root to ensure uniform distribution in volume + vectors_scaled = vectors_normalized * scale[:, np.newaxis] + + return vectors_scaled + + +def compute_energy(pos, max_dist): + dist = torch.linalg.norm(pos[:, 0, :] - pos[:, 1, :], axis=1) # shape (size,) + y = 20 + 80 * (1 - dist / max_dist) # shape (size,) + return y + + +def compute_forces(pos, max_dist): + pos = pos.clone().detach().requires_grad_(True) + y = compute_energy(pos, max_dist) + y_sum = y.sum() + y_sum.backward() + forces = -pos.grad + return forces + + +class ShortRange(Dataset): + def __init__(self, root, max_dist, size, max_z, transform=None, pre_transform=None): + super(ShortRange, self).__init__(root, transform, pre_transform) + self.max_dist = max_dist + self.size = size + self.max_z = max_z + # Create some npy files with random data. The dataset consists of pairs of atoms, with their positions, atomic numbers and energy + # Positions inside a sphere of radius max_dist + self.pos = random_vectors_in_sphere_box_muller(max_dist, 2 * size) + self.pos = self.pos.reshape(size, 2, 3) + # Atomic numbers + self.z = np.random.randint(1, max_z, size=2 * size).reshape(size, 2) + # Energy + self.y = compute_energy(torch.tensor(self.pos), max_dist).detach().numpy() * 0 + assert self.y.shape == (size,) + assert self.z.shape == (size, 2) + # Negative gradient of the energy with respect to the positions, should have the same shape as pos + self.neg_dy = ( + compute_forces(torch.tensor(self.pos, dtype=torch.float), max_dist) + .detach() + .numpy() + * 0 + ) + + def get(self, idx): + y = torch.tensor(self.y[idx], dtype=torch.float).view(1, 1) + z = torch.tensor(self.z[idx], dtype=torch.long).view(2) + pos = torch.tensor(self.pos[idx], dtype=torch.float).view(2, 3) + neg_dy = torch.tensor(self.neg_dy[idx], dtype=torch.float).view(2, 3) + data = Data( + z=z, + pos=pos, + y=y, + neg_dy=neg_dy, + ) + return data + + def len(self): + return self.size + + # Taken from https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat + + _ELEMENT_ENERGIES = { + 1: -0.5978583943827134, # H + 6: -38.08933878049795, # C + 7: -54.711968298621066, # N + 8: -75.19106774742086, # O + 9: -99.80348506781634, # F + 16: -398.1577125334925, # S + 17: -460.1681939421027, # Cl + } + HARTREE_TO_EV = 27.211386246 #::meta private: + + def get_atomref(self, max_z=100): + """Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior. + + Args: + max_z (int): Maximum atomic number + + Returns: + torch.Tensor: Atomic energy reference values for each element in the dataset. + """ + refs = torch.zeros(max_z) + for key, val in self._ELEMENT_ENERGIES.items(): + refs[key] = val * self.HARTREE_TO_EV * 0 + + return refs.view(-1, 1) diff --git a/torchmdnet/module.py b/torchmdnet/module.py index d5ea73cf..8f663134 100644 --- a/torchmdnet/module.py +++ b/torchmdnet/module.py @@ -167,6 +167,29 @@ def _compute_losses(self, y, neg_y, batch, loss_fn, stage): loss_y = self._update_loss_with_ema(stage, "y", loss_name, loss_y) return {"y": loss_y, "neg_dy": loss_neg_y} + def _compute_second_derivative_regularization(self, y, neg_dy, batch): + # Compute force gradient and add it to the loss like: max(0, grad(neg_dy.sum())-eps)^2 + # Args: + # y: predicted value + # neg_dy: predicted negative derivative + # batch: batch of data + # Returns: + # regularization: regularization term + assert "pos" in batch + force_sum = (neg_dy**2).sum() + grad_outputs = [torch.ones_like(force_sum)] + assert batch.pos.requires_grad + ddy = torch.autograd.grad( + [force_sum], + [batch.pos], + grad_outputs=grad_outputs, + create_graph=True, + retain_graph=True, + )[0] + regularization = ddy.norm() * self.hparams.regularization_weight + print(f"Regularization: {regularization}") + return regularization + def _update_loss_with_ema(self, stage, type, loss_name, loss): # Update the loss using an exponential moving average when applicable # Args: @@ -235,6 +258,15 @@ def step(self, batch, loss_fn_list, stage): step_losses["y"] * self.hparams.y_weight + step_losses["neg_dy"] * self.hparams.neg_dy_weight ) + if ( + self.hparams.regularize_second_gradient + and self.hparams.derivative + and stage == "train" + ): + total_loss = ( + total_loss + + self._compute_second_derivative_regularization(y, neg_dy, batch) + ) self.losses[stage]["total"][loss_name].append(total_loss.detach()) return total_loss diff --git a/torchmdnet/scripts/train.py b/torchmdnet/scripts/train.py index 7f2d8e07..f4f08058 100644 --- a/torchmdnet/scripts/train.py +++ b/torchmdnet/scripts/train.py @@ -59,6 +59,9 @@ def get_argparse(): parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log') parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm') parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.') + + parser.add_argument('--regularize-second-gradient', action="store_true", help='If true, regularize the second derivative of the energy w.r.t. the coordinates') + parser.add_argument('--regularization-weight', type=float, default=0.0, help='Weight for the force regularization term') # dataset specific parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset') parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')