Skip to content

Commit

Permalink
Implement and test a registry function for `tf.keras.layers.LayerNorm…
Browse files Browse the repository at this point in the history
…alization`.

PiperOrigin-RevId: 557168617
  • Loading branch information
tensorflower-gardener committed Aug 30, 2023
1 parent 372c934 commit 58e1f00
Show file tree
Hide file tree
Showing 7 changed files with 314 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,25 @@ py_test(
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)

py_library(
name = "layer_normalization",
srcs = ["layer_normalization.py"],
srcs_version = "PY3",
deps = ["//tensorflow_privacy/privacy/fast_gradient_clipping:type_aliases"],
)

py_test(
name = "layer_normalization_test",
srcs = ["layer_normalization_test.py"],
python_version = "PY3",
shard_count = 8,
srcs_version = "PY3",
deps = [
":dense",
":layer_normalization",
"//tensorflow_privacy/privacy/fast_gradient_clipping:clip_grads",
"//tensorflow_privacy/privacy/fast_gradient_clipping:common_test_utils",
"//tensorflow_privacy/privacy/fast_gradient_clipping:layer_registry",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.Dense`."""

from typing import Any, Dict, Optional, Text, Tuple
from typing import Any, Mapping, Tuple, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_manip_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases
Expand All @@ -22,9 +22,9 @@
def dense_layer_computation(
layer_instance: tf.keras.layers.Dense,
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
num_microbatches: Union[tf.Tensor, None] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Dense`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.Embedding`."""

from typing import Any, Dict, Optional, Text, Tuple
from typing import Any, Mapping, Tuple, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases


def embedding_layer_computation(
layer_instance: tf.keras.layers.Embedding,
input_args: Tuple[Any, ...],
input_kwargs: Dict[Text, Any],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Optional[tf.Tensor] = None,
num_microbatches: Union[tf.Tensor, None] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.Embedding`.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Fast clipping function for `tf.keras.layers.LayerNormalization`."""

from typing import Any, Mapping, Tuple, Union
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases


# ==============================================================================
# Supported Keras layers
# ==============================================================================
def _sqr_norm_fn(grads):
stacked_grads = tf.stack(grads, axis=-1)
reduction_axes = tf.range(1, tf.rank(stacked_grads))
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)


def layer_normalization_computation(
layer_instance: tf.keras.layers.LayerNormalization,
input_args: Tuple[Any, ...],
input_kwargs: Mapping[str, Any],
tape: tf.GradientTape,
num_microbatches: Union[tf.Tensor, None] = None,
) -> type_aliases.RegistryFunctionOutput:
"""Registry function for `tf.keras.layers.LayerNormalization`.
This function computes actual per-example gradients and computes their
norms directly, instead of employing a chain-rule trick. This is done using
some slick reshaping calls.
Args:
layer_instance: A `tf.keras.layers.LayerNormalization` instance.
input_args: See `dense_layer_computation()` in `dense.py`.
input_kwargs: See `dense_layer_computation()` in `dense.py`.
tape: See `dense_layer_computation()` in `dense.py`.
num_microbatches: See `dense_layer_computation()` in `dense.py`.
Returns:
See `dense_layer_computation()` in `dense.py`.
"""
del input_kwargs # Unused in layer normaliztion calls.
if num_microbatches is not None:
raise NotImplementedError("Microbatching is not currently supported.")

# To make sure the watched variables (beta, gamma) generate per-example
# gradients, we need to convert trainable variables from shape [S] to
# [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting.
inputs = input_args[0]
base_vars = []
batch_size = tf.shape(inputs)[0]

def process_variable(var):
"""Expand univariate `var` and the expanded tensor to `base_vars`."""
expanded_var = tf.tile(
tf.expand_dims(var, axis=0), [batch_size] + [1] * len(var.shape)
)
tape.watch(expanded_var)
base_vars.append(expanded_var)
broadcast_shape = [1] * len(inputs.shape)
broadcast_shape[0] = batch_size
for d in layer_instance.axis:
broadcast_shape[d] = tf.shape(inputs)[d]
final_var = tf.reshape(expanded_var, broadcast_shape)
return final_var

orig_gamma = layer_instance.gamma
orig_beta = layer_instance.beta
layer_instance.gamma = process_variable(orig_gamma)
layer_instance.beta = process_variable(orig_beta)

# Do the computation, ensure that the output conforms to the unexpanded
# computation, and restore the state of the original instance.
outputs = layer_instance.call(inputs)
layer_instance.gamma = orig_gamma
layer_instance.beta = orig_beta

return base_vars, outputs, _sqr_norm_fn
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from absl.testing import parameterized
import numpy as np
import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization


# ==============================================================================
# Helper functions.
# ==============================================================================
def get_layer_norm_layer_generators():
return {
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x),
}


def get_layer_norm_model_generators():
return {
# TODO(b/274483956): Test more complex models once the we can support
# `nD` inputs for `tf.keras.layers.Dense`.
'func1': common_test_utils.make_one_layer_functional_model,
}


def get_layer_norm_parameter_tuples():
"""Consists of (input_dims, parameter_axes)."""
return [
# Rank-2
([3], -1),
([3], [1]),
# Rank-3
([3, 4], -1),
([3, 4], [1]),
([3, 4], [2]),
([3, 4], [1, 2]),
# Rank-4
([3, 4, 5], -1),
([3, 4, 5], [1]),
([3, 4, 5], [2]),
([3, 4, 5], [3]),
([3, 4, 5], [1, 2]),
([3, 4, 5], [1, 3]),
([3, 4, 5], [2, 3]),
([3, 4, 5], [1, 2, 3]),
]


def get_layer_norm_registries():
ln_registry = layer_registry.LayerRegistry()
ln_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation)
ln_registry.insert(
tf.keras.layers.LayerNormalization,
layer_normalization.layer_normalization_computation,
)
return {
'layer_norm_only': ln_registry,
}


# ==============================================================================
# Main tests.
# ==============================================================================
class GradNormTest(tf.test.TestCase, parameterized.TestCase):

def setUp(self):
super().setUp()
self.strategy = tf.distribute.get_strategy()

@parameterized.product(
model_name=list(get_layer_norm_model_generators().keys()),
layer_name=list(get_layer_norm_layer_generators().keys()),
parameter_tuple=get_layer_norm_parameter_tuples(),
layer_registry_name=list(get_layer_norm_registries().keys()),
is_eager=[True, False],
)
def test_gradient_norms_on_various_models(
self,
model_name,
layer_name,
parameter_tuple,
layer_registry_name,
is_eager,
):
# Parse inputs to generate test data.
input_dims, parameter_axes = parameter_tuple

def curried_generator(a, b):
del a, b # Unused by the generator.
layer_norm_generator = get_layer_norm_layer_generators()[layer_name]
return layer_norm_generator(parameter_axes)

# Load shared assets to all devices.
with self.strategy.scope():
dummy_output_dim = 1
model = common_test_utils.get_model_from_generator(
model_generator=get_layer_norm_model_generators()[model_name],
layer_generator=curried_generator,
input_dims=input_dims,
output_dims=[dummy_output_dim],
is_eager=is_eager,
)

# Define the main testing ops. These may be later compiled to a Graph op.
def test_op(x_batch):
return common_test_utils.get_computed_and_true_norms_from_model(
model=model,
per_example_loss_fn=None,
num_microbatches=None,
x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch,
weight_batch=None,
registry=get_layer_norm_registries()[layer_registry_name],
)

# TPUs can only run `tf.function`-decorated functions.
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy)
if using_tpu:
test_op = tf.function(test_op, jit_compile=True, autograph=False)

# TPUs use lower precision than CPUs, so we relax our criterion (see
# `dense_test.py` for additional discussions).
rtol = 1e-2 if using_tpu else 1e-3
atol = 1e-1 if using_tpu else 1e-2

# Each batched input is a reshape of a `tf.range()` call.
batch_size = 2
example_size = np.prod(input_dims)
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
batch_size = x_batch.shape[0]
# Set up the device ops and run the test.
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,))
# TPUs return replica contexts, which must be unwrapped.
if using_tpu:
common_test_utils.assert_replica_values_are_close(self, computed_norms)
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
self.assertEqual(tf.shape(computed_norms)[0], batch_size)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2023, The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test


class GradNormTpuTest(layer_normalization_test.GradNormTest):

def setUp(self):
super().setUp()
self.strategy = ctu.create_tpu_strategy()
self.assertIn('TPU', self.strategy.extended.worker_devices[0])


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
"""A collection of type aliases used throughout the clipping library."""

from typing import Any, Callable, Dict, Iterable, List, Optional, Text, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
import tensorflow as tf


# Tensorflow aliases.
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[Text, tf.Tensor]]
PackedTensors = Union[tf.Tensor, Iterable[tf.Tensor], Dict[str, tf.Tensor]]

InputTensors = PackedTensors

Expand All @@ -34,7 +34,13 @@
RegistryFunctionOutput = Tuple[Any, OutputTensors, SquareNormFunction]

RegistryFunction = Callable[
[Any, Tuple[Any, ...], Dict[Text, Any], tf.GradientTape],
[
Any,
Tuple[Any, ...],
Mapping[str, Any],
tf.GradientTape,
Union[tf.Tensor, None],
],
RegistryFunctionOutput,
]

Expand Down

0 comments on commit 58e1f00

Please sign in to comment.