Skip to content

Commit

Permalink
Update Kauldron sharding
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720487197
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Jan 31, 2025
1 parent 760e689 commit 37d6253
Show file tree
Hide file tree
Showing 6 changed files with 204 additions and 82 deletions.
13 changes: 5 additions & 8 deletions kauldron/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from kauldron import kontext
from kauldron.train import context as context_lib
from kauldron.typing import ArraySpec, ElementSpec, PyTree # pylint: disable=g-multiple-import,g-importing-member
from kauldron.utils import _jax
from kauldron.utils import sharding_utils
import numpy as np

Expand Down Expand Up @@ -85,15 +86,11 @@ def mock_batch_from_elem_spec(
"""Create a mock batch from the element_spec of a data iterator."""
elem_spec = etree.spec_like(elem_spec)

# We only support FIRST_DIM and REPLICATED sharding for now.
def _get_global_shape(spec):
if elem_sharding is sharding_utils.sharding.FIRST_DIM:
shape = (spec.shape[0] * jax.process_count(),) + spec.shape[1:]
elif elem_sharding is sharding_utils.sharding.REPLICATED:
shape = spec.shape
else:
raise ValueError(f"Unsupported sharding: {elem_sharding!r}")
return ArraySpec(shape=shape, dtype=spec.dtype)
return ArraySpec(
shape=_jax.local_to_global_shape(spec.shape, sharding=elem_sharding),
dtype=spec.dtype,
)

elem_spec = jax.tree.map(_get_global_shape, elem_spec)

Expand Down
13 changes: 8 additions & 5 deletions kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,11 +290,14 @@ def basic_eval_step(
"""Call the model (pmap version)."""
# Note that step is train step (from train state), NOT `eval_step`
ctx = context_lib.Context.from_state_and_batch(state=state, batch=batch)
_, ctx = model_with_aux.forward(
context=ctx,
rngs=rng_streams.eval_rngs(eval_step),
is_training=False,
)

with sharding.set_global_mesh():
_, ctx = model_with_aux.forward(
context=ctx,
rngs=rng_streams.eval_rngs(eval_step),
is_training=False,
)

aux = model_with_aux.get_aux(
ctx,
return_losses=True,
Expand Down
43 changes: 25 additions & 18 deletions kauldron/train/train_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,15 @@ def _init_model(
"""Initialize the model and return the initial TrainState."""
batch = data_utils.mock_batch_from_elem_spec(elem_spec, self.sharding.ds)
args, kwargs = data_utils.get_model_inputs_from_batch(self.model, batch)
collections = self.model.init(
self.rng_streams.init_rngs(),
*args,
method=model_method,
is_training_property=True,
capture_intermediates=True,
**kwargs,
)
with self.sharding.set_global_mesh():
collections = self.model.init(
self.rng_streams.init_rngs(),
*args,
method=model_method,
is_training_property=True,
capture_intermediates=True,
**kwargs,
)
collections = flax.core.unfreeze(collections)
params = collections.pop("params", {})
collections.pop("intermediates", None) # Remove intermediates
Expand Down Expand Up @@ -222,19 +223,25 @@ def step(
auxiliaries: Auxiliaries containing the losses, metrics and summaries
states.
"""
# This function is just a small wrapper around `_step` for:
# * Checkify errors handling
# * Select which auxiliaries metrics to return.
# * Sharding
# If reading the code, you can likely skip this function and go directly
# to `_step`.

if checkify_error_categories:
step_fn = checkify.checkify(self._step, errors=checkify_error_categories)
error, (state, ctx) = step_fn(state, batch)
else:
error = None
state, ctx = self._step(state, batch)
# This function is just a small wrapper around `_step` for:
# * Checkify errors handling
# * Select which auxiliaries metrics to return.
# * Set the output sharding
# * Wrap the step function in the `self.sharding.set_global_mesh()` context
# (as some implementations of models rely on a global mesh).

with self.sharding.set_global_mesh():
if checkify_error_categories:
step_fn = checkify.checkify(
self._step, errors=checkify_error_categories
)
error, (state, ctx) = step_fn(state, batch)
else:
error = None
state, ctx = self._step(state, batch)

# TODO(epot): More flexible way to select the subset of context to return ?
# And have a way to return the full context ?
Expand Down
43 changes: 43 additions & 0 deletions kauldron/utils/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
_P = ParamSpec('_P')


# TODO(epot): Remove this function once Jax provide this natively.
def eval_shape_with_sharding(
fn: Callable[_P, _T],
*args: _P.args,
Expand Down Expand Up @@ -52,3 +53,45 @@ def replace_shape_dtype_struct(
}
new_kwargs.update(kwargs)
return jax.ShapeDtypeStruct(**new_kwargs)


def local_to_global_shape(
shape: tuple[int, ...],
*,
sharding: jax.sharding.Sharding,
) -> tuple[int, ...]:
"""Convert a per-process shape to a global shape.
Contrary to the jax version, this function always scale the sharded dimension
by the number of processes.
Example:
* shape=(x, y), sharding=('a') -> (x * jax.process_count(), y)
Args:
shape: The local shape
sharding: The sharding to apply
Returns:
The global shape
"""
if not isinstance(sharding, jax.sharding.NamedSharding):
raise ValueError(
f'Only NamedSharding is supported for now. Got: {sharding!r}'
)

# TODO(epot): We only support sharding on the first dimension for now. To
# supports arbitrary axes, the sharding would have to specify on which
# dimensions are split across hosts and which are simply sharded
match len(sharding.spec):
case 0:
return shape
case 1:
return (shape[0] * jax.process_count(),) + shape[1:]
case _:
# TODO(epot): Supports spec = `('batch', None)`
raise ValueError(
f'Data can only be sharded on the first dimension. Got: {sharding!r}.'
' Please raise an issue if you need this.'
)
76 changes: 76 additions & 0 deletions kauldron/utils/_jax_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from unittest import mock

import jax
from kauldron.utils import _jax
import numpy as np
import pytest


@pytest.mark.parametrize(
'axes, process_count, in_shape, expected',
[
((), 1, (3, 4), (3, 4)),
((), 4, (3, 4), (3, 4)),
(('batch',), 1, (3, 4), (3, 4)),
(('batch',), 4, (3, 4), (12, 4)),
# One dim can match to 2 mesh axes (`spec[0] == ('batch', 'replica')`)
# by using nested tuple: `(('batch', 'replica'),)`
((('batch', 'replica'),), 1, (3, 4), (3, 4)),
((('batch', 'replica'),), 4, (3, 4), (12, 4)),
],
)
def test_local_to_global_shape(
axes: tuple[str, ...],
process_count: int,
in_shape: tuple[int, ...],
expected: tuple[int, ...],
):
devices = np.array(jax.devices())
devices = devices.reshape((1, 1, 1, 1))

mesh = jax.sharding.Mesh(
devices,
axis_names=('replica', 'batch', 'seq', 'model'),
)

sharding = jax.sharding.NamedSharding(
mesh,
spec=jax.sharding.PartitionSpec(*axes),
)

with mock.patch.object(jax, 'process_count', return_value=process_count):
out_shape = _jax.local_to_global_shape(in_shape, sharding=sharding)
assert out_shape == expected


def test_local_to_global_shape_fail():

devices = np.array(jax.devices())
devices = devices.reshape((1, 1, 1, 1))

mesh = jax.sharding.Mesh(
devices,
axis_names=('replica', 'batch', 'seq', 'model'),
)

sharding = jax.sharding.NamedSharding(
mesh,
spec=jax.sharding.PartitionSpec('batch', 'seq'),
)

with pytest.raises(ValueError, match='Data can only be sharded on the first'):
_jax.local_to_global_shape((3, 4), sharding=sharding)
Loading

0 comments on commit 37d6253

Please sign in to comment.