diff --git a/.readthedocs.yml b/.readthedocs.yml index fa87e6d31f..d466631bbb 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -8,7 +8,7 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.9" + python: "3.10" # Build documentation in the docs/ directory with Sphinx sphinx: diff --git a/docs/api_reference/flax.experimental.nnx/index.rst b/docs/api_reference/flax.experimental.nnx/index.rst index fa2dca6b8d..bc0a376894 100644 --- a/docs/api_reference/flax.experimental.nnx/index.rst +++ b/docs/api_reference/flax.experimental.nnx/index.rst @@ -14,4 +14,5 @@ Experimental API. See the `NNX page ,\n", - " precision=None,\n", - " kernel_init=.init at 0x35cbd31f0>,\n", - " bias_init=,\n", - " conv_general_dilated= /* penzai.treescope rendering of a Python object (compressed) */ (()=>{ let observer; let lastStep = new Promise((resolve, reject) => { observer = new IntersectionObserver((entries) => { for (const entry of entries) { if (entry.isIntersecting) { resolve(); observer.disconnect(); return; } } }, {rootMargin: \"1000px\"}); }); window.treescope_decompress_enqueue = (encoded, destId) => { const previous = lastStep; const destElt = document.getElementById(destId); lastStep = (async () => { await previous; let blob = new Blob([ Uint8Array.from(atob(encoded), (m) => m.codePointAt(0)) ]); let reader = blob.stream().pipeThrough( new DecompressionStream(\"deflate\") ).pipeThrough( new TextDecoderStream(\"utf-8\") ).getReader(); let parts = []; while (true) { let step = await reader.read(); if (step.done) { break; } parts.push(step.value); } let newElt = document.createElement(\"div\"); newElt.innerHTML = parts.join(\"\"); destElt.parentNode.replaceChild(newElt, destElt); for (let oldScript of newElt.querySelectorAll(\"script\")) { let newScript = document.createElement(\"script\"); newScript.type = oldScript.type; newScript.textContent = oldScript.textContent; oldScript.parentNode.replaceChild(newScript, oldScript); } })(); requestAnimationFrame(() => { observer.observe(destElt); }); } })();
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "from flax.experimental import nnx # NNX API\n", + "from functools import partial\n", "\n", "class CNN(nnx.Module):\n", " \"\"\"A simple CNN model.\"\"\"\n", "\n", " def __init__(self, *, rngs: nnx.Rngs):\n", - " self.conv1 = nnx.Conv(\n", - " in_features=1, out_features=32, kernel_size=(3, 3), rngs=rngs\n", - " )\n", - " self.conv2 = nnx.Conv(\n", - " in_features=32, out_features=64, kernel_size=(3, 3), rngs=rngs\n", - " )\n", - " self.linear1 = nnx.Linear(in_features=3136, out_features=256, rngs=rngs)\n", - " self.linear2 = nnx.Linear(in_features=256, out_features=10, rngs=rngs)\n", + " self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)\n", + " self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)\n", + " self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))\n", + " self.linear1 = nnx.Linear(3136, 256, rngs=rngs)\n", + " self.linear2 = nnx.Linear(256, 10, rngs=rngs)\n", "\n", " def __call__(self, x):\n", - " x = self.conv1(x)\n", - " x = nnx.relu(x)\n", - " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = self.conv2(x)\n", - " x = nnx.relu(x)\n", - " x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2))\n", - " x = x.reshape((x.shape[0], -1)) # flatten\n", - " x = self.linear1(x)\n", - " x = nnx.relu(x)\n", + " x = self.avg_pool(nnx.relu(self.conv1(x)))\n", + " x = self.avg_pool(nnx.relu(self.conv2(x)))\n", + " x = x.reshape(x.shape[0], -1) # flatten\n", + " x = nnx.relu(self.linear1(x))\n", " x = self.linear2(x)\n", " return x\n", "\n", - "\n", "model = CNN(rngs=nnx.Rngs(0))\n", - "\n", - "print(f'model = {model}'[:500] + '\\n...\\n') # print a part of the model\n", - "print(\n", - " f'{model.conv1.kernel.value.shape = }'\n", - ") # inspect the shape of the kernel of the first convolutional layer" + "nnx.display(model)" ] }, { @@ -186,7 +169,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "8", "metadata": { "outputId": "2c580f41-bf5d-40ec-f1cf-ab7f319a84da" @@ -194,21 +177,22 @@ "outputs": [ { "data": { + "text/html": [ + "
(Loading...)
" + ], "text/plain": [ - "Array([[-0.06820839, -0.14743432, 0.00265857, -0.2173656 , 0.16673787,\n", - " -0.00923921, -0.06636689, 0.28341877, 0.33754364, -0.20142877]], dtype=float32)" + "" ] }, - "execution_count": 3, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ "import jax.numpy as jnp # JAX NumPy\n", "\n", "y = model(jnp.ones((1, 28, 28, 1)))\n", - "y" + "nnx.display(y)" ] }, { @@ -216,33 +200,9 @@ "id": "9", "metadata": {}, "source": [ - "## 4. Create the `TrainState`\n", + "## 4. Create Optimizer and Metrics\n", "\n", - "In Flax, a common practice is to use a dataclass to encapsulate the entire training state, which would allow you to simply pass only two arguments (the train state and batched data) to functions like `train_step`. The training state would typically contain an [`nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/optimizer.html#flax.experimental.nnx.optimizer.Optimizer) (which contains the step number, model and optimizer state) and an `nnx.Module` (for easier access to the model from the top-level of the train state). The training state can also be easily extended to add training and test metrics, as you will see in this tutorial (see [`nnx.metrics`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/metrics.html#module-flax.experimental.nnx.metrics) for more detail on NNX's metric classes)." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "10", - "metadata": {}, - "outputs": [], - "source": [ - "import dataclasses\n", - "\n", - "@dataclasses.dataclass\n", - "class TrainState(nnx.GraphNode):\n", - " optimizer: nnx.Optimizer\n", - " model: CNN\n", - " metrics: nnx.MultiMetric" - ] - }, - { - "cell_type": "markdown", - "id": "11", - "metadata": {}, - "source": [ - "We use `optax` to create an optimizer ([`adamw`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw)) and initialize the `nnx.Optimizer`. We use `nnx.MultiMetric` to keep track of both the accuracy and average loss for both training and test batches." + "In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model parameters and an `optax` optimizer that will define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss." ] }, { @@ -250,21 +210,33 @@ "execution_count": 5, "id": "12", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "import optax\n", "\n", "learning_rate = 0.005\n", "momentum = 0.9\n", - "tx = optax.adamw(learning_rate, momentum)\n", - "\n", - "state = TrainState(\n", - " optimizer=nnx.Optimizer(model=model, tx=tx),\n", - " model=model,\n", - " metrics=nnx.MultiMetric(\n", - " accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average()\n", - " ),\n", - ")" + "\n", + "optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum))\n", + "metrics = nnx.MultiMetric(\n", + " accuracy=nnx.metrics.Accuracy(), \n", + " loss=nnx.metrics.Average('loss'),\n", + ")\n", + "\n", + "nnx.display(optimizer)" ] }, { @@ -284,7 +256,7 @@ "metadata": {}, "outputs": [], "source": [ - "def loss_fn(model, batch):\n", + "def loss_fn(model: CNN, batch):\n", " logits = model(batch['image'])\n", " loss = optax.softmax_cross_entropy_with_integer_labels(\n", " logits=logits, labels=batch['label']\n", @@ -297,11 +269,11 @@ "id": "15", "metadata": {}, "source": [ - "Next, we create the training step function. This function takes the `state` and a data `batch` and does the following:\n", + "Next, we create the training step function. This function takes the `model` and a data `batch` and does the following:\n", "\n", "* Computes the loss, logits and gradients with respect to the loss function using `nnx.value_and_grad`.\n", - "* Updates the training loss using the loss and updates the training accuracy using the logits and batch labels\n", - "* Updates model parameters and optimizer state by applying the gradient pytree to the optimizer." + "* Updates training accuracy using the loss, logits, and batch labels.\n", + "* Updates model parameters via the optimizer by applying the gradient updates." ] }, { @@ -312,12 +284,12 @@ "outputs": [], "source": [ "@nnx.jit\n", - "def train_step(state: TrainState, batch):\n", + "def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch):\n", " \"\"\"Train for a single step.\"\"\"\n", " grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)\n", - " (loss, logits), grads = grad_fn(state.model, batch)\n", - " state.metrics.update(values=loss, logits=logits, labels=batch['label'])\n", - " state.optimizer.update(grads=grads)" + " (loss, logits), grads = grad_fn(model, batch)\n", + " metrics.update(loss=loss, logits=logits, labels=batch['label'])\n", + " optimizer.update(grads)" ] }, { @@ -328,9 +300,9 @@ "The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with \n", "[XLA](https://www.tensorflow.org/xla), optimizing performance on \n", "hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit),\n", - "except it can decorate functions that make stateful updates to NNX classes.\n", + "except it can transforms functions that contain NNX objects as inputs and outputs.\n", "\n", - "## 6. Metric Computation\n", + "## 6. Evaluation step\n", "\n", "Create a separate function to calculate loss and accuracy metrics for the test batch, since this will be outside the `train_step` function. Loss is determined using the `optax.softmax_cross_entropy_with_integer_labels` function, since we're reusing the loss function defined earlier." ] @@ -343,9 +315,9 @@ "outputs": [], "source": [ "@nnx.jit\n", - "def compute_test_metrics(*, state: TrainState, batch):\n", - " loss, logits = loss_fn(state.model, batch)\n", - " state.metrics.update(values=loss, logits=logits, labels=batch['label'])" + "def eval_step(model: CNN, metrics: nnx.MultiMetric, batch):\n", + " loss, logits = loss_fn(model, batch)\n", + " metrics.update(loss=loss, logits=logits, labels=batch['label'])" ] }, { @@ -376,20 +348,9 @@ "source": [ "## 8. Train and Evaluate\n", "\n", - "**Dataset Preparation:** create a \"shuffled\" dataset\n", - "- Repeat the dataset for the desired number of training epochs.\n", - "- Establish a 1024-sample buffer (holding the dataset's initial 1024 samples).\n", - " Randomly draw batches from this buffer.\n", - "- As samples are drawn, replenish the buffer with subsequent dataset samples.\n", - "\n", - "**Training Loop:** Iterate through epochs\n", - "- Sample batches randomly from the dataset.\n", - "- Execute an optimization step for each training batch.\n", - "- Calculate mean training metrics across batches within the epoch.\n", - "- With updated parameters, compute metrics on the test set.\n", - "- Log train and test metrics for visualization.\n", - "\n", - "After 10 training and testing epochs, your model should reach approximately 99% accuracy." + "Now we train a model using batches of data for 10 epochs, evaluate its performance \n", + "on the test set after each epoch, and log the training and testing metrics (loss and\n", + "accuracy) throughout the process. Typically this leads to a model with around 99% accuracy." ] }, { @@ -400,31 +361,150 @@ "outputId": "258a2c76-2c8f-4a9e-d48b-dde57c342a87" }, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:11:51.147408: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "train epoch: 1, loss: 0.10209392756223679, accuracy: 96.92666625976562\n", - "test epoch: 1, loss: 0.05703972652554512, accuracy: 98.10697174072266\n", + "test epoch: 1, loss: 0.05703972652554512, accuracy: 98.10697174072266\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:12:16.589051: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "train epoch: 2, loss: 0.04372011497616768, accuracy: 98.63666534423828\n", - "test epoch: 2, loss: 0.041248343884944916, accuracy: 98.73797607421875\n", + "test epoch: 2, loss: 0.041248343884944916, accuracy: 98.73797607421875\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:12:41.074941: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "train epoch: 3, loss: 0.030999813228845596, accuracy: 99.0433349609375\n", - "test epoch: 3, loss: 0.05681844428181648, accuracy: 98.49759674072266\n", + "test epoch: 3, loss: 0.05681844428181648, accuracy: 98.49759674072266\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:13:06.820973: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "train epoch: 4, loss: 0.026122156530618668, accuracy: 99.25333404541016\n", - "test epoch: 4, loss: 0.04033380746841431, accuracy: 98.68789672851562\n", + "test epoch: 4, loss: 0.04033380746841431, accuracy: 98.68789672851562\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:13:32.306590: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "train epoch: 5, loss: 0.023744497448205948, accuracy: 99.31500244140625\n", - "test epoch: 5, loss: 0.05083772540092468, accuracy: 98.76802825927734\n", + "test epoch: 5, loss: 0.05083772540092468, accuracy: 98.76802825927734\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:13:57.767435: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "train epoch: 6, loss: 0.01850314810872078, accuracy: 99.45500183105469\n", - "test epoch: 6, loss: 0.04953562840819359, accuracy: 98.85816955566406\n", + "test epoch: 6, loss: 0.04953562840819359, accuracy: 98.85816955566406\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:14:23.140815: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "train epoch: 7, loss: 0.015862680971622467, accuracy: 99.51166534423828\n", - "test epoch: 7, loss: 0.0707646906375885, accuracy: 98.40745544433594\n", + "test epoch: 7, loss: 0.0707646906375885, accuracy: 98.40745544433594\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:14:49.081003: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "train epoch: 8, loss: 0.018966104835271835, accuracy: 99.47333526611328\n", - "test epoch: 8, loss: 0.061334095895290375, accuracy: 98.89823913574219\n", + "test epoch: 8, loss: 0.061334095895290375, accuracy: 98.89823913574219\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:15:14.341633: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ "train epoch: 9, loss: 0.015244723297655582, accuracy: 99.6050033569336\n", "test epoch: 9, loss: 0.07078084349632263, accuracy: 98.78805541992188\n", "train epoch: 10, loss: 0.013812240213155746, accuracy: 99.61500549316406\n", "test epoch: 10, loss: 0.09043453633785248, accuracy: 98.818115234375\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-04-25 15:15:39.885448: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n", + "2024-04-25 15:15:39.886727: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence\n" + ] } ], "source": [ @@ -442,22 +522,22 @@ " # - the train state's model parameters\n", " # - the optimizer state\n", " # - the training loss and accuracy batch metrics\n", - " train_step(state, batch)\n", + " train_step(model, optimizer, metrics, batch)\n", "\n", " if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed\n", " # Log training metrics\n", - " for metric, value in state.metrics.compute().items(): # compute metrics\n", + " for metric, value in metrics.compute().items(): # compute metrics\n", " metrics_history[f'train_{metric}'].append(value) # record metrics\n", - " state.metrics.reset() # reset metrics for test set\n", + " metrics.reset() # reset metrics for test set\n", "\n", " # Compute metrics on the test set after each training epoch\n", " for test_batch in test_ds.as_numpy_iterator():\n", - " compute_test_metrics(state=state, batch=test_batch)\n", + " eval_step(model, metrics, test_batch)\n", "\n", " # Log test metrics\n", - " for metric, value in state.metrics.compute().items():\n", + " for metric, value in metrics.compute().items():\n", " metrics_history[f'test_{metric}'].append(value)\n", - " state.metrics.reset() # reset metrics for next training epoch\n", + " metrics.reset() # reset metrics for next training epoch\n", "\n", " print(\n", " f\"train epoch: {(step+1) // num_steps_per_epoch}, \"\n", @@ -491,22 +571,13 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" } ], "source": [ @@ -521,8 +592,7 @@ " ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')\n", "ax1.legend()\n", "ax2.legend()\n", - "plt.show()\n", - "plt.clf()" + "plt.show()" ] }, { @@ -543,8 +613,8 @@ "outputs": [], "source": [ "@nnx.jit\n", - "def pred_step(state: TrainState, batch):\n", - " logits = state.model(batch['image'])\n", + "def pred_step(model: CNN, batch):\n", + " logits = model(batch['image'])\n", " return logits.argmax(axis=1)" ] }, @@ -558,7 +628,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -569,7 +639,7 @@ ], "source": [ "test_batch = test_ds.as_numpy_iterator().next()\n", - "pred = pred_step(state, test_batch)\n", + "pred = pred_step(model, test_batch)\n", "\n", "fig, axs = plt.subplots(5, 5, figsize=(12, 12))\n", "for i, ax in enumerate(axs.flatten()):\n", @@ -602,7 +672,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.6" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/docs/experimental/nnx/mnist_tutorial.md b/docs/experimental/nnx/mnist_tutorial.md index e3e40a3a51..e6510d2397 100644 --- a/docs/experimental/nnx/mnist_tutorial.md +++ b/docs/experimental/nnx/mnist_tutorial.md @@ -28,13 +28,14 @@ Since NNX is under active development, we recommend using the latest version fro ```{code-cell} ipython3 :tags: [skip-execution] -# TODO: Fix text descriptions in this tutorial -!pip install git+https://github.com/google/flax.git +# !pip install git+https://github.com/google/flax.git ``` ## 2. Load the MNIST Dataset -We'll use TensorFlow Datasets (TFDS) for loading and preparing the MNIST dataset: +First, the MNIST dataset is loaded and prepared for training and testing using +Tensorflow Datasets. Image values are normalized, the data is shuffled and divided +into batches, and samples are prefetched to enhance performance. ```{code-cell} ipython3 import tensorflow_datasets as tfds # TFDS for MNIST @@ -77,40 +78,28 @@ Create a convolutional neural network with NNX by subclassing `nnx.Module`. ```{code-cell} ipython3 from flax.experimental import nnx # NNX API +from functools import partial class CNN(nnx.Module): """A simple CNN model.""" def __init__(self, *, rngs: nnx.Rngs): - self.conv1 = nnx.Conv( - in_features=1, out_features=32, kernel_size=(3, 3), rngs=rngs - ) - self.conv2 = nnx.Conv( - in_features=32, out_features=64, kernel_size=(3, 3), rngs=rngs - ) - self.linear1 = nnx.Linear(in_features=3136, out_features=256, rngs=rngs) - self.linear2 = nnx.Linear(in_features=256, out_features=10, rngs=rngs) + self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs) + self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs) + self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2)) + self.linear1 = nnx.Linear(3136, 256, rngs=rngs) + self.linear2 = nnx.Linear(256, 10, rngs=rngs) def __call__(self, x): - x = self.conv1(x) - x = nnx.relu(x) - x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = self.conv2(x) - x = nnx.relu(x) - x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = self.linear1(x) - x = nnx.relu(x) + x = self.avg_pool(nnx.relu(self.conv1(x))) + x = self.avg_pool(nnx.relu(self.conv2(x))) + x = x.reshape(x.shape[0], -1) # flatten + x = nnx.relu(self.linear1(x)) x = self.linear2(x) return x - model = CNN(rngs=nnx.Rngs(0)) - -print(f'model = {model}'[:500] + '\n...\n') # print a part of the model -print( - f'{model.conv1.kernel.value.shape = }' -) # inspect the shape of the kernel of the first convolutional layer +nnx.display(model) ``` ### Run model @@ -123,39 +112,26 @@ Let's put our model to the test! We'll perform a forward pass with arbitrary da import jax.numpy as jnp # JAX NumPy y = model(jnp.ones((1, 28, 28, 1))) -y +nnx.display(y) ``` -## 4. Create the `TrainState` - -In Flax, a common practice is to use a dataclass to encapsulate the entire training state, which would allow you to simply pass only two arguments (the train state and batched data) to functions like `train_step`. The training state would typically contain an [`nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/optimizer.html#flax.experimental.nnx.optimizer.Optimizer) (which contains the step number, model and optimizer state) and an `nnx.Module` (for easier access to the model from the top-level of the train state). The training state can also be easily extended to add training and test metrics, as you will see in this tutorial (see [`nnx.metrics`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/training/metrics.html#module-flax.experimental.nnx.metrics) for more detail on NNX's metric classes). - -```{code-cell} ipython3 -import dataclasses - -@dataclasses.dataclass -class TrainState(nnx.GraphNode): - optimizer: nnx.Optimizer - model: CNN - metrics: nnx.MultiMetric -``` +## 4. Create Optimizer and Metrics -We use `optax` to create an optimizer ([`adamw`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.adamw)) and initialize the `nnx.Optimizer`. We use `nnx.MultiMetric` to keep track of both the accuracy and average loss for both training and test batches. +In NNX, we create an `Optimizer` object to manage the model's parameters and apply gradients during training. `Optimizer` receives the model parameters and an `optax` optimizer that will define the update rules. Additionally, we'll define a `MultiMetric` object to keep track of the `Accuracy` and the `Average` loss. ```{code-cell} ipython3 import optax learning_rate = 0.005 momentum = 0.9 -tx = optax.adamw(learning_rate, momentum) - -state = TrainState( - optimizer=nnx.Optimizer(model=model, tx=tx), - model=model, - metrics=nnx.MultiMetric( - accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average() - ), + +optimizer = nnx.Optimizer(model, optax.adamw(learning_rate, momentum)) +metrics = nnx.MultiMetric( + accuracy=nnx.metrics.Accuracy(), + loss=nnx.metrics.Average('loss'), ) + +nnx.display(optimizer) ``` ## 5. Training step @@ -163,7 +139,7 @@ state = TrainState( We define a loss function using cross entropy loss (see more details in [`optax.softmax_cross_entropy_with_integer_labels()`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.softmax_cross_entropy_with_integer_labels)) that our model will optimize over. In addition to the loss, the logits are also outputted since they will be used to calculate the accuracy metric during training and testing. ```{code-cell} ipython3 -def loss_fn(model, batch): +def loss_fn(model: CNN, batch): logits = model(batch['image']) loss = optax.softmax_cross_entropy_with_integer_labels( logits=logits, labels=batch['label'] @@ -171,36 +147,36 @@ def loss_fn(model, batch): return loss, logits ``` -Next, we create the training step function. This function takes the `state` and a data `batch` and does the following: +Next, we create the training step function. This function takes the `model` and a data `batch` and does the following: * Computes the loss, logits and gradients with respect to the loss function using `nnx.value_and_grad`. -* Updates the training loss using the loss and updates the training accuracy using the logits and batch labels -* Updates model parameters and optimizer state by applying the gradient pytree to the optimizer. +* Updates training accuracy using the loss, logits, and batch labels. +* Updates model parameters via the optimizer by applying the gradient updates. ```{code-cell} ipython3 @nnx.jit -def train_step(state: TrainState, batch): +def train_step(model: CNN, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch): """Train for a single step.""" grad_fn = nnx.value_and_grad(loss_fn, has_aux=True) - (loss, logits), grads = grad_fn(state.model, batch) - state.metrics.update(values=loss, logits=logits, labels=batch['label']) - state.optimizer.update(grads=grads) + (loss, logits), grads = grad_fn(model, batch) + metrics.update(loss=loss, logits=logits, labels=batch['label']) + optimizer.update(grads) ``` The [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.experimental.nnx/transforms.html#flax.experimental.nnx.jit) decorator traces the `train_step` function for just-in-time compilation with [XLA](https://www.tensorflow.org/xla), optimizing performance on hardware accelerators. `nnx.jit` is similar to [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit), -except it can decorate functions that make stateful updates to NNX classes. +except it can transforms functions that contain NNX objects as inputs and outputs. -## 6. Metric Computation +## 6. Evaluation step Create a separate function to calculate loss and accuracy metrics for the test batch, since this will be outside the `train_step` function. Loss is determined using the `optax.softmax_cross_entropy_with_integer_labels` function, since we're reusing the loss function defined earlier. ```{code-cell} ipython3 @nnx.jit -def compute_test_metrics(*, state: TrainState, batch): - loss, logits = loss_fn(state.model, batch) - state.metrics.update(values=loss, logits=logits, labels=batch['label']) +def eval_step(model: CNN, metrics: nnx.MultiMetric, batch): + loss, logits = loss_fn(model, batch) + metrics.update(loss=loss, logits=logits, labels=batch['label']) ``` ## 7. Seed randomness @@ -213,20 +189,9 @@ tf.random.set_seed(0) ## 8. Train and Evaluate -**Dataset Preparation:** create a "shuffled" dataset -- Repeat the dataset for the desired number of training epochs. -- Establish a 1024-sample buffer (holding the dataset's initial 1024 samples). - Randomly draw batches from this buffer. -- As samples are drawn, replenish the buffer with subsequent dataset samples. - -**Training Loop:** Iterate through epochs -- Sample batches randomly from the dataset. -- Execute an optimization step for each training batch. -- Calculate mean training metrics across batches within the epoch. -- With updated parameters, compute metrics on the test set. -- Log train and test metrics for visualization. - -After 10 training and testing epochs, your model should reach approximately 99% accuracy. +Now we train a model using batches of data for 10 epochs, evaluate its performance +on the test set after each epoch, and log the training and testing metrics (loss and +accuracy) throughout the process. Typically this leads to a model with around 99% accuracy. ```{code-cell} ipython3 :outputId: 258a2c76-2c8f-4a9e-d48b-dde57c342a87 @@ -245,22 +210,22 @@ for step, batch in enumerate(train_ds.as_numpy_iterator()): # - the train state's model parameters # - the optimizer state # - the training loss and accuracy batch metrics - train_step(state, batch) + train_step(model, optimizer, metrics, batch) if (step + 1) % num_steps_per_epoch == 0: # one training epoch has passed # Log training metrics - for metric, value in state.metrics.compute().items(): # compute metrics + for metric, value in metrics.compute().items(): # compute metrics metrics_history[f'train_{metric}'].append(value) # record metrics - state.metrics.reset() # reset metrics for test set + metrics.reset() # reset metrics for test set # Compute metrics on the test set after each training epoch for test_batch in test_ds.as_numpy_iterator(): - compute_test_metrics(state=state, batch=test_batch) + eval_step(model, metrics, test_batch) # Log test metrics - for metric, value in state.metrics.compute().items(): + for metric, value in metrics.compute().items(): metrics_history[f'test_{metric}'].append(value) - state.metrics.reset() # reset metrics for next training epoch + metrics.reset() # reset metrics for next training epoch print( f"train epoch: {(step+1) // num_steps_per_epoch}, " @@ -293,7 +258,6 @@ for dataset in ('train', 'test'): ax1.legend() ax2.legend() plt.show() -plt.clf() ``` ## 10. Perform inference on test set @@ -302,8 +266,8 @@ Define a jitted inference function, `pred_step`, to generate predictions on the ```{code-cell} ipython3 @nnx.jit -def pred_step(state: TrainState, batch): - logits = state.model(batch['image']) +def pred_step(model: CNN, batch): + logits = model(batch['image']) return logits.argmax(axis=1) ``` @@ -311,7 +275,7 @@ def pred_step(state: TrainState, batch): :outputId: 1db5a01c-9d70-4f7d-8c0d-0a3ad8252d3e test_batch = test_ds.as_numpy_iterator().next() -pred = pred_step(state, test_batch) +pred = pred_step(model, test_batch) fig, axs = plt.subplots(5, 5, figsize=(12, 12)) for i, ax in enumerate(axs.flatten()): diff --git a/docs/experimental/nnx/nnx_basics.ipynb b/docs/experimental/nnx/nnx_basics.ipynb index d7828ebe00..3a90e0d836 100644 --- a/docs/experimental/nnx/nnx_basics.ipynb +++ b/docs/experimental/nnx/nnx_basics.ipynb @@ -6,15 +6,15 @@ "source": [ "# NNX Basics\n", "\n", - "NNX is a **N**eural **N**etworks JA**X** library that embraces Python’s object-oriented \n", - "programming model to provide an intuitive and highly simplified user experience. It\n", - "represents objects as PyGraphs (instead of PyTrees), which allows NNX to handle reference\n", - "sharing and mutability, making model code be regular Python code that users from frameworks\n", - "like Pytorch will be familiar with.be familiar with.\n", - "\n", - "NNX is also designed to support \n", - "all the patterns that allowed Linen to scale to large code bases while having a much simpler\n", - "implementation." + "NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best \n", + "development experience, so building and experimenting with neural networks is easy and\n", + "intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees), \n", + "enabling reference sharing and mutability. This design allows your models to resemble \n", + "familiar Python object-oriented code, particularly appealing to users of frameworks\n", + "like PyTorch.\n", + "\n", + "Despite its simplified implementation, NNX supports the same powerful design patterns \n", + "that have allowed Linen to scale effectively to large codebases." ] }, { @@ -33,11 +33,18 @@ "metadata": {}, "source": [ "## The Module System\n", - "To begin lets see how to create a `Linear` Module using NNX. The main noticeable\n", - "difference between NNX and Module systems like Haiku or Linen is that in NNX everything is\n", - "**explicit**. This means among other things that 1) the Module itself holds the state\n", - "(e.g. parameters) directly, 2) the RNG state is threaded by the user, and 3) all shape information\n", - "must be provided on initialization (no shape inference)." + "To begin lets see how to create a `Linear` Module using NNX. The main difference between \n", + "NNX and Module systems like Haiku or Linen is that in NNX everything is **explicit**. This \n", + "means among other things that 1) the Module itself holds the state (e.g. parameters) directly, \n", + "2) the RNG state is threaded by the user, and 3) all shape information must be provided on \n", + "initialization (no shape inference).\n", + "\n", + "As shown next, dynamic state is usually stored in `nnx.Param`s, and static state \n", + "(all types not handled by NNX) such as integers or strings are stored directly. \n", + "Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic \n", + "state,although storing them inside `nnx.Variable`s such as `Param` is preferred.\n", + "Also, the `nnx.Rngs` object by can be used to get new unique keys based on a root \n", + "key passed to the constructor." ] }, { @@ -54,22 +61,23 @@ " self.din, self.dout = din, dout\n", "\n", " def __call__(self, x: jax.Array):\n", - " return x @ self.w.value + self.b.value" + " return x @ self.w + self.b" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "As shown above dynamic state is usually stored in `nnx.Param`s,\n", - "and static state (all types not handled by NNX) such as integers or strings \n", - "are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic state,\n", - "although storing them inside `nnx.Variable`s is preferred. Also, the `nnx.Rngs` object by can be used to\n", - "get new unique keys based on a root key passed to the constructor (see below).\n", - "\n", - "To actually initialize a Module is very easy: simply call the constructor. All the\n", - "parameters of a Module will be created right then and there, and are immediately available\n", - "for inspection using regular Python attribute access." + "`nnx.Variable`'s inner values can be accessed using the `.value` property, however\n", + "for convenience they implement all numeric operators and can be used directly in\n", + "arithmetic expressions (as shown above). Additionally, Variables can passed\n", + "to any JAX function as they implement the `__jax_array__` protocol (as long as their\n", + "inner value is a JAX array).\n", + "\n", + "To actually initialize a Module you simply call the constructor, all the parameters \n", + "of a Module are usually created eagerly. Since Modules hold their own state methods \n", + "can be called directly without the no need for a separate `apply` method, this is very \n", + "convenient for debugging as entire structure of the model can be inspected directly." ] }, { @@ -81,60 +89,35 @@ "name": "stdout", "output_type": "stream", "text": [ - "model = Linear(\n", - " din=2,\n", - " dout=3\n", - ")\n", - "model.w.value = Array([[0.9913868 , 0.45571804, 0.7215481 ],\n", - " [0.8873962 , 0.2008096 , 0.72537684]], dtype=float32)\n", - "model.b.value = Array([0., 0., 0.], dtype=float32)\n" + "[[1.245453 0.74195766 0.8553282 0.6763327 1.2617068 ]]\n" ] - } - ], - "source": [ - "model = Linear(din=2, dout=3, rngs=nnx.Rngs(params=0))\n", - "\n", - "print(f'{model = }')\n", - "print(f'{model.w.value = }')\n", - "print(f'{model.b.value = }')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "This is very handy for debugging as it allows accessing the entire structure or\n", - "modifying it. Similarly, computations can be ran directly." - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ + }, { "data": { + "text/html": [ + "
(Loading...)
" + ], "text/plain": [ - "Array([[1.878783 , 0.65652764, 1.4469249 ]], dtype=float32)" + "" ] }, - "execution_count": 4, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "x = jnp.ones((1, 2))\n", + "model = Linear(2, 5, rngs=nnx.Rngs(params=0))\n", + "y = model(x=jnp.ones((1, 2)))\n", "\n", - "model(x)" + "print(y)\n", + "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Since Modules hold their own state there is no need for a separate `apply` method, as in\n", - "Linen or Haiku." + "The above visualization by `nnx.display` is generated using the awesome [Penzai](https://penzai.readthedocs.io/en/stable/index.html#) library." ] }, { @@ -143,30 +126,31 @@ "source": [ "### Stateful Computation\n", "\n", - "When implementing layers like Batch Normalization or Multi Head Attention with \n", - "autoregressive decoding you often need to store and update state inside a Module \n", - "during the forward pass. The way to do this in NNX is simply to store the state \n", - "inside a `Variable` and update it in-place when need it." + "Implementing layers such as `BatchNorm` requires performing state updates during the \n", + "forward pass. To implement this in NNX you just create a `Variable` and update its \n", + "`.value` during the forward pass." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "counter.count.value = 0\n", - "counter.count.value = 1\n" + "counter.count.value = Array(0, dtype=int32, weak_type=True)\n", + "counter.count.value = Array(1, dtype=int32, weak_type=True)\n" ] } ], "source": [ + "class Count(nnx.Variable): pass\n", + "\n", "class Counter(nnx.Module):\n", " def __init__(self):\n", - " self.count = nnx.Variable(0)\n", + " self.count = Count(jnp.array(0))\n", "\n", " def __call__(self):\n", " self.count.value += 1\n", @@ -181,11 +165,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "JAX frameworks have avoided mutable references until now. The key innovations which \n", - "allows their usage in NNX is that 1) there is a clear boundary between code that uses \n", - "reference semantics and code that uses value semantics, defined by \n", - "[The Functional API](#the-functional-api), and 2) there are guards in place to avoid \n", - "updating NNX objects from a `MainTrace`, thus preventing tracer leakage." + "Mutable references are usually avoided in JAX, however as we'll see in later sections\n", + "NNX provides sound mechanisms to handle them." ] }, { @@ -194,70 +175,133 @@ "source": [ "### Nested Modules\n", "\n", - "As expected, Modules can be used to compose other Modules in a nested\n", - "structure, including standard Modules such as `nnx.Linear`,\n", - "`nnx.Conv`, etc., or any custom Module created by users. Modules can\n", - "be assigned as attributes of a Module, but as shown by `MLP.blocks` in the\n", - "example below, they can also be stored in attributes of type `list`, `dict`, `tuple`, \n", - "or in nested structures of the same." + "As expected, Modules can be used to compose other Modules in a nested structure, these can \n", + "be assigned directly as attributes, or inside an attribute of any (nested) pytree type e.g.\n", + " `list`, `dict`, `tuple`, etc. In the example below we define a simple `MLP` Module that\n", + "consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "model = MLP(\n", - " blocks=[Block(\n", - " linear=Linear(\n", - " in_features=2,\n", - " out_features=2,\n", - " use_bias=True,\n", - " dtype=None,\n", - " param_dtype=,\n", - " precision=None,\n", - " kernel_init=.init at 0x13cfa4040>,\n", - " bias_init=,\n", - " dot_general=\n", - " ),\n", - " bn=BatchNorm(\n", - " num_features=2,\n", - " ...\n" - ] + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "class Block(nnx.Module):\n", - " def __init__(self, dim: int, *, rngs: nnx.Rngs):\n", - " self.linear = nnx.Linear(dim, dim, rngs=rngs)\n", - " self.bn = nnx.BatchNorm(dim, use_running_average=True, rngs=rngs)\n", + "class MLP(nnx.Module):\n", + " def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs):\n", + " self.linear1 = Linear(din, dmid, rngs=rngs)\n", + " self.dropout = nnx.Dropout(rate=0.1, rngs=rngs)\n", + " self.bn = nnx.BatchNorm(dmid, rngs=rngs)\n", + " self.linear2 = Linear(dmid, dout, rngs=rngs)\n", "\n", " def __call__(self, x: jax.Array):\n", - " return nnx.relu(self.bn(self.linear(x)))\n", - " \n", - "class MLP(nnx.Module):\n", - " def __init__(self, num_layers: int, dim: int, *, rngs: nnx.Rngs):\n", - " self.blocks = [Block(dim, rngs=rngs) for _ in range(num_layers)]\n", + " x = nnx.gelu(self.dropout(self.bn(self.linear1(x))))\n", + " return self.linear2(x)\n", " \n", + "model = MLP(2, 16, 5, rngs=nnx.Rngs(0))\n", + "\n", + "y = model(x=jnp.ones((3, 2)))\n", + "\n", + "nnx.display(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In NNX `Dropout` is a stateful module that stores an `Rngs` object so that it can generate\n", + "new masks during the forward pass without the need for the user to pass a new key each time." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Model Surgery\n", + "NNX Modules are mutable by default, this means their structure can be changed at any time, \n", + "this makes model surgery quite easy as any submodule attribute can be replaced with anything\n", + "else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, \n", + "`Variable`s can also be modified or replaced / shared.\n", + "\n", + "The following example shows how to replace the `Linear` layers in the `MLP` model\n", + "from before with `LoraLinear` layers." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "class LoraParam(nnx.Param): pass\n", + "\n", + "class LoraLinear(nnx.Module):\n", + " def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs):\n", + " self.linear = linear\n", + " self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank)))\n", + " self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout)))\n", + "\n", " def __call__(self, x: jax.Array):\n", - " for block in self.blocks:\n", - " x = block(x)\n", - " return x\n", - " \n", - "model = MLP(num_layers=5, dim=2, rngs=nnx.Rngs(0))\n", - "print(f'{model = }'[:500] + '...')" + " return self.linear(x) + x @ self.A @ self.B\n", + "\n", + "rngs = nnx.Rngs(0)\n", + "model = MLP(2, 32, 5, rngs=rngs)\n", + "\n", + "# model surgery\n", + "model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs)\n", + "model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs)\n", + "\n", + "y = model(x=jnp.ones((3, 2)))\n", + "\n", + "nnx.display(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "One of the benefits of NNX is that nested Modules as easy to inspect and\n", - "static analyzers, e.g., code completion, can help you while doing so." + "## NNX Transforms\n", + "\n", + "NNX Transforms extend JAX transforms to support Modules and other objects.\n", + "They are supersets of their equivalent JAX counterpart with the addition of\n", + "being aware of the object's state and providing additional APIs to transform \n", + "it. One of the main features of NNX Transforms is the preservation of reference semantics, \n", + "meaning that any mutation of the object graph that occurs inside the transform is\n", + "propagated outisde as long as its legal within the transform rules. In practice this\n", + "means that NNX programs can be express using imperative code, highly simplifying\n", + "the user experience.\n", + "\n", + "In the following example we define a `train_step` function that takes a `MLP` model,\n", + "an `Optimizer`, and a batch of data, and returns the loss for that step. The loss\n", + "and the gradients are computed using the `nnx.value_and_grad` transform over the\n", + "`loss_fn`. The gradients are passed to the optimizer's `update` method to update\n", + "the `model`'s parameters." ] }, { @@ -269,29 +313,59 @@ "name": "stdout", "output_type": "stream", "text": [ - "model.blocks[1].linear.kernel.value = Array([[-0.31410056, -0.9153769 ],\n", - " [-0.38879898, -0.12699318]], dtype=float32)\n", - "model.blocks[0].bn.scale.value = Array([1., 1.], dtype=float32)\n" + "loss = Array(1.0000279, dtype=float32)\n", + "optimizer.step.value = Array(1, dtype=uint32)\n" ] } ], "source": [ - "print(f'{model.blocks[1].linear.kernel.value = }')\n", - "print(f'{model.blocks[0].bn.scale.value = }')" + "import optax\n", + "\n", + "# MLP contains 2 Linear layers, 1 Dropout layer, 1 BatchNorm layer\n", + "model = MLP(2, 16, 10, rngs=nnx.Rngs(0))\n", + "optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing\n", + "\n", + "@nnx.jit # automatic state management\n", + "def train_step(model, optimizer, x, y):\n", + " def loss_fn(model: MLP):\n", + " y_pred = model(x)\n", + " return jnp.mean((y_pred - y) ** 2)\n", + "\n", + " loss, grads = nnx.value_and_grad(loss_fn)(model)\n", + " optimizer.update(grads) # inplace updates\n", + "\n", + " return loss\n", + "\n", + "x, y = jnp.ones((5, 2)), jnp.ones((5, 10))\n", + "loss = train_step(model, optimizer, x, y)\n", + "\n", + "print(f'{loss = }')\n", + "print(f'{optimizer.step.value = }')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Model Surgery\n", - "NNX Modules are mutable by default, this means their structure can be changed\n", - "at any time. Also, NNX's Module system supports reference sharing of Modules and\n", - "Variables.\n", + "Theres a couple of things happening in this example that are worth mentioning:\n", + "1. The updates to the `BatchNorm` and `Dropout` layer's state is automatically propagated\n", + " from within `loss_fn` to `train_step` all the way to the `model` reference outside.\n", + "2. `optimizer` holds a mutable reference to `model`, this relationship is preserved\n", + " inside the `train_step` function making it possible to update the model's parameters\n", + " using the optimizer alone.\n", "\n", - "This makes Model Surgery quite easy as any submodule could be replaced by\n", - "e.g., a pretrained Module, a shared Module, or even just a Module/function that\n", - "uses the same signature. More over, Variables can also be modified or shared." + "#### Scan over layers\n", + "Next lets take a look at a different example using `nnx.vmap` to create an\n", + "`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the\n", + "input (scan over layers). \n", + "\n", + "Notice the following:\n", + "1. The `create_model` function creates a (single) `MLP` object that is lifted by\n", + " `nnx.vmap` to have an additional dimension of size `axis_size`.\n", + "2. The `forward` function indexes the `MLP` object's state to get a different set of\n", + " parameters at each step.\n", + "3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and \n", + "`Dropout` layers from within `forward` to the `model` reference outside." ] }, { @@ -299,31 +373,53 @@ "execution_count": 8, "metadata": {}, "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "y.shape = (3, 10)\n" + ] + }, { "data": { + "text/html": [ + "
(Loading...)
" + ], "text/plain": [ - "Array([[0., 0.]], dtype=float32)" + "" ] }, - "execution_count": 8, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "# Module replacement\n", - "pretrained = Block(dim=2, rngs=nnx.Rngs(42)) # imagine this is pretrained\n", - "model.blocks[0] = pretrained\n", - "# adhoc Module sharing\n", - "model.blocks[3] = model.blocks[1]\n", - "# monkey patching\n", - "def awesome_layer(x): return x\n", - "model.blocks[2] = awesome_layer\n", - "\n", - "# Variable sharing (weight tying)\n", - "model.blocks[-1].linear.kernel = model.blocks[0].linear.kernel\n", - "\n", - "model(jnp.ones((1, 2)))" + "from functools import partial\n", + "\n", + "@partial(nnx.vmap, axis_size=5)\n", + "def create_model(rngs: nnx.Rngs):\n", + " return MLP(10, 32, 10, rngs=rngs)\n", + "\n", + "model = create_model(nnx.Rngs(0))\n", + "\n", + "@nnx.scan\n", + "def forward(x, model: MLP):\n", + " x = model(x)\n", + " return x, None\n", + "\n", + "x = jnp.ones((3, 10))\n", + "y, _ = forward(x, model)\n", + "\n", + "print(f'{y.shape = }')\n", + "nnx.display(model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "How do NNX transforms achieve this? To understand how NNX objects interact with\n", + "JAX transforms lets take a look at the Functional API." ] }, { @@ -347,7 +443,20 @@ "cell_type": "code", "execution_count": 9, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "class Count(nnx.Variable): pass\n", "\n", @@ -361,7 +470,8 @@ " self.count.value += 1\n", " return x @ self.w.value + self.b.value\n", " \n", - "model = StatefulLinear(din=2, dout=3, rngs=nnx.Rngs(0))" + "model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0))\n", + "nnx.display(model)" ] }, { @@ -371,7 +481,7 @@ "### State and GraphDef\n", "\n", "A Module can be decomposed into `GraphDef` and `State` using the\n", - "`.split()` method. State is a Mapping from strings to Variables or nested \n", + "`split` function. State is a Mapping from strings to Variables or nested \n", "States. GraphDef contains all the static information needed to reconstruct \n", "a Module graph, it is analogous to JAX's `PyTreeDef`." ] @@ -382,33 +492,34 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "state = State({\n", - " 'b': Param(\n", - " raw_value=Array([0., 0., 0.], dtype=float32)\n", - " ),\n", - " 'count': Count(\n", - " raw_value=0\n", - " ),\n", - " 'w': Param(\n", - " raw_value=Array([[0.9913868 , 0.45571804, 0.7215481 ],\n", - " [0.8873962 , 0.2008096 , 0.72537684]], dtype=float32)\n", - " )\n", - "})\n", - "\n", - "graphdef = GraphDef(nodedef=NodeDef(type=, index=0, attributes=('b', 'count', 'w'), subgraphs={}, static_fields={}, variables={'b': VariableDef(\n", - " type=Param,\n", - " index=...\n" - ] + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ - "graphdef, state = model.split()\n", + "graphdef, state = nnx.split(model)\n", "\n", - "print(f'{state = }\\n')\n", - "print(f'{graphdef = }'[:200] + '...')" + "nnx.display(graphdef, state)" ] }, { @@ -420,8 +531,8 @@ "`merge` is the reverse of `split`, it takes the GraphDef + State and reconstructs\n", "the Module. As shown in the example below, by using `split` and `merge` in sequence\n", "any Module can be lifted to be used in any JAX transform. `update` can\n", - "update a Module structure from a compatible State. This is often used to propagate the state\n", - "updates from a transform back to the source object outside." + "update an object inplace with the content of a given State. This pattern is used to \n", + "propagate the state from a transform back to the source object outside." ] }, { @@ -433,32 +544,30 @@ "name": "stdout", "output_type": "stream", "text": [ - "model.count = Count(\n", - " raw_value=0\n", - ")\n", + "model.count.value = 0\n", "model.count.value = Array(1, dtype=int32, weak_type=True)\n" ] } ], "source": [ - "print(f'{model.count = }')\n", + "print(f'{model.count.value = }')\n", "\n", "# 1. Use split to create a pytree representation of the Module\n", - "graphdef, state = model.split()\n", + "graphdef, state = nnx.split(model)\n", "\n", "@jax.jit\n", "def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]:\n", " # 2. Use merge to create a new model inside the JAX transformation\n", - " model = graphdef.merge(state)\n", + " model = nnx.merge(graphdef, state)\n", " # 3. Call the Module\n", " y = model(x)\n", " # 4. Use split to propagate State updates\n", - " _, state = model.split()\n", + " _, state = nnx.split(model)\n", " return y, state\n", "\n", - "y, state = forward(graphdef, state, x=jnp.ones((1, 2)))\n", + "y, state = forward(graphdef, state, x=jnp.ones((1, 3)))\n", "# 5. Update the state of the original Module\n", - "model.update(state)\n", + "nnx.update(model, state)\n", "\n", "print(f'{model.count.value = }')" ] @@ -503,40 +612,42 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "params = State({\n", - " 'b': Param(\n", - " raw_value=Array([0., 0., 0.], dtype=float32)\n", - " ),\n", - " 'w': Param(\n", - " raw_value=Array([[0.9913868 , 0.45571804, 0.7215481 ],\n", - " [0.8873962 , 0.2008096 , 0.72537684]], dtype=float32)\n", - " )\n", - "})\n", - "\n", - "counts = State({\n", - " 'count': Count(\n", - " raw_value=Array(1, dtype=int32, weak_type=True)\n", - " )\n", - "})\n" - ] + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
(Loading...)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ "# use Variable type filters to split into multiple States\n", - "graphdef, params, counts = model.split(nnx.Param, Count)\n", + "graphdef, params, counts = nnx.split(model, nnx.Param, Count)\n", "\n", - "print(f'{params = }\\n')\n", - "print(f'{counts = }')" + "nnx.display(params, counts)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "**Note**: filters must be exhaustive, if a Variable is not matched an error will be raised.\n", + "Note that filters must be exhaustive, if a value is not matched an error will be raised.\n", "\n", "As expected the `merge` and `update` methods naturally consume multiple States:" ] @@ -548,9 +659,9 @@ "outputs": [], "source": [ "# merge multiple States\n", - "model = graphdef.merge(params, counts)\n", + "model = nnx.merge(graphdef, params, counts)\n", "# update with multiple States\n", - "model.update(params, counts)" + "nnx.update(model, params, counts)" ] } ], @@ -568,7 +679,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.18" + "version": "3.10.13" } }, "nbformat": 4, diff --git a/docs/experimental/nnx/nnx_basics.md b/docs/experimental/nnx/nnx_basics.md index 1eed38744a..c27ae068ac 100644 --- a/docs/experimental/nnx/nnx_basics.md +++ b/docs/experimental/nnx/nnx_basics.md @@ -10,15 +10,15 @@ jupytext: # NNX Basics -NNX is a **N**eural **N**etworks JA**X** library that embraces Python’s object-oriented -programming model to provide an intuitive and highly simplified user experience. It -represents objects as PyGraphs (instead of PyTrees), which allows NNX to handle reference -sharing and mutability, making model code be regular Python code that users from frameworks -like Pytorch will be familiar with.be familiar with. +NNX is a **N**eural **N**etwork library for JA**X** that focuses on providing the best +development experience, so building and experimenting with neural networks is easy and +intuitive. It achieves this by representing objects as PyGraphs (instead of PyTrees), +enabling reference sharing and mutability. This design allows your models to resemble +familiar Python object-oriented code, particularly appealing to users of frameworks +like PyTorch. -NNX is also designed to support -all the patterns that allowed Linen to scale to large code bases while having a much simpler -implementation. +Despite its simplified implementation, NNX supports the same powerful design patterns +that have allowed Linen to scale effectively to large codebases. ```{code-cell} ipython3 from flax.experimental import nnx @@ -27,11 +27,18 @@ import jax.numpy as jnp ``` ## The Module System -To begin lets see how to create a `Linear` Module using NNX. The main noticeable -difference between NNX and Module systems like Haiku or Linen is that in NNX everything is -**explicit**. This means among other things that 1) the Module itself holds the state -(e.g. parameters) directly, 2) the RNG state is threaded by the user, and 3) all shape information -must be provided on initialization (no shape inference). +To begin lets see how to create a `Linear` Module using NNX. The main difference between +NNX and Module systems like Haiku or Linen is that in NNX everything is **explicit**. This +means among other things that 1) the Module itself holds the state (e.g. parameters) directly, +2) the RNG state is threaded by the user, and 3) all shape information must be provided on +initialization (no shape inference). + +As shown next, dynamic state is usually stored in `nnx.Param`s, and static state +(all types not handled by NNX) such as integers or strings are stored directly. +Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic +state,although storing them inside `nnx.Variable`s such as `Param` is preferred. +Also, the `nnx.Rngs` object by can be used to get new unique keys based on a root +key passed to the constructor. ```{code-cell} ipython3 class Linear(nnx.Module): @@ -42,52 +49,44 @@ class Linear(nnx.Module): self.din, self.dout = din, dout def __call__(self, x: jax.Array): - return x @ self.w.value + self.b.value + return x @ self.w + self.b ``` -As shown above dynamic state is usually stored in `nnx.Param`s, -and static state (all types not handled by NNX) such as integers or strings -are stored directly. Attributes of type `jax.Array` and `numpy.ndarray` are also treated as dynamic state, -although storing them inside `nnx.Variable`s is preferred. Also, the `nnx.Rngs` object by can be used to -get new unique keys based on a root key passed to the constructor (see below). - -To actually initialize a Module is very easy: simply call the constructor. All the -parameters of a Module will be created right then and there, and are immediately available -for inspection using regular Python attribute access. - -```{code-cell} ipython3 -model = Linear(din=2, dout=3, rngs=nnx.Rngs(params=0)) - -print(f'{model = }') -print(f'{model.w.value = }') -print(f'{model.b.value = }') -``` +`nnx.Variable`'s inner values can be accessed using the `.value` property, however +for convenience they implement all numeric operators and can be used directly in +arithmetic expressions (as shown above). Additionally, Variables can passed +to any JAX function as they implement the `__jax_array__` protocol (as long as their +inner value is a JAX array). -This is very handy for debugging as it allows accessing the entire structure or -modifying it. Similarly, computations can be ran directly. +To actually initialize a Module you simply call the constructor, all the parameters +of a Module are usually created eagerly. Since Modules hold their own state methods +can be called directly without the no need for a separate `apply` method, this is very +convenient for debugging as entire structure of the model can be inspected directly. ```{code-cell} ipython3 -x = jnp.ones((1, 2)) +model = Linear(2, 5, rngs=nnx.Rngs(params=0)) +y = model(x=jnp.ones((1, 2))) -model(x) +print(y) +nnx.display(model) ``` -Since Modules hold their own state there is no need for a separate `apply` method, as in -Linen or Haiku. +The above visualization by `nnx.display` is generated using the awesome [Penzai](https://penzai.readthedocs.io/en/stable/index.html#) library. +++ ### Stateful Computation -When implementing layers like Batch Normalization or Multi Head Attention with -autoregressive decoding you often need to store and update state inside a Module -during the forward pass. The way to do this in NNX is simply to store the state -inside a `Variable` and update it in-place when need it. +Implementing layers such as `BatchNorm` requires performing state updates during the +forward pass. To implement this in NNX you just create a `Variable` and update its +`.value` during the forward pass. ```{code-cell} ipython3 +class Count(nnx.Variable): pass + class Counter(nnx.Module): def __init__(self): - self.count = nnx.Variable(0) + self.count = Count(jnp.array(0)) def __call__(self): self.count.value += 1 @@ -98,78 +97,163 @@ counter() print(f'{counter.count.value = }') ``` -JAX frameworks have avoided mutable references until now. The key innovations which -allows their usage in NNX is that 1) there is a clear boundary between code that uses -reference semantics and code that uses value semantics, defined by -[The Functional API](#the-functional-api), and 2) there are guards in place to avoid -updating NNX objects from a `MainTrace`, thus preventing tracer leakage. +Mutable references are usually avoided in JAX, however as we'll see in later sections +NNX provides sound mechanisms to handle them. +++ ### Nested Modules -As expected, Modules can be used to compose other Modules in a nested -structure, including standard Modules such as `nnx.Linear`, -`nnx.Conv`, etc., or any custom Module created by users. Modules can -be assigned as attributes of a Module, but as shown by `MLP.blocks` in the -example below, they can also be stored in attributes of type `list`, `dict`, `tuple`, -or in nested structures of the same. +As expected, Modules can be used to compose other Modules in a nested structure, these can +be assigned directly as attributes, or inside an attribute of any (nested) pytree type e.g. + `list`, `dict`, `tuple`, etc. In the example below we define a simple `MLP` Module that +consists of two `Linear` layers, a `Dropout` layer, and a `BatchNorm` layer. ```{code-cell} ipython3 -class Block(nnx.Module): - def __init__(self, dim: int, *, rngs: nnx.Rngs): - self.linear = nnx.Linear(dim, dim, rngs=rngs) - self.bn = nnx.BatchNorm(dim, use_running_average=True, rngs=rngs) - - def __call__(self, x: jax.Array): - return nnx.relu(self.bn(self.linear(x))) - class MLP(nnx.Module): - def __init__(self, num_layers: int, dim: int, *, rngs: nnx.Rngs): - self.blocks = [Block(dim, rngs=rngs) for _ in range(num_layers)] - + def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): + self.linear1 = Linear(din, dmid, rngs=rngs) + self.dropout = nnx.Dropout(rate=0.1, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.linear2 = Linear(dmid, dout, rngs=rngs) + def __call__(self, x: jax.Array): - for block in self.blocks: - x = block(x) - return x + x = nnx.gelu(self.dropout(self.bn(self.linear1(x)))) + return self.linear2(x) -model = MLP(num_layers=5, dim=2, rngs=nnx.Rngs(0)) -print(f'{model = }'[:500] + '...') +model = MLP(2, 16, 5, rngs=nnx.Rngs(0)) + +y = model(x=jnp.ones((3, 2))) + +nnx.display(model) ``` -One of the benefits of NNX is that nested Modules as easy to inspect and -static analyzers, e.g., code completion, can help you while doing so. +In NNX `Dropout` is a stateful module that stores an `Rngs` object so that it can generate +new masks during the forward pass without the need for the user to pass a new key each time. + ++++ + +#### Model Surgery +NNX Modules are mutable by default, this means their structure can be changed at any time, +this makes model surgery quite easy as any submodule attribute can be replaced with anything +else e.g. new Modules, existing shared Modules, Modules of different types, etc. More over, +`Variable`s can also be modified or replaced / shared. + +The following example shows how to replace the `Linear` layers in the `MLP` model +from before with `LoraLinear` layers. ```{code-cell} ipython3 -print(f'{model.blocks[1].linear.kernel.value = }') -print(f'{model.blocks[0].bn.scale.value = }') +class LoraParam(nnx.Param): pass + +class LoraLinear(nnx.Module): + def __init__(self, linear: Linear, rank: int, rngs: nnx.Rngs): + self.linear = linear + self.A = LoraParam(jax.random.normal(rngs(), (linear.din, rank))) + self.B = LoraParam(jax.random.normal(rngs(), (rank, linear.dout))) + + def __call__(self, x: jax.Array): + return self.linear(x) + x @ self.A @ self.B + +rngs = nnx.Rngs(0) +model = MLP(2, 32, 5, rngs=rngs) + +# model surgery +model.linear1 = LoraLinear(model.linear1, 4, rngs=rngs) +model.linear2 = LoraLinear(model.linear2, 4, rngs=rngs) + +y = model(x=jnp.ones((3, 2))) + +nnx.display(model) ``` -#### Model Surgery -NNX Modules are mutable by default, this means their structure can be changed -at any time. Also, NNX's Module system supports reference sharing of Modules and -Variables. +## NNX Transforms + +NNX Transforms extend JAX transforms to support Modules and other objects. +They are supersets of their equivalent JAX counterpart with the addition of +being aware of the object's state and providing additional APIs to transform +it. One of the main features of NNX Transforms is the preservation of reference semantics, +meaning that any mutation of the object graph that occurs inside the transform is +propagated outisde as long as its legal within the transform rules. In practice this +means that NNX programs can be express using imperative code, highly simplifying +the user experience. + +In the following example we define a `train_step` function that takes a `MLP` model, +an `Optimizer`, and a batch of data, and returns the loss for that step. The loss +and the gradients are computed using the `nnx.value_and_grad` transform over the +`loss_fn`. The gradients are passed to the optimizer's `update` method to update +the `model`'s parameters. + +```{code-cell} ipython3 +import optax + +# MLP contains 2 Linear layers, 1 Dropout layer, 1 BatchNorm layer +model = MLP(2, 16, 10, rngs=nnx.Rngs(0)) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing + +@nnx.jit # automatic state management +def train_step(model, optimizer, x, y): + def loss_fn(model: MLP): + y_pred = model(x) + return jnp.mean((y_pred - y) ** 2) + + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) # inplace updates + + return loss -This makes Model Surgery quite easy as any submodule could be replaced by -e.g., a pretrained Module, a shared Module, or even just a Module/function that -uses the same signature. More over, Variables can also be modified or shared. +x, y = jnp.ones((5, 2)), jnp.ones((5, 10)) +loss = train_step(model, optimizer, x, y) + +print(f'{loss = }') +print(f'{optimizer.step.value = }') +``` + +Theres a couple of things happening in this example that are worth mentioning: +1. The updates to the `BatchNorm` and `Dropout` layer's state is automatically propagated + from within `loss_fn` to `train_step` all the way to the `model` reference outside. +2. `optimizer` holds a mutable reference to `model`, this relationship is preserved + inside the `train_step` function making it possible to update the model's parameters + using the optimizer alone. + +#### Scan over layers +Next lets take a look at a different example using `nnx.vmap` to create an +`MLP` stack and `nnx.scan` to iteratively apply each layer in the stack to the +input (scan over layers). + +Notice the following: +1. The `create_model` function creates a (single) `MLP` object that is lifted by + `nnx.vmap` to have an additional dimension of size `axis_size`. +2. The `forward` function indexes the `MLP` object's state to get a different set of + parameters at each step. +3. `nnx.scan` automatically propagates the state updates for the `BatchNorm` and +`Dropout` layers from within `forward` to the `model` reference outside. ```{code-cell} ipython3 -# Module replacement -pretrained = Block(dim=2, rngs=nnx.Rngs(42)) # imagine this is pretrained -model.blocks[0] = pretrained -# adhoc Module sharing -model.blocks[3] = model.blocks[1] -# monkey patching -def awesome_layer(x): return x -model.blocks[2] = awesome_layer - -# Variable sharing (weight tying) -model.blocks[-1].linear.kernel = model.blocks[0].linear.kernel - -model(jnp.ones((1, 2))) +from functools import partial + +@partial(nnx.vmap, axis_size=5) +def create_model(rngs: nnx.Rngs): + return MLP(10, 32, 10, rngs=rngs) + +model = create_model(nnx.Rngs(0)) + +@nnx.scan +def forward(x, model: MLP): + x = model(x) + return x, None + +x = jnp.ones((3, 10)) +y, _ = forward(x, model) + +print(f'{y.shape = }') +nnx.display(model) ``` +How do NNX transforms achieve this? To understand how NNX objects interact with +JAX transforms lets take a look at the Functional API. + ++++ + ## The Functional API The Functional API establishes a clear boundary between reference/object semantics and @@ -195,21 +279,21 @@ class StatefulLinear(nnx.Module): self.count.value += 1 return x @ self.w.value + self.b.value -model = StatefulLinear(din=2, dout=3, rngs=nnx.Rngs(0)) +model = StatefulLinear(din=3, dout=5, rngs=nnx.Rngs(0)) +nnx.display(model) ``` ### State and GraphDef A Module can be decomposed into `GraphDef` and `State` using the -`.split()` method. State is a Mapping from strings to Variables or nested +`split` function. State is a Mapping from strings to Variables or nested States. GraphDef contains all the static information needed to reconstruct a Module graph, it is analogous to JAX's `PyTreeDef`. ```{code-cell} ipython3 -graphdef, state = model.split() +graphdef, state = nnx.split(model) -print(f'{state = }\n') -print(f'{graphdef = }'[:200] + '...') +nnx.display(graphdef, state) ``` ### Split, Merge, and Update @@ -217,28 +301,28 @@ print(f'{graphdef = }'[:200] + '...') `merge` is the reverse of `split`, it takes the GraphDef + State and reconstructs the Module. As shown in the example below, by using `split` and `merge` in sequence any Module can be lifted to be used in any JAX transform. `update` can -update a Module structure from a compatible State. This is often used to propagate the state -updates from a transform back to the source object outside. +update an object inplace with the content of a given State. This pattern is used to +propagate the state from a transform back to the source object outside. ```{code-cell} ipython3 -print(f'{model.count = }') +print(f'{model.count.value = }') # 1. Use split to create a pytree representation of the Module -graphdef, state = model.split() +graphdef, state = nnx.split(model) @jax.jit def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array) -> tuple[jax.Array, nnx.State]: # 2. Use merge to create a new model inside the JAX transformation - model = graphdef.merge(state) + model = nnx.merge(graphdef, state) # 3. Call the Module y = model(x) # 4. Use split to propagate State updates - _, state = model.split() + _, state = nnx.split(model) return y, state -y, state = forward(graphdef, state, x=jnp.ones((1, 2))) +y, state = forward(graphdef, state, x=jnp.ones((1, 3))) # 5. Update the state of the original Module -model.update(state) +nnx.update(model, state) print(f'{model.count.value = }') ``` @@ -271,19 +355,18 @@ types as shown below. ```{code-cell} ipython3 # use Variable type filters to split into multiple States -graphdef, params, counts = model.split(nnx.Param, Count) +graphdef, params, counts = nnx.split(model, nnx.Param, Count) -print(f'{params = }\n') -print(f'{counts = }') +nnx.display(params, counts) ``` -**Note**: filters must be exhaustive, if a Variable is not matched an error will be raised. +Note that filters must be exhaustive, if a value is not matched an error will be raised. As expected the `merge` and `update` methods naturally consume multiple States: ```{code-cell} ipython3 # merge multiple States -model = graphdef.merge(params, counts) +model = nnx.merge(graphdef, params, counts) # update with multiple States -model.update(params, counts) +nnx.update(model, params, counts) ``` diff --git a/docs/experimental/nnx/transforms.rst b/docs/experimental/nnx/transforms.rst index c49f438e29..1e6bee1c59 100644 --- a/docs/experimental/nnx/transforms.rst +++ b/docs/experimental/nnx/transforms.rst @@ -47,34 +47,33 @@ the transformed function. def loss_fn(model): return ((model(x) - y) ** 2).mean() grads = nnx.grad(loss_fn)(model) - - model.update( - jax.tree_util.tree_map( - lambda p, g: p - 0.1 * g, model.extract(nnx.Param), grads - ) + params = nnx.state(model, nnx.Param) + params = jax.tree_util.tree_map( + lambda p, g: p - 0.1 * g, params, grads ) + nnx.update(model, params) model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) train_step(model, x, y) --- @jax.jit #! - def train_step(static, state, x, y): #! - def loss_fn(static, state): #! - model = static.merge(state) #! + def train_step(graphdef, state, x, y): #! + def loss_fn(graphdef, state): #! + model = nnx.merge(graphdef, state) #! return ((model(x) - y) ** 2).mean() - grads = jax.grad(loss_fn, argnums=1)(static, state) #! + grads = jax.grad(loss_fn, argnums=1)(graphdef, state) #! - model = static.merge(state) #! - model.update( - jax.tree_util.tree_map( - lambda p, g: p - 0.1 * g, model.extract(nnx.Param), grads - ) + model = nnx.merge(graphdef, state) #! + params = nnx.state(model, nnx.Param) + params = jax.tree_util.tree_map( + lambda p, g: p - 0.1 * g, params, grads ) - return model.split() #! + nnx.update(model, params) + return nnx.split(model) #! - static, state = nnx.Linear(2, 3, rngs=nnx.Rngs(0)).split() #! - static, state = train_step(static, state, x, y) #! + graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0))) #! + graphdef, state = train_step(graphdef, state, x, y) #! Mixing NNX and JAX transformations @@ -90,36 +89,34 @@ pure and has valid argument types that are recognized by JAX. @nnx.jit def train_step(model, x, y): - def loss_fn(static, state): #! - model = static.merge(state) + def loss_fn(graphdef, state): #! + model = nnx.merge(graphdef, state) return ((model(x) - y) ** 2).mean() - grads = jax.grad(loss_fn, 1)(*model.split()) #! - - model.update( - jax.tree_util.tree_map( - lambda p, g: p - 0.1 * g, model.extract(nnx.Param), grads - ) + grads = jax.grad(loss_fn, 1)(*nnx.split(model)) #! + params = nnx.state(model, nnx.Param) + params = jax.tree_util.tree_map( + lambda p, g: p - 0.1 * g, params, grads ) + nnx.update(model, params) model = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) train_step(model, x, y) --- @jax.jit #! - def train_step(static, state, x, y): #! - model = static.merge(state) + def train_step(graphdef, state, x, y): #! + model = nnx.merge(graphdef, state) def loss_fn(model): return ((model(x) - y) ** 2).mean() grads = nnx.grad(loss_fn)(model) - - model.update( - jax.tree_util.tree_map( - lambda p, g: p - 0.1 * g, model.extract(nnx.Param), grads - ) + params = nnx.state(model, nnx.Param) + params = jax.tree_util.tree_map( + lambda p, g: p - 0.1 * g, params, grads ) - return model.split() + nnx.update(model, params) + return nnx.split(model) - static, state = nnx.Linear(2, 3, rngs=nnx.Rngs(0)).split() - static, state = train_step(static, state, x, y) + graphdef, state = nnx.split(nnx.Linear(2, 3, rngs=nnx.Rngs(0))) + graphdef, state = train_step(graphdef, state, x, y) diff --git a/docs/requirements.txt b/docs/requirements.txt index 600735f0b0..f9b5ce9a99 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -14,6 +14,7 @@ ipython_genutils sphinx-design jupytext==1.13.8 dm-haiku +penzai; python_version>='3.10' # Need to pin docutils to 0.16 to make bulleted lists appear correctly on # ReadTheDocs: https://stackoverflow.com/a/68008428 @@ -32,3 +33,4 @@ tensorflow_text>=2.11.0 # WMT example # notebooks einops + diff --git a/flax/experimental/nnx/README.md b/flax/experimental/nnx/README.md index 9a78ab20c1..cc00e1358e 100644 --- a/flax/experimental/nnx/README.md +++ b/flax/experimental/nnx/README.md @@ -4,135 +4,67 @@ _**N**eural **N**etworks for JA**X**_ - | [docs](https://flax.readthedocs.io/en/latest/experimental/nnx/index.html) | -NNX is a JAX-based neural network library designed for simplicity and power. Its modular approach follows standard Python conventions, making it both intuitive and compatible with the broader JAX ecosystem. +NNX is a JAX-based neural network library that focuses on providing the best development experience to make +building and experimenting with neural networks as easy and intuitive as possible. * **Pythonic**: Modules are standard Python classes, promoting ease of use and a more familiar development experience. -* **Compatible**: Effortlessly convert between Modules and pytrees using the Functional API for maximum flexibility. -* **Control**: Manage a Module's state with precision using typed Variable collections, enabling fine-grained control - on JAX transformations. -* **User-friendly**: NNX prioritizes simplicity for common use cases, building upon lessons learned from Linen - to provide a streamlined experience. +* **Easy-to-use**: NNX provides a set of transforms that take care of state management, allowing + users to focus on building their models and training loops. +* **Expressive**: NNX allows fine-grained over the Module state with lifted transforms, enabling + users to define complex architectures. +* **Compatible**: NNX allows functionalizing Module state, making it possible to directly use JAX + transformations when needed. > [!NOTE] > NNX is currently in an experimental state and is subject to change. Linen is still the recommended option for large-scale projects. Feedback and contributions are welcome! -## Installation - -To get started with `nnx`, install Flax from GitHub: -``` -pip install git+https://github.com/google/flax.git -``` - ## What does NNX look like? -We provide three examples using the NNX API: a simple multi-layer perceptron, a CNN and an auto-encoder. - -To learn more about the `Module` abstraction, check out our [NNX Basics](https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html#) guide. +NNX removes most of the friction from building and training neural networks in JAX. It provides +a Module system that uses standard Python classes, and a set of transforms that extend +JAX to handle objects. ```python -import jax -import jax.numpy as jnp - from flax.experimental import nnx +import optax - -class MLP(nnx.Module): - def __init__(self, features: list[int], *, rngs: nnx.Rngs): - self.layers = [ - nnx.Linear(din, dout, rngs=rngs) - for din, dout in zip(features[:-1], features[1:]) - ] - - def __call__(self, x: jax.Array) -> jax.Array: - for layer in self.layers[:-1]: - x = nnx.relu(layer(x)) - x = self.layers[-1](x) - return x - - -model = MLP([784, 64, 32, 10], rngs=nnx.Rngs(0)) -y = model(jnp.ones((1, 784))) -``` - -```python -class CNN(nnx.Module): - def __init__(self, *, rngs: nnx.Rngs): - self.conv1 = nnx.Conv(1, 64, kernel_size=(3, 3), rngs=rngs) - self.conv2 = nnx.Conv(64, 32, kernel_size=(3, 3), rngs=rngs) - self.linear1 = nnx.Linear(7 * 7 * 32, 256, rngs=rngs) - self.linear2 = nnx.Linear(256, 10, rngs=rngs) +class Model(nnx.Module): + def __init__(self, din, dmid, dout, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dmid, rngs=rngs) + self.bn = nnx.BatchNorm(dmid, rngs=rngs) + self.dropout = nnx.Dropout(0.2, rngs=rngs) + self.linear_out = nnx.Linear(dmid, dout, rngs=rngs) def __call__(self, x): - x = nnx.relu(self.conv1(x)) - x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = nnx.relu(self.conv2(x)) - x = nnx.avg_pool(x, window_shape=(2, 2), strides=(2, 2)) - x = x.reshape((x.shape[0], -1)) # flatten - x = nnx.relu(self.linear1(x)) - logits = self.linear2(x) - return logits - - -model = CNN(rngs=nnx.Rngs(0)) -x = jnp.ones((1, 28, 28, 1)) # (N, H, W, C) format -logits = model(x) -``` + x = nnx.relu(self.dropout(self.bn(self.linear(x)))) + return self.linear_out(x) -```python -class AutoEncoder(nnx.Module): - def __init__( - self, - input_features: int, - encoder_features: list[int], - decoder_features: list[int], - *, - rngs: nnx.Rngs, - ): - self.encoder = MLP([input_features, *encoder_features], rngs=rngs) - self.decoder = MLP([*decoder_features, input_features], rngs=rngs) - def __call__(self, x): - return self.decode(self.encode(x)) - - def encode(self, x): - return self.encoder(x) +model = Model(2, 64, 3, rngs=nnx.Rngs(0)) # eager initialization +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) # reference sharing - def decode(self, z): - return nnx.sigmoid(self.decoder(z)) +@nnx.jit # automatic state management +def train_step(model, optimizer, x, y): + def loss_fn(model): + y_pred = model(x) # call methods directly + return ((y_pred - y) ** 2).mean() + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) # inplace updates -model = AutoEncoder( - input_features=784, - encoder_features=[64, 32], - decoder_features=[32, 64], - rngs=nnx.Rngs(0), -) -x = jnp.ones((1, 784)) -z = model.encode(x) -y = model.decode(z) + return loss ``` -### Interacting with JAX - -To interact with JAX NNX provides the [Functional API](https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html#the-functional-api) which consists of 3 simple methods: `split`, `merge`, and `update`. Using these methods any Module can be lifted to be used in JAX transformations. Here is a simple jitted `forward` function as an example: - -```pythonthon -state, static = model.split() - -@jax.jit -def forward(static: nnx.ModuleDef, state: nnx.State, x: jax.Array): - model = static.merge(state) - y = model(x) - _, state = model.split() - return y, state +To learn more about the `Module` abstraction, check out our [NNX Basics](https://flax.readthedocs.io/en/latest/experimental/nnx/nnx_basics.html#) guide. -x = jnp.ones((2, 4)) -y, state = forward(static, state, x) +## Installation -model.update(state) +To get started with `nnx`, install Flax from GitHub: +``` +pip install git+https://github.com/google/flax.git ``` ### Examples @@ -140,9 +72,7 @@ model.update(state) * [LM1B](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/lm1b): A language model trained on the 1 Billion Word Benchmark dataset. #### Toy Examples +* [Basic Example](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using NNX. * [Using the Functional API](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/01_functional_api.py): Shows how to train a simple model using the functional API. -* [Using Lifted Transforms](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/02_lifted_transforms.py): Shows how to train a simple model using lifted transforms. -* [Using TrainState](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/03_train_state.py): Shows how to train a simple model using the functional API with the help of `TrainState`. -* [Training a VAE](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset, uses the functional API, `TrainState`, and shows how to use capture intermediate values to retrieve `kl_loss`. +* [Training a VAE](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/05_vae.py): Shows how to train a VAE on the binarized MNIST dataset. * [Scan over layers](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py): An contrived example that implements scan over layers with dropout and a share BatcNorm layer to showcase how lifted transforms can be implemented. It uses the functional API along with `jax.vmap` and `jax.lax.scan`. -* [Creating a Transformer](https://github.com/google/flax/tree/main/flax/experimental/nnx/examples/toy_examples/07_transformer.py): Shows how to create a Transformer with an auto-regressive decoder that uses scan over layers and a kv-cache for fast inference. Credits to @levskaya. diff --git a/flax/experimental/nnx/__init__.py b/flax/experimental/nnx/__init__.py index 9a8bb1e742..827542835f 100644 --- a/flax/experimental/nnx/__init__.py +++ b/flax/experimental/nnx/__init__.py @@ -19,22 +19,27 @@ from flax.typing import Initializer as Initializer from .nnx import compatibility as compatibility -from .nnx import graph_utils as graph_utils +from .nnx import graph as graph from .nnx import errors as errors from .nnx import errors as helpers from .nnx.filterlib import All as All from .nnx.filterlib import Not as Not -from .nnx.graph_utils import GraphDef as GraphDef -from .nnx.graph_utils import GraphNode as GraphNode +from .nnx.graph import GraphDef as GraphDef +from .nnx.graph import GraphNode as GraphNode from .nnx.helpers import Dict as Dict from .nnx.helpers import List as List from .nnx.helpers import Sequential as Sequential from .nnx.helpers import TrainState as TrainState from .nnx.module import M as M from .nnx.module import Module as Module -from .nnx.graph_utils import merge as merge -from .nnx.graph_utils import split as split -from .nnx.graph_utils import update as update +from .nnx.graph import merge as merge +from .nnx.graph import UpdateContext as UpdateContext +from .nnx.graph import split as split +from .nnx.graph import update as update +from .nnx.graph import clone as clone +from .nnx.graph import pop as pop +from .nnx.graph import state as state +from .nnx.graph import graphdef as graphdef from .nnx.nn import initializers as initializers from .nnx.nn.activations import celu as celu from .nnx.nn.activations import elu as elu @@ -79,6 +84,7 @@ from .nnx.rnglib import RngState as RngState from .nnx.rnglib import RngKey as RngKey from .nnx.rnglib import RngCount as RngCount +from .nnx.rnglib import fork as fork from .nnx.spmd import PARTITION_NAME as PARTITION_NAME from .nnx.spmd import get_partition_spec as get_partition_spec from .nnx.spmd import get_named_sharding as get_named_sharding @@ -90,7 +96,7 @@ from .nnx.training.metrics import Metric as Metric from .nnx.training.metrics import MultiMetric as MultiMetric from .nnx.training.optimizer import Optimizer as Optimizer -from .nnx.transforms import JIT as JIT +from .nnx.transforms import Jit as Jit from .nnx.transforms import Remat as Remat from .nnx.transforms import Scan as Scan from .nnx.transforms import Vmap as Vmap @@ -100,6 +106,7 @@ from .nnx.transforms import scan as scan from .nnx.transforms import value_and_grad as value_and_grad from .nnx.transforms import vmap as vmap +from .nnx.transforms import eval_shape as eval_shape from .nnx.variables import EMPTY as EMPTY from .nnx.variables import A as A from .nnx.variables import BatchStat as BatchStat @@ -107,7 +114,8 @@ from .nnx.variables import Empty as Empty from .nnx.variables import Intermediate as Intermediate from .nnx.variables import Param as Param -from .nnx.variables import Rng as Rng from .nnx.variables import Variable as Variable +from .nnx.variables import VariableState as VariableState from .nnx.variables import VariableMetadata as VariableMetadata from .nnx.variables import with_metadata as with_metadata +from .nnx.visualization import display as display diff --git a/flax/experimental/nnx/docs/demo.ipynb b/flax/experimental/nnx/docs/demo.ipynb index 4679af7468..ae71ad479a 100644 --- a/flax/experimental/nnx/docs/demo.ipynb +++ b/flax/experimental/nnx/docs/demo.ipynb @@ -204,7 +204,7 @@ " \n", "...\n", "\n", - "static = GraphDef(\n", + "graphdef = GraphDef(\n", " type=MLP,\n", " index=0,\n", " attributes=('blocks', 'count'),\n", @@ -223,13 +223,13 @@ } ], "source": [ - "static, state = model.split()\n", + "graphdef, state = model.split()\n", "\n", "# state is a dictionary-like JAX pytree\n", "print(f'{state = }'[:500] + '\\n...')\n", "\n", - "# static is also a JAX pytree, but just metadata\n", - "print(f'\\n{static = }'[:300] + '\\n...')" + "# graphdef is also a JAX pytree, but just metadata\n", + "print(f'\\n{graphdefefefefefef = }'[:300] + '\\n...')" ] }, { @@ -250,17 +250,17 @@ } ], "source": [ - "static, state = model.split()\n", + "graphdef, state = model.split()\n", "\n", "@jax.jit\n", - "def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", - " model = static.merge(state)\n", + "def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", + " model = graphdef.merge(state)\n", " y = model(x)\n", " state, _ = model.split()\n", " return y, state\n", "\n", "x = jnp.ones((2, 4))\n", - "y, state = forward(static,state, x)\n", + "y, state = forward(graphdef,state, x)\n", "\n", "model.update(state)\n", "\n", @@ -284,17 +284,17 @@ } ], "source": [ - "params, batch_stats, counts, static = model.split(nnx.Param, nnx.BatchStat, Count)\n", + "params, batch_stats, counts, graphdef = model.split(nnx.Param, nnx.BatchStat, Count)\n", "\n", "@jax.jit\n", - "def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", - " model = static.merge(params, batch_stats, counts)\n", + "def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", + " model = graphdef.merge(params, batch_stats, counts)\n", " y = model(x, train=True)\n", " params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n", " return y, params, batch_stats, counts\n", "\n", "x = jnp.ones((2, 4))\n", - "y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n", + "y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x)\n", "\n", "model.update(params, batch_stats, counts)\n", "\n", @@ -323,16 +323,16 @@ " self.model = model\n", "\n", " def __call__(self, x):\n", - " params, batch_stats, counts, static = self.model.split(nnx.Param, nnx.BatchStat, Count)\n", + " params, batch_stats, counts, graphdef = self.model.split(nnx.Param, nnx.BatchStat, Count)\n", "\n", " @jax.jit\n", - " def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", - " model = static.merge(params, batch_stats, counts)\n", + " def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array):\n", + " model = graphdef.merge(params, batch_stats, counts)\n", " y = model(x)\n", " params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count)\n", " return y, params, batch_stats, counts\n", "\n", - " y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x)\n", + " y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x)\n", "\n", " self.model.update(params, batch_stats, counts)\n", " return y\n", diff --git a/flax/experimental/nnx/docs/demo.md b/flax/experimental/nnx/docs/demo.md index e4a511dddc..5d02e5da7c 100644 --- a/flax/experimental/nnx/docs/demo.md +++ b/flax/experimental/nnx/docs/demo.md @@ -86,29 +86,29 @@ print(f'{y.shape = }') ```{code-cell} ipython3 :outputId: 9a3f378b-739e-4f45-9968-574651200ede -static, state = model.split() +graphdef, state = model.split() # state is a dictionary-like JAX pytree print(f'{state = }'[:500] + '\n...') -# static is also a JAX pytree, but just metadata -print(f'\n{static = }'[:300] + '\n...') +# graphdef is also a JAX pytree, but just metadata +print(f'\n{graphdefefefefefef = }'[:300] + '\n...') ``` ```{code-cell} ipython3 :outputId: 0007d357-152a-449e-bcb9-b1b5a91d2d8d -static, state = model.split() +graphdef, state = model.split() @jax.jit -def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array): - model = static.merge(state) +def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array): + model = graphdef.merge(state) y = model(x) state, _ = model.split() return y, state x = jnp.ones((2, 4)) -y, state = forward(static,state, x) +y, state = forward(graphdef,state, x) model.update(state) @@ -117,17 +117,17 @@ print(f'{model.count.value = }') ``` ```{code-cell} ipython3 -params, batch_stats, counts, static = model.split(nnx.Param, nnx.BatchStat, Count) +params, batch_stats, counts, graphdef = model.split(nnx.Param, nnx.BatchStat, Count) @jax.jit -def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array): - model = static.merge(params, batch_stats, counts) +def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array): + model = graphdef.merge(params, batch_stats, counts) y = model(x, train=True) params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count) return y, params, batch_stats, counts x = jnp.ones((2, 4)) -y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x) +y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x) model.update(params, batch_stats, counts) @@ -141,16 +141,16 @@ class Parent(nnx.Module): self.model = model def __call__(self, x): - params, batch_stats, counts, static = self.model.split(nnx.Param, nnx.BatchStat, Count) + params, batch_stats, counts, graphdef = self.model.split(nnx.Param, nnx.BatchStat, Count) @jax.jit - def forward(static: nnx.GraphDef, params, batch_stats, counts, x: jax.Array): - model = static.merge(params, batch_stats, counts) + def forward(graphdef: nnx.GraphDef, params, batch_stats, counts, x: jax.Array): + model = graphdef.merge(params, batch_stats, counts) y = model(x) params, batch_stats, counts, _ = model.split(nnx.Param, nnx.BatchStat, Count) return y, params, batch_stats, counts - y, params, batch_stats, counts = forward(static, params, batch_stats, counts, x) + y, params, batch_stats, counts = forward(graphdef, params, batch_stats, counts, x) self.model.update(params, batch_stats, counts) return y diff --git a/flax/experimental/nnx/docs/quick_start.ipynb b/flax/experimental/nnx/docs/quick_start.ipynb index b0b19c5b56..fc617db8a8 100644 --- a/flax/experimental/nnx/docs/quick_start.ipynb +++ b/flax/experimental/nnx/docs/quick_start.ipynb @@ -335,7 +335,7 @@ "\n", "Now that we have a working model, lets see how to train it with `jax.jit` using NNX's Functional API. The `Module.split` method allows you to convert a Module into pytrees with functional semantics, this allows you to integrate with JAX's functional APIs like `jax.jit` and `jax.grad`.\n", "\n", - "In this next example we will use the `.split` method to split the model into a `params: State` and `static: GraphDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `static` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `GraphDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state." + "In this next example we will use the `.split` method to split the model into a `params: State` and `graphdef: GraphDef` objects. We pass the `\"params\"` filter to check that the Module's state only contain `Variables` with the `params` collection. Having `params` and `graphdef` its pretty easy to implement a jitted `train_step` much like you would in Flax or Haiku. `GraphDef` exposes an `apply` method which accepts some `State` and creates a function that runs the Module's `__call__` method. This function then returns the output of the Module along with the updated state." ] }, { @@ -344,13 +344,13 @@ "metadata": {}, "outputs": [], "source": [ - "static, params = model.split(\"params\")\n", + "graphdef, params = model.split(\"params\")\n", "\n", "\n", "@jax.jit\n", "def train_step(params: nnx.State, x, y):\n", " def loss_fn(params):\n", - " logits, _updates = static.apply(params)(x)\n", + " logits, _updates = graphdef.apply(params)(x)\n", " return optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()\n", "\n", " loss, grads = jax.value_and_grad(loss_fn)(params)\n", @@ -420,7 +420,7 @@ "outputs": [], "source": [ "state = nnx.TrainState(\n", - " static,\n", + " graphdef,\n", " params=params,\n", " tx=optax.adam(0.001),\n", ")\n", diff --git a/flax/experimental/nnx/docs/why.ipynb b/flax/experimental/nnx/docs/why.ipynb index fe6cfb9a16..f661f64bc5 100644 --- a/flax/experimental/nnx/docs/why.ipynb +++ b/flax/experimental/nnx/docs/why.ipynb @@ -255,7 +255,7 @@ "\n", "NNX has two very simple APIs to interact with JAX: `split` and `merge`.\n", "\n", - "The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the static structure of the Module." + "The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the graphdef structure of the Module." ] }, { @@ -278,7 +278,7 @@ " [-0.06992685, -0.64693886, 0.20232596, 1.1200062 ]], dtype=float32)\n", "})\n", "\n", - "static = GraphDef(\n", + "graphdef = GraphDef(\n", " type=CounterLinear,\n", " index=0,\n", " static_fields=(),\n", @@ -305,13 +305,13 @@ "source": [ "model = CounterLinear(4, 4, rngs=nnx.Rngs(0))\n", "\n", - "static, state = model.split()\n", + "graphdef, state = model.split()\n", "\n", "# state is a dictionary-like JAX pytree\n", "print(f'{state = }')\n", "\n", - "# static is also a JAX pytree, but containing no data, just metadata\n", - "print(f'\\n{static = }')" + "# graphdef is also a JAX pytree, but containing no data, just metadata\n", + "print(f'\\n{graphdef = }')" ] }, { @@ -341,14 +341,14 @@ ], "source": [ "@jax.jit\n", - "def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", - " model = static.merge(state)\n", + "def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array):\n", + " model = graphdef.merge(state)\n", " y = model(x)\n", " state, _ = model.split()\n", " return y, state\n", "\n", "x = jnp.ones((2, 4))\n", - "y, state = forward(static,state, x)\n", + "y, state = forward(graphdef,state, x)\n", "\n", "print(f'{y.shape = }')\n", "print(f'{state[\"count\"] = }')" @@ -401,34 +401,34 @@ " return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split(\n", " nnx.Param, Count\n", " )\n", - " params, counts, static = jax.vmap(\n", + " params, counts, graphdef = jax.vmap(\n", " vmap_init, in_axes=(0,), out_axes=(0, None, None)\n", " )(keys)\n", "\n", " # update wrapped submodule reference\n", - " self.models = static.merge(params, counts)\n", + " self.models = graphdef.merge(params, counts)\n", "\n", " def __call__(self, x):\n", " # get module values, define pure fn,\n", " # notice that we split the data into two collections by their types.\n", - " params, counts, static = self.models.split(nnx.Param, Count)\n", + " params, counts, graphdef = self.models.split(nnx.Param, Count)\n", "\n", " # define pure init fn and vmap\n", - " def vmap_apply(x, params, counts, static):\n", - " model = static.merge(params, counts)\n", + " def vmap_apply(x, params, counts, graphdef):\n", + " model = graphdef.merge(params, counts)\n", " y = model(x)\n", - " params, counts, static = model.split(nnx.Param, Count)\n", - " return y, params, counts, static\n", + " params, counts, graphdef = model.split(nnx.Param, Count)\n", + " return y, params, counts, graphdef\n", "\n", - " y, params, counts, static = jax.vmap(\n", + " y, params, counts, graphdef = jax.vmap(\n", " vmap_apply,\n", " in_axes=(None, 0, None, None),\n", " out_axes=(0, 0, None, None)\n", - " )(x, params, counts, static)\n", + " )(x, params, counts, graphdef)\n", "\n", " # update wrapped module\n", " # uses `update` to integrate the new state\n", - " self.models.update(params, counts, static)\n", + " self.models.update(params, counts, graphdef)\n", " return y\n", "\n", "x = jnp.ones((4,))\n", @@ -680,7 +680,7 @@ "model = AnnotatedLinear(4, 8, rngs=nnx.Rngs(0))\n", "y = model(jnp.ones((2, 4)))\n", "\n", - "static, state = model.split()\n", + "graphdef, state = model.split()\n", "\n", "print(f\"{state.variables['kernel'].meta=}\\n{state.variables['kernel'].other_meta=}\")\n", "print(f\"{state.variables['bias'].meta=}\\n{state.variables['bias'].other_meta=}\")" @@ -751,7 +751,7 @@ " input_shape=(2, 6, 6, 3),\n", " rngs=nnx.Rngs(0))\n", "\n", - "static, state = model.split()\n", + "graphdef, state = model.split()\n", "jax.tree_util.tree_map(jnp.shape, state)" ] } diff --git a/flax/experimental/nnx/docs/why.md b/flax/experimental/nnx/docs/why.md index ace7c6f35d..aef9107811 100644 --- a/flax/experimental/nnx/docs/why.md +++ b/flax/experimental/nnx/docs/why.md @@ -145,20 +145,20 @@ While NNX Modules inherently follow reference semantics, they can be easily conv NNX has two very simple APIs to interact with JAX: `split` and `merge`. -The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the static structure of the Module. +The `Module.split` method allows you to convert into a `State` dict-like object that contains the dynamic state of the Module, and a `GraphDef` object that contains the graphdef structure of the Module. ```{code-cell} :outputId: 9a3f378b-739e-4f45-9968-574651200ede model = CounterLinear(4, 4, rngs=nnx.Rngs(0)) -static, state = model.split() +graphdef, state = model.split() # state is a dictionary-like JAX pytree print(f'{state = }') -# static is also a JAX pytree, but containing no data, just metadata -print(f'\n{static = }') +# graphdef is also a JAX pytree, but containing no data, just metadata +print(f'\n{graphdef = }') ``` The `GraphDef.merge` method allows you to take a `GraphDef` and one or more `State` objects and merge them back into a `Module` object. @@ -169,14 +169,14 @@ Using `split` and `merge` in conjunction allows you to carry your Module in and :outputId: 0007d357-152a-449e-bcb9-b1b5a91d2d8d @jax.jit -def forward(static: nnx.GraphDef, state: nnx.State, x: jax.Array): - model = static.merge(state) +def forward(graphdef: nnx.GraphDef, state: nnx.State, x: jax.Array): + model = graphdef.merge(state) y = model(x) state, _ = model.split() return y, state x = jnp.ones((2, 4)) -y, state = forward(static,state, x) +y, state = forward(graphdef,state, x) print(f'{y.shape = }') print(f'{state["count"] = }') @@ -205,34 +205,34 @@ class LinearEnsemble(nnx.Module): return CounterLinear(din, dout, rngs=nnx.Rngs(keys)).split( nnx.Param, Count ) - params, counts, static = jax.vmap( + params, counts, graphdef = jax.vmap( vmap_init, in_axes=(0,), out_axes=(0, None, None) )(keys) # update wrapped submodule reference - self.models = static.merge(params, counts) + self.models = graphdef.merge(params, counts) def __call__(self, x): # get module values, define pure fn, # notice that we split the data into two collections by their types. - params, counts, static = self.models.split(nnx.Param, Count) + params, counts, graphdef = self.models.split(nnx.Param, Count) # define pure init fn and vmap - def vmap_apply(x, params, counts, static): - model = static.merge(params, counts) + def vmap_apply(x, params, counts, graphdef): + model = graphdef.merge(params, counts) y = model(x) - params, counts, static = model.split(nnx.Param, Count) - return y, params, counts, static + params, counts, graphdef = model.split(nnx.Param, Count) + return y, params, counts, graphdef - y, params, counts, static = jax.vmap( + y, params, counts, graphdef = jax.vmap( vmap_apply, in_axes=(None, 0, None, None), out_axes=(0, 0, None, None) - )(x, params, counts, static) + )(x, params, counts, graphdef) # update wrapped module # uses `update` to integrate the new state - self.models.update(params, counts, static) + self.models.update(params, counts, graphdef) return y x = jnp.ones((4,)) @@ -359,7 +359,7 @@ class AnnotatedLinear(nnx.Module): model = AnnotatedLinear(4, 8, rngs=nnx.Rngs(0)) y = model(jnp.ones((2, 4))) -static, state = model.split() +graphdef, state = model.split() print(f"{state.variables['kernel'].meta=}\n{state.variables['kernel'].other_meta=}") print(f"{state.variables['bias'].meta=}\n{state.variables['bias'].other_meta=}") @@ -404,6 +404,6 @@ model = Example(in_filters=3, input_shape=(2, 6, 6, 3), rngs=nnx.Rngs(0)) -static, state = model.split() +graphdef, state = model.split() jax.tree_util.tree_map(jnp.shape, state) ``` diff --git a/flax/experimental/nnx/examples/lm1b/models_test.py b/flax/experimental/nnx/examples/lm1b/models_test.py index d8beda34ec..e66a2949b0 100644 --- a/flax/experimental/nnx/examples/lm1b/models_test.py +++ b/flax/experimental/nnx/examples/lm1b/models_test.py @@ -18,12 +18,6 @@ from pathlib import Path from typing import Any -# add project_root to import lm1b Linen model -project_root = str(Path(__file__).parents[6]) -sys.path.append(project_root) -from examples.lm1b.models import TransformerLM as TransformerLinen - -sys.path.pop() import dataclasses @@ -43,6 +37,13 @@ jax.config.update('jax_disable_most_optimizations', True) +# add project_root to import lm1b Linen model +project_root = str(Path(__file__).absolute().parents[5]) +sys.path.append(project_root) +from examples.lm1b.models import TransformerLM as TransformerLinen + +sys.path.pop() + @dataclasses.dataclass(unsafe_hash=True) class CompatTransformerConfig(TransformerConfig): @@ -88,13 +89,14 @@ def transfer_params( def apply_rules(names: tuple[str, ...]): return tuple(rules[name] for name in names) - def copy_var(nnx_name, linen_name): + def copy_var(nnx_name: str, linen_name: str): + nnx_path = tuple(nnx_name.split('/')) assert ( - flat_params_nnx[nnx_name].raw_value.shape + flat_params_nnx[nnx_path].value.shape == flat_params_linen[linen_name].value.shape ) - flat_params_nnx[nnx_name].raw_value = flat_params_linen[linen_name].value - assert flat_params_nnx[nnx_name].sharding == apply_rules( + flat_params_nnx[nnx_path].value = flat_params_linen[linen_name].value + assert flat_params_nnx[nnx_path].sharding == apply_rules( flat_params_linen[linen_name].names ) @@ -168,12 +170,13 @@ def transfer_cache( flat_cache_nnx = cache_nnx.flat_state() flat_cache_linen = traverse_util.flatten_dict(cache_linen, sep='/') - def copy_var(nnx_name, linen_name): + def copy_var(nnx_name: str, linen_name: str): + nnx_path = tuple(nnx_name.split('/')) assert ( - flat_cache_nnx[nnx_name].raw_value.shape + flat_cache_nnx[nnx_path].value.shape == flat_cache_linen[linen_name].shape ) - flat_cache_nnx[nnx_name].raw_value = flat_cache_linen[linen_name] + flat_cache_nnx[nnx_path].value = flat_cache_linen[linen_name] for idx in range(config.num_layers): copy_var( @@ -206,8 +209,8 @@ def test_forward_eval(self): decode=False, ) - model_nnx = TransformerLM.create_abstract(config, rngs=nnx.Rngs(0)) - params_nnx, _ = model_nnx.split(nnx.Param) + model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0))) + _, params_nnx = nnx.split(model_nnx, nnx.Param) model_linen = TransformerLinen(config) @@ -215,7 +218,7 @@ def test_forward_eval(self): params_linen = model_linen.init(random.key(0), sample_inputs)['params'] self.transfer_params(config, params_nnx, params_linen) - model_nnx.update(params_nnx) + nnx.update(model_nnx, params_nnx) model_nnx.set_attributes(deterministic=True, decode=False) output_nnx = model_nnx(sample_inputs) @@ -240,13 +243,13 @@ def test_forward_decode(self): decode=True, ) - model_nnx = TransformerLM.create_abstract(config, rngs=nnx.Rngs(0)) - for _path, m in model_nnx.modules(): + model_nnx = nnx.eval_shape(lambda: TransformerLM(config, rngs=nnx.Rngs(0))) + for _path, m in model_nnx.iter_modules(): if isinstance(m, HasCache): input_shape = (batch_size, config.max_len, config.emb_dim) m.init_cache(input_shape, dtype=config.dtype) - params_nnx, cache_nnx, _ = model_nnx.split(nnx.Param, nnx.Cache) + _, params_nnx, cache_nnx = nnx.split(model_nnx, nnx.Param, nnx.Cache) model_linen = TransformerLinen(config) @@ -262,7 +265,7 @@ def test_forward_decode(self): self.transfer_params(config, params_nnx, params_linen) self.transfer_cache(config, cache_nnx, cache_linen) - model_nnx.update(params_nnx, cache_nnx) + nnx.update(model_nnx, params_nnx, cache_nnx) model_nnx.set_attributes(deterministic=True, decode=True) outputs_nnx = [] diff --git a/flax/experimental/nnx/examples/lm1b/train.py b/flax/experimental/nnx/examples/lm1b/train.py index 7639033fb4..5bd289ac31 100644 --- a/flax/experimental/nnx/examples/lm1b/train.py +++ b/flax/experimental/nnx/examples/lm1b/train.py @@ -193,7 +193,7 @@ def train_step( def loss_fn(params): """loss function used for training.""" - module = state.graphdef.merge(params) + module = nnx.merge(state.graphdef, params) module.set_attributes(deterministic=False, decode=False) logits = module( inputs, @@ -222,13 +222,13 @@ def loss_fn(params): def eval_step( params: nnx.State, batch, - static: nnx.GraphDef[models.TransformerLM], + graphdef: nnx.GraphDef[models.TransformerLM], label_smoothing=0.0, ): """Calculate evaluation metrics on a batch.""" inputs = batch['inputs'] weights = jnp.where(inputs > 0, 1.0, 0.0) - module = static.merge(params) + module = nnx.merge(graphdef, params) module.set_attributes(deterministic=True, decode=False) logits = module(inputs) @@ -239,7 +239,7 @@ def predict_step( inputs, params: nnx.State, rngkey: jax.Array, - static: nnx.GraphDef[models.TransformerLM], + graphdef: nnx.GraphDef[models.TransformerLM], eos_id: int, max_decode_len: int, config: models.TransformerConfig, @@ -247,23 +247,23 @@ def predict_step( top_k: int, ): """Predict language model on a batch.""" - module = static.merge(params) + module = nnx.merge(graphdef, params) # TODO(cgarciae): check how pytorch does this. - for _path, m in module.modules(): + for _path, m in module.iter_modules(): if isinstance(m, HasCache): input_shape = (inputs.shape[0], max_decode_len, config.emb_dim) m.init_cache(input_shape, dtype=config.dtype) - cache = module.extract(nnx.Cache) + graphdef, params, cache = nnx.split(module, nnx.Param, nnx.Cache) def tokens_ids_to_logits(flat_ids, cache: nnx.State): """Token slice to logits from decoder model.""" # --> [batch * beam, 1, vocab] - module = static.merge(params, cache) + module = nnx.merge(graphdef, params, cache) module.set_attributes(deterministic=True, decode=True) logits = module(flat_ids) - cache = module.extract(nnx.Cache) + cache = nnx.state(module, nnx.Cache) # Remove singleton sequence-length dimension: # [batch, 1, vocab] --> [batch, vocab] logits = logits.squeeze(axis=1) @@ -347,7 +347,7 @@ def evaluate( def generate_prediction( *, jit_pred_step, - static: nnx.GraphDef[models.TransformerLM], + graphdef: nnx.GraphDef[models.TransformerLM], params: nnx.State, tokenized_prompts, eos_id, @@ -379,7 +379,7 @@ def generate_prediction( pred_batch, params, inference_rngs, - static, + graphdef, eos_id, config.max_predict_length, model_config, @@ -630,7 +630,7 @@ def constructor(config: models.TransformerConfig, key: jax.Array): with report_progress.timed('generate_text'): exemplars = generate_prediction( jit_pred_step=jit_pred_step, - static=state.graphdef, + graphdef=state.graphdef, params=state.params, tokenized_prompts=tokenized_prompts, eos_id=eos_id, diff --git a/flax/experimental/nnx/examples/lm1b/utils.py b/flax/experimental/nnx/examples/lm1b/utils.py index 70771803cf..9ba2e280f3 100644 --- a/flax/experimental/nnx/examples/lm1b/utils.py +++ b/flax/experimental/nnx/examples/lm1b/utils.py @@ -157,9 +157,9 @@ def setup_initial_state( with mesh: model = constructor(config, rng) - static, params = model.split(nnx.Param) + graphdef, params = nnx.split(model, nnx.Param) state = TrainState.create( - apply_fn=static.apply, params=params, tx=tx, graphdef=static + apply_fn=graphdef.apply, params=params, tx=tx, graphdef=graphdef ) state = jax.tree_util.tree_map(_to_array, state) state_spec = nnx.get_partition_spec(state) diff --git a/flax/experimental/nnx/examples/toy_examples/00_demo.ipynb b/flax/experimental/nnx/examples/toy_examples/00_demo.ipynb deleted file mode 100644 index 66e82a495f..0000000000 --- a/flax/experimental/nnx/examples/toy_examples/00_demo.ipynb +++ /dev/null @@ -1,505 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Linear(\n", - " din=2,\n", - " dout=2\n", - ")\n", - "[[0.63114893 1.2928092 ]\n", - " [0.63114893 1.2928092 ]]\n" - ] - } - ], - "source": [ - "from flax.experimental import nnx\n", - "import jax\n", - "import jax.numpy as jnp\n", - "\n", - "\n", - "class Linear(nnx.Module):\n", - "\n", - " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", - " # static attributes\n", - " self.din = din\n", - " self.dout = dout\n", - " # variables\n", - " self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n", - " self.b = nnx.Param(jnp.zeros((dout,)))\n", - "\n", - " def __call__(self, x):\n", - " return x @ self.w.value + self.b.value\n", - "\n", - "\n", - "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n", - "\n", - "y = linear(jnp.ones((2, 2)))\n", - "\n", - "print(linear)\n", - "print(y)" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "State({\n", - " 'w': Param(\n", - " raw_value=Array([[0.31696808, 0.55285215],\n", - " [0.31418085, 0.7399571 ]], dtype=float32)\n", - " ),\n", - " 'b': Param(\n", - " raw_value=Array([0., 0.], dtype=float32)\n", - " )\n", - "})\n", - "GraphDef(\n", - " type=Linear,\n", - " index=0,\n", - " attributes=('din', 'dout', 'w', 'b'),\n", - " subgraphs={},\n", - " static_fields={\n", - " 'din': 2,\n", - " 'dout': 2\n", - " },\n", - " variables={\n", - " 'w': VariableDef(\n", - " type=Param,\n", - " index=1,\n", - " metadata={\n", - " 'get_value_hooks': (),\n", - " 'set_value_hooks': (),\n", - " 'create_value_hooks': (),\n", - " 'add_axis_hooks': (),\n", - " 'remove_axis_hooks': ()\n", - " }\n", - " ),\n", - " 'b': VariableDef(\n", - " type=Param,\n", - " index=2,\n", - " metadata={\n", - " 'get_value_hooks': (),\n", - " 'set_value_hooks': (),\n", - " 'create_value_hooks': (),\n", - " 'add_axis_hooks': (),\n", - " 'remove_axis_hooks': ()\n", - " }\n", - " )\n", - " },\n", - " metadata=\n", - ")\n" - ] - } - ], - "source": [ - "graphdef, state = linear.split()\n", - "\n", - "print(state)\n", - "print(graphdef)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "State({\n", - " 'linear': {\n", - " 'w': Param(\n", - " raw_value=Array([[0.31696808, 0.55285215],\n", - " [0.31418085, 0.7399571 ]], dtype=float32)\n", - " ),\n", - " 'b': Param(\n", - " raw_value=Array([0., 0.], dtype=float32)\n", - " )\n", - " }\n", - "})" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "class Nested(nnx.Module):\n", - " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", - " self.linear = Linear(din, dout, rngs=rngs)\n", - " \n", - "module = Nested(2, 2, rngs=nnx.Rngs(0))\n", - "\n", - "state, static = module.split()\n", - "state" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Linear(\n", - " din=2,\n", - " dout=2,\n", - " submodule=Linear(...)\n", - ")\n", - "[[0.63114893 1.2928092 ]\n", - " [0.63114893 1.2928092 ]]\n" - ] - } - ], - "source": [ - "class Linear(nnx.Module):\n", - "\n", - " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", - " self.din = din\n", - " self.dout = dout\n", - " self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n", - " self.b = nnx.Param(jnp.zeros((dout,)))\n", - " # introduce a self-reference\n", - " self.submodule = self\n", - "\n", - " def __call__(self, x):\n", - " return x @ self.submodule.w.value + self.submodule.b.value\n", - "\n", - "\n", - "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n", - "\n", - "y = linear(jnp.ones((2, 2)))\n", - "\n", - "print(linear)\n", - "print(y)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "State({\n", - " 'w': Param(\n", - " raw_value=Array([[0.31696808, 0.55285215],\n", - " [0.31418085, 0.7399571 ]], dtype=float32)\n", - " ),\n", - " 'b': Param(\n", - " raw_value=Array([0., 0.], dtype=float32)\n", - " )\n", - "})\n", - "GraphDef(\n", - " type=Linear,\n", - " index=0,\n", - " attributes=('din', 'dout', 'w', 'b', 'submodule'),\n", - " subgraphs={\n", - " 'submodule': 0\n", - " },\n", - " static_fields={\n", - " 'din': 2,\n", - " 'dout': 2\n", - " },\n", - " variables={\n", - " 'w': VariableDef(\n", - " type=Param,\n", - " index=1,\n", - " metadata={\n", - " 'get_value_hooks': (),\n", - " 'set_value_hooks': (),\n", - " 'create_value_hooks': (),\n", - " 'add_axis_hooks': (),\n", - " 'remove_axis_hooks': ()\n", - " }\n", - " ),\n", - " 'b': VariableDef(\n", - " type=Param,\n", - " index=2,\n", - " metadata={\n", - " 'get_value_hooks': (),\n", - " 'set_value_hooks': (),\n", - " 'create_value_hooks': (),\n", - " 'add_axis_hooks': (),\n", - " 'remove_axis_hooks': ()\n", - " }\n", - " )\n", - " },\n", - " metadata=\n", - ")\n" - ] - } - ], - "source": [ - "graphdef, state = linear.split()\n", - "\n", - "print(state)\n", - "print(graphdef)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "linear2 = graphdef.merge(state)\n", - "\n", - "linear2.submodule is linear2" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Linear(\n", - " din=2,\n", - " dout=2\n", - ")\n", - "[[0.63114893 1.2928092 ]\n", - " [0.63114893 1.2928092 ]]\n" - ] - } - ], - "source": [ - "class Linear(nnx.Module):\n", - "\n", - " def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs):\n", - " # static attributes\n", - " self.din = din\n", - " self.dout = dout\n", - " # variables\n", - " self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout)))\n", - " self.b = nnx.Param(jnp.zeros((dout,)))\n", - "\n", - " def __call__(self, x):\n", - " y = x @ self.w.value + self.b.value\n", - " self.y = nnx.Intermediate(y)\n", - " return y\n", - "\n", - "\n", - "linear = Linear(2, 2, rngs=nnx.Rngs(0))\n", - "\n", - "y = linear(jnp.ones((2, 2)))\n", - "\n", - "print(linear)\n", - "print(y)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "State({\n", - " 'y': Intermediate(\n", - " raw_value=Array([[0.63114893, 1.2928092 ],\n", - " [0.63114893, 1.2928092 ]], dtype=float32)\n", - " )\n", - "})\n", - "State({\n", - " 'w': Param(\n", - " raw_value=Array([[0.31696808, 0.55285215],\n", - " [0.31418085, 0.7399571 ]], dtype=float32)\n", - " ),\n", - " 'b': Param(\n", - " raw_value=Array([0., 0.], dtype=float32)\n", - " ),\n", - " 'y': Intermediate(\n", - " raw_value=Empty\n", - " )\n", - "})\n" - ] - } - ], - "source": [ - "intermediates = linear.pop(nnx.Intermediate)\n", - "graphdef, state = linear.split()\n", - "\n", - "print(intermediates)\n", - "print(state)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "state = State({\n", - " 'bar': {\n", - " 'kernel': Array([[-0.3641057 , 0.10192434],\n", - " [-0.37005556, 0.49028906]], dtype=float32),\n", - " 'bias': Array([0., 0.], dtype=float32)\n", - " },\n", - " 'baz': {\n", - " 'bias': Array([0., 0.], dtype=float32)\n", - " }\n", - "})\n", - "static = GraphDef(\n", - " type=Foo,\n", - " index=0,\n", - " attributes=('bar', 'baz'),\n", - " subgraphs={\n", - " 'bar': GraphDef(\n", - " type=Linear,\n", - " index=1,\n", - " attributes=('kernel', 'bias', 'in_features', 'out_features', 'use_bias', 'dtype', 'param_dtype', 'precision', 'kernel_init', 'bias_init', 'dot_general'),\n", - " subgraphs={},\n", - " static_fields={\n", - " 'in_features': 2,\n", - " 'out_features': 2,\n", - " 'use_bias': True,\n", - " 'dtype': None,\n", - " 'param_dtype': ,\n", - " 'precision': None,\n", - " 'kernel_init': .init at 0x1391425f0>,\n", - " 'bias_init': ,\n", - " 'dot_general': \n", - " },\n", - " variables={\n", - " 'kernel': VariableDef(\n", - " type=Param,\n", - " index=2,\n", - " metadata={\n", - " 'get_value_hooks': (),\n", - " 'set_value_hooks': (),\n", - " 'create_value_hooks': (),\n", - " 'add_axis_hooks': (),\n", - " 'remove_axis_hooks': ()\n", - " }\n", - " ),\n", - " 'bias': VariableDef(\n", - " type=Param,\n", - " index=3,\n", - " metadata={\n", - " 'get_value_hooks': (),\n", - " 'set_value_hooks': (),\n", - " 'create_value_hooks': (),\n", - " 'add_axis_hooks': (),\n", - " 'remove_axis_hooks': ()\n", - " }\n", - " )\n", - " },\n", - " metadata=\n", - " ),\n", - " 'baz': GraphDef(\n", - " type=Linear,\n", - " index=4,\n", - " attributes=('kernel', 'bias', 'in_features', 'out_features', 'use_bias', 'dtype', 'param_dtype', 'precision', 'kernel_init', 'bias_init', 'dot_general'),\n", - " subgraphs={},\n", - " static_fields={\n", - " 'in_features': 2,\n", - " 'out_features': 2,\n", - " 'use_bias': True,\n", - " 'dtype': None,\n", - " 'param_dtype': ,\n", - " 'precision': None,\n", - " 'kernel_init': .init at 0x1391425f0>,\n", - " 'bias_init': ,\n", - " 'dot_general': \n", - " },\n", - " variables={\n", - " 'kernel': 2,\n", - " 'bias': VariableDef(\n", - " type=Param,\n", - " index=5,\n", - " metadata={\n", - " 'get_value_hooks': (),\n", - " 'set_value_hooks': (),\n", - " 'create_value_hooks': (),\n", - " 'add_axis_hooks': (),\n", - " 'remove_axis_hooks': ()\n", - " }\n", - " )\n", - " },\n", - " metadata=\n", - " )\n", - " },\n", - " static_fields={},\n", - " variables={},\n", - " metadata=\n", - ")\n" - ] - } - ], - "source": [ - "class Foo(nnx.Module):\n", - " def __init__(self, *, rngs: nnx.Rngs) -> None:\n", - " self.bar = nnx.Linear(2, 2, rngs=rngs)\n", - " self.baz = nnx.Linear(2, 2, rngs=rngs)\n", - "\n", - " # tie the weights\n", - " self.baz.variables.kernel = self.bar.variables.kernel\n", - "\n", - "model = Foo(rngs=nnx.Rngs(0))\n", - "state, static = model.split()\n", - "\n", - "print(f'{state = }')\n", - "print(f'{static = }')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.18" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/flax/experimental/nnx/examples/toy_examples/01_functional_api.py b/flax/experimental/nnx/examples/toy_examples/01_functional_api.py index bfb6d517aa..bd6451555e 100644 --- a/flax/experimental/nnx/examples/toy_examples/01_functional_api.py +++ b/flax/experimental/nnx/examples/toy_examples/01_functional_api.py @@ -57,8 +57,8 @@ def __call__(self, x): return x -static, params, counts = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)).split( - nnx.Param, Count +graphdef, params, counts = nnx.split( + MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)), nnx.Param, Count ) @@ -67,9 +67,9 @@ def train_step(params, counts, batch): x, y = batch def loss_fn(params): - model = static.merge(params, counts) + model = nnx.merge(graphdef, params, counts) y_pred = model(x) - new_counts = model.extract(Count) + new_counts = nnx.state(model, Count) loss = jnp.mean((y - y_pred) ** 2) return loss, new_counts @@ -83,7 +83,7 @@ def loss_fn(params): @jax.jit def test_step(params: nnx.State, counts: nnx.State, batch): x, y = batch - model = static.merge(params, counts) + model = nnx.merge(graphdef, params, counts) y_pred = model(x) loss = jnp.mean((y - y_pred) ** 2) return {'loss': loss} @@ -100,7 +100,7 @@ def test_step(params: nnx.State, counts: nnx.State, batch): if step >= total_steps - 1: break -model = static.merge(params, counts) +model = nnx.merge(graphdef, params, counts) print('times called:', model.count.value) y_pred = model(X) diff --git a/flax/experimental/nnx/examples/toy_examples/03_train_state.py b/flax/experimental/nnx/examples/toy_examples/03_train_state.py deleted file mode 100644 index e2a0495bfd..0000000000 --- a/flax/experimental/nnx/examples/toy_examples/03_train_state.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# %% -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -import optax - -from flax.experimental import nnx - -X = np.linspace(0, 1, 100)[:, None] -Y = 0.8 * X**2 + 0.1 + np.random.normal(0, 0.1, size=X.shape) - - -def dataset(batch_size): - while True: - idx = np.random.choice(len(X), size=batch_size) - yield X[idx], Y[idx] - - -class Linear(nnx.Module): - def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): - self.w = nnx.Param(jax.random.uniform(rngs.params(), (din, dout))) - self.b = nnx.Param(jnp.zeros((dout,))) - - def __call__(self, x): - return x @ self.w.value + self.b.value - - -class Count(nnx.Variable[nnx.A]): - pass - - -class MLP(nnx.Module): - def __init__(self, din, dhidden, dout, *, rngs: nnx.Rngs): - self.count = Count(jnp.array(0)) - self.linear1 = Linear(din, dhidden, rngs=rngs) - self.linear2 = Linear(dhidden, dout, rngs=rngs) - - def __call__(self, x): - self.count.value += 1 - x = self.linear1(x) - x = jax.nn.relu(x) - x = self.linear2(x) - return x - - -static, params, counts = MLP(din=1, dhidden=32, dout=1, rngs=nnx.Rngs(0)).split( - nnx.Param, ... -) - -class TrainState(nnx.TrainState[MLP]): - counts: nnx.State - - -state = TrainState.create( - static, - params=params, - tx=optax.sgd(0.1), - counts=counts, -) -del params, counts - - -@jax.jit -def train_step(state: TrainState, batch): - x, y = batch - - def loss_fn(params): - y_pred, (_, updates) = state.apply(params, 'counts')(x) - counts = updates.extract(Count) - loss = jnp.mean((y - y_pred) ** 2) - return loss, counts - - grads, counts = jax.grad(loss_fn, has_aux=True)(state.params) - # sdg update - state = state.apply_gradients(grads=grads, counts=counts) - - return state - - -@jax.jit -def test_step(state: TrainState, batch): - x, y = batch - y_pred, _ = state.apply('params', 'counts')(x) - loss = jnp.mean((y - y_pred) ** 2) - return {'loss': loss} - - -total_steps = 10_000 -for step, batch in enumerate(dataset(32)): - state = train_step(state, batch) - - if step % 1000 == 0: - logs = test_step(state, (X, Y)) - print(f"step: {step}, loss: {logs['loss']}") - - if step >= total_steps - 1: - break - -model = static.merge(state.params, state.counts) -print('times called:', model.count.value) - -y_pred = model(X) - -plt.scatter(X, Y, color='blue') -plt.plot(X, y_pred, color='black') -plt.show() diff --git a/flax/experimental/nnx/examples/toy_examples/05_vae.py b/flax/experimental/nnx/examples/toy_examples/05_vae.py index 6d3e148fae..895dcd894b 100644 --- a/flax/experimental/nnx/examples/toy_examples/05_vae.py +++ b/flax/experimental/nnx/examples/toy_examples/05_vae.py @@ -14,7 +14,6 @@ # %% import typing as tp -from functools import partial import jax import jax.numpy as jnp @@ -54,8 +53,9 @@ def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): self.linear1 = nnx.Linear(din, dmid, rngs=rngs) self.linear_mean = nnx.Linear(dmid, dout, rngs=rngs) self.linear_std = nnx.Linear(dmid, dout, rngs=rngs) + self.rngs = rngs - def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: + def __call__(self, x: jax.Array) -> jax.Array: x = x.reshape((x.shape[0], -1)) # flatten x = self.linear1(x) x = jax.nn.relu(x) @@ -68,7 +68,7 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: 0.5 * jnp.mean(-jnp.log(std**2) - 1.0 + std**2 + mean**2, axis=-1) ) ) - key = rngs.noise() + key = self.rngs.noise() z = mean + std * jax.random.normal(key, mean.shape) return z @@ -101,8 +101,8 @@ def __init__( latent_size, hidden_size, int(np.prod(output_shape)), rngs=rngs ) - def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: - z = self.encoder(x, rngs=rngs) + def __call__(self, x: jax.Array) -> jax.Array: + z = self.encoder(x) logits = self.decoder(z) logits = jnp.reshape(logits, (-1, *self.output_shape)) return logits @@ -113,60 +113,48 @@ def generate(self, z): return nnx.sigmoid(logits) -static, params = VAE( +model = VAE( din=int(np.prod(image_shape)), hidden_size=256, latent_size=latent_size, output_shape=image_shape, - rngs=nnx.Rngs(0), -).split(nnx.Param) - -state = nnx.TrainState.create( - static, - params=params, - tx=optax.adam(1e-3), + rngs=nnx.Rngs(0, noise=1), ) +optimizer = nnx.Optimizer(model, optax.adam(1e-3)) -# %% -@jax.jit -def train_step(state: nnx.TrainState[VAE], x: jax.Array, key: jax.Array): - def loss_fn(params: nnx.State): - rngs = nnx.Rngs(noise=jax.random.fold_in(key, state.step)) - logits, (_, updates) = state.apply(params)(x, rngs=rngs) - losses = updates.extract(Loss) +# %% +@nnx.jit +def train_step(model: VAE, optimizer: nnx.Optimizer, x: jax.Array): + def loss_fn(model: VAE): + logits = model(x) + losses = nnx.pop(model, Loss) kl_loss = sum(jax.tree_util.tree_leaves(losses), 0.0) reconstruction_loss = jnp.mean( optax.sigmoid_binary_cross_entropy(logits, x) ) - # jax.debug.print("kl_loss={kl_loss}", kl_loss=kl_loss) - loss = reconstruction_loss + 0.1 * kl_loss return loss - loss, grads = jax.value_and_grad(loss_fn)(state.params) - state = state.apply_gradients(grads=grads) + loss, grads = nnx.value_and_grad(loss_fn)(model) + optimizer.update(grads) - return state, loss + return loss -@partial(jax.jit, donate_argnums=(0,)) -def forward( - state: nnx.TrainState[VAE], x: jax.Array, key: jax.Array -) -> jax.Array: - rngs = nnx.Rngs(noise=key) - y_pred = state.apply('params')(x, rngs=rngs)[0] +@nnx.jit +def forward(model: VAE, x: jax.Array) -> jax.Array: + y_pred = model(x) return jax.nn.sigmoid(y_pred) -@jax.jit -def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array: - return state.apply('params').generate(z)[0] +@nnx.jit +def sample(model: VAE, z: jax.Array) -> jax.Array: + return model.generate(z) # %% -key = jax.random.key(0) for epoch in range(epochs): losses = [] @@ -174,7 +162,7 @@ def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array: idxs = np.random.randint(0, len(X_train), size=(batch_size,)) x_batch = X_train[idxs] - state, loss = train_step(state, x_batch, key) + loss = train_step(model, optimizer, x_batch) losses.append(np.asarray(loss)) print(f'Epoch {epoch} loss: {np.mean(losses)}') @@ -186,7 +174,7 @@ def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array: x_sample = X_test[idxs] # get predictions -y_pred = forward(state, x_sample, key) +y_pred = forward(model, x_sample) # plot reconstruction figure = plt.figure(figsize=(3 * 5, 3 * 2)) @@ -203,7 +191,7 @@ def sample(state: nnx.TrainState[VAE], z: jax.Array) -> jax.Array: # %% # plot generative samples z_samples = np.random.normal(scale=1.5, size=(12, latent_size)) -samples = sample(state, z_samples) +samples = sample(model, z_samples) figure = plt.figure(figsize=(3 * 5, 3 * 2)) plt.title('Generative Samples') diff --git a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py b/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py index 0dee1ea798..9a2b01727c 100644 --- a/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py +++ b/flax/experimental/nnx/examples/toy_examples/06_scan_over_layers.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +from functools import partial import jax import jax.numpy as jnp @@ -23,13 +23,11 @@ class Block(nnx.Module): def __init__(self, dim: int, *, rngs: nnx.Rngs): self.linear = nnx.Linear(dim, dim, rngs=rngs) - self.dropout = nnx.Dropout(0.5) + self.bn = nnx.BatchNorm(dim, rngs=rngs) + self.dropout = nnx.Dropout(0.5, rngs=rngs) - def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: - x = self.linear(x) - x = self.dropout(x, rngs=rngs) - x = jax.nn.gelu(x) - return x + def __call__(self, x: jax.Array): + return jax.nn.gelu(self.dropout(self.bn(self.linear(x)))) class ScanMLP(nnx.Module): @@ -41,47 +39,28 @@ class ScanMLP(nnx.Module): def __init__(self, dim: int, *, n_layers: int, rngs: nnx.Rngs): self.n_layers = n_layers - # fork Rngs, split keys into `n_layers` - keys = rngs.fork(n_layers) - - def create_block(keys): - # create Block instance and return its split - return Block(dim, rngs=nnx.Rngs(keys)).split() - - # call vmap over create_block, passing the split `params` key - # and immediately merge to get a Block instance - self.layers = nnx.merge(*jax.vmap(create_block)(keys)) - - def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: - # fork Rngs, split keys into `n_layers` - keys = rngs.fork(self.n_layers) - # split Module to get params - static, params = self.layers.split(nnx.Param) - - def scan_fn( - x: jax.Array, inputs: Tuple[nnx.State, dict[str, nnx.RngStream]] - ) -> Tuple[jax.Array, nnx.State]: - params, keys = inputs - # merge back Module and Rngs - module = static.merge(params) - # forward pass - x = module(x, rngs=nnx.Rngs(keys)) - # split state and return - params, _ = module.split(nnx.Param) - return x, params - - # call scan passing x as the carry, and params + keys as the input - x, params = jax.lax.scan(scan_fn, x, (params, keys)) - # update layers state and return - self.layers.update(params) + + @partial(nnx.vmap, axis_size=n_layers) + def create_block(rngs: nnx.Rngs): + return Block(dim, rngs=rngs) + + self.layers = create_block(rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + @nnx.scan + def scan_fn(x: jax.Array, block: Block): + x = block(x) + return x, None + + x, _ = scan_fn(x, self.layers) + return x model = ScanMLP(10, n_layers=5, rngs=nnx.Rngs(0)) x = jnp.ones((3, 10)) -model.set_attributes(deterministic=False) -y = model(x, rngs=nnx.Rngs(dropout=1)) +y = model(x) -print(jax.tree_util.tree_map(jnp.shape, model.get_state())) +print(jax.tree_util.tree_map(jnp.shape, nnx.state(model))) print(y.shape) diff --git a/flax/experimental/nnx/examples/toy_examples/07_transformer.py b/flax/experimental/nnx/examples/toy_examples/07_transformer.py deleted file mode 100644 index 2af899816c..0000000000 --- a/flax/experimental/nnx/examples/toy_examples/07_transformer.py +++ /dev/null @@ -1,414 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import dataclasses -import typing as tp - -import jax -import jax.numpy as jnp -import numpy as np -from jax.sharding import PartitionSpec as P - -from flax.experimental import nnx - -ShardSpec = tp.Union[str, tp.Tuple[str, ...], None] - - -# Sharding -@dataclasses.dataclass -class Sharding: - batch: ShardSpec = 'data' - sequence: ShardSpec = None - layers: ShardSpec = None - vocab: ShardSpec = 'model' - embed: ShardSpec = None - heads: ShardSpec = 'model' - depth: ShardSpec = None - hidden: ShardSpec = 'model' - - -# Config -@dataclasses.dataclass -class Config: - # mode - decode: bool = False - # shapes - batch: int = 16 - layers: int = 2 - vocab: int = 1024 - embed: int = 64 - heads: int = 12 - depth: int = 64 - hidden: int = 256 - max_length: int = 256 - # dtypes - param_dtype: tp.Any = jnp.float32 - dtype: tp.Any = jnp.float32 - # sharding - sharding: Sharding = Sharding() - scanned: bool = False - # layer params - epsilon: float = 1e-6 - dropout_rate: float = 0.0 - rp_num_buckets: int = 32 - rp_max_distance: int = 128 - - -cfg = Config() - - -def nd_dense_init(scale, mode, distribution): - """Initializer with in_axis, out_axis set at call time.""" - - def init_fn(key, shape, dtype, in_axis, out_axis) -> jax.Array: - fn = jax.nn.initializers.variance_scaling( - scale, mode, distribution, in_axis, out_axis - ) - return fn(key, shape, dtype) - - return init_fn - - -dense_init = nd_dense_init(1.0, 'fan_in', 'truncated_normal') -embed_init = nd_dense_init(1.0, 'fan_in', 'normal') - - -def make_attention_mask( - query_input: tp.Any, - key_input: tp.Any, - pairwise_fn: tp.Callable = jnp.multiply, - dtype: tp.Any = jnp.float32, -): - mask = pairwise_fn( - jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) - ) - return jnp.expand_dims(mask, axis=-3).astype(dtype) - - -def make_causal_mask(x, dtype=jnp.float32): - idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) - return make_attention_mask(idxs, idxs, jnp.greater_equal, dtype=dtype) - - -# padding mask -# make_attention_mask(decoder_target_tokens > 0, decoder_target_tokens > 0, dtype=dtype) -# packing mask -# make_attention_mask(decoder_segment_ids, decoder_segment_ids, jnp.equal, dtype=dtype) - - -def sine_table(features, length, min_timescale=1.0, max_timescale=10000.0): - fraction = jnp.arange(0, features, 2, dtype=jnp.float32) / features - timescale = min_timescale * (max_timescale / min_timescale) ** fraction - rotational_frequency = 1.0 / timescale - # Must use high precision einsum here, bfloat16 rounding is catastrophic. - sinusoid_inp = jnp.einsum( - 'i,j->ij', - jnp.arange(length), - rotational_frequency, - precision=jax.lax.Precision.HIGHEST, - ) - sinusoid_inp = jnp.concatenate([sinusoid_inp, sinusoid_inp], axis=-1) - return jnp.sin(sinusoid_inp), jnp.cos(sinusoid_inp) - - -def rotate_half(x): - x1, x2 = jnp.split(x, 2, axis=-1) - x = jnp.concatenate([-x2, x1], axis=-1) - return x - - -def apply_rotary_embedding(q, k, cos, sin, index=None): - """Helper function to apply Rotary Embeddings.""" - batch, qlen, qheads, d = q.shape - kbatch, klen, kheads, kd = k.shape - if index is not None: - qcos = jax.lax.broadcast_in_dim( - cos[index, :], (batch, qlen, qheads, d), (3,) - ) - qsin = jax.lax.broadcast_in_dim( - sin[index, :], (batch, qlen, qheads, d), (3,) - ) - else: - qcos = jax.lax.broadcast_in_dim( - cos[:qlen, :], (batch, qlen, qheads, d), (1, 3) - ) - qsin = jax.lax.broadcast_in_dim( - sin[:qlen, :], (batch, qlen, qheads, d), (1, 3) - ) - kcos = jax.lax.broadcast_in_dim( - cos[:klen, :], (batch, klen, kheads, d), (1, 3) - ) - ksin = jax.lax.broadcast_in_dim( - sin[:klen, :], (batch, klen, kheads, d), (1, 3) - ) - out_q = (q * qcos) + (rotate_half(q) * qsin) - out_k = (k * kcos) + (rotate_half(k) * ksin) - return out_q, out_k - - -def rms_norm(cfg, scale, x): - x = jnp.asarray(x, jnp.float32) - mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * jax.lax.rsqrt(mean2 + cfg.epsilon), cfg.dtype) - return y * jnp.asarray(scale, cfg.dtype) - - -def dropout(cfg: Config, x, broadcast_dims=(-2,), *, rngs: nnx.Rngs): - if cfg.dropout_rate == 0.0: - return x - broadcast_shape = list(x.shape) - for dim in broadcast_dims: - broadcast_shape[dim] = 1 - keep_rate = 1.0 - cfg.dropout_rate - key = rngs.dropout() - mask = jax.random.bernoulli(key, p=keep_rate, shape=broadcast_shape) - return jax.lax.select( - jnp.broadcast_to(mask, x.shape), x / keep_rate, jnp.zeros_like(x) - ) - - -class Attention(nnx.Module): - def __init__(self, cfg: Config, *, rngs: nnx.Rngs): - sharding = cfg.sharding - - key = rngs.params() - self.WQ = nnx.Param( - dense_init( - key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) - ), - P(sharding.embed, sharding.heads, sharding.depth), - ) - key = rngs.params() - self.WK = nnx.Param( - dense_init( - key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) - ), - P(sharding.embed, sharding.heads, sharding.depth), - ) - key = rngs.params() - self.WV = nnx.Param( - dense_init( - key, (cfg.embed, cfg.heads, cfg.depth), cfg.param_dtype, 0, (1, 2) - ), - P(sharding.embed, sharding.heads, sharding.depth), - ) - key = rngs.params() - self.WO = nnx.Param( - dense_init( - key, (cfg.heads, cfg.depth, cfg.embed), cfg.param_dtype, (0, 1), 2 - ), - P(sharding.heads, sharding.depth, sharding.embed), - ) - # cache - self.index = nnx.variable('cache', jnp.array(0, dtype=jnp.int32), P()) - self.key = nnx.variable( - 'cache', - jnp.zeros( - (cfg.batch, cfg.heads, cfg.depth, cfg.max_length), - jnp.bfloat16, - ), - P(sharding.batch, sharding.heads, sharding.depth, None), - ) - self = nnx.variable( - 'cache', - jnp.zeros( - (cfg.batch, cfg.heads, cfg.depth, cfg.max_length), - jnp.bfloat16, - ), - P(sharding.batch, sharding.heads, sharding.depth, None), - ) - - # We combine the cache and params into "vs", but it would be no harder at all - # to thread through a separate "cache" argument storing cache entries. - def __call__(self, cfg: Config, x_q, x_kv, mask=None, *, rngs: nnx.Rngs): - q = jnp.einsum('bse,enh->bsnh', x_q, self.WQ.astype(cfg.dtype)).astype( - jnp.float32 - ) - k = jnp.einsum('bte,enh->btnh', x_kv, self.WK.astype(cfg.dtype)).astype( - jnp.float32 - ) - v = jnp.einsum('bte,enh->btnh', x_kv, self.WV.astype(cfg.dtype)) - - index = None - if cfg.decode: - index = self.index - one_hot_indices = jax.nn.one_hot( - self.index, cfg.max_length, dtype=cfg.dtype - ) - self.key = self.key + jnp.moveaxis(k, -3, -1) * one_hot_indices - self = self + jnp.moveaxis(v, -3, -1) * one_hot_indices - k = jnp.moveaxis(self.key, -1, -3) - v = jnp.moveaxis(self, -1, -3) - cache_mask = jnp.broadcast_to( - jnp.arange(cfg.max_length) <= self.index, - (cfg.batch, 1, 1, cfg.max_length), - ) - mask = jnp.logical_and( - cache_mask if mask is None else mask, cache_mask - ).astype(cfg.dtype) - self.index = self.index + 1 - - attention_bias = 0.0 - if mask is None: # Hack in lieu of general mask routing. - mask = make_causal_mask(x, jnp.float32) - if mask is not None: - attention_bias = jax.lax.select( - mask > 0, - jnp.full(mask.shape, 0.0, cfg.dtype), - jnp.full(mask.shape, -1e10, cfg.dtype), - ) - - sin, cos = sine_table(q.shape[-1], max(q.shape[1], k.shape[1])) - q, k = apply_rotary_embedding(q, k, cos, sin, index=index) - - l = ( - jnp.einsum('bsnh,btnh->bnst', q, k) / np.sqrt(cfg.depth) + attention_bias - ) - s = jax.nn.softmax(l).astype(cfg.dtype) - s = dropout(cfg, s, rngs=rngs) - a = jnp.einsum('bnst,btnh->bsnh', s, v) - o = jnp.einsum('bsnh,nhe->bse', a, self.WO.astype(cfg.dtype)) - - return o - - -class MLP(nnx.Module): - def __init__(self, cfg: Config, *, rngs: nnx.Rngs): - sharding = cfg.sharding - self.Win1 = nnx.Param( - dense_init( - rngs.params(), - (cfg.embed, cfg.hidden), - cfg.param_dtype, - 0, - 1, - ), - P(sharding.embed, sharding.hidden), - ) - self.Win2 = nnx.Param( - dense_init( - rngs.params(), - (cfg.embed, cfg.hidden), - cfg.param_dtype, - 0, - 1, - ), - P(sharding.embed, sharding.hidden), - ) - self.Wout = nnx.Param( - dense_init( - rngs.params(), - (cfg.hidden, cfg.embed), - cfg.param_dtype, - 0, - 1, - ), - P(sharding.hidden, sharding.embed), - ) - - def __call__(self, cfg: Config, x, *, rngs: nnx.Rngs): - h1 = jnp.einsum('bse,eh->bsh', x, self.Win1.astype(cfg.dtype)) - h2 = jnp.einsum('bse,eh->bsh', x, self.Win2.astype(cfg.dtype)) - h = jax.nn.gelu(h1) * h2 - h = dropout(cfg, h, rngs=rngs) - o = jnp.einsum('bsh,he->bse', h, self.Wout.astype(cfg.dtype)) - return o - - -class DecoderBlock(nnx.Module): - def __init__(self, cfg: Config, *, rngs: nnx.Rngs): - sharding = cfg.sharding - self.attn = Attention(cfg, rngs=rngs) - self.mlp = MLP(cfg, rngs=rngs) - self.scale1 = nnx.Param( - jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) - ) - self.scale2 = nnx.Param( - jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) - ) - - def __call__(self, cfg: Config, input, *, rngs: nnx.Rngs): - x = rms_norm(cfg, self.scale1, input) - x = self.attn(cfg, x, x, mask=None, rngs=rngs) - x = dropout(cfg, x, rngs=rngs) - x = x + input - y = rms_norm(cfg, self.scale2, x) - y = self.mlp(cfg, y, rngs=rngs) - y = dropout(cfg, y, rngs=rngs) - return y + x - - -class Decoder(nnx.Module): - def __init__(self, cfg: Config, *, rngs: nnx.Rngs): - sharding = cfg.sharding - self.embed = nnx.Param( - embed_init( - rngs.params(), - (cfg.vocab, cfg.embed), - cfg.param_dtype, - 1, - 0, - ), - P(sharding.vocab, sharding.embed), - ) - self.unembed = nnx.Param( - dense_init(rngs.params(), (cfg.embed, cfg.vocab), jnp.float32, 0, 1), - P(sharding.embed, sharding.vocab), - ) - self.scale1 = nnx.Param( - jnp.ones((cfg.embed,), cfg.param_dtype), P(sharding.embed) - ) - - if cfg.scanned: - self.layers = nnx.merge( - *jax.vmap(lambda key: DecoderBlock(cfg, rngs=nnx.Rngs(key)).split())( - jax.random.split(rngs.params(), cfg.layers) - ) - ) - else: - self.layers = nnx.List( - DecoderBlock(cfg, rngs=rngs) for _ in range(cfg.layers) - ) - - def __call__(self, cfg: Config, x, *, rngs: nnx.Rngs): - # TODO: handle right-shifting for training: here or in train loop. - # TODO: handle general mask routing. - x = self.embed.astype(cfg.dtype)[x] - - if cfg.scanned: - assert isinstance(self.layers, DecoderBlock) - - static, state = self.layers.split() - rngs, rngsdef = rngs.fork() - dropout_key = jax.random.split(rngs['dropout'], cfg.layers) - - def scan_fn(x, s: tp.Tuple[jax.Array, nnx.State]): - dropout_key, state = s - rngs = rngsdef.merge({'dropout': dropout_key}) - y, (state, _) = static.apply(state)(cfg, x, rngs=rngs) - return y, state - - x, state = jax.lax.scan( - scan_fn, - x, - (dropout_key, state), - ) - self.layers.update(state) - else: - assert isinstance(self.layers, nnx.List) - for decoder_block in self.layers: - x = decoder_block(cfg, x, rngs=rngs) - - x = jnp.einsum('bse,ev->bsv', x, self.unembed) - return x diff --git a/flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py b/flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py index 4d6e20e93c..281a290f1f 100644 --- a/flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py +++ b/flax/experimental/nnx/examples/toy_examples/08_save_load_checkpoints.py @@ -39,7 +39,7 @@ def create_model(seed: int): def create_and_save(seed: int, path: str): model = create_model(seed) - state = model.get_state() + state = nnx.state(model) # Save the parameters checkpointer = orbax.PyTreeCheckpointer() checkpointer.save(f'{path}/state', state) @@ -47,12 +47,13 @@ def create_and_save(seed: int, path: str): def load_model(path: str) -> MLP: # create that model with abstract shapes - static, state = jax.eval_shape(lambda: create_model(0).split()) + model = nnx.eval_shape(lambda: create_model(0)) + state = nnx.state(model) # Load the parameters checkpointer = orbax.PyTreeCheckpointer() state = checkpointer.restore(f'{path}/state', item=state) - # Merge the parameters into the model - model = static.merge(state) + # update the model with the loaded state + nnx.update(model, state) return model diff --git a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py b/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py index c81db63eef..c7f5dd07f7 100644 --- a/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py +++ b/flax/experimental/nnx/examples/toy_examples/09_parameter_surgery.py @@ -46,11 +46,16 @@ def __call__(self, x): # create a filter to select all the parameters that are not part of the # backbone, i.e. the classifier parameters is_trainable = lambda path, node: ( - path.startswith('backbone') and isinstance(node, nnx.Param) + 'backbone' in path and isinstance(node, nnx.Param) ) # split the parameters into trainable and non-trainable parameters -trainable_params, non_trainable, static = model.split(is_trainable, ...) +graphdef, trainable_params, non_trainable = nnx.split(model, is_trainable, ...) -print('trainable_params =', jax.tree_util.tree_map(jax.numpy.shape, trainable_params)) -print('non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable)) +print( + 'trainable_params =', + jax.tree_util.tree_map(jax.numpy.shape, trainable_params), +) +print( + 'non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable) +) diff --git a/flax/experimental/nnx/examples/toy_examples/10_quantization.py b/flax/experimental/nnx/examples/toy_examples/10_quantization.py deleted file mode 100644 index b58db649fb..0000000000 --- a/flax/experimental/nnx/examples/toy_examples/10_quantization.py +++ /dev/null @@ -1,443 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# %% -import typing as tp -from functools import partial - -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt -import numpy as np -import optax -from datasets import load_dataset - -from flax.experimental import nnx - -np.random.seed(42) -image_shape: tp.Sequence[int] = (28, 28) -steps_per_epoch: int = 200 -batch_size: int = 64 -epochs: int = 20 - - -@jax.custom_vjp -def diff_round(x) -> jax.Array: - y = jnp.round(x) - return y - - -def diff_round_fwd(x): - return diff_round(x), None - - -def diff_round_bwd(_, g): - return (g,) - - -diff_round.defvjp(diff_round_fwd, diff_round_bwd) - - -@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) -def diff_clip(x, low, high) -> jax.Array: - return jnp.clip(x, low, high) - - -def diff_clip_fwd(x, low, high): - return diff_clip(x, low, high), None - - -def diff_clip_bwd(_, _1, _2, dy): - return (dy,) - - -diff_clip.defvjp(diff_clip_fwd, diff_clip_bwd) - - -# %% -def f(x): - return diff_clip(diff_round(x * 128) + 128, 0, 255) - - -df = jax.vmap(jax.grad(f)) - -x = jnp.linspace(-1.5, 1.5, 100) -dx = df(x) - -plt.plot(x, dx) - -# %% -dataset = load_dataset('mnist') -X_train = np.array(np.stack(dataset['train']['image']), dtype=np.float32) -Y_train = np.array(dataset['train']['label'], dtype=np.int32) -X_test = np.array(np.stack(dataset['test']['image']), dtype=np.float32) -Y_test = np.array(dataset['test']['label'], dtype=np.int32) -# normalize data -X_train = X_train / 255.0 -X_test = X_test / 255.0 - - -print('X_train:', X_train.shape, X_train.dtype) -print('X_test:', X_test.shape, X_test.dtype) - - -# %% -class MLP(nnx.Module): - def __init__(self, din: int, dmid: int, dout: int, *, rngs: nnx.Rngs): - self.linear1 = nnx.Linear(din, dmid, rngs=rngs) - self.linear2 = nnx.Linear(dmid, dout, rngs=rngs) - - def __call__(self, x: jax.Array) -> jax.Array: - x = x.reshape((x.shape[0], -1)) - x = self.linear1(x) - x = jax.nn.gelu(x) - x = self.linear2(x) - return x - - -static, params = MLP( - din=np.prod(image_shape), dmid=256, dout=10, rngs=nnx.Rngs(0) -).split(nnx.Param) - -state = nnx.TrainState.create( - static, - params=params, - tx=optax.adam(1e-3), -) - - -# %% -@jax.jit -def train_step( - state: nnx.TrainState[MLP], - inputs: jax.Array, - labels: jax.Array, -): - def loss_fn(params: nnx.State): - logits, _ = state.apply(params)(inputs) - loss = jnp.mean( - optax.softmax_cross_entropy_with_integer_labels(logits, labels) - ) - return loss - - grad_fn = jax.value_and_grad(loss_fn) - loss, grads = grad_fn(state.params) - state = state.apply_gradients(grads=grads) - - return state, loss - - -@jax.jit -def eval_step(state: nnx.TrainState[MLP], inputs: jax.Array, labels: jax.Array): - logits, _ = state.apply('params')(inputs) - loss = jnp.mean( - optax.softmax_cross_entropy_with_integer_labels(logits, labels) - ) - acc = jnp.mean(jnp.argmax(logits, axis=-1) == labels) - return {'loss': loss, 'accuracy': acc} - - -@partial(jax.jit, donate_argnums=(0,)) -def forward(state: nnx.TrainState[MLP], inputs: jax.Array) -> jax.Array: - y_pred = state.apply('params')(inputs)[0] - return jnp.argmax(y_pred, axis=-1) - - -# %% -key = jax.random.key(0) - -for epoch in range(epochs): - for step in range(steps_per_epoch): - idxs = np.random.randint(0, len(X_train), size=(batch_size,)) - x_batch = X_train[idxs] - y_batch = Y_train[idxs] - - state, loss = train_step(state, x_batch, y_batch) - - metrics = eval_step(state, X_test, Y_test) - metrics = jax.tree_util.tree_map(lambda x: x.item(), metrics) - print(f'Epoch {epoch} - {metrics}') - -# %% -# get random samples -idxs = np.random.randint(0, len(X_test), size=(10,)) -x_sample = X_test[idxs] -y_sample = Y_test[idxs] - -# get predictions -y_pred = forward(state, x_sample) - -# plot predictions -figure = plt.figure(figsize=(3 * 5, 3 * 2)) - -for i in range(5): - plt.subplot(2, 5, i + 1) - plt.imshow(x_sample[i].reshape(image_shape), cmap='gray') - plt.title(f'{y_pred[i]}') - -plt.show() - -model = state.graphdef.merge(state.params) -# %% -# Quantization - -A = tp.TypeVar('A') - - -class QParam(nnx.Variable[A]): - pass - - -class QHParam(nnx.Variable[A]): - pass - - -class QLinear(nnx.Module): - def __init__(self, din: int, dout: int): - self.scale = QHParam(jnp.array(0.5)) - self.zero_point = QHParam(jnp.array(0.5)) - self.qkernel = QParam(jnp.zeros((din, dout))) - self.qbias = QParam(jnp.zeros((dout,))) - - def __call__(self, x: jax.Array) -> jax.Array: - x = self.quantize(x, 8, jnp.uint8) - print(x.shape, self.qkernel.value.shape, self.qbias.value.shape) - x = jnp.dot(x, self.qkernel.value, preferred_element_type=jnp.uint16) - x = (x + self.qbias.value).astype(jnp.uint32) - x = self.dequantize(x) - return x - - def quantize(self, x: jax.Array, b: int, dtype: jnp.dtype) -> jax.Array: - return jnp.clip( - diff_round(x / self.scale.value) + self.zero_point.value, 0, 2**b - 1 - ).astype(dtype) - - def dequantize(self, x: jax.Array) -> jax.Array: - return (x - self.zero_point.value) * self.scale.value - - def optimize( - self, - pretrained: nnx.Linear, - x: jax.Array, - *, - num_steps: int = 100, - debug: bool = False, - ): - static, q_hparams, rest = self.split(QHParam, ...) - tx = optax.adam(1e-3) - opt_state = tx.init(q_hparams) - - print(jax.tree_util.tree_map(lambda x: x.shape, q_hparams)) - - @jax.jit - def optimization_step( - q_hparams: nnx.State, - rest: nnx.State, - opt_state: optax.OptState, - x: jax.Array, - ): - print('JITTING') - - def loss_fn(q_hparams: nnx.State): - model = static.merge(q_hparams, rest) - model.qkernel.value = model.quantize( - pretrained.kernel.value, 8, jnp.uint8 - ) - assert pretrained.bias is not None - model.qbias.value = model.quantize( - pretrained.bias.value, 16, jnp.uint16 - ) - - y_quant = model(x) - y_unquant = pretrained(x) - loss = jnp.mean((y_unquant - y_quant) ** 2) - return loss - - loss, grads = jax.value_and_grad(loss_fn)(q_hparams) - - updates, opt_state = tx.update(grads, opt_state, q_hparams) - q_hparams = optax.apply_updates(q_hparams, updates) # type: ignore - - return q_hparams, opt_state, loss - - for step in range(num_steps): - q_hparams, opt_state, loss = optimization_step( - q_hparams, rest, opt_state, x - ) - if debug and step % (num_steps / 10) == 0: - print(f'Step {step} - loss: {loss}') - - self.update(q_hparams) - - self.qkernel.value = self.quantize(pretrained.kernel.value, 8, jnp.uint8) - assert pretrained.bias.value is not None - self.qbias.value = self.quantize(pretrained.bias.value, 16, jnp.uint16) - - -def optimize2( - self, - pretrained: nnx.Linear, - X: jax.Array, -): - W = pretrained.kernel - b = pretrained.bias - assert b is not None - - # X - alpha_X = jnp.min(X) - beta_X = jnp.max(X) - s_X, z_X = generate_quantization_int8_constants(alpha=alpha_X, beta=beta_X) - X_q = quantization_int8(x=X, s=s_X, z=z_X) - X_q_dq = dequantization(x_q=X_q, s=s_X, z=z_X) - - # W - alpha_W = jnp.min(W) - beta_W = jnp.max(W) - s_W, z_W = generate_quantization_int8_constants(alpha=alpha_W, beta=beta_W) - W_q = quantization_int8(x=W, s=s_W, z=z_W) - W_q_dq = dequantization(x_q=W_q, s=s_W, z=z_W) - - # b - alpha_b = jnp.min(b) - beta_b = jnp.max(b) - s_b, z_b = generate_quantization_int8_constants(alpha=alpha_b, beta=beta_b) - b_q = quantization_int8(x=b, s=s_b, z=z_b) - b_q_dq = dequantization(x_q=b_q, s=s_b, z=z_b) - - # Y - Y = jnp.matmul(X, W) + b - alpha_Y = jnp.min(Y) - beta_Y = jnp.max(Y) - s_Y, z_Y = generate_quantization_int8_constants(alpha=alpha_Y, beta=beta_Y) - Y_q = quantization_int8(x=Y, s=s_Y, z=z_Y) - - Y_prime = jnp.matmul(X_q_dq, W_q_dq) + b_q_dq - Y_prime_q = quantization_int8(x=Y_prime, s=s_Y, z=z_Y) - Y_prime_q_dq = dequantization(x_q=Y_prime_q, s=s_Y, z=z_Y) - - print('Expected FP32 Y:') - print(Y) - print('Expected FP32 Y Quantized:') - print(Y_q) - - Y_q_simulated = quantization_matrix_multiplication_int8( - X_q=X_q, - W_q=W_q, - b_q=b_q, - s_X=s_X, - z_X=z_X, - s_W=s_W, - z_W=z_W, - s_b=s_b, - z_b=z_b, - s_Y=s_Y, - z_Y=z_Y, - ) - Y_simulated = dequantization(x_q=Y_q_simulated, s=s_Y, z=z_Y) - - print('Expected Quantized Y_q from Quantized Matrix Multiplication:') - print(Y_q_simulated) - print( - 'Expected Quantized Y_q from Quantized Matrix Multiplication Dequantized:' - ) - print(Y_simulated) - - # Ensure the algorithm implementation is correct - assert jnp.array_equal(Y_simulated, Y_prime_q_dq) - assert jnp.array_equal(Y_q_simulated, Y_prime_q) - - -def quantization(x, s, z, alpha_q, beta_q): - x_q = jnp.round(1 / s * x + z, decimals=0) - x_q = jnp.clip(x_q, a_min=alpha_q, a_max=beta_q) - - return x_q - - -def quantization_int8(x, s, z): - x_q = quantization(x, s, z, alpha_q=-128, beta_q=127) - x_q = x_q.astype(jnp.int8) - - return x_q - - -def dequantization(x_q, s, z): - # x_q - z might go outside the quantization range. - x_q = x_q.astype(jnp.int32) - x = s * (x_q - z) - x = x.astype(jnp.float32) - - return x - - -def generate_quantization_constants(alpha, beta, alpha_q, beta_q): - # Affine quantization mapping - s = (beta - alpha) / (beta_q - alpha_q) - z = int((beta * alpha_q - alpha * beta_q) / (beta - alpha)) - - return s, z - - -def generate_quantization_int8_constants(alpha, beta): - b = 8 - alpha_q = -(2 ** (b - 1)) - beta_q = 2 ** (b - 1) - 1 - - s, z = generate_quantization_constants( - alpha=alpha, beta=beta, alpha_q=alpha_q, beta_q=beta_q - ) - - return s, z - - -def quantization_matrix_multiplication_int8( - X_q, W_q, b_q, s_X, z_X, s_W, z_W, s_b, z_b, s_Y, z_Y -): - p = W_q.shape[0] - - # Y_q_simulated is FP32 - Y_q_simulated = ( - z_Y - + (s_b / s_Y * (b_q.astype(jnp.int32) - z_b)) - + ( - (s_X * s_W / s_Y) - * ( - jnp.matmul(X_q.astype(jnp.int32), W_q.astype(jnp.int32)) - - z_W * jnp.sum(X_q.astype(jnp.int32), axis=1, keepdims=True) - - z_X * jnp.sum(W_q.astype(jnp.int32), axis=0, keepdims=True) - + p * z_X * z_W - ) - ) - ) - - Y_q_simulated = jnp.round(Y_q_simulated, decimals=0) - Y_q_simulated = jnp.clip(Y_q_simulated, a_min=-128, a_max=127) - Y_q_simulated = Y_q_simulated.astype(jnp.int8) - - return Y_q_simulated - - -# %% -qlinear1 = QLinear(din=np.prod(image_shape), dout=256) -# qlinear2 = QLinear(din=256, dout=10) - -idxs = np.random.randint(0, len(X_test), size=(100,)) -x_optimize = jnp.asarray(X_test[idxs], dtype=jnp.float32) -x_optimize = x_optimize.reshape((x_optimize.shape[0], -1)) -print(x_optimize.shape) -qlinear1.optimize(model.linear1, x_optimize, num_steps=100, debug=True) - -# %% - -# %% diff --git a/flax/experimental/nnx/ideas/shape_inference.py b/flax/experimental/nnx/ideas/shape_inference.py deleted file mode 100644 index bff4df2717..0000000000 --- a/flax/experimental/nnx/ideas/shape_inference.py +++ /dev/null @@ -1,210 +0,0 @@ -# Copyright 2024 The Flax Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import typing as tp - -import jax -import jax.numpy as jnp -from jax import random - -from flax.experimental import nnx - - -class Linear(nnx.Module): - @tp.overload - def __init__(self, *, din: int, dout: int, rngs: nnx.Rngs): - ... - - @tp.overload - def __init__(self, *, dout: int): - ... - - @tp.overload - def __init__( - self, - *, - din: tp.Optional[int] = None, - dout: int, - rngs: tp.Optional[nnx.Rngs] = None, - ): - ... - - def __init__( - self, - *, - din: tp.Optional[int] = None, - dout: int, - rngs: tp.Optional[nnx.Rngs] = None, - ): - self.dout = dout - if din is not None: - if rngs is None: - raise ValueError('rngs must be provided if din is provided') - self.init_variables(din, rngs) - - def init_variables(self, din: int, rngs: nnx.Rngs): - key = rngs.params() - self.w = nnx.Param(random.uniform(key, (din, self.dout))) - self.b = nnx.Param(jnp.zeros((self.dout,))) - - def __call__( - self, x: jax.Array, *, rngs: tp.Optional[nnx.Rngs] = None - ) -> jax.Array: - if self.is_initializing and not hasattr(self, 'w'): - if rngs is None: - raise ValueError('rngs must be provided to initialize module') - self.init_variables(x.shape[-1], rngs) - - return x @ self.w + self.b - - -class BatchNorm(nnx.Module): - @tp.overload - def __init__(self, *, mu: float = 0.95): - ... - - @tp.overload - def __init__(self, *, din: int, mu: float = 0.95, rngs: nnx.Rngs): - ... - - @tp.overload - def __init__( - self, - *, - din: tp.Optional[int] = None, - mu: float = 0.95, - rngs: tp.Optional[nnx.Rngs] = None, - ): - ... - - def __init__( - self, - *, - din: tp.Optional[int] = None, - mu: float = 0.95, - rngs: tp.Optional[nnx.Rngs] = None, - ): - self.mu = mu - - if din is not None: - if rngs is None: - raise ValueError('rngs must be provided if din is provided') - self.init_variables(din, rngs) - - def init_variables(self, din: int, rngs: nnx.Rngs): - self.scale = nnx.Param(jax.numpy.ones((din,))) - self.bias = nnx.Param(jax.numpy.zeros((din,))) - self.mean = nnx.BatchStat(jax.numpy.zeros((din,))) - self.var = nnx.BatchStat(jax.numpy.ones((din,))) - - def __call__( - self, x, *, train: bool, rngs: tp.Optional[nnx.Rngs] = None - ) -> jax.Array: - if self.is_initializing and not hasattr(self, 'scale'): - if rngs is None: - raise ValueError('rngs must be provided to initialize module') - self.init_variables(x.shape[-1], rngs) - - if train: - axis = tuple(range(x.ndim - 1)) - mean = jax.numpy.mean(x, axis=axis) - var = jax.numpy.var(x, axis=axis) - # ema update - self.mean = self.mu * self.mean + (1 - self.mu) * mean - self.var = self.mu * self.var + (1 - self.mu) * var - else: - mean, var = self.mean, self.var - - scale, bias = self.scale, self.bias - x = (x - mean) / jax.numpy.sqrt(var + 1e-5) * scale + bias - return x - - -class Dropout(nnx.Module): - def __init__(self, rate: float): - self.rate = rate - - def __call__(self, x: jax.Array, *, train: bool, rngs: nnx.Rngs) -> jax.Array: - if train: - mask = random.bernoulli(rngs.dropout(), (1 - self.rate), x.shape) - x = x * mask / (1 - self.rate) - return x - - -# ---------------------------- -# test Linear -# ---------------------------- -print('test Linear') - -# eager -m1 = Linear(din=32, dout=10, rngs=nnx.Rngs(params=0)) -y = m1(x=jnp.ones((1, 32))) -print(jax.tree_util.tree_map(jnp.shape, m1.get_state())) - -# lazy -m2 = Linear(dout=10) -y = m2.init(x=jnp.ones((1, 32)), rngs=nnx.Rngs(params=0)) -print(jax.tree_util.tree_map(jnp.shape, m2.get_state())) - -# usage -y1 = m1(x=jnp.ones((1, 32))) -y2 = m2(x=jnp.ones((1, 32))) - -# ---------------------------- -# Test scan -# ---------------------------- -print('\ntest scan') - - -class Block(nnx.Module): - def __init__( - self, - din: tp.Optional[int] = None, - dout: int = 10, - rngs: tp.Optional[nnx.Rngs] = None, - ): - self.linear = Linear(din=din, dout=dout, rngs=rngs) - self.bn = BatchNorm(din=dout if din is not None else None, rngs=rngs) - self.dropout = Dropout(0.5) - - def __call__(self, x: jax.Array, _, *, train: bool, rngs: nnx.Rngs): - x = self.linear(x, rngs=rngs) - x = self.bn(x, train=train, rngs=rngs) - x = self.dropout(x, train=train, rngs=rngs) - x = jax.nn.gelu(x) - return x, None - - -MLP = nnx.Scan( - Block, - variable_axes={nnx.Param: 0}, - variable_carry=nnx.BatchStat, - split_rngs={'params': True, 'dropout': True}, - length=5, -) - - -# eager -mlp = MLP(din=10, dout=10, rngs=nnx.Rngs(params=0)) -y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, rngs=nnx.Rngs(dropout=1)) -print(f'{y.shape=}') -print('state =', jax.tree_util.tree_map(jnp.shape, mlp.get_state())) -print() - -# lazy -mlp = MLP(dout=10) -mlp.init(jnp.ones((1, 10)), None, train=False, rngs=nnx.Rngs(params=0)) -y, _ = mlp.call(jnp.ones((1, 10)), None, train=True, rngs=nnx.Rngs(dropout=1)) -print(f'{y.shape=}') -print('state =', jax.tree_util.tree_map(jnp.shape, mlp.get_state())) diff --git a/flax/experimental/nnx/nnx/compatibility.py b/flax/experimental/nnx/nnx/compatibility.py index 8d31c5f0f8..50a954e65b 100644 --- a/flax/experimental/nnx/nnx/compatibility.py +++ b/flax/experimental/nnx/nnx/compatibility.py @@ -16,6 +16,7 @@ import typing as tp from typing import Any +from flax.experimental import nnx from flax import linen from flax.experimental.nnx.nnx import variables as variableslib from flax.experimental.nnx.nnx.module import GraphDef, Module @@ -38,7 +39,7 @@ def init(self, *, rngs: tp.Optional[Rngs] = None) -> State: if rngs is not None: kwargs['rngs'] = rngs module = self.module_type(*self.args, **self.kwargs, **kwargs) - graphdef, state = module.split() + graphdef, state = nnx.split(module) self.graphdef = graphdef return state diff --git a/flax/experimental/nnx/nnx/filterlib.py b/flax/experimental/nnx/nnx/filterlib.py index 25bb67cfc7..b6c406f114 100644 --- a/flax/experimental/nnx/nnx/filterlib.py +++ b/flax/experimental/nnx/nnx/filterlib.py @@ -23,13 +23,19 @@ ellipsis = tp.Any Predicate = tp.Callable[[PathParts, tp.Any], bool] + FilterLiteral = tp.Union[type, str, Predicate, bool, ellipsis, None] -Filter = tp.Union[FilterLiteral, tuple[FilterLiteral, ...], list[FilterLiteral]] +Filter = tp.Union[FilterLiteral, tuple['Filter', ...], list['Filter']] + @tp.runtime_checkable class _HasTag(tp.Protocol): tag: str +@tp.runtime_checkable +class _HasType(tp.Protocol): + type: type + def to_predicate(filter: Filter) -> Predicate: if isinstance(filter, str): @@ -63,7 +69,11 @@ class OfType: type: type def __call__(self, path: PathParts, x: tp.Any): - return isinstance(x, self.type) + return ( + isinstance(x, self.type) + or isinstance(x, _HasType) + and issubclass(x.type, self.type) + ) class Any: @@ -87,7 +97,7 @@ def __call__(self, path: PathParts, x: tp.Any): class Not: - def __init__(self, collection_filter: Filter): + def __init__(self, collection_filter: Filter, /): self.predicate = to_predicate(collection_filter) def __call__(self, path: PathParts, x: tp.Any): diff --git a/flax/experimental/nnx/nnx/graph_utils.py b/flax/experimental/nnx/nnx/graph.py similarity index 64% rename from flax/experimental/nnx/nnx/graph_utils.py rename to flax/experimental/nnx/nnx/graph.py index 8d56776fe9..71dafa2592 100644 --- a/flax/experimental/nnx/nnx/graph_utils.py +++ b/flax/experimental/nnx/nnx/graph.py @@ -21,14 +21,14 @@ from abc import ABCMeta from copy import deepcopy + import jax +import numpy as np import typing_extensions as tpe from flax.experimental.nnx.nnx import ( errors, filterlib, - graph_utils, - ids, reprlib, tracers, ) @@ -37,8 +37,13 @@ CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.state import State, StateLeaf, is_state_leaf -from flax.experimental.nnx.nnx.variables import EMPTY, Empty, Variable +from flax.experimental.nnx.nnx.state import ( + FlatState, + State, + StateLeaf, + is_state_leaf, +) +from flax.experimental.nnx.nnx.variables import Variable, VariableState from flax.typing import PathParts, Key A = tp.TypeVar('A') @@ -63,13 +68,19 @@ tuple[State, ...], ] +NodeLeaf = tp.Union[Variable[tp.Any], np.ndarray, jax.Array] + + +def is_node_leaf(x: tp.Any) -> tpe.TypeGuard[NodeLeaf]: + return isinstance(x, (Variable, np.ndarray, jax.Array)) + @dataclasses.dataclass class GraphUtilsContext(threading.local): node_types: dict[ type, 'NodeImpl[tp.Any, tp.Any, tp.Any]' ] = dataclasses.field(default_factory=dict) - seen_modules_repr: tp.Optional[tp.Set[ids.UUID]] = None + seen_modules_repr: set[int] | None = None CONTEXT = GraphUtilsContext() @@ -99,8 +110,11 @@ def __eq__(self, other: tp.Any) -> bool: class RefMap(tp.MutableMapping[A, B], reprlib.MappingReprMixin[A, B]): """A mapping that uses object id as the hash for the keys.""" - def __init__(self): + def __init__( + self, mapping: tp.Mapping[A, B] | tp.Iterable[tuple[A, B]] = (), / + ): self._mapping: dict[_HashById[A], B] = {} + self.update(mapping) def __getitem__(self, key: A) -> B: return self._mapping[_HashById(key)] @@ -175,9 +189,7 @@ def register_graph_node_type( def is_node(x: tp.Any) -> bool: - if isinstance(x, Variable): - return False - elif type(x) in CONTEXT.node_types: + if type(x) in CONTEXT.node_types: return True return is_pytree_node(x) @@ -239,162 +251,72 @@ def __repr__(self) -> str: return repr(self._mapping) -@dataclasses.dataclass(repr=False) -class _MappingRepr(reprlib.Representable): - mapping: tp.Mapping[Key, tp.Any] - - def __nnx_repr__(self): - yield reprlib.Object(type='', value_sep=': ', start='{', end='}') - - for key, value in self.mapping.items(): - yield reprlib.Attr(repr(key), value) - - -class VariableDef(reprlib.Representable): - __slots__ = ( - '_type', - '_index', - '_metadata', - ) - - def __init__( - self, - type: tp.Type[Variable[tp.Any]], - index: int, - metadata: dict[Key, tp.Any], - ): - self._type = type - self._index = index - self._metadata = metadata - - def __nnx_repr__(self): - yield reprlib.Object(type=type(self)) - - yield reprlib.Attr('type', self._type.__name__) - yield reprlib.Attr('index', self._index) - yield reprlib.Attr('metadata', _MappingRepr(self._metadata)) - - @property - def type(self): - return self._type - - @property - def index(self): - return self._index - - @property - def metadata(self): - return self._metadata +@dataclasses.dataclass(frozen=True, repr=False) +class NodeDef(tp.Generic[Node], reprlib.Representable): + type: tp.Type[Node] + index: int + attributes: tuple[Key, ...] + subgraphs: _HashableMapping[Key, tp.Union['NodeDef[tp.Any]', Index]] + static_fields: _HashableMapping[Key, tp.Any] + variables: _HashableMapping[Key, Index] + metadata: tp.Any @classmethod - def from_variable(cls, variable: Variable[tp.Any], index: int) -> VariableDef: - metadata = vars(variable).copy() - del metadata['raw_value'] - del metadata['_trace_state'] - return cls(type(variable), index, metadata) - - def to_variable(self, value: Node) -> Variable[Node]: - # we use object.__new__ to avoid calling __init__ and bypass the - # __init__ logic which should not be called twice - variables = object.__new__(self._type) - vars(variables).update( - self._metadata, raw_value=value, _trace_state=tracers.TraceState() - ) - return variables - - def __hash__(self): - return hash((self._type, self._index, tuple(self._metadata.items()))) - - def __eq__(self, other): - if not isinstance(other, VariableDef): - return False - return ( - self._type == other._type - and self._index == other._index - and self._metadata == other._metadata - ) - - -class GraphDef(tp.Generic[Node], reprlib.Representable): - __slots__ = ( - '_type', - '_index', - '_attributes', - '_subgraphs', - '_static_fields', - '_variables', - '_metadata', - ) - - def __init__( - self, + def create( + cls, type: tp.Type[Node], index: int, attributes: tuple[Key, ...], - subgraphs: tp.Iterable[tuple[Key, tp.Union['GraphDef[tp.Any]', int]]], + subgraphs: tp.Iterable[tuple[Key, tp.Union['GraphDef[tp.Any]', Index]]], static_fields: tp.Iterable[tuple[Key, tp.Any]], - variables: tp.Iterable[tuple[Key, VariableDef | int]], + variables: tp.Iterable[tuple[Key, Index]], metadata: tp.Any, ): - self._type: type[Node] = type - self._index = index - self._attributes = attributes - self._subgraphs = _HashableMapping(subgraphs) - self._static_fields = _HashableMapping(static_fields) - self._variables = _HashableMapping(variables) - self._metadata = metadata + return cls( + type=type, + index=index, + attributes=attributes, + subgraphs=_HashableMapping(subgraphs), + static_fields=_HashableMapping(static_fields), + variables=_HashableMapping(variables), + metadata=metadata, + ) def __nnx_repr__(self): yield reprlib.Object(type=type(self)) - yield reprlib.Attr('type', self._type.__name__) - yield reprlib.Attr('index', self._index) - yield reprlib.Attr('attributes', self._attributes) - yield reprlib.Attr('subgraphs', _MappingRepr(self._subgraphs)) - yield reprlib.Attr('static_fields', _MappingRepr(self._static_fields)) - yield reprlib.Attr('variables', _MappingRepr(self._variables)) - yield reprlib.Attr('metadata', self._metadata) - - def __hash__(self) -> int: - return hash((self._type, self._subgraphs)) - - def __eq__(self, other: tp.Any) -> bool: - if not isinstance(other, GraphDef): - return False - return self._type == other._type and self._subgraphs == other._subgraphs - - @property - def type(self) -> tp.Type[Node]: - return self._type + yield reprlib.Attr('type', self.type.__name__) + yield reprlib.Attr('index', self.index) + yield reprlib.Attr('attributes', self.attributes) + yield reprlib.Attr('subgraphs', reprlib.PrettyMapping(self.subgraphs)) + yield reprlib.Attr( + 'static_fields', reprlib.PrettyMapping(self.static_fields) + ) + yield reprlib.Attr('variables', reprlib.PrettyMapping(self.variables)) + yield reprlib.Attr('metadata', self.metadata) - @property - def index(self) -> int: - return self._index - @property - def attributes(self) -> tuple[str, ...]: - return self._attributes +@dataclasses.dataclass(frozen=True, repr=False) +class GraphDef(tp.Generic[Node], reprlib.Representable): + nodedef: NodeDef[Node] + index_mapping: dict[Index, Index] | None - @property - def subgraphs(self): - return self._subgraphs + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) - @property - def static_fields(self): - return self._static_fields + yield reprlib.Attr('nodedef', self.nodedef) + yield reprlib.Attr('index_mapping', self.index_mapping) - @property - def variables(self): - return self._variables + def __deepcopy__(self, memo=None): + nodedef = deepcopy(self.nodedef, memo) + index_mapping = deepcopy(self.index_mapping, memo) + return GraphDef(nodedef, index_mapping) - @property - def metadata(self) -> tp.Any: - return self._metadata + def __hash__(self): + return hash(self.nodedef) - def merge(self, state: State, /, *states: State) -> Node: - if states: - state = State.merge(state, *states) - return graph_unflatten(self, state)[0] + def __eq__(self, other): + return isinstance(other, GraphDef) and self.nodedef == other.nodedef def apply( self, state: State, *states: State @@ -404,105 +326,98 @@ def apply( def _apply( accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tuple[GraphDef[Node], State]]: - module = self.merge(state, *states) + module = merge(self, state, *states) fn = accessor(module) out = fn(*args, **kwargs) - return out, graph_flatten(module)[:2] + return out, flatten(module)[:2] return CallableProxy(_apply, accessor) # type: ignore def make_empty(self) -> Node: - return self.merge(State({})) - - -def _gradphdef_flatten(graphdef: GraphDef[tp.Any]): - return (), ( - graphdef._type, - graphdef._index, - graphdef._attributes, - graphdef._subgraphs, - graphdef._static_fields, - graphdef._variables, - graphdef._metadata, - ) + return merge(self, State({})) + + +def _graphdef_flatten(graphdef: GraphDef[Node]): + # refmap is opaque, we don't propagate it + static = (graphdef.nodedef, graphdef.index_mapping) + return (), static def _graphdef_unflatten( - metadata: tuple[ - tp.Type[Node], - int, - tuple[Key, ...], - tuple[tuple[Key, GraphDef[Node] | int], ...], - tuple[tuple[Key, tp.Any], ...], - tuple[tuple[Key, Variable[Empty] | int], ...], - tp.Any, - ], - _, -) -> GraphDef[Node]: - return GraphDef(*metadata) + static: tuple[NodeDef[Node], dict[Index, Index] | None], _nodes: tuple[()] +): + nodedef, index_mapping = static + return GraphDef(nodedef, index_mapping) jax.tree_util.register_pytree_node( - GraphDef, _gradphdef_flatten, _graphdef_unflatten + GraphDef, + _graphdef_flatten, + _graphdef_unflatten, ) -def graph_flatten( +def flatten( x: Node, /, -) -> tuple[GraphDef[Node], State, tp.Mapping[tp.Any, Index]]: - ref_to_index = RefMap[tp.Any, Index]() + *, + idxmap: dict[Index, tp.Any] | None = None, +) -> tuple[GraphDef[Node], State, RefMap[tp.Any, Index]]: + refmap = RefMap[tp.Any, Index]() flat_state: dict[PathParts, StateLeaf] = {} - graphdef = _graph_flatten((), ref_to_index, flat_state, x) - assert not isinstance(graphdef, int) - return graphdef, State.from_flat_path(flat_state), ref_to_index + nodedef = _graph_flatten((), refmap, flat_state, x) + assert not isinstance(nodedef, int) + if idxmap is not None: + index_to_index = compose_mapping(idxmap, refmap) + else: + index_to_index = None + graphdef = GraphDef(nodedef, index_to_index) + return graphdef, State.from_flat_path(flat_state), refmap def _graph_flatten( path: PathParts, - ref_to_index: RefMap[tp.Any, Index], + refmap: RefMap[tp.Any, Index], flat_state: dict[PathParts, StateLeaf], node: Node, -) -> GraphDef[Node] | int: +) -> NodeDef[Node] | int: if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}, this is a bug.') - if node in ref_to_index: - return ref_to_index[node] + if node in refmap: + return refmap[node] node_impl = get_node_impl(node) # only cache graph nodes if isinstance(node_impl, GraphNodeImpl): - index = len(ref_to_index) - ref_to_index[node] = index + index = len(refmap) + refmap[node] = index else: index = -1 - subgraphs: list[tuple[Key, tp.Union[GraphDef[Node], int]]] = [] + subgraphs: list[tuple[Key, tp.Union[NodeDef[Node], int]]] = [] static_fields: list[tuple[Key, tp.Any]] = [] - variables: list[tuple[Key, VariableDef | int]] = [] + variables: list[tuple[Key, int]] = [] values, metadata = node_impl.flatten(node) for key, value in values: if is_node(value): - graphdef = _graph_flatten((*path, key), ref_to_index, flat_state, value) - subgraphs.append((key, graphdef)) + nodedef = _graph_flatten((*path, key), refmap, flat_state, value) + subgraphs.append((key, nodedef)) elif isinstance(value, Variable): - if value in ref_to_index: - variables.append((key, ref_to_index[value])) + if value in refmap: + variables.append((key, refmap[value])) else: - flat_state[(*path, key)] = value.copy() - variable_index = ref_to_index[value] = len(ref_to_index) - variables.append( - (key, VariableDef.from_variable(value, variable_index)) - ) + flat_state[(*path, key)] = value.to_state() + variable_index = refmap[value] = len(refmap) + variables.append((key, variable_index)) elif is_state_leaf(value): flat_state[(*path, key)] = value else: static_fields.append((key, value)) - graphdef = GraphDef( + nodedef = NodeDef.create( type=node_impl.type, index=index, attributes=tuple(key for key, _ in values), @@ -511,20 +426,20 @@ def _graph_flatten( variables=variables, metadata=metadata, ) - return graphdef + return nodedef -def graph_unflatten( +def unflatten( graphdef: GraphDef[Node], state: State, /, *, - ref_cache: dict[Index, tp.Any] | None = None, + idxmap: dict[Index, tp.Any] | None = None, ) -> tuple[Node, dict[Index, tp.Any]]: """Unflattens a graphdef into a node with the given state. Args: - graphdef: A GraphDef instance. + graphdef: A NodeDef instance. state: A State instance. ref_cache: A mapping from indexes to existing nodes that can be reused. When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the @@ -533,20 +448,22 @@ def graph_unflatten( specified by the graphdef. """ index_to_ref: dict[Index, tp.Any] = {} - node = _graph_unflatten(graphdef, state.raw_mapping, index_to_ref, ref_cache) + node = _graph_unflatten( + graphdef.nodedef, state.raw_mapping, index_to_ref, idxmap + ) return node, index_to_ref def _graph_unflatten( - graphdef: tp.Union[GraphDef[Node], int], - state: dict[str, StateLeaf | dict[str, tp.Any]], + nodedef: tp.Union[NodeDef[Node], int], + state: dict[Key, StateLeaf | dict[Key, tp.Any]], index_to_ref: dict[Index, tp.Any], - ref_cache: dict[Index, tp.Any] | None, + idxmap: dict[Index, tp.Any] | None, ) -> Node: """Recursive helper for graph_unflatten. Args: - graphdef: A GraphDef instance or an index to a node in the cache. + nodedef: A NodeDef instance or an index to a node in the cache. state: A mapping from attribute names to variables or subgraphs. index_to_ref: A mapping from indexes to nodes that have been traversed. If a node is already in the cache, it won't be traversed again. @@ -554,132 +471,124 @@ def _graph_unflatten( When an reference is reused, ``GraphNodeImpl.clear`` is called to leave the object in an empty state and then filled by the unflatten process, as a result existing graph nodes are mutated to have the new content/topology - specified by the graphdef. + specified by the nodedef. """ - if isinstance(graphdef, int): - return index_to_ref[graphdef] + if isinstance(nodedef, int): + return index_to_ref[nodedef] - if not is_node_type(graphdef.type): - raise RuntimeError(f'Unsupported type: {graphdef.type}, this is a bug.') + if not is_node_type(nodedef.type): + raise RuntimeError(f'Unsupported type: {nodedef.type}, this is a bug.') - if graphdef.index in index_to_ref: - raise RuntimeError(f'GraphDef index {graphdef.index} already used.') + if nodedef.index in index_to_ref: + raise RuntimeError(f'NodeDef index {nodedef.index} already used.') - node_impl = get_node_impl_for_type(graphdef.type) + node_impl = get_node_impl_for_type(nodedef.type) def _get_children(): - children: dict[str, StateLeaf | Node] = {} + children: dict[Key, NodeLeaf | Node] = {} + + if unkown_keys := set(state) - set(nodedef.attributes): + raise ValueError(f'Unknown keys: {unkown_keys}') - for key in graphdef.attributes: - if key in graphdef.static_fields: - children[key] = graphdef.static_fields[key] + for key in nodedef.attributes: + if key in nodedef.static_fields: + children[key] = nodedef.static_fields[key] elif key not in state: # TODO(cgarcia): maybe we shouldn't support unflattening with missing keys? # if key is not present create an empty types - if key in graphdef.subgraphs: + if key in nodedef.subgraphs: # if the key is a subgraph we create an empty node - subgraphdef = graphdef.subgraphs[key] + subgraphdef = nodedef.subgraphs[key] if isinstance(subgraphdef, int): # subgraph exists, take it from the cache children[key] = index_to_ref[subgraphdef] else: - # create an empty node and add it to the cache + # create an empty node substate = {} - node = children[key] = _graph_unflatten( - subgraphdef, substate, index_to_ref, ref_cache + children[key] = _graph_unflatten( + subgraphdef, substate, index_to_ref, idxmap ) - elif key in graphdef.variables: - variable_def = graphdef.variables[key] - if isinstance(variable_def, int): + elif key in nodedef.variables: + variable_index = nodedef.variables[key] + if variable_index in index_to_ref: # variable exists, take it from the cache - children[key] = index_to_ref[variable_def] + children[key] = index_to_ref[variable_index] else: - # create an empty variable and add it to the cache - if ref_cache is not None and variable_def.index in ref_cache: - node = ref_cache[variable_def.index] - if type(node) != variable_def.type: - raise ValueError( - f'Expected a node of type {variable_def.type.__name__} for ' - f'index {variable_def.index}, but got a node of type ' - f'{type(node).__name__}.' - ) - assert isinstance(node, Variable) - node.copy_from_def(variable_def, EMPTY) - else: - node = variable_def.to_variable(EMPTY) - children[key] = node - index_to_ref[variable_def.index] = node + # key for a variable is missing, raise an error + raise ValueError( + f'Expected key for Variable but was not found in state: {key!r}' + ) else: raise RuntimeError(f'Unknown static field: {key!r}') else: value = state[key] - if key in graphdef.subgraphs: + if key in nodedef.subgraphs: if is_state_leaf(value): raise ValueError( - f'Expected a subgraph for {key!r}, but got a Variable.' + f'Expected value of type {nodedef.subgraphs[key]} for ' + f'{key!r}, but got {value!r}' ) assert isinstance(value, dict) - subgraphdef = graphdef.subgraphs[key] + subgraphdef = nodedef.subgraphs[key] if isinstance(subgraphdef, int): node = index_to_ref[subgraphdef] else: node = children[key] = _graph_unflatten( - subgraphdef, value, index_to_ref, ref_cache + subgraphdef, value, index_to_ref, idxmap ) - elif key in graphdef.variables: - variable_def = graphdef.variables[key] - if isinstance(variable_def, int): - children[key] = index_to_ref[variable_def] + elif key in nodedef.variables: + variable_index = nodedef.variables[key] + if variable_index in index_to_ref: + children[key] = index_to_ref[variable_index] else: - if type(value) != variable_def.type: + if not isinstance(value, VariableState): raise ValueError( - f'Expected a Variable of type {variable_def.type} ' - f'for {key!r}, but got a Variable of type {type(value)}.' + f'Expected a Variable type for {key!r}, but got {type(value)}.' ) - assert isinstance(value, Variable) - if ref_cache is not None and variable_def.index in ref_cache: - variable = ref_cache[variable_def.index] - if type(variable) != variable_def.type: + if idxmap is not None and variable_index in idxmap: + variable = idxmap[variable_index] + if not isinstance(variable, Variable): raise ValueError( - f'Expected a Variable of type {variable_def.type} for ' - f'{key!r}, but got a Variable of type {type(variable)}.' + f'Expected a Variable type for {key!r}, but got {type(variable)}.' ) - variable.copy_from(value) + variable.copy_from_state(value) else: - assert isinstance(value, Variable) - variable = value.copy() + assert isinstance(value, VariableState) + variable = value.to_variable() children[key] = variable - index_to_ref[variable_def.index] = variable + index_to_ref[variable_index] = variable elif is_state_leaf(value): + if isinstance(value, VariableState): + value = value.to_variable() children[key] = value - for new_key in set(state) - set(graphdef.attributes): - raise ValueError(f'Unknown key: {new_key!r}') + else: + raise RuntimeError return children if isinstance(node_impl, GraphNodeImpl): # we create an empty node first and add it to the index # this avoids infinite recursion when there is a reference cycle - if ref_cache is not None and graphdef.index in ref_cache: - node = ref_cache[graphdef.index] - if type(node) != graphdef.type: + if idxmap is not None and nodedef.index in idxmap: + node = idxmap[nodedef.index] + if type(node) != nodedef.type: raise ValueError( - f'Expected a node of type {graphdef.type} for index ' - f'{graphdef.index}, but got a node of type {type(node)}.' + f'Expected a node of type {nodedef.type} for index ' + f'{nodedef.index}, but got a node of type {type(node)}.' ) - node_impl.clear(node, graphdef.metadata) + node_impl.clear(node, nodedef.metadata) else: - node = node_impl.create_empty(graphdef.metadata) - index_to_ref[graphdef.index] = node + node = node_impl.create_empty(nodedef.metadata) + index_to_ref[nodedef.index] = node children = _get_children() node_impl.init(node, tuple(children.items())) else: # if the node type does not support the creation of an empty object it means # that it cannot reference itself, so we can create its children first children = _get_children() - node = node_impl.unflatten(tuple(children.items()), graphdef.metadata) + node = node_impl.unflatten(tuple(children.items()), nodedef.metadata) return node @@ -716,10 +625,14 @@ def _graph_pop( for name, value in node_dict.items(): if is_node(value): _graph_pop( - value, id_to_index, (*path_parts, name), flat_states, predicates + node=value, + id_to_index=id_to_index, + path_parts=(*path_parts, name), + flat_states=flat_states, + predicates=predicates, ) continue - elif not is_state_leaf(value): + elif not is_node_leaf(value): continue elif id(value) in id_to_index: continue @@ -735,7 +648,7 @@ def _graph_pop( id_to_index[id(value)] = len(id_to_index) node_impl.pop_key(node, name) if isinstance(value, Variable): - value = value.copy() + value = value.to_state() state[node_path] = value break else: @@ -743,23 +656,7 @@ def _graph_pop( pass -def graph_update_dynamic( - node: tp.Any, - updates: State | tp.Sequence[State], -) -> None: - if not is_node(node): - raise ValueError(f'Unsupported type: {type(node)}') - - if isinstance(updates, State): - new_states = (updates,) - else: - new_states = updates - - for state in new_states: - _graph_update_dynamic(node, state.raw_mapping) - - -def _graph_update_dynamic(node: tp.Any, state: dict[str, tp.Any]): +def _graph_update_dynamic(node: tp.Any, state: dict[Key, tp.Any]): if not is_node(node): raise RuntimeError(f'Unsupported type: {type(node)}') @@ -786,14 +683,14 @@ def _graph_update_dynamic(node: tp.Any, state: dict[str, tp.Any]): if is_state_leaf(value): raise ValueError(f'Expected a subgraph for {key!r}, but got: {value!r}') _graph_update_dynamic(current_value, value) - elif isinstance(value, Variable): + elif isinstance(value, VariableState): # case 3: state leaf is being updated if not isinstance(current_value, Variable): raise ValueError( f'Trying to update a non-Variable attribute {key!r} with a Variable: ' f'{value!r}' ) - current_value.copy_from(value) + current_value.copy_from_state(value) elif is_state_leaf(value): # case 4: state field is being updated if isinstance(node_impl, PytreeNodeImpl): @@ -813,6 +710,14 @@ class _StaticModuleStatus(enum.Enum): UPDATED = enum.auto() +# TODO(cgarciae): remove once transform init are reimplemented +def update_from(node: Node, updates: Node) -> None: + graph_update_static(node, updates) + _, state = split(updates) + update(node, state) + + +# TODO(cgarciae): remove once transform init are reimplemented def graph_update_static(node: Node, updates: Node) -> None: cache: dict[int, _StaticModuleStatus] = {} _graph_update_static(node, updates, cache, _StaticModuleStatus.UPDATED, ()) @@ -903,14 +808,116 @@ def _graph_update_static( node_impl.set_key(node, name, value_updates) +@dataclasses.dataclass +class UpdateContext: + refmap: RefMap[tp.Any, Index] | None = None + idxmap: dict[Index, tp.Any] | None = None + + # define context manager to clean up refmap and idxmap + # on exit + def __enter__(self): + return self + + def __exit__(self, *args, **kwargs): + self.refmap = None + self.idxmap = None + + # define hash and eq to make this an opaque object + def __hash__(self): + return 0 + + def __eq__(self, other): + return isinstance(other, UpdateContext) + + @tp.overload + def split(self, graph_node: A, /) -> tuple[GraphDef[A], State]: + ... + + @tp.overload + def split( + self, graph_node: A, first: filterlib.Filter, / + ) -> tuple[GraphDef[A], State]: + ... + + @tp.overload + def split( + self, + graph_node: A, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, + ) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: + ... + + def split( + self, node: A, *filters: filterlib.Filter + ) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: + if self.refmap is not None and self.idxmap is None: + raise ValueError( + "'merge' was not called in-between the first and second call to 'split'" + ) + graphdef, state, refmap = flatten(node, idxmap=self.idxmap) + + if len(filters) == 0: + states = (state,) + elif len(filters) == 1: + states = (state.split(filters[0]),) + else: + states = state.split(filters[0], filters[1], *filters[2:]) + + if self.refmap is None: + self.refmap = refmap + + return graphdef, states[0], *states[1:] + + def merge( + self, + graphdef: GraphDef[A], + state: State, + *states: State, + ) -> A: + # TODO: add docstring of example usage + if states: + state = State.merge(state, *states) + + node, self.idxmap = unflatten(graphdef, state) + return node + + def update( + self, + new_graphdef: GraphDef[A], + state: State, + /, + *states: State, + ): + if self.refmap is None: + raise ValueError('Cannot update a graphdef without refmap.') + if new_graphdef.index_mapping is None: + raise ValueError('Cannot update a graphdef without index_mapping.') + + if states: + state = State.merge(state, *states) + + index_to_ref = compose_mapping_reversed( + self.refmap, new_graphdef.index_mapping + ) + return unflatten(new_graphdef, state, idxmap=index_to_ref)[0] + + +jax.tree_util.register_static(UpdateContext) + + @tp.overload -def split(graph_node: A) -> tuple[GraphDef[A], State]: +def split(graph_node: A, /) -> tuple[GraphDef[A], State]: ... @tp.overload def split( - graph_node: A, first: filterlib.Filter, / + graph_node: A, + first: filterlib.Filter, + /, ) -> tuple[GraphDef[A], State]: ... @@ -927,9 +934,9 @@ def split( def split( - graph_node: A, *filters: filterlib.Filter + node: A, *filters: filterlib.Filter ) -> tuple[GraphDef[A], State, tpe.Unpack[tuple[State, ...]]]: - graphdef, state, _ = graph_flatten(graph_node) + graphdef, state, _ = flatten(node) if len(filters) == 0: states = (state,) @@ -944,63 +951,114 @@ def split( def merge( graphdef: GraphDef[A], state: State, + /, *states: State, ) -> A: - # TODO: add docstring of example usage - return graphdef.merge(state, *states) + if states: + state = State.merge(state, *states) + node, _ = unflatten(graphdef, state) + return node -def update(graph_node: A, update: Updates[A], /, *updates: Updates[A]) -> None: - updates = (update, *updates) - # find states and module_update - leaves = jax.tree_util.tree_leaves( - updates, is_leaf=lambda x: isinstance(x, (GraphDef, State)) - ) - states: list[State] = [] - module_update: tp.Optional[A] = None +def update(node, state: State, /, *states: State) -> None: + if states: + state = State.merge(state, *states) - for leaf in leaves: - if is_graph_node(leaf) or isinstance(leaf, GraphDef): - if module_update is not None: - raise ValueError( - 'Expected only one GraphDef or GraphNode in the updates' - ) + _graph_update_dynamic(node, state.raw_mapping) - if is_graph_node(leaf): - if not isinstance(leaf, type(graph_node)): - raise ValueError( - 'Expected a GraphNode of the same type as the input, ' - f'got {type(leaf).__name__} instead.' - ) - module_update = leaf - states.append(split(leaf)[1]) - elif isinstance(leaf, GraphDef): - module_update = leaf.make_empty() - else: - raise ValueError( - 'Expected a GraphDef or graph node, got' f' {type(leaf).__name__}' - ) - elif isinstance(leaf, State): - states.append(leaf) - else: - raise ValueError( - 'Expected a GraphDef, GraphNode or State, got' f' {type(leaf).__name__}' - ) - if module_update is not None: - graph_update_static(graph_node, module_update) +@tp.overload +def state(node, /) -> State: + ... - if states: - graph_update_dynamic(graph_node, states) + +@tp.overload +def state(node, first: filterlib.Filter, /) -> State: + ... + + +@tp.overload +def state( + node, + first: filterlib.Filter, + second: filterlib.Filter, + /, + *filters: filterlib.Filter, +) -> tuple[State, ...]: + ... + + +def state( + node, + *filters: filterlib.Filter, +) -> tp.Union[State, tuple[State, ...]]: + state = flatten(node)[1] + + if len(filters) == 0: + states = state + elif len(filters) == 1: + states = state.filter(filters[0]) + else: + states = state.filter(filters[0], filters[1], *filters[1:]) + + return states + + +def graphdef(node: tp.Any, /) -> GraphDef[tp.Any]: + graphdef, _, _ = flatten(node) + return graphdef + + +@tp.overload +def pop( + node, + filter: filterlib.Filter, + /, +) -> State: + ... + + +@tp.overload +def pop( + node, + filter: filterlib.Filter, + filter2: filterlib.Filter, + /, + *filters: filterlib.Filter, +) -> tuple[State, ...]: + ... + + +def pop(node, *filters: filterlib.Filter) -> tp.Union[State, tuple[State, ...]]: + if len(filters) == 0: + raise ValueError('Expected at least one filter') + + id_to_index: dict[int, Index] = {} + path_parts: PathParts = () + predicates = tuple(filterlib.to_predicate(filter) for filter in filters) + flat_states: tuple[FlatState, ...] = tuple({} for _ in predicates) + _graph_pop( + node=node, + id_to_index=id_to_index, + path_parts=path_parts, + flat_states=flat_states, + predicates=predicates, + ) + states = tuple(State.from_flat_path(flat_state) for flat_state in flat_states) + + if len(states) == 1: + return states[0] + else: + return states def clone(node: Node) -> Node: - static, state, _ = graph_flatten(node) - return static.merge(state) + graphdef, state = split(node) + return merge(graphdef, state) -def iter_nodes(node: tp.Any) -> tp.Iterator[tuple[PathParts, tp.Any]]: +def iter_nodes(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]: visited: set[int] = set() path_parts: PathParts = () yield from _iter_nodes(node, visited, path_parts) @@ -1053,7 +1111,7 @@ class Static(tp.Generic[A]): class GraphNodeIndex: """Index of a graph node in a Pytree structure.""" - index: int + index: Index jax.tree_util.register_static(GraphNodeIndex) @@ -1065,7 +1123,10 @@ def extract_graph_nodes(pytree: A, /) -> tuple[A, tuple[tp.Any, ...]]: def _maybe_extract(x): if is_graph_node(x): - index = nodes.setdefault(x, len(nodes)) + if x not in nodes: + index = nodes[x] = len(nodes) + else: + index = nodes[x] return GraphNodeIndex(index) return x @@ -1091,19 +1152,19 @@ def _maybe_insert(x): class ModuleState(reprlib.Representable): - __slots__ = ('_trace_state', '_id') + __slots__ = ('_trace_state', '_initializing') - def __init__(self): + def __init__(self, initializing: bool = False): self._trace_state = tracers.TraceState() - self._id = ids.uuid() + self._initializing = initializing @property def trace_state(self) -> tracers.TraceState: return self._trace_state @property - def id(self) -> ids.UUID: - return self._id + def initializing(self) -> bool: + return self._initializing def __nnx_repr__(self): yield reprlib.Object(type(self)) @@ -1124,6 +1185,14 @@ def _graph_node_meta_call(cls: tp.Type[G], *args, **kwargs) -> G: return node +@dataclasses.dataclass(frozen=True, repr=False) +class Array: + shape: tp.Tuple[int, ...] + dtype: tp.Any + + def __repr__(self): + return f'Array(shape={self.shape}, dtype={self.dtype.name})' + class GraphNode(reprlib.Representable, metaclass=GraphNodeMeta): if tp.TYPE_CHECKING: @@ -1132,7 +1201,7 @@ class GraphNode(reprlib.Representable, metaclass=GraphNodeMeta): def __init_subclass__(cls) -> None: super().__init_subclass__() - graph_utils.register_graph_node_type( + register_graph_node_type( type=cls, flatten=cls._graph_node_flatten, set_key=cls._graph_node_set_key, @@ -1157,13 +1226,10 @@ def check_valid_context(self, error_msg: str) -> None: raise errors.TraceContextError(error_msg) def __deepcopy__(self: G, memo=None) -> G: - graphdef, state, _ = graph_utils.graph_flatten(self) + graphdef, state = split(self) graphdef = deepcopy(graphdef) state = deepcopy(state) - return graphdef.merge(state) - - def __hash__(self) -> int: - return hash(self._graph_node__state.id) + return merge(graphdef, state) def __nnx_repr__(self): if CONTEXT.seen_modules_repr is None: @@ -1172,19 +1238,29 @@ def __nnx_repr__(self): else: clear_seen = False - if self._graph_node__state.id in CONTEXT.seen_modules_repr: + if id(self) in CONTEXT.seen_modules_repr: yield reprlib.Object(type=type(self), empty_repr='...') return yield reprlib.Object(type=type(self)) - CONTEXT.seen_modules_repr.add(self._graph_node__state.id) + CONTEXT.seen_modules_repr.add(id(self)) try: for name, value in vars(self).items(): - if isinstance(value, GraphNode) or ( - not isinstance(value, Variable) and not name.startswith('_') - ): - yield reprlib.Attr(name, repr(value)) + if name.startswith('_'): + continue + + def to_shape_dtype(value): + if isinstance(value, Variable): + return value.replace( + raw_value=jax.tree.map(to_shape_dtype, value.raw_value) + ) + elif isinstance(value, (np.ndarray, jax.Array)): + return Array(value.shape, value.dtype) + return value + + value = jax.tree.map(to_shape_dtype, value) + yield reprlib.Attr(name, repr(value)) finally: if clear_seen: CONTEXT.seen_modules_repr = None @@ -1204,9 +1280,9 @@ def _graph_node_set_key(self, key: Key, value: tp.Any): elif ( hasattr(self, key) and isinstance(variable := getattr(self, key), Variable) - and isinstance(value, Variable) + and isinstance(value, VariableState) ): - variable.copy_from(value) + variable.copy_from_state(value) else: setattr(self, key, value) diff --git a/flax/experimental/nnx/nnx/helpers.py b/flax/experimental/nnx/nnx/helpers.py index 137fe11bd2..8aeca41764 100644 --- a/flax/experimental/nnx/nnx/helpers.py +++ b/flax/experimental/nnx/nnx/helpers.py @@ -34,7 +34,7 @@ import jax.numpy as jnp import optax -from flax.experimental.nnx.nnx.graph_utils import Key +from flax.experimental.nnx.nnx.graph import Key from flax.experimental.nnx.nnx.module import GraphDef, Module from flax.experimental.nnx.nnx.proxy_caller import ApplyCaller from flax.experimental.nnx.nnx.rnglib import Rngs diff --git a/flax/experimental/nnx/nnx/module.py b/flax/experimental/nnx/nnx/module.py index cb1c54dd99..69dd37d75b 100644 --- a/flax/experimental/nnx/nnx/module.py +++ b/flax/experimental/nnx/nnx/module.py @@ -18,22 +18,18 @@ import typing as tp from functools import partial -import jax import jax.tree_util as jtu -import typing_extensions as tpe from flax.experimental.nnx.nnx import ( filterlib, - graph_utils, + graph, ) from flax.experimental.nnx.nnx import variables as variableslib -from flax.experimental.nnx.nnx.graph_utils import GraphDef, GraphNodeMeta +from flax.experimental.nnx.nnx.graph import GraphDef, GraphNode, GraphNodeMeta from flax.experimental.nnx.nnx.proxy_caller import ( - ApplyCaller, CallableProxy, DelayedAccessor, ) -from flax.experimental.nnx.nnx.rnglib import Rngs from flax.experimental.nnx.nnx.state import State from flax.experimental.nnx.nnx.variables import Variable from flax.typing import Path, PathParts @@ -71,198 +67,8 @@ def _module_meta_call(cls: tp.Type[M], *args, **kwargs) -> M: return module -class Module(graph_utils.GraphNode, metaclass=ModuleMeta): - @classmethod - def init(cls: type[M], *args, **kwargs) -> tuple[GraphDef[M], State]: - return cls(*args, **kwargs).split() - - @classmethod - @property - def create_abstract(cls: type[M]) -> type[M]: - def lift_rngs(kwargs: dict[str, tp.Any]): - if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], tp.Mapping): - kwargs['rngs'] = Rngs(rngs) - return kwargs - - def _create_abstract(accessor: DelayedAccessor, *args, **kwargs): - constructor = accessor(cls) - if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], Rngs): - kwargs['rngs'] = rngs.fork() - graphdef, state = jax.eval_shape( - lambda: constructor(*args, **lift_rngs(kwargs)).split() - ) - return graphdef.merge(state) - - return CallableProxy(_create_abstract) # type: ignore - - @classmethod - def partial_init(cls: type[M], state: State, *states: State) -> type[M]: - """Creates a constuctor that initializes the Module with the given state. - - ``partial_init`` takes one or more States and returns a constructor that uses - ``jax.jit`` to initialize the Module and update its state with the given - States. Its semantically equivalent to:: - - module = MyModule(*args, **kwargs) - module.update(state, *states) - - However, thanks to dead code elimination the resulting constructor will only - initialize the subset of ``Variable``'s that were part of the given state(s). - - Example:: - - >>> import jax.numpy as jnp - >>> import jax - >>> from flax.experimental import nnx - ... - >>> bias = jax.random.normal(jax.random.key(0), (4,)) - >>> state = nnx.State({'bias': bias}) # in reality load it from a checkpoint - >>> linear = nnx.Linear.partial_init(state)(2, 4, rngs=nnx.Rngs(1)) - >>> y = linear(jnp.ones((1, 2))) - ... - >>> assert jnp.allclose(linear.bias, bias) - >>> assert y.shape == (1, 4) - - Args: - state: The State to initialize the Module with. - *states: Additional States to initialize the Module with. - - Returns: - A constructor that initializes the Module with the given States. - """ - states = (state, *states) - - def lift_rngs(kwargs: dict[str, tp.Any]): - if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], tp.Mapping): - kwargs['rngs'] = Rngs(rngs) - return kwargs - - def _partial_init(accessor: DelayedAccessor, *args, **kwargs): - constructor: tp.Callable[[], M] = accessor(cls) - if 'rngs' in kwargs and isinstance(rngs := kwargs['rngs'], Rngs): - kwargs['rngs'] = rngs.fork() - - def _partial_init_constructor(): - module = constructor(*args, **lift_rngs(kwargs)) - module.update(*states) - return module.split() - - graphdef: GraphDef[M] - state: State - graphdef, state = jax.jit(_partial_init_constructor)() - module = graphdef.merge(state) - return module - - return CallableProxy(_partial_init) # type: ignore - - def clone(self: M) -> M: - return graph_utils.merge(*self.split()) - - @tp.overload - def split(self: M) -> tuple[GraphDef[M], State]: - ... - - @tp.overload - def split(self: M, first: filterlib.Filter, /) -> tuple[GraphDef[M], State]: - ... - - @tp.overload - def split( - self: M, - first: filterlib.Filter, - second: filterlib.Filter, - /, - *filters: filterlib.Filter, - ) -> tuple[GraphDef[M], State, tpe.Unpack[tuple[State, ...]]]: - ... - - def split( - self: M, *filters: filterlib.Filter - ) -> tuple[GraphDef[M], State, tpe.Unpack[tuple[State, ...]]]: - return graph_utils.split(self, *filters) - - def get_state(self) -> State: - _, state = self.split() - return state - - def get_graphdef(self: M) -> GraphDef[M]: - graphdef, _ = self.split() - return graphdef - - @tp.overload - def extract(self, first: filterlib.Filter, /) -> State: - ... - - @tp.overload - def extract( - self, - first: filterlib.Filter, - second: filterlib.Filter, - /, - *filters: filterlib.Filter, - ) -> tuple[State, ...]: - ... - - def extract( - self, - first: filterlib.Filter, - /, - *filters: filterlib.Filter, - ) -> tp.Union[State, tuple[State, ...]]: - state = self.get_state() - - if len(filters) == 0: - states = state.extract(first) - else: - states = state.extract(first, filters[0], *filters[1:]) - - return states - - @tp.overload - def pop( - self, - filter: filterlib.Filter, - /, - ) -> State: - ... - - @tp.overload - def pop( - self, - filter: filterlib.Filter, - filter2: filterlib.Filter, - /, - *filters: filterlib.Filter, - ) -> tuple[State, ...]: - ... - - def pop( - self, *filters: filterlib.Filter - ) -> tp.Union[State, tuple[State, ...]]: - if len(filters) == 0: - raise ValueError('Expected at least one filter') - - states = graph_utils.graph_pop(self, filters) - - if len(states) == 1: - return states[0] - else: - return states - - @property - def apply(self: M) -> ApplyCaller[M]: - def _apply(accessor: DelayedAccessor, *args, **kwargs) -> tuple[tp.Any, M]: - module = self.clone() - fn = accessor(module) - out = fn(*args, **kwargs) - return out, module - - return CallableProxy(_apply) # type: ignore - - def update( - self: M, update: graph_utils.Updates[M], /, *updates: graph_utils.Updates[M] - ) -> None: - graph_utils.update(self, update, *updates) +class Module(graph.GraphNode, metaclass=ModuleMeta): + """""" def sow( self, @@ -288,8 +94,101 @@ def sow( reduced_value = reduce_fn(init_fn(), value) setattr(self, name, variable_type(reduced_value)) - def modules(self) -> tp.Iterator[tuple[PathParts, Module]]: - for path, value in graph_utils.iter_nodes(self): + @property + def init(self: M) -> M: + """Calls a method in initialization mode. + + When a method is called using ``init``, the ``is_initializing`` method + will return ``True``. This is useful to implement Modules that support + lazy initialization. + + Example:: + + >>> from flax.experimental import nnx + >>> import jax + >>> import jax.numpy as jnp + ... + >>> class Linear(nnx.Module): + ... def __init__(self, dout, rngs: nnx.Rngs): + ... self.dout = dout + ... self.rngs = rngs + ... + ... def __call__(self, x): + ... if self.is_initializing(): + ... din = x.shape[-1] + ... if not hasattr(self, 'w'): + ... key = self.rngs.params() + ... self.w = nnx.Param(jax.random.uniform(key, (din, self.dout))) + ... if not hasattr(self, 'b'): + ... self.b = nnx.Param(jnp.zeros((self.dout,))) + ... + ... return x @ self.w + self.b + ... + >>> linear = Linear(3, nnx.Rngs(0)) + >>> x = jnp.ones((5, 2)) + >>> y = linear.init(x) + >>> linear.w.value.shape + (2, 3) + >>> linear.b.value.shape + (3,) + >>> y.shape + (5, 3) + """ + + def _init_context(accessor: DelayedAccessor, *args, **kwargs): + for _, value in graph.iter_nodes(self): + if isinstance(value, GraphNode): + value._graph_node__state._initializing = True + + method = accessor(self) + try: + out = method(*args, **kwargs) + finally: + for _, value in graph.iter_nodes(self): + if isinstance(value, GraphNode): + value._graph_node__state._initializing = False + + return out + + return CallableProxy(_init_context) # type: ignore + + def is_initializing(self) -> bool: + """Returns whether the Module is initializing. + + ``is_initializing`` returns ``True`` if the Module is currently being run + under ``init``. + """ + + return self._graph_node__state._initializing + + def iter_modules(self) -> tp.Iterator[tuple[PathParts, Module]]: + """Iterates over all nested Modules of the current Module, including the current Module. + + ``iter_modules`` creates a generator that yields the path and the Module instance, where + the path is a tuple of strings or integers representing the path to the Module from the + root Module. + + Example:: + + >>> from flax.experimental import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.5) + ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) + ... + ... + >>> model = Block(2, 5, rngs=nnx.Rngs(0)) + >>> for path, module in model.iter_modules(): + ... print(path, type(module).__name__) + ... + () Block + ('batch_norm',) BatchNorm + ('dropout',) Dropout + ('linear',) Linear + """ + for path, value in graph.iter_nodes(self): if isinstance(value, Module): yield path, value @@ -322,7 +221,7 @@ def set_attributes( ``Filter``'s can be used to set the attributes of specific Modules:: >>> block = Block(2, 5, rngs=nnx.Rngs(0)) - >>> block.set_attributes(nnx.Dropout, deterministic=True, use_running_average=True) + >>> block.set_attributes(nnx.Dropout, deterministic=True) >>> # Only the dropout will be modified >>> block.dropout.deterministic, block.batch_norm.use_running_average (True, False) @@ -337,7 +236,7 @@ def set_attributes( if not filters: filters = (True,) predicates = tuple(map(filterlib.to_predicate, filters)) - for path, module in self.modules(): + for path, module in self.iter_modules(): for predicate in predicates: if predicate(path, module): for name, value in attributes.items(): @@ -352,6 +251,77 @@ def set_attributes( f'Could not find at least one instance of the following attributes: {remaining_attributes}' ) + def train(self, **attributes): + """Sets the Module to training mode. + + ``train`` uses ``set_attributes`` to recursively set attributes ``deterministic=False`` + and ``use_running_average=False`` of all nested Modules that have these attributes. + Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` + Modules. + + Example:: + + >>> from flax.experimental import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... # initialize Dropout and BatchNorm in eval mode + ... self.dropout = nnx.Dropout(0.5, deterministic=True) + ... self.batch_norm = nnx.BatchNorm(10, use_running_average=True, rngs=rngs) + ... + >>> block = Block(2, 5, rngs=nnx.Rngs(0)) + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (True, True) + >>> block.train() + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (False, False) + + Args: + **attributes: additional attributes passed to ``set_attributes``. + """ + return self.set_attributes( + deterministic=False, + use_running_average=False, + **attributes, + raise_if_not_found=False, + ) + + def eval(self, **attributes): + """Sets the Module to evaluation mode. + + ``eval`` uses ``set_attributes`` to recursively set attributes ``deterministic=True`` + and ``use_running_average=True`` of all nested Modules that have these attributes. + Its primarily used to control the runtime behavior of the ``Dropout`` and ``BatchNorm`` + Modules. + + Example:: + + >>> from flax.experimental import nnx + ... + >>> class Block(nnx.Module): + ... def __init__(self, din, dout, *, rngs: nnx.Rngs): + ... self.linear = nnx.Linear(din, dout, rngs=rngs) + ... self.dropout = nnx.Dropout(0.5) + ... self.batch_norm = nnx.BatchNorm(10, rngs=rngs) + ... + >>> block = Block(2, 5, rngs=nnx.Rngs(0)) + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (False, False) + >>> block.eval() + >>> block.dropout.deterministic, block.batch_norm.use_running_average + (True, True) + + Args: + **attributes: additional attributes passed to ``set_attributes``. + """ + return self.set_attributes( + deterministic=True, + use_running_average=True, + **attributes, + raise_if_not_found=False, + ) + def __init_subclass__(cls, experimental_pytree: bool = False) -> None: super().__init_subclass__() @@ -368,7 +338,7 @@ def __init_subclass__(cls, experimental_pytree: bool = False) -> None: # Pytree Definition # ------------------------- def _module_flatten(module: Module, *, with_keys: bool): - graphdef, state = module.split() + graphdef, state = graph.split(module) key_values = sorted(state.raw_mapping.items()) keys = tuple(key for key, _ in key_values) @@ -385,7 +355,7 @@ def _module_unflatten( variables: tuple[Variable[tp.Any], ...], ) -> M: paths, graphdef = paths_moduledef - return graphdef.merge(State(zip(paths, variables))) + return graph.merge(graphdef, State(zip(paths, variables))) def first_from(*args: tp.Optional[A], error_msg: str) -> A: diff --git a/flax/experimental/nnx/nnx/nn/attention.py b/flax/experimental/nnx/nnx/nn/attention.py index aad0c3772a..fa47421b0d 100644 --- a/flax/experimental/nnx/nnx/nn/attention.py +++ b/flax/experimental/nnx/nnx/nn/attention.py @@ -598,9 +598,9 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): >>> from flax.experimental import nnx >>> import jax.numpy as jnp - + ... >>> rngs = nnx.Rngs(42) - + ... >>> x = jnp.ones((1, 3)) >>> model_nnx = nnx.MultiHeadAttention( ... num_heads=2, @@ -609,10 +609,10 @@ def init_cache(self, input_shape: Shape, dtype: Dtype = jnp.float32): ... out_features=6, ... decode=True, ... rngs=rngs, - >>> ) - - >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized - + ... ) + ... + >>> # out_nnx = model_nnx(x) <-- throws an error because cache isn't initialized + ... >>> model_nnx.init_cache(x.shape) >>> out_nnx = model_nnx(x) """ diff --git a/flax/experimental/nnx/nnx/nn/linear.py b/flax/experimental/nnx/nnx/nn/linear.py index 49f3ac8381..990f6c6f3a 100644 --- a/flax/experimental/nnx/nnx/nn/linear.py +++ b/flax/experimental/nnx/nnx/nn/linear.py @@ -28,7 +28,6 @@ from __future__ import annotations import typing as tp -from types import MappingProxyType import jax import jax.numpy as jnp @@ -36,6 +35,7 @@ from jax import lax import opt_einsum +from flax.core.frozen_dict import FrozenDict from flax.experimental import nnx from flax.experimental.nnx.nnx import rnglib, variables from flax.experimental.nnx.nnx.module import Module, first_from @@ -109,24 +109,32 @@ class LinearGeneral(Module): Example usage:: - >>> import flax.linen as nn + >>> from flax.experimental import nnx >>> import jax, jax.numpy as jnp - - >>> # equivalent to `nn.Linear(features=4)` - >>> layer = nn.LinearGeneral(features=4) + ... + >>> # equivalent to `nnx.Linear(2, 4)` + >>> layer = nnx.LinearGeneral(2, 4, rngs=nnx.Rngs(0)) + >>> layer.kernel.value.shape + (2, 4) >>> # output features (4, 5) - >>> layer = nn.LinearGeneral(features=(4, 5)) - >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3))) - >>> jax.tree_util.tree_map(jnp.shape, params) - {'params': {'bias': (4, 5), 'kernel': (3, 4, 5)}} + >>> layer = nnx.LinearGeneral(2, (4, 5), rngs=nnx.Rngs(0)) + >>> layer.kernel.value.shape + (2, 4, 5) + >>> layer.bias.value.shape + (4, 5) >>> # apply transformation on the the second and last axes - >>> layer = nn.LinearGeneral(features=(4, 5), axis=(1, -1)) - >>> params = layer.init(jax.random.key(0), jnp.ones((1, 3, 6, 7))) - >>> jax.tree_util.tree_map(jnp.shape, params) - {'params': {'bias': (4, 5), 'kernel': (3, 7, 4, 5)}} + >>> layer = nnx.LinearGeneral((2, 3), (4, 5), axis=(1, -1), rngs=nnx.Rngs(0)) + >>> layer.kernel.value.shape + (2, 3, 4, 5) + >>> layer.bias.value.shape + (4, 5) + >>> y = layer(jnp.ones((16, 2, 3))) + >>> y.shape + (16, 4, 5) Attributes: - features: int or tuple with number of output features. + in_features: int or tuple with number of output features. + out_features: int or tuple with number of output features. axis: int or tuple with axes to apply the transformation on. For instance, (-2, -1) will apply the transformation to the last two axes. batch_dims: tuple with batch axes. @@ -145,7 +153,7 @@ def __init__( out_features: Size | tp.Sequence[Size], *, axis: Axis | tp.Sequence[Axis] = -1, - batch_axis: tp.Mapping[Axis, Size] = MappingProxyType({}), + batch_axis: tp.Mapping[Axis, Size] = FrozenDict({}), use_bias: bool = True, dtype: Dtype | None = None, param_dtype: Dtype = jnp.float32, @@ -160,7 +168,7 @@ def __init__( self.in_features = _canonicalize_tuple(in_features) self.out_features = _canonicalize_tuple(out_features) self.axis = _canonicalize_tuple(axis) - self.batch_axis = MappingProxyType(batch_axis) + self.batch_axis = FrozenDict(batch_axis) self.use_bias = use_bias self.dtype = dtype self.param_dtype = param_dtype @@ -365,12 +373,15 @@ class Einsum(Module): >>> from flax.experimental import nnx >>> import jax.numpy as jnp - - >>> layer = nnx.Einsum('abc,cde->abde', (3, 4, 5), (5, 6, 7), rngs=nnx.Rngs(0)) - >>> assert layer.kernel.value.shape == (5, 6, 7) - >>> assert layer.bias.value.shape == (6, 7) - >>> out = layer(jnp.ones((3, 4, 5))) - >>> assert out.shape == (3, 4, 6, 7) + ... + >>> layer = nnx.Einsum('nta,hab->nthb', (8, 2, 4), (8, 4), rngs=nnx.Rngs(0)) + >>> layer.kernel.value.shape + (8, 2, 4) + >>> layer.bias.value.shape + (8, 4) + >>> y = layer(jnp.ones((16, 11, 2))) + >>> y.shape + (16, 11, 8, 4) Attributes: einsum_str: a string to denote the einsum equation. The equation must diff --git a/flax/experimental/nnx/nnx/nn/normalization.py b/flax/experimental/nnx/nnx/nn/normalization.py index d93ebfd4cc..f27d6b2798 100644 --- a/flax/experimental/nnx/nnx/nn/normalization.py +++ b/flax/experimental/nnx/nnx/nn/normalization.py @@ -216,7 +216,7 @@ def __init__( self, num_features: int, *, - use_running_average: tp.Optional[bool] = None, + use_running_average: bool = False, axis: int = -1, momentum: float = 0.99, epsilon: float = 1e-5, diff --git a/flax/experimental/nnx/nnx/nn/stochastic.py b/flax/experimental/nnx/nnx/nn/stochastic.py index a4a676df7e..efd8f94f31 100644 --- a/flax/experimental/nnx/nnx/nn/stochastic.py +++ b/flax/experimental/nnx/nnx/nn/stochastic.py @@ -53,7 +53,7 @@ class Dropout(Module): rate: float broadcast_dims: Sequence[int] = () - deterministic: bool | None = None + deterministic: bool = False rng_collection: str = 'dropout' rngs: rnglib.Rngs | None = None diff --git a/flax/experimental/nnx/nnx/reprlib.py b/flax/experimental/nnx/nnx/reprlib.py index e734879fd6..5efc065ed0 100644 --- a/flax/experimental/nnx/nnx/reprlib.py +++ b/flax/experimental/nnx/nnx/reprlib.py @@ -88,19 +88,16 @@ def _repr_elem(elem: tp.Any) -> str: value = elem.value if isinstance(elem.value, str) else repr(elem.value) - if '\n' in value and not isinstance(elem.value, Representable): - value = value.replace('\n', '\n' + get_indent()) + value = value.replace('\n', '\n' + config.elem_indent) - return ( - f'{get_indent()}{elem.start}{elem.key}{config.value_sep}{value}{elem.end}' - ) + return f'{config.elem_indent}{elem.start}{elem.key}{config.value_sep}{value}{elem.end}' with add_indent(config.elem_indent): elems = list(map(_repr_elem, iterator)) elems = ',\n'.join(elems) if elems: - elems = '\n' + elems + '\n' + get_indent() + elems = '\n' + elems + '\n' else: elems = config.empty_repr @@ -115,4 +112,14 @@ def __nnx_repr__(self): yield Object(type='', value_sep=': ', start='{', end='}') for key, value in self.items(): + yield Attr(repr(key), value) + +@dataclasses.dataclass(repr=False) +class PrettyMapping(Representable): + mapping: tp.Mapping + + def __nnx_repr__(self): + yield Object(type='', value_sep=': ', start='{', end='}') + + for key, value in self.mapping.items(): yield Attr(repr(key), value) \ No newline at end of file diff --git a/flax/experimental/nnx/nnx/rnglib.py b/flax/experimental/nnx/nnx/rnglib.py index 84e91a29b7..8e3802bef6 100644 --- a/flax/experimental/nnx/nnx/rnglib.py +++ b/flax/experimental/nnx/nnx/rnglib.py @@ -34,13 +34,15 @@ import jax import jax.numpy as jnp +from flax.experimental.nnx.nnx import graph +from flax.experimental.nnx.nnx.state import State from flax.experimental.nnx.nnx.variables import Variable from flax.experimental.nnx.nnx import filterlib -from flax.experimental.nnx.nnx.graph_utils import GraphNode +from flax.experimental.nnx.nnx.graph import GraphNode Counts = list[int] AxesValue = tp.Union[int, None] -Pattern = tp.Union[AxesValue, tuple[AxesValue, ...]] +SplitPattern = tp.Union[AxesValue, tuple[AxesValue, ...]] class Missing: @@ -61,8 +63,14 @@ class RngCount(RngState): class RngKey(RngState): tag: str +class RngKeyBackup(RngState): + pass + + +NotKey = filterlib.All(RngState, filterlib.Not(RngKey)) + -@dataclasses.dataclass +@dataclasses.dataclass(repr=False) class RngStream(GraphNode): def __init__( self, @@ -72,6 +80,7 @@ def __init__( ): self.key = RngKey(key, tag=tag) self.count = RngCount(count) + self.key_backups: list[RngKeyBackup] = [] def __post_init__(self): if not isinstance(self.key, jax.Array): @@ -85,7 +94,7 @@ def __call__(self) -> jax.Array: self.count.value += 1 return key - def fork(self, pattern: Pattern) -> jax.Array: + def fork(self, pattern: SplitPattern) -> jax.Array: if pattern is None: # broadcast key key = self() @@ -159,20 +168,16 @@ def __len__(self) -> int: def __contains__(self, name: tp.Any) -> bool: return name in vars(self) - def replace(self, **kwargs: tp.Union[int, jax.Array, RngStream]) -> 'Rngs': - rngs: dict[str, tp.Any] = vars(self).copy() - del rngs['_graph_node__state'] - rngs.update(kwargs) - return Rngs(**rngs) - def fork( self, - _default: Pattern | dict[filterlib.Filter, Pattern] | Missing = MISSING, + _default: SplitPattern + | dict[filterlib.Filter, SplitPattern] + | Missing = MISSING, /, - **patterns: Pattern, + **patterns: SplitPattern, ) -> ForkedKeys: - filter_patterns: list[tuple[filterlib.Filter, Pattern]] + filter_patterns: list[tuple[filterlib.Filter, SplitPattern]] if isinstance(_default, dict): # merge default and patterns filter_patterns = [ @@ -276,3 +281,46 @@ def _split_rng_unflatten( _split_rng_unflatten, flatten_func=functools.partial(_split_rng_flatten, with_keys=False), ) + +def fork( + state: State, + split_filter: filterlib.Filter, + split_pattern: SplitPattern, +) -> tuple[State, State]: + if split_pattern is None: + raise RuntimeError('Split pattern cannot be None, this is a bug.') + if isinstance(split_pattern, int): + num_splits = split_pattern + else: + num_splits = tuple(x if x is not None else 1 for x in split_pattern) + + not_keys, split_state, broadcast_state = state.split( + NotKey, split_filter, ... + ) + broadcast_state = State.merge(not_keys, broadcast_state) + + def split_key(key: tp.Any) -> jax.Array: + if not isinstance(key, jax.Array): + raise TypeError(f'key must be a jax.Array, got {type(key)}') + + return jax.random.split(key, num_splits) + + split_state = jax.tree.map(split_key, split_state) + + return split_state, broadcast_state + +def backup_keys(node: tp.Any, /): + streams: list[RngStream] = [] + for _, stream in graph.iter_nodes(node): + if isinstance(stream, RngStream): + stream.key_backups.append(RngKeyBackup(stream.key.value)) + streams.append(stream) + return streams + + +def restore_keys(streams: list[RngStream], /): + for stream in streams: + if not stream.key_backups: + raise RuntimeError('No key backups found.') + backup = stream.key_backups.pop() + stream.key.value = backup.value \ No newline at end of file diff --git a/flax/experimental/nnx/nnx/spmd.py b/flax/experimental/nnx/nnx/spmd.py index 0554901487..20c0630173 100644 --- a/flax/experimental/nnx/nnx/spmd.py +++ b/flax/experimental/nnx/nnx/spmd.py @@ -44,7 +44,7 @@ def add_axis( axis_name = _get_partition_name(params) def _add_axis(x: tp.Any): - if isinstance(x, variables.Variable): + if isinstance(x, variables.VariableState): if isinstance(x, HasSharding) and x.sharding is not None: sharding = list(x.sharding) while len(sharding) < index: @@ -56,7 +56,7 @@ def _add_axis(x: tp.Any): return x return jax.tree_util.tree_map( - _add_axis, state, is_leaf=lambda x: isinstance(x, variables.Variable) + _add_axis, state, is_leaf=lambda x: isinstance(x, variables.VariableState) ) @@ -66,7 +66,7 @@ def remove_axis( axis_name = _get_partition_name(params) def _remove_axis(x: tp.Any): - if isinstance(x, variables.Variable): + if isinstance(x, variables.VariableState): if isinstance(x, HasSharding) and x.sharding is not None: sharding = list(x.sharding) assert sharding.pop(index) == axis_name @@ -75,7 +75,9 @@ def _remove_axis(x: tp.Any): return x return jax.tree_util.tree_map( - _remove_axis, state, is_leaf=lambda x: isinstance(x, variables.Variable) + _remove_axis, + state, + is_leaf=lambda x: isinstance(x, variables.VariableState), ) @@ -98,16 +100,16 @@ def _maybe_replicate(x): return None def f(x): - if isinstance(x, variables.Variable): + if isinstance(x, (variables.VariableState, variables.Variable)): if isinstance(x, HasSharding) and x.sharding: - return x.replace(raw_value=PartitionSpec(*x.sharding)) + return x.replace(PartitionSpec(*x.sharding)) else: - return x.replace(raw_value=_maybe_replicate(x.raw_value)) + return x.replace(_maybe_replicate(x.value)) return _maybe_replicate(x) - return jax.tree_map( - f, tree, is_leaf=lambda x: isinstance(x, variables.Variable) + return jax.tree_util.tree_map( + f, tree, is_leaf=lambda x: isinstance(x, variables.VariableState) ) diff --git a/flax/experimental/nnx/nnx/state.py b/flax/experimental/nnx/nnx/state.py index 905abc5cec..ffe77ea767 100644 --- a/flax/experimental/nnx/nnx/state.py +++ b/flax/experimental/nnx/nnx/state.py @@ -36,17 +36,17 @@ from flax import traverse_util from flax.experimental.nnx.nnx import filterlib, reprlib -from flax.experimental.nnx.nnx.variables import Variable +from flax.experimental.nnx.nnx.variables import VariableState from flax.typing import Key, PathParts A = tp.TypeVar('A') -StateLeaf = tp.Union[Variable[tp.Any], np.ndarray, jax.Array] +StateLeaf = tp.Union[VariableState[tp.Any], np.ndarray, jax.Array] FlatState = dict[PathParts, StateLeaf] def is_state_leaf(x: tp.Any) -> tpe.TypeGuard[StateLeaf]: - return isinstance(x, (Variable, np.ndarray, jax.Array)) + return isinstance(x, (VariableState, np.ndarray, jax.Array)) class NestedStateRepr(reprlib.Representable): @@ -66,8 +66,8 @@ class State(tp.MutableMapping[Key, tp.Any], reprlib.Representable): def __init__( self, mapping: tp.Union[ - tp.Mapping[Key, tp.Any], - tp.Iterator[tuple[Key, tp.Any]], + tp.Mapping[Key, tp.Mapping | StateLeaf], + tp.Iterator[tuple[Key, tp.Mapping | StateLeaf]], ], /, ): @@ -80,11 +80,14 @@ def __init__( def raw_mapping(self) -> dict[Key, dict[str, tp.Any] | tp.Any]: return self._mapping + def __contains__(self, key: Key) -> bool: + return key in self._mapping + def __getitem__(self, key: Key) -> State | StateLeaf: value = self._mapping[key] - if is_state_leaf(value): - return value - return State(value) + if isinstance(value, tp.Mapping): + return State(value) + return value def __getattr__(self, key: Key) -> State | StateLeaf: if '_mapping' not in vars(self) or key not in self._mapping: @@ -120,7 +123,9 @@ def flat_state(self) -> FlatState: return traverse_util.flatten_dict(self._mapping) # type: ignore @classmethod - def from_flat_path(cls, flat_state: FlatState, /) -> State: + def from_flat_path( + cls, flat_state: tp.Mapping[PathParts, StateLeaf], / + ) -> State: nested_state = traverse_util.unflatten_dict(flat_state) return cls(nested_state) @@ -157,7 +162,7 @@ def split( return states @tp.overload - def extract( + def filter( self, first: filterlib.Filter, /, @@ -165,7 +170,7 @@ def extract( ... @tp.overload - def extract( + def filter( self, first: filterlib.Filter, second: filterlib.Filter, @@ -174,7 +179,7 @@ def extract( ) -> tuple['State', ...]: ... - def extract( + def filter( self, first: filterlib.Filter, /, @@ -245,11 +250,13 @@ def _split_state( *filters: filterlib.Filter, ) -> tuple[State, ...]: for i, filter_ in enumerate(filters): - if filter_ is ... and i != len(filters) - 1: - raise ValueError( - 'Ellipsis `...` can only be used as the last filter, ' - f'got it at index {i}.' - ) + if filter_ in (..., True) and i != len(filters) - 1: + remaining_filters = filters[i + 1 :] + if not all(f in (..., True) for f in remaining_filters): + raise ValueError( + '`...` or `True` can only be used as the last filters, ' + f'got {filter_} it at index {i}.' + ) predicates = tuple(map(filterlib.to_predicate, filters)) flat_state = state.flat_state() diff --git a/flax/experimental/nnx/nnx/training/metrics.py b/flax/experimental/nnx/nnx/training/metrics.py index f724de2bfd..9c434f0404 100644 --- a/flax/experimental/nnx/nnx/training/metrics.py +++ b/flax/experimental/nnx/nnx/training/metrics.py @@ -29,7 +29,7 @@ import jax, jax.numpy as jnp from flax.experimental.nnx.nnx.variables import Variable -from flax.experimental.nnx.nnx import filterlib, graph_utils +from flax.experimental.nnx.nnx import filterlib, graph import typing as tp @@ -39,7 +39,7 @@ class MetricState(Variable): """Wrapper class for Metric Variables.""" pass -class Metric(graph_utils.GraphNode): +class Metric(graph.GraphNode): def __init__(self): raise NotImplementedError('Must override `__init__()` method.') def reset(self): @@ -49,16 +49,22 @@ def update(self): def compute(self): raise NotImplementedError('Must override `compute()` method.') def split(self, *filters: filterlib.Filter): - return graph_utils.split(self, *filters) + return graph.split(self, *filters) + class Average(Metric): - def __init__(self): + def __init__(self, argname: str = 'values'): + self.argname = argname self.total = MetricState(jnp.array(0, dtype=jnp.float32)) self.count = MetricState(jnp.array(0, dtype=jnp.int32)) def reset(self): self.total.value = jnp.array(0, dtype=jnp.float32) self.count.value = jnp.array(0, dtype=jnp.int32) - def update(self, *, values: tp.Union[int, float, jax.Array], **_): + + def update(self, **kwargs): + if self.argname not in kwargs: + raise TypeError(f"Expected keyword argument '{self.argname}'") + values: tp.Union[int, float, jax.Array] = kwargs[self.argname] self.total.value += values if isinstance(values, (int, float)) else values.sum() self.count.value += 1 if isinstance(values, (int, float)) else values.size def compute(self): @@ -74,24 +80,24 @@ def update(self, *, logits: jax.Array, labels: jax.Array, **_): super().update(values=(logits.argmax(axis=-1)==labels)) class MultiMetric(Metric): - '''MultiMetric class to store multiple metrics and update them in a single call. + """MultiMetric class to store multiple metrics and update them in a single call. Example usage:: >>> import jax, jax.numpy as jnp >>> from flax.experimental import nnx - + ... >>> logits = jax.random.normal(jax.random.key(0), (5, 2)) >>> labels = jnp.array([1, 1, 0, 1, 0]) >>> logits2 = jax.random.normal(jax.random.key(1), (5, 2)) >>> labels2 = jnp.array([0, 1, 1, 1, 1]) - + ... >>> batch_loss = jnp.array([1, 2, 3, 4]) >>> batch_loss2 = jnp.array([3, 2, 1, 0]) - + ... >>> metrics = nnx.MultiMetric( ... accuracy=nnx.metrics.Accuracy(), loss=nnx.metrics.Average() - >>> ) + ... ) >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} >>> metrics.update(logits=logits, labels=labels, values=batch_loss) @@ -103,7 +109,7 @@ class MultiMetric(Metric): >>> metrics.reset() >>> metrics.compute() {'accuracy': Array(nan, dtype=float32), 'loss': Array(nan, dtype=float32)} - ''' + """ def __init__(self, **metrics): # TODO: raise error if a kwarg is passed that is in ('reset', 'update', 'compute'), since these names are reserved for methods self._metric_names = [] diff --git a/flax/experimental/nnx/nnx/training/optimizer.py b/flax/experimental/nnx/nnx/training/optimizer.py index c85909ab72..6682618123 100644 --- a/flax/experimental/nnx/nnx/training/optimizer.py +++ b/flax/experimental/nnx/nnx/training/optimizer.py @@ -27,12 +27,12 @@ # limitations under the License. from __future__ import annotations -from flax.experimental import nnx -from flax.experimental.nnx.nnx.variables import Variable -from flax.experimental.nnx.nnx import filterlib, graph_utils - +import jax.numpy as jnp import optax +from flax.experimental import nnx +from flax.experimental.nnx.nnx import filterlib, graph +from flax.experimental.nnx.nnx.variables import Variable #TODO: add tests and docstrings @@ -40,7 +40,7 @@ class OptState(Variable): """Wrapper class for Optimizer Variables.""" pass -class Optimizer(graph_utils.GraphNode): +class Optimizer(graph.GraphNode): """Simple train state for the common case with a single Optax optimizer. Example usage:: @@ -48,28 +48,28 @@ class Optimizer(graph_utils.GraphNode): >>> import jax, jax.numpy as jnp >>> from flax.experimental import nnx >>> import optax - + ... >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... return self.linear2(self.linear1(x)) - + ... >>> x = jax.random.normal(jax.random.key(0), (1, 2)) >>> y = jnp.ones((1, 4)) - + ... >>> model = Model(nnx.Rngs(0)) >>> tx = optax.adam(1e-3) >>> state = nnx.Optimizer(model, tx) - - >>> loss_fn = lambda model: ((model(x)-y)**2).mean() - >>> loss_fn(state.model) - 1.7055722 + ... + >>> loss_fn = lambda model: ((model(x) - y) ** 2).mean() + >>> loss_fn(model) + Array(1.7055722, dtype=float32) >>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model) >>> state.update(grads) - >>> loss_fn(state.model) - 1.6925814 + >>> loss_fn(model) + Array(1.6925814, dtype=float32) Note that you can easily extend this class by subclassing it for storing additional data (e.g. adding metrics). @@ -83,17 +83,17 @@ class Optimizer(graph_utils.GraphNode): ... def update(self, *, grads, **updates): ... self.metrics.update(**updates) ... super().update(grads) - + ... >>> metrics = nnx.metrics.Average() >>> state = TrainState(model, tx, metrics) - + ... >>> grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() - 1.6925814 + Array(1.6925814, dtype=float32) >>> state.update(grads=grads, values=loss_fn(state.model)) >>> state.metrics.compute() - 1.68612 + Array(1.68612, dtype=float32) For more exotic usecases (e.g. multiple optimizers) it's probably best to fork the class and modify it. @@ -108,13 +108,13 @@ def __init__( model: nnx.Module, tx: optax.GradientTransformation, ): - self.step = OptState(0) + self.step = OptState(jnp.array(0, dtype=jnp.uint32)) self.model = model self.tx = tx - self.opt_state = tx.init(model.extract(nnx.Param)) + self.opt_state = tx.init(nnx.state(model, nnx.Param)) def split(self, *filters: filterlib.Filter): - return graph_utils.split(self, *filters) + return graph.split(self, *filters) def update(self, grads): """Updates ``step``, ``params``, ``opt_state`` and ``**kwargs`` in return value. @@ -131,7 +131,7 @@ def update(self, grads): and ``opt_state`` updated by applying ``grads``, and additional attributes replaced as specified by ``kwargs``. """ - params = self.model.extract(nnx.Param) + params = nnx.state(self.model, nnx.Param) updates, new_opt_state = self.tx.update( grads, self.opt_state, params @@ -140,6 +140,6 @@ def update(self, grads): assert isinstance(new_params, nnx.State) self.step.value += 1 - self.model.update(new_params) + nnx.update(self.model, new_params) self.opt_state = new_opt_state diff --git a/flax/experimental/nnx/nnx/transforms.py b/flax/experimental/nnx/nnx/transforms.py index 400a94a3d0..ad5a4cdc3d 100644 --- a/flax/experimental/nnx/nnx/transforms.py +++ b/flax/experimental/nnx/nnx/transforms.py @@ -31,16 +31,18 @@ import functools import typing as tp from abc import abstractmethod -from types import MappingProxyType -from typing import Any + +from flax.core.frozen_dict import FrozenDict import jax +import jax.core import jax.numpy as jnp import jax.stages +from jax._src.tree_util import broadcast_prefix from flax.experimental.nnx.nnx import ( filterlib, - graph_utils, + graph, rnglib, spmd, variables, @@ -65,16 +67,6 @@ Leaves = tp.List[Leaf] Index = int - -def _check_args(args: tuple[tp.Any, ...]): - """Check if Rngs is passed as a positional argument and raise an error.""" - for arg in args: - if isinstance(arg, rnglib.Rngs): - raise ValueError( - "Rngs must be passed as a keyword argument named 'rngs', not a" - ' positional argument' - ) - def _normalize_sequence( x: StrInt | tp.Iterable[StrInt] | None, / ) -> tuple[StrInt, ...]: @@ -104,7 +96,6 @@ def call(self) -> tp.Any: module = self def check_and_call(accessor: DelayedAccessor, *args, **kwargs): - _check_args(args) return self._call(accessor, *args, **kwargs) proxy = CallableProxy(check_and_call) @@ -125,6 +116,7 @@ def check_and_call(accessor: DelayedAccessor, *args, **kwargs): @dataclasses.dataclass(frozen=True) class JitStaticInputs: graphdef: GraphDef[tuple[tp.Any, ...]] + ctx: graph.UpdateContext jax.tree_util.register_static(JitStaticInputs) @@ -158,8 +150,8 @@ class JITOptions: inline: bool abstracted_axes: tp.Optional[tp.Any] # nnx specific - donate_object_state: bool - constrain_object_state: tp.Callable[[State], State] | None + donate_state: bool + constrain_state: tp.Callable[[State], State] | None @classmethod def from_jit_kwargs( @@ -175,20 +167,20 @@ def from_jit_kwargs( backend: tp.Optional[str], inline: bool, abstracted_axes: tp.Optional[tp.Any], - donate_object_state: bool, - constrain_object_state: bool | tp.Callable[[State], State], + donate_state: bool, + constrain_state: bool | tp.Callable[[State], State], ): _static_argnums = _normalize_sequence(static_argnums) _static_argnames = _normalize_sequence(static_argnames) _donate_argnums = _normalize_sequence(donate_argnums) _donate_argnames = _normalize_sequence(donate_argnames) - if donate_object_state: + if donate_state: _donate_argnames = (*_donate_argnames, '_nnx_jit_state') - if callable(constrain_object_state): - _constrain_object_state = constrain_object_state - elif constrain_object_state: + if callable(constrain_state): + _constrain_object_state = constrain_state + elif constrain_state: _constrain_object_state = _default_constrain_object_state else: _constrain_object_state = None @@ -205,14 +197,14 @@ def from_jit_kwargs( backend=backend, inline=inline, abstracted_axes=abstracted_axes, - donate_object_state=donate_object_state, - constrain_object_state=_constrain_object_state, + donate_state=donate_state, + constrain_state=_constrain_object_state, ) def get_jit_kwargs(self) -> dict[str, tp.Any]: kwargs = vars(self).copy() - del kwargs['donate_object_state'] - del kwargs['constrain_object_state'] + del kwargs['donate_state'] + del kwargs['constrain_state'] if kwargs['in_shardings'] is UNSPECIFIED: kwargs.pop('in_shardings') if kwargs['out_shardings'] is UNSPECIFIED: @@ -237,13 +229,12 @@ def __call__( inline: bool = False, abstracted_axes: tp.Optional[tp.Any] = None, # nnx specific - donate_object_state: bool = False, - constrain_object_state: bool | tp.Callable[[State], State] = False, - ) -> tp.Callable[..., 'JIT[M]']: + donate_state: bool = False, + constrain_state: bool | tp.Callable[[State], State] = False, + ) -> tp.Callable[..., 'Jit[M]']: super_call = super().__call__ - def _create_jit(*args, **kwargs) -> JIT[M]: - _check_args(args) + def _create_jit(*args, **kwargs) -> Jit[M]: return super_call( module_constructor=module_constructor, in_shardings=in_shardings, @@ -258,7 +249,8 @@ def _create_jit(*args, **kwargs) -> JIT[M]: inline=inline, abstracted_axes=abstracted_axes, # nnx specific - donate_object_state=donate_object_state, + donate_state=donate_state, + constrain_state=constrain_state, # submodule args module_init_args=args, module_init_kwargs=kwargs, @@ -274,7 +266,9 @@ def __call__( _nnx_jit_static: JitStaticInputs, _nnx_jit_state: State, **kwargs: tp.Any, - ) -> tuple[tp.Any, State, JitStaticOutputs]: + ) -> tuple[ + tp.Any, State, GraphDef[tuple[tuple[tp.Any, ...], tuple[tp.Any, ...]]] + ]: ... @@ -287,38 +281,28 @@ def jitted_fn( _nnx_jit_static: JitStaticInputs, _nnx_jit_state: State, **kwargs: tp.Any, - ): + ) -> tuple[tp.Any, State, GraphDef[tuple[tp.Any, ...]]]: + ctx = _nnx_jit_static.ctx graphdef = _nnx_jit_static.graphdef state: State = _nnx_jit_state - if options.constrain_object_state is not None: - state = options.constrain_object_state(state) + if options.constrain_state is not None: + state = options.constrain_state(state) - input_graph_nodes, outer_idx_inner_ref = graph_utils.graph_unflatten( - graphdef, state - ) + input_graph_nodes = ctx.merge(graphdef, state) - (args, kwargs) = graph_utils.insert_graph_nodes( - (args, kwargs), input_graph_nodes - ) + (args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes) out = f(*args, **kwargs) - out, output_graph_nodes = graph_utils.extract_graph_nodes(out) + out, output_graph_nodes = graph.extract_graph_nodes(out) - graphdef, state, inner_ref_inner_idx = graph_utils.graph_flatten( - (input_graph_nodes, output_graph_nodes) - ) - outer_idx_inner_idx = graph_utils.compose_mapping( - outer_idx_inner_ref, inner_ref_inner_idx - ) + graphdef, state = ctx.split((input_graph_nodes, output_graph_nodes)) - if options.constrain_object_state is not None: - state = options.constrain_object_state(state) + if options.constrain_state is not None: + state = options.constrain_state(state) - output_static = JitStaticOutputs(graphdef, outer_idx_inner_idx) - out = (out, state, output_static) - return out + return out, state, graphdef return jitted_fn @@ -329,34 +313,24 @@ def jit_apply( args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], ) -> tp.Any: - (args, kwargs), input_graph_nodes = graph_utils.extract_graph_nodes( - (args, kwargs) - ) - - graphdef, state, outer_ref_outer_idx = graph_utils.graph_flatten( - input_graph_nodes - ) + ctx = graph.UpdateContext() + (args, kwargs), input_graph_nodes = graph.extract_graph_nodes((args, kwargs)) + graphdef, state = ctx.split(input_graph_nodes) - out, output_state, output_static = jitted_fn( + out, output_state, output_graphdef = jitted_fn( *args, - _nnx_jit_static=JitStaticInputs(graphdef), + _nnx_jit_static=JitStaticInputs(graphdef, ctx), _nnx_jit_state=state, **kwargs, ) - outer_idx_inner_idx = output_static.index_mapping - output_graphdef = output_static.graphdef - inner_idx_outer_ref = graph_utils.compose_mapping_reversed( - outer_ref_outer_idx, outer_idx_inner_idx - ) - (input_graph_nodes, output_graph_nodes), _ = graph_utils.graph_unflatten( - output_graphdef, output_state, ref_cache=inner_idx_outer_ref + input_graph_nodes, output_graph_nodes = ctx.update( + output_graphdef, output_state ) - out = graph_utils.insert_graph_nodes(out, output_graph_nodes) - + out = graph.insert_graph_nodes(out, output_graph_nodes) return out -class JIT(LiftedModule[M], metaclass=JITMeta): +class Jit(LiftedModule[M], metaclass=JITMeta): def __init__( self, module_constructor: tp.Callable[..., M], @@ -373,8 +347,8 @@ def __init__( inline: bool = False, abstracted_axes: tp.Optional[tp.Any] = None, # nnx specific - donate_object_state: bool = False, - constrain_object_state: bool | tp.Callable[[State], State] = False, + donate_state: bool = False, + constrain_state: bool | tp.Callable[[State], State] = False, # submodule args module_init_args: tuple[tp.Any, ...], module_init_kwargs: dict[str, tp.Any], @@ -391,15 +365,15 @@ def __init__( backend=backend, inline=inline, abstracted_axes=abstracted_axes, - donate_object_state=donate_object_state, - constrain_object_state=constrain_object_state, + donate_state=donate_state, + constrain_state=constrain_state, ) self.accessor: tp.Optional[DelayedAccessor] = None def jit_call_module(module, *args, **kwargs): assert self.accessor is not None - f = self.accessor(module) - return f(*args, **kwargs) + method = self.accessor(module) + return method(*args, **kwargs) self.jitted_fn: JittedFn[M] = get_jitted_fn(jit_call_module, self.options) self.module_constructor = module_constructor @@ -411,7 +385,7 @@ def jit_call_module(module, *args, **kwargs): def _submodule(self) -> M: return self.jit_module - def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> Any: + def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any: self.accessor = accessor try: out = jit_apply( @@ -437,8 +411,8 @@ def jit( inline: bool = False, abstracted_axes: tp.Optional[tp.Any] = None, # nnx specific - donate_object_state: bool = False, - constrain_object_state: bool | tp.Callable[[State], State] = False, + donate_state: bool = False, + constrain_state: bool | tp.Callable[[State], State] = False, ) -> F: """ Lifted version of ``jax.jit`` that can handle Modules / graph nodes as @@ -555,9 +529,9 @@ def jit( inline: Specify whether this function should be inlined into enclosing jaxprs (rather than being represented as an application of the xla_call primitive with its own subjaxpr). Default False. - donate_object_state: Optional, bool. If True, the object state of the + donate_state: Optional, bool. If True, the object state of the graph node's state will be donated to the computation. Default False. - constrain_object_state: Optional, bool or callable. If True, the object + constrain_state: Optional, bool or callable. If True, the object state of the graph node's state will be constrained to the partition specified by the graph node's partition spec as computed by :func:`nnx.spmd.get_partition_spec`. If a callable, the object State will @@ -579,8 +553,8 @@ def jit( backend=backend, inline=inline, abstracted_axes=abstracted_axes, - donate_object_state=donate_object_state, - constrain_object_state=constrain_object_state, + donate_state=donate_state, + constrain_state=constrain_state, ) jitted_fn = get_jitted_fn(fun, options) @@ -625,7 +599,6 @@ def __call__( super_call = super().__call__ def _create_grad(*args, **kwargs) -> Grad[M]: - _check_args(args) return super_call( module_constructor=module_constructor, wrt=wrt, @@ -677,24 +650,25 @@ def __init__( def _submodule(self) -> M: return self.grad_module - def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> Any: + def _call(self, accessor: DelayedAccessor, *args, **kwargs) -> tp.Any: def grad_call_apply(module, *args, **kwargs): - return accessor(module)(*args, **kwargs) + method = accessor(module) + return method(*args, **kwargs) return grad_apply(self.options, grad_call_apply, (self.grad_module, *args)) def grad_apply(options: GradOptions, f, args: tuple[tp.Any, ...]): - _, input_nodes = graph_utils.extract_graph_nodes(args) + _, input_nodes = graph.extract_graph_nodes(args) _args = list(args) diff_graph_nodes: dict[int, tp.Any] = { i: arg for i, arg in enumerate(args) - if i in options.argnums and graph_utils.is_node(arg) + if i in options.argnums and graph.is_node(arg) } - _, diff_state, _ = graph_utils.split(diff_graph_nodes, options.wrt, ...) + _, diff_state, _ = graph.split(diff_graph_nodes, options.wrt, ...) for i in diff_graph_nodes: _args[i] = diff_state[i] @@ -717,13 +691,13 @@ def grad_fn(*args): _args = list(args) for i, graph_node in diff_graph_nodes.items(): diff_state: State = _args[i] - graph_utils.graph_update_dynamic(graph_node, diff_state) + graph.update(graph_node, diff_state) _args[i] = graph_node out = f(*_args) - out, out_nodes = graph_utils.extract_graph_nodes(out) + out, out_nodes = graph.extract_graph_nodes(out) - _, updates, _ = graph_utils.graph_flatten((input_nodes, out_nodes)) + _, updates, _ = graph.flatten((input_nodes, out_nodes)) if options.has_aux: loss, aux = out @@ -750,7 +724,7 @@ def grad_fn(*args): else: out, updates = out - graph_utils.graph_update_dynamic((input_nodes, out_nodes), updates) + graph.update((input_nodes, out_nodes), updates) return out @@ -775,6 +749,8 @@ def grad( Example:: + >>> from flax.experimental import nnx + ... >>> m = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) >>> x = jnp.ones((1, 2)) >>> y = jnp.ones((1, 3)) @@ -785,11 +761,13 @@ def grad( >>> grads = grad_fn(m, x, y) >>> jax.tree_util.tree_map(jnp.shape, grads) State({ - 'bias': Param( - raw_value=(3,) + 'bias': VariableState( + type=Param, + value=(3,) ), - 'kernel': Param( - raw_value=(2, 3) + 'kernel': VariableState( + type=Param, + value=(2, 3) ) }) @@ -871,7 +849,6 @@ def value_and_grad( @functools.wraps(f) def value_and_grad_wrapper(*args): - _check_args(args) return grad_apply(options, f, args) return value_and_grad_wrapper # type: ignore @@ -881,18 +858,21 @@ def value_and_grad_wrapper(*args): # scan # ------------------------------- - @dataclasses.dataclass class ScanOptions: - variable_axes: tp.Mapping[filterlib.Filter, int] - broadcast_rngs: filterlib.Filter - in_args_axes: tp.Any - in_kwargs_axes: tp.Any - out_axes: tp.Any - length: tp.Optional[int] + length: int | None reverse: bool - unroll: int - scan_metadata: tp.Mapping[str, tp.Any] + unroll: int | bool + _split_transpose: bool + # extended api + in_axes: tp.Any + in_axes_kwargs: tp.Any + out_axes: tp.Any + carry_argnum: int + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] + split_rngs: filterlib.Filter + transform_metadata: tp.Mapping[str, tp.Any] scan_output: bool @@ -901,34 +881,42 @@ def __call__( self, module_constructor: tp.Callable[..., M], *, - variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), - broadcast_rngs: filterlib.Filter = None, - in_args_axes: tp.Any = 0, - in_kwargs_axes: tp.Any = 0, - out_axes: tp.Any = 0, - length: tp.Optional[int] = None, + length: int | None = None, reverse: bool = False, - unroll: int = 1, - scan_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 1, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), scan_output: bool = True, ) -> tp.Callable[..., 'Scan[M]']: super_call = super().__call__ def _create_scan(*args, **kwargs) -> Scan[M]: - _check_args(args) return super_call( module_constructor=module_constructor, module_init_args=args, module_init_kwargs=kwargs, - variable_axes=variable_axes, - broadcast_rngs=broadcast_rngs, - in_args_axes=in_args_axes, - in_kwargs_axes=in_kwargs_axes, - out_axes=out_axes, + # base api length=length, reverse=reverse, unroll=unroll, - scan_metadata=scan_metadata, + _split_transpose=_split_transpose, + # extended api + in_axes=in_axes, + in_axes_kwargs=in_axes_kwargs, + out_axes=out_axes, + carry_argnum=carry_argnum, + # nnx specific + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, scan_output=scan_output, ) @@ -940,15 +928,19 @@ def __init__( self, module_constructor: tp.Callable[..., M], *, - variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), - broadcast_rngs: filterlib.Filter = None, - in_args_axes: tp.Any = 0, - in_kwargs_axes: tp.Any = 0, - out_axes: tp.Any = 0, - length: tp.Optional[int] = None, + length: int | None = None, reverse: bool = False, - unroll: int = 1, - scan_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 1, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), scan_output: bool = True, # submodule args module_init_args: tuple[tp.Any, ...], @@ -956,20 +948,36 @@ def __init__( ): self.module_constructor = module_constructor self.options = ScanOptions( - variable_axes=variable_axes, - broadcast_rngs=broadcast_rngs, - in_args_axes=in_args_axes, - in_kwargs_axes=in_kwargs_axes, - out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, - scan_metadata=scan_metadata, + _split_transpose=_split_transpose, + # extended api + in_axes=in_axes, + in_axes_kwargs=in_axes_kwargs, + out_axes=out_axes, + carry_argnum=carry_argnum, + # nnx specific + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, scan_output=scan_output, ) - self.scan_module = scan_init( - self.options, module_constructor, module_init_args, module_init_kwargs - ) + # use Vmap to handle initialisation + vmapped_module = Vmap( + module_constructor, + in_axes=in_axes, + out_axes=None, + axis_name=None, + axis_size=length, + spmd_axis_name=None, + state_axes=state_axes, + split_rngs=split_rngs, + in_axes_kwargs=in_axes_kwargs, + transform_metadata=transform_metadata, + )(*module_init_args, **module_init_kwargs) + + self.scan_module = vmapped_module.vmap_module @property def _submodule(self) -> M: @@ -978,181 +986,64 @@ def _submodule(self) -> M: def _call( self, accessor: DelayedAccessor, *args, **kwargs ) -> tuple[tp.Any, tp.Any]: - if len(args) < 1: - raise TypeError( - f'Expected at least 1 positional arguments, got {len(args)}' - ) - _check_args(args) - carry_arg, args = args[0], args[1:] - def scan_call_apply(module, *args, **kwargs): - return accessor(module)(*args, **kwargs) + method = accessor(module) + return method(*args, **kwargs) return scan_apply( self.options, scan_call_apply, - self.scan_module, - carry_arg, - args, + (self._submodule, *args), kwargs, ) -class ScanCall(tp.Protocol, tp.Generic[C, B]): - def __call__( - self, - module: Module, - carry_arg: C, - *args: tp.Any, - **kwargs: tp.Any, - ) -> tuple[C, B] | C: - ... - - -def scan_init( - options: ScanOptions, - module_constructor: tp.Callable[..., M], - module_init_args: tuple[tp.Any, ...], - module_init_kwargs: dict[str, tp.Any], -) -> M: - if options.variable_axes and options.length is None: - raise ValueError('Cannot use variable_axes without specifying a length') - - _check_args(module_init_args) - - rngs = module_init_kwargs.pop('rngs', None) - - if rngs is not None and not isinstance(rngs, rnglib.Rngs): - raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - - split_keys = [] - - if rngs is not None: - if not isinstance(rngs, rnglib.Rngs): - raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - - forked_rngs = rngs.fork( - {filterlib.Not(options.broadcast_rngs): options.length} - ) - split_keys, broadcast_keys = forked_rngs.splits, forked_rngs.broadcasts - - if split_keys and options.length is None: - raise ValueError('Cannot split RNGs without specifying a length') - - else: - split_keys = None - broadcast_keys = None - - graphdef: tp.Optional[GraphDef[M]] = None - - def _init_state(split_keys, broadcast_keys): - nonlocal graphdef - - if split_keys is not None: - assert broadcast_keys is not None - module_init_kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) - - module = module_constructor(*module_init_args, **module_init_kwargs) - - # lift module - filters = (*options.variable_axes.keys(), ...) - - graphdef, *states = module.split(*filters) - - return tuple(states) - - if split_keys is not None or options.variable_axes: - init_out_axes = (*options.variable_axes.values(), None) - _init_state = jax.vmap( - _init_state, - in_axes=(0, None), - out_axes=init_out_axes, - axis_size=options.length, - ) - - *axes_states, carry_state = _init_state(split_keys, broadcast_keys) - graphdef = tp.cast(GraphDef[M], graphdef) - - # add additional axis name to Variable.sharding - if spmd.PARTITION_NAME in options.scan_metadata: - axes_states = [ - spmd.add_axis(state, index, options.scan_metadata) - for state, index in zip(axes_states, options.variable_axes.values()) - ] - - module = graphdef.merge(*axes_states, carry_state) - - return module - - def scan_apply( options: ScanOptions, - f: ScanCall[C, B], - module: Module, - carry_arg: C, + f: tp.Callable[..., tuple[C, B] | C], args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], ) -> tuple[C, B] | C: - rngs = kwargs.pop('rngs', None) + # extract nodes + (args, kwargs), input_graph_nodes = graph.extract_graph_nodes((args, kwargs)) + input_rng_streams = rnglib.backup_keys(input_graph_nodes) - # split module state - filters = (*options.variable_axes.keys(), ...) - graphdef, *scan_states, carry_state = module.split(*filters) + # extract carry arg + carry_arg, args = _extract_carry_arg(args, options.carry_argnum) - # transpose axes state - scan_states = tuple( - jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, axis, 0), axes_state) - for axes_state, axis in zip(scan_states, options.variable_axes.values()) + ctx = graph.UpdateContext() + # split module state + filters = (*options.state_axes.keys(), ...) + graphdef, rng_state, *scan_states, carry_state = ctx.split( + input_graph_nodes, rnglib.RngState, *filters ) + # transpose axes arg - scan_args = jax.tree_util.tree_map( - lambda axis, node: jax.tree_util.tree_map( - lambda x: jnp.moveaxis(x, axis, 0), node - ) - if axis is not None - else None, - options.in_args_axes, - args, - is_leaf=lambda x: x is None, - ) - broadcast_args = jax.tree_util.tree_map( - lambda axis, node: node if axis is None else None, - options.in_args_axes, - args, - is_leaf=lambda x: x is None, - ) - scan_kwargs = jax.tree_util.tree_map( - lambda axis, node: jax.tree_util.tree_map( - lambda x: jnp.moveaxis(x, axis, 0), node - ) - if axis is not None - else None, - options.in_kwargs_axes, - kwargs, - is_leaf=lambda x: x is None, - ) - broadcast_kwargs = jax.tree_util.tree_map( - lambda axis, node: None if axis is not None else node, - options.in_kwargs_axes, - kwargs, - is_leaf=lambda x: x is None, + flatdef, flat_scan, flat_carry = _transpose_and_split( + (args, kwargs, scan_states), + ( + options.in_axes, + options.in_axes_kwargs, + list(options.state_axes.values()), + ), ) # infer length - lengths: tp.Set[int] = set( - x.shape[0] - for x in jax.tree_util.tree_leaves((scan_states, scan_args, scan_kwargs)) + lengths: set[int] = set( + x.shape[axis] # type: ignore + for x, axis in zip(flat_scan, flatdef.flat_axes) + if axis is not None ) if len(lengths) > 1: raise ValueError( - 'Inconsistent lengths between variable_axes states and ' + 'Inconsistent lengths between state_axes states and ' f'arguments: {lengths}' ) elif len(lengths) == 0: if options.length is None: raise ValueError( - 'Cannot infer length from variable_axes states or axes_arg, ' + 'Cannot infer length from state_axes states or axes_arg, ' 'please specify `length`' ) length = options.length @@ -1165,199 +1056,316 @@ def scan_apply( ) # split rng state - if rngs is not None: - if not isinstance(rngs, rnglib.Rngs): - raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - forked_rngs = rngs.fork({filterlib.Not(options.broadcast_rngs): length}) - split_keys, broadcast_keys = forked_rngs.splits, forked_rngs.broadcasts - else: - split_keys = None - broadcast_keys = None - - moduledef_out: tp.Optional[GraphDef[Module]] = None + split_keys, carry_keys = rnglib.fork( + rng_state, + options.split_rngs, + length, + ) def scan_fn( - carry: tuple[State, tp.Any], + carry: tuple[ + State, # carry_keys + State, # carry_state + tp.Any, # carry_arg + ], scan: tuple[ - dict[str, rnglib.RngStream] | None, - tuple[State, ...], - tuple[tp.Any, ...], - dict[str, tp.Any], + State, # split_keys + list[jax.Array | None], # flat_scan ], ): - nonlocal moduledef_out - carry_state, carry_arg = carry - split_keys, scan_states, scan_args, scan_kwargs = scan + carry_keys, carry_state, carry_arg = carry + split_keys, flat_scan = scan # merge args and kwargs - args = jax.tree_util.tree_map( - lambda axis, scan, broadcast: scan if axis is not None else broadcast, - options.in_args_axes, - scan_args, - broadcast_args, - is_leaf=lambda x: x is None, + args, kwargs, scan_states = _unflatten_splits( + flatdef, flat_scan, flat_carry ) - kwargs = jax.tree_util.tree_map( - lambda axis, scan, broadcast: scan if axis is not None else broadcast, - options.in_kwargs_axes, - scan_kwargs, - broadcast_kwargs, - is_leaf=lambda x: x is None, - ) - - # merge rng state - if split_keys is not None: - assert broadcast_keys is not None - kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) - # remove metadata axis name from Variable.sharding - if spmd.PARTITION_NAME in options.scan_metadata: + if spmd.PARTITION_NAME in options.transform_metadata: scan_states = [ - spmd.remove_axis(state, index, options.scan_metadata) - for state, index in zip(scan_states, options.variable_axes.values()) + spmd.remove_axis(state, index, options.transform_metadata) + for state, index in zip(scan_states, options.state_axes.values()) ] + # insert carry arg + args = _insert_carry_arg(args, options.carry_argnum, carry_arg) + # merge module state - module = graphdef.merge(*scan_states, carry_state) + input_graph_nodes = ctx.merge( + graphdef, *scan_states, carry_state, split_keys, carry_keys + ) + (args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes) - output = f(module, carry_arg, *args, **kwargs) + out = f(*args, **kwargs) if options.scan_output: - if not isinstance(output, tuple) or len(output) != 2: + if not isinstance(out, tuple) or len(out) != 2: raise ValueError( 'Expected a tuple of length 2 as the output of the scan function, ' - f'got {output}' + f'got {out}' ) - output = tp.cast(tuple[C, B], output) - carry_out, scan_out = output + out = tp.cast(tuple[C, B], out) + carry_arg_out, scan_args_out = out else: - output = tp.cast(C, output) - carry_out = output - scan_out = None + out = tp.cast(C, out) + carry_arg_out = out + scan_args_out = None + + ( + (carry_arg_out, scan_args_out), + output_graph_nodes, + ) = graph.extract_graph_nodes((carry_arg_out, scan_args_out)) # split module state - moduledef_out, *scan_states_out, carry_state_out = module.split(*filters) - carry_state_new = carry_state_out - carry_state + ( + graphdef_out, + rng_state_out, + *scan_states_out, + carry_state_out, + ) = ctx.split( + (input_graph_nodes, output_graph_nodes), + rnglib.RngState, + *filters, + ) + + not_keys_out, split_keys_out, carry_keys_out = rng_state_out.split( + rnglib.NotKey, options.split_rngs, ... + ) + carry_keys_out = State.merge(not_keys_out, carry_keys_out) - # remove new carry state - carry_state_out = carry_state_out - carry_state_new + if 1 in carry_state_out: + raise ValueError( + f'Cannot add new carry state during scan, got {carry_state_out[1]}' + ) + if 0 in carry_state_out: + carry_state_out = carry_state_out[0] + assert isinstance(carry_state_out, State) + if 1 in carry_keys_out: + raise ValueError( + f'Cannot add new carry keys during scan, got {carry_keys_out[1]}' + ) + if 0 in carry_keys_out: + carry_keys_out = carry_keys_out[0] + assert isinstance(carry_keys_out, State) # add metadata axis name to Variable.sharding - if spmd.PARTITION_NAME in options.scan_metadata: + if spmd.PARTITION_NAME in options.transform_metadata: scan_states_out = [ - spmd.add_axis(state, index, options.scan_metadata) - for state, index in zip(scan_states_out, options.variable_axes.values()) + spmd.add_axis(state, index, options.transform_metadata) + for state, index in zip(scan_states_out, options.state_axes.values()) ] - full_carry_out = (carry_state_out, carry_out) - full_scan_out = (scan_states_out, carry_state_new, scan_out) + carry_out = (carry_keys_out, carry_state_out, carry_arg_out) + scan_out = (graphdef_out, scan_args_out, scan_states_out, split_keys_out) - return full_carry_out, full_scan_out + return carry_out, scan_out - carry = (carry_state, carry_arg) - scan = (split_keys, scan_states, scan_args, scan_kwargs) + carry = (carry_keys, carry_state, carry_arg) + scan = (split_keys, flat_scan) - full_carry_out, full_scan_out = jax.lax.scan( + carry_out, scan_out = jax.lax.scan( scan_fn, carry, scan, length=length, reverse=options.reverse, unroll=options.unroll, + _split_transpose=options._split_transpose, ) - carry_state, carry_out = full_carry_out - scan_states, carry_state_new, scan_out = full_scan_out - assert moduledef_out is not None - - # transpose axes state - scan_states = tuple( - jax.tree_util.tree_map(lambda x: jnp.moveaxis(x, 0, axis), axes_state) - for axes_state, axis in zip(scan_states, options.variable_axes.values()) + carry_keys_out, carry_state_out, carry_arg_out = carry_out + graphdef_out, scan_args_out, scan_states_out, split_keys_out = scan_out + + scan_args_out, scan_states_out = _transpose_tree( + (scan_args_out, scan_states_out), + (options.out_axes, list(options.state_axes.values())), + axis_is_source=False, ) - # transpose axes arg - scan_out = jax.tree_util.tree_map( - lambda axis, node: jax.tree_util.tree_map( - lambda x: jnp.moveaxis(x, 0, axis), node - ), - options.out_axes, - scan_out, + + if carry_state_out: + carry_state_out = State({0: carry_state_out._mapping}) + if carry_keys_out: + carry_keys_out = State({0: carry_keys_out._mapping}) + _, output_graph_nodes = ctx.update( + graphdef_out, + *scan_states_out, + carry_state_out, + carry_keys_out, + split_keys_out, + ) + + carry_arg_out, scan_args_out = graph.insert_graph_nodes( + (carry_arg_out, scan_args_out), output_graph_nodes ) - # slice new carry state - carry_state_new = jax.tree_util.tree_map(lambda x: x[0], carry_state_new) - module.update(((*scan_states, carry_state, carry_state_new), moduledef_out)) + rnglib.restore_keys(input_rng_streams) if options.scan_output: - return carry_out, scan_out + scan_args_out = tp.cast(B, scan_args_out) + return carry_arg_out, scan_args_out else: - return carry_out + return carry_arg_out + + +@dataclasses.dataclass(frozen=True) +class FlatDef(tp.Generic[A]): + type: type[A] + treedef: jax.tree_util.PyTreeDef + flat_axes: list[int | None] + +jax.tree_util.register_static(FlatDef) + +def _transpose_tree(tree: A, axes, /, *, axis_is_source: bool) -> A: + flatdef, flat_transposes, _ = _transpose_and_split( + tree, axes, allow_none=False, axis_is_source=axis_is_source + ) + return flatdef.treedef.unflatten(flat_transposes) + + +def _transpose_and_split( + tree: A, axes, /, *, allow_none: bool = True, axis_is_source: bool = True +) -> tuple[ + FlatDef[A], + list[jax.Array | None], + list[tp.Any], +]: + flat_axes: list[int | None] = broadcast_prefix( + axes, tree, is_leaf=lambda x: x is None + ) + flat_tree, treedef = jax.tree.flatten(tree) + + flat_broadcasts: list[tp.Any] = [] + flat_transposes: list[jax.Array | None] = [] + + for i, (axis, node) in enumerate(zip(flat_axes, flat_tree)): + if axis is None: + if not allow_none: + raise ValueError('None axis not allowed') + + flat_broadcasts.append(node) + flat_transposes.append(None) + else: + if not isinstance(node, jax.Array): + raise TypeError( + f'Expected a jax.Array, got {type(node).__name__} for axis {axis}' + ) + # normalize axis + if axis < 0: + if axis < -len(node.shape): + raise ValueError( + f'Axis {axis} out of bounds for array with shape {node.shape}' + ) + axis = len(node.shape) + axis + flat_axes[i] = axis + + if axis_is_source: + node = jnp.moveaxis(node, axis, 0) + else: + node = jnp.moveaxis(node, 0, axis) + flat_broadcasts.append(None) + flat_transposes.append(node) + + flatdef = FlatDef(type(tree), treedef, flat_axes) + + return flatdef, flat_transposes, flat_broadcasts + +def _unflatten_splits( + flatdef: FlatDef[A], + flat_transposes: list[jax.Array | None], + flat_broadcasts: list[tp.Any] | None = None, + /, + *, + allow_none: bool = True, +) -> A: + flat_axes = flatdef.flat_axes + treedef = flatdef.treedef + if flat_broadcasts is None: + if allow_none: + raise ValueError('flat_broadcasts must be provided if allow_none is True') + flat_broadcasts = [None] * len(flat_axes) + + flat_tree = [] + for axis, transpose, broadcast in zip( + flat_axes, flat_transposes, flat_broadcasts + ): + if axis is None: + if not allow_none: + raise ValueError('None axis not allowed') + flat_tree.append(broadcast) + else: + if transpose is None: + raise ValueError('None transpose not allowed') + flat_tree.append(transpose) + + tree = treedef.unflatten(flat_tree) + return tree + + +def _extract_carry_arg( + args: tuple[tp.Any, ...], carry_argnum: int, / +) -> tuple[tp.Any, tuple[tp.Any, ...]]: + # extract carry arg + if len(args) < carry_argnum + 1: + raise TypeError( + f'Expected at least {carry_argnum + 1} positional arguments, ' + f'got {len(args)}' + ) + + args_ = list(args) + carry_arg = args_[carry_argnum] + args_[carry_argnum] = None + args = tuple(args_) + + return carry_arg, args + + +def _insert_carry_arg( + args: tuple[tp.Any, ...], carry_argnum: int, carry_arg: tp.Any, / +) -> tuple[tp.Any, ...]: + args_ = list(args) + args_[carry_argnum] = carry_arg + args = tuple(args_) + + return args def scan( f: F, *, - variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), - broadcast_rngs: filterlib.Filter = None, - in_args_axes: tp.Any = 0, - in_kwargs_axes: tp.Any = 0, - out_axes: tp.Any = 0, - length: tp.Optional[int] = None, + length: int | None = None, reverse: bool = False, - unroll: int = 1, - is_init: tp.Optional[bool] = None, - scan_metadata: tp.Mapping[str, tp.Any] = {}, + unroll: int | bool = 1, + _split_transpose: bool = False, + # extended api + in_axes: int | None | tp.Sequence[tp.Any] = 0, + in_axes_kwargs: tp.Any = 0, + out_axes: tp.Any = 0, + carry_argnum: int = 0, + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), scan_output: bool = True, ) -> F: - if is_init is None: - is_init = f.__name__ == '__init__' - options = ScanOptions( - variable_axes=variable_axes, - broadcast_rngs=broadcast_rngs, - in_args_axes=in_args_axes, - in_kwargs_axes=in_kwargs_axes, - out_axes=out_axes, length=length, reverse=reverse, unroll=unroll, - scan_metadata=scan_metadata, + _split_transpose=_split_transpose, + in_axes=in_axes, + in_axes_kwargs=in_axes_kwargs, + out_axes=out_axes, + carry_argnum=carry_argnum, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, scan_output=scan_output, ) - if is_init: - - @functools.wraps(f) - def scan_init_wrapper(module: Module, *args, **kwargs): - def module_constructor(*args, **kwargs): - _check_args(args) - f(module, *args, **kwargs) - return module - - lifted_module = scan_init(options, module_constructor, args, kwargs) - module.update(lifted_module) - - wrapper = scan_init_wrapper - - else: - - @functools.wraps(f) - def scan_apply_wrapper( - module: Module, - *args, - **kwargs, - ) -> tuple[C, tp.Any]: - if len(args) < 2: - raise TypeError( - f'Expected at least 2 positional arguments, got {len(args)}' - ) - _check_args(args) - - carry_arg, args = args[0], args[1:] - return scan_apply(options, f, module, carry_arg, args, kwargs) - - wrapper = scan_apply_wrapper + @functools.wraps(f) + def scan_apply_wrapper(*args, **kwargs) -> C | tuple[C, tp.Any]: + return scan_apply(options, f, args, kwargs) - return wrapper # type: ignore + return scan_apply_wrapper # type: ignore # ------------------------------- @@ -1369,16 +1377,13 @@ class RematMeta(ModuleMeta): def __call__( self, module_constructor: tp.Callable[..., M], - # variables: lift.CollectionFilter = True, - # rngs: lift.PRNGSequenceFilter = True, prevent_cse: bool = True, - static_argnums: tp.Union[int, tuple[int, ...]] = (), - policy: tp.Optional[tp.Callable[..., bool]] = None, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, ) -> tp.Callable[..., 'Remat[M]']: super_call = super().__call__ def create_remat(*args, **kwargs) -> Remat[M]: - _check_args(args) return super_call( module_constructor=module_constructor, module_init_args=args, @@ -1394,16 +1399,16 @@ def create_remat(*args, **kwargs) -> Remat[M]: @dataclasses.dataclass class RematOptions: prevent_cse: bool - static_argnums: tp.Union[int, tuple[int, ...]] - policy: tp.Optional[tp.Callable[..., bool]] + static_argnums: int | tuple[int, ...] + policy: tp.Callable[..., bool] | None def __post_init__(self): if isinstance(self.static_argnums, int): self.static_argnums = (self.static_argnums,) - # add 2 as an offset to account for state and keys + # add 1 as an offset to account for state parameter self.static_argnums = tuple( - x + 2 if x >= 0 else x for x in self.static_argnums + x + 1 if x >= 0 else x for x in self.static_argnums ) @@ -1413,8 +1418,8 @@ def __init__( *, module_constructor: tp.Callable[..., M], prevent_cse: bool = True, - static_argnums: tp.Union[int, tuple[int, ...]] = (), - policy: tp.Optional[tp.Callable[..., bool]] = None, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, # submodule args module_init_args: tuple[tp.Any, ...], module_init_kwargs: dict[str, tp.Any], @@ -1433,65 +1438,45 @@ def __init__( def _submodule(self) -> M: return self.remat_module - def _call( - self, - accessor: DelayedAccessor, - *args, - rngs: tp.Optional[rnglib.Rngs] = None, - ) -> tp.Any: - def remat_call_apply(module, *args, **kwargs): - return accessor(module)(*args, **kwargs) + def _call(self, accessor: DelayedAccessor, *args) -> tp.Any: + def remat_apply_call(module, *args): + method = accessor(module) + return method(*args) return remat_apply( self.options, - remat_call_apply, - self.remat_module, - args, - rngs, + remat_apply_call, + (self.remat_module, *args), ) -class RematCall(tp.Protocol): - def __call__(self, *args, rngs: tp.Optional[rnglib.Rngs]) -> tp.Any: - ... - - def remat_apply( options: RematOptions, - f: RematCall, - module: Module, + f: tp.Callable[..., tp.Any], args: tuple[tp.Any, ...], - rngs: tp.Optional[rnglib.Rngs], ): - _check_args(args) - - graphdef, state = module.split() - keys = rngs.fork() if rngs is not None else None - - def _remat_fn( - state: State, - keys: tp.Optional[dict[str, jax.Array]], - *args, - ) -> tuple[tuple[GraphDef[Module], State], tp.Any]: - kwargs = {} - if keys is not None: - kwargs['rngs'] = rnglib.Rngs(keys) + ctx = graph.UpdateContext() + args, input_nodes = graph.extract_graph_nodes(args) + graphdef, state = ctx.split(input_nodes) - module = graphdef.merge(state) - out = f(module, *args, **kwargs) + def _remat_fn(state: State, *args): + input_nodes = ctx.merge(graphdef, state) + args = graph.insert_graph_nodes(args, input_nodes) + out = f(*args) - def_and_state = module.split() - return def_and_state, out + out, output_nodes = graph.extract_graph_nodes(out) + new_graphdef, new_state = ctx.split((input_nodes, output_nodes)) + return (new_graphdef, new_state), out - def_and_state: tuple[GraphDef[Module], State] - def_and_state, out = jax.checkpoint( + (new_graphdef, new_state), out = jax.checkpoint( _remat_fn, prevent_cse=options.prevent_cse, static_argnums=options.static_argnums, policy=options.policy, - )(state, keys, *args) + )(state, *args) - module.update(def_and_state) + _, output_nodes = ctx.update(new_graphdef, new_state) + out = graph.insert_graph_nodes(out, output_nodes) return out @@ -1499,53 +1484,39 @@ def _remat_fn( def remat( f: F, *, - # variables: lift.CollectionFilter, - # rngs: lift.PRNGSequenceFilter, prevent_cse: bool = True, - static_argnums: tp.Union[int, tuple[int, ...]] = (), - policy: tp.Optional[tp.Callable[..., bool]] = None, - is_init: tp.Optional[bool] = None, + static_argnums: int | tuple[int, ...] = (), + policy: tp.Callable[..., bool] | None = None, ) -> F: - if is_init is None: - is_init = f.__name__ == '__init__' - options = RematOptions( - # variables=variables, - # rngs=rngs, prevent_cse=prevent_cse, static_argnums=static_argnums, policy=policy, ) - if is_init: - return f - else: - - @functools.wraps(f) - def remat_wrapper( - module: Module, *args, rngs: tp.Optional[rnglib.Rngs] = None - ): - return remat_apply(options, f, module, args, rngs) + @functools.wraps(f) + def remat_wrapper(*args): + return remat_apply(options, f, args) - return remat_wrapper # type: ignore + return remat_wrapper # type: ignore # ------------------------------- # vmap # ------------------------------- - @dataclasses.dataclass class VmapOptions: - variable_axes: tp.Mapping[filterlib.Filter, int] - broadcast_rngs: filterlib.Filter - in_args_axes: tp.Any - in_kwargs_axes: tp.Any + in_axes: int | None | tp.Sequence[tp.Any] out_axes: tp.Any + axis_name: AxisName | None axis_size: int | None - axis_name: str | None - spmd_axis_name: str | None - vmap_metadata: tp.Mapping[str, tp.Any] + spmd_axis_name: AxisName | tuple[AxisName, ...] | None + # nnx specific + state_axes: tp.Mapping[filterlib.Filter, int] + split_rngs: filterlib.Filter + in_axes_kwargs: tp.Any + transform_metadata: tp.Mapping[str, tp.Any] class VmapMeta(ModuleMeta): @@ -1553,36 +1524,38 @@ def __call__( self, module_constructor: tp.Callable[..., M], *, - variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), - broadcast_rngs: filterlib.Filter = None, - in_args_axes: tp.Any = 0, - in_kwargs_axes: tp.Any = 0, + in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, + axis_name: AxisName | None = None, axis_size: int | None = None, - axis_name: str | None = None, - spmd_axis_name: str | None = None, - vmap_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> tp.Callable[..., 'Vmap[M]']: super_call = super().__call__ - def _create_scan(*args, **kwargs) -> Scan[M]: - _check_args(args) + def _create_vmap(*args, **kwargs) -> Scan[M]: return super_call( module_constructor=module_constructor, - module_init_args=args, - module_init_kwargs=kwargs, - variable_axes=variable_axes, - broadcast_rngs=broadcast_rngs, - in_args_axes=in_args_axes, - in_kwargs_axes=in_kwargs_axes, + in_axes=in_axes, out_axes=out_axes, axis_size=axis_size, axis_name=axis_name, spmd_axis_name=spmd_axis_name, - vmap_metadata=vmap_metadata, + # nnx specific + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, + # submodule args + module_init_args=args, + module_init_kwargs=kwargs, ) - return _create_scan + return _create_vmap class Vmap(LiftedModule[M], metaclass=VmapMeta): @@ -1590,148 +1563,83 @@ def __init__( self, module_constructor: tp.Callable[..., M], *, - variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), - broadcast_rngs: filterlib.Filter = None, - in_args_axes: tp.Any = 0, - in_kwargs_axes: tp.Any = 0, + in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, + axis_name: AxisName | None = None, axis_size: int | None = None, - axis_name: str | None = None, - spmd_axis_name: str | None = None, - vmap_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), # submodule args module_init_args: tuple[tp.Any, ...], module_init_kwargs: dict[str, tp.Any], ): self.module_constructor = module_constructor self.options = VmapOptions( - variable_axes=variable_axes, - broadcast_rngs=broadcast_rngs, - in_args_axes=in_args_axes, - in_kwargs_axes=in_kwargs_axes, + in_axes=in_axes, out_axes=out_axes, - axis_size=axis_size, axis_name=axis_name, + axis_size=axis_size, spmd_axis_name=spmd_axis_name, - vmap_metadata=vmap_metadata, + # nnx specific + in_axes_kwargs=in_axes_kwargs, + state_axes=state_axes, + split_rngs=split_rngs, + transform_metadata=transform_metadata, ) - self.vmap_module = vmap_init( - self.options, module_constructor, module_init_args, module_init_kwargs + + ( + (module_init_args, module_init_kwargs), + init_nodes, + ) = graph.extract_graph_nodes((module_init_args, module_init_kwargs)) + + def vmap_init(init_nodes): + (args, kwargs) = graph.insert_graph_nodes( + (module_init_args, module_init_kwargs), init_nodes + ) + return module_constructor(*args, **kwargs) + + init_options = dataclasses.replace( + self.options, + in_axes=None, + out_axes=None, ) + self.vmap_module = vmap_apply(init_options, vmap_init, (init_nodes,), {}) @property def _submodule(self) -> M: return self.vmap_module - def _call( - self, accessor: DelayedAccessor, *args, **kwargs - ) -> tuple[tp.Any, tp.Any]: - _check_args(args) - - def vmap_call_apply(module, *args, **kwargs): - return accessor(module)(*args, **kwargs) + def _call(self, accessor: DelayedAccessor, *args, **kwargs): + def vmap_apply_call(module, *args, **kwargs): + method = accessor(module) + return method(*args, **kwargs) return vmap_apply( self.options, - vmap_call_apply, - self.vmap_module, - args, + vmap_apply_call, + (self._submodule, *args), kwargs, ) - -class VmapCall(tp.Protocol): - def __call__( - self, - module: Module, - *args: tp.Any, - **kwargs: tp.Any, - ) -> tp.Any: - ... - - -def vmap_init( - options: VmapOptions, - module_constructor: tp.Callable[..., M], - module_init_args: tuple[tp.Any, ...], - module_init_kwargs: dict[str, tp.Any], -) -> M: - if options.variable_axes and options.axis_size is None: - raise ValueError('Cannot use variable_axes without specifying a length') - - _check_args(module_init_args) - - rngs = module_init_kwargs.pop('rngs', None) - - if rngs is not None and not isinstance(rngs, rnglib.Rngs): - raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - - if rngs is not None: - if not isinstance(rngs, rnglib.Rngs): - raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - forked_rngs = rngs.fork( - {filterlib.Not(options.broadcast_rngs): options.axis_size} - ) - split_keys, broadcast_keys = forked_rngs.splits, forked_rngs.broadcasts - if split_keys and options.axis_size is None: - raise ValueError('Cannot split RNGs without specifying a length') - else: - split_keys = None - broadcast_keys = None - - graphdef: tp.Optional[GraphDef[M]] = None - - def _init_state(split_keys, broadcast_keys): - nonlocal graphdef - - if split_keys is not None: - assert broadcast_keys is not None - module_init_kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) - - module = module_constructor(*module_init_args, **module_init_kwargs) - - # lift module - filters = (*options.variable_axes.keys(), ...) - - graphdef, *states = module.split(*filters) - - return tuple(states) - - if split_keys is not None or options.variable_axes: - init_out_axes = (*options.variable_axes.values(), None) - _init_state = jax.vmap( - _init_state, - in_axes=(0, None), - out_axes=init_out_axes, - axis_size=options.axis_size, - ) - - *axes_states, carry_state = _init_state(split_keys, broadcast_keys) - graphdef = tp.cast(GraphDef[M], graphdef) - - # add additional axis name to Variable.sharding - if spmd.PARTITION_NAME in options.vmap_metadata: - axes_states = [ - spmd.add_axis(state, index, options.vmap_metadata) - for state, index in zip(axes_states, options.variable_axes.values()) - ] - - module = graphdef.merge(*axes_states, carry_state) - return module - - def vmap_apply( options: VmapOptions, - f: VmapCall, - module: Module, + f: tp.Callable[..., A], args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], -) -> tp.Any: - rngs = kwargs.pop('rngs', None) +) -> A: + (args, kwargs), input_graph_nodes = graph.extract_graph_nodes((args, kwargs)) + input_rng_streams = rnglib.backup_keys(input_graph_nodes) + ctx = graph.UpdateContext() # split module state - filters = (*options.variable_axes.keys(), ...) - graphdef, *vectorized_states, broadcast_state = module.split(*filters) + filters = (*options.state_axes.keys(), ...) + graphdef, rng_state, *vectorized_states, broadcast_state = ctx.split( + input_graph_nodes, rnglib.RngState, *filters + ) # infer length axis_sizes: tp.Set[int] = set() @@ -1739,7 +1647,7 @@ def vmap_apply( lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node) if axis is not None else None, - options.in_args_axes, + options.in_axes, args, is_leaf=lambda x: x is None, ) @@ -1747,7 +1655,7 @@ def vmap_apply( lambda axis, node: jax.tree_util.tree_map(lambda x: x.shape[axis], node) if axis is not None else None, - options.in_kwargs_axes, + options.in_axes_kwargs, kwargs, is_leaf=lambda x: x is None, ) @@ -1756,13 +1664,13 @@ def vmap_apply( if len(axis_sizes) > 1: raise ValueError( - 'Inconsistent lengths between variable_axes states and ' + 'Inconsistent lengths between state_axes states and ' f'arguments: {axis_sizes}' ) elif len(axis_sizes) == 0: if options.axis_size is None: raise ValueError( - 'Cannot infer length from variable_axes states or axes_arg, ' + 'Cannot infer length from state_axes states or axes_arg, ' 'please specify `length`' ) axis_size = options.axis_size @@ -1774,136 +1682,178 @@ def vmap_apply( f' inferred length {axis_size}' ) - # split rng state - if rngs is not None: - if not isinstance(rngs, rnglib.Rngs): - raise TypeError(f'Expected a Rngs, got {type(rngs).__name__}') - - forked_rngs = rngs.fork({filterlib.Not(options.broadcast_rngs): axis_size}) - split_keys, broadcast_keys = forked_rngs.splits, forked_rngs.broadcasts - else: - split_keys = None - broadcast_keys = None - - moduledef_out: tp.Optional[GraphDef[Module]] = None + split_keys, broadcast_keys = rnglib.fork( + rng_state, + options.split_rngs, + axis_size, + ) keys_axes = 0 - states_axes = list(options.variable_axes.values()) - args_axes = options.in_args_axes - kwargs_axes = options.in_kwargs_axes + states_axes = list(options.state_axes.values()) + args_axes = options.in_axes + kwargs_axes = options.in_axes_kwargs out_axes = options.out_axes + broadcast_state_axes = None + graphdef_out_axes = None + keys_axes_out = 0 @functools.partial( jax.vmap, in_axes=(keys_axes, states_axes, args_axes, kwargs_axes), - out_axes=(None, states_axes, out_axes), + out_axes=( + graphdef_out_axes, + broadcast_state_axes, + states_axes, + keys_axes_out, + out_axes, + ), axis_name=options.axis_name, axis_size=axis_size, spmd_axis_name=options.spmd_axis_name, ) def vmap_fn( - split_keys: dict[str, rnglib.RngStream] | None, + split_keys: State, vectorized_states: list[State], args: tuple[tp.Any, ...], kwargs: dict[str, tp.Any], ): - nonlocal moduledef_out - - # merge rng state - if split_keys is not None: - assert broadcast_keys is not None - kwargs['rngs'] = rnglib.Rngs(**split_keys, **broadcast_keys) - # remove metadata axis name from Variable.sharding - if spmd.PARTITION_NAME in options.vmap_metadata: + if spmd.PARTITION_NAME in options.transform_metadata: vectorized_states = [ - spmd.remove_axis(state, index, options.vmap_metadata) - for state, index in zip( - vectorized_states, options.variable_axes.values() - ) + spmd.remove_axis(state, index, options.transform_metadata) + for state, index in zip(vectorized_states, options.state_axes.values()) ] # merge module state - module = graphdef.merge(*vectorized_states, broadcast_state) + input_graph_nodes = ctx.merge( + graphdef, *vectorized_states, broadcast_state, split_keys, broadcast_keys + ) + + (args, kwargs) = graph.insert_graph_nodes((args, kwargs), input_graph_nodes) + + out = f(*args, **kwargs) - output = f(module, *args, **kwargs) + out, output_graph_nodes = graph.extract_graph_nodes(out) # split module state - moduledef_out, *vectorized_states_out, broadcast_state_out = module.split( - *filters + ( + graphdef_out, + rng_state_out, + *vectorized_states_out, + broadcast_state_out, + ) = ctx.split( + (input_graph_nodes, output_graph_nodes), + rnglib.RngState, + *filters, + ) + + not_keys_out, split_keys_out, broadcast_keys_out = rng_state_out.split( + rnglib.NotKey, options.split_rngs, ... + ) + + broadcast_state_out = State.merge( + broadcast_state_out, broadcast_keys_out, not_keys_out ) # add metadata axis name to Variable.sharding - if spmd.PARTITION_NAME in options.vmap_metadata: + if spmd.PARTITION_NAME in options.transform_metadata: vectorized_states_out = [ - spmd.add_axis(state, index, options.vmap_metadata) + spmd.add_axis(state, index, options.transform_metadata) for state, index in zip( - vectorized_states_out, options.variable_axes.values() + vectorized_states_out, options.state_axes.values() ) ] - return broadcast_state_out, vectorized_states_out, output + return ( + graphdef_out, + broadcast_state_out, + vectorized_states_out, + split_keys_out, + out, + ) - broadcast_state, vectorized_states, output = vmap_fn( - split_keys, vectorized_states, args, kwargs + ( + graphdef_out, + broadcast_state, + vectorized_states, + split_keys_out, + out, + ) = vmap_fn(split_keys, vectorized_states, args, kwargs) + + _, output_graph_nodes = ctx.update( + graphdef_out, + *vectorized_states, + broadcast_state, + split_keys_out, ) - assert moduledef_out is not None - module.update(((*vectorized_states, broadcast_state), moduledef_out)) + out = graph.insert_graph_nodes(out, output_graph_nodes) - return output + rnglib.restore_keys(input_rng_streams) + + return out def vmap( f: F, *, - variable_axes: tp.Mapping[filterlib.Filter, int] = MappingProxyType({}), - broadcast_rngs: filterlib.Filter = None, - in_args_axes: tp.Any = 0, - in_kwargs_axes: tp.Any = 0, + in_axes: int | None | tp.Sequence[tp.Any] = 0, out_axes: tp.Any = 0, + axis_name: AxisName | None = None, axis_size: int | None = None, - axis_name: str | None = None, - spmd_axis_name: str | None = None, - vmap_metadata: tp.Mapping[str, tp.Any] = MappingProxyType({}), - is_init: tp.Optional[bool] = None, + spmd_axis_name: AxisName | tuple[AxisName, ...] | None = None, + # nnx specific + in_axes_kwargs: tp.Any = 0, + state_axes: tp.Mapping[filterlib.Filter, int] = FrozenDict({...: 0}), + split_rngs: filterlib.Filter = ..., + transform_metadata: tp.Mapping[str, tp.Any] = FrozenDict({}), ) -> F: - if is_init is None: - is_init = f.__name__ == '__init__' - options = VmapOptions( - variable_axes=variable_axes, - broadcast_rngs=broadcast_rngs, - in_args_axes=in_args_axes, - in_kwargs_axes=in_kwargs_axes, + state_axes=state_axes, + split_rngs=split_rngs, + in_axes=in_axes, + in_axes_kwargs=in_axes_kwargs, out_axes=out_axes, axis_size=axis_size, axis_name=axis_name, spmd_axis_name=spmd_axis_name, - vmap_metadata=vmap_metadata, + transform_metadata=transform_metadata, ) - if is_init: + @functools.wraps(f) + def vmap_apply_wrapper(*args, **kwargs) -> tp.Any: + return vmap_apply(options, f, args, kwargs) + + wrapper = vmap_apply_wrapper - @functools.wraps(f) - def vmap_init_wrapper(module: Module, *args, **kwargs): - def module_constructor(*args, **kwargs): - _check_args(args) - f(module, *args, **kwargs) - return module + return wrapper # type: ignore - lifted_module = vmap_init(options, module_constructor, args, kwargs) - module.update(lifted_module) +# ------------------------------- +# eval_shape +# ------------------------------- - wrapper = vmap_init_wrapper - else: +def eval_shape( + f: tp.Callable[..., A], + *args: tp.Any, + **kwargs: tp.Any, +) -> A: + (args, kwargs), input_nodes = graph.extract_graph_nodes((args, kwargs)) + graphdef, state = graph.split(input_nodes) - @functools.wraps(f) - def vmap_apply_wrapper(module: Module, *args, **kwargs) -> tp.Any: - _check_args(args) - return vmap_apply(options, f, module, args, kwargs) + @functools.wraps(f) + def _eval_shape_fn(state: State, *args, **kwargs): + input_nodes = graph.merge(graphdef, state) + args, kwargs = graph.insert_graph_nodes((args, kwargs), input_nodes) + out = f(*args, **kwargs) + out, output_nodes = graph.extract_graph_nodes(out) + graphdef_out, state_out = graph.split(output_nodes) + return graphdef_out, state_out, out - wrapper = vmap_apply_wrapper + graphdef_out, state_out, out = jax.eval_shape( + _eval_shape_fn, state, *args, **kwargs + ) - return wrapper # type: ignore + output_nodes = graph.merge(graphdef_out, state_out) + out = graph.insert_graph_nodes(out, output_nodes) + return out \ No newline at end of file diff --git a/flax/experimental/nnx/nnx/variables.py b/flax/experimental/nnx/nnx/variables.py index 2d1589d3e7..4d76980e58 100644 --- a/flax/experimental/nnx/nnx/variables.py +++ b/flax/experimental/nnx/nnx/variables.py @@ -29,11 +29,9 @@ import dataclasses import functools import typing as tp -from abc import ABCMeta from functools import partial from typing import Any -import jax import jax.tree_util as jtu from flax.experimental.nnx.nnx import reprlib, tracers @@ -73,6 +71,11 @@ def __hash__(self): EMPTY = Empty() +class _Missing: + pass + +MISSING = _Missing() + @dataclasses.dataclass class VariableMetadata(tp.Generic[A]): @@ -222,6 +225,7 @@ def __init__( def __getattr__(self, name: str) -> tp.Any: ... else: + def __setattr__(self, name: str, value: Any) -> None: return self._setattr(name, value) @@ -233,23 +237,34 @@ def _setattr(self, name: str, value: tp.Any): object.__setattr__(self, name, value) + @classmethod + def state(cls, value: A, **metadata) -> 'VariableState[A]': + return cls(value, **metadata).to_state() + def copy_from(self, other: 'Variable[A]') -> None: - if not self.is_equivalent(other): + if type(self) is not type(other): raise ValueError( f'Cannot copy from incompatible container, ' f'expected {type(self).__name__}, got {type(other).__name__}' ) if self is other: return + trace_state = self._trace_state vars_dict = vars(self) + other_vars = vars(other).copy() + del other_vars['_trace_state'] vars_dict.clear() - vars_dict.update(vars(other)) + vars_dict.update(other_vars, _trace_state=trace_state) - def copy_from_def(self, other: 'nnx.graph_utils.VariableDef', /, value: A): - _trace_state = self._trace_state + def copy_from_state(self, variable_state: 'VariableState[A]'): + trace_state = self._trace_state variable_vars = vars(self) variable_vars.clear() - variable_vars.update(other.metadata, _trace_state=_trace_state, raw_value=value) + variable_vars.update( + variable_state.get_metadata(), + raw_value=variable_state.value, + _trace_state=trace_state, + ) @property def value(self) -> A: @@ -287,21 +302,28 @@ def __eq__(self, other: object) -> bool: return type(self) is type(other) and vars(other) == vars(self) @tp.overload - def replace(self, *, value: B, **kwargs) -> 'Variable[B]': + def replace(self, value: B, **kwargs) -> 'Variable[B]': ... @tp.overload def replace(self, **kwargs) -> 'Variable[A]': ... - def replace(self, **kwargs) -> 'Variable[tp.Any]': + def replace(self, value: tp.Any = MISSING, **kwargs) -> 'Variable[tp.Any]': + if value is not MISSING: + kwargs['raw_value'] = value + + # rename `value` to `raw_value` + if 'value' in kwargs: + kwargs['raw_value'] = kwargs.pop('value') + # return `value` if it is a Variable if 'raw_value' in kwargs and isinstance( value := kwargs['raw_value'], Variable ): # remove value from kwargs kwargs.pop('raw_value') - if not self.is_equivalent(value): + if type(self) is not type(value): raise ValueError( 'Cannot replace value from incompatible container, ' f'expected {type(self).__name__}, got {type(value).__name__}' @@ -321,9 +343,6 @@ def replace(self, **kwargs) -> 'Variable[tp.Any]': vars(obj).update(attributes) return obj - def is_equivalent(self, other: tp.Any) -> bool: - return type(self) is type(other) - def copy(self: 'Variable[A]') -> 'Variable[A]': obj = object.__new__(type(self)) attributes = vars(self).copy() @@ -331,23 +350,21 @@ def copy(self: 'Variable[A]') -> 'Variable[A]': vars(obj).update(attributes) return obj + def to_state(self: 'Variable[A]') -> 'VariableState[A]': + metadata = vars(self).copy() + del metadata['raw_value'] + del metadata['_trace_state'] + return VariableState(type(self), self.raw_value, **metadata) + def __nnx_repr__(self): yield reprlib.Object(type=type(self)) for name, value in vars(self).items(): - if name.endswith('_hooks') or name == "_trace_state": + if name == 'raw_value': + name = 'value' + if name.endswith('_hooks') or name == '_trace_state': continue yield reprlib.Attr(name, repr(value)) - def __init_subclass__(cls): - super().__init_subclass__() - - jtu.register_pytree_with_keys( - cls, - partial(_variable_flatten, with_keys=True), # type: ignore - partial(_variable_unflatten, cls=cls), # type: ignore - flatten_func=partial(_variable_flatten, with_keys=False), # type: ignore - ) - # hooks API if tp.TYPE_CHECKING: @@ -369,35 +386,171 @@ def on_remove_axis( raise NotImplementedError -def _variable_flatten(x: Variable[tp.Any], *, with_keys: bool): - attributes = vars(x).copy() - del attributes['_trace_state'] - value = attributes.pop('raw_value') - if with_keys: - node = (jtu.GetAttrKey('raw_value'), value) - else: - node = value + # operator overloads + def __jax_array__(self): + return self.value - return (node,), attributes + def __getitem__(self, key) -> tp.Any: + return self.value.__getitem__(key) + def __add__(self, other) -> A: + return self.value.__add__(other) # type: ignore -def _variable_unflatten( - metadata: tp.Mapping[str, tp.Any], - children: tp.Tuple[A], - *, - cls: type[Variable[A]], -) -> Variable[A]: - variable = object.__new__(cls) - vars(variable).update(metadata, _trace_state=tracers.TraceState(), raw_value=children[0]) - return variable + def __sub__(self, other) -> A: + return self.value.__sub__(other) # type: ignore + def __mul__(self, other) -> A: + return self.value.__mul__(other) # type: ignore -jtu.register_pytree_with_keys( - Variable, - partial(_variable_flatten, with_keys=True), # type: ignore - partial(_variable_unflatten, cls=Variable), # type: ignore - flatten_func=partial(_variable_flatten, with_keys=False), # type: ignore -) + def __matmul__(self, other) -> A: + return self.value.__matmul__(other) # type: ignore + + def __truediv__(self, other) -> A: + return self.value.__truediv__(other) # type: ignore + + def __floordiv__(self, other) -> A: + return self.value.__floordiv__(other) # type: ignore + + def __mod__(self, other) -> A: + return self.value.__mod__(other) # type: ignore + + def __divmod__(self, other) -> A: + return self.value.__divmod__(other) # type: ignore + + def __pow__(self, other) -> A: + return self.value.__pow__(other) # type: ignore + + def __lshift__(self, other) -> A: + return self.value.__lshift__(other) # type: ignore + + def __rshift__(self, other) -> A: + return self.value.__rshift__(other) # type: ignore + + def __and__(self, other) -> A: + return self.value.__and__(other) # type: ignore + + def __xor__(self, other) -> A: + return self.value.__xor__(other) # type: ignore + + def __or__(self, other) -> A: + return self.value.__or__(other) # type: ignore + + def __radd__(self, other) -> A: + return self.value.__radd__(other) # type: ignore + + def __rsub__(self, other) -> A: + return self.value.__rsub__(other) # type: ignore + + def __rmul__(self, other) -> A: + return self.value.__rmul__(other) # type: ignore + + def __rmatmul__(self, other) -> A: + return self.value.__rmatmul__(other) # type: ignore + + def __rtruediv__(self, other) -> A: + return self.value.__rtruediv__(other) # type: ignore + + def __rfloordiv__(self, other) -> A: + return self.value.__rfloordiv__(other) # type: ignore + + def __rmod__(self, other) -> A: + return self.value.__rmod__(other) # type: ignore + + def __rdivmod__(self, other) -> A: + return self.value.__rdivmod__(other) # type: ignore + + def __rpow__(self, other) -> A: + return self.value.__rpow__(other) # type: ignore + + def __rlshift__(self, other) -> A: + return self.value.__rlshift__(other) # type: ignore + + def __rrshift__(self, other) -> A: + return self.value.__rrshift__(other) # type: ignore + + def __rand__(self, other) -> A: + return self.value.__rand__(other) # type: ignore + + def __rxor__(self, other) -> A: + return self.value.__rxor__(other) # type: ignore + + def __ror__(self, other) -> A: + return self.value.__ror__(other) # type: ignore + + def __iadd__(self, other) -> A: + return self.value.__iadd__(other) # type: ignore + + def __isub__(self, other) -> A: + return self.value.__isub__(other) # type: ignore + + def __imul__(self, other) -> A: + return self.value.__imul__(other) # type: ignore + + def __imatmul__(self, other) -> A: + return self.value.__imatmul__(other) # type: ignore + + def __itruediv__(self, other) -> A: + return self.value.__itruediv__(other) # type: ignore + + def __ifloordiv__(self, other) -> A: + return self.value.__ifloordiv__(other) # type: ignore + + def __imod__(self, other) -> A: + return self.value.__imod__(other) # type: ignore + + def __ipow__(self, other) -> A: + return self.value.__ipow__(other) # type: ignore + + def __ilshift__(self, other) -> A: + return self.value.__ilshift__(other) # type: ignore + + def __irshift__(self, other) -> A: + return self.value.__irshift__(other) # type: ignore + + def __iand__(self, other) -> A: + return self.value.__iand__(other) # type: ignore + + def __ixor__(self, other) -> A: + return self.value.__ixor__(other) # type: ignore + + def __ior__(self, other) -> A: + return self.value.__ior__(other) # type: ignore + + def __neg__(self) -> A: + return self.value.__neg__() # type: ignore + + def __pos__(self) -> A: + return self.value.__pos__() # type: ignore + + def __abs__(self) -> A: + return self.value.__abs__() # type: ignore + + def __invert__(self) -> A: + return self.value.__invert__() # type: ignore + + def __complex__(self) -> A: + return self.value.__complex__() # type: ignore + + def __int__(self) -> A: + return self.value.__int__() # type: ignore + + def __float__(self) -> A: + return self.value.__float__() # type: ignore + + def __index__(self) -> A: + return self.value.__index__() # type: ignore + + def __round__(self, ndigits: int) -> A: + return self.value.__round__(ndigits) # type: ignore + + def __trunc__(self) -> A: + return self.value.__trunc__() # type: ignore + + def __floor__(self) -> A: + return self.value.__floor__() # type: ignore + + def __ceil__(self) -> A: + return self.value.__ceil__() # type: ignore class Param(Variable[A]): @@ -416,15 +569,90 @@ class Intermediate(Variable[A]): pass -class Rng(Variable[jax.Array]): - tag: str +class VariableState(tp.Generic[A], reprlib.Representable): + def __init__( + self, + type: tp.Type[Variable[A]], + value: A, + **metadata, + ): + self.type = type + self.value = value + vars(self).update(metadata) - def __init__(self, value: jax.Array, *, tag: str, **metadata: tp.Any): - super().__init__(value, tag=tag, **metadata) + if tp.TYPE_CHECKING: - def on_get_value(self, value: jax.Array): - self.raw_value, value = jax.random.split(value) - return value + def __getattr__(self, name: str) -> tp.Any: + ... + + def __nnx_repr__(self): + yield reprlib.Object(type=type(self)) + yield reprlib.Attr('type', self.type.__name__) + + for name, value in vars(self).items(): + if name == 'type' or name.endswith('_hooks'): + continue + yield reprlib.Attr(name, repr(value)) + + def replace(self, value: B) -> 'VariableState[B]': + return VariableState(self.type, value, **self.get_metadata()) + + def to_variable(self) -> Variable[A]: + # we use object.__new__ to avoid calling __init__ and bypass the + # __init__ logic which should not be called twice + metadata = self.get_metadata() + variables = object.__new__(self.type) + vars(variables).update( + metadata, raw_value=self.value, _trace_state=tracers.TraceState() + ) + return variables + + def get_metadata(self) -> dict[str, tp.Any]: + metadata = vars(self).copy() + del metadata['type'] + del metadata['value'] + return metadata + + def add_axis(self, axis_name: AxisName, axis_index: AxisIndex): + if not hasattr(self, 'add_axis_hooks'): + raise ValueError(f'No add_axis_hooks found for VariableState: {self}') + for hook in self.add_axis_hooks: + hook(self, axis_name, axis_index) + + def remove_axis(self, axis_name: AxisName, axis_index: AxisIndex): + if not hasattr(self, 'remove_axis_hooks'): + raise ValueError(f'No remove_axis_hooks found for VariableState: {self}') + for hook in self.remove_axis_hooks: + hook(self, axis_name, axis_index) + + +def _variable_state_flatten(x: VariableState[tp.Any], *, with_keys: bool): + metadata = tuple(x.get_metadata().items()) + if with_keys: + node = (jtu.GetAttrKey('raw_value'), x.value) + else: + node = x.value + + return (node,), (x.type, metadata) + + +def _variable_state_unflatten( + static: tuple[type[Variable[A]], tuple[tuple[str, tp.Any], ...]], + children: tuple[A], +) -> VariableState[A]: + return VariableState( + type=static[0], + value=children[0], + **dict(static[1]), + ) + + +jtu.register_pytree_with_keys( + VariableState, + partial(_variable_state_flatten, with_keys=True), # type: ignore + _variable_state_unflatten, # type: ignore + flatten_func=partial(_variable_state_flatten, with_keys=False), # type: ignore +) def with_metadata( diff --git a/flax/experimental/nnx/nnx/visualization.py b/flax/experimental/nnx/nnx/visualization.py new file mode 100644 index 0000000000..ec4225a642 --- /dev/null +++ b/flax/experimental/nnx/nnx/visualization.py @@ -0,0 +1,114 @@ +import dataclasses +import importlib.util +import typing as tp + +import jax + +from flax.experimental import nnx + +penzai_installed = importlib.util.find_spec('penzai') is not None +try: + from IPython import get_ipython + + in_ipython = get_ipython() is not None +except ImportError: + in_ipython = False + + +def display(*args): + """Display the given objects using a Penzai visualizer. + + If Penzai is not installed or the code is not running in IPython, ``display`` + will print the objects instead. + """ + if not penzai_installed or not in_ipython: + for x in args: + print(x) + return + + from penzai import pz + + with pz.ts.active_autovisualizer.set_scoped(pz.ts.ArrayAutovisualizer()): + for x in args: + value = to_dataclass(x) + pz.ts.display(value, ignore_exceptions=True) + + +def to_dataclass(node): + seen_nodes = set() + return _treemap_to_dataclass(node, seen_nodes) + + +def _to_dataclass(x, seen_nodes: set[int]): + if nnx.graph.is_graph_node(x): + if id(x) in seen_nodes: + dc_type = _make_dataclass_obj( + type(x), + {'repeated': True}, + ) + return dc_type + seen_nodes.add(id(x)) + node_impl = nnx.graph.get_node_impl(x) + node_dict = node_impl.node_dict(x) + node_dict = { + str(key): _treemap_to_dataclass(value, seen_nodes) + for key, value in node_dict.items() + } + dc_type = _make_dataclass_obj( + type(x), + node_dict, + ) + return dc_type + elif isinstance(x, (nnx.Variable, nnx.VariableState)): + obj_vars = vars(x).copy() + if 'raw_value' in obj_vars: + obj_vars['value'] = obj_vars.pop('raw_value') + if '_trace_state' in obj_vars: + del obj_vars['_trace_state'] + for name in list(obj_vars): + if name.endswith('_hooks'): + del obj_vars[name] + obj_vars = { + key: _treemap_to_dataclass(value, seen_nodes) + for key, value in obj_vars.items() + } + dc_type = _make_dataclass_obj( + type(x), + obj_vars, + penzai_dataclass=not isinstance(x, nnx.VariableState), + ) + return dc_type + elif isinstance(x, nnx.State): + return _treemap_to_dataclass(x._mapping, seen_nodes) + return x + + +def _treemap_to_dataclass(node, seen_nodes: set[int]): + def _to_dataclass_fn(x): + return _to_dataclass(x, seen_nodes) + + return jax.tree.map( + _to_dataclass_fn, + node, + is_leaf=lambda x: isinstance(x, (nnx.VariableState, nnx.State)), + ) + + +def _make_dataclass_obj( + cls, fields: dict[str, tp.Any], penzai_dataclass: bool = True +) -> tp.Type: + from penzai import pz + + dataclass = pz.pytree_dataclass if penzai_dataclass else dataclasses.dataclass + base = pz.Layer if penzai_dataclass else object + + attributes = { + '__annotations__': {key: type(value) for key, value in fields.items()}, + } + + if hasattr(cls, '__call__'): + attributes['__call__'] = cls.__call__ + + dc_type = type(cls.__name__, (base,), attributes) + dc_type = dataclass(dc_type) + return dc_type(**fields) \ No newline at end of file diff --git a/flax/experimental/nnx/tests/nn/test_attention.py b/flax/experimental/nnx/tests/nn/test_attention.py index 6ffca788ce..489c786a50 100644 --- a/flax/experimental/nnx/tests/nn/test_attention.py +++ b/flax/experimental/nnx/tests/nn/test_attention.py @@ -68,18 +68,17 @@ def __call__(self, x, sow_weights=False): module.set_attributes(decode=False) _ = module(x, True) - intermediates = module.pop(nnx.Intermediate) - # assert intermediates['attention_layers/0/attention_weights'].raw_value[ - assert intermediates['attention_layers'][0]['attention_weights'].raw_value[ + intermediates = nnx.pop(module, nnx.Intermediate) + assert intermediates['attention_layers'][0]['attention_weights'].value[ 0 ].shape == (4, 8, 6, 6) assert 1 not in intermediates['attention_layers'] - assert intermediates['attention_layers'][2]['attention_weights'].raw_value[ + assert intermediates['attention_layers'][2]['attention_weights'].value[ 0 ].shape == (4, 8, 6, 6) _ = module(x) - intermediates = module.pop(nnx.Intermediate) + intermediates = nnx.pop(module, nnx.Intermediate) assert not intermediates # empty def test_autoregressive_decode_with_x64(self): diff --git a/flax/experimental/nnx/tests/nn/test_conv.py b/flax/experimental/nnx/tests/nn/test_conv.py index db5a3a4c84..1d6016eb46 100644 --- a/flax/experimental/nnx/tests/nn/test_conv.py +++ b/flax/experimental/nnx/tests/nn/test_conv.py @@ -62,20 +62,23 @@ def test_nnx_linen_equivalence( padding = (4, 2) x = jax.numpy.ones(INPUT_SHAPE) - model_nnx = nnx.Conv.create_abstract( - IN_FEATURES, - OUT_FEATURES, - kernel_size, - strides, - padding=padding, - input_dilation=input_dilation, - kernel_dilation=kernel_dilation, - feature_group_count=feature_group_count, - use_bias=use_bias, - dtype=dtype, - param_dtype=param_dtype, - precision=precision, - rngs=rngs, + model_nnx = nnx.eval_shape( + lambda rngs: nnx.Conv( + IN_FEATURES, + OUT_FEATURES, + kernel_size, + strides, + padding=padding, + input_dilation=input_dilation, + kernel_dilation=kernel_dilation, + feature_group_count=feature_group_count, + use_bias=use_bias, + dtype=dtype, + param_dtype=param_dtype, + precision=precision, + rngs=rngs, + ), + rngs, ) model = linen.Conv( OUT_FEATURES, diff --git a/flax/experimental/nnx/tests/nn/test_embed.py b/flax/experimental/nnx/tests/nn/test_embed.py index 92962fc734..bf761346ec 100644 --- a/flax/experimental/nnx/tests/nn/test_embed.py +++ b/flax/experimental/nnx/tests/nn/test_embed.py @@ -44,12 +44,15 @@ def test_nnx_linen_equivalence( NUM_EMBEDDINGS = num_embeddings x = jax.numpy.arange(NUM_EMBEDDINGS, dtype=input_dtype) - model_nnx = nnx.Embed.create_abstract( - NUM_EMBEDDINGS, - IN_FEATURES, - dtype=dtype, - param_dtype=param_dtype, - rngs=rngs, + model_nnx = nnx.eval_shape( + lambda rngs: nnx.Embed( + NUM_EMBEDDINGS, + IN_FEATURES, + dtype=dtype, + param_dtype=param_dtype, + rngs=rngs, + ), + rngs, ) model = linen.Embed( NUM_EMBEDDINGS, IN_FEATURES, dtype=dtype, param_dtype=param_dtype diff --git a/flax/experimental/nnx/tests/nn/test_linear.py b/flax/experimental/nnx/tests/nn/test_linear.py index 3272b80fb8..a2eee40cdf 100644 --- a/flax/experimental/nnx/tests/nn/test_linear.py +++ b/flax/experimental/nnx/tests/nn/test_linear.py @@ -65,14 +65,17 @@ def test_nnx_linear_equivalence( OUT_FEATURES = 64 x = jax.numpy.ones((1, IN_FEATURES)) - model_nnx = nnx.Linear.create_abstract( - IN_FEATURES, - OUT_FEATURES, - use_bias=use_bias, - dtype=dtype, - param_dtype=param_dtype, - precision=precision, - rngs=rngs, + model_nnx = nnx.eval_shape( + lambda rngs: nnx.Linear( + IN_FEATURES, + OUT_FEATURES, + use_bias=use_bias, + dtype=dtype, + param_dtype=param_dtype, + precision=precision, + rngs=rngs, + ), + rngs, ) model = linen.Dense( OUT_FEATURES, diff --git a/flax/experimental/nnx/tests/nn/test_stochastic.py b/flax/experimental/nnx/tests/nn/test_stochastic.py index 64c3c8ccca..f302a34f10 100644 --- a/flax/experimental/nnx/tests/nn/test_stochastic.py +++ b/flax/experimental/nnx/tests/nn/test_stochastic.py @@ -70,11 +70,6 @@ def test_dropout_arg_override(self): m = nnx.Dropout(rate=0.5) x = jnp.ones((1, 10)) - # no deterministic arg provided - with pytest.raises( - ValueError, match='No `deterministic` argument was provided to Dropout' - ): - m(x) # deterministic call arg provided m(x, deterministic=True) # deterministic constructor arg provided diff --git a/flax/experimental/nnx/tests/test_graph_utils.py b/flax/experimental/nnx/tests/test_graph_utils.py index 80ad57a9fe..d976a44229 100644 --- a/flax/experimental/nnx/tests/test_graph_utils.py +++ b/flax/experimental/nnx/tests/test_graph_utils.py @@ -25,21 +25,22 @@ def test_flatten(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - static, state, ref_idx = nnx.graph_utils.graph_flatten(g) + graphdef, state, refmap = nnx.graph.flatten(g) + assert refmap is not None state[0]['b'].raw_value = 2 state[3].raw_value = 4 - assert len(ref_idx) == 2 - assert a['b'] in ref_idx - assert g[3] in ref_idx + assert len(refmap) == 2 + assert a['b'] in refmap + assert g[3] in refmap def test_unflatten(self): a = nnx.Dict(a=1, b=nnx.Param(2)) g = nnx.List([a, 3, a, nnx.Param(4)]) - static, state, _ = nnx.graph_utils.graph_flatten(g) - g = static.merge(state) + graphdef, state = nnx.split(g) + g = nnx.merge(graphdef, state) assert g[0] is g[2] @@ -47,8 +48,8 @@ def test_unflatten_pytree(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - static, state, _ = nnx.graph_utils.graph_flatten(g) - g = static.merge(state) + graphdef, state = nnx.split(g) + g = nnx.merge(graphdef, state) assert g[0] is not g[2] @@ -56,33 +57,33 @@ def test_unflatten_empty(self): a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) g = nnx.List([a, 3, a, nnx.Param(4)]) - static, state, _ = nnx.graph_utils.graph_flatten(g) - g = static.merge(nnx.State({})) + graphdef, state = nnx.split(g) - assert g[0] is g[2] - assert g[0]['b'].raw_value is nnx.EMPTY - assert g[3].raw_value is nnx.EMPTY + with pytest.raises( + ValueError, match='Expected key for Variable but was not found in state' + ): + nnx.graph.unflatten(graphdef, nnx.State({})) def test_update_dynamic(self): a = {'a': 1, 'b': nnx.Param(2)} g = [a, 3, a, nnx.Param(4)] - static, state, _ = nnx.graph_utils.graph_flatten(g) + graphdef, state = nnx.split(g) - state[0]['b'].raw_value = 3 - nnx.graph_utils.graph_update_dynamic(g, state) + state[0]['b'].value = 3 + nnx.graph.update(g, state) - assert g[0]['b'].raw_value == 3 - assert g[2]['b'].raw_value == 3 + assert g[0]['b'].value == 3 + assert g[2]['b'].value == 3 def test_update_static(self): a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) g = nnx.List([a, 3, a, nnx.Param(4)]) - g2 = nnx.graph_utils.clone(g) + g2 = nnx.graph.clone(g) g2[0]['a'] = 5 - nnx.graph_utils.graph_update_static(g, g2) + nnx.graph.graph_update_static(g, g2) assert g[0]['a'] == 5 assert g[2]['a'] == 5 @@ -95,7 +96,7 @@ def test_update_static_inconsistent_types(self): with pytest.raises( ValueError, match='Trying to update a node with a different type' ): - nnx.graph_utils.graph_update_static(g, g2) + nnx.graph.graph_update_static(g, g2) def test_update_static_add_new(self): a = nnx.Dict({'a': 1, 'b': nnx.Param(2)}) @@ -103,7 +104,7 @@ def test_update_static_add_new(self): g = nnx.List([a, 3, a, nnx.Param(4)]) g2 = nnx.List([a, 3, a, nnx.Param(4), b]) - nnx.graph_utils.graph_update_static(g, g2) + nnx.graph.graph_update_static(g, g2) assert g[4][0] == 5 assert g[4][1] == 6 @@ -114,7 +115,7 @@ def test_update_static_add_shared_error(self): g2 = nnx.List([a, 3, a, nnx.Param(4), a]) with pytest.raises(ValueError, match='Trying to add a new node at path'): - nnx.graph_utils.graph_update_static(g, g2) + nnx.graph.graph_update_static(g, g2) def test_module_list(self): rngs = nnx.Rngs(0) @@ -123,24 +124,24 @@ def test_module_list(self): nnx.BatchNorm(2, rngs=rngs), ] - static, state, _ = nnx.graph_utils.graph_flatten(ls) + graphdef, state = nnx.split(ls) - assert state[0]['kernel'].raw_value.shape == (2, 2) - assert state[0]['bias'].raw_value.shape == (2,) - assert state[1]['scale'].raw_value.shape == (2,) - assert state[1]['bias'].raw_value.shape == (2,) - assert state[1]['mean'].raw_value.shape == (2,) - assert state[1]['var'].raw_value.shape == (2,) + assert state[0]['kernel'].value.shape == (2, 2) + assert state[0]['bias'].value.shape == (2,) + assert state[1]['scale'].value.shape == (2,) + assert state[1]['bias'].value.shape == (2,) + assert state[1]['mean'].value.shape == (2,) + assert state[1]['var'].value.shape == (2,) def test_shared_variables(self): v = nnx.Param(1) g = [v, v] - static, state, _ = nnx.graph_utils.graph_flatten(g) + graphdef, state = nnx.split(g) assert len(state.flat_state()) == 1 - g2 = static.merge(state) + g2 = nnx.merge(graphdef, state) assert g2[0] is g2[1] @@ -154,11 +155,11 @@ def __init__(self, *, rngs: nnx.Rngs) -> None: self.baz.kernel = self.bar.kernel node = Foo(rngs=nnx.Rngs(0)) - static, state, _ = nnx.graph_utils.graph_flatten(node) + graphdef, state = nnx.split(node) assert len(state.flat_state()) == 3 # 2 bias + 1 kernel - node2 = static.merge(state) + node2 = nnx.merge(graphdef, state) assert node2.bar.kernel is node2.baz.kernel @@ -187,7 +188,7 @@ def __call__(self, x): return self.linear_out(x) model = Encoder(rngs=nnx.Rngs(0)) - static, state = model.split() + graphdef, state = nnx.split(model) assert len(state.flat_state()) == 1 @@ -202,19 +203,19 @@ def __init__(self): self.a = nnx.Param(1) m = Foo() - static, state = m.split() + graphdef, state = nnx.split(m) assert isinstance(m.a, nnx.Param) - assert isinstance(state.a, nnx.Param) + assert issubclass(state.a.type, nnx.Param) assert m.a is not state.a - assert m.a.value == state.a.raw_value + assert m.a.value == state.a.value - m2 = static.merge(state) + m2 = nnx.merge(graphdef, state) assert isinstance(m2.a, nnx.Param) - assert isinstance(state.a, nnx.Param) + assert issubclass(state.a.type, nnx.Param) assert m2.a is not state.a - assert m2.a.value == state.a.raw_value + assert m2.a.value == state.a.value def test_shared_state_variables_not_shared_with_graph(self): class Foo(nnx.Module): @@ -224,26 +225,26 @@ def __init__(self): self.b = p m = Foo() - static, state = m.split() + graphdef, state = nnx.split(m) assert isinstance(m.a, nnx.Param) assert isinstance(m.b, nnx.Param) - assert isinstance(state.a, nnx.Param) + assert issubclass(state.a.type, nnx.Param) assert 'b' not in state assert m.a is not state.a assert m.b is not state.a - assert m.a.value == state.a.raw_value - assert m.b.value == state.a.raw_value + assert m.a.value == state.a.value + assert m.b.value == state.a.value - m2 = static.merge(state) + m2 = nnx.merge(graphdef, state) assert isinstance(m2.a, nnx.Param) assert isinstance(m2.b, nnx.Param) - assert isinstance(state.a, nnx.Param) + assert issubclass(state.a.type, nnx.Param) assert m2.a is not state.a assert m2.b is not state.a - assert m2.a.value == state.a.raw_value - assert m2.b.value == state.a.raw_value + assert m2.a.value == state.a.value + assert m2.b.value == state.a.value assert m2.a is m2.b def test_pytree_flatten(self): @@ -254,14 +255,14 @@ class Tree: p = Tree(1, 'a') - leaves, treedef = nnx.graph_utils._flatten_pytree(p) + leaves, treedef = nnx.graph._flatten_pytree(p) fields = dict(leaves) assert 'a' in fields assert 'b' not in fields assert fields['a'] == 1 - p2 = nnx.graph_utils._unflatten_pytree(leaves, treedef) + p2 = nnx.graph._unflatten_pytree(leaves, treedef) assert isinstance(p2, Tree) assert p2.a == 1 @@ -278,13 +279,13 @@ def __init__(self): m = Foo() - static, state = m.split() + graphdef, state = nnx.split(m) assert 'tree' in state assert 'a' in state.tree - assert static.subgraphs['tree'].type is nnx.graph_utils.PytreeType + assert graphdef.nodedef.subgraphs['tree'].type is nnx.graph.PytreeType - m2 = static.merge(state) + m2 = nnx.merge(graphdef, state) assert isinstance(m2.tree, Tree) assert m2.tree.a.raw_value == 1 @@ -305,30 +306,26 @@ def f(m: Foo): a = m.a b = m.b - static: nnx.graph_utils.GraphDef[Foo] - static, state, ref_out_idx_out = nnx.graph_utils.graph_flatten(m) + graphdef: nnx.graph.GraphDef[Foo] + graphdef, state, ref_out_idx_out = nnx.graph.flatten(m) @partial(jax.jit, static_argnums=(0,)) - def f_pure(static: nnx.graph_utils.GraphDef[Foo], state): - m, idx_out_ref_in = nnx.graph_utils.graph_unflatten(static, state) + def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): + m, idx_out_ref_in = nnx.graph.unflatten(graphdef, state) f(m) - static, state, ref_in_idx_in = nnx.graph_utils.graph_flatten(m) - idx_out_idx_in = nnx.graph_utils.compose_mapping( - idx_out_ref_in, ref_in_idx_in - ) - static_out = nnx.graph_utils.Static((static, idx_out_idx_in)) + graphdef, state, ref_in_idx_in = nnx.graph.flatten(m) + idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out - static_out: nnx.graph_utils.Static - state, static_out = f_pure(static, state) + static_out: nnx.graph.Static + state, static_out = f_pure(graphdef, state) idx_out_idx_in: dict[int, int] - static, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph_utils.compose_mapping_reversed( + graphdef, idx_out_idx_in = static_out.value + idx_in_ref_out = nnx.graph.compose_mapping_reversed( ref_out_idx_out, idx_out_idx_in ) - m2, _ = nnx.graph_utils.graph_unflatten( - static, state, ref_cache=idx_in_ref_out - ) + m2, _ = nnx.graph.unflatten(graphdef, state, idxmap=idx_in_ref_out) assert m2 is m assert m2.a is b assert m2.b is a @@ -346,30 +343,26 @@ def f(m: Foo): a = m.a b = m.b - static: nnx.graph_utils.GraphDef[Foo] - static, state, ref_out_idx_out = nnx.graph_utils.graph_flatten(m) + graphdef: nnx.graph.GraphDef[Foo] + graphdef, state, ref_out_idx_out = nnx.graph.flatten(m) @partial(jax.jit, static_argnums=(0,)) - def f_pure(static: nnx.graph_utils.GraphDef[Foo], state): - m, idx_out_ref_in = nnx.graph_utils.graph_unflatten(static, state) + def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): + m, idx_out_ref_in = nnx.graph.unflatten(graphdef, state) f(m) - static, state, ref_in_idx_in = nnx.graph_utils.graph_flatten(m) - idx_out_idx_in = nnx.graph_utils.compose_mapping( - idx_out_ref_in, ref_in_idx_in - ) - static_out = nnx.graph_utils.Static((static, idx_out_idx_in)) + graphdef, state, ref_in_idx_in = nnx.graph.flatten(m) + idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out - static_out: nnx.graph_utils.Static - state, static_out = f_pure(static, state) + static_out: nnx.graph.Static + state, static_out = f_pure(graphdef, state) idx_out_idx_in: dict[int, int] - static, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph_utils.compose_mapping_reversed( + graphdef, idx_out_idx_in = static_out.value + idx_in_ref_out = nnx.graph.compose_mapping_reversed( ref_out_idx_out, idx_out_idx_in ) - m2, _ = nnx.graph_utils.graph_unflatten( - static, state, ref_cache=idx_in_ref_out - ) + m2, _ = nnx.graph.unflatten(graphdef, state, idxmap=idx_in_ref_out) assert m2 is m assert m2.a is b assert m2.b is a @@ -384,29 +377,25 @@ def f(m: Foo): m = Foo() - static: nnx.graph_utils.GraphDef[Foo] - static, state, ref_out_idx_out = nnx.graph_utils.graph_flatten(m) + graphdef: nnx.graph.GraphDef[Foo] + graphdef, state, ref_out_idx_out = nnx.graph.flatten(m) @partial(jax.jit, static_argnums=(0,)) - def f_pure(static: nnx.graph_utils.GraphDef[Foo], state): - m, idx_out_ref_in = nnx.graph_utils.graph_unflatten(static, state) + def f_pure(graphdef: nnx.graph.GraphDef[Foo], state): + m, idx_out_ref_in = nnx.graph.unflatten(graphdef, state) f(m) - static, state, ref_in_idx_in = nnx.graph_utils.graph_flatten(m) - idx_out_idx_in = nnx.graph_utils.compose_mapping( - idx_out_ref_in, ref_in_idx_in - ) - static_out = nnx.graph_utils.Static((static, idx_out_idx_in)) + graphdef, state, ref_in_idx_in = nnx.graph.flatten(m) + idx_out_idx_in = nnx.graph.compose_mapping(idx_out_ref_in, ref_in_idx_in) + static_out = nnx.graph.Static((graphdef, idx_out_idx_in)) return state, static_out - static_out: nnx.graph_utils.Static - state, static_out = f_pure(static, state) + static_out: nnx.graph.Static + state, static_out = f_pure(graphdef, state) idx_out_idx_in: dict[int, int] - static, idx_out_idx_in = static_out.value - idx_in_ref_out = nnx.graph_utils.compose_mapping_reversed( + graphdef, idx_out_idx_in = static_out.value + idx_in_ref_out = nnx.graph.compose_mapping_reversed( ref_out_idx_out, idx_out_idx_in ) - m2, _ = nnx.graph_utils.graph_unflatten( - static, state, ref_cache=idx_in_ref_out - ) + m2, _ = nnx.graph.unflatten(graphdef, state, idxmap=idx_in_ref_out) assert m2 is m assert m2.ref is m2 diff --git a/flax/experimental/nnx/tests/test_helpers.py b/flax/experimental/nnx/tests/test_helpers.py index 299791e17d..785d89ee0a 100644 --- a/flax/experimental/nnx/tests/test_helpers.py +++ b/flax/experimental/nnx/tests/test_helpers.py @@ -26,7 +26,7 @@ class TestHelpers: def test_train_state(self): m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2)) - graphdef, params, batch_stats = m.split(nnx.Param, nnx.BatchStat) + graphdef, params, batch_stats = nnx.split(m, nnx.Param, nnx.BatchStat) state = TrainState.create( graphdef, @@ -49,7 +49,7 @@ def __call__(self, x: jax.Array, train: bool) -> jax.Array: return x module = Foo(rngs=nnx.Rngs(0)) - graphdef, params, batch_stats = module.split(nnx.Param, nnx.BatchStat) + graphdef, params, batch_stats = nnx.split(module, nnx.Param, nnx.BatchStat) state = TrainState.create( graphdef, diff --git a/flax/experimental/nnx/tests/test_integration.py b/flax/experimental/nnx/tests/test_integration.py index 00cf13f499..49b58af2a7 100644 --- a/flax/experimental/nnx/tests/test_integration.py +++ b/flax/experimental/nnx/tests/test_integration.py @@ -54,10 +54,11 @@ def loss_fn(model: Model): return jnp.mean((y - y_pred) ** 2) grads = loss_fn(model) - model.update( + nnx.update( + model, jax.tree_util.tree_map( - lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads - ) + lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads + ), ) model = Model(rngs=nnx.Rngs(0)) @@ -97,7 +98,7 @@ def __call__(self, x): @jax.jit def train_step(state: nnx.State, graphdef: nnx.GraphDef[Model], x, y): - model = graphdef.merge(state) + model = nnx.merge(graphdef, state) model.set_attributes(use_running_average=False) @nnx.grad @@ -106,16 +107,17 @@ def loss_fn(model: Model): return jnp.mean((y - y_pred) ** 2) grads = loss_fn(model) - model.update( + nnx.update( + model, jax.tree_util.tree_map( - lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads - ) + lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads + ), ) - return model.split() + return nnx.split(model) graphdef: nnx.GraphDef[Model] - graphdef, state = Model(rngs=nnx.Rngs(0)).split() + graphdef, state = nnx.split(Model(rngs=nnx.Rngs(0))) x = np.random.uniform(size=(4, 2)) y = np.random.uniform(size=(4, 2)) @@ -123,7 +125,7 @@ def loss_fn(model: Model): for _i in range(3): graphdef, state = train_step(state, graphdef, x, y) - model = graphdef.merge(state) + model = nnx.merge(graphdef, state) assert model.block1.linear.bias is not None assert model.block2.linear.bias is not None @@ -161,10 +163,11 @@ def loss_fn(model): # compute gradient grads: nnx.State = nnx.grad(loss_fn, wrt=nnx.Param)(model) # SGD update - model.update( + nnx.update( + model, jax.tree_util.tree_map( - lambda w, g: w - 0.1 * g, model.extract(nnx.Param), grads - ) + lambda w, g: w - 0.1 * g, nnx.state(model, nnx.Param), grads + ), ) # execute the training step @@ -192,14 +195,14 @@ def __call__(self, x): y = model(x) assert model.count.value == 1 - graphdef, params, counts = model.split(nnx.Param, Count) + graphdef, params, counts = nnx.split(model, nnx.Param, Count) @jax.jit def train_step(params, counts, x, y): def loss_fn(params): y_pred, (_, updates) = graphdef.apply(params, counts)(x) loss = jax.numpy.mean((y_pred - y) ** 2) - return loss, updates.extract(Count) + return loss, updates.filter(Count) # compute gradient grads, counts = jax.grad(loss_fn, has_aux=True)(params) @@ -210,7 +213,7 @@ def loss_fn(params): # execute the training step params, counts = train_step(params, counts, x, y) - model = graphdef.merge(params, counts) + model = nnx.merge(graphdef, params, counts) assert model.count.value == 2 def test_intermediates_example(self): @@ -229,7 +232,7 @@ def __call__(self, x): y = model(jnp.ones((8, 12))) - intermediates = model.pop(nnx.Intermediate) + intermediates = nnx.pop(model, nnx.Intermediate) assert 'y' in intermediates @@ -247,7 +250,7 @@ def __call__(self, x): model = Linear(12, 2, rngs=nnx.Rngs(0)) - graphdef, state = model.split() + graphdef, state = nnx.split(model) y, (_, state) = graphdef.apply(state)(jnp.ones((8, 12))) diff --git a/flax/experimental/nnx/tests/test_metrics.py b/flax/experimental/nnx/tests/test_metrics.py index b2a83847b2..2e0188ee7e 100644 --- a/flax/experimental/nnx/tests/test_metrics.py +++ b/flax/experimental/nnx/tests/test_metrics.py @@ -29,8 +29,8 @@ def test_split_merge(self): accuracy = nnx.metrics.Accuracy() accuracy.update(logits=logits, labels=labels) - static, state= accuracy.split() - accuracy = static.merge(state) + graphdef, state = accuracy.split() + accuracy = nnx.merge(graphdef, state) self.assertEqual(accuracy.compute(), 0.6) accuracy.update(logits=logits2, labels=labels2) self.assertEqual(accuracy.compute(), 0.7) diff --git a/flax/experimental/nnx/tests/test_module.py b/flax/experimental/nnx/tests/test_module.py index ef44281217..46f568b103 100644 --- a/flax/experimental/nnx/tests/test_module.py +++ b/flax/experimental/nnx/tests/test_module.py @@ -51,14 +51,14 @@ def f(): def test_tree_map(self): m = nnx.Dict(a=nnx.Param(1)) - static, state = m.split() + graphdef, state = nnx.split(m) state = jax.tree_util.tree_map(lambda x: x + 1, state) def test_split_2(self): m = nnx.Dict(a=nnx.Param(1)) - empty, some, static = m.split(None, ...) + graphdef, empty, some = nnx.split(m, None, ...) some = jax.tree_util.tree_map(lambda x: x + 1, some) @@ -67,12 +67,12 @@ def test_split_merge(self): @jax.jit def g(graphdef: nnx.GraphDef[nnx.Dict[int]], state: nnx.State): - m = graphdef.merge(state) + m = nnx.merge(graphdef, state) m.a = 2 - return m.split() + return nnx.split(m) - graphdef, state = g(*m.split()) - m2 = graphdef.merge(state) + graphdef, state = g(*nnx.split(m)) + m2 = nnx.merge(graphdef, state) assert m2.a == 2 @@ -109,7 +109,7 @@ def test_shared_module(self): m1 = nnx.Dict(a=nnx.Param(1), b=nnx.Param(2)) m2 = nnx.Dict(x=m1, y=m1, z=nnx.Param(3)) - m3 = nnx.merge(*m2.split()) + m3 = nnx.merge(*nnx.split(m2)) assert m3['x'] is m3['y'] assert m3['x']['a'] is m3['y']['a'] @@ -123,10 +123,10 @@ def __init__(self): m = Foo() - graphdef, state = m.split() + graphdef, state = nnx.split(m) assert len(state) == 1 - m2 = graphdef.merge(state) + m2 = nnx.merge(graphdef, state) assert m2 is m2.sub def test_deref_through_jit(self): @@ -137,15 +137,15 @@ def test_deref_through_jit(self): @jax.jit def f(graphdef: nnx.GraphDef[nnx.Dict[Any]], state: nnx.State): - m = graphdef.merge(state) + m = nnx.merge(graphdef, state) assert m['a'][0] is m['b'] assert m['a'][1] is not m['b'] - return m.split() + return nnx.split(m) - graphdef, state = f(*m.split()) - m = graphdef.merge(state) + graphdef, state = f(*nnx.split(m)) + m = nnx.merge(graphdef, state) assert m['a'][0] is m['b'] assert m['a'][1] is not m['b'] @@ -160,12 +160,12 @@ def test_cross_barrier(self): @jax.jit def g(graphdef: nnx.GraphDef[nnx.Dict[nnx.Param[int]]], state: nnx.State): - m = graphdef.merge(state) + m = nnx.merge(graphdef, state) m.a.value += 1 - return m.split() + return nnx.split(m) - graphdef, state = g(*m.split()) - m2 = graphdef.merge(state) + graphdef, state = g(*nnx.split(m)) + m2 = nnx.merge(graphdef, state) assert m2 is not m assert m.a.value == 1 assert m2.a.value == 2 @@ -180,23 +180,23 @@ def g(state_and_def): n += 1 m = nnx.merge(*state_and_def) m.a.value += 1 - return m.split() + return nnx.split(m) - m2 = nnx.merge(*g(m.split())) + m2 = nnx.merge(*g(nnx.split(m))) assert n == 1 assert m2 is not m assert m.a.value == 1 assert m2.a.value == 2 - g(m.split()) + g(nnx.split(m)) assert n == 1 - g(m2.split()) + g(nnx.split(m2)) assert n == 1 m2.b = nnx.Param(10) - g(m2.split()) + g(nnx.split(m2)) assert n == 2 @@ -211,7 +211,7 @@ def test_deref_number_of_fields(self): } ) - graphdef, p = m.split() + graphdef, p = nnx.split(m) assert len(p.flat_state()) == 2 assert len(jax.tree_util.tree_leaves(p)) == 2 @@ -221,7 +221,7 @@ def test_clone(self): b=nnx.Dict(c=nnx.Param(1), d=nnx.Param(2)), ) - m2 = m.clone() + m2 = nnx.clone(m) assert m is not m2 assert m2.a[0] == m2.b.c @@ -247,10 +247,10 @@ def __call__(self, x): assert y2 == 11 assert m.y.value == (3, 11) - intermediates = m.pop(nnx.Intermediate) + intermediates = nnx.pop(m, nnx.Intermediate) - assert isinstance(intermediates.y, nnx.Intermediate) - assert intermediates['y'].raw_value == (3, 11) + assert issubclass(intermediates.y.type, nnx.Intermediate) + assert intermediates['y'].value == (3, 11) assert not hasattr(m, 'y') @@ -284,32 +284,6 @@ def __call__(self, x): with pytest.raises(ValueError, match='to be of type'): m(2) - def test_update_static_state(self): - class Foo(nnx.Module): - def add_field(self): - self.a = 1 - - m1 = Foo() - m2 = Foo() - m2.add_field() - - m1.update(m2) - - assert m1.a == 1 - - def test_update_moduledef(self): - class Foo(nnx.Module): - def add_field(self): - self.a = 1 - - m1 = Foo() - m2 = Foo() - m2.add_field() - - m1.update(m2.get_graphdef()) - - assert m1.a == 1 - def test_update_static_state_submodules(self): class Bar(nnx.Module): def __init__(self) -> None: @@ -324,10 +298,13 @@ def __init__(self) -> None: self.b = self.a m1 = Foo() - m2 = Foo() - m2.a.add_field() + with nnx.UpdateContext() as ctx: + graphdef, state = ctx.split(m1) + m2 = ctx.merge(graphdef, state) + m2.a.add_field() + new_graphdef, state = ctx.split(m2) - m1.update(m2) + ctx.update(new_graphdef, state) assert m1.a.x == 1 assert m1.a.y == 2 @@ -347,10 +324,13 @@ def add_module(self): self.b = Bar() m1 = Foo() - m2 = Foo() + ctx = nnx.UpdateContext() + graphdef, state = ctx.split(m1) + m2 = ctx.merge(graphdef, state) m2.add_module() + new_graphdef, state = ctx.split(m2) - m1.update(m2) + ctx.update(new_graphdef, state) assert m1.a.x == 1 assert m1.b.x == 1 @@ -366,15 +346,17 @@ def __init__(self) -> None: self.b = self.a m1 = Foo() - m2 = Foo() + ctx = nnx.UpdateContext() + graphdef, state = ctx.split(m1) + m2 = ctx.merge(graphdef, state) m2.a.x = 2 - - m1.update(m2) + new_graphdef, state = ctx.split(m2) + ctx.update(new_graphdef, state) assert m1.a.x == 2 assert m1.b.x == 2 - def test_update_add_shared_error(self): + def test_update_add_shared(self): class Bar(nnx.Module): def __init__(self) -> None: self.x = 1 @@ -388,53 +370,36 @@ def add_submodule(self): self.c = self.a m1 = Foo() - m2 = Foo() - m2.add_submodule() - - assert hasattr(m2, 'c') - - with pytest.raises(ValueError, match='Trying to add a new node at path'): - m1.update(m2) - - def test_update_add_shared_error_new_first(self): - class Bar(nnx.Module): - def __init__(self) -> None: - self.x = 1 - - class Foo(nnx.Module): - def __init__(self) -> None: - self.b = Bar() - self.c = self.b - - def add_submodule(self): - self.a = self.b - - m1 = Foo() - m2 = Foo() + ctx = nnx.UpdateContext() + graphdef, state = ctx.split(m1) + m2 = ctx.merge(graphdef, state) m2.add_submodule() + new_graphdef, state = ctx.split(m2) + ctx.update(new_graphdef, state) - assert hasattr(m2, 'a') - - m2 = m2.clone() # clone to sort the fields - - with pytest.raises(ValueError, match='Trying to update a node at path'): - m1.update(m2) + assert hasattr(m1, 'c') def test_create_abstract(self): - linear = nnx.Linear.create_abstract(2, 3, rngs=nnx.Rngs(0)) + linear = nnx.eval_shape(lambda: nnx.Linear(2, 3, rngs=nnx.Rngs(0))) assert linear.kernel.value == jax.ShapeDtypeStruct((2, 3), jnp.float32) assert linear.bias.value == jax.ShapeDtypeStruct((3,), jnp.float32) def test_partial_init(self): linear = nnx.Linear(2, 3, rngs=nnx.Rngs(0)) - state = linear.get_state() + state = nnx.state(linear) del state['bias'] - linear2 = nnx.Linear.partial_init(state)( - 2, 3, bias_init=nnx.initializers.ones_init(), rngs=nnx.Rngs(1) - ) + @nnx.jit + def partial_init(state: nnx.State): + m = nnx.Linear( + 2, 3, bias_init=nnx.initializers.ones_init(), rngs=nnx.Rngs(1) + ) + nnx.update(m, state) + return m + + linear2 = partial_init(state) np.testing.assert_allclose(linear.kernel.value, linear2.kernel.value) np.testing.assert_allclose(linear.bias.value, 0) @@ -506,20 +471,44 @@ def __init__(self, din, dout, *, rngs: nnx.Rngs): raise_if_not_found=False, ) + def test_init(self): + class Linear(nnx.Module): + def __init__(self, dout, rngs: nnx.Rngs): + self.dout = dout + self.rngs = rngs + + def __call__(self, x): + if self.is_initializing(): + din = x.shape[-1] + if not hasattr(self, 'w'): + key = self.rngs.params() + self.w = nnx.Param(jax.random.uniform(key, (din, self.dout))) + if not hasattr(self, 'b'): + self.b = nnx.Param(jnp.zeros((self.dout,))) + return x @ self.w + self.b[None] + + linear = Linear(3, nnx.Rngs(0)) + x = jnp.ones((5, 2)) + y = linear.init(x) + assert linear.w.value.shape == (2, 3) + assert linear.b.value.shape == (3,) + assert y.shape == (5, 3) + assert not linear.is_initializing() + class TestModulePytree: def test_tree_map(self): class Foo(nnx.Module, experimental_pytree=True): def __init__(self): self.node = nnx.Param(1) - self.static = 1 + self.graphdef = 1 m = Foo() m = jax.tree_util.tree_map(lambda x: x + 1, m) assert m.node.value == 2 - assert m.static == 1 + assert m.graphdef == 1 class TestModuleDataclass: @@ -534,21 +523,25 @@ class Foo(nnx.Module): f: int m = Foo( - a=1, # static + a=1, # graphdef b=nnx.Variable(2), # node c=nnx.Param(3), # param d=nnx.Variable(4), # var e=nnx.BatchStat(5), # var - f=6, # static int + f=6, # graphdef int ) - graphdef, state = m.split() + graphdef, state = nnx.split(m) assert len(state) == 4 - assert state.b == nnx.Variable(2) - assert state.c == nnx.Param(3) - assert state.d == nnx.Variable(4) - assert state.e == nnx.BatchStat(5) + assert state.b.value == 2 + assert state.b.type == nnx.Variable + assert state.c.value == 3 + assert state.c.type == nnx.Param + assert state.d.value == 4 + assert state.d.type == nnx.Variable + assert state.e.value == 5 + assert state.e.type == nnx.BatchStat def test_post_init(self): @dataclasses.dataclass @@ -599,11 +592,10 @@ def __call__(self, x, *, rngs: nnx.Rngs): rngs = nnx.Rngs(0) foo = Foo(c=1.0, rngs=rngs) - graphdef, states = foo.split() + graphdef, states = nnx.split(foo) assert isinstance(states, nnx.State) - assert isinstance(states.w, nnx.Param) - # assert isinstance(states["c"], jax.Array) + assert issubclass(states.w.type, nnx.Param) y, _updates = graphdef.apply(states)(x=2.0, rngs=nnx.Rngs(e=1)) @@ -623,12 +615,12 @@ def __call__(self, x, *, rngs: nnx.Rngs): foo = Foo(c=1.0, rngs=nnx.Rngs(0)) - graphdef, state = foo.split() + graphdef, state = nnx.split(foo) assert isinstance(graphdef, nnx.GraphDef) assert isinstance(state, nnx.State) - assert isinstance(state.w, nnx.Param) - assert isinstance(state.c, nnx.Variable) + assert issubclass(state.w.type, nnx.Param) + assert issubclass(state.c.type, nnx.Variable) y, (graphdef, state) = graphdef.apply(state)(x=2.0, rngs=nnx.Rngs(e=1)) @@ -644,7 +636,7 @@ def __init__(self, *, rngs: nnx.Rngs): module = Foo(rngs=nnx.Rngs(0)) - modules = list(module.modules()) + modules = list(module.iter_modules()) assert len(modules) == 3 assert modules[0][0] == () @@ -661,12 +653,12 @@ def __init__(self): foo = Foo() - graphdef, state = foo.split() + graphdef, state = nnx.split(foo) assert isinstance(state, nnx.State) assert isinstance(state.a, jax.Array) - foo2 = graphdef.merge(state) + foo2 = nnx.merge(graphdef, state) assert isinstance(foo2.a, jax.Array) @@ -677,11 +669,11 @@ def __init__(self): foo = Foo() - graphdef, state = foo.split() + graphdef, state = nnx.split(foo) assert isinstance(state, nnx.State) assert isinstance(state.a, nnx.State) - foo2 = graphdef.merge(state) + foo2 = nnx.merge(graphdef, state) assert isinstance(foo2.a, nnx.State) diff --git a/flax/experimental/nnx/tests/test_optimizer.py b/flax/experimental/nnx/tests/test_optimizer.py index d7b29f1372..6890056905 100644 --- a/flax/experimental/nnx/tests/test_optimizer.py +++ b/flax/experimental/nnx/tests/test_optimizer.py @@ -41,8 +41,8 @@ def test_split_merge(self, module_cls): tx = optax.adam(1e-3) state = nnx.Optimizer(model, tx) out = state.model(x) - static, state = state.split() - state = static.merge(state) + graphdef, state = state.split() + state = nnx.merge(graphdef, state) np.testing.assert_allclose(out, state.model(x)) @parameterized.product( @@ -58,28 +58,30 @@ def test_jit(self, module_cls, jit_decorator, optimizer): state = nnx.Optimizer(model, tx) if jit_decorator == jax.jit: - model_static, model_state = state.model.split() - loss_fn = lambda static, state, x, y: ((static.merge(state)(x)-y)**2).mean() + model_static, model_state = nnx.split(state.model) + loss_fn = lambda graphdef, state, x, y: ( + (nnx.merge(graphdef, state)(x) - y) ** 2 + ).mean() initial_loss = loss_fn(model_static, model_state, x, y) - def train_step(static, state, x, y): - state = static.merge(state) - model_static, model_state = state.model.split() + def train_step(graphdef, state, x, y): + state = nnx.merge(graphdef, state) + model_static, model_state = nnx.split(state.model) grads = jax.grad(loss_fn, argnums=1)(model_static, model_state, x, y) state.update(grads) return state.split() - static, state = jit_decorator(train_step)(*state.split(), x, y) - state = static.merge(state) - new_loss = loss_fn(*state.model.split(), x, y) + graphdef, state = jit_decorator(train_step)(*state.split(), x, y) + state = nnx.merge(graphdef, state) + new_loss = loss_fn(*nnx.split(state.model), x, y) else: loss_fn = lambda model, x, y: ((model(x)-y)**2).mean() initial_loss = loss_fn(state.model, x, y) - def train_step(state, x, y): - grads = nnx.grad(loss_fn, wrt=nnx.Param)(state.model, x, y) - state.update(grads) + def train_step(optimizer: nnx.Optimizer, x, y): + grads = nnx.grad(loss_fn, wrt=nnx.Param)(optimizer.model, x, y) + optimizer.update(grads) jit_decorator(train_step)(state, x, y) new_loss = loss_fn(state.model, x, y) diff --git a/flax/experimental/nnx/tests/test_partitioning.py b/flax/experimental/nnx/tests/test_partitioning.py index ae8fff0b96..b2d5fdfdc5 100644 --- a/flax/experimental/nnx/tests/test_partitioning.py +++ b/flax/experimental/nnx/tests/test_partitioning.py @@ -27,19 +27,19 @@ def test_partition(self): c=100, ) - graphdef, params, rest = m.split(nnx.Param, ...) + graphdef, params, rest = nnx.split(m, nnx.Param, ...) assert len(params) == 2 assert len(rest) == 1 # check params - assert params['a'][0].raw_value == m.a[0].value - assert params['b'].raw_value == m.b.value + assert params['a'][0].value == m.a[0].value + assert params['b'].value == m.b.value # check rest - assert rest['a'][1].raw_value == m.a[1].value + assert rest['a'][1].value == m.a[1].value - m2 = graphdef.merge(params, rest) + m2 = nnx.merge(graphdef, params, rest) assert m2.a[0].value == m.a[0].value assert m2.a[1].value == m.a[1].value @@ -53,7 +53,7 @@ def test_complete_partitioning(self): ) # no error - m.split(nnx.Param, nnx.BatchStat, nnx.Variable) + nnx.split(m, nnx.Param, nnx.BatchStat, nnx.Variable) def test_complete_partitioning_plus_ellipsis(self): m = nnx.Dict( @@ -62,7 +62,7 @@ def test_complete_partitioning_plus_ellipsis(self): ) # no error if additional ... is passed at the end - m.split(nnx.Param, nnx.BatchStat, nnx.Variable, ...) + nnx.split(m, nnx.Param, nnx.BatchStat, nnx.Variable, ...) def test_inclomplete_partition_error(self): m = nnx.Dict( @@ -73,7 +73,7 @@ def test_inclomplete_partition_error(self): with pytest.raises( ValueError, match='Non-exhaustive filters, got a non-empty remainder' ): - m.split(nnx.Param) + nnx.split(m, nnx.Param) def test_ellipsis_not_last_error(self): m = nnx.Dict( @@ -82,9 +82,9 @@ def test_ellipsis_not_last_error(self): ) with pytest.raises( - ValueError, match='Ellipsis `...` can only be used as the last filter,' + ValueError, match='`...` or `True` can only be used as the last filters' ): - m.split(..., nnx.Param) + nnx.split(m, ..., nnx.Param) def test_update_from(self): m = nnx.Dict( @@ -93,10 +93,12 @@ def test_update_from(self): c=100, ) - state = m.split()[1] + state = nnx.split( + m, + )[1] state = jax.tree_util.tree_map(lambda x: x * 2, state) - m.update(state) + nnx.update(m, state) assert m.a[0].value == 2 assert m.a[1].value == 6 @@ -110,10 +112,12 @@ def test_update_from_with_array_leaf(self): c=nnx.Variable(jax.numpy.array(100)), ) - graphdef, state = m.split() + graphdef, state = nnx.split( + m, + ) state = jax.tree_util.tree_map(lambda x: x * 2, state) - m.update(state) + nnx.update(m, state) assert m.a[0].value == 2 assert m.a[1].value == 6 @@ -127,13 +131,13 @@ def test_grad_example(self): c=100, ) - params = m.extract(nnx.Param) + params = nnx.state(m, nnx.Param) def loss(params): return sum(2 * p for p in jax.tree_util.tree_leaves(params)) grads = jax.grad(loss)(params) - m.update(grads) + nnx.update(m, grads) assert m.a[0].value == 2.0 assert m.a[1].value == -10 @@ -151,9 +155,9 @@ def test_get_paritition(self): # test Variables not shared assert vars(m.a)['0'] is not vars(m)['b'] - state = m.extract(nnx.Variable) - assert state['a'][0].raw_value == m.a[0].value - assert state['a'][1].raw_value == m.a[1].value - assert state['b'].raw_value == m.b.value + state = nnx.state(m, nnx.Variable) + assert state['a'][0].value == m.a[0].value + assert state['a'][1].value == m.a[1].value + assert state['b'].value == m.b.value assert state.b is not state.a[0] assert len(state.flat_state()) == 3 diff --git a/flax/experimental/nnx/tests/test_rngs.py b/flax/experimental/nnx/tests/test_rngs.py index 7aac8e7aa1..f3e7e6585a 100644 --- a/flax/experimental/nnx/tests/test_rngs.py +++ b/flax/experimental/nnx/tests/test_rngs.py @@ -51,17 +51,6 @@ def test_rng_stream(self): assert rngs.params.key.value is key0 assert not jnp.allclose(key1, key2) - def test_rng_fork(self): - key0 = jax.random.key(0) - rngs1 = nnx.Rngs(params=key0) - rngs2 = nnx.Rngs(rngs1.fork()) - - assert rngs2['params'].count.value == 0 - - key1 = rngs1.params() - key2 = rngs2.params() - - assert not jnp.allclose(key1, key2) def test_rng_trace_level_constraints(self): rngs = nnx.Rngs(0) @@ -76,16 +65,6 @@ def f(): f() - @jax.jit - def g(): - with pytest.raises( - nnx.errors.TraceContextError, - match='Cannot call RngStream from a different trace level', - ): - rngs.fork() - - g() - rngs1: Any = None @jax.jit @@ -102,87 +81,6 @@ def h(): ): rngs1.params() - def test_partition_merge(self): - rngs = nnx.Rngs(dropout=0) - - keys = rngs.fork() - - assert 'dropout' in keys - - rngs2 = nnx.Rngs(keys) - - key1 = rngs.dropout() - key2 = rngs2.dropout() - assert not jnp.allclose(key1, key2) - - rngs3 = nnx.Rngs(keys) - key3 = rngs3.dropout() - assert jnp.allclose(key2, key3) - - def test_fork_broadcast(self): - rngs = nnx.Rngs(params=0, dropout=1) - jax.random.key - - keys = rngs.fork() # all broadcast - - assert keys['params'].shape == () - assert keys['dropout'].shape == () - assert jnp.allclose( - keys['params'], jax.random.fold_in(jax.random.key(0), 0) - ) - assert jnp.allclose( - keys['dropout'], jax.random.fold_in(jax.random.key(1), 0) - ) - - def test_fork_split(self): - rngs = nnx.Rngs(params=0, dropout=1) - keys = rngs.fork(4) # split all - - assert keys['params'].shape == (4,) - assert keys['dropout'].shape == (4,) - - def test_fork_split_and_broadcast(self): - rngs = nnx.Rngs(params=0, dropout=1) - forked = rngs.fork(params=4, dropout=None) - - assert forked['params'].shape == (4,) - assert forked['dropout'].shape == () - - def test_fork_filters(self): - rngs = nnx.Rngs(params=0, dropout=1) - forked = rngs.fork({'params': 4}) - - assert forked['params'].shape == (4,) - assert forked['dropout'].shape == () - - def test_fork_multidimensional_split(self): - rngs = nnx.Rngs(params=0, dropout=1) - keys = rngs.fork((4, None, 3)) # split all - - assert keys['params'].shape == (4, 1, 3) - assert keys['dropout'].shape == (4, 1, 3) - - def test_fork_multidimensional_split_mixed(self): - rngs = nnx.Rngs(params=0, dropout=1) - keys = rngs.fork(params=(4, None, 3)) # split all - - assert keys['params'].shape == (4, 1, 3) - assert keys['dropout'].shape == () - - def test_rng_stream_pytree(self): - rngs = nnx.Rngs(params=0, dropout=1) - - keys = rngs.fork(dropout=4) - keys2 = jax.tree_util.tree_map(lambda x: x, keys) - - assert 'dropout' in keys.splits - assert 'params' in keys.broadcasts - - assert keys2 is not keys - assert set(keys.keys()) == set(keys2.keys()) - assert set(keys.splits.keys()) == set(keys2.splits.keys()) - assert set(keys.broadcasts.keys()) == set(keys2.broadcasts.keys()) - def test_jit_updates(self): class Foo(nnx.Module): def __init__(self, not_rngs): @@ -228,8 +126,8 @@ def __call__(self, x): rngs = nnx.Rngs(params=0, dropout=1) m = Foo(rngs) - _, params, dropout_keys, param_keys, rng_counts = m.split( - nnx.Param, 'dropout', 'params', nnx.RngCount + _, params, dropout_keys, param_keys, rng_counts = nnx.split( + m, nnx.Param, 'dropout', 'params', nnx.RngCount ) assert m.rngs.params.count.value == 2 @@ -253,10 +151,10 @@ def __call__(self, x): out_axes=(0, 0, None), ) def f(params, dropout_keys, param_keys, rng_counts, x): - m.update(params, dropout_keys, param_keys, rng_counts) + nnx.update(m, params, dropout_keys, param_keys, rng_counts) y = m(x) - _, params, dropout_keys, param_keys, rng_counts = m.split( - nnx.Param, 'dropout', 'params', nnx.RngCount + _, params, dropout_keys, param_keys, rng_counts = nnx.split( + m, nnx.Param, 'dropout', 'params', nnx.RngCount ) return y, params, rng_counts @@ -269,8 +167,57 @@ def f(params, dropout_keys, param_keys, rng_counts, x): x, ) - m.update(params, dropout_keys, param_keys, rng_counts) + nnx.update(m, params, dropout_keys, param_keys, rng_counts) assert y.shape == (4, 1, 3) assert m.rngs.params.count.value == 2 assert m.rngs['dropout'].count.value == 1 + + def test_state_fork_split(self): + rngs = nnx.Rngs(params=0, dropout=1) + graphdef, state = nnx.split(rngs, nnx.RngState) + split, broadcast = nnx.fork(state, ..., 4) + + assert len(jax.tree.leaves(split)) == 2 + assert len(jax.tree.leaves(broadcast)) == 2 + assert split.params.key.value.shape == (4,) + assert split.dropout.key.value.shape == (4,) + assert broadcast.params.count.value == 0 + assert broadcast.dropout.count.value == 0 + + def test_state_fork_split_and_broadcast(self): + rngs = nnx.Rngs(params=0, dropout=1) + graphdef, state = nnx.split(rngs, nnx.RngState) + split, broadcast = nnx.fork(state, 'params', 4) + + assert len(jax.tree.leaves(split)) == 1 + assert len(jax.tree.leaves(broadcast)) == 3 + assert split.params.key.value.shape == (4,) + assert broadcast.dropout.key.value.shape == () + assert broadcast.params.count.value == 0 + assert broadcast.dropout.count.value == 0 + + + def test_state_fork_multidimensional_split(self): + rngs = nnx.Rngs(params=0, dropout=1) + graphdef, state = nnx.split(rngs, nnx.RngState) + split, broadcast = nnx.fork(state, ..., (4, None, 3)) + + assert len(jax.tree.leaves(split)) == 2 + assert len(jax.tree.leaves(broadcast)) == 2 + assert split.params.key.value.shape == (4, 1, 3) + assert split.dropout.key.value.shape == (4, 1, 3) + assert broadcast.params.count.value == 0 + assert broadcast.dropout.count.value == 0 + + def test_state_fork_multidimensional_split_mixed(self): + rngs = nnx.Rngs(params=0, dropout=1) + graphdef, state = nnx.split(rngs, nnx.RngState) + split, broadcast = nnx.fork(state, 'params', (4, None, 3)) + + assert len(jax.tree.leaves(split)) == 1 + assert len(jax.tree.leaves(broadcast)) == 3 + assert split.params.key.value.shape == (4, 1, 3) + assert broadcast.dropout.key.value.shape == () + assert broadcast.params.count.value == 0 + assert broadcast.dropout.count.value == 0 diff --git a/flax/experimental/nnx/tests/test_spmd.py b/flax/experimental/nnx/tests/test_spmd.py index b332fa48cd..8abd5827f9 100644 --- a/flax/experimental/nnx/tests/test_spmd.py +++ b/flax/experimental/nnx/tests/test_spmd.py @@ -39,7 +39,7 @@ def __call__(self, x): @jax.jit def create_module(): - return Foo().split() + return nnx.split(Foo()) mesh = Mesh(mesh_utils.create_device_mesh((2, 2)), ('model', 'data')) @@ -49,6 +49,31 @@ def create_module(): assert m.w.shape == (8, 2) assert m.w.sharding.shard_shape(m.w.shape) == (4, 1) + def test_init_all_devices(self): + class Foo(nnx.Module): + def __init__(self): + self.w = nnx.Param( + nnx.with_partitioning( + lambda: jnp.ones((8, 2)), + sharding=('model', 'data'), + )() + ) + + def __call__(self, x): + return x @ self.w + + @jax.jit + def create_module(): + return nnx.split(Foo()) + + mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ('model', 'data')) + + with mesh: + m: Foo = nnx.merge(*create_module()) + + assert m.w.value.shape == (8, 2) + assert m.w.value.sharding.shard_shape(m.w.value.shape) == (8, 2) + def test_get_partition_spec(self): class Foo(nnx.Module): def __init__(self): @@ -62,7 +87,7 @@ def __init__(self): def __call__(self, x): return x @ self.w - graphdef, params = Foo().split() + graphdef, params = nnx.split(Foo()) state = nnx.TrainState.create( graphdef, params=params, @@ -70,10 +95,6 @@ def __call__(self, x): ) state_spec = nnx.get_partition_spec(state) - assert state_spec.params['w'].raw_value == PartitionSpec('row', 'col') - assert state_spec.opt_state[0].mu['w'].raw_value == PartitionSpec( - 'row', 'col' - ) - assert state_spec.opt_state[0].nu['w'].raw_value == PartitionSpec( - 'row', 'col' - ) + assert state_spec.params['w'].value == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].mu['w'].value == PartitionSpec('row', 'col') + assert state_spec.opt_state[0].nu['w'].value == PartitionSpec('row', 'col') diff --git a/flax/experimental/nnx/tests/test_state.py b/flax/experimental/nnx/tests/test_state.py index 3a9d73475c..1397284168 100644 --- a/flax/experimental/nnx/tests/test_state.py +++ b/flax/experimental/nnx/tests/test_state.py @@ -19,36 +19,36 @@ class StateTest(TestCase): def test_create_state(self): - state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) + state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) - assert state['a'].raw_value == 1 - assert state['b']['c'].raw_value == 2 + assert state['a'].value == 1 + assert state['b']['c'].value == 2 def test_get_attr(self): - state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) + state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) - assert state.a.raw_value == 1 - assert state.b.c.raw_value == 2 + assert state.a.value == 1 + assert state.b.c.value == 2 def test_set_attr(self): - state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) + state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) - state.a.raw_value = 3 - state.b.c.raw_value = 4 + state.a.value = 3 + state.b.c.value = 4 - assert state['a'].raw_value == 3 - assert state['b']['c'].raw_value == 4 + assert state['a'].value == 3 + assert state['b']['c'].value == 4 def test_set_attr_variables(self): - state = nnx.State({'a': nnx.Param(1), 'b': {'c': nnx.Param(2)}}) + state = nnx.State({'a': nnx.Param.state(1), 'b': {'c': nnx.Param.state(2)}}) - state.a.raw_value = 3 - state.b.c.raw_value = 4 + state.a.value = 3 + state.b.c.value = 4 - assert isinstance(state.a, nnx.Param) - assert state.a.raw_value == 3 - assert isinstance(state.b.c, nnx.Param) - assert state.b.c.raw_value == 4 + assert issubclass(state.a.type, nnx.Param) + assert state.a.value == 3 + assert issubclass(state.b.c.type, nnx.Param) + assert state.b.c.value == 4 def test_integer_access(self): class Foo(nnx.Module): @@ -56,9 +56,9 @@ def __init__(self, *, rngs: nnx.Rngs): self.layers = [nnx.Linear(1, 2, rngs=rngs), nnx.Linear(2, 3, rngs=rngs)] module = Foo(rngs=nnx.Rngs(0)) - state = module.get_state() + state = nnx.state(module) assert module.layers[0].kernel.value.shape == (1, 2) - assert state.layers[0].kernel.raw_value.shape == (1, 2) + assert state.layers[0].kernel.value.shape == (1, 2) assert module.layers[1].kernel.value.shape == (2, 3) - assert state.layers[1].kernel.raw_value.shape == (2, 3) + assert state.layers[1].kernel.value.shape == (2, 3) diff --git a/flax/experimental/nnx/tests/test_transforms.py b/flax/experimental/nnx/tests/test_transforms.py index 709b61a33e..db6cb24288 100644 --- a/flax/experimental/nnx/tests/test_transforms.py +++ b/flax/experimental/nnx/tests/test_transforms.py @@ -110,7 +110,7 @@ def __call__(self, x: jax.Array) -> jax.Array: n += 1 return jnp.dot(x, self.w.value) - m = nnx.JIT(Foo)(2, 3, rngs=nnx.Rngs(0)) + m = nnx.Jit(Foo)(2, 3, rngs=nnx.Rngs(0)) y = m(jnp.ones((1, 2))) assert y.shape == (1, 3) @@ -334,7 +334,7 @@ def test_apply_shardings(self): ), ) - @partial(nnx.jit, constrain_object_state=True) + @partial(nnx.jit, constrain_state=True) def constrain_object(m): pass @@ -366,13 +366,13 @@ def f(m: nnx.Dict): assert m.a[0] is m.b assert isinstance(grads, nnx.State) - assert grads['a'][0].raw_value == 2.0 - assert isinstance(grads.a[0], nnx.Variable) - assert grads['a'][1].raw_value == 1.0 - assert isinstance(grads.a[1], nnx.Variable) + assert grads['a'][0].value == 2.0 + assert issubclass(grads.a[0].type, nnx.Variable) + assert grads['a'][1].value == 1.0 + assert issubclass(grads.a[1].type, nnx.Variable) assert len(grads.flat_state()) == 2 - m.update(grads) + nnx.update(m, grads) assert m.a[0] is m.b assert m['a'][0].value == 2.0 @@ -397,11 +397,11 @@ def f(m: nnx.Dict): grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a'][0].raw_value == 1.0 - assert isinstance(grads.a[0], nnx.Param) + assert grads['a'][0].value == 1.0 + assert issubclass(grads.a[0].type, nnx.Param) assert len(grads) == 2 - m.update(grads) + nnx.update(m, grads) assert m.a[0].value == 1.0 assert m.a[1].value == 20.0 @@ -425,11 +425,11 @@ def f(m: nnx.Dict): grads = f(m) assert isinstance(grads, nnx.State) - assert grads['a'][1].raw_value == 1.0 - assert isinstance(grads.a[1], nnx.BatchStat) + assert grads['a'][1].value == 1.0 + assert issubclass(grads.a[1].type, nnx.BatchStat) assert len(grads) == 1 - m.update(grads) + nnx.update(m, grads) assert m.a[0].value == 10.0 assert m.a[1].value == 1.0 @@ -447,9 +447,9 @@ def test_multiple_inputs(self): grads = grad_fn(m, x, y) assert 'kernel' in grads - assert grads.kernel.raw_value.shape == (2, 3) + assert grads.kernel.value.shape == (2, 3) assert 'bias' in grads - assert grads.bias.raw_value.shape == (3,) + assert grads.bias.value.shape == (3,) def test_multiple_graph_nodes(self): rngs = nnx.Rngs(0) @@ -462,13 +462,13 @@ def test_multiple_graph_nodes(self): grads_m1, grads_m2 = grad_fn(m1, m2, x, y) assert 'kernel' in grads_m1 - assert grads_m1.kernel.raw_value.shape == (2, 3) + assert grads_m1.kernel.value.shape == (2, 3) assert 'bias' in grads_m1 - assert grads_m1.bias.raw_value.shape == (3,) + assert grads_m1.bias.value.shape == (3,) assert 'kernel' in grads_m2 - assert grads_m2.kernel.raw_value.shape == (3, 3) + assert grads_m2.kernel.value.shape == (3, 3) assert 'bias' in grads_m2 - assert grads_m2.bias.raw_value.shape == (3,) + assert grads_m2.bias.value.shape == (3,) def test_multiple_graph_nodes_mix_positions(self): rngs = nnx.Rngs(0) @@ -481,17 +481,48 @@ def test_multiple_graph_nodes_mix_positions(self): grads_m1, grads_m2 = grad_fn(x, m1, y, m2) assert 'kernel' in grads_m1 - assert grads_m1.kernel.raw_value.shape == (2, 3) + assert grads_m1.kernel.value.shape == (2, 3) assert 'bias' in grads_m1 - assert grads_m1.bias.raw_value.shape == (3,) + assert grads_m1.bias.value.shape == (3,) assert 'kernel' in grads_m2 - assert grads_m2.kernel.raw_value.shape == (3, 3) + assert grads_m2.kernel.value.shape == (3, 3) assert 'bias' in grads_m2 - assert grads_m2.bias.raw_value.shape == (3,) + assert grads_m2.bias.value.shape == (3,) class TestScan: def test_basic(self): + class Block(nnx.Module): + def __init__(self, *, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + # self.node = nnx.Variable(jnp.ones((2,))) + + def __call__(self, x: jax.Array): + x = self.linear(x) + x = nnx.gelu(x) + return x + + @partial(nnx.scan, state_axes={nnx.Param: 0}, length=5) + def create_block(_, rngs: nnx.Rngs): + return None, Block(rngs=rngs) + + _, module = create_block(None, nnx.Rngs(0)) + + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + # assert module.node.value.shape == (2,) + + @partial(nnx.scan, in_axes=None, state_axes={nnx.Param: 0}, length=5) + def forward_block(_, block: Block, x: jax.Array): + return None, block(x) + + x = jnp.ones((1, 3)) + out, y = forward_block(None, module, x) + + assert y.shape == (5, 1, 3) + assert out is None + + def test_basic_combinator(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -504,7 +535,7 @@ def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: MLP = nnx.Scan( Block, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, ) @@ -533,7 +564,7 @@ def __call__(self, x: jax.Array): MLP = nnx.Scan( Block, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, scan_output=False, ) @@ -562,7 +593,7 @@ def __call__(self, x: jax.Array): MLP = nnx.Scan( Block, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, out_axes=(1, 2), ) @@ -597,7 +628,7 @@ def __call__( MLP = nnx.Scan( Block, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, ) @@ -632,9 +663,9 @@ def __call__( MLP = nnx.Scan( Block, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, - in_args_axes=(0, None), + in_axes=(None, None, 0, None), ) module = MLP(rngs=nnx.Rngs(0)) @@ -667,7 +698,7 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: return x MLP = nnx.Scan( - Block, variable_axes={nnx.Param: 0}, length=5, scan_output=False + Block, state_axes={nnx.Param: 0}, length=5, scan_output=False ) module = MLP(rngs=nnx.Rngs(0)) @@ -699,10 +730,10 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: MLP = nnx.Scan( Block, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, # params is split, dropout is broadcast - broadcast_rngs=['dropout'], + split_rngs=['dropout'], scan_output=False, ) @@ -719,14 +750,12 @@ def __call__(self, x: jax.Array, *, rngs: nnx.Rngs) -> jax.Array: assert y.shape == (1, 3) def test_complex_decorator(self): - scan_over_layers = partial( - nnx.scan, - variable_axes={nnx.Param: 0}, - length=5, - ) - class Block(nnx.Module): - @scan_over_layers + @partial( + nnx.vmap, + state_axes={nnx.Param: 0}, + axis_size=5, + ) def __init__(self, *, rngs: nnx.Rngs): self.d = 3 self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -734,7 +763,12 @@ def __init__(self, *, rngs: nnx.Rngs): self.dropout = nnx.Dropout(0.5) self.node = nnx.Variable(jnp.ones((2,))) - @scan_over_layers + @partial( + nnx.scan, + state_axes={nnx.Param: 0}, + length=5, + carry_argnum=1, + ) def __call__( self, x: jax.Array, _, *, rngs: nnx.Rngs ) -> tp.Tuple[jax.Array, None]: @@ -779,26 +813,26 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: x = self.linear(x) # test sharding layer axes is not present inside scan - state = self.linear.get_state() - assert state.kernel.raw_value.shape == (3, 3) + state = nnx.state(self.linear) + assert state.kernel.value.shape == (3, 3) assert state.kernel.sharding == ('din', 'dout') - assert state.bias.raw_value.shape == (3,) + assert state.bias.value.shape == (3,) assert state.bias.sharding == ('dout',) return x, None MLP = nnx.Scan( Block, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, - scan_metadata={nnx.PARTITION_NAME: 'layers'}, + transform_metadata={nnx.PARTITION_NAME: 'layers'}, ) m = MLP(rngs=nnx.Rngs(0)) # test sharding layers axes is set - state = m.get_state() - assert state.scan_module.linear.kernel.raw_value.shape == ( + state = nnx.state(m) + assert state.scan_module.linear.kernel.value.shape == ( 5, 3, 3, @@ -808,7 +842,7 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: 'din', 'dout', ) - assert state.scan_module.linear.bias.raw_value.shape == (5, 3) + assert state.scan_module.linear.bias.value.shape == (5, 3) assert state.scan_module.linear.bias.sharding == ( 'layers', 'dout', @@ -818,14 +852,14 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: y, out = m(x, None) # test sharding axes is preserved - state = m.get_state() - assert state.scan_module.linear.kernel.raw_value.shape == (5, 3, 3) + state = nnx.state(m) + assert state.scan_module.linear.kernel.value.shape == (5, 3, 3) assert state.scan_module.linear.kernel.sharding == ( 'layers', 'din', 'dout', ) - assert state.scan_module.linear.bias.raw_value.shape == (5, 3) + assert state.scan_module.linear.bias.value.shape == (5, 3) assert state.scan_module.linear.bias.sharding == ( 'layers', 'dout', @@ -841,37 +875,17 @@ def __call__(self): MLP = nnx.Scan( Block, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, ) mlp = MLP(rngs=nnx.Rngs(0)) with pytest.raises( - TypeError, match='Expected at least 1 positional argument' + TypeError, match='Expected at least 2 positional argument' ): mlp() - def test_value_error_positional_argument_type_context(self): - class Block(nnx.Module): - def __init__(self, rngs: nnx.Rngs): - self.linear = nnx.Linear(3, 3, rngs=rngs) - - def __call__(self, x: jax.Array) -> tp.Tuple[jax.Array, None]: - x = self.linear(x) - return x, None - - MLP = nnx.Scan( - Block, - variable_axes={nnx.Param: 0}, - length=5, - ) - - with pytest.raises( - ValueError, match='Rngs must be passed as a keyword argument named' - ): - MLP(nnx.Rngs(0)) - class TestRemat: def test_basic_remat(self): @@ -885,15 +899,15 @@ def test_basic_remat(self): def test_remat_decorator(self): class RematLinear(nnx.Module): - @nnx.remat - def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + @partial(nnx.remat, static_argnums=(1, 2)) + def __init__(self, din: int, dout: int, rngs: nnx.Rngs): self.linear = nnx.Linear(din, dout, rngs=rngs) @nnx.remat def __call__(self, x: jax.Array) -> jax.Array: return self.linear(x) - module = RematLinear(2, 3, rngs=nnx.Rngs(0)) + module = RematLinear(2, 3, nnx.Rngs(0)) y = module(jnp.ones((1, 2))) @@ -912,7 +926,7 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: ScanRematLinear = nnx.Scan( RematLinear, - variable_axes={nnx.Param: 0}, + state_axes={nnx.Param: 0}, length=5, ) @@ -928,18 +942,22 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: assert y.shape == (1, 3) def test_remat_with_scan_decorator(self): - scan = partial( - nnx.scan, - variable_axes={nnx.Param: 0}, - length=5, - ) - class ScanLinear(nnx.Module): - @scan + @partial( + nnx.vmap, + state_axes={nnx.Param: 0}, + axis_size=5, + ) def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) - @scan + @partial( + nnx.scan, + in_axes=None, + state_axes={nnx.Param: 0}, + length=5, + carry_argnum=1, + ) @nnx.remat def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: x = self.linear(x) @@ -956,6 +974,144 @@ def __call__(self, x: jax.Array, _) -> tp.Tuple[jax.Array, None]: class TestVmap: def test_basic(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.relu(x) + x = self.dropout(x) + return x + + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + vectorized_create_block = nnx.vmap( + create_block, state_axes={nnx.Param: 0}, axis_size=5 + ) + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + module = vectorized_create_block(rngs) + + assert rngs.default.count.value == 2 + assert rngs.default.key.value == initial_key + assert not jnp.allclose( + module.linear.kernel.value[0], + module.linear.kernel.value[1], + ) + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + + x = jnp.ones((5, 1, 3)) + + def forward_block(module, x): + return module(x) + + vectorized_forward_block = nnx.vmap( + forward_block, state_axes={nnx.Param: 0}, axis_size=5 + ) + + y = vectorized_forward_block(module, x) + + assert y.shape == (5, 1, 3) + assert rngs.default.count.value == 3 + assert rngs.default.key.value == initial_key + + y2 = vectorized_forward_block(module, x) + + assert not jnp.allclose(y, y2) + + def test_basic_demo(self): + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.dropout(nnx.relu(self.linear(x))) + + @partial(nnx.vmap, axis_size=5) + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + @partial(nnx.vmap, axis_size=5) + def forward_block(module: Block, x): + return module(x) + + rngs = nnx.Rngs(0) + module = create_block(rngs) + + assert rngs.default.count.value == 2 + assert module.linear.kernel.value.shape == (5, 3, 3) + assert module.linear.bias.value.shape == (5, 3) + assert not jnp.allclose( + module.linear.kernel.value[0], + module.linear.kernel.value[1], + ) + + x = jnp.ones((5, 1, 3)) + + y = forward_block(module, x) + + assert y.shape == (5, 1, 3) + assert rngs.default.count.value == 3 + + y2 = forward_block(module, x) + + # dropout is working! + assert not jnp.allclose(y, y2) + + def test_replicate(self): + din = 3 + dout = 10 + + class Block(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(din, dout, rngs=rngs) + self.dropout = nnx.Dropout(0.5, deterministic=False, rngs=rngs) + + def __call__(self, x: jax.Array) -> jax.Array: + return self.dropout(nnx.relu(self.linear(x))) + + def create_block(rngs: nnx.Rngs): + return Block(rngs) + + @partial( + nnx.vmap, + state_axes={}, # replicate all state + split_rngs=True, # different rngs for each replica + ) + def forward_block(module: Block, x): + return module(x) + + rngs = nnx.Rngs(0) + initial_key = rngs.default.key.value + module = create_block(rngs) + + assert rngs.default.count.value == 2 + assert module.linear.kernel.value.shape == (din, dout) + assert module.linear.bias.value.shape == (dout,) + + x = jnp.ones((5, 1, din)) + + y = forward_block(module, x) + + assert y.shape == (5, 1, dout) + assert rngs.default.count.value == 3 + + assert not jnp.allclose(y[0], y[1]) + + y2 = forward_block(module, x) + + # dropout is working! + assert not jnp.allclose(y, y2) + + assert rngs.default.key.value == initial_key + + def test_combinator(self): class Block(nnx.Module): def __init__(self, *, rngs: nnx.Rngs): self.linear = nnx.Linear(3, 3, rngs=rngs) @@ -965,7 +1121,7 @@ def __call__(self, x: jax.Array) -> jax.Array: x = nnx.gelu(x) return x - MLP = nnx.Vmap(Block, variable_axes={nnx.Param: 0}, axis_size=5) + MLP = nnx.Vmap(Block, state_axes={nnx.Param: 0}, axis_size=5) module = MLP(rngs=nnx.Rngs(0)) @@ -980,3 +1136,20 @@ def __call__(self, x: jax.Array) -> jax.Array: y = module(x) assert y.shape == (5, 1, 3) + + def test_combinator_init(self): + class Block(nnx.Module): + def __init__(self, *, graphdef: str, rngs: nnx.Rngs): + self.linear = nnx.Linear(3, 3, rngs=rngs) + self.graphdef = graphdef + + def __call__(self, x: jax.Array) -> jax.Array: + x = self.linear(x) + x = nnx.gelu(x) + return x + + MLP = nnx.Vmap(Block, state_axes={nnx.Param: 0}, axis_size=5) + + module = MLP(graphdef='hello', rngs=nnx.Rngs(0)) + + assert module.vmap_module.graphdef == 'hello' diff --git a/flax/experimental/nnx/tests/test_variable.py b/flax/experimental/nnx/tests/test_variable.py index 36f55f16c3..6048060560 100644 --- a/flax/experimental/nnx/tests/test_variable.py +++ b/flax/experimental/nnx/tests/test_variable.py @@ -15,19 +15,50 @@ import typing as tp import jax +import jax.numpy as jnp from flax.experimental import nnx A = tp.TypeVar('A') -class TestVariable: - def test_value(self): - r1 = nnx.Variable(1) - assert r1.raw_value == 1 +class TestVariableState: + def test_pytree(self): + r1 = nnx.VariableState(nnx.Param, 1) + assert r1.value == 1 r2 = jax.tree_util.tree_map(lambda x: x + 1, r1) - assert r1.raw_value == 1 - assert r2.raw_value == 2 + assert r1.value == 1 + assert r2.value == 2 assert r1 is not r2 + + def test_overloads_module(self): + class Linear(nnx.Module): + def __init__(self, din, dout, rngs: nnx.Rngs): + key = rngs() + self.w = nnx.Param(jax.random.normal(key, (din, dout))) + self.b = nnx.Param(jax.numpy.zeros((dout,))) + + def __call__(self, x: jax.Array): + return x @ self.w + self.b + + linear = Linear(3, 4, nnx.Rngs(0)) + x = jax.numpy.ones((3,)) + y = linear(x) + assert y.shape == (4,) + + def test_jax_array(self): + class Linear(nnx.Module): + def __init__(self, din, dout, rngs: nnx.Rngs): + key = rngs() + self.w = nnx.Param(jax.random.normal(key, (din, dout))) + self.b = nnx.Param(jax.numpy.zeros((dout,))) + + def __call__(self, x: jax.Array): + return jnp.dot(x, self.w) + self.b + + linear = Linear(3, 4, nnx.Rngs(0)) + x = jax.numpy.ones((3,)) + y = linear(x) + assert y.shape == (4,) diff --git a/flax/linen/__init__.py b/flax/linen/__init__.py index 4cd77c7136..cd29659c1c 100644 --- a/flax/linen/__init__.py +++ b/flax/linen/__init__.py @@ -31,13 +31,13 @@ # pylint: disable=g-multiple-import,useless-import-alias # re-export commonly used modules and functions -from ..core import ( +from flax.core import ( DenyList as DenyList, FrozenDict as FrozenDict, broadcast as broadcast, meta as meta, ) -from ..core.meta import ( +from flax.core.meta import ( PARTITION_NAME as PARTITION_NAME, Partitioned as Partitioned, get_partition_spec as get_partition_spec, diff --git a/flax/linen/module.py b/flax/linen/module.py index 17f00215ac..02b02a67bd 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -2580,23 +2580,8 @@ def sow( >>> model = Foo() >>> variables = model.init(jax.random.key(0), x) >>> y, state = model.apply(variables, x, mutable=['intermediates']) - >>> print(state['intermediates']) - {'h': (Array([[-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ], - [-1.503171 , 0.7377704 , -0.59388214, -1.0079019 ]], dtype=float32),)} + >>> jax.tree.map(jnp.shape, state['intermediates']) + {'h': ((16, 4),)} By default the values are stored in a tuple and each stored value is appended at the end. This way all intermediates can be tracked when diff --git a/pyproject.toml b/pyproject.toml index 1dd37d8429..1595e1c084 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ testing = [ "nbstripout", "black[jupyter]==23.7.0", # "pyink==23.5.0", # disabling pyink fow now + "penzai; python_version>='3.10'", ] [project.urls] @@ -166,7 +167,7 @@ exclude = [ "__init__.py", "activation.py", "partitioning.py", - "variables.py", + "flax/core/variables.py", "examples/", ] diff --git a/tests/run_all_tests.sh b/tests/run_all_tests.sh index a2b016138d..9bc9cb853a 100755 --- a/tests/run_all_tests.sh +++ b/tests/run_all_tests.sh @@ -85,7 +85,7 @@ if $RUN_DOCTEST; then pytest -n auto flax \ --doctest-modules \ --suppress-no-test-exit-code \ - --ignore=flax/experimental/nnx + --ignore=flax/experimental/nnx/examples fi # check that flax is running on editable mode @@ -111,6 +111,8 @@ if $RUN_PYTEST; then # Run battery of core FLAX API tests. echo "pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE" pytest -n auto tests $PYTEST_OPTS $PYTEST_IGNORE + # Run nnx tests + pytest -n auto flax/experimental/nnx/tests $PYTEST_OPTS $PYTEST_IGNORE # Per-example tests. # @@ -118,11 +120,21 @@ if $RUN_PYTEST; then # In pytest foo/bar/baz_test.py and baz/bleep/baz_test.py will collide and error out when # /foo/bar and /baz/bleep aren't set up as packages. for egd in $(find examples -maxdepth 1 -mindepth 1 -type d); do - pytest $egd + # skip if folder starts with "_" + if [[ $egd == *"_"* ]]; then + continue + fi + pytest $egd + done + + for egd in $(find flax/experimental/nnx/examples -maxdepth 1 -mindepth 1 -type d); do + # skip if folder starts with "_" or is "toy_examples" + if [[ $egd == *"_"* ]] || [[ $egd == *"toy_examples"* ]]; then + continue + fi + pytest $egd done - # Run nnx tests - pytest -n auto flax/experimental/nnx/tests $PYTEST_OPTS $PYTEST_IGNORE fi if $RUN_PYTYPE; then