-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #80 from mila-iqia/repulsive_force
Repulsive force
- Loading branch information
Showing
4 changed files
with
481 additions
and
46 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
211 changes: 211 additions & 0 deletions
211
crystal_diffusion/models/score_networks/force_field_augmented_score_network.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
from dataclasses import dataclass | ||
from typing import AnyStr, Dict, Optional | ||
|
||
import einops | ||
import torch | ||
|
||
from crystal_diffusion.models.score_networks import ScoreNetwork | ||
from crystal_diffusion.namespace import NOISY_RELATIVE_COORDINATES, UNIT_CELL | ||
from crystal_diffusion.utils.basis_transformations import ( | ||
get_positions_from_coordinates, get_reciprocal_basis_vectors, | ||
get_relative_coordinates_from_cartesian_positions) | ||
from crystal_diffusion.utils.neighbors import ( | ||
AdjacencyInfo, get_periodic_adjacency_information) | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class ForceFieldParameters: | ||
"""Force field parameters. | ||
The force field is based on a potential of the form: | ||
phi(r) = strength * (r - radial_cutoff)^2 | ||
The corresponding force is thus of the form | ||
F(r) = -nabla phi(r) = -2 strength * ( r - radial_cutoff) r_hat. | ||
""" | ||
|
||
radial_cutoff: float # Cutoff to the interaction, in Angstrom | ||
strength: float # Strength of the repulsion | ||
|
||
|
||
class ForceFieldAugmentedScoreNetwork(torch.nn.Module): | ||
"""Force Field-Augmented Score Network. | ||
This class wraps around an arbitrary score network in order to augment | ||
its output with an effective "force field". The intuition behind this is that | ||
atoms should never be very close to each other, but random numbers can lead | ||
to such proximity: a repulsive force field will encourage atoms to separate during | ||
diffusion. | ||
""" | ||
def __init__( | ||
self, score_network: ScoreNetwork, force_field_parameters: ForceFieldParameters | ||
): | ||
"""Init method. | ||
Args: | ||
score_network : a score network, to be augmented with a repulsive force. | ||
force_field_parameters : parameters for the repulsive force. | ||
""" | ||
super().__init__() | ||
|
||
self._score_network = score_network | ||
self._force_field_parameters = force_field_parameters | ||
|
||
def forward( | ||
self, batch: Dict[AnyStr, torch.Tensor], conditional: Optional[bool] = None | ||
) -> torch.Tensor: | ||
"""Model forward. | ||
Args: | ||
batch : dictionary containing the data to be processed by the model. | ||
conditional: if True, do a conditional forward, if False, do a unconditional forward. If None, choose | ||
randomly with probability conditional_prob | ||
Returns: | ||
computed_scores : the scores computed by the model. | ||
""" | ||
raw_scores = self._score_network(batch, conditional) | ||
forces = self.get_relative_coordinates_pseudo_force(batch) | ||
return raw_scores + forces | ||
|
||
def _get_cartesian_pseudo_forces_contributions( | ||
self, cartesian_displacements: torch.Tensor | ||
): | ||
"""Get cartesian pseudo forces. | ||
The potential is given by | ||
phi(r) = s * (r - r0)^2 | ||
Args: | ||
cartesian_displacements : vectors (r_i - r_j). Dimension [number_of_edges, spatial_dimension] | ||
Returns: | ||
cartesian_pseudo_forces_contributions: Force contributions for each displacement, for the | ||
chosen potential. F(r_i - r_j) = - d/dr phi(r) (r_i - r_j) / ||r_i - r_j|| | ||
""" | ||
s = self._force_field_parameters.strength | ||
r0 = self._force_field_parameters.radial_cutoff | ||
|
||
number_of_edges, spatial_dimension = cartesian_displacements.shape | ||
|
||
r = torch.linalg.norm(cartesian_displacements, dim=1) | ||
|
||
# Add a small epsilon value in case r is close to zero, to avoid NaNs. | ||
epsilon = torch.tensor(1.0e-8).to(r) | ||
|
||
pseudo_force_prefactors = 2.0 * s * (r - r0) / (r + epsilon) | ||
# Repeat so we can multiply by r_hat | ||
repeat_pseudo_force_prefactors = einops.repeat( | ||
pseudo_force_prefactors, "e -> e d", d=spatial_dimension | ||
) | ||
contributions = repeat_pseudo_force_prefactors * cartesian_displacements | ||
return contributions | ||
|
||
def _get_adjacency_information( | ||
self, batch: Dict[AnyStr, torch.Tensor] | ||
) -> AdjacencyInfo: | ||
basis_vectors = batch[UNIT_CELL] | ||
relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] | ||
cartesian_positions = get_positions_from_coordinates( | ||
relative_coordinates, basis_vectors | ||
) | ||
|
||
adj_info = get_periodic_adjacency_information( | ||
cartesian_positions, | ||
basis_vectors, | ||
radial_cutoff=self._force_field_parameters.radial_cutoff, | ||
) | ||
return adj_info | ||
|
||
def _get_cartesian_displacements( | ||
self, adj_info: AdjacencyInfo, batch: Dict[AnyStr, torch.Tensor] | ||
): | ||
# The following are 1D arrays of length equal to the total number of neighbors for all batch elements | ||
# and all atoms. | ||
# bch: which batch does an edge belong to | ||
# src: at which atom does an edge start | ||
# dst: at which atom does an edge end | ||
bch = adj_info.edge_batch_indices | ||
src, dst = adj_info.adjacency_matrix | ||
|
||
relative_coordinates = batch[NOISY_RELATIVE_COORDINATES] | ||
basis_vectors = batch[UNIT_CELL] | ||
cartesian_positions = get_positions_from_coordinates( | ||
relative_coordinates, basis_vectors | ||
) | ||
|
||
cartesian_displacements = ( | ||
cartesian_positions[bch, dst] | ||
- cartesian_positions[bch, src] | ||
+ adj_info.shifts | ||
) | ||
return cartesian_displacements | ||
|
||
def _get_cartesian_pseudo_forces( | ||
self, | ||
cartesian_pseudo_force_contributions: torch.Tensor, | ||
adj_info: AdjacencyInfo, | ||
batch: Dict[AnyStr, torch.Tensor], | ||
): | ||
# The following are 1D arrays of length equal to the total number of neighbors for all batch elements | ||
# and all atoms. | ||
# bch: which batch does an edge belong to | ||
# src: at which atom does an edge start | ||
# dst: at which atom does an edge end | ||
bch = adj_info.edge_batch_indices | ||
src, dst = adj_info.adjacency_matrix | ||
|
||
batch_size, natoms, spatial_dimension = batch[NOISY_RELATIVE_COORDINATES].shape | ||
|
||
# Combine the bch and src index into a single global index | ||
node_idx = natoms * bch + src | ||
|
||
list_pseudo_force_components = [] | ||
|
||
for space_idx in range(spatial_dimension): | ||
pseudo_force_component = torch.zeros(natoms * batch_size).to(cartesian_pseudo_force_contributions) | ||
pseudo_force_component.scatter_add_( | ||
dim=0, | ||
index=node_idx, | ||
src=cartesian_pseudo_force_contributions[:, space_idx], | ||
) | ||
list_pseudo_force_components.append(pseudo_force_component) | ||
|
||
cartesian_pseudo_forces = einops.rearrange( | ||
list_pseudo_force_components, | ||
pattern="d (b n) -> b n d", | ||
b=batch_size, | ||
n=natoms, | ||
) | ||
return cartesian_pseudo_forces | ||
|
||
def get_relative_coordinates_pseudo_force( | ||
self, batch: Dict[AnyStr, torch.Tensor] | ||
) -> torch.Tensor: | ||
"""Get relative coordinates pseudo force. | ||
Args: | ||
batch : dictionary containing the data to be processed by the model. | ||
Returns: | ||
relative_pseudo_forces : repulsive force in relative coordinates. | ||
""" | ||
adj_info = self._get_adjacency_information(batch) | ||
|
||
cartesian_displacements = self._get_cartesian_displacements(adj_info, batch) | ||
cartesian_pseudo_force_contributions = ( | ||
self._get_cartesian_pseudo_forces_contributions(cartesian_displacements) | ||
) | ||
|
||
cartesian_pseudo_forces = self._get_cartesian_pseudo_forces( | ||
cartesian_pseudo_force_contributions, adj_info, batch | ||
) | ||
|
||
basis_vectors = batch[UNIT_CELL] | ||
reciprocal_basis_vectors = get_reciprocal_basis_vectors(basis_vectors) | ||
relative_pseudo_forces = get_relative_coordinates_from_cartesian_positions( | ||
cartesian_pseudo_forces, reciprocal_basis_vectors | ||
) | ||
|
||
return relative_pseudo_forces |
93 changes: 48 additions & 45 deletions
93
examples/config_files/diffusion/config_diffusion_egnn.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,91 +1,94 @@ | ||
# general | ||
exp_name: dev_debug | ||
run_name: run1 | ||
max_epoch: 100 | ||
max_epoch: 50 | ||
log_every_n_steps: 1 | ||
gradient_clipping: 0 | ||
gradient_clipping: 0.0 | ||
accumulate_grad_batches: 1 # make this number of forward passes before doing a backprop step | ||
|
||
|
||
# set to null to avoid setting a seed (can speed up GPU computation, but | ||
# results will not be reproducible) | ||
seed: 1234 | ||
|
||
# data | ||
data: | ||
batch_size: 1024 | ||
num_workers: 0 | ||
max_atom: 8 | ||
batch_size: 128 | ||
num_workers: 8 | ||
max_atom: 64 | ||
|
||
# architecture | ||
spatial_dimension: 3 | ||
model: | ||
loss: | ||
algorithm: mse | ||
score_network: | ||
architecture: egnn | ||
message_n_hidden_dimensions: 1 | ||
message_hidden_dimensions_size: 16 | ||
node_n_hidden_dimensions: 1 | ||
node_hidden_dimensions_size: 32 | ||
coordinate_n_hidden_dimensions: 1 | ||
coordinate_hidden_dimensions_size: 32 | ||
residual: True | ||
n_layers: 4 | ||
coordinate_hidden_dimensions_size: 128 | ||
coordinate_n_hidden_dimensions: 4 | ||
coords_agg: "mean" | ||
message_hidden_dimensions_size: 128 | ||
message_n_hidden_dimensions: 4 | ||
node_hidden_dimensions_size: 128 | ||
node_n_hidden_dimensions: 4 | ||
attention: False | ||
normalize: False | ||
normalize: True | ||
residual: True | ||
tanh: False | ||
coords_agg: mean | ||
n_layers: 4 | ||
noise: | ||
total_time_steps: 100 | ||
sigma_min: 0.005 # default value | ||
sigma_max: 0.5 # default value' | ||
total_time_steps: 1000 | ||
sigma_min: 0.0001 | ||
sigma_max: 0.2 | ||
corrector_step_epsilon: 2.0e-7 | ||
|
||
# optimizer and scheduler | ||
optimizer: | ||
name: adamw | ||
learning_rate: 0.001 | ||
weight_decay: 1.0e-6 | ||
weight_decay: 5.0e-8 | ||
|
||
|
||
scheduler: | ||
name: ReduceLROnPlateau | ||
factor: 0.1 | ||
patience: 10 | ||
name: CosineAnnealingLR | ||
T_max: 50 | ||
eta_min: 0.0 | ||
|
||
# early stopping | ||
early_stopping: | ||
metric: validation_epoch_loss | ||
mode: min | ||
patience: 10 | ||
patience: 100 | ||
|
||
model_checkpoint: | ||
monitor: validation_epoch_loss | ||
monitor: validation_ks_distance_structure | ||
mode: min | ||
|
||
# A callback to check the loss vs. sigma | ||
loss_monitoring: | ||
number_of_bins: 50 | ||
sample_every_n_epochs: 1 | ||
|
||
# Sampling from the generative model | ||
diffusion_sampling: | ||
noise: | ||
total_time_steps: 100 | ||
sigma_min: 0.001 # default value | ||
sigma_max: 0.5 # default value | ||
total_time_steps: 1000 | ||
sigma_min: 0.0001 | ||
sigma_max: 0.2 | ||
corrector_step_epsilon: 2.0e-7 | ||
sampling: | ||
algorithm: predictor_corrector | ||
number_of_corrector_steps: 1 | ||
spatial_dimension: 3 | ||
number_of_atoms: 8 | ||
number_of_samples: 128 | ||
sample_batchsize: 128 | ||
sample_every_n_epochs: 1 | ||
record_samples: True | ||
cell_dimensions: [5.43, 5.43, 5.43] | ||
spatial_dimension: 3 | ||
number_of_corrector_steps: 1 | ||
number_of_atoms: 64 | ||
number_of_samples: 32 | ||
record_samples: False | ||
cell_dimensions: [10.86, 10.86, 10.86] | ||
metrics: | ||
compute_energies: True | ||
compute_structure_factor: True | ||
structure_factor_max_distance: 10.0 | ||
|
||
sampling_visualization: | ||
record_every_n_epochs: 1 | ||
first_record_epoch: 1 | ||
record_trajectories: False | ||
record_energies: True | ||
record_structure: True | ||
|
||
|
||
logging: | ||
# - comet | ||
- tensorboard | ||
#- csv | ||
- comet |
Oops, something went wrong.