Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force gradient regularization #331

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchmdnet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .qm9q import QM9q
from .spice import SPICE
from .genentech import GenentechTorsions
from .artificial_short_range import ShortRange

__all__ = [
"Ace",
Expand All @@ -47,4 +48,5 @@
"SPICE",
"Tripeptides",
"WaterBox",
"ShortRange",
]
117 changes: 117 additions & 0 deletions torchmdnet/datasets/artificial_short_range.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import glob
import numpy as np
import torch
from torch_geometric.data import Dataset, Data


def random_vectors_in_sphere_box_muller(radius, count):
# Generate uniformly distributed random numbers for Box-Muller
u1 = np.random.uniform(low=0.0, high=1.0, size=count)
u2 = np.random.uniform(low=0.0, high=1.0, size=count)
u3 = np.random.uniform(low=0.0, high=1.0, size=count)

# Box-Muller transform for normal distribution
normal1 = np.sqrt(-2.0 * np.log(u1)) * np.cos(2.0 * np.pi * u2)
normal2 = np.sqrt(-2.0 * np.log(u1)) * np.sin(2.0 * np.pi * u2)
normal3 = np.sqrt(-2.0 * np.log(u3)) * np.cos(
2.0 * np.pi * u2
) # Using u2 again for the third component

# Stack the normals
vectors = np.column_stack((normal1, normal2, normal3))

# Normalize each vector to have magnitude 1
norms = np.linalg.norm(vectors, axis=1)
vectors_normalized = vectors / norms[:, np.newaxis]

# Scale vectors by random radii up to 'radius'
scale = np.random.uniform(0, radius**3, count) ** (
1 / 3
) # Cube root to ensure uniform distribution in volume
vectors_scaled = vectors_normalized * scale[:, np.newaxis]

return vectors_scaled


def compute_energy(pos, max_dist):
dist = torch.linalg.norm(pos[:, 0, :] - pos[:, 1, :], axis=1) # shape (size,)
y = 20 + 80 * (1 - dist / max_dist) # shape (size,)
return y


def compute_forces(pos, max_dist):
pos = pos.clone().detach().requires_grad_(True)
y = compute_energy(pos, max_dist)
y_sum = y.sum()
y_sum.backward()
forces = -pos.grad
return forces


class ShortRange(Dataset):
def __init__(self, root, max_dist, size, max_z, transform=None, pre_transform=None):
super(ShortRange, self).__init__(root, transform, pre_transform)
self.max_dist = max_dist
self.size = size
self.max_z = max_z
# Create some npy files with random data. The dataset consists of pairs of atoms, with their positions, atomic numbers and energy
# Positions inside a sphere of radius max_dist
self.pos = random_vectors_in_sphere_box_muller(max_dist, 2 * size)
self.pos = self.pos.reshape(size, 2, 3)
# Atomic numbers
self.z = np.random.randint(1, max_z, size=2 * size).reshape(size, 2)
# Energy
self.y = compute_energy(torch.tensor(self.pos), max_dist).detach().numpy() * 0
assert self.y.shape == (size,)
assert self.z.shape == (size, 2)
# Negative gradient of the energy with respect to the positions, should have the same shape as pos
self.neg_dy = (
compute_forces(torch.tensor(self.pos, dtype=torch.float), max_dist)
.detach()
.numpy()
* 0
)

def get(self, idx):
y = torch.tensor(self.y[idx], dtype=torch.float).view(1, 1)
z = torch.tensor(self.z[idx], dtype=torch.long).view(2)
pos = torch.tensor(self.pos[idx], dtype=torch.float).view(2, 3)
neg_dy = torch.tensor(self.neg_dy[idx], dtype=torch.float).view(2, 3)
data = Data(
z=z,
pos=pos,
y=y,
neg_dy=neg_dy,
)
return data

def len(self):
return self.size

# Taken from https://github.com/isayev/ASE_ANI/blob/master/ani_models/ani-2x_8x/sae_linfit.dat

_ELEMENT_ENERGIES = {
1: -0.5978583943827134, # H
6: -38.08933878049795, # C
7: -54.711968298621066, # N
8: -75.19106774742086, # O
9: -99.80348506781634, # F
16: -398.1577125334925, # S
17: -460.1681939421027, # Cl
}
HARTREE_TO_EV = 27.211386246 #::meta private:

def get_atomref(self, max_z=100):
"""Atomic energy reference values for the :py:mod:`torchmdnet.priors.Atomref` prior.

Args:
max_z (int): Maximum atomic number

Returns:
torch.Tensor: Atomic energy reference values for each element in the dataset.
"""
refs = torch.zeros(max_z)
for key, val in self._ELEMENT_ENERGIES.items():
refs[key] = val * self.HARTREE_TO_EV * 0

return refs.view(-1, 1)
32 changes: 32 additions & 0 deletions torchmdnet/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,29 @@ def _compute_losses(self, y, neg_y, batch, loss_fn, stage):
loss_y = self._update_loss_with_ema(stage, "y", loss_name, loss_y)
return {"y": loss_y, "neg_dy": loss_neg_y}

def _compute_second_derivative_regularization(self, y, neg_dy, batch):
# Compute force gradient and add it to the loss like: max(0, grad(neg_dy.sum())-eps)^2
# Args:
# y: predicted value
# neg_dy: predicted negative derivative
# batch: batch of data
# Returns:
# regularization: regularization term
assert "pos" in batch
force_sum = (neg_dy**2).sum()
grad_outputs = [torch.ones_like(force_sum)]
assert batch.pos.requires_grad
ddy = torch.autograd.grad(
[force_sum],
[batch.pos],
grad_outputs=grad_outputs,
create_graph=True,
retain_graph=True,
)[0]
regularization = ddy.norm() * self.hparams.regularization_weight
print(f"Regularization: {regularization}")
return regularization

def _update_loss_with_ema(self, stage, type, loss_name, loss):
# Update the loss using an exponential moving average when applicable
# Args:
Expand Down Expand Up @@ -235,6 +258,15 @@ def step(self, batch, loss_fn_list, stage):
step_losses["y"] * self.hparams.y_weight
+ step_losses["neg_dy"] * self.hparams.neg_dy_weight
)
if (
self.hparams.regularize_second_gradient
and self.hparams.derivative
and stage == "train"
):
total_loss = (
total_loss
+ self._compute_second_derivative_regularization(y, neg_dy, batch)
)
self.losses[stage]["total"][loss_name].append(total_loss.detach())
return total_loss

Expand Down
3 changes: 3 additions & 0 deletions torchmdnet/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def get_argparse():
parser.add_argument('--redirect', type=bool, default=False, help='Redirect stdout and stderr to log_dir/log')
parser.add_argument('--gradient-clipping', type=float, default=0.0, help='Gradient clipping norm')
parser.add_argument('--remove-ref-energy', action='store_true', help='If true, remove the reference energy from the dataset for delta-learning. Total energy can still be predicted by the model during inference by turning this flag off when loading. The dataset must be compatible with Atomref for this to be used.')

parser.add_argument('--regularize-second-gradient', action="store_true", help='If true, regularize the second derivative of the energy w.r.t. the coordinates')
parser.add_argument('--regularization-weight', type=float, default=0.0, help='Weight for the force regularization term')
# dataset specific
parser.add_argument('--dataset', default=None, type=str, choices=datasets.__all__, help='Name of the torch_geometric dataset')
parser.add_argument('--dataset-root', default='~/data', type=str, help='Data storage directory (not used if dataset is "CG")')
Expand Down
Loading