forked from google-deepmind/deepmind-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
graph_model.py
185 lines (154 loc) · 6.64 KB
/
graph_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright 2019 Deepmind Technologies Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A graph neural network based model to predict particle mobilities.
The architecture and performance of this model is described in our publication:
"Unveiling the predictive power of static structure in glassy systems".
"""
import functools
from typing import Any, Dict, Text, Tuple, Optional
from graph_nets import graphs
from graph_nets import modules as gn_modules
from graph_nets import utils_tf
import sonnet as snt
import tensorflow.compat.v1 as tf
def make_graph_from_static_structure(
positions: tf.Tensor,
types: tf.Tensor,
box: tf.Tensor,
edge_threshold: float) -> graphs.GraphsTuple:
"""Returns graph representing the static structure of the glass.
Each particle is represented by a node in the graph. The particle type is
stored as a node feature.
Two particles at a distance less than the threshold are connected by an edge.
The relative distance vector is stored as an edge feature.
Args:
positions: particle positions with shape [n_particles, 3].
types: particle types with shape [n_particles].
box: dimensions of the cubic box that contains the particles with shape [3].
edge_threshold: particles at distance less than threshold are connected by
an edge.
"""
# Calculate pairwise relative distances between particles: shape [n, n, 3].
cross_positions = positions[tf.newaxis, :, :] - positions[:, tf.newaxis, :]
# Enforces periodic boundary conditions.
box_ = box[tf.newaxis, tf.newaxis, :]
cross_positions += tf.cast(cross_positions < -box_ / 2., tf.float32) * box_
cross_positions -= tf.cast(cross_positions > box_ / 2., tf.float32) * box_
# Calculates adjacency matrix in a sparse format (indices), based on the given
# distances and threshold.
distances = tf.norm(cross_positions, axis=-1)
indices = tf.where(distances < edge_threshold)
# Defines graph.
nodes = types[:, tf.newaxis]
senders = indices[:, 0]
receivers = indices[:, 1]
edges = tf.gather_nd(cross_positions, indices)
return graphs.GraphsTuple(
nodes=tf.cast(nodes, tf.float32),
n_node=tf.reshape(tf.shape(nodes)[0], [1]),
edges=tf.cast(edges, tf.float32),
n_edge=tf.reshape(tf.shape(edges)[0], [1]),
globals=tf.zeros((1, 1), dtype=tf.float32),
receivers=tf.cast(receivers, tf.int32),
senders=tf.cast(senders, tf.int32)
)
def apply_random_rotation(graph: graphs.GraphsTuple) -> graphs.GraphsTuple:
"""Returns randomly rotated graph representation.
The rotation is an element of O(3) with rotation angles multiple of pi/2.
This function assumes that the relative particle distances are stored in
the edge features.
Args:
graph: The graphs tuple as defined in `graph_nets.graphs`.
"""
# Transposes edge features, so that the axes are in the first dimension.
# Outputs a tensor of shape [3, n_particles].
xyz = tf.transpose(graph.edges)
# Random pi/2 rotation(s)
permutation = tf.random.shuffle(tf.constant([0, 1, 2], dtype=tf.int32))
xyz = tf.gather(xyz, permutation)
# Random reflections.
symmetry = tf.random_uniform([3], minval=0, maxval=2, dtype=tf.int32)
symmetry = 1 - 2 * tf.cast(tf.reshape(symmetry, [3, 1]), tf.float32)
xyz = xyz * symmetry
edges = tf.transpose(xyz)
return graph.replace(edges=edges)
class GraphBasedModel(snt.AbstractModule):
"""Graph based model which predicts particle mobilities from their positions.
This network encodes the nodes and edges of the input graph independently, and
then performs message-passing on this graph, updating its edges based on their
associated nodes, then updating the nodes based on the input nodes' features
and their associated updated edge features.
This update is repeated several times.
Afterwards the resulting node embeddings are decoded to predict the particle
mobility.
"""
def __init__(self,
n_recurrences: int,
mlp_sizes: Tuple[int],
mlp_kwargs: Optional[Dict[Text, Any]] = None,
name='Graph'):
"""Creates a new GraphBasedModel object.
Args:
n_recurrences: the number of message passing steps in the graph network.
mlp_sizes: the number of neurons in each layer of the MLP.
mlp_kwargs: additional keyword aguments passed to the MLP.
name: the name of the Sonnet module.
"""
super(GraphBasedModel, self).__init__(name=name)
self._n_recurrences = n_recurrences
if mlp_kwargs is None:
mlp_kwargs = {}
model_fn = functools.partial(
snt.nets.MLP,
output_sizes=mlp_sizes,
activate_final=True,
**mlp_kwargs)
final_model_fn = functools.partial(
snt.nets.MLP,
output_sizes=mlp_sizes + (1,),
activate_final=False,
**mlp_kwargs)
with self._enter_variable_scope():
self._encoder = gn_modules.GraphIndependent(
node_model_fn=model_fn,
edge_model_fn=model_fn)
if self._n_recurrences > 0:
self._propagation_network = gn_modules.GraphNetwork(
node_model_fn=model_fn,
edge_model_fn=model_fn,
# We do not use globals, hence we just pass the identity function.
global_model_fn=lambda: lambda x: x,
reducer=tf.unsorted_segment_sum,
edge_block_opt=dict(use_globals=False),
node_block_opt=dict(use_globals=False),
global_block_opt=dict(use_globals=False))
self._decoder = gn_modules.GraphIndependent(
node_model_fn=final_model_fn,
edge_model_fn=model_fn)
def _build(self, graphs_tuple: graphs.GraphsTuple) -> tf.Tensor:
"""Connects the model into the tensorflow graph.
Args:
graphs_tuple: input graph tensor as defined in `graphs_tuple.graphs`.
Returns:
tensor with shape [n_particles] containing the predicted particle
mobilities.
"""
encoded = self._encoder(graphs_tuple)
outputs = encoded
for _ in range(self._n_recurrences):
# Adds skip connections.
inputs = utils_tf.concat([outputs, encoded], axis=-1)
outputs = self._propagation_network(inputs)
decoded = self._decoder(outputs)
return tf.squeeze(decoded.nodes, axis=-1)