From ea17eb012a1d340ddff017b7a534c2162aaec34c Mon Sep 17 00:00:00 2001 From: Flaxformer Team Date: Tue, 10 Oct 2023 14:53:57 +0000 Subject: [PATCH] No public description PiperOrigin-RevId: 572248552 --- README.md | 2 +- flaxformer/activation_partitioning.py | 2 +- .../dual_encoder/similarity_functions.py | 6 +++--- .../architectures/h_transformer/token_hierarchy.py | 6 ++++++ .../h_transformer/token_hierarchy_test.py | 13 +++++++++++++ flaxformer/architectures/moe/moe_enums.py | 3 +++ flaxformer/architectures/moe/moe_layers.py | 2 +- flaxformer/architectures/moe/routing.py | 12 ++++++++---- flaxformer/components/transforms.py | 2 +- setup.py | 7 ++++--- 10 files changed, 41 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index df3e5d3..9ed49fc 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ as needed! First, we recommend installing a few dependencies manually, ``` -pip3 install numpy sentencepiece tensorflow==2.8.1 +pip3 install numpy sentencepiece tensorflow>=2.14.0 ``` This is a workaround to prevent pip backtracking on package versions; we diff --git a/flaxformer/activation_partitioning.py b/flaxformer/activation_partitioning.py index 6ae0800..39a1303 100644 --- a/flaxformer/activation_partitioning.py +++ b/flaxformer/activation_partitioning.py @@ -19,7 +19,7 @@ from absl import logging from flax.linen import partitioning as flax_partitioning import jax -from jax.experimental.pjit import with_sharding_constraint as jax_pjit_wsc +from jax.lax import with_sharding_constraint as jax_pjit_wsc def global_mesh_defined(): diff --git a/flaxformer/architectures/dual_encoder/similarity_functions.py b/flaxformer/architectures/dual_encoder/similarity_functions.py index a0a3619..49627c2 100644 --- a/flaxformer/architectures/dual_encoder/similarity_functions.py +++ b/flaxformer/architectures/dual_encoder/similarity_functions.py @@ -207,7 +207,7 @@ def __call__(self, left_encodings: Array, right_encodings: Array, # Implement the dot product as module to be consistent to other similarity # functions. del self - return jnp.sum(left_encodings * right_encodings, axis=-1, keepdims=True) + return jnp.sum(left_encodings * right_encodings, axis=-1, keepdims=True) # pytype: disable=bad-return-type # jnp-type # ============================ Batch Similarity ================================ @@ -266,7 +266,7 @@ def __call__(self, # so shape is [batch_size, batch_size * (1 + num_hard_negatives)]. logits = jnp.dot(left_encodings, right_encodings.transpose()) - return logits + return logits # pytype: disable=bad-return-type # jnp-type class DoNothing(nn.Module): @@ -300,7 +300,7 @@ def __call__(self, del right_encodings del right_additional_encodings del params - return jnp.zeros((), dtype=jnp.int32) + return jnp.zeros((), dtype=jnp.int32) # pytype: disable=bad-return-type # jnp-type class BatchAttentionSimilarity(nn.Module): diff --git a/flaxformer/architectures/h_transformer/token_hierarchy.py b/flaxformer/architectures/h_transformer/token_hierarchy.py index 4aeeccc..337c64e 100644 --- a/flaxformer/architectures/h_transformer/token_hierarchy.py +++ b/flaxformer/architectures/h_transformer/token_hierarchy.py @@ -84,6 +84,9 @@ class TokenCoarseningMethod(str, enum.Enum): SUM = 'sum' CONST_AVERAGE = 'const_average' + def __format__(self, format_spec: str) -> str: + return self.value.__format__(format_spec) + @gin.constants_from_enum class ConvKernelType(str, enum.Enum): @@ -92,6 +95,9 @@ class ConvKernelType(str, enum.Enum): CONST = 'const' LINEAR = 'linear' + def __format__(self, format_spec: str) -> str: + return self.value.__format__(format_spec) + class OneDimTokenCoarsening: """Coarsening class for one-dimension sequence token hierarchy.""" diff --git a/flaxformer/architectures/h_transformer/token_hierarchy_test.py b/flaxformer/architectures/h_transformer/token_hierarchy_test.py index 00e6d8e..b1acf83 100644 --- a/flaxformer/architectures/h_transformer/token_hierarchy_test.py +++ b/flaxformer/architectures/h_transformer/token_hierarchy_test.py @@ -26,6 +26,19 @@ class OneDimTokenCoarseningTest(parameterized.TestCase): """Test cases for OneDimTokenCoarsening.""" + def test_enum_format(self): + # Python 3.10 and 3.11 changed behavior. We just want to make sure that enum + # items can be formatted. + self.assertIsInstance( + 'token coarsening: {:20s}.'.format( + token_hierarchy.TokenCoarseningMethod.CONST_AVERAGE + ), + str, + ) + self.assertIsInstance( + 'conv: {:10s}.'.format(token_hierarchy.ConvKernelType.CONST), str + ) + @parameterized.named_parameters( ('sample', token_hierarchy.TokenCoarseningMethod.SAMPLE, np.array([[[[1.], [2.]], [[5.], [6.]]]])), diff --git a/flaxformer/architectures/moe/moe_enums.py b/flaxformer/architectures/moe/moe_enums.py index ed08c23..5abcaf2 100644 --- a/flaxformer/architectures/moe/moe_enums.py +++ b/flaxformer/architectures/moe/moe_enums.py @@ -35,3 +35,6 @@ class LayerLayout(str, enum.Enum): MIDDLE = 'middle' MIXED = 'mixed' TOP = 'top' + + def __format__(self, format_spec: str) -> str: + return self.value.__format__(format_spec) diff --git a/flaxformer/architectures/moe/moe_layers.py b/flaxformer/architectures/moe/moe_layers.py index 3a33b85..fbf33af 100644 --- a/flaxformer/architectures/moe/moe_layers.py +++ b/flaxformer/architectures/moe/moe_layers.py @@ -445,7 +445,7 @@ def _mask_and_dispatch_to_experts( total_expert_capacity = self.num_experts * expert_capacity * num_groups expert_usage = num_tokens_dispatched / total_expert_capacity - self._sow_expert_metrics( + self._sow_expert_metrics( # pytype: disable=wrong-arg-types # jnp-type router_mask.auxiliary_loss, router_mask.router_z_loss, fraction_tokens_left_behind, diff --git a/flaxformer/architectures/moe/routing.py b/flaxformer/architectures/moe/routing.py index 82aece6..ea0c5f5 100644 --- a/flaxformer/architectures/moe/routing.py +++ b/flaxformer/architectures/moe/routing.py @@ -745,9 +745,13 @@ def _load_balancing_loss(router_probs: Array, expert_indices: Array) -> float: expert_mask, dtype=jnp.float32, axis=-2) router_prob_per_group_and_expert = jnp.mean( router_probs, dtype=jnp.float32, axis=-2) - return jnp.mean( - tokens_per_group_and_expert * router_prob_per_group_and_expert, - dtype=jnp.float32) * num_experts**2 + return ( + jnp.mean( # pytype: disable=bad-return-type # jnp-type + tokens_per_group_and_expert * router_prob_per_group_and_expert, + dtype=jnp.float32, + ) + * num_experts**2 + ) def _router_z_loss(router_logits: Array) -> float: @@ -767,4 +771,4 @@ def _router_z_loss(router_logits: Array) -> float: num_groups, tokens_per_group, _ = router_logits.shape log_z = jax.nn.logsumexp(router_logits, axis=-1) z_loss = log_z**2 - return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) + return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) # pytype: disable=bad-return-type # jnp-type diff --git a/flaxformer/components/transforms.py b/flaxformer/components/transforms.py index 1974fd8..5a25e80 100644 --- a/flaxformer/components/transforms.py +++ b/flaxformer/components/transforms.py @@ -22,7 +22,7 @@ from flax.core.lift import Out as ScanOut # pylint: disable=unused-import from flax.linen import partitioning import jax -from jax.experimental.pjit import with_sharding_constraint as jax_pjit_wsc +from jax.lax import with_sharding_constraint as jax_pjit_wsc # TODO: this file contains JAX transform workarounds to fix/move # upstream, primarily concerning the JAX checkpoint/remat transform and diff --git a/setup.py b/setup.py index 9ac9437..92b1808 100644 --- a/setup.py +++ b/setup.py @@ -29,20 +29,21 @@ "numpy>=1.12", "jax>=0.2.21", "flax>=0.6.9", - "aqtp[jax_legacy]>=0.0.10, <=0.1.0", + "aqtp>=0.1.0", ] tests_require = [ "absl-py", "pytest", - "tensorflow>=2.12.0", + "tensorflow>=2.14.0", + "tensorflow-text>=2.14.0rc0", "gin-config", "t5x @ git+https://github.com/google-research/t5x", ] setup( name="flaxformer", - version="0.8.3", + version="0.8.4", description="Flaxformer: Transformer implementations in Flax", long_description="\n\n".join([README]), long_description_content_type="text/markdown",