Skip to content

Commit

Permalink
Merge pull request #3861 from gnecula:nnx_doc1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 625686942
  • Loading branch information
Flax Authors committed Apr 17, 2024
2 parents bf4dfff + 7f1b36a commit 9973693
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 142 deletions.
162 changes: 76 additions & 86 deletions docs/experimental/nnx/nnx_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
"source": [
"# NNX Basics\n",
"\n",
"NNX is a **N**eural **N**etworks JA**X** library that embraces Python's object-oriented programming \n",
"model to provide an intuitive and highly simplified user experience. It uses PyGraphs (instead of PyTrees)\n",
"to represent stateful objects, which allows it to express reference sharing and mutability in Python itself. \n",
"This makes NNX code look like regular Python code that users from frameworkslike Pytorch and Keras will \n",
"be familiar with.\n",
"NNX is a **N**eural **N**etworks JA**X** library that embraces Pythons object-oriented \n",
"programming model to provide an intuitive and highly simplified user experience. It\n",
"represents objects as PyGraphs (instead of PyTrees), which allows NNX to handle reference\n",
"sharing and mutability, making model code be regular Python code that users from frameworks\n",
"like Pytorch will be familiar with.be familiar with.\n",
"\n",
"NNX is also designed to support \n",
"all the patterns that allowed Linen to scale to large code bases while building upon a much simpler \n",
"foundation."
"all the patterns that allowed Linen to scale to large code bases while having a much simpler\n",
"implementation."
]
},
{
Expand All @@ -34,8 +34,8 @@
"source": [
"## The Module System\n",
"To begin lets see how to create a `Linear` Module using NNX. The main noticeable\n",
"different between Module systems like Haiku or Linen and NNX is that in NNX everything is\n",
"**explicit**. This means amongst other things that 1) the Module itself holds the state \n",
"difference between NNX and Module systems like Haiku or Linen is that in NNX everything is\n",
"**explicit**. This means among other things that 1) the Module itself holds the state\n",
"(e.g. parameters) directly, 2) the RNG state is threaded by the user, and 3) all shape information\n",
"must be provided on initialization (no shape inference)."
]
Expand All @@ -61,16 +61,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As shown above dynamic state is usually stored in `nnx.Variable`s such as `nnx.Param`,\n",
"As shown above dynamic state is usually stored in `nnx.Param`s,\n",
"and static state (all types not handled by NNX) such as integers or strings \n",
"are stored directly. JAX array and Numpy array attributes are also treated as dynamic state,\n",
"although storing them inside `nnx.Variable`s is preferred. Also, RNG keys can be requested from the \n",
"`nnx.Rngs` object by calling `rngs.<stream_name>()` where the stream name show match on of \n",
"the names provided to the `Rngs` constructor (shown below).\n",
"are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic state,\n",
"although storing them inside `nnx.Variable`s is preferred. Also, the `nnx.Rngs` object by can be used to\n",
"get new unique keys based on a root key passed to the constructor (see below).\n",
"\n",
"To actually initialize a Module is very easy: simply call the constructor. All of the\n",
"To actually initialize a Module is very easy: simply call the constructor. All the\n",
"parameters of a Module will be created right then and there, and are immediately available\n",
"for inspection."
"for inspection using regular Python attribute access."
]
},
{
Expand All @@ -86,8 +85,8 @@
" din=2,\n",
" dout=3\n",
")\n",
"model.w.value = Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
" [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n",
"model.w.value = Array([[0.9913868 , 0.45571804, 0.7215481 ],\n",
" [0.8873962 , 0.2008096 , 0.72537684]], dtype=float32)\n",
"model.b.value = Array([0., 0., 0.], dtype=float32)\n"
]
}
Expand All @@ -105,7 +104,7 @@
"metadata": {},
"source": [
"This is very handy for debugging as it allows accessing the entire structure or\n",
"modify it. Similarly, computation can be ran directly."
"modifying it. Similarly, computations can be ran directly."
]
},
{
Expand All @@ -116,7 +115,7 @@
{
"data": {
"text/plain": [
"Array([[0.9763588 , 0.34776556, 0.87546587]], dtype=float32)"
"Array([[1.878783 , 0.65652764, 1.4469249 ]], dtype=float32)"
]
},
"execution_count": 4,
Expand All @@ -134,7 +133,8 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Since Modules hold their own state there is no need for a separate `apply` method."
"Since Modules hold their own state there is no need for a separate `apply` method, as in\n",
"Linen or Haiku."
]
},
{
Expand All @@ -151,7 +151,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -181,12 +181,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"**This looks too easy, what is the catch?** \n",
"JAX frameworks have avoided mutable references until now. The key innovations which \n",
"allows their usage in NNX is that 1) there is a clear boundary between reference \n",
"semantics and value semantics, defined by [The Functional API](#the-functional-api),\n",
"and 2) there are guards in place to avoid updating NNX objects from a `MainTrace`, \n",
"thus preventing tracer leakage."
"allows their usage in NNX is that 1) there is a clear boundary between code that uses \n",
"reference semantics and code that uses value semantics, defined by \n",
"[The Functional API](#the-functional-api), and 2) there are guards in place to avoid \n",
"updating NNX objects from a `MainTrace`, thus preventing tracer leakage."
]
},
{
Expand All @@ -195,17 +194,17 @@
"source": [
"### Nested Modules\n",
"\n",
"As expected, Modules can used to compose other Modules in a nested\n",
"structure, this includes standard Modules such as `nnx.Linear`,\n",
"`nnx.Conv`, etc, or any custom Module created by users. Modules can \n",
"be assigned as attributes of a Module, but shown by `MLP.blocks` in the\n",
"As expected, Modules can be used to compose other Modules in a nested\n",
"structure, including standard Modules such as `nnx.Linear`,\n",
"`nnx.Conv`, etc., or any custom Module created by users. Modules can\n",
"be assigned as attributes of a Module, but as shown by `MLP.blocks` in the\n",
"example below, they can also be stored in attributes of type `list`, `dict`, `tuple`, \n",
"or nested structues of the previous."
"or in nested structures of the same."
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -221,9 +220,9 @@
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x169773f70>,\n",
" bias_init=<function zeros at 0x1353b8ca0>,\n",
" dot_general=<function dot_general at 0x126dc5700>\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x13cfa4040>,\n",
" bias_init=<function zeros at 0x128869430>,\n",
" dot_general=<function dot_general at 0x11ff55430>\n",
" ),\n",
" bn=BatchNorm(\n",
" num_features=2,\n",
Expand Down Expand Up @@ -258,20 +257,20 @@
"metadata": {},
"source": [
"One of the benefits of NNX is that nested Modules as easy to inspect and\n",
"static analyzers can help you while doing so."
"static analyzers, e.g., code completion, can help you while doing so."
]
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"model.blocks[1].linear.kernel.value = Array([[0.992858 , 0.9711272],\n",
" [1.4061186, 0.4704619]], dtype=float32)\n",
"model.blocks[1].linear.kernel.value = Array([[-0.31410056, -0.9153769 ],\n",
" [-0.38879898, -0.12699318]], dtype=float32)\n",
"model.blocks[0].bn.scale.value = Array([1., 1.], dtype=float32)\n"
]
}
Expand All @@ -290,8 +289,8 @@
"at any time. Also, NNX's Module system supports reference sharing of Modules and\n",
"Variables.\n",
"\n",
"The previous makes Model Surgery quite easy as any submodule could be replace by\n",
"e.g. a pretrained Module, a shared Module, or even just a Module/function that\n",
"This makes Model Surgery quite easy as any submodule could be replaced by\n",
"e.g., a pretrained Module, a shared Module, or even just a Module/function that\n",
"uses the same signature. More over, Variables can also be modified or shared."
]
},
Expand All @@ -315,9 +314,9 @@
"# Module replacement\n",
"pretrained = Block(dim=2, rngs=nnx.Rngs(42)) # imagine this is pretrained\n",
"model.blocks[0] = pretrained\n",
"# Module sharing\n",
"# adhoc Module sharing\n",
"model.blocks[3] = model.blocks[1]\n",
"# Monkey patching\n",
"# monkey patching\n",
"def awesome_layer(x): return x\n",
"model.blocks[2] = awesome_layer\n",
"\n",
Expand All @@ -333,12 +332,12 @@
"source": [
"## The Functional API\n",
"\n",
"The Functional API established a clear boundary between reference/object semantics and \n",
"The Functional API establishes a clear boundary between reference/object semantics and\n",
"value/pytree semantics. It also allows same amount of fine-grained control over the \n",
"state Linen/Haiku users are used to. The Functional API consists of 3 basic methods: \n",
"state that Linen/Haiku users are used to. The Functional API consists of 3 basic methods:\n",
"`split`, `merge`, and `update`.\n",
"\n",
"The `StatefulLinear` Module shown below will serve as an example to learn to use the \n",
"The `StatefulLinear` Module shown below will serve as an example for the use of the\n",
"Functional API. It contains some `nnx.Param` Variables and a custom `Count` Variable\n",
"type which is used to keep track of integer scalar state that increases on every \n",
"forward pass."
Expand Down Expand Up @@ -371,11 +370,10 @@
"source": [
"### State and GraphDef\n",
"\n",
"A Module can be decomposed into a `State` and `GraphDef` pytrees using the \n",
"A Module can be decomposed into `GraphDef` and `State` using the\n",
"`.split()` method. State is a Mapping from strings to Variables or nested \n",
"States. GraphDef is contains all the static information needed to reconstruct \n",
"a Module graph, its analogous to JAX's `PyTreeDef`, and for convenience it \n",
"implements an empty pytree."
"States. GraphDef contains all the static information needed to reconstruct \n",
"a Module graph, it is analogous to JAX's `PyTreeDef`."
]
},
{
Expand All @@ -388,37 +386,29 @@
"output_type": "stream",
"text": [
"state = State({\n",
" 'w': Param(\n",
" raw_value=Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
" [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n",
" ),\n",
" 'b': Param(\n",
" raw_value=Array([0., 0., 0.], dtype=float32)\n",
" ),\n",
" 'count': Count(\n",
" raw_value=0\n",
" ),\n",
" 'w': Param(\n",
" raw_value=Array([[0.9913868 , 0.45571804, 0.7215481 ],\n",
" [0.8873962 , 0.2008096 , 0.72537684]], dtype=float32)\n",
" )\n",
"})\n",
"\n",
"static = GraphDef(\n",
" type=StatefulLinear,\n",
" index=0,\n",
" attributes=('w', 'b', 'count'),\n",
" subgraphs={},\n",
" static_fields={},\n",
" variables={\n",
" 'w': VariableDef(\n",
" type=Param,\n",
" index=1,\n",
" me...\n"
"graphdef = GraphDef(nodedef=NodeDef(type=<class '__main__.StatefulLinear'>, index=0, attributes=('b', 'count', 'w'), subgraphs={}, static_fields={}, variables={'b': VariableDef(\n",
" type=Param,\n",
" index=...\n"
]
}
],
"source": [
"static, state = model.split()\n",
"graphdef, state = model.split()\n",
"\n",
"print(f'{state = }\\n')\n",
"print(f'{static = }'[:200] + '...')"
"print(f'{graphdef = }'[:200] + '...')"
]
},
{
Expand All @@ -428,9 +418,9 @@
"### Split, Merge, and Update\n",
"\n",
"`merge` is the reverse of `split`, it takes the GraphDef + State and reconstructs\n",
"the Module. As shown in the example below, by using split and merge in sequence \n",
"the Module. As shown in the example below, by using `split` and `merge` in sequence\n",
"any Module can be lifted to be used in any JAX transform. `update` can\n",
"update a Module strucure from a compatible State, this is often used to propagate the state\n",
"update a Module structure from a compatible State. This is often used to propagate the state\n",
"updates from a transform back to the source object outside."
]
},
Expand All @@ -454,19 +444,19 @@
"print(f'{model.count = }')\n",
"\n",
"# 1. Use split to create a pytree representation of the Module\n",
"static, state = model.split()\n",
"graphdef, state = model.split()\n",
"\n",
"@jax.jit\n",
"def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n",
"def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:\n",
" # 2. Use merge to create a new model inside the JAX transformation\n",
" model = static.merge(state)\n",
" model = graphdef.merge(state)\n",
" # 3. Call the Module\n",
" y = model(x)\n",
" # 4. Use split to propagate State updates\n",
" _, state = model.split()\n",
" return y, state\n",
"\n",
"y, state = forward(static, state, x=jnp.ones((1, 2)))\n",
"y, state = forward(graphdef, state, x=jnp.ones((1, 2)))\n",
"# 5. Update the state of the original Module\n",
"model.update(state)\n",
"\n",
Expand All @@ -481,11 +471,11 @@
"fine within a transform context (including the base eager interpreter)\n",
"but its necessary to use the Functional API when crossing boundaries.\n",
"\n",
"**Why aren't Module's just Pytrees?** The main reason is that its very\n",
"**Why aren't Module's just Pytrees?** The main reason is that it is very\n",
"easy to lose track of shared references by accident this way, for example\n",
"if you pass two Module that have a shared Module through a JAX boundary\n",
"you will silently lose that shared reference. The Functional API makes this\n",
"behavior explicit and thus its much easier to reason about."
"you will silently lose that sharing. The Functional API makes this\n",
"behavior explicit, and thus it is much easier to reason about."
]
},
{
Expand All @@ -497,14 +487,14 @@
"Seasoned Linen and Haiku users might recognize that having all the state in\n",
"a single structure is not always the best choice as there are cases in which\n",
"you might want to handle different subsets of the state differently. This a\n",
"common occurrence when interacting with JAX transform, for example, not all\n",
"common occurrence when interacting with JAX transforms, for example, not all\n",
"the model's state can or should be differentiated when interacting which `grad`,\n",
"or sometimes there is a need to specify what part of the model's state is a\n",
"carry and what part is not when using `scan`.\n",
"\n",
"To solve this `split` allows you to pass one or more `Filter`s to partition\n",
"To solve this, `split` allows you to pass one or more `Filter`s to partition\n",
"the Variables into mutually exclusive States. The most common Filter being\n",
"Variable types as shown below."
"types as shown below."
]
},
{
Expand All @@ -517,12 +507,12 @@
"output_type": "stream",
"text": [
"params = State({\n",
" 'w': Param(\n",
" raw_value=Array([[0.19007349, 0.31424356, 0.3686391 ],\n",
" [0.7862853 , 0.03352201, 0.50682676]], dtype=float32)\n",
" ),\n",
" 'b': Param(\n",
" raw_value=Array([0., 0., 0.], dtype=float32)\n",
" ),\n",
" 'w': Param(\n",
" raw_value=Array([[0.9913868 , 0.45571804, 0.7215481 ],\n",
" [0.8873962 , 0.2008096 , 0.72537684]], dtype=float32)\n",
" )\n",
"})\n",
"\n",
Expand All @@ -536,7 +526,7 @@
],
"source": [
"# use Variable type filters to split into multiple States\n",
"static, params, counts = model.split(nnx.Param, Count)\n",
"graphdef, params, counts = model.split(nnx.Param, Count)\n",
"\n",
"print(f'{params = }\\n')\n",
"print(f'{counts = }')"
Expand All @@ -558,7 +548,7 @@
"outputs": [],
"source": [
"# merge multiple States\n",
"model = static.merge(params, counts)\n",
"model = graphdef.merge(params, counts)\n",
"# update with multiple States\n",
"model.update(params, counts)"
]
Expand Down
Loading

0 comments on commit 9973693

Please sign in to comment.