-
Notifications
You must be signed in to change notification settings - Fork 453
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement and test a registry function for `tf.keras.layers.LayerNorm…
…alization`. PiperOrigin-RevId: 557168617
- Loading branch information
1 parent
372c934
commit 58e1f00
Showing
7 changed files
with
314 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
89 changes: 89 additions & 0 deletions
89
tensorflow_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Copyright 2023, The TensorFlow Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Fast clipping function for `tf.keras.layers.LayerNormalization`.""" | ||
|
||
from typing import Any, Mapping, Tuple, Union | ||
import tensorflow as tf | ||
from tensorflow_privacy.privacy.fast_gradient_clipping import type_aliases | ||
|
||
|
||
# ============================================================================== | ||
# Supported Keras layers | ||
# ============================================================================== | ||
def _sqr_norm_fn(grads): | ||
stacked_grads = tf.stack(grads, axis=-1) | ||
reduction_axes = tf.range(1, tf.rank(stacked_grads)) | ||
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes) | ||
|
||
|
||
def layer_normalization_computation( | ||
layer_instance: tf.keras.layers.LayerNormalization, | ||
input_args: Tuple[Any, ...], | ||
input_kwargs: Mapping[str, Any], | ||
tape: tf.GradientTape, | ||
num_microbatches: Union[tf.Tensor, None] = None, | ||
) -> type_aliases.RegistryFunctionOutput: | ||
"""Registry function for `tf.keras.layers.LayerNormalization`. | ||
This function computes actual per-example gradients and computes their | ||
norms directly, instead of employing a chain-rule trick. This is done using | ||
some slick reshaping calls. | ||
Args: | ||
layer_instance: A `tf.keras.layers.LayerNormalization` instance. | ||
input_args: See `dense_layer_computation()` in `dense.py`. | ||
input_kwargs: See `dense_layer_computation()` in `dense.py`. | ||
tape: See `dense_layer_computation()` in `dense.py`. | ||
num_microbatches: See `dense_layer_computation()` in `dense.py`. | ||
Returns: | ||
See `dense_layer_computation()` in `dense.py`. | ||
""" | ||
del input_kwargs # Unused in layer normaliztion calls. | ||
if num_microbatches is not None: | ||
raise NotImplementedError("Microbatching is not currently supported.") | ||
|
||
# To make sure the watched variables (beta, gamma) generate per-example | ||
# gradients, we need to convert trainable variables from shape [S] to | ||
# [batch_size, S] via duplication to `tf.shape(inputs)` via broadcasting. | ||
inputs = input_args[0] | ||
base_vars = [] | ||
batch_size = tf.shape(inputs)[0] | ||
|
||
def process_variable(var): | ||
"""Expand univariate `var` and the expanded tensor to `base_vars`.""" | ||
expanded_var = tf.tile( | ||
tf.expand_dims(var, axis=0), [batch_size] + [1] * len(var.shape) | ||
) | ||
tape.watch(expanded_var) | ||
base_vars.append(expanded_var) | ||
broadcast_shape = [1] * len(inputs.shape) | ||
broadcast_shape[0] = batch_size | ||
for d in layer_instance.axis: | ||
broadcast_shape[d] = tf.shape(inputs)[d] | ||
final_var = tf.reshape(expanded_var, broadcast_shape) | ||
return final_var | ||
|
||
orig_gamma = layer_instance.gamma | ||
orig_beta = layer_instance.beta | ||
layer_instance.gamma = process_variable(orig_gamma) | ||
layer_instance.beta = process_variable(orig_beta) | ||
|
||
# Do the computation, ensure that the output conforms to the unexpanded | ||
# computation, and restore the state of the original instance. | ||
outputs = layer_instance.call(inputs) | ||
layer_instance.gamma = orig_gamma | ||
layer_instance.beta = orig_beta | ||
|
||
return base_vars, outputs, _sqr_norm_fn |
159 changes: 159 additions & 0 deletions
159
...low_privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
# Copyright 2023, The TensorFlow Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from absl.testing import parameterized | ||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils | ||
from tensorflow_privacy.privacy.fast_gradient_clipping import layer_registry | ||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import dense | ||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization | ||
|
||
|
||
# ============================================================================== | ||
# Helper functions. | ||
# ============================================================================== | ||
def get_layer_norm_layer_generators(): | ||
return { | ||
'defaults': lambda x: tf.keras.layers.LayerNormalization(axis=x), | ||
} | ||
|
||
|
||
def get_layer_norm_model_generators(): | ||
return { | ||
# TODO(b/274483956): Test more complex models once the we can support | ||
# `nD` inputs for `tf.keras.layers.Dense`. | ||
'func1': common_test_utils.make_one_layer_functional_model, | ||
} | ||
|
||
|
||
def get_layer_norm_parameter_tuples(): | ||
"""Consists of (input_dims, parameter_axes).""" | ||
return [ | ||
# Rank-2 | ||
([3], -1), | ||
([3], [1]), | ||
# Rank-3 | ||
([3, 4], -1), | ||
([3, 4], [1]), | ||
([3, 4], [2]), | ||
([3, 4], [1, 2]), | ||
# Rank-4 | ||
([3, 4, 5], -1), | ||
([3, 4, 5], [1]), | ||
([3, 4, 5], [2]), | ||
([3, 4, 5], [3]), | ||
([3, 4, 5], [1, 2]), | ||
([3, 4, 5], [1, 3]), | ||
([3, 4, 5], [2, 3]), | ||
([3, 4, 5], [1, 2, 3]), | ||
] | ||
|
||
|
||
def get_layer_norm_registries(): | ||
ln_registry = layer_registry.LayerRegistry() | ||
ln_registry.insert(tf.keras.layers.Dense, dense.dense_layer_computation) | ||
ln_registry.insert( | ||
tf.keras.layers.LayerNormalization, | ||
layer_normalization.layer_normalization_computation, | ||
) | ||
return { | ||
'layer_norm_only': ln_registry, | ||
} | ||
|
||
|
||
# ============================================================================== | ||
# Main tests. | ||
# ============================================================================== | ||
class GradNormTest(tf.test.TestCase, parameterized.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.strategy = tf.distribute.get_strategy() | ||
|
||
@parameterized.product( | ||
model_name=list(get_layer_norm_model_generators().keys()), | ||
layer_name=list(get_layer_norm_layer_generators().keys()), | ||
parameter_tuple=get_layer_norm_parameter_tuples(), | ||
layer_registry_name=list(get_layer_norm_registries().keys()), | ||
is_eager=[True, False], | ||
) | ||
def test_gradient_norms_on_various_models( | ||
self, | ||
model_name, | ||
layer_name, | ||
parameter_tuple, | ||
layer_registry_name, | ||
is_eager, | ||
): | ||
# Parse inputs to generate test data. | ||
input_dims, parameter_axes = parameter_tuple | ||
|
||
def curried_generator(a, b): | ||
del a, b # Unused by the generator. | ||
layer_norm_generator = get_layer_norm_layer_generators()[layer_name] | ||
return layer_norm_generator(parameter_axes) | ||
|
||
# Load shared assets to all devices. | ||
with self.strategy.scope(): | ||
dummy_output_dim = 1 | ||
model = common_test_utils.get_model_from_generator( | ||
model_generator=get_layer_norm_model_generators()[model_name], | ||
layer_generator=curried_generator, | ||
input_dims=input_dims, | ||
output_dims=[dummy_output_dim], | ||
is_eager=is_eager, | ||
) | ||
|
||
# Define the main testing ops. These may be later compiled to a Graph op. | ||
def test_op(x_batch): | ||
return common_test_utils.get_computed_and_true_norms_from_model( | ||
model=model, | ||
per_example_loss_fn=None, | ||
num_microbatches=None, | ||
x_batch=[x_batch, x_batch] if model_name == 'tower2' else x_batch, | ||
weight_batch=None, | ||
registry=get_layer_norm_registries()[layer_registry_name], | ||
) | ||
|
||
# TPUs can only run `tf.function`-decorated functions. | ||
using_tpu = isinstance(self.strategy, tf.distribute.TPUStrategy) | ||
if using_tpu: | ||
test_op = tf.function(test_op, jit_compile=True, autograph=False) | ||
|
||
# TPUs use lower precision than CPUs, so we relax our criterion (see | ||
# `dense_test.py` for additional discussions). | ||
rtol = 1e-2 if using_tpu else 1e-3 | ||
atol = 1e-1 if using_tpu else 1e-2 | ||
|
||
# Each batched input is a reshape of a `tf.range()` call. | ||
batch_size = 2 | ||
example_size = np.prod(input_dims) | ||
example_values = tf.range(batch_size * example_size, dtype=tf.float32) | ||
x_batch = tf.reshape(example_values, [batch_size] + input_dims) | ||
batch_size = x_batch.shape[0] | ||
# Set up the device ops and run the test. | ||
computed_norms, true_norms = self.strategy.run(test_op, args=(x_batch,)) | ||
# TPUs return replica contexts, which must be unwrapped. | ||
if using_tpu: | ||
common_test_utils.assert_replica_values_are_close(self, computed_norms) | ||
common_test_utils.assert_replica_values_are_close(self, true_norms) | ||
computed_norms = computed_norms.values[0] | ||
true_norms = true_norms.values[0] | ||
self.assertEqual(tf.shape(computed_norms)[0], batch_size) | ||
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
29 changes: 29 additions & 0 deletions
29
...privacy/privacy/fast_gradient_clipping/registry_functions/layer_normalization_tpu_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Copyright 2023, The TensorFlow Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import tensorflow as tf | ||
from tensorflow_privacy.privacy.fast_gradient_clipping import common_test_utils as ctu | ||
from tensorflow_privacy.privacy.fast_gradient_clipping.registry_functions import layer_normalization_test | ||
|
||
|
||
class GradNormTpuTest(layer_normalization_test.GradNormTest): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.strategy = ctu.create_tpu_strategy() | ||
self.assertIn('TPU', self.strategy.extended.worker_devices[0]) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters