Skip to content

Commit

Permalink
Remove torque option added.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Jan 11, 2025
1 parent a56933b commit e5622f7
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 1 deletion.
9 changes: 8 additions & 1 deletion alignn/models/ealignn_atomwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
compute_pair_vector_and_distance,
MLPLayer,
lightweight_line_graph,
remove_net_torque,
)
from alignn.graphs import compute_bond_cosines
from alignn.utils import BaseSettings
Expand All @@ -45,7 +46,8 @@ class eALIGNNAtomWiseConfig(BaseSettings):
stresswise_weight: float = 0.0
atomwise_weight: float = 0.0
classification: bool = False
energy_mult_natoms: bool = True
energy_mult_natoms: bool = True # Make it false for regression only
remove_torque: bool = True
inner_cutoff: float = 2.8 # Ansgtrom
use_penalty: bool = True
extra_features: int = 0
Expand Down Expand Up @@ -390,6 +392,11 @@ def forward(
forces = torch.squeeze(
g.ndata["forces_ji"] - rg.ndata["forces_ij"]
)
if self.config.remove_torque:
# print('forces1',forces,forces.shape)
# print('natoms',natoms,natoms.shape)
forces = remove_net_torque(g, forces, natoms)
# print('forces2',forces,forces.shape)

if self.config.stresswise_weight != 0:
stresses = []
Expand Down
105 changes: 105 additions & 0 deletions alignn/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torch.nn as nn
import dgl
from typing import Tuple


class RBFExpansion(nn.Module):
Expand Down Expand Up @@ -171,3 +172,107 @@ def forward(self, x):
"""Linear, Batchnorm, silu layer."""
# print('xtype',x.dtype)
return self.layer(x)


def compute_net_torque(
positions: torch.Tensor, forces: torch.Tensor, n_nodes: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Compute the net torque on a system of particles."""
total_mass = n_nodes.float().sum() # Total number of particles
com = torch.sum(positions, dim=0) / total_mass # Center of mass

# Compute the relative positions of particles with respect to CoM
com_repeat = com.repeat(
positions.size(0), 1
) # Repeat CoM for each particle
com_relative_positions = (
positions - com_repeat
) # Relative position to the CoM

# Compute individual torques (cross product of r_i and F_i)
torques = torch.cross(com_relative_positions, forces) # Shape: (N, 3)

# Aggregate torques to get the net torque (sum all torques)
net_torque = torch.sum(torques, dim=0) # Sum of all individual torques

return net_torque, com_relative_positions


def remove_net_torque(
g: dgl.DGLGraph,
forces: torch.Tensor,
n_nodes: torch.Tensor,
) -> torch.Tensor:
"""Adjust the predicted forces to eliminate net torque for each graph in the batch.
Args:
g : dgl.DGLGraph
The graph representing a batch of particles (atoms).
forces : torch.Tensor of shape (N, 3)
Predicted forces on atoms.
n_nodes : torch.Tensor of shape (B,)
Number of nodes in each graph, where B is the number of graphs in the batch.
Returns:
adjusted_forces : torch.Tensor of shape (N, 3)
Adjusted forces with zero net torque and net force for each graph.
"""
# Step 1: Get positions from the graph (assuming 'cart_coords' holds the positions)
positions = g.ndata["cart_coords"]

# Compute the net torque and relative positions
tau_total, r = compute_net_torque(positions, forces, n_nodes)

# Step 2: Compute scalar s per graph: sum_i ||r_i||^2
r_squared = torch.sum(r**2, dim=1) # Shape: (N,)

# Sum over nodes to aggregate r_squared for each graph
s = torch.zeros(n_nodes.size(0), device=positions.device)
start_idx = 0
for i, num_nodes in enumerate(n_nodes):
end_idx = start_idx + num_nodes
s[i] = torch.sum(r_squared[start_idx:end_idx])
start_idx = end_idx

# Step 3: Compute matrix S per graph: sum_i outer(r_i, r_i)
r_unsqueezed = r.unsqueeze(2) # Shape: (N, 3, 1)
r_T_unsqueezed = r.unsqueeze(1) # Shape: (N, 1, 3)
outer_products = r_unsqueezed @ r_T_unsqueezed # Shape: (N, 3, 3)

# Aggregate outer products for each graph
S = torch.zeros(n_nodes.size(0), 3, 3, device=positions.device)
start_idx = 0
for i, num_nodes in enumerate(n_nodes):
end_idx = start_idx + num_nodes
S[i] = torch.sum(outer_products[start_idx:end_idx], dim=0)
start_idx = end_idx

# Step 4: Compute M = S - sI
I = (
torch.eye(3, device=positions.device)
.unsqueeze(0)
.expand(n_nodes.size(0), -1, -1)
) # Identity matrix
M = S - s.view(-1, 1, 1) * I # Shape: (B, 3, 3)

# Step 5: Right-hand side vector b per graph
b = -tau_total # Shape: (B, 3)

# Step 6: Solve M * mu = b for mu per graph
try:
mu = torch.linalg.solve(
M, b
) # Shape: (B, 3) -- No need for unsqueeze(2)
except RuntimeError:
# Handle singular matrix M by using the pseudo-inverse
M_pinv = torch.linalg.pinv(M) # Shape: (B, 3, 3)
mu = torch.bmm(M_pinv, b.unsqueeze(2)).squeeze(2) # Shape: (B, 3)

# Step 7: Compute adjustments to forces
mu_batch = torch.repeat_interleave(mu, n_nodes, dim=0) # Shape: (N, 3)
forces_delta = torch.cross(r, mu_batch) # Shape: (N, 3)

# Step 8: Adjust forces
adjusted_forces = forces + forces_delta # Shape: (N, 3)

return adjusted_forces

0 comments on commit e5622f7

Please sign in to comment.