Skip to content

Commit

Permalink
bumped version, added some comments
Browse files Browse the repository at this point in the history
  • Loading branch information
gerkone committed Jan 20, 2023
1 parent 51e8402 commit d22df64
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 13 deletions.
6 changes: 3 additions & 3 deletions experiments/nbody/data/generate_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Generate chargedcd and gravity datasets.
Generate charged and gravity datasets.
charged: python3 -u generate_dataset.py --simulation=charged --num-train 10000 --seed 43
gravity: python3 -u generate_dataset.py --simulation=gravity --num-train 10000 --seed 43 --n-balls=100
charged: python3 generate_dataset.py --simulation=charged --num-train=10000 --seed=43
gravity: python3 generate_dataset.py --simulation=gravity --num-train=10000 --seed=43 --n-balls=100
"""
import argparse
import time
Expand Down
11 changes: 7 additions & 4 deletions experiments/nbody/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,23 @@
from typing import Callable, List, Optional, Tuple

import e3nn_jax as e3nn
import jax.numpy as jnp
import jax.tree_util as tree
import numpy as np
import torch
import e3nn_jax as e3nn
from jraph import GraphsTuple, segment_mean
from torch.utils.data import DataLoader
from torch_geometric.nn import knn_graph

from segnn_jax import SteerableGraphsTuple

from .datasets import ChargedDataset, GravityDataset


def O3Transform(
node_features_irreps: e3nn.Irreps, edge_features_irreps: e3nn.Irreps, lmax_attributes: int
node_features_irreps: e3nn.Irreps,
edge_features_irreps: e3nn.Irreps,
lmax_attributes: int,
) -> Callable:
"""
Build a transformation function that includes (nbody) O3 attributes to a graph.
Expand All @@ -25,7 +28,7 @@ def _o3_transform(
st_graph: SteerableGraphsTuple,
loc: jnp.ndarray,
vel: jnp.ndarray,
charges: jnp.ndarray
charges: jnp.ndarray,
) -> SteerableGraphsTuple:

graph = st_graph.graph
Expand All @@ -40,7 +43,7 @@ def _o3_transform(

vel_abs = jnp.sqrt(jnp.power(vel, 2).sum(1, keepdims=True))
mean_loc = loc.mean(1, keepdims=True)

nodes = e3nn.IrrepsArray(
node_features_irreps,
jnp.concatenate((loc - mean_loc, vel, vel_abs), axis=-1),
Expand Down
7 changes: 4 additions & 3 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import time
from functools import partial
from typing import Tuple, Union

import e3nn_jax as e3nn
import haiku as hk
import jax
Expand All @@ -20,7 +21,7 @@ def predict(
state: hk.State,
graph: SteerableGraphsTuple,
mean_shift: Union[jnp.array, float] = 0,
mad_shift: Union[jnp.array, float] = 1
mad_shift: Union[jnp.array, float] = 1,
) -> Tuple[jnp.ndarray, hk.State]:
pred, state = segnn.apply(params, state, graph)
return (jnp.multiply(pred, mad_shift) + mean_shift), state
Expand All @@ -34,7 +35,7 @@ def mae(
target: jnp.ndarray,
mean_shift: Union[jnp.array, float] = 0,
mad_shift: Union[jnp.array, float] = 1,
mask_last: bool = False
mask_last: bool = False,
) -> Tuple[float, hk.State]:
pred, state = predict(params, state, graph, mean_shift, mad_shift)
assert target.shape == pred.shape
Expand All @@ -53,7 +54,7 @@ def mse(
target: jnp.ndarray,
mean_shift: Union[jnp.array, float] = 0,
mad_shift: Union[jnp.array, float] = 1,
mask_last: bool = False
mask_last: bool = False,
) -> Tuple[float, hk.State]:
pred, state = predict(params, state, graph, mean_shift, mad_shift)
assert target.shape == pred.shape
Expand Down
2 changes: 1 addition & 1 deletion segnn_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
"SteerableGraphsTuple",
]

__version__ = "0.2"
__version__ = "0.3"
2 changes: 1 addition & 1 deletion segnn_jax/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def O3TensorProductGate(
x (IrrepsArray): Left tensor
y (IrrepsArray): Right tensor
output_irreps: Output representation
biases: If set ot true will add biases
biases: Add biases
scalar_activation: Activation function for scalars
gate_activation: Activation function for higher order
name: Name of the linear layer params
Expand Down
4 changes: 3 additions & 1 deletion segnn_jax/segnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _message(
msg = O3TensorProductGate(
msg, edge_attribute, output_irreps, name=f"message_{i}_{layer_num}"
)
# NOTE: original implementation only applied batch norm to messages
if norm == "batch":
msg = e3nn.BatchNorm(irreps=output_irreps)(msg)
return msg
Expand All @@ -136,6 +137,7 @@ def _update(
)
# residual connection
nodes += update
# message norm
if norm in ["batch", "instance"]:
nodes = e3nn.BatchNorm(irreps=output_irreps, instance=(norm == "instance"))(
nodes
Expand Down Expand Up @@ -172,7 +174,7 @@ def __init__(
blocks_per_layer: int = 2,
embed_msg_features: bool = False,
):
super(SEGNN, self).__init__() # noqa
super(SEGNN, self).__init__() # noqa # pylint: disable=R1725

if isinstance(hidden_irreps, e3nn.Irreps):
self._hidden_irreps_units = num_layers * [hidden_irreps]
Expand Down

0 comments on commit d22df64

Please sign in to comment.