Quick question on implementation. #31
-
Hi guys, Does this library adds any increase of performance over using the safetensors native numpy save and load functions?, If so shouldn't be merged to safe-tensors repo directly? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi @miguelamendez so the idea behind The idea behind So on, the following doesn't work with plain from safetensors.jax import save_file
tensors = {"dense0": {"weight": jnp.ones((1, 10)), "bias": jnp.ones((10))}}
save_file(tensors, "tensors.safetensors") While using from safejax import serialize
tensors = {"dense0": {"weight": jnp.ones((1, 10)), "bias": jnp.ones((10))}}
serialize(tensors, filename="tensors.safetensors") More information on why For extra context, |
Beta Was this translation helpful? Give feedback.
Hi @miguelamendez so the idea behind
safetensors
is that it's another format for storing tensors in opposition topickle
which is not safe. This doesn't mean thatjax.numpy.save
,jax.numpy.savez
, or any other existing serializing format is not recommended.The idea behind
safejax
is to usesafetensors
to serialize the whole tree, which decomposes the trees in a Python dictionary by joining the keys, so that the original tree is flattened. This is done becausesafetensors
doesn't support the storing of complex structures/trees e.g.FrozenDict
inflax
. More information on the latter is at huggingface/safetensors#138So on, the following doesn't work with plain
safetensors
: