Skip to content

Commit

Permalink
Add the second set of EinsumDense utility functions for implementing …
Browse files Browse the repository at this point in the history
…fast gradient norm computation.

PiperOrigin-RevId: 568063831
  • Loading branch information
tensorflower-gardener committed Sep 24, 2023
1 parent 1be6e02 commit 62a2d43
Show file tree
Hide file tree
Showing 3 changed files with 270 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ py_test(
name = "einsum_utils_test",
srcs = ["einsum_utils_test.py"],
python_version = "PY3",
shard_count = 4,
srcs_version = "PY3",
deps = [":einsum_utils"],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import enum
import itertools
import os
import re

import numpy as np
Expand All @@ -37,6 +38,36 @@ def _is_batch_of_vectors(t: tf.Tensor) -> bool:
return num_nontrivial_indices <= 1


def _is_valid_einsum_equation(
maybe_ab: str,
maybe_bc: str,
maybe_ac: str,
) -> bool:
"""Checks if three input strings form a valid einsum dense equation.
Given three substrings `maybe_ab`, `maybe_bc`, and `maybe_ac`, this function
checks if
```
maybe_ab + ',' + maybe_bc + '->' + maybe_ac
```
is an einsum equation of the form `ab,bc->ac`.
Args:
maybe_ab: The proposed `ab` substring.
maybe_bc: The proposed `bc` substring.
maybe_ac: The proposed `ac` substring.
Returns:
`True` if the three input strings form an einsum equation of the form
`ab,bc->ac` and `False` otherwise.
"""
a_substr = os.path.commonprefix([maybe_ab, maybe_ac])
a_len = len(a_substr)
b_substr = maybe_ab[a_len:]
c_substr = maybe_ac[a_len:]
return maybe_bc == b_substr + c_substr


def _parse_einsum_equation(
equation: str,
) -> tuple[EquationType, tuple[str, str, str]]:
Expand All @@ -61,22 +92,33 @@ def _try_match(regex_str):
maybe_match = re.fullmatch(regex_str, equation)
return maybe_match.groups() if maybe_match is not None else None

groups1 = _try_match(r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)")
if groups1 is not None:
return EquationType.NO_ELLIPSES, groups1
groups2 = _try_match(r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)")
if groups2 is not None:
return EquationType.LEFT_ELLIPSES, groups2
groups3 = _try_match(r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.")
if groups3 is not None:
return EquationType.RIGHT_ELLIPSES, groups3
raise ValueError(
error_message = (
"Invalid Einsum equation string "
+ equation
+ " ."
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
"{ab...,bc->ac...}"
)
case_pairs = [
# equation_type, regex_str
(EquationType.NO_ELLIPSES, r"([a-zA-Z]+),([a-zA-Z]+)->([a-zA-Z]+)"),
(
EquationType.LEFT_ELLIPSES,
r"\.\.\.([a-zA-Z]+),([a-zA-Z]+)->\.\.\.([a-zA-Z]+)",
),
(
EquationType.RIGHT_ELLIPSES,
r"([a-zA-Z]+)\.\.\.,([a-zA-Z]+)->([a-zA-Z]+)\.\.\.",
),
]
for equation_type, regex_str in case_pairs:
groups = _try_match(regex_str)
if groups is not None:
if not _is_valid_einsum_equation(*groups):
raise ValueError(error_message)
return equation_type, groups
# No valid cases found. Raise an error.
raise ValueError(error_message)


def _reshape_einsum_inputs(
Expand All @@ -100,13 +142,17 @@ def _reshape_einsum_inputs(
and the number of output columns is the second dimension of the input. The
product of the non-trivial dimensions of the output should be equal to
the product of the dimensions of `input_tensor`.
Raises:
ValueError: If `equation` is not a valid einsum equation in the context of
the `tf.keras.layers.EinsumDense` layer.
"""
# Find the components `ab`, `bc`, and `ac` given that `equation` can only be
# one of the following mutually exclusive forms:
#
# (C1) ab,bc->ac,
# (C2) ...ab,bc->...ac
# (C3) ab...,bc->ac...
# C1. ab,bc->ac,
# C2. ...ab,bc->...ac
# C3. ab...,bc->ac...
#
# NOTE: `a`, `b`, and `c` are (possibly) also substrings.

Expand All @@ -115,7 +161,7 @@ def _reshape_einsum_inputs(
input_len = len(input_shape)
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
if equation_type == EquationType.LEFT_ELLIPSES:
# In case (C2), the `a` part of this component can be empty, so we have no
# In case C2, the `a` part of this component can be empty, so we have no
# choice but to compare the `c` part of `ac` with the `bc` component.
c_len = 0
for s1, s2 in itertools.zip_longest(reversed(bc_str), reversed(ac_str)):
Expand All @@ -135,7 +181,7 @@ def _reshape_einsum_inputs(
else:
break
# Prepare `input_tensor` for reshaping and get the pivot index of the prepped
# tensor. Note that case (C3) requires a transpose to ensure that matrix
# tensor. Note that case C3 requires a transpose to ensure that matrix
# multiplication is performed by the caller.
if equation_type == EquationType.RIGHT_ELLIPSES:
ellipses_idx = len(ab_str)
Expand Down Expand Up @@ -178,17 +224,124 @@ def _reshape_einsum_outputs(
A rank-3 `tf.Tensor` whose first dimension is the batch dimension. The
product of the non-trivial dimensions of the output should be equal to
the product of the non-trivial dimensions of `output_tensor`.
Raises:
ValueError: If `equation` is not a valid einsum equation in the context of
the `tf.keras.layers.EinsumDense` layer.
"""
match = re.fullmatch(r"([a-zA-Z.]+),([a-zA-Z.]+)->([a-zA-Z.]+)", equation)
if match is not None:
s1, s2, s3 = match.groups()
else:
raise ValueError(
"Invalid Einsum equation string "
+ equation
+ " ."
"Must be one of the forms {ab,bc->ac}, {...ab,bc->...ac}, "
"{ab...,bc->ac...}"
)
reversed_equation = s3 + "," + s2[::-1] + "->" + s1
# Get the raw components of the reversed equation.
equation_type, (ab_str, bc_str, ac_str) = _parse_einsum_equation(equation)
prefix = "..." if equation_type == EquationType.LEFT_ELLIPSES else ""
suffix = "..." if equation_type == EquationType.RIGHT_ELLIPSES else ""
ellided_ab_str = prefix + ab_str + suffix
ellided_ac_str = prefix + ac_str + suffix
# Swap the `b` and `c` components.
c_str = os.path.commonprefix([bc_str[::-1], ac_str[::-1]])[::-1]
b_len = len(bc_str) - len(c_str)
b_str = bc_str[:b_len]
cb_str = c_str + b_str
reversed_equation = ellided_ac_str + "," + cb_str + "->" + ellided_ab_str
return _reshape_einsum_inputs(output_tensor, reversed_equation)


def _get_einsum_bias_adjoint_reduction_axes(
equation: str,
bias_axes: str,
einsum_rank: int,
) -> list[int]:
"""Computes axes related to the per-example adjoint of the einsum bias op.
To describe the output of this computation, first recall that for each
example the `EinsumDense` layer performs the following transformation:
```
F(W, bias | X) = Einsum(W, X) + Q(bias)
```
where `W` is a tensor of trainable variables, `bias` is a tensor of rank
`len(bias_axes)`, `X` is a batch of inputs, and `Q` is a linear broadcast
operator that roughly corresponds to `Q(bias) ~= tf.broadcast_to(bias, S)` for
`S := tf.shape(Einsum(W, X))`.
It is straightforward to show that the per-example adjoint of `Q` is given by
`Q'(Y) := tf.reduce_sum(Y, axes=R)` where `R` contains the broadcasting
indices. This function returns `R` as an unordered list of `int`s.
Assumptions:
A1. `equation` is one of the following forms:
C1. `ab,bc->ac`
C2. `...ab,bc->...ac`
C3. `ab...,bc->ac...`
A2. The first character in the substring `a` (or `...a` in C2)
in assumption A1 corresponds to the batch dimension.
A3. The characters in `bias_axes` must be subset of the non-batch dimension
characters in the substring `ac` (or `...ac` in C2) in
assumption A1.
A4. `einsum_rank` is the length of the substring `ac` (or `...ac` in C2) in
assumption A1. This includes the batch dimension.
Examples:
1. equation = 'ab,bc->ac', bias_axes = 'c', einsum_rank = 2 -> []
2. equation = 'ab,bce->ace', bias_axes = 'ce', einsum_rank = 3, -> []
3. equation = 'ab,bce->ace', bias_axes = 'c', einsum_rank = 3, -> [2]
4. equation = 'ab,bce->ace', bias_axes = 'e', einsum_rank = 3, -> [1]
5. equation = 'ab,bced->aced', bias_axes = 'ced', einsum_rank = 4 -> []
6. equation = 'ab,bced->aced', bias_axes = 'ce', einsum_rank = 4, -> [3],
7. equation = 'ab,bced->aced', bias_axes = 'c', einsum_rank = 4, -> [2, 3]
8. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 4
-> [1, 3]
9. equation = '...ab,bce->...ace', bias_axes = 'c', einsum_rank = 10
-> [1, 2, 3, 4, 5, 6, 7, 9]
10. equation = 'ab...,bce->ace...', bias_axes = 'e', einsum_rank = 4
-> [1, 3]
Args:
equation: The einsum equation `string`.
bias_axes: A substring of the output part of `equation` specifying which
axes a bias `tf.Tensor` is added to.
einsum_rank: The rank of the tensor that the per-example adjoint operator is
being applied to.
Returns:
A list of `int` containing axes in the `input` corresponding to
`input_rank`. Each `int` is at most `input_rank-1` and excludes zero.
Raises:
ValueError: If `equation` is not a valid einsum equation in the context of
the `tf.keras.layers.EinsumDense` layer.
"""
reduction_axes = []
bias_char_set = set(bias_axes)
equation_type, (_, _, ac_str) = _parse_einsum_equation(equation)
# Do not allow the bias axes to be the batch axis, since we want the adjoint
# of the bias broadcast op to apply the same operation to all examples in a
# batch.
if equation_type != EquationType.LEFT_ELLIPSES and ac_str[0] in bias_axes:
raise ValueError(f"Bias axis '{bias_axes}' cannot also be the batch axis.")
# If `equation` of the form `...ab,bc->...ac`, i.e., case C2, we do a
# right to left traversal; the other cases do a left to right traversal.
input_indices = range(einsum_rank)
traversal_zip = (
itertools.zip_longest(reversed(input_indices), reversed(ac_str))
if equation_type == EquationType.LEFT_ELLIPSES
else itertools.zip_longest(input_indices, ac_str)
)
# Traverse the output part of `equation` and add an index to the output if
# the corresponding `char` in the `ac` part is NOT in `bias_axes` and the
# index is not zero (batch dimension). Add all indices except index zero in
# the `...` part of the output substring (if present).
for idx, output_char in traversal_zip:
# Exclude the batch dimension (idx == 0), since we want the per-example
# adjoint.
if idx != 0:
if output_char is not None and bias_char_set:
if output_char not in bias_char_set:
reduction_axes.append(idx)
else:
bias_char_set.remove(output_char)
else:
reduction_axes.append(idx)
return reduction_axes
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,21 @@ def test_is_batch_of_vectors(self, experiment_params):
computed_result = einsum_utils._is_batch_of_vectors(t)
self.assertEqual(computed_result, true_result)

@parameterized.product(
experiment_params=[
(('ab', 'bc', 'ac'), True),
(('ab', 'a', 'b'), False),
(('ab', 'ca', 'bc'), False),
(('b', 'bc', 'c'), True),
(('ab', 'bc', 'bc'), False),
(('abc', 'cde', 'abde'), True),
]
)
def test_is_valid_einsum_equation(self, experiment_params):
inputs, true_result = experiment_params
computed_result = einsum_utils._is_valid_einsum_equation(*inputs)
self.assertEqual(computed_result, true_result)

@parameterized.product(
experiment_params=[
(
Expand All @@ -66,7 +81,7 @@ def test_is_batch_of_vectors(self, experiment_params):
)
def test_parse_einsum_equation(self, experiment_params):
equation, true_eqn_type, true_groups = experiment_params
(computed_eqn_type, computed_groups) = einsum_utils._parse_einsum_equation(
computed_eqn_type, computed_groups = einsum_utils._parse_einsum_equation(
equation
)
self.assertEqual(computed_eqn_type, true_eqn_type)
Expand Down Expand Up @@ -98,7 +113,7 @@ def test_parse_einsum_equation(self, experiment_params):
]
)
def test_reshape_einsum_inputs(self, experiment_params):
(equation, input_shape, true_permutations, true_parsed_shape) = (
equation, input_shape, true_permutations, true_parsed_shape = (
experiment_params
)
num_entries = int(np.prod(input_shape))
Expand Down Expand Up @@ -141,7 +156,7 @@ def test_reshape_einsum_inputs(self, experiment_params):
]
)
def test_reshape_einsum_outputs(self, experiment_params):
(equation, output_shape, true_permutations, true_parsed_shape) = (
equation, output_shape, true_permutations, true_parsed_shape = (
experiment_params
)
num_entries = int(np.prod(output_shape))
Expand All @@ -158,6 +173,77 @@ def test_reshape_einsum_outputs(self, experiment_params):
true_parsed_tensor = tf.reshape(true_parsed_tensor, true_parsed_shape)
self.assertAllEqual(computed_parsed_tensor, true_parsed_tensor)

@parameterized.product(
experiment_params=[
# einsum_utils.EquationType.NO_ELLIPSES
('ab,bc->ac', 'c', 2, []),
('ab,bce->ace', 'ce', 3, []),
('ab,bce->ace', 'ec', 3, []),
('ab,bce->ace', 'c', 3, [2]),
('ab,bce->ace', 'e', 3, [1]),
('ab,bced->aced', 'ced', 4, []),
('ab,bced->aced', 'edc', 4, []),
('ab,bced->aced', 'ce', 4, [3]),
('ab,bced->aced', 'ec', 4, [3]),
('ab,bced->aced', 'cd', 4, [2]),
('ab,bced->aced', 'ed', 4, [1]),
('ab,bced->aced', 'c', 4, [2, 3]),
('ab,bced->aced', 'e', 4, [1, 3]),
('ab,bced->aced', 'd', 4, [1, 2]),
# einsum_utils.EquationType.LEFT_ELLIPSES
('...b,bc->...c', 'c', 2, []),
('...b,bce->...ce', 'c', 3, [2]),
('...b,bce->...ce', 'e', 3, [1]),
('...ab,bc->...ac', 'c', 3, [1]),
('...ab,bce->...ace', 'ac', 4, [3]),
('...ab,bce->...ace', 'ae', 4, [2]),
('...ab,bce->...ace', 'ce', 4, [1]),
('...ab,bce->...ace', 'ec', 4, [1]),
('...ab,bce->...ace', 'a', 4, [2, 3]),
('...ab,bce->...ace', 'c', 4, [1, 3]),
('...ab,bce->...ace', 'e', 4, [1, 2]),
('...ab,bce->...ace', 'c', 5, [1, 2, 4]),
('...ab,bce->...ace', 'c', 10, [1, 2, 3, 4, 5, 6, 7, 9]),
# einsum_utils.EquationType.RIGHT_ELLIPSES
('ab...,bc->ac...', 'c', 3, [2]),
('ab...,bce->ace...', 'ce', 4, [3]),
('ab...,bce->ace...', 'ec', 4, [3]),
('ab...,bce->ace...', 'c', 4, [2, 3]),
('ab...,bce->ace...', 'e', 4, [1, 3]),
]
)
def test_get_einsum_bias_adjoint_reduction_axes(self, experiment_params):
equation, bias_axes, einsum_rank, true_reduction_axes = experiment_params
computed_reduction_axes = (
einsum_utils._get_einsum_bias_adjoint_reduction_axes(
equation, bias_axes, einsum_rank
)
)
computed_reduction_axes.sort()
true_reduction_axes.sort()
self.assertAllEqual(computed_reduction_axes, true_reduction_axes)

@parameterized.product(
experiment_params=[
# einsum_utils.EquationType.NO_ELLIPSES
('ab,bc->ac', 'a', 2),
# einsum_utils.EquationType.RIGHT_ELLIPSES
('ab...,bc->ac...', 'a', 3),
('ab...,bc->ac...', 'a', 4),
('ab...,bcde->acde...', 'acd', 4),
]
)
def test_bias_axis_eq_batch_axis_throws_error(self, experiment_params):
equation, bias_axes, einsum_rank = experiment_params
with self.assertRaises(ValueError) as context:
einsum_utils._get_einsum_bias_adjoint_reduction_axes(
equation, bias_axes, einsum_rank
)
self.assertEqual(
f"Bias axis '{bias_axes}' cannot also be the batch axis.",
str(context.exception),
)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 62a2d43

Please sign in to comment.