Skip to content

Latest commit

 

History

History
42 lines (28 loc) · 1.04 KB

readme.md

File metadata and controls

42 lines (28 loc) · 1.04 KB

FLUX-Flax

A JAX port of FLUX.1 models using flax.nnx.

Important

The current codebase is designed to maintain consistency with the original implementation, with minimal modifications. While it works as expected, it may not be the most efficient implementation. I plan to release an updated version soon that better adheres to JAX conventions and best practices.

img

Status

Only tested with GPU now.

Currently no quantization support & no torch-like CPU offloading support.

PRs are welcome.

Local installation

git clone https://github.com/lkwq007/flux-flax.git
cd flux-flax
mamba create -p ./env python=3.10
mamba activate ./env
pip install -r requirements.txt

Usage

For interactive sampling run

python main.py --name <name>

Or to generate a single sample run (not recommended, as jit compilation takes time)

python main.py --name <name> \
  --height <height> --width <width> --nonloop \
  --prompt "<prompt>"