Skip to content

Commit

Permalink
Remove torque option added.
Browse files Browse the repository at this point in the history
  • Loading branch information
knc6 committed Jan 11, 2025
1 parent e5622f7 commit a3baa7a
Showing 1 changed file with 8 additions and 6 deletions.
14 changes: 8 additions & 6 deletions alignn/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,21 +203,23 @@ def remove_net_torque(
forces: torch.Tensor,
n_nodes: torch.Tensor,
) -> torch.Tensor:
"""Adjust the predicted forces to eliminate net torque for each graph in the batch.
"""Adjust the predicted forces to eliminate net torque.
Args:
g : dgl.DGLGraph
The graph representing a batch of particles (atoms).
forces : torch.Tensor of shape (N, 3)
Predicted forces on atoms.
n_nodes : torch.Tensor of shape (B,)
Number of nodes in each graph, where B is the number of graphs in the batch.
Number of nodes in each graph,
where B is the number of graphs in the batch.
Returns:
adjusted_forces : torch.Tensor of shape (N, 3)
Adjusted forces with zero net torque and net force for each graph.
Adjusted forces with zero net torque
and net force for each graph.
"""
# Step 1: Get positions from the graph (assuming 'cart_coords' holds the positions)
# Step 1: Get positions from the graph
positions = g.ndata["cart_coords"]

# Compute the net torque and relative positions
Expand Down Expand Up @@ -248,12 +250,12 @@ def remove_net_torque(
start_idx = end_idx

# Step 4: Compute M = S - sI
I = (
Imat = (
torch.eye(3, device=positions.device)
.unsqueeze(0)
.expand(n_nodes.size(0), -1, -1)
) # Identity matrix
M = S - s.view(-1, 1, 1) * I # Shape: (B, 3, 3)
M = S - s.view(-1, 1, 1) * Imat # Shape: (B, 3, 3)

# Step 5: Right-hand side vector b per graph
b = -tau_total # Shape: (B, 3)
Expand Down

0 comments on commit a3baa7a

Please sign in to comment.