From e5622f75316be808dc1a22932ed04a87f0267285 Mon Sep 17 00:00:00 2001 From: knc6 Date: Sat, 11 Jan 2025 14:33:59 -0500 Subject: [PATCH] Remove torque option added. --- alignn/models/ealignn_atomwise.py | 9 ++- alignn/models/utils.py | 105 ++++++++++++++++++++++++++++++ 2 files changed, 113 insertions(+), 1 deletion(-) diff --git a/alignn/models/ealignn_atomwise.py b/alignn/models/ealignn_atomwise.py index 3aae354..0318138 100644 --- a/alignn/models/ealignn_atomwise.py +++ b/alignn/models/ealignn_atomwise.py @@ -21,6 +21,7 @@ compute_pair_vector_and_distance, MLPLayer, lightweight_line_graph, + remove_net_torque, ) from alignn.graphs import compute_bond_cosines from alignn.utils import BaseSettings @@ -45,7 +46,8 @@ class eALIGNNAtomWiseConfig(BaseSettings): stresswise_weight: float = 0.0 atomwise_weight: float = 0.0 classification: bool = False - energy_mult_natoms: bool = True + energy_mult_natoms: bool = True # Make it false for regression only + remove_torque: bool = True inner_cutoff: float = 2.8 # Ansgtrom use_penalty: bool = True extra_features: int = 0 @@ -390,6 +392,11 @@ def forward( forces = torch.squeeze( g.ndata["forces_ji"] - rg.ndata["forces_ij"] ) + if self.config.remove_torque: + # print('forces1',forces,forces.shape) + # print('natoms',natoms,natoms.shape) + forces = remove_net_torque(g, forces, natoms) + # print('forces2',forces,forces.shape) if self.config.stresswise_weight != 0: stresses = [] diff --git a/alignn/models/utils.py b/alignn/models/utils.py index 184f583..acd5e33 100644 --- a/alignn/models/utils.py +++ b/alignn/models/utils.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn import dgl +from typing import Tuple class RBFExpansion(nn.Module): @@ -171,3 +172,107 @@ def forward(self, x): """Linear, Batchnorm, silu layer.""" # print('xtype',x.dtype) return self.layer(x) + + +def compute_net_torque( + positions: torch.Tensor, forces: torch.Tensor, n_nodes: torch.Tensor +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute the net torque on a system of particles.""" + total_mass = n_nodes.float().sum() # Total number of particles + com = torch.sum(positions, dim=0) / total_mass # Center of mass + + # Compute the relative positions of particles with respect to CoM + com_repeat = com.repeat( + positions.size(0), 1 + ) # Repeat CoM for each particle + com_relative_positions = ( + positions - com_repeat + ) # Relative position to the CoM + + # Compute individual torques (cross product of r_i and F_i) + torques = torch.cross(com_relative_positions, forces) # Shape: (N, 3) + + # Aggregate torques to get the net torque (sum all torques) + net_torque = torch.sum(torques, dim=0) # Sum of all individual torques + + return net_torque, com_relative_positions + + +def remove_net_torque( + g: dgl.DGLGraph, + forces: torch.Tensor, + n_nodes: torch.Tensor, +) -> torch.Tensor: + """Adjust the predicted forces to eliminate net torque for each graph in the batch. + + Args: + g : dgl.DGLGraph + The graph representing a batch of particles (atoms). + forces : torch.Tensor of shape (N, 3) + Predicted forces on atoms. + n_nodes : torch.Tensor of shape (B,) + Number of nodes in each graph, where B is the number of graphs in the batch. + + Returns: + adjusted_forces : torch.Tensor of shape (N, 3) + Adjusted forces with zero net torque and net force for each graph. + """ + # Step 1: Get positions from the graph (assuming 'cart_coords' holds the positions) + positions = g.ndata["cart_coords"] + + # Compute the net torque and relative positions + tau_total, r = compute_net_torque(positions, forces, n_nodes) + + # Step 2: Compute scalar s per graph: sum_i ||r_i||^2 + r_squared = torch.sum(r**2, dim=1) # Shape: (N,) + + # Sum over nodes to aggregate r_squared for each graph + s = torch.zeros(n_nodes.size(0), device=positions.device) + start_idx = 0 + for i, num_nodes in enumerate(n_nodes): + end_idx = start_idx + num_nodes + s[i] = torch.sum(r_squared[start_idx:end_idx]) + start_idx = end_idx + + # Step 3: Compute matrix S per graph: sum_i outer(r_i, r_i) + r_unsqueezed = r.unsqueeze(2) # Shape: (N, 3, 1) + r_T_unsqueezed = r.unsqueeze(1) # Shape: (N, 1, 3) + outer_products = r_unsqueezed @ r_T_unsqueezed # Shape: (N, 3, 3) + + # Aggregate outer products for each graph + S = torch.zeros(n_nodes.size(0), 3, 3, device=positions.device) + start_idx = 0 + for i, num_nodes in enumerate(n_nodes): + end_idx = start_idx + num_nodes + S[i] = torch.sum(outer_products[start_idx:end_idx], dim=0) + start_idx = end_idx + + # Step 4: Compute M = S - sI + I = ( + torch.eye(3, device=positions.device) + .unsqueeze(0) + .expand(n_nodes.size(0), -1, -1) + ) # Identity matrix + M = S - s.view(-1, 1, 1) * I # Shape: (B, 3, 3) + + # Step 5: Right-hand side vector b per graph + b = -tau_total # Shape: (B, 3) + + # Step 6: Solve M * mu = b for mu per graph + try: + mu = torch.linalg.solve( + M, b + ) # Shape: (B, 3) -- No need for unsqueeze(2) + except RuntimeError: + # Handle singular matrix M by using the pseudo-inverse + M_pinv = torch.linalg.pinv(M) # Shape: (B, 3, 3) + mu = torch.bmm(M_pinv, b.unsqueeze(2)).squeeze(2) # Shape: (B, 3) + + # Step 7: Compute adjustments to forces + mu_batch = torch.repeat_interleave(mu, n_nodes, dim=0) # Shape: (N, 3) + forces_delta = torch.cross(r, mu_batch) # Shape: (N, 3) + + # Step 8: Adjust forces + adjusted_forces = forces + forces_delta # Shape: (N, 3) + + return adjusted_forces