Skip to content

Commit

Permalink
Fix a gradient clipping bug for layer normalization layers with micro…
Browse files Browse the repository at this point in the history
…batch axes.

The previous code passed the unstacked gradients (a list) instead of the stacked gradients (a tensor) to the microbatcher, which led to unexpected behavior. This change passes the right argument and changes the original unit test to catch this bug.

PiperOrigin-RevId: 669413064
  • Loading branch information
wwkong authored and tensorflower-gardener committed Aug 30, 2024
1 parent b396397 commit 66d05a2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ def sqr_norm_fn(grads):
stacked_grads = tf.stack(grads, axis=-1)
if num_microbatches is not None:
stacked_grads = common_manip_utils.maybe_add_microbatch_axis(
grads, num_microbatches
stacked_grads, num_microbatches
)
# We will need to sum over the new microbatch size axis (axis=1) in order
# to account for microbatch aggregation.
stacked_grads = tf.reduce_sum(stacked_grads, axis=1)
reduction_axes = tf.range(1, tf.rank(stacked_grads))
return tf.reduce_sum(tf.square(stacked_grads), axis=reduction_axes)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def test_op(x_batch):
atol = 1e-1 if self.using_tpu else 1e-2

# Each batched input is a reshape of a `tf.range()` call.
batch_size = 2
batch_size = 6
example_size = np.prod(input_dims)
example_values = tf.range(batch_size * example_size, dtype=tf.float32)
x_batch = tf.reshape(example_values, [batch_size] + input_dims)
Expand All @@ -147,7 +147,9 @@ def test_op(x_batch):
common_test_utils.assert_replica_values_are_close(self, true_norms)
computed_norms = computed_norms.values[0]
true_norms = true_norms.values[0]
self.assertEqual(tf.shape(computed_norms)[0], batch_size)
self.assertEqual(
tf.shape(computed_norms)[0], num_microbatches or batch_size
)
self.assertAllClose(computed_norms, true_norms, rtol=rtol, atol=atol)


Expand Down

0 comments on commit 66d05a2

Please sign in to comment.