Skip to content

Commit

Permalink
Merge pull request #80 from mila-iqia/repulsive_force
Browse files Browse the repository at this point in the history
Repulsive force
  • Loading branch information
rousseab authored Sep 28, 2024
2 parents e9f18cc + cc3f750 commit 8c2c6d2
Show file tree
Hide file tree
Showing 4 changed files with 481 additions and 46 deletions.
2 changes: 1 addition & 1 deletion crystal_diffusion/generators/sde_position_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class SDESamplingParameters(SamplingParameters):
class SDE(torch.nn.Module):
"""SDE.
This class computes the drift and the diffusion coefficients in order to be consisent with the expectations
This class computes the drift and the diffusion coefficients in order to be consistent with the expectations
of the torchsde library.
"""
noise_type = 'diagonal' # we assume that there is a distinct Wiener process for each component.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
from dataclasses import dataclass
from typing import AnyStr, Dict, Optional

import einops
import torch

from crystal_diffusion.models.score_networks import ScoreNetwork
from crystal_diffusion.namespace import NOISY_RELATIVE_COORDINATES, UNIT_CELL
from crystal_diffusion.utils.basis_transformations import (
get_positions_from_coordinates, get_reciprocal_basis_vectors,
get_relative_coordinates_from_cartesian_positions)
from crystal_diffusion.utils.neighbors import (
AdjacencyInfo, get_periodic_adjacency_information)


@dataclass(kw_only=True)
class ForceFieldParameters:
"""Force field parameters.
The force field is based on a potential of the form:
phi(r) = strength * (r - radial_cutoff)^2
The corresponding force is thus of the form
F(r) = -nabla phi(r) = -2 strength * ( r - radial_cutoff) r_hat.
"""

radial_cutoff: float # Cutoff to the interaction, in Angstrom
strength: float # Strength of the repulsion


class ForceFieldAugmentedScoreNetwork(torch.nn.Module):
"""Force Field-Augmented Score Network.
This class wraps around an arbitrary score network in order to augment
its output with an effective "force field". The intuition behind this is that
atoms should never be very close to each other, but random numbers can lead
to such proximity: a repulsive force field will encourage atoms to separate during
diffusion.
"""
def __init__(
self, score_network: ScoreNetwork, force_field_parameters: ForceFieldParameters
):
"""Init method.
Args:
score_network : a score network, to be augmented with a repulsive force.
force_field_parameters : parameters for the repulsive force.
"""
super().__init__()

self._score_network = score_network
self._force_field_parameters = force_field_parameters

def forward(
self, batch: Dict[AnyStr, torch.Tensor], conditional: Optional[bool] = None
) -> torch.Tensor:
"""Model forward.
Args:
batch : dictionary containing the data to be processed by the model.
conditional: if True, do a conditional forward, if False, do a unconditional forward. If None, choose
randomly with probability conditional_prob
Returns:
computed_scores : the scores computed by the model.
"""
raw_scores = self._score_network(batch, conditional)
forces = self.get_relative_coordinates_pseudo_force(batch)
return raw_scores + forces

def _get_cartesian_pseudo_forces_contributions(
self, cartesian_displacements: torch.Tensor
):
"""Get cartesian pseudo forces.
The potential is given by
phi(r) = s * (r - r0)^2
Args:
cartesian_displacements : vectors (r_i - r_j). Dimension [number_of_edges, spatial_dimension]
Returns:
cartesian_pseudo_forces_contributions: Force contributions for each displacement, for the
chosen potential. F(r_i - r_j) = - d/dr phi(r) (r_i - r_j) / ||r_i - r_j||
"""
s = self._force_field_parameters.strength
r0 = self._force_field_parameters.radial_cutoff

number_of_edges, spatial_dimension = cartesian_displacements.shape

r = torch.linalg.norm(cartesian_displacements, dim=1)

# Add a small epsilon value in case r is close to zero, to avoid NaNs.
epsilon = torch.tensor(1.0e-8).to(r)

pseudo_force_prefactors = 2.0 * s * (r - r0) / (r + epsilon)
# Repeat so we can multiply by r_hat
repeat_pseudo_force_prefactors = einops.repeat(
pseudo_force_prefactors, "e -> e d", d=spatial_dimension
)
contributions = repeat_pseudo_force_prefactors * cartesian_displacements
return contributions

def _get_adjacency_information(
self, batch: Dict[AnyStr, torch.Tensor]
) -> AdjacencyInfo:
basis_vectors = batch[UNIT_CELL]
relative_coordinates = batch[NOISY_RELATIVE_COORDINATES]
cartesian_positions = get_positions_from_coordinates(
relative_coordinates, basis_vectors
)

adj_info = get_periodic_adjacency_information(
cartesian_positions,
basis_vectors,
radial_cutoff=self._force_field_parameters.radial_cutoff,
)
return adj_info

def _get_cartesian_displacements(
self, adj_info: AdjacencyInfo, batch: Dict[AnyStr, torch.Tensor]
):
# The following are 1D arrays of length equal to the total number of neighbors for all batch elements
# and all atoms.
# bch: which batch does an edge belong to
# src: at which atom does an edge start
# dst: at which atom does an edge end
bch = adj_info.edge_batch_indices
src, dst = adj_info.adjacency_matrix

relative_coordinates = batch[NOISY_RELATIVE_COORDINATES]
basis_vectors = batch[UNIT_CELL]
cartesian_positions = get_positions_from_coordinates(
relative_coordinates, basis_vectors
)

cartesian_displacements = (
cartesian_positions[bch, dst]
- cartesian_positions[bch, src]
+ adj_info.shifts
)
return cartesian_displacements

def _get_cartesian_pseudo_forces(
self,
cartesian_pseudo_force_contributions: torch.Tensor,
adj_info: AdjacencyInfo,
batch: Dict[AnyStr, torch.Tensor],
):
# The following are 1D arrays of length equal to the total number of neighbors for all batch elements
# and all atoms.
# bch: which batch does an edge belong to
# src: at which atom does an edge start
# dst: at which atom does an edge end
bch = adj_info.edge_batch_indices
src, dst = adj_info.adjacency_matrix

batch_size, natoms, spatial_dimension = batch[NOISY_RELATIVE_COORDINATES].shape

# Combine the bch and src index into a single global index
node_idx = natoms * bch + src

list_pseudo_force_components = []

for space_idx in range(spatial_dimension):
pseudo_force_component = torch.zeros(natoms * batch_size).to(cartesian_pseudo_force_contributions)
pseudo_force_component.scatter_add_(
dim=0,
index=node_idx,
src=cartesian_pseudo_force_contributions[:, space_idx],
)
list_pseudo_force_components.append(pseudo_force_component)

cartesian_pseudo_forces = einops.rearrange(
list_pseudo_force_components,
pattern="d (b n) -> b n d",
b=batch_size,
n=natoms,
)
return cartesian_pseudo_forces

def get_relative_coordinates_pseudo_force(
self, batch: Dict[AnyStr, torch.Tensor]
) -> torch.Tensor:
"""Get relative coordinates pseudo force.
Args:
batch : dictionary containing the data to be processed by the model.
Returns:
relative_pseudo_forces : repulsive force in relative coordinates.
"""
adj_info = self._get_adjacency_information(batch)

cartesian_displacements = self._get_cartesian_displacements(adj_info, batch)
cartesian_pseudo_force_contributions = (
self._get_cartesian_pseudo_forces_contributions(cartesian_displacements)
)

cartesian_pseudo_forces = self._get_cartesian_pseudo_forces(
cartesian_pseudo_force_contributions, adj_info, batch
)

basis_vectors = batch[UNIT_CELL]
reciprocal_basis_vectors = get_reciprocal_basis_vectors(basis_vectors)
relative_pseudo_forces = get_relative_coordinates_from_cartesian_positions(
cartesian_pseudo_forces, reciprocal_basis_vectors
)

return relative_pseudo_forces
93 changes: 48 additions & 45 deletions examples/config_files/diffusion/config_diffusion_egnn.yaml
Original file line number Diff line number Diff line change
@@ -1,91 +1,94 @@
# general
exp_name: dev_debug
run_name: run1
max_epoch: 100
max_epoch: 50
log_every_n_steps: 1
gradient_clipping: 0
gradient_clipping: 0.0
accumulate_grad_batches: 1 # make this number of forward passes before doing a backprop step


# set to null to avoid setting a seed (can speed up GPU computation, but
# results will not be reproducible)
seed: 1234

# data
data:
batch_size: 1024
num_workers: 0
max_atom: 8
batch_size: 128
num_workers: 8
max_atom: 64

# architecture
spatial_dimension: 3
model:
loss:
algorithm: mse
score_network:
architecture: egnn
message_n_hidden_dimensions: 1
message_hidden_dimensions_size: 16
node_n_hidden_dimensions: 1
node_hidden_dimensions_size: 32
coordinate_n_hidden_dimensions: 1
coordinate_hidden_dimensions_size: 32
residual: True
n_layers: 4
coordinate_hidden_dimensions_size: 128
coordinate_n_hidden_dimensions: 4
coords_agg: "mean"
message_hidden_dimensions_size: 128
message_n_hidden_dimensions: 4
node_hidden_dimensions_size: 128
node_n_hidden_dimensions: 4
attention: False
normalize: False
normalize: True
residual: True
tanh: False
coords_agg: mean
n_layers: 4
noise:
total_time_steps: 100
sigma_min: 0.005 # default value
sigma_max: 0.5 # default value'
total_time_steps: 1000
sigma_min: 0.0001
sigma_max: 0.2
corrector_step_epsilon: 2.0e-7

# optimizer and scheduler
optimizer:
name: adamw
learning_rate: 0.001
weight_decay: 1.0e-6
weight_decay: 5.0e-8


scheduler:
name: ReduceLROnPlateau
factor: 0.1
patience: 10
name: CosineAnnealingLR
T_max: 50
eta_min: 0.0

# early stopping
early_stopping:
metric: validation_epoch_loss
mode: min
patience: 10
patience: 100

model_checkpoint:
monitor: validation_epoch_loss
monitor: validation_ks_distance_structure
mode: min

# A callback to check the loss vs. sigma
loss_monitoring:
number_of_bins: 50
sample_every_n_epochs: 1

# Sampling from the generative model
diffusion_sampling:
noise:
total_time_steps: 100
sigma_min: 0.001 # default value
sigma_max: 0.5 # default value
total_time_steps: 1000
sigma_min: 0.0001
sigma_max: 0.2
corrector_step_epsilon: 2.0e-7
sampling:
algorithm: predictor_corrector
number_of_corrector_steps: 1
spatial_dimension: 3
number_of_atoms: 8
number_of_samples: 128
sample_batchsize: 128
sample_every_n_epochs: 1
record_samples: True
cell_dimensions: [5.43, 5.43, 5.43]
spatial_dimension: 3
number_of_corrector_steps: 1
number_of_atoms: 64
number_of_samples: 32
record_samples: False
cell_dimensions: [10.86, 10.86, 10.86]
metrics:
compute_energies: True
compute_structure_factor: True
structure_factor_max_distance: 10.0

sampling_visualization:
record_every_n_epochs: 1
first_record_epoch: 1
record_trajectories: False
record_energies: True
record_structure: True


logging:
# - comet
- tensorboard
#- csv
- comet
Loading

0 comments on commit 8c2c6d2

Please sign in to comment.