For seqax
we write checkpoints of JAX PyTrees, in a simple format documented here.
The zarr of a PyTree is a a zarr Group with the following elements:
- for each
path, array
in the flattened PyTree, the zarr Group containsarray
as a child array, with path equal tojax.tree_util.keystr(path)
- additionally there is a zarr attribute by name
write_completed
and valueTrue
.
The zarr of a PyTree may be written to disk with any compression and chunk size settings.
We use zarr
to support parallel writers from different hosts in a fully-sharded training setup. (Parallel writers in this scenario must choose a chunk size that divides the data size per host, so as to avoid zarr race conditions during writing.) Readers of the checkpoint format do not need to be aware that it was written in parallel, as this is hidden by the zarr abstraction.
We use the write_completed
attribute to allow parallel writers to support a "two phase commit" protocol: all writers write their data chunks, then wait for a global barrier, then the "leader" writer sets the write_completed
attribute. This protects readers from reading partially-written checkpoints.