Releases: google/flax
Releases · google/flax
Version 0.6.5
What's Changed
- Internal change by @copybara-service in #2826
- Cache .mypy_cache on CI by @cgarciae in #2819
- improve nn.map_variables docstrings by @cgarciae in #2818
- Render non-array variables in tabulate by @cgarciae in #2769
- Improve nn.map_variables example by @cgarciae in #2831
- Introduce relaxed naming policy behind a feature flag. by @levskaya in #2833
- Bumps minimal JAX version from 0.3.16 to 0.3.24 and cleans up some code. by @copybara-service in #2840
- import initializers module directly by @chiamp in #2830
- Re-box axis metadata in DenseGeneral's init functions by @IvyZX in #2843
- Update jax dependency by @jheek in #2827
- replaced
zeros
andones
initializers withzeros_init
andones_init
builder functions by @chiamp in #2815 - Fixing the punctuation in the README by @arcAman07 in #2811
- Implement lazy init by @jheek in #2816
- improve nn.checkpoint docs by @cgarciae in #2837
- Flax 0.6.5 point release by @levskaya in #2856
Full Changelog: v0.6.4...v0.6.5
Version 0.6.4
What's Changed
- Improve and migrate model_surgery guide by @IvyZX in #2668
- Add deprecation warnings and docstring pointers to flax.training.lr_schedule by @IvyZX in #2702
- Introduce flax.io.NotFoundError to remove tensorflow dependency by @wookayin in #2697
- wrap Module properties by @cgarciae in #2541
- Update the checkpointing guide with Orbax migration code examples by @IvyZX in #2688
- Include mypy in run_all_tests by @zaxtax in #2670
- Add serialization for Partial functions. Completes #2433 by @zaxtax in #2557
- Cleanup DenseGeneral by @cgarciae in #2701
- Adds wave2vec2 link by @marcvanzee in #2718
- Remove some filterwarnings by @marcvanzee in #2715
- Adds standardize initializer by @marcvanzee in #2717
- Bump action/setup-python Python version to 3.9 by @marcvanzee in #2726
- Expose module.merge_param in RTD and fix docstring by @marcvanzee in #2680
- Revert to using public github runner pool while internal issues are fixed. by @levskaya in #2729
- Re-enable faster github CI runners by @levskaya in #2732
- Monitor checkpoint load/save durations. by @copybara-service in #2629
- Add Flax Dropout guide by @8bitmp3 in #2675
- Update readthedocs build to python 3.8 by @IvyZX in #2746
- Update Flax BatchNorm guide title by @8bitmp3 in #2747
- Simplify _AxisRules by @cgarciae in #2734
- Re-added serialization error checking for dicts by @chiamp in #2691
- Guide for Flax partitioning (auto SPMD) API by @IvyZX in #2730
- Fix LogicalNames type by @cgarciae in #2759
- removed duplicate flip section in docs by @chiamp in #2757
- updated tabulate documentation by @chiamp in #2756
- add unbind by @cgarciae in #2674
- Remove incorrect SPMD section in linen.rst by @cgarciae in #2766
- Adds some more comments to scope.py by @marcvanzee in #2743
- Fix GRU typo by @jheek in #2785
- Add metrics for async checkpoint (beacon and total write duration). by @copybara-service in #2786
- Fix indentation by @8bitmp3 in #2787
- Link CleanRL's PPO Implementation by @vwxyzjn in #2780
- Flax docs restructuring by @8bitmp3 in #2788
- Add metadata helper transform for adding axis metadata. by @levskaya in #2783
- Update Flax mission, fix some links by @8bitmp3 in #2797
- Updated getting started doc by @chiamp in #2698
- Added builder functions for zeros and ones initializers by @chiamp in #2790
- Restructure Flax Examples, add sections by @8bitmp3 in #2801
- Update how to build flax docs by @8bitmp3 in #2740
- Updating the doc by @arcAman07 in #2796
- added zeros_init and ones_init builder to initializers list by @chiamp in #2806
- Implement perturb no-op behaviour by @cgarciae in #2728
- Fix links in Flax Optax upgrading and dataset processing by @8bitmp3 in #2798
- improve Module.init docs by @cgarciae in #2792
- replaced zeros and ones initializers with zeros_init and ones_init builder functions by @chiamp in #2807
- Restructure Flax guides, add sections, update copyright year by @8bitmp3 in #2800
- Added glossary doc by @chiamp in #2805
- Add a comment about an error. by @nouiz in #2810
- Add transpose_kernel argument to ConvTranspose. by @mathisgerdes in #2578
- rename ConvLSTM to ConvLSTMCell by @cgarciae in #2767
- Add missing fields from repr by @cgarciae in #2803
- Allow specifying method as a string by @cgarciae in #2809
- Release 0.6.4 by @IvyZX in #2820
New Contributors
- @wookayin made their first contribution in #2697
- @vwxyzjn made their first contribution in #2780
- @arcAman07 made their first contribution in #2796
- @nouiz made their first contribution in #2810
- @mathisgerdes made their first contribution in #2578
Full Changelog: v0.6.3...v0.6.4
Version 0.6.3
What's Changed
- Add gfile api shim to remove tensorflow dependency for basic IO. by @chiamp in #2586
- Remove Mypy type errors by @zaxtax in #2594
- Attempt to fix pytype issue. by @levskaya in #2628
- example/imagenet: use absolute path to locate the Flax root dir by @yhtang in #2630
- Move Flax - The Sharp Bits ToC location by @8bitmp3 in #2633
- Speed up Github Actions CI by @levskaya in #2635
- Update requirements by @marcvanzee in #2652
- Switch to Orbax for Flax single-checkpoint support under the hood. by @copybara-service in #2637
- Added path discrepancy details in serialization errors by @chiamp in #2632
- Fix get_type_hints again by @cgarciae in #2654
- BatchNorm guide by @cgarciae in #2536
- Generalize pool to handle multiple batch dimensions by @cgarciae in #2591
- Improve transfer learning guide by @cgarciae in #2595
- update docstrings and error messages in traverse_util.py. by @copybara-service in #2666
- Use a different rng key for each batch element in DenseGeneral init by @j-towns in #2665
- Added check for Mac M1 chip, when conducting serialization test. by @chiamp in #2657
- Added explicit warning if flax.io is using default Python I/O operations by @chiamp in #2625
- Check file or directory before removing checkpoints. by @IvyZX in #2676
- updated extracting_intermediates in flax docs by @chiamp in #2685
- updated model_surgery in flax_docs by @chiamp in #2687
- updated getting_started in flax docs by @chiamp in #2684
- Updated flax docs. by @chiamp in #2667
- updated flax_basics in flax docs by @chiamp in #2686
- Update python version support by @cgarciae in #2682
- add mypy.ini placeholder. by @copybara-service in #2693
- Release 0.6.3 by @IvyZX in #2705
New Contributors
Full Changelog: v0.6.2...v0.6.3
Version 0.6.2
What's Changed
- Refactor out dataclass transform that allows
parent
andname
to be moved to the end of the argument list into a more general "kw_only_dataclasses" module. by @copybara-service in #2468 - Don't create reference cycles among Modules. by @levskaya in #2499
- adds perturb to docs by @cgarciae in #2511
- Add pre-commit hook to remove trailing white spaces by @cgarciae in #2513
- Add extracting gradients section to capture intermediates guide by @cgarciae in #2515
- Make
is_fully_replicated
andis_fully_addressble
a property rather than a method. by @copybara-service in #2516 - Adds CallSetupUnboundModuleError by @cgarciae in #2496
- Adding documentation to Dropout around rng use by @zaxtax in #2492
- no-op when double wrapping with struct.dataclass by @cgarciae in #2505
- Add epub and pdf RDT formats by @cgarciae in #2517
- Update landing page example by @cgarciae in #2366
- Use
gfile.remove
for files because it doesn't work on GCS files. by @IvyZX in #2518 - Add save_checkpoint_multiprocess to api reference. by @IvyZX in #2519
- Add Tensorstore back to required dependencies. by @IvyZX in #2520
- Delete flax.optim.rst by @ppwwyyxx in #2522
- Added an error for when we call init, init_with_output and apply on Module class. by @chiamp in #2529
- Adding link to Bayesian Inference example that uses Flax by @zaxtax in #2521
- Added IncorrectPostInitOverrideError to capture incorrect post init overrides. by @copybara-service in #2535
- Fix to use
is_initializing
for init-detection. by @yotarok in #2486 - add rng_collection argument to Dropout by @cgarciae in #2540
- Cancel tests if other jobs fail by @cgarciae in #2507
- Update Guides link in Flax README by @8bitmp3 in #2544
- Pin jupytext version in requirements.txt by @IvyZX in #2545
- Fix flax.linen.stochastic.Dropout by @dslisleedh in #2510
- Transfer Learning Guide by @cgarciae in #2394
- Update Flax Contributing.md by @8bitmp3 in #2546
- cap dynamic scale to float32 max by @jheek in #2553
- Remove optional import of jax.experimental.array for older JAX versions. by @copybara-service in #2552
- Update examples.rst with denoising-diffusion-flax by @yiyixuxu in #2487
- return None if no _parent_ref is set by @cgarciae in #2548
- Add a documentation page on checkpointing by @IvyZX in #2530
- Lint Flax Contributing guide by @8bitmp3 in #2560
- Remove extra metadata in Checkpointing guide by @8bitmp3 in #2559
- Update getting_started.ipynb and getting_started.md by @chiamp in #2563
- Remove unused svn dependency by @8bitmp3 in #2574
- Fix pytype check in checkpoints.py by @IvyZX in #2592
- Add new 🔪 Flax - The Sharp Bits 🔪 Dropout and randomness by @8bitmp3 in #2593
- Fixes
launch_gce.sh
with imagenet example. by @andsteing in #2598 - Added test to check for Variable warning. by @chiamp in #2610
- Release version 0.6.2 by @IvyZX in #2613
New Contributors
- @zaxtax made their first contribution in #2492
- @ppwwyyxx made their first contribution in #2522
- @chiamp made their first contribution in #2529
- @yotarok made their first contribution in #2486
- @yiyixuxu made their first contribution in #2487
Full Changelog: v0.6.1...v0.6.2
Version 0.6.1
What's Changed
- Updates examples/{imagenet,wmt} requirements. by @andsteing in #2405
- Bump
rich
dependency version by @yklcs in #2407 - Adds axis_name and axis_index_groups to LayerNorm and GroupNorm. by @copybara-service in #2402
- Plumb spmd_axis_name through transforms.vmap through to JAX vmap by @copybara-service in #2398
- Support multiple inputs in flax lifted vjp/custom_vjp by @copybara-service in #2399
- Explicit reexport initializers from jax by @lkhphuc in #2409
- Improve tabulate by @cgarciae in #2316
- Add path_aware_map function by @cgarciae in #2371
- PIL does not accept DeviceArray, so needed to use numpy. by @villeh1 in #2427
- Move examples to RTD by @cgarciae in #2367
- Simplify dynamic context by @cgarciae in #2388
- Remove pytype generic workaround by @jheek in #2446
- Add static_argnums to nn.checkpoint by @cgarciae in #2457
- ignore tf deprecation warning. by @copybara-service in #2466
- Fix Managing Parameters and State docs by @cgarciae in #2473
- Use gfile.listdir instead of gfile.glob in checkpointing by @IvyZX in #2470
- Create test matrix to speedup tests by @cgarciae in #2458
- Fix Conv docstrings by @cgarciae in #2425
- Use proper scikit-learn dependency by @cgarciae in #2465
- Improve attribute error msg for unbounded modules by @cgarciae in #2440
- Adding "count_include_pad" argument to flax.linen.pooling.avg_pool by @dslisleedh in #2451
- Add perturb() to allow capturing intermediate gradients by @IvyZX in #2476
- fix DynamicScale docstring by @cgarciae in #2491
- test against python 3.8 and 3.9 by @cgarciae in #2490
- Update version to 0.6.1 by @cgarciae in #2494
- Adoption cache should use WeakValueDictionary. by @levskaya in #2495
- FLIP: General metadata by @jheek in #2435
New Contributors
- @yklcs made their first contribution in #2407
- @villeh1 made their first contribution in #2427
- @dslisleedh made their first contribution in #2451
Full Changelog: v0.6.0...v0.6.1
Version 0.6.0
What's Changed
- Add
on_commit_callback
to put the responsibility of renaming the directories on the users of the serialization library. This will also fix the GCS atomic rename issue where the users can write a success file when the commit is successful and check the existence of that file before deserialization. by @copybara-service in #2328 - RDT Redesign by @cgarciae in #2177
- Fix stale URLs to read the docs site. by @levskaya in #2338
- Replace all jax.tree_* calls with jax.tree_util.tree_* by @levskaya in #2325
- Further fix the singular-leaf checkpointing, and add tests. by @copybara-service in #2336
- Forward unroll argument in scan_with_axes by @sanchit-gandhi in #2339
- Document repo analytics by @cgarciae in #2317
- Improve flax basics by @cgarciae in #2291
- Split build into multiple jobs by @cgarciae in #2277
- Fix type annotation for
step
intraining.checkpoints
by @Chuxiaof in #2343 - Add test for writing and restoring empty checkpoints. by @copybara-service in #2345
- Allow all processes to checkpoint when not using GDA by @copybara-service in #2350
- Make importing tensorstore optional and move related type hints to comments. by @copybara-service in #2348
- Use jax.named_scope for name stack rather than named_call. by @copybara-service in #2349
- Internal change by @copybara-service in #2356
- Fix sphinx CI errors by @cgarciae in #2361
- Forward path to rewound Scope by @jheek in #2360
- Make link a link on the getting started by @Davidnet in #2340
- Fix colab & github links by @cgarciae in #2363
- Fix ConvTranspose with circular padding by @cgarciae in #2364
- Add some docstrings to the old flax.training.common_utils module. by @levskaya in #2373
- Correct state variable name by @Jeevesh8 in #2369
- Allow linen's Conv layer to operate on arbitrary-rank inputs. by @copybara-service in #2308
- Copies
dynamic_scale.py
fromoptim/
totraining/
. by @copybara-service in #2375 - Add option
auto_flush
toflax.metrics.tensorboard.SummaryWriter
by @copybara-service in #2376 - updated supported transforms in lifting docs. by @levskaya in #2374
- Test docstrings with autodoc on CI by @cgarciae in #2372
- skip remat test that fails with autodiff by @mattjj in #2389
- Removes
flax.optim.dynamic_scale
. by @copybara-service in #2314 - Plumb spmd_axis_name from vmap_with_axes through to JAX vmap by @copybara-service in #2390
- fixed math expressions by @banda-larga in #2392
New Contributors
- @sanchit-gandhi made their first contribution in #2339
- @Chuxiaof made their first contribution in #2343
- @Davidnet made their first contribution in #2340
- @Jeevesh8 made their first contribution in #2369
- @banda-larga made their first contribution in #2392
Full Changelog: v0.5.3...v0.6.0
Version 0.5.3
What's Changed
- Adds .pre-commit-config.yaml by @copybara-service in #2212
- Fix missing passthrough of nn.scan unroll arg by @jheek in #2213
- Test Notebooks on CI by @cgarciae in #2166
- Bump numpy from 1.21.4 to 1.22.0 in examples by @marcvanzee in #2228
- Add nn.switch by @cgarciae in #2205
- Fix notebooks by @cgarciae in #2231
- Add launch section with colab button by @cgarciae in #2235
- Enabling the dollarmath extension of MyST to render correctly math expresions by @WaterKnight1998 in #2238
- Update codediff to use sphinx-design tabs by @cgarciae in #2204
- Fix tests by @cgarciae in #2253
- Add single-host async save to save_checkpoint. by @IvyZX in #2233
- Add a method for detecting the use of "init" functions. by @levskaya in #2234
- Small fix in MNIST example by @marcvanzee in #2258
- Fix typos in the doc of
flax.linen.Module.bind
by @nalzok in #2269 - Add colab button to flax_basics by @cgarciae in #2276
- Fix type annotations by @cgarciae in #2281
- Exclude pseudo-fields of dataclass by @YouJiacheng in #2199
- Fix variable aliasing in put_variable by @jheek in #2296
- Update reference to tree_map to avoid deprecation warning. by @copybara-service in #2298
- Fix nondeterministic bug arising from sharing logic during module adoption. by @copybara-service in #2302
- fix ppo example typo by @fuyw in #2306
- Forward axis_size tot jax.vmap by @jheek in #2310
- cleanup: replace deprecated jax.tree_map with jax.tree_util.tree_map by @copybara-service in #2311
- Add GlobalDeviceArray/multihost checkpoint support to Flax. by @copybara-service in #2287
- 0.5.3 update version & changelog by @IvyZX in #2330
- Replace use of id() with global counter-based id. by @levskaya in #2313
New Contributors
- @WaterKnight1998 made their first contribution in #2238
- @nalzok made their first contribution in #2269
- @YouJiacheng made their first contribution in #2199
- @fuyw made their first contribution in #2306
Full Changelog: v0.5.2...v0.5.3
Version 0.5.2
What's Changed
- Flax Basics docs: Add missing
@jax.jit
tomse
by @rsokl in #2181 - add missing colon in example code by @PWhiddy in #2188
- New-sphinx-theme by @cgarciae in #2171
- Add missing PyYAML dependency by @cgarciae in #2193
- Improve module docs by @cgarciae in #2167
- Changed optimizer to optax by @berndbohnet in #1916
- Show repository button by @PhilipVinc in #2206
- Updates filterwarning in pytest.ini by @marcvanzee in #2209
- v0.5.2 by @cgarciae in #2203
New Contributors
- @rsokl made their first contribution in #2181
- @PWhiddy made their first contribution in #2188
- @berndbohnet made their first contribution in #1916
Full Changelog: v0.5.1...v0.5.2
Version 0.5.1
What's Changed
- Adds flax import to summary.py by @marcvanzee in #2138
- Add options for fallback behavior. by @copybara-service in #2130
- Upgrade to modern python idioms using pyupgrade. by @levskaya in #2132
- Update download_dataset_metadata.sh by @mattiasmar in #1801
- Mark correct minimum jax version requirement by @PhilipVinc in #2136
- Edited contributing.md by @IvyZX in #2151
- Bump tensorflow from 2.8.0 to 2.8.1 in /examples/imagenet by @dependabot in #2143
- Bump tensorflow from 2.8.0 to 2.8.1 in /examples/wmt by @dependabot in #2142
- Add typehint to Module.scope by @cgarciae in #2106
- Correcting Mistakes In Flip Docs by @saiteja13427 in #2140
- Add CAUSAL padding for 1D convolution. by @copybara-service in #2141
- Calculate cumulative number or issues and prs by @cgarciae in #2154
- Improve setup instructions in contributing guide by @cgarciae in #2155
- Forward unroll argument in lifted scan by @jheek in #2158
- Improve tabulate by @cgarciae in #2162
- Remove unused variable from nlp_seq example by @marcvanzee in #2163
- Allow nn.cond, nn.while to act on bound methods. by @levskaya in #2172
- 0.5.1 by @cgarciae in #2180
- Update normalization.py by @yechengxi in #2182
New Contributors
- @mattiasmar made their first contribution in #1801
- @PhilipVinc made their first contribution in #2136
- @IvyZX made their first contribution in #2151
- @saiteja13427 made their first contribution in #2140
- @yechengxi made their first contribution in #2182
Full Changelog: v0.5.0...v0.5.1
Version 0.5.0
New features:
- Added
flax.jax_utils.ad_shard_unpad()
by @lucasb-eyer - Implemented default dtype FLIP.
This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32.
This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate
complex numbers to their real component by default. Instead the complex dtype is preserved by default.
Bug fixes:
- Fix support for JAX's experimental_name_stack.
Breaking changes:
- In rare cases the dtype of a layer can change due to default dtype FLIP. See the "Backward compatibility" section of the proposal for more information.