diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 32c58af842..9a8bb1e742 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -33,15 +33,8 @@ from .nnx.module import M as M from .nnx.module import Module as Module from .nnx.graph_utils import merge as merge -from .nnx.graph_utils import full_merge as full_merge from .nnx.graph_utils import split as split -from .nnx.graph_utils import full_split as full_split from .nnx.graph_utils import update as update -from .nnx.graph_utils import full_update as full_update -from .nnx.graph_utils import clone as clone -from .nnx.graph_utils import pop as pop -from .nnx.rnglib import init as init -from .nnx.rnglib import empty as empty from .nnx.nn import initializers as initializers from .nnx.nn.activations import celu as celu from .nnx.nn.activations import elu as elu diff --git a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py b/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py index 3e95fc22de..c81db63eef 100644 --- a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py +++ b/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py @@ -52,10 +52,5 @@ def __call__(self, x): # split the parameters into trainable and non-trainable parameters trainable_params, non_trainable, static = model.split(is_trainable, ...) -print( - 'trainable_params =', - jax.tree_util.tree_map(jax.numpy.shape, trainable_params), -) -print( - 'non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable) -) +print('trainable_params =', jax.tree_util.tree_map(jax.numpy.shape, trainable_params)) +print('non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable)) diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph_utils.py index e5d7ccb25f..8d56776fe9 100644 --- a/flax/experimental/nnx/nnx/graph_utils.py +++ b/flax/experimental/nnx/nnx/graph_utils.py @@ -37,13 +37,8 @@ CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.state import ( - FlatState, - State, - StateLeaf, - is_state_leaf, -) -from flax.experimental.nnx.nnx.variables import EMPTY, Variable +from flax.experimental.nnx.nnx.state import State, StateLeaf, is_state_leaf +from flax.experimental.nnx.nnx.variables import EMPTY, Empty, Variable from flax.typing import PathParts, Key A = tp.TypeVar('A') @@ -104,11 +99,8 @@ def __eq__(self, other: tp.Any) -> bool: class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]): """A mapping that uses object id as the hash for the keys.""" - def __init__( - self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), / - ): + def __init__(self): self._mapping: dict[_HashById[A], B] = {} - self.update(mapping) def __getitem__(self, key: A) -> B: return self._mapping[_HashById(key)] @@ -132,7 +124,6 @@ def __str__(self) -> str: return repr(self) - @dataclasses.dataclass(frozen=True) class NodeImplBase(tp.Generic[Node, Leaf, AuxData]): type: type @@ -323,19 +314,20 @@ def __eq__(self, other): and self._metadata == other._metadata ) -@dataclasses.dataclass(frozen=True) -class NodeDef(tp.Generic[Node], reprlib.Representable): - type: tp.Type[Node] - index: int - attributes: tuple[Key, ...] - subgraphs: _HashableMapping[Key, tp.Union['NodeDef[tp.Any]', int]] - static_fields: _HashableMapping[Key, tp.Any] - variables: _HashableMapping[Key, VariableDef | int] - metadata: tp.Any - @classmethod - def create( - cls, +class GraphDef(tp.Generic[Node], reprlib.Representable): + __slots__ = ( + '_type', + '_index', + '_attributes', + '_subgraphs', + '_static_fields', + '_variables', + '_metadata', + ) + + def __init__( + self, type: tp.Type[Node], index: int, attributes: tuple[Key, ...], @@ -344,51 +336,60 @@ def create( variables: tp.Iterable[tuple[Key, VariableDef | int]], metadata: tp.Any, ): - return cls( - type=type, - index=index, - attributes=attributes, - subgraphs=_HashableMapping(subgraphs), - static_fields=_HashableMapping(static_fields), - variables=_HashableMapping(variables), - metadata=metadata, - ) + self._type: type[Node] = type + self._index = index + self._attributes = attributes + self._subgraphs = _HashableMapping(subgraphs) + self._static_fields = _HashableMapping(static_fields) + self._variables = _HashableMapping(variables) + self._metadata = metadata def __nnx_repr__(self): yield reprlib.Object(type=type(self)) - yield reprlib.Attr('type', self.type.__name__) - yield reprlib.Attr('index', self.index) - yield reprlib.Attr('attributes', self.attributes) - yield reprlib.Attr('subgraphs', _MappingRepr(self.subgraphs)) - yield reprlib.Attr('static_fields', _MappingRepr(self.static_fields)) - yield reprlib.Attr('variables', _MappingRepr(self.variables)) - yield reprlib.Attr('metadata', self.metadata) + yield reprlib.Attr('type', self._type.__name__) + yield reprlib.Attr('index', self._index) + yield reprlib.Attr('attributes', self._attributes) + yield reprlib.Attr('subgraphs', _MappingRepr(self._subgraphs)) + yield reprlib.Attr('static_fields', _MappingRepr(self._static_fields)) + yield reprlib.Attr('variables', _MappingRepr(self._variables)) + yield reprlib.Attr('metadata', self._metadata) + def __hash__(self) -> int: + return hash((self._type, self._subgraphs)) -@dataclasses.dataclass(frozen=True) -class GraphDef(tp.Generic[Node], reprlib.Representable): - nodedef: NodeDef[Node] - index_mapping: dict[Index, Index] | None + def __eq__(self, other: tp.Any) -> bool: + if not isinstance(other, GraphDef): + return False + return self._type == other._type and self._subgraphs == other._subgraphs - def __nnx_repr__(self): - yield reprlib.Object(type=type(self)) + @property + def type(self) -> tp.Type[Node]: + return self._type - yield reprlib.Attr('nodedef', self.nodedef) - yield reprlib.Attr('index_mapping', self.index_mapping) + @property + def index(self) -> int: + return self._index - def __deepcopy__(self, memo=None): - nodedef = deepcopy(self.nodedef, memo) - index_mapping = deepcopy(self.index_mapping, memo) - return GraphDef(nodedef, index_mapping) + @property + def attributes(self) -> tuple[str, ...]: + return self._attributes - def __hash__(self): - # refmap is opaque - return hash(self.nodedef) + @property + def subgraphs(self): + return self._subgraphs - def __eq__(self, other): - # refmap is opaque - return isinstance(other, GraphDef) and self.nodedef == other.nodedef + @property + def static_fields(self): + return self._static_fields + + @property + def variables(self): + return self._variables + + @property + def metadata(self) -> tp.Any: + return self._metadata def merge(self, state: State, /, *states: State) -> Node: if states: @@ -403,7 +404,7 @@ def apply( def _apply( accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tuple[GraphDef[Node], State]]: - module = merge(self, state, *states) + module = self.merge(state, *states) fn = accessor(module) out = fn(*args, **kwargs) return out, graph_flatten(module)[:2] @@ -411,83 +412,88 @@ def _apply( return CallableProxy(_apply, accessor) # type: ignore def make_empty(self) -> Node: - return merge(self, State({})) - - -def _graphdef_flatten(graphdef: GraphDef[Node]): - # refmap is opaque, we don't propagate it - static = (graphdef.nodedef, graphdef.index_mapping) - return (), static + return self.merge(State({})) + + +def _gradphdef_flatten(graphdef: GraphDef[tp.Any]): + return (), ( + graphdef._type, + graphdef._index, + graphdef._attributes, + graphdef._subgraphs, + graphdef._static_fields, + graphdef._variables, + graphdef._metadata, + ) def _graphdef_unflatten( - static: tuple[NodeDef[Node], dict[Index, Index] | None], _nodes: tuple[()] -): - nodedef, index_mapping = static - return GraphDef(nodedef, index_mapping) + metadata: tuple[ + tp.Type[Node], + int, + tuple[Key, ...], + tuple[tuple[Key, GraphDef[Node] | int], ...], + tuple[tuple[Key, tp.Any], ...], + tuple[tuple[Key, Variable[Empty] | int], ...], + tp.Any, + ], + _, +) -> GraphDef[Node]: + return GraphDef(*metadata) jax.tree_util.register_pytree_node( - GraphDef, - _graphdef_flatten, - _graphdef_unflatten, + GraphDef, _gradphdef_flatten, _graphdef_unflatten ) def graph_flatten( x: Node, /, - *, - idxmap: dict[Index, tp.Any] | None = None, -) -> tuple[GraphDef[Node], State, RefMap[tp.Any, Index]]: - refmap = RefMap[tp.Any, Index]() +) -> tuple[GraphDef[Node], State, tp.Mapping[tp.Any, Index]]: + ref_to_index = RefMap[tp.Any, Index]() flat_state: dict[PathParts, StateLeaf] = {} - nodedef = _graph_flatten((), refmap, flat_state, x) - assert not isinstance(nodedef, int) - if idxmap is not None: - index_to_index = compose_mapping(idxmap, refmap) - else: - index_to_index = None - graphdef = GraphDef(nodedef, index_to_index) - return graphdef, State.from_flat_path(flat_state), refmap + graphdef = _graph_flatten((), ref_to_index, flat_state, x) + assert not isinstance(graphdef, int) + return graphdef, State.from_flat_path(flat_state), ref_to_index def _graph_flatten( path: PathParts, - refmap: RefMap[tp.Any, Index], + ref_to_index: RefMap[tp.Any, Index], flat_state: dict[PathParts, StateLeaf], node: Node, -) -> NodeDef[Node] | int: +) -> GraphDef[Node] | int: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') - if node in refmap: - return refmap[node] + if node in ref_to_index: + return ref_to_index[node] node_impl = get_node_impl(node) # only cache graph nodes if isinstance(node_impl, GraphNodeImpl): - index = len(refmap) - refmap[node] = index + index = len(ref_to_index) + ref_to_index[node] = index else: index = -1 - subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], int]]] = [] + subgraphs: list[tuple[Key, tp.Union[GraphDef[Node], int]]] = [] static_fields: list[tuple[Key, tp.Any]] = [] variables: list[tuple[Key, VariableDef | int]] = [] values, metadata = node_impl.flatten(node) for key, value in values: if is_node(value): - nodedef = _graph_flatten((*path, key), refmap, flat_state, value) - subgraphs.append((key, nodedef)) + graphdef = _graph_flatten((*path, key), ref_to_index, flat_state, value) + subgraphs.append((key, graphdef)) elif isinstance(value, Variable): - if value in refmap: - variables.append((key, refmap[value])) + if value in ref_to_index: + variables.append((key, ref_to_index[value])) else: flat_state[(*path, key)] = value.copy() - variable_index = refmap[value] = len(refmap) + variable_index = ref_to_index[value] = len(ref_to_index) variables.append( (key, VariableDef.from_variable(value, variable_index)) ) @@ -496,7 +502,7 @@ def _graph_flatten( else: static_fields.append((key, value)) - nodedef = NodeDef.create( + graphdef = GraphDef( type=node_impl.type, index=index, attributes=tuple(key for key, _ in values), @@ -505,7 +511,7 @@ def _graph_flatten( variables=variables, metadata=metadata, ) - return nodedef + return graphdef def graph_unflatten( @@ -513,12 +519,12 @@ def graph_unflatten( state: State, /, *, - idxmap: dict[Index, tp.Any] | None = None, + ref_cache: dict[Index, tp.Any] | None = None, ) -> tuple[Node, dict[Index, tp.Any]]: """Unflattens a graphdef into a node with the given state. Args: - graphdef: A NodeDef instance. + graphdef: A GraphDef instance. state: A State instance. ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the @@ -527,22 +533,20 @@ def graph_unflatten( specified by the graphdef. """ index_to_ref: dict[Index, tp.Any] = {} - node = _graph_unflatten( - graphdef.nodedef, state.raw_mapping, index_to_ref, idxmap - ) + node = _graph_unflatten(graphdef, state.raw_mapping, index_to_ref, ref_cache) return node, index_to_ref def _graph_unflatten( - nodedef: tp.Union[NodeDef[Node], int], - state: dict[Key, StateLeaf | dict[Key, tp.Any]], + graphdef: tp.Union[GraphDef[Node], int], + state: dict[str, StateLeaf | dict[str, tp.Any]], index_to_ref: dict[Index, tp.Any], - idxmap: dict[Index, tp.Any] | None, + ref_cache: dict[Index, tp.Any] | None, ) -> Node: """Recursive helper for graph_unflatten. Args: - nodedef: A NodeDef instance or an index to a node in the cache. + graphdef: A GraphDef instance or an index to a node in the cache. state: A mapping from attribute names to variables or subgraphs. index_to_ref: A mapping from indexes to nodes that have been traversed. If a node is already in the cache, it won't be traversed again. @@ -550,31 +554,31 @@ def _graph_unflatten( When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology - specified by the nodedef. + specified by the graphdef. """ - if isinstance(nodedef, int): - return index_to_ref[nodedef] + if isinstance(graphdef, int): + return index_to_ref[graphdef] - if not is_node_type(nodedef.type): - raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') + if not is_node_type(graphdef.type): + raise RuntimeError(f'Unsupported type: {graphdef.type}, this is a bug.') - if nodedef.index in index_to_ref: - raise RuntimeError(f'NodeDef index {nodedef.index} already used.') + if graphdef.index in index_to_ref: + raise RuntimeError(f'GraphDef index {graphdef.index} already used.') - node_impl = get_node_impl_for_type(nodedef.type) + node_impl = get_node_impl_for_type(graphdef.type) def _get_children(): children: dict[str, StateLeaf | Node] = {} - for key in nodedef.attributes: - if key in nodedef.static_fields: - children[key] = nodedef.static_fields[key] + for key in graphdef.attributes: + if key in graphdef.static_fields: + children[key] = graphdef.static_fields[key] elif key not in state: # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys? # if key is not present create an empty types - if key in nodedef.subgraphs: + if key in graphdef.subgraphs: # if the key is a subgraph we create an empty node - subgraphdef = nodedef.subgraphs[key] + subgraphdef = graphdef.subgraphs[key] if isinstance(subgraphdef, int): # subgraph exists, take it from the cache children[key] = index_to_ref[subgraphdef] @@ -582,17 +586,17 @@ def _get_children(): # create an empty node and add it to the cache substate = {} node = children[key] = _graph_unflatten( - subgraphdef, substate, index_to_ref, idxmap + subgraphdef, substate, index_to_ref, ref_cache ) - elif key in nodedef.variables: - variable_def = nodedef.variables[key] + elif key in graphdef.variables: + variable_def = graphdef.variables[key] if isinstance(variable_def, int): # variable exists, take it from the cache children[key] = index_to_ref[variable_def] else: # create an empty variable and add it to the cache - if idxmap is not None and variable_def.index in idxmap: - node = idxmap[variable_def.index] + if ref_cache is not None and variable_def.index in ref_cache: + node = ref_cache[variable_def.index] if type(node) != variable_def.type: raise ValueError( f'Expected a node of type {variable_def.type.__name__} for ' @@ -609,23 +613,23 @@ def _get_children(): raise RuntimeError(f'Unknown static field: {key!r}') else: value = state[key] - if key in nodedef.subgraphs: + if key in graphdef.subgraphs: if is_state_leaf(value): raise ValueError( f'Expected a subgraph for {key!r}, but got a Variable.' ) assert isinstance(value, dict) - subgraphdef = nodedef.subgraphs[key] + subgraphdef = graphdef.subgraphs[key] if isinstance(subgraphdef, int): node = index_to_ref[subgraphdef] else: node = children[key] = _graph_unflatten( - subgraphdef, value, index_to_ref, idxmap + subgraphdef, value, index_to_ref, ref_cache ) - elif key in nodedef.variables: - variable_def = nodedef.variables[key] + elif key in graphdef.variables: + variable_def = graphdef.variables[key] if isinstance(variable_def, int): children[key] = index_to_ref[variable_def] else: @@ -635,8 +639,8 @@ def _get_children(): f'for {key!r}, but got a Variable of type {type(value)}.' ) assert isinstance(value, Variable) - if idxmap is not None and variable_def.index in idxmap: - variable = idxmap[variable_def.index] + if ref_cache is not None and variable_def.index in ref_cache: + variable = ref_cache[variable_def.index] if type(variable) != variable_def.type: raise ValueError( f'Expected a Variable of type {variable_def.type} for ' @@ -650,7 +654,7 @@ def _get_children(): index_to_ref[variable_def.index] = variable elif is_state_leaf(value): children[key] = value - for new_key in set(state) - set(nodedef.attributes): + for new_key in set(state) - set(graphdef.attributes): raise ValueError(f'Unknown key: {new_key!r}') return children @@ -658,24 +662,24 @@ def _get_children(): if isinstance(node_impl, GraphNodeImpl): # we create an empty node first and add it to the index # this avoids infinite recursion when there is a reference cycle - if idxmap is not None and nodedef.index in idxmap: - node = idxmap[nodedef.index] - if type(node) != nodedef.type: + if ref_cache is not None and graphdef.index in ref_cache: + node = ref_cache[graphdef.index] + if type(node) != graphdef.type: raise ValueError( - f'Expected a node of type {nodedef.type} for index ' - f'{nodedef.index}, but got a node of type {type(node)}.' + f'Expected a node of type {graphdef.type} for index ' + f'{graphdef.index}, but got a node of type {type(node)}.' ) - node_impl.clear(node, nodedef.metadata) + node_impl.clear(node, graphdef.metadata) else: - node = node_impl.create_empty(nodedef.metadata) - index_to_ref[nodedef.index] = node + node = node_impl.create_empty(graphdef.metadata) + index_to_ref[graphdef.index] = node children = _get_children() node_impl.init(node, tuple(children.items())) else: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first children = _get_children() - node = node_impl.unflatten(tuple(children.items()), nodedef.metadata) + node = node_impl.unflatten(tuple(children.items()), graphdef.metadata) return node @@ -712,11 +716,7 @@ def _graph_pop( for name, value in node_dict.items(): if is_node(value): _graph_pop( - node=value, - id_to_index=id_to_index, - path_parts=(*path_parts, name), - flat_states=flat_states, - predicates=predicates, + value, id_to_index, (*path_parts, name), flat_states, predicates ) continue elif not is_state_leaf(value): @@ -743,7 +743,23 @@ def _graph_pop( pass -def _graph_update_dynamic(node: tp.Any, state: dict[Key, tp.Any]): +def graph_update_dynamic( + node: tp.Any, + updates: State | tp.Sequence[State], +) -> None: + if not is_node(node): + raise ValueError(f'Unsupported type: {type(node)}') + + if isinstance(updates, State): + new_states = (updates,) + else: + new_states = updates + + for state in new_states: + _graph_update_dynamic(node, state.raw_mapping) + + +def _graph_update_dynamic(node: tp.Any, state: dict[str, tp.Any]): if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') @@ -796,14 +812,7 @@ class _StaticModuleStatus(enum.Enum): NEW = enum.auto() UPDATED = enum.auto() -# TODO(cgarciae): remove once transform init are reimplemented -def update_from(node: Node, updates: Node) -> None: - graph_update_static(node, updates) - _, state = split(updates) - update(node, state) - -# TODO(cgarciae): remove once transform init are reimplemented def graph_update_static(node: Node, updates: Node) -> None: cache: dict[int, _StaticModuleStatus] = {} _graph_update_static(node, updates, cache, _StaticModuleStatus.UPDATED, ()) @@ -894,49 +903,33 @@ def _graph_update_static( node_impl.set_key(node, name, value_updates) - @tp.overload -def full_split( - graph_node: A, - *, - idxmap: dict[Index, tp.Any] | None = None, -) -> tuple[RefMap[tp.Any, Index], GraphDef[A], State]: +def split(graph_node: A) -> tuple[GraphDef[A], State]: ... @tp.overload -def full_split( - graph_node: A, - first: filterlib.Filter, - /, - *, - idxmap: dict[Index, tp.Any] | None = None, -) -> tuple[RefMap[tp.Any, Index], GraphDef[A], State]: +def split( + graph_node: A, first: filterlib.Filter, / +) -> tuple[GraphDef[A], State]: ... @tp.overload -def full_split( +def split( graph_node: A, first: filterlib.Filter, second: filterlib.Filter, /, *filters: filterlib.Filter, - idxmap: dict[Index, tp.Any] | None = None, -) -> tuple[ - RefMap[tp.Any, Index], GraphDef[A], State, tpe.Unpack[tuple[State, ...]] -]: +) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: ... -def full_split( - graph_node: A, - *filters: filterlib.Filter, - idxmap: dict[Index, tp.Any] | None = None, -) -> tuple[ - RefMap[tp.Any, Index], GraphDef[A], State, tpe.Unpack[tuple[State, ...]] -]: - graphdef, state, refmap = graph_flatten(graph_node, idxmap=idxmap) +def split( + graph_node: A, *filters: filterlib.Filter +) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: + graphdef, state, _ = graph_flatten(graph_node) if len(filters) == 0: states = (state,) @@ -945,71 +938,7 @@ def full_split( else: states = state.split(filters[0], filters[1], *filters[2:]) - return refmap, graphdef, states[0], *states[1:] - - -def full_merge( - graphdef: GraphDef[A], - state: State, - *states: State, -) -> tuple[A, dict[Index, tp.Any]]: - # TODO: add docstring of example usage - if states: - state = State.merge(state, *states) - - return graph_unflatten(graphdef, state) - - -def full_update( - refmap: RefMap[tp.Any, Index], - new_graphdef: GraphDef[A], - state: State, - /, - *states: State, -): - if refmap is None: - raise ValueError('Cannot update a graphdef without refmap.') - if new_graphdef.index_mapping is None: - raise ValueError('Cannot update a graphdef without index_mapping.') - - if states: - state = State.merge(state, *states) - - index_to_ref = compose_mapping_reversed(refmap, new_graphdef.index_mapping) - return graph_unflatten(new_graphdef, state, idxmap=index_to_ref)[0] - - -@tp.overload -def split(graph_node: A, /) -> tuple[GraphDef[A], State]: - ... - - -@tp.overload -def split( - graph_node: A, - first: filterlib.Filter, - /, -) -> tuple[GraphDef[A], State]: - ... - - -@tp.overload -def split( - graph_node: A, - first: filterlib.Filter, - second: filterlib.Filter, - /, - *filters: filterlib.Filter, -) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: - ... - - -def split( - graph_node: A, - *filters: filterlib.Filter, -) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: - _, graphdef, state, *states = full_split(graph_node, *filters) - return graphdef, state, *states + return graphdef, states[0], *states[1:] def merge( @@ -1018,94 +947,57 @@ def merge( *states: State, ) -> A: # TODO: add docstring of example usage - return full_merge(graphdef, state, *states)[0] - - -def update(node, state: State, *states: State) -> None: - if states: - state = State.merge(state, *states) - - _graph_update_dynamic(node, state.raw_mapping) - - -@tp.overload -def extract(node, first: filterlib.Filter, /) -> State: - ... - - -@tp.overload -def extract( - node, - first: filterlib.Filter, - second: filterlib.Filter, - /, - *filters: filterlib.Filter, -) -> tuple[State, ...]: - ... - + return graphdef.merge(state, *states) -def extract( - node, - first: filterlib.Filter, - /, - *filters: filterlib.Filter, -) -> tp.Union[State, tuple[State, ...]]: - state = graph_flatten(node)[1] - - if len(filters) == 0: - states = state.extract(first) - else: - states = state.extract(first, filters[0], *filters[1:]) - - return states - - -@tp.overload -def pop( - node, - filter: filterlib.Filter, - /, -) -> State: - ... +def update(graph_node: A, update: Updates[A], /, *updates: Updates[A]) -> None: + updates = (update, *updates) -@tp.overload -def pop( - node, - filter: filterlib.Filter, - filter2: filterlib.Filter, - /, - *filters: filterlib.Filter, -) -> tuple[State, ...]: - ... + # find states and module_update + leaves = jax.tree_util.tree_leaves( + updates, is_leaf=lambda x: isinstance(x, (GraphDef, State)) + ) + states: list[State] = [] + module_update: tp.Optional[A] = None + for leaf in leaves: + if is_graph_node(leaf) or isinstance(leaf, GraphDef): + if module_update is not None: + raise ValueError( + 'Expected only one GraphDef or GraphNode in the updates' + ) -def pop(node, *filters: filterlib.Filter) -> tp.Union[State, tuple[State, ...]]: - if len(filters) == 0: - raise ValueError('Expected at least one filter') + if is_graph_node(leaf): + if not isinstance(leaf, type(graph_node)): + raise ValueError( + 'Expected a GraphNode of the same type as the input, ' + f'got {type(leaf).__name__} instead.' + ) + module_update = leaf + states.append(split(leaf)[1]) + elif isinstance(leaf, GraphDef): + module_update = leaf.make_empty() + else: + raise ValueError( + 'Expected a GraphDef or graph node, got' f' {type(leaf).__name__}' + ) + elif isinstance(leaf, State): + states.append(leaf) + else: + raise ValueError( + 'Expected a GraphDef, GraphNode or State, got' f' {type(leaf).__name__}' + ) - id_to_index: dict[int, Index] = {} - path_parts: PathParts = () - predicates = tuple(filterlib.to_predicate(filter) for filter in filters) - flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates) - _graph_pop( - node=node, - id_to_index=id_to_index, - path_parts=path_parts, - flat_states=flat_states, - predicates=predicates, - ) - states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + if module_update is not None: + graph_update_static(graph_node, module_update) - if len(states) == 1: - return states[0] - else: - return states + if states: + graph_update_dynamic(graph_node, states) def clone(node: Node) -> Node: - static, state = split(node) - return merge(static, state) + static, state, _ = graph_flatten(node) + return static.merge(state) def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[PathParts, tp.Any]]: @@ -1117,31 +1009,16 @@ def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[PathParts, tp.Any]]: def _iter_nodes( node: tp.Any, visited: set[int], path_parts: PathParts ) -> tp.Iterator[tuple[PathParts, tp.Any]]: - for path_parts, value in _iter_all(node, visited, path_parts): - if is_node(value): - yield path_parts, value - - -def _iter_node_or_variable( - x: tp.Any, visited: set[int], path_parts: PathParts -) -> tp.Iterator[tuple[PathParts, tp.Any]]: - for path_parts, value in _iter_all(x, visited, path_parts): - if is_node(value) or isinstance(value, Variable): - yield path_parts, value - - -def _iter_all( - x: tp.Any, visited: set[int], path_parts: PathParts -) -> tp.Iterator[tuple[PathParts, tp.Any]]: - if id(x) in visited: + if not is_node(node): + return + if id(node) in visited: return - visited.add(id(x)) - yield path_parts, x - if is_node(x): - node_impl = get_node_impl(x) - node_dict = node_impl.node_dict(x) - for key, value in node_dict.items(): - yield from _iter_all(value, visited, (*path_parts, key)) + visited.add(id(node)) + yield path_parts, node + node_impl = get_node_impl(node) + node_dict = node_impl.node_dict(node) + for key, value in node_dict.items(): + yield from _iter_nodes(value, visited, (*path_parts, key)) def compose_mapping( @@ -1280,10 +1157,10 @@ def check_valid_context(self, error_msg: str) -> None: raise errors.TraceContextError(error_msg) def __deepcopy__(self: G, memo=None) -> G: - graphdef, state = graph_utils.split(self) + graphdef, state, _ = graph_utils.graph_flatten(self) graphdef = deepcopy(graphdef) state = deepcopy(state) - return merge(graphdef, state) + return graphdef.merge(state) def __hash__(self) -> int: return hash(self._graph_node__state.id) diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index 6a231601cc..cb1c54dd99 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -20,6 +20,7 @@ import jax import jax.tree_util as jtu +import typing_extensions as tpe from flax.experimental.nnx.nnx import ( filterlib, @@ -71,6 +72,10 @@ def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M: class Module(graph_utils.GraphNode, metaclass=ModuleMeta): + @classmethod + def init(cls: type[M], *args, **kwargs) -> tuple[GraphDef[M], State]: + return cls(*args, **kwargs).split() + @classmethod @property def create_abstract(cls: type[M]) -> type[M]: @@ -139,13 +144,13 @@ def _partial_init(accessor: DelayedAccessor, *args, **kwargs): def _partial_init_constructor(): module = constructor(*args, **lift_rngs(kwargs)) - graph_utils.update(module, *states) - return graph_utils.split(module) + module.update(*states) + return module.split() graphdef: GraphDef[M] state: State graphdef, state = jax.jit(_partial_init_constructor)() - module = graph_utils.merge(graphdef, state) + module = graphdef.merge(state) return module return CallableProxy(_partial_init) # type: ignore @@ -177,11 +182,11 @@ def split( return graph_utils.split(self, *filters) def get_state(self) -> State: - _, state = graph_utils.split(self) + _, state = self.split() return state def get_graphdef(self: M) -> GraphDef[M]: - graphdef, _ = graph_utils.split(self) + graphdef, _ = self.split() return graphdef @tp.overload @@ -247,15 +252,17 @@ def pop( @property def apply(self: M) -> ApplyCaller[M]: def _apply(accessor: DelayedAccessor, *args, **kwargs) -> tuple[tp.Any, M]: - module = graph_utils.clone(self) + module = self.clone() fn = accessor(module) out = fn(*args, **kwargs) return out, module return CallableProxy(_apply) # type: ignore - def update(self, state: State, *states: State) -> None: - graph_utils.update(self, state, *states) + def update( + self: M, update: graph_utils.Updates[M], /, *updates: graph_utils.Updates[M] + ) -> None: + graph_utils.update(self, update, *updates) def sow( self, @@ -361,7 +368,7 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None: # Pytree Definition # ------------------------- def _module_flatten(module: Module, *, with_keys: bool): - graphdef, state = graph_utils.split(module) + graphdef, state = module.split() key_values = sorted(state.raw_mapping.items()) keys = tuple(key for key, _ in key_values) diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/experimental/nnx/nnx/rnglib.py index 6de6caff1b..84e91a29b7 100644 --- a/flax/experimental/nnx/nnx/rnglib.py +++ b/flax/experimental/nnx/nnx/rnglib.py @@ -34,13 +34,10 @@ import jax import jax.numpy as jnp -from flax.experimental.nnx.nnx import graph_utils from flax.experimental.nnx.nnx.variables import Variable from flax.experimental.nnx.nnx import filterlib from flax.experimental.nnx.nnx.graph_utils import GraphNode -from flax.typing import Dtype, Shape -A = tp.TypeVar('A') Counts = list[int] AxesValue = tp.Union[int, None] Pattern = tp.Union[AxesValue, tuple[AxesValue, ...]] @@ -279,46 +276,3 @@ def _split_rng_unflatten( _split_rng_unflatten, flatten_func=functools.partial(_split_rng_flatten, with_keys=False), ) - -@tp.runtime_checkable -class _HasRngInit(tp.Protocol): - def rng_init(self, rngs: Rngs): - ... - - -@tp.overload -def init(node: A, rngs: Rngs, /) -> A: - ... - - -@tp.overload -def init( - node: A, - default: RngValue | RngDict | None = None, - /, - **rngs: RngValue, -) -> A: - ... - - -def init(node: A, *args, **kwargs) -> A: - if len(args) > 0 and isinstance(args[0], Rngs): - if len(args) > 1: - raise ValueError( - 'Too many positional arguments, expected at most 1 Rngs.' - ) - if len(kwargs) > 0: - raise ValueError( - 'Cannot use keyword arguments with Rngs positional argument.' - ) - rngs = args[0] - else: - rngs = Rngs(*args, **kwargs) - for _, value in graph_utils._iter_node_or_variable(node, set(), ()): - if isinstance(value, _HasRngInit): - value.rng_init(rngs) - return node - - -def empty(shape: Shape, dtype: Dtype = jax.numpy.float32, /) -> jax.Array: - return jax.ShapeDtypeStruct(shape, dtype) # type: ignore diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/experimental/nnx/nnx/spmd.py index 3b721f31ce..0554901487 100644 --- a/flax/experimental/nnx/nnx/spmd.py +++ b/flax/experimental/nnx/nnx/spmd.py @@ -106,7 +106,7 @@ def f(x): return _maybe_replicate(x) - return jax.tree_util.tree_map( + return jax.tree_map( f, tree, is_leaf=lambda x: isinstance(x, variables.Variable) ) diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py index 24a231b5f3..905abc5cec 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/experimental/nnx/nnx/state.py @@ -77,7 +77,7 @@ def __init__( super().__setattr__('_mapping', dict(mapping)) @property - def raw_mapping(self) -> dict[Key, dict[Key, tp.Any] | tp.Any]: + def raw_mapping(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: return self._mapping def __getitem__(self, key: Key) -> State | StateLeaf: diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index f1bfa6881d..400a94a3d0 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -349,7 +349,7 @@ def jit_apply( outer_ref_outer_idx, outer_idx_inner_idx ) (input_graph_nodes, output_graph_nodes), _ = graph_utils.graph_unflatten( - output_graphdef, output_state, idxmap=inner_idx_outer_ref + output_graphdef, output_state, ref_cache=inner_idx_outer_ref ) out = graph_utils.insert_graph_nodes(out, output_graph_nodes) @@ -717,7 +717,7 @@ def grad_fn(*args): _args = list(args) for i, graph_node in diff_graph_nodes.items(): diff_state: State = _args[i] - graph_utils.update(graph_node, diff_state) + graph_utils.graph_update_dynamic(graph_node, diff_state) _args[i] = graph_node out = f(*_args) @@ -750,7 +750,7 @@ def grad_fn(*args): else: out, updates = out - graph_utils.update((input_nodes, out_nodes), updates) + graph_utils.graph_update_dynamic((input_nodes, out_nodes), updates) return out @@ -1080,7 +1080,7 @@ def _init_state(split_keys, broadcast_keys): for state, index in zip(axes_states, options.variable_axes.values()) ] - module = graph_utils.merge(graphdef, *axes_states, carry_state) + module = graphdef.merge(*axes_states, carry_state) return module @@ -1097,9 +1097,7 @@ def scan_apply( # split module state filters = (*options.variable_axes.keys(), ...) - refmap, graphdef, *scan_states, carry_state = graph_utils.full_split( - module, *filters - ) + graphdef, *scan_states, carry_state = module.split(*filters) # transpose axes state scan_states = tuple( @@ -1220,7 +1218,7 @@ def scan_fn( ] # merge module state - module, idxmap = graph_utils.full_merge(graphdef, *scan_states, carry_state) + module = graphdef.merge(*scan_states, carry_state) output = f(module, carry_arg, *args, **kwargs) @@ -1238,12 +1236,7 @@ def scan_fn( scan_out = None # split module state - ( - _, - moduledef_out, - *scan_states_out, - carry_state_out, - ) = graph_utils.full_split(module, *filters, idxmap=idxmap) + moduledef_out, *scan_states_out, carry_state_out = module.split(*filters) carry_state_new = carry_state_out - carry_state # remove new carry state @@ -1292,10 +1285,7 @@ def scan_fn( # slice new carry state carry_state_new = jax.tree_util.tree_map(lambda x: x[0], carry_state_new) - # module.update(((*scan_states, carry_state, carry_state_new), moduledef_out)) - graph_utils.full_update( - refmap, moduledef_out, *scan_states, carry_state, carry_state_new - ) + module.update(((*scan_states, carry_state, carry_state_new), moduledef_out)) if options.scan_output: return carry_out, scan_out @@ -1344,7 +1334,7 @@ def module_constructor(*args, **kwargs): return module lifted_module = scan_init(options, module_constructor, args, kwargs) - graph_utils.update_from(module, lifted_module) + module.update(lifted_module) wrapper = scan_init_wrapper @@ -1475,7 +1465,7 @@ def remat_apply( ): _check_args(args) - refmap, graphdef, state = graph_utils.full_split(module) + graphdef, state = module.split() keys = rngs.fork() if rngs is not None else None def _remat_fn( @@ -1487,11 +1477,11 @@ def _remat_fn( if keys is not None: kwargs['rngs'] = rnglib.Rngs(keys) - module, idxmap = graph_utils.full_merge(graphdef, state) + module = graphdef.merge(state) out = f(module, *args, **kwargs) - _, new_graphdef, new_state = graph_utils.full_split(module, idxmap=idxmap) - return (new_graphdef, new_state), out + def_and_state = module.split() + return def_and_state, out def_and_state: tuple[GraphDef[Module], State] def_and_state, out = jax.checkpoint( @@ -1500,9 +1490,8 @@ def _remat_fn( static_argnums=options.static_argnums, policy=options.policy, )(state, keys, *args) - new_graphdef, new_state = def_and_state - graph_utils.full_update(refmap, new_graphdef, new_state) + module.update(def_and_state) return out @@ -1727,7 +1716,7 @@ def _init_state(split_keys, broadcast_keys): for state, index in zip(axes_states, options.variable_axes.values()) ] - module = graph_utils.merge(graphdef, *axes_states, carry_state) + module = graphdef.merge(*axes_states, carry_state) return module @@ -1742,12 +1731,7 @@ def vmap_apply( # split module state filters = (*options.variable_axes.keys(), ...) - ( - refmap, - graphdef, - *vectorized_states, - broadcast_state, - ) = graph_utils.full_split(module, *filters) + graphdef, *vectorized_states, broadcast_state = module.split(*filters) # infer length axis_sizes: tp.Set[int] = set() @@ -1840,19 +1824,14 @@ def vmap_fn( ] # merge module state - module, idxmap = graph_utils.full_merge( - graphdef, *vectorized_states, broadcast_state - ) + module = graphdef.merge(*vectorized_states, broadcast_state) output = f(module, *args, **kwargs) # split module state - ( - _, - moduledef_out, - *vectorized_states_out, - broadcast_state_out, - ) = graph_utils.full_split(module, *filters, idxmap=idxmap) + moduledef_out, *vectorized_states_out, broadcast_state_out = module.split( + *filters + ) # add metadata axis name to Variable.sharding if spmd.PARTITION_NAME in options.vmap_metadata: @@ -1870,9 +1849,7 @@ def vmap_fn( ) assert moduledef_out is not None - graph_utils.full_update( - refmap, moduledef_out, *vectorized_states, broadcast_state - ) + module.update(((*vectorized_states, broadcast_state), moduledef_out)) return output @@ -1916,7 +1893,7 @@ def module_constructor(*args, **kwargs): return module lifted_module = vmap_init(options, module_constructor, args, kwargs) - graph_utils.update_from(module, lifted_module) + module.update(lifted_module) wrapper = vmap_init_wrapper diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index 2355c62d17..2d1589d3e7 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -25,7 +25,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations import dataclasses import functools @@ -39,7 +38,6 @@ from flax.experimental.nnx.nnx import reprlib, tracers from flax.experimental import nnx -from flax.typing import Initializer A = tp.TypeVar('A') B = tp.TypeVar('B') @@ -75,9 +73,6 @@ def __hash__(self): EMPTY = Empty() -@tp.runtime_checkable -class _HashInitializer(tp.Protocol): - initializer: Initializer @dataclasses.dataclass class VariableMetadata(tp.Generic[A]): @@ -160,6 +155,31 @@ def __init__( else: remove_axis_hooks = () + if isinstance(value, VariableMetadata): + value_metadata = dict(value.metadata) + if set_value_hooks and value.set_value_hooks: + set_value_hooks = set_value_hooks + value.set_value_hooks + elif value.set_value_hooks: + set_value_hooks = value.set_value_hooks + if get_value_hooks and value.get_value_hooks: + get_value_hooks = get_value_hooks + value.get_value_hooks + elif value.get_value_hooks: + get_value_hooks = value.get_value_hooks + if create_value_hooks and value.create_value_hooks: + create_value_hooks = create_value_hooks + value.create_value_hooks + elif value.create_value_hooks: + create_value_hooks = value.create_value_hooks + if add_axis_hooks and value.add_axis_hooks: + add_axis_hooks = add_axis_hooks + value.add_axis_hooks + elif value.add_axis_hooks: + add_axis_hooks = value.add_axis_hooks + if remove_axis_hooks and value.remove_axis_hooks: + remove_axis_hooks = remove_axis_hooks + value.remove_axis_hooks + elif value.remove_axis_hooks: + remove_axis_hooks = value.remove_axis_hooks + + metadata.update(value_metadata) + value = tp.cast(A, value.raw_value) if hasattr(self, 'on_get_value'): on_get_value = getattr(type(self), 'on_get_value') @@ -186,6 +206,7 @@ def __init__( if on_remove_axis not in remove_axis_hooks: remove_axis_hooks = (on_remove_axis, *remove_axis_hooks) + self.raw_value = value self.get_value_hooks = get_value_hooks self.set_value_hooks = set_value_hooks self.create_value_hooks = create_value_hooks @@ -194,29 +215,7 @@ def __init__( vars(self).update(metadata) # run create_value hooks - if isinstance(value, jax.ShapeDtypeStruct): - self.raw_value = value - else: - self.raw_value = self.create_value(value) - - def _setup_value(self, value: A | VariableMetadata[A]) -> A: - if isinstance(value, VariableMetadata): - value_metadata = dict(value.metadata) - if value.set_value_hooks: - self.set_value_hooks += value.set_value_hooks - elif value.get_value_hooks: - self.get_value_hooks += value.get_value_hooks - elif value.create_value_hooks: - self.create_value_hooks += value.create_value_hooks - elif value.add_axis_hooks: - self.add_axis_hooks += value.add_axis_hooks - elif value.remove_axis_hooks: - self.remove_axis_hooks += value.remove_axis_hooks - - vars(self).update(value_metadata) - value = tp.cast(A, value.raw_value) - - return value + self.raw_value = self.create_value(self.raw_value) if tp.TYPE_CHECKING: @@ -234,13 +233,6 @@ def _setattr(self, name: str, value: tp.Any): object.__setattr__(self, name, value) - def rng_init(self, rngs: 'nnx.Rngs'): - value = self.raw_value - if isinstance(value, jax.ShapeDtypeStruct) and isinstance(self, _HashInitializer): - value = self.initializer(rngs(), value.shape, value.dtype) - self.raw_value = self.create_value(value) - - def copy_from(self, other: 'Variable[A]') -> None: if not self.is_equivalent(other): raise ValueError( @@ -278,8 +270,7 @@ def value(self, value: A): value = hook(self, value) self.raw_value = value - def create_value(self, value: A | VariableMetadata[A]): - value = self._setup_value(value) + def create_value(self, value: A): for hook in self.create_value_hooks: value = hook(self, value) return value diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py index 300d696849..80ad57a9fe 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -25,22 +25,21 @@ def test_flatten(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - refmap, static, state = nnx.full_split(g) - assert refmap is not None + static, state, ref_idx = nnx.graph_utils.graph_flatten(g) state[0]['b'].raw_value = 2 state[3].raw_value = 4 - assert len(refmap) == 2 - assert a['b'] in refmap - assert g[3] in refmap + assert len(ref_idx) == 2 + assert a['b'] in ref_idx + assert g[3] in ref_idx def test_unflatten(self): a = nnx.Dict(a=1, b=nnx.Param(2)) g = nnx.List([a, 3, a, nnx.Param(4)]) - static, state = nnx.split(g) - g = nnx.merge(static, state) + static, state, _ = nnx.graph_utils.graph_flatten(g) + g = static.merge(state) assert g[0] is g[2] @@ -48,8 +47,8 @@ def test_unflatten_pytree(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - static, state = nnx.split(g) - g = nnx.merge(static, state) + static, state, _ = nnx.graph_utils.graph_flatten(g) + g = static.merge(state) assert g[0] is not g[2] @@ -57,7 +56,7 @@ def test_unflatten_empty(self): a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) g = nnx.List([a, 3, a, nnx.Param(4)]) - static, state = nnx.split(g) + static, state, _ = nnx.graph_utils.graph_flatten(g) g = static.merge(nnx.State({})) assert g[0] is g[2] @@ -68,10 +67,10 @@ def test_update_dynamic(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - static, state = nnx.split(g) + static, state, _ = nnx.graph_utils.graph_flatten(g) state[0]['b'].raw_value = 3 - nnx.graph_utils.update(g, state) + nnx.graph_utils.graph_update_dynamic(g, state) assert g[0]['b'].raw_value == 3 assert g[2]['b'].raw_value == 3 @@ -124,7 +123,7 @@ def test_module_list(self): nnx.BatchNorm(2, rngs=rngs), ] - static, state = nnx.split(ls) + static, state, _ = nnx.graph_utils.graph_flatten(ls) assert state[0]['kernel'].raw_value.shape == (2, 2) assert state[0]['bias'].raw_value.shape == (2,) @@ -137,7 +136,7 @@ def test_shared_variables(self): v = nnx.Param(1) g = [v, v] - static, state = nnx.split(g) + static, state, _ = nnx.graph_utils.graph_flatten(g) assert len(state.flat_state()) == 1 @@ -155,7 +154,7 @@ def __init__(self, *, rngs: nnx.Rngs) -> None: self.baz.kernel = self.bar.kernel node = Foo(rngs=nnx.Rngs(0)) - static, state = nnx.split(node) + static, state, _ = nnx.graph_utils.graph_flatten(node) assert len(state.flat_state()) == 3 # 2 bias + 1 kernel @@ -283,7 +282,7 @@ def __init__(self): assert 'tree' in state assert 'a' in state.tree - assert static.nodedef.subgraphs['tree'].type is nnx.graph_utils.PytreeType + assert static.subgraphs['tree'].type is nnx.graph_utils.PytreeType m2 = static.merge(state) @@ -328,7 +327,7 @@ def f_pure(static: nnx.graph_utils.GraphDef[Foo], state): ref_out_idx_out, idx_out_idx_in ) m2, _ = nnx.graph_utils.graph_unflatten( - static, state, idxmap=idx_in_ref_out + static, state, ref_cache=idx_in_ref_out ) assert m2 is m assert m2.a is b @@ -369,7 +368,7 @@ def f_pure(static: nnx.graph_utils.GraphDef[Foo], state): ref_out_idx_out, idx_out_idx_in ) m2, _ = nnx.graph_utils.graph_unflatten( - static, state, idxmap=idx_in_ref_out + static, state, ref_cache=idx_in_ref_out ) assert m2 is m assert m2.a is b @@ -407,61 +406,7 @@ def f_pure(static: nnx.graph_utils.GraphDef[Foo], state): ref_out_idx_out, idx_out_idx_in ) m2, _ = nnx.graph_utils.graph_unflatten( - static, state, idxmap=idx_in_ref_out + static, state, ref_cache=idx_in_ref_out ) assert m2 is m assert m2.ref is m2 - - def test_init_rngs(self): - class Linear(nnx.Module): - def __init__(self, din: int, dout: int): - self.kernel = nnx.Param( - nnx.empty((din, dout)), initializer=nnx.initializers.lecun_normal() - ) - self.bias = nnx.Param( - nnx.empty((dout,)), initializer=nnx.initializers.zeros - ) - - def __call__(self, x): - return x @ self.kernel.value + self.bias.value - - m = Linear(2, 2) - nnx.init(m, nnx.Rngs(0)) - assert isinstance(m.kernel.value, jax.Array) - assert isinstance(m.bias.value, jax.Array) - - def test_init_seed(self): - class Linear(nnx.Module): - def __init__(self, din: int, dout: int): - self.kernel = nnx.Param( - nnx.empty((din, dout)), initializer=nnx.initializers.lecun_normal() - ) - self.bias = nnx.Param( - nnx.empty((dout,)), initializer=nnx.initializers.zeros - ) - - def __call__(self, x): - return x @ self.kernel.value + self.bias.value - - m = Linear(2, 2) - nnx.init(m, 0) - assert isinstance(m.kernel.value, jax.Array) - assert isinstance(m.bias.value, jax.Array) - - def test_init_key(self): - class Linear(nnx.Module): - def __init__(self, din: int, dout: int): - self.kernel = nnx.Param( - nnx.empty((din, dout)), initializer=nnx.initializers.lecun_normal() - ) - self.bias = nnx.Param( - nnx.empty((dout,)), initializer=nnx.initializers.zeros - ) - - def __call__(self, x): - return x @ self.kernel.value + self.bias.value - - m = Linear(2, 2) - nnx.init(m, jax.random.key(0)) - assert isinstance(m.kernel.value, jax.Array) - assert isinstance(m.bias.value, jax.Array) diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index 9d71c2b879..ef44281217 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -51,14 +51,14 @@ def f(): def test_tree_map(self): m = nnx.Dict(a=nnx.Param(1)) - static, state = nnx.split(m) + static, state = m.split() state = jax.tree_util.tree_map(lambda x: x + 1, state) def test_split_2(self): m = nnx.Dict(a=nnx.Param(1)) - static, empty, some = nnx.split(m, None, ...) + empty, some, static = m.split(None, ...) some = jax.tree_util.tree_map(lambda x: x + 1, some) @@ -69,9 +69,9 @@ def test_split_merge(self): def g(graphdef: nnx.GraphDef[nnx.Dict[int]], state: nnx.State): m = graphdef.merge(state) m.a = 2 - return nnx.split(m) + return m.split() - graphdef, state = g(*nnx.split(m)) + graphdef, state = g(*m.split()) m2 = graphdef.merge(state) assert m2.a == 2 @@ -109,7 +109,7 @@ def test_shared_module(self): m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) - m3 = nnx.merge(*nnx.split(m2)) + m3 = nnx.merge(*m2.split()) assert m3['x'] is m3['y'] assert m3['x']['a'] is m3['y']['a'] @@ -123,7 +123,7 @@ def __init__(self): m = Foo() - graphdef, state = nnx.split(m) + graphdef, state = m.split() assert len(state) == 1 m2 = graphdef.merge(state) @@ -142,9 +142,9 @@ def f(graphdef: nnx.GraphDef[nnx.Dict[Any]], state: nnx.State): assert m['a'][0] is m['b'] assert m['a'][1] is not m['b'] - return nnx.split(m) + return m.split() - graphdef, state = f(*nnx.split(m)) + graphdef, state = f(*m.split()) m = graphdef.merge(state) assert m['a'][0] is m['b'] @@ -162,9 +162,9 @@ def test_cross_barrier(self): def g(graphdef: nnx.GraphDef[nnx.Dict[nnx.Param[int]]], state: nnx.State): m = graphdef.merge(state) m.a.value += 1 - return nnx.split(m) + return m.split() - graphdef, state = g(*nnx.split(m)) + graphdef, state = g(*m.split()) m2 = graphdef.merge(state) assert m2 is not m assert m.a.value == 1 @@ -180,23 +180,23 @@ def g(state_and_def): n += 1 m = nnx.merge(*state_and_def) m.a.value += 1 - return nnx.split(m) + return m.split() - m2 = nnx.merge(*g(nnx.split(m))) + m2 = nnx.merge(*g(m.split())) assert n == 1 assert m2 is not m assert m.a.value == 1 assert m2.a.value == 2 - g(nnx.split(m)) + g(m.split()) assert n == 1 - g(nnx.split(m2)) + g(m2.split()) assert n == 1 m2.b = nnx.Param(10) - g(nnx.split(m2)) + g(m2.split()) assert n == 2 @@ -211,7 +211,7 @@ def test_deref_number_of_fields(self): } ) - graphdef, p = nnx.split(m) + graphdef, p = m.split() assert len(p.flat_state()) == 2 assert len(jax.tree_util.tree_leaves(p)) == 2 @@ -221,7 +221,7 @@ def test_clone(self): b=nnx.Dict(c=nnx.Param(1), d=nnx.Param(2)), ) - m2 = nnx.clone(m) + m2 = m.clone() assert m is not m2 assert m2.a[0] == m2.b.c @@ -247,7 +247,7 @@ def __call__(self, x): assert y2 == 11 assert m.y.value == (3, 11) - intermediates = nnx.pop(m, nnx.Intermediate) + intermediates = m.pop(nnx.Intermediate) assert isinstance(intermediates.y, nnx.Intermediate) assert intermediates['y'].raw_value == (3, 11) @@ -284,6 +284,32 @@ def __call__(self, x): with pytest.raises(ValueError, match='to be of type'): m(2) + def test_update_static_state(self): + class Foo(nnx.Module): + def add_field(self): + self.a = 1 + + m1 = Foo() + m2 = Foo() + m2.add_field() + + m1.update(m2) + + assert m1.a == 1 + + def test_update_moduledef(self): + class Foo(nnx.Module): + def add_field(self): + self.a = 1 + + m1 = Foo() + m2 = Foo() + m2.add_field() + + m1.update(m2.get_graphdef()) + + assert m1.a == 1 + def test_update_static_state_submodules(self): class Bar(nnx.Module): def __init__(self) -> None: @@ -298,12 +324,10 @@ def __init__(self) -> None: self.b = self.a m1 = Foo() - refmap, graphdef, state = nnx.full_split(m1) - m2, idxmap = nnx.full_merge(graphdef, state) + m2 = Foo() m2.a.add_field() - _, new_graphdef, state = nnx.full_split(m2, idxmap=idxmap) - nnx.full_update(refmap, new_graphdef, state) + m1.update(m2) assert m1.a.x == 1 assert m1.a.y == 2 @@ -323,12 +347,10 @@ def add_module(self): self.b = Bar() m1 = Foo() - refmap, graphdef, state = nnx.full_split(m1) - m2, idxmap = nnx.full_merge(graphdef, state) + m2 = Foo() m2.add_module() - _, new_graphdef, state = nnx.full_split(m2, idxmap=idxmap) - nnx.full_update(refmap, new_graphdef, state) + m1.update(m2) assert m1.a.x == 1 assert m1.b.x == 1 @@ -344,16 +366,15 @@ def __init__(self) -> None: self.b = self.a m1 = Foo() - refmap, graphdef, state = nnx.full_split(m1) - m2, idxmap = nnx.full_merge(graphdef, state) + m2 = Foo() m2.a.x = 2 - _, new_graphdef, state = nnx.full_split(m2, idxmap=idxmap) - nnx.full_update(refmap, new_graphdef, state) + + m1.update(m2) assert m1.a.x == 2 assert m1.b.x == 2 - def test_update_add_shared(self): + def test_update_add_shared_error(self): class Bar(nnx.Module): def __init__(self) -> None: self.x = 1 @@ -367,13 +388,37 @@ def add_submodule(self): self.c = self.a m1 = Foo() - refmap, graphdef, state = nnx.full_split(m1) - m2, idxmap = nnx.full_merge(graphdef, state) + m2 = Foo() + m2.add_submodule() + + assert hasattr(m2, 'c') + + with pytest.raises(ValueError, match='Trying to add a new node at path'): + m1.update(m2) + + def test_update_add_shared_error_new_first(self): + class Bar(nnx.Module): + def __init__(self) -> None: + self.x = 1 + + class Foo(nnx.Module): + def __init__(self) -> None: + self.b = Bar() + self.c = self.b + + def add_submodule(self): + self.a = self.b + + m1 = Foo() + m2 = Foo() m2.add_submodule() - _, new_graphdef, state = nnx.full_split(m2, idxmap=idxmap) - nnx.full_update(refmap, new_graphdef, state) - assert hasattr(m1, 'c') + assert hasattr(m2, 'a') + + m2 = m2.clone() # clone to sort the fields + + with pytest.raises(ValueError, match='Trying to update a node at path'): + m1.update(m2) def test_create_abstract(self): linear = nnx.Linear.create_abstract(2, 3, rngs=nnx.Rngs(0)) @@ -497,7 +542,7 @@ class Foo(nnx.Module): f=6, # static int ) - graphdef, state = nnx.split(m) + graphdef, state = m.split() assert len(state) == 4 assert state.b == nnx.Variable(2)