Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 572248552
  • Loading branch information
Flaxformer Team committed Oct 11, 2023
1 parent f40decd commit ea17eb0
Show file tree
Hide file tree
Showing 10 changed files with 41 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ as needed!
First, we recommend installing a few dependencies manually,

```
pip3 install numpy sentencepiece tensorflow==2.8.1
pip3 install numpy sentencepiece tensorflow>=2.14.0
```

This is a workaround to prevent pip backtracking on package versions; we
Expand Down
2 changes: 1 addition & 1 deletion flaxformer/activation_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from absl import logging
from flax.linen import partitioning as flax_partitioning
import jax
from jax.experimental.pjit import with_sharding_constraint as jax_pjit_wsc
from jax.lax import with_sharding_constraint as jax_pjit_wsc


def global_mesh_defined():
Expand Down
6 changes: 3 additions & 3 deletions flaxformer/architectures/dual_encoder/similarity_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __call__(self, left_encodings: Array, right_encodings: Array,
# Implement the dot product as module to be consistent to other similarity
# functions.
del self
return jnp.sum(left_encodings * right_encodings, axis=-1, keepdims=True)
return jnp.sum(left_encodings * right_encodings, axis=-1, keepdims=True) # pytype: disable=bad-return-type # jnp-type


# ============================ Batch Similarity ================================
Expand Down Expand Up @@ -266,7 +266,7 @@ def __call__(self,
# so shape is [batch_size, batch_size * (1 + num_hard_negatives)].
logits = jnp.dot(left_encodings, right_encodings.transpose())

return logits
return logits # pytype: disable=bad-return-type # jnp-type


class DoNothing(nn.Module):
Expand Down Expand Up @@ -300,7 +300,7 @@ def __call__(self,
del right_encodings
del right_additional_encodings
del params
return jnp.zeros((), dtype=jnp.int32)
return jnp.zeros((), dtype=jnp.int32) # pytype: disable=bad-return-type # jnp-type


class BatchAttentionSimilarity(nn.Module):
Expand Down
6 changes: 6 additions & 0 deletions flaxformer/architectures/h_transformer/token_hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ class TokenCoarseningMethod(str, enum.Enum):
SUM = 'sum'
CONST_AVERAGE = 'const_average'

def __format__(self, format_spec: str) -> str:
return self.value.__format__(format_spec)


@gin.constants_from_enum
class ConvKernelType(str, enum.Enum):
Expand All @@ -92,6 +95,9 @@ class ConvKernelType(str, enum.Enum):
CONST = 'const'
LINEAR = 'linear'

def __format__(self, format_spec: str) -> str:
return self.value.__format__(format_spec)


class OneDimTokenCoarsening:
"""Coarsening class for one-dimension sequence token hierarchy."""
Expand Down
13 changes: 13 additions & 0 deletions flaxformer/architectures/h_transformer/token_hierarchy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,19 @@
class OneDimTokenCoarseningTest(parameterized.TestCase):
"""Test cases for OneDimTokenCoarsening."""

def test_enum_format(self):
# Python 3.10 and 3.11 changed behavior. We just want to make sure that enum
# items can be formatted.
self.assertIsInstance(
'token coarsening: {:20s}.'.format(
token_hierarchy.TokenCoarseningMethod.CONST_AVERAGE
),
str,
)
self.assertIsInstance(
'conv: {:10s}.'.format(token_hierarchy.ConvKernelType.CONST), str
)

@parameterized.named_parameters(
('sample', token_hierarchy.TokenCoarseningMethod.SAMPLE,
np.array([[[[1.], [2.]], [[5.], [6.]]]])),
Expand Down
3 changes: 3 additions & 0 deletions flaxformer/architectures/moe/moe_enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,6 @@ class LayerLayout(str, enum.Enum):
MIDDLE = 'middle'
MIXED = 'mixed'
TOP = 'top'

def __format__(self, format_spec: str) -> str:
return self.value.__format__(format_spec)
2 changes: 1 addition & 1 deletion flaxformer/architectures/moe/moe_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def _mask_and_dispatch_to_experts(
total_expert_capacity = self.num_experts * expert_capacity * num_groups
expert_usage = num_tokens_dispatched / total_expert_capacity

self._sow_expert_metrics(
self._sow_expert_metrics( # pytype: disable=wrong-arg-types # jnp-type
router_mask.auxiliary_loss,
router_mask.router_z_loss,
fraction_tokens_left_behind,
Expand Down
12 changes: 8 additions & 4 deletions flaxformer/architectures/moe/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,9 +745,13 @@ def _load_balancing_loss(router_probs: Array, expert_indices: Array) -> float:
expert_mask, dtype=jnp.float32, axis=-2)
router_prob_per_group_and_expert = jnp.mean(
router_probs, dtype=jnp.float32, axis=-2)
return jnp.mean(
tokens_per_group_and_expert * router_prob_per_group_and_expert,
dtype=jnp.float32) * num_experts**2
return (
jnp.mean( # pytype: disable=bad-return-type # jnp-type
tokens_per_group_and_expert * router_prob_per_group_and_expert,
dtype=jnp.float32,
)
* num_experts**2
)


def _router_z_loss(router_logits: Array) -> float:
Expand All @@ -767,4 +771,4 @@ def _router_z_loss(router_logits: Array) -> float:
num_groups, tokens_per_group, _ = router_logits.shape
log_z = jax.nn.logsumexp(router_logits, axis=-1)
z_loss = log_z**2
return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group)
return jnp.sum(z_loss, dtype=jnp.float32) / (num_groups * tokens_per_group) # pytype: disable=bad-return-type # jnp-type
2 changes: 1 addition & 1 deletion flaxformer/components/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from flax.core.lift import Out as ScanOut # pylint: disable=unused-import
from flax.linen import partitioning
import jax
from jax.experimental.pjit import with_sharding_constraint as jax_pjit_wsc
from jax.lax import with_sharding_constraint as jax_pjit_wsc

# TODO: this file contains JAX transform workarounds to fix/move
# upstream, primarily concerning the JAX checkpoint/remat transform and
Expand Down
7 changes: 4 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,21 @@
"numpy>=1.12",
"jax>=0.2.21",
"flax>=0.6.9",
"aqtp[jax_legacy]>=0.0.10, <=0.1.0",
"aqtp>=0.1.0",
]

tests_require = [
"absl-py",
"pytest",
"tensorflow>=2.12.0",
"tensorflow>=2.14.0",
"tensorflow-text>=2.14.0rc0",
"gin-config",
"t5x @ git+https://github.com/google-research/t5x",
]

setup(
name="flaxformer",
version="0.8.3",
version="0.8.4",
description="Flaxformer: Transformer implementations in Flax",
long_description="\n\n".join([README]),
long_description_content_type="text/markdown",
Expand Down

0 comments on commit ea17eb0

Please sign in to comment.