Skip to content

Commit

Permalink
Merge pull request #103 from the16thpythonist/master
Browse files Browse the repository at this point in the history
ragged tensor gnnexplainer implementation for xai benchmarks
  • Loading branch information
PatReis authored Feb 1, 2023
2 parents f64f53a + 2c0c0c8 commit 3c6873d
Show file tree
Hide file tree
Showing 11 changed files with 786 additions and 24 deletions.
190 changes: 190 additions & 0 deletions kgcnn/literature/GNNExplain.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,201 @@
"""
"Ying et al. - GNNExplainer: Generating Explanations for Graph Neural Networks"
**Changelog**
??.??.2022 - Initial implementation
30.01.2023 - Added the class "GnnExplainer" which supports RaggedTensors and can thus generate multiple
explanations at once, greatly improving time efficiency for explaining large batches of predictions. However
the new class does not implement visualization of the explanations. This will have to be realized on a
higher abstraction level.
"""
import time
import typing as t

import numpy as np
import tensorflow as tf
ks = tf.keras

from kgcnn.xai.base import ImportanceExplanationMethod

# Keep track of model version from commit date in literature.
# To be updated if model is changed in a significant way.
__model_version__ = "2022.05.31"


# == REDUCED, RAGGED TENSOR IMPLEMENTATION ==

class GnnExplainer(ImportanceExplanationMethod):
"""
Implementation of "ImportanceExplanationMethod", which means that calling an instance of this class
given a model, a ragged input tensor and output predictions, it should return the corresponding
node and edge importance tensors, which provide an explanation by assigning each node and edge of the
input graphs with a 0-1 importance value.
By the nature of the base idea behind GNNExplainer, the number of explanations produced has to be equal
to the number of prediction targets that are generated by the model. Each target will receive its own
explanation.
"""
def __init__(self,
channels: int,
epochs: int = 100,
learning_rate: float = 0.01,
node_sparsity_factor: float = 0.1,
edge_sparsity_factor: float = 0.1,
log_step: int = 10,
verbose: bool = True):
super(GnnExplainer, self).__init__(channels=channels)
self.epochs = epochs
self.learning_rate = learning_rate
self.log_step = log_step
self.verbose = verbose
self.node_sparsity_factor = node_sparsity_factor
self.edge_sparsity_factor = edge_sparsity_factor

def __call__(self,
model: ks.models.Model,
x: t.Tuple[tf.RaggedTensor, tf.RaggedTensor, tf.RaggedTensor],
y: np.ndarray):
"""
Given a model, the input tensor and the output array, this method will return a tuple of two
ragged tensors, which represent the node importances and the edge importances.
Beware, that this method executes an entire training process and may take some time.
Reference of tensor shapes. [Brackets] indicate ragged dimension
- V: Number of nodes in graph
- E: Number of edges in graph
- K: Number of explanation channels given in constructor. This has to be equal to the number of
prediction targets specified in the constructor.
- N: Number of node attributes
- M: Number of edge attributes
- B: batch size
Args:
x: A tuple (node_input, edge_input, edge_indices) of 3 RaggedTensors
- node_input: Shape ([B], [V], N)
- edge_input: Shape ([B], [E], M)
- edge_indices: Shape ([B], [E], 2)
y: A numpy array of shape (B, K)
model: Any compatible keras model, which means any model which accepts the previously described
input tensors and returns output similar to the previously described output tensor.
Returns:
A tuple (node_importances, edge_importances) of RaggedTensors.
- node_importances: Shape ([B], [V], K)
- edge_importances: Shape ([B], [E], K)
"""
# Generally the idea of the implementation is that we use the node_input and edge_input tensors as
# templates to generate the mask variable tensors, which match the graph dimensions but differ in
# the final dimension, which instead of the node / edge features we will use to represent the
# number of importance channels (== number of prediction targets).

node_input, edge_input, edge_indices = x

# Here we reduce away the last dimension of node and edge input to get just the ragged graph sizes
# But we run into a problem here with multiple channels: We cant actually use the last dimension to
# represent the number of different explanation channels. Instead, we do a workaround here where for
# each channel we extend the batch dimension. Aka we assume that all the different channels are just
# additional graphs to be treated like the others. The reason why we have to do it like that is
# because later on we need to multiply the masks with the inputs!
node_mask_single = tf.reduce_mean(tf.ones_like(node_input), axis=-1, keepdims=True)
node_mask_ragged = tf.concat([node_mask_single for _ in range(self.channels)], axis=0)
node_mask_variables = tf.Variable(node_mask_ragged.flat_values, trainable=True, dtype=tf.float64)

edge_mask_single = tf.reduce_mean(tf.ones_like(edge_input), axis=-1, keepdims=True)
edge_mask_ragged = tf.concat([edge_mask_single for _ in range(self.channels)], axis=0)
edge_mask_variables = tf.Variable(edge_mask_ragged.flat_values, trainable=True, dtype=tf.float64)

optimizer = ks.optimizers.Nadam(learning_rate=self.learning_rate)

# This is a logical extension of what was previously described. Since we treat the different
# explanation channels as just a batch extension, we have to modify the input values and the output
# values accordingly so that they have the same batch size so to say. Naturally we simply have to
# duplicate the values.
x_extended = (
tf.concat([node_input for _ in range(self.channels)], axis=0),
tf.concat([edge_input for _ in range(self.channels)], axis=0),
tf.concat([edge_indices for _ in range(self.channels)], axis=0),
)
y_extended = []
for c in range(self.channels):
y_mod = np.zeros_like(y)
y_mod[:, c] = y[:, c]
y_extended.append(y_mod)

y_extended = np.concatenate(y_extended)

start_time = time.time()
for epoch in range(self.epochs):

with tf.GradientTape() as tape:
node_mask = tf.RaggedTensor.from_nested_row_splits(
node_mask_variables,
nested_row_splits=node_mask_ragged.nested_row_splits
)

edge_mask = tf.RaggedTensor.from_nested_row_splits(
edge_mask_variables,
nested_row_splits=edge_mask_ragged.nested_row_splits
)

out = model([
x_extended[0] * node_mask,
x_extended[1] * edge_mask,
x_extended[2]
])

# The loss can basically be summerized as: We try to find the smallest subset of nodes and
# edges in the input, which will cause the network to get as close as possible to it's
# original prediction!
loss = tf.cast(tf.reduce_mean(tf.square(y_extended - out)), dtype=tf.float64)
# Important detail: The reduce_sum here reduces over all the nodes / edges and is necessary!
loss += self.node_sparsity_factor * tf.reduce_mean(tf.reduce_sum(tf.abs(node_mask), axis=1))
loss += self.edge_sparsity_factor * tf.reduce_mean(tf.reduce_sum(tf.abs(edge_mask), axis=1))

trainable_vars = [node_mask_variables, edge_mask_variables]
gradients = tape.gradient(loss, trainable_vars)
optimizer.apply_gradients(zip(gradients, trainable_vars))

if self.verbose and epoch % self.log_step == 0:
print(f' * epoch ({epoch}/{self.epochs}) '
f' - loss: {loss}'
f' - elapsed time: {time.time()-start_time:.2f} seconds')

# For the training we had to treat the different explanation channels as a batch extension. As per
# the interface we need to return the importances however such that the different explanation
# channels are organized into the third dimension of the tensors.

# Sadly this does not work in a more direct fashion. We get the number of elements of nodes and
# edges that belong to one explanation channel. Iterate in chunks of that size and turn each of
# those chunks into it's own explanation respectively. At the end we concatenate all of them in
# the 3rd dimension to produce the desired result.
num_elements_node = node_mask_single.flat_values.shape[0]
num_elements_edge = edge_mask_single.flat_values.shape[0]
node_importances_list = []
edge_importances_list = []
for c in range(self.channels):
node_importances_part = tf.RaggedTensor.from_nested_row_splits(
node_mask_variables[c*num_elements_node:(c+1)*num_elements_node, :],
node_mask_single.nested_row_splits
)
node_importances_list.append(node_importances_part)

edge_importances_part = tf.RaggedTensor.from_nested_row_splits(
edge_mask_variables[c*num_elements_edge:(c+1)*num_elements_edge, :],
edge_mask_single.nested_row_splits
)
edge_importances_list.append(edge_importances_part)

return (
tf.concat(node_importances_list, axis=-1),
tf.concat(edge_importances_list, axis=-1)
)


# == ORIGINAL IMPLEMENTATION ==

class GNNInterface:
"""An interface class which should be implemented by a Graph Neural Network (GNN) model to make it explainable.
This class is just an interface, which is used by the `GNNExplainer` and should be implemented in a subclass.
Expand Down
56 changes: 56 additions & 0 deletions kgcnn/xai/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import typing as t

import numpy as np
import tensorflow as tf
import tensorflow.keras as ks

from kgcnn.data.utils import ragged_tensor_from_nested_numpy


class AbstractExplanationMixin:
Expand All @@ -20,3 +24,55 @@ def explain_importances(self,
**kwargs
) -> t.Tuple[tf.RaggedTensor, tf.RaggedTensor]:
raise NotImplementedError


class AbstractExplanationMethod:

def __call__(self, model, x, y):
raise NotImplementedError


class ImportanceExplanationMethod(AbstractExplanationMethod):

def __init__(self,
channels: int):
self.channels = channels

def __call__(self,
model: ks.models.Model,
x: tf.Tensor,
y: tf.Tensor
) -> t.Tuple[tf.Tensor, tf.Tensor]:
raise NotImplementedError


class MockImportanceExplanationMethod(ImportanceExplanationMethod):
"""
This is a mock implementation of "ImportanceExplanationMethod". It is purely for testing purposes.
Using this method will result in randomly generated importance values for nodes and edges.
"""
def __init__(self, channels):
super(MockImportanceExplanationMethod, self).__init__(channels=channels)

def __call__(self,
model: ks.models.Model,
x: t.Tuple[tf.Tensor],
y: t.Tuple[tf.Tensor],
) -> t.Tuple[tf.Tensor, tf.Tensor]:
node_input, edge_input, _ = x

# Im sure you could probably do this in tensorflow directly, but I am just going to go the numpy
# route here because that's just easier.
node_input = node_input.numpy()
edge_input = edge_input.numpy()

node_importances = [np.random.uniform(0, 1, size=(v.shape[0], self.channels))
for v in node_input]
edge_importances = [np.random.uniform(0, 1, size=(v.shape[0], self.channels))
for v in edge_input]

return (
ragged_tensor_from_nested_numpy(node_importances),
ragged_tensor_from_nested_numpy(edge_importances)
)

130 changes: 130 additions & 0 deletions kgcnn/xai/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
import random
import typing as t

import numpy as np
import tensorflow as tf
import tensorflow.keras as ks

from kgcnn.layers.conv.gat_conv import AttentionHeadGATV2
from kgcnn.layers.modules import DenseEmbedding
from kgcnn.layers.pooling import PoolingGlobalEdges
from kgcnn.data.utils import ragged_tensor_from_nested_numpy


# This is a very simple mock implementation, because to test the explanation methods we need some sort
# of a model as basis and this model will act as such.
class Model(ks.models.Model):

def __init__(self,
num_targets: int = 1):
super(Model, self).__init__()
self.conv_layers = [
AttentionHeadGATV2(units=64, use_edge_features=True, use_bias=True),
]
self.lay_pooling = PoolingGlobalEdges(pooling_method='sum')
self.lay_dense = DenseEmbedding(units=num_targets, activation='linear')

def call(self, inputs, training=False):
node_input, edge_input, edge_index_input = inputs
x = node_input
for lay in self.conv_layers:
x = lay([x, edge_input, edge_index_input])

pooled = self.lay_pooling(x)
out = self.lay_dense(pooled)
return out


class MockContext:

def __init__(self,
num_elements: int = 10,
num_targets: int = 1,
epochs: int = 10,
batch_size: int = 2):
self.num_elements = num_elements
self.num_targets = num_targets
self.epochs = epochs
self.batch_size = batch_size

self.model = Model(num_targets=num_targets)
self.x = None
self.y = None

def generate_graph(self,
num_nodes: int,
num_node_attributes: int = 3,
num_edge_attributes: int = 1):
remaining = list(range(num_nodes))
random.shuffle(remaining)
inserted = [remaining.pop(0)]
node_attributes = [[random.random() for _ in range(num_node_attributes)] for _ in range(num_nodes)]
edge_indices = []
edge_attributes = []
while len(remaining) != 0:
i = remaining.pop(0)
j = random.choice(inserted)
inserted.append(i)

edge_indices += [[i, j], [j, i]]
edge_attribute = [1 for _ in range(num_edge_attributes)]
edge_attributes += [edge_attribute, edge_attribute]

return (
np.array(node_attributes, dtype=float),
np.array(edge_attributes, dtype=float),
np.array(edge_indices, dtype=int)
)

def generate_data(self):
node_attributes_list = []
edge_attributes_list = []
edge_indices_list = []
targets_list = []
for i in range(self.num_elements):
num_nodes = random.randint(5, 20)
node_attributes, edge_attributes, edge_indices = self.generate_graph(num_nodes)
node_attributes_list.append(node_attributes)
edge_attributes_list.append(edge_attributes)
edge_indices_list.append(edge_indices)

# The target value we will actually determine deterministically here so that our network
# actually has a chance to learn anything
target = np.sum(node_attributes)
targets = [target for _ in range(self.num_targets)]
targets_list.append(targets)

self.x = (
ragged_tensor_from_nested_numpy(node_attributes_list),
ragged_tensor_from_nested_numpy(edge_attributes_list),
ragged_tensor_from_nested_numpy(edge_indices_list)
)

self.y = (
np.array(targets_list, dtype=float)
)

def __enter__(self):
# This method will generate random input and output data and thus populate the internal attributes
# self.x and self.y
self.generate_data()

# Using these we will train our mock model for a few very brief epochs.
self.model.compile(
loss=ks.losses.mean_squared_error,
metrics=ks.metrics.mean_squared_error,
run_eagerly=False,
optimizer=ks.optimizers.Nadam(learning_rate=0.01),
)
hist = self.model.fit(
self.x, self.y,
batch_size=self.batch_size,
epochs=self.epochs,
verbose=0,
)
self.history = hist.history

return self

def __exit__(self, *args, **kwargs):
pass
Loading

0 comments on commit 3c6873d

Please sign in to comment.