Skip to content

Commit

Permalink
Move immutabledict to a self-contained library and freeze trainer inputs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 721319856
  • Loading branch information
Conchylicultor authored and The kauldron Authors committed Jan 30, 2025
1 parent 733ceda commit d9cb563
Show file tree
Hide file tree
Showing 10 changed files with 200 additions and 111 deletions.
3 changes: 3 additions & 0 deletions kauldron/evals/evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from kauldron.train import train_step
from kauldron.train import trainer_lib
from kauldron.utils import config_util
from kauldron.utils import immutabledict
from kauldron.utils import kdash
from kauldron.utils import utils
from kauldron.utils.sharding_utils import sharding as sharding_lib # pylint: disable=g-importing-member
Expand Down Expand Up @@ -188,6 +189,8 @@ class Evaluator(EvaluatorBase):
def __post_init__(self) -> None:
super().__post_init__()

immutabledict.freeze_dict_attrs(self, ['losses', 'metrics', 'summaries'])

if self.ds is None:
raise ValueError(
f'Eval dataset missing (`cfg.evals.{self.name}.ds is None`). Please'
Expand Down
4 changes: 2 additions & 2 deletions kauldron/konfig/configdict_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@
from etils import epy
from kauldron.konfig import configdict_base
from kauldron.konfig import fake_import_utils
from kauldron.konfig import immutabledict_lib
from kauldron.konfig import utils
from kauldron.utils import immutabledict
import ml_collections


Expand Down Expand Up @@ -176,7 +176,7 @@ def _resolve_sequence(self, value):
def _resolve_dict(self, value):
cls = type(value)
if self._freeze:
cls = immutabledict_lib.ImmutableDict
cls = immutabledict.ImmutableDict
return cls({
k: _reraise_with_info(self._resolve_value, k)(v)
for k, v in _as_dict(value).items()
Expand Down
107 changes: 7 additions & 100 deletions kauldron/konfig/immutabledict_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,106 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Immutable dict util."""
"""DEPRECATED: Use `from kauldron.utils import immutabledict` instead."""

from __future__ import annotations
from kauldron.utils import immutabledict

from collections.abc import Hashable
import sys
from typing import Any, ClassVar
print(
'ImmutableDict was moved. Please uses `from kauldron.utils import'
' immutabledict` instead.'
)

from etils import epy
import immutabledict as immutabledict_lib
from packaging import version

_IMMUTABLE_DICT_V4 = version.parse(
immutabledict_lib.__version__
) >= version.Version('4.0.0')


class ImmutableDict(immutabledict_lib.immutabledict):
"""Immutable dict abstraction with `getattr` access."""

_dca_jax_tree_registered: ClassVar[bool] = False
_flax_registered: ClassVar[bool] = False

def __new__(cls, *args: Any, **kwargs: Any) -> ImmutableDict:
if not cls._dca_jax_tree_registered and 'jax' in sys.modules:
import jax # pylint: disable=g-import-not-at-top # pytype: disable=import-error

jax.tree_util.register_pytree_with_keys_class(cls)
cls._dca_jax_tree_registered = True

if not cls._flax_registered and 'flax' in sys.modules:
import flax # pylint: disable=g-import-not-at-top,g-bad-import-order # pytype: disable=import-error

for type_ in list(flax.serialization._STATE_DICT_REGISTRY): # pylint: disable=undefined-variable
match type_:
case object(
__name__='ImmutableDict',
__module__='kauldron.konfig.immutabledict_lib',
):
del flax.serialization._STATE_DICT_REGISTRY[type_] # pylint: disable=undefined-variable

def restore_immutable_dict(*args, **kwargs):
d = flax.serialization._restore_dict(*args, **kwargs) # pylint: disable=protected-access
return cls(d)

flax.serialization.register_serialization_state(
cls,
flax.serialization._dict_state_dict, # pylint: disable=protected-access
restore_immutable_dict,
)
cls._flax_registered = True

if _IMMUTABLE_DICT_V4:
# immutabledict 4.0.0 switched from using __init__ to __new__ and thus
# requires passing the args and kwargs along here.
return super().__new__(cls, *args, **kwargs) # pylint: disable=no-value-for-parameter
else:
return super().__new__(cls)

def __getattr__(self, name: str) -> str:
# The base-class has a `dict_cls` attribute, but collisions should be
# extremely rare.
return self[name]

def __repr__(self) -> str:
return epy.Lines.make_block(
header=f'{self.__class__.__name__}',
content={repr(k): v for k, v in self._dict.items()},
braces=('({', '})'),
equal=': ',
)

# Jax tree_utils protocol

def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]:
"""Flattens this FrozenDict.
Returns:
A flattened version of this FrozenDict instance.
"""
import jax # pylint: disable=g-import-not-at-top # pytype: disable=import-error

sorted_keys = sorted(self)
return tuple(
[(jax.tree_util.DictKey(k), self[k]) for k in sorted_keys]
), tuple(self)

@classmethod
def tree_unflatten(cls, keys, values):
# Flatten sort the keys, so reconstruct the ordered sorted
ordered_items = {k: v for k, v in zip(sorted(keys), values)}
# Restore original dict order
new_items = ((k, ordered_items[k]) for k in keys)

return cls(new_items)

# Pickle protocol

def __getstate__(self):
return self._dict

def __setstate__(self, state):
self.__init__(state)
ImmutableDict = immutabledict.ImmutableDict
4 changes: 4 additions & 0 deletions kauldron/train/auxiliaries.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from kauldron import summaries as kd_summaries
from kauldron.train import context as context_lib
from kauldron.utils import config_util
from kauldron.utils import immutabledict
from kauldron.utils.kdash import dashboard_utils


Expand All @@ -45,6 +46,9 @@ class Auxiliaries(config_util.UpdateFromRootCfg):
config_util.ROOT_CFG_REF.train_summaries
)

def __post_init__(self):
immutabledict.freeze_dict_attrs(self, ["losses", "metrics", "summaries"])

@jax.named_call
def update_context(self, context: context_lib.Context) -> context_lib.Context:
"""Get auxilaries."""
Expand Down
13 changes: 13 additions & 0 deletions kauldron/train/trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from kauldron.utils import _jax
from kauldron.utils import chrono_utils
from kauldron.utils import config_util
from kauldron.utils import immutabledict
from kauldron.utils import kdash
from kauldron.utils.sharding_utils import sharding as sharding_utils # pylint: disable=g-importing-member
import optax
Expand All @@ -73,6 +74,7 @@
# as Type['JaxException'] before class JaxException is declared.
CheckifyErrorCategory = checkify.ErrorCategory if typing.TYPE_CHECKING else Any

# TODO(epot): Should unify to use `immutabledict` everywhere.
FrozenDict = dict if typing.TYPE_CHECKING else flax.core.FrozenDict


Expand Down Expand Up @@ -243,6 +245,17 @@ class Trainer(config_util.BaseConfig):

def __post_init__(self):

# Freeze the mutable fields as they are passed to `jit` functions.
immutabledict.freeze_dict_attrs(
self,
(
'train_losses',
'train_metrics',
'train_summaries',
'schedules',
),
)

# It's convenient to set `cfg.evals = None`,... to disable evaluation
for name, default_factory in {
'evals': FrozenDict,
Expand Down
20 changes: 20 additions & 0 deletions kauldron/utils/immutabledict/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Mini immutabledict library."""

# pylint: disable=g-importing-member

from kauldron.utils.immutabledict.immutabledict_lib import ImmutableDict
from kauldron.utils.immutabledict.utils import freeze_dict_attrs
117 changes: 117 additions & 0 deletions kauldron/utils/immutabledict/immutabledict_lib.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright 2024 The kauldron Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Immutable dict util."""

from __future__ import annotations

from collections.abc import Hashable
import sys
from typing import Any, ClassVar

from etils import epy
import immutabledict as immutabledict_lib
from packaging import version

_IMMUTABLE_DICT_V4 = version.parse(
immutabledict_lib.__version__
) >= version.Version('4.0.0')


class ImmutableDict(immutabledict_lib.immutabledict):
"""Immutable dict abstraction with `getattr` access."""

_dca_jax_tree_registered: ClassVar[bool] = False
_flax_registered: ClassVar[bool] = False

def __new__(cls, *args: Any, **kwargs: Any) -> ImmutableDict:
if not cls._dca_jax_tree_registered and 'jax' in sys.modules:
import jax # pylint: disable=g-import-not-at-top # pytype: disable=import-error

jax.tree_util.register_pytree_with_keys_class(cls)
cls._dca_jax_tree_registered = True

if not cls._flax_registered and 'flax' in sys.modules:
import flax # pylint: disable=g-import-not-at-top,g-bad-import-order # pytype: disable=import-error

for type_ in list(flax.serialization._STATE_DICT_REGISTRY): # pylint: disable=undefined-variable
match type_:
case object(
__name__='ImmutableDict',
__module__='kauldron.konfig.immutabledict_lib',
):
del flax.serialization._STATE_DICT_REGISTRY[type_] # pylint: disable=undefined-variable

def restore_immutable_dict(*args, **kwargs):
d = flax.serialization._restore_dict(*args, **kwargs) # pylint: disable=protected-access
return cls(d)

flax.serialization.register_serialization_state(
cls,
flax.serialization._dict_state_dict, # pylint: disable=protected-access
restore_immutable_dict,
)
cls._flax_registered = True

if _IMMUTABLE_DICT_V4:
# immutabledict 4.0.0 switched from using __init__ to __new__ and thus
# requires passing the args and kwargs along here.
return super().__new__(cls, *args, **kwargs) # pylint: disable=no-value-for-parameter
else:
return super().__new__(cls)

def __getattr__(self, name: str) -> str:
# The base-class has a `dict_cls` attribute, but collisions should be
# extremely rare.
return self[name]

def __repr__(self) -> str:
return epy.Lines.make_block(
header=f'{self.__class__.__name__}',
content={repr(k): v for k, v in self._dict.items()},
braces=('({', '})'),
equal=': ',
)

# Jax tree_utils protocol

def tree_flatten_with_keys(self) -> tuple[tuple[Any, ...], Hashable]:
"""Flattens this FrozenDict.
Returns:
A flattened version of this FrozenDict instance.
"""
import jax # pylint: disable=g-import-not-at-top # pytype: disable=import-error

sorted_keys = sorted(self)
return tuple(
[(jax.tree_util.DictKey(k), self[k]) for k in sorted_keys]
), tuple(self)

@classmethod
def tree_unflatten(cls, keys, values):
# Flatten sort the keys, so reconstruct the ordered sorted
ordered_items = {k: v for k, v in zip(sorted(keys), values)}
# Restore original dict order
new_items = ((k, ordered_items[k]) for k in keys)

return cls(new_items)

# Pickle protocol

def __getstate__(self):
return self._dict

def __setstate__(self, state):
self.__init__(state)
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Test."""

import pickle

import cloudpickle
from etils import epy
import jax
from kauldron.konfig import immutabledict_lib
from kauldron.utils import immutabledict
import pytest


def test_dict():
d = immutabledict_lib.ImmutableDict({
d = immutabledict.ImmutableDict({
'z': 1,
'a': 2,
'w': 3,
})
d = jax.tree.map(lambda x: x * 10, d)
assert d == immutabledict_lib.ImmutableDict({
assert d == immutabledict.ImmutableDict({
'z': 10,
'a': 20,
'w': 30,
Expand All @@ -41,7 +39,7 @@ def test_dict():


def test_dict_repr():
d = immutabledict_lib.ImmutableDict({
d = immutabledict.ImmutableDict({
'z': 1,
'a': 2,
'w': 'aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa',
Expand All @@ -58,7 +56,7 @@ def test_dict_repr():

@pytest.mark.parametrize('pkl', [pickle, cloudpickle])
def test_dict_pickle(pkl):
a = immutabledict_lib.ImmutableDict({
a = immutabledict.ImmutableDict({
'z': 1,
'a': 2,
})
Expand Down
Loading

0 comments on commit d9cb563

Please sign in to comment.