Skip to content

Commit

Permalink
Merge pull request #28 from raoulritter/main
Browse files Browse the repository at this point in the history
[Model Support] FLUX.1-dev
  • Loading branch information
arda-argmax authored Sep 9, 2024
2 parents c8083cc + 5cdda4d commit bfbdd0e
Show file tree
Hide file tree
Showing 9 changed files with 71 additions and 15 deletions.
6 changes: 0 additions & 6 deletions .flake8

This file was deleted.

1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ __pycache__/
# Distribution / packaging
.Python
build/
.build/
develop-eggs/
dist/
downloads/
Expand Down
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@ pip install -e .
<summary> Click to expand </summary>


[Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) requires users to accept the terms before downloading the checkpoint. Once you accept the terms, sign in with your Hugging Face hub READ token as below:
[Stable Diffusion 3](https://huggingface.co/stabilityai/stable-diffusion-3-medium) requires users to accept the terms before downloading the checkpoint.

[FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) also requires users to accept the terms before downloading the checkpoint.

Once you accept the terms, sign in with your Hugging Face hub READ token as below:
> [!IMPORTANT]
> If using a fine-grained token, it is also necessary to [edit permissions](https://huggingface.co/settings/tokens) to allow `Read access to contents of all public gated repos you can access`
Expand Down Expand Up @@ -89,6 +93,8 @@ Some notable optional arguments for:

Please refer to the help menu for all available arguments: `diffusionkit-cli -h`.

Note: When using `FLUX.1-dev`, verify you've accepted the [FLUX.1-dev licence](https://huggingface.co/black-forest-labs/FLUX.1-dev) and have allowed gated access on your [HuggingFace token](https://huggingface.co/settings/tokens)

### Code ###

For Stable Diffusion 3:
Expand All @@ -109,7 +115,7 @@ For FLUX:
from diffusionkit.mlx import FluxPipeline
pipeline = FluxPipeline(
shift=1.0,
model_version="argmaxinc/mlx-FLUX.1-schnell",
model_version="argmaxinc/mlx-FLUX.1-schnell", # model_version="argmaxinc/mlx-FLUX.1-dev" for FLUX.1-dev
low_memory_mode=True,
a16=True,
w16=True,
Expand All @@ -120,7 +126,7 @@ Finally, to generate the image, use the `generate_image()` function:
```python
HEIGHT = 512
WIDTH = 512
NUM_STEPS = 4 # 4 for FLUX.1-schnell, 50 for SD3
NUM_STEPS = 4 # 4 for FLUX.1-schnell, 50 for SD3 and FLUX.1-dev
CFG_WEIGHT = 0. # for FLUX.1-schnell, 5. for SD3

image, _ = pipeline.generate_image(
Expand Down
6 changes: 5 additions & 1 deletion python/src/diffusionkit/mlx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@
"sd3-8b-unreleased": "models/sd3_8b_beta.safetensors", # unreleased
"argmaxinc/mlx-FLUX.1-schnell": "argmaxinc/mlx-FLUX.1-schnell",
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "argmaxinc/mlx-FLUX.1-schnell-4bit-quantized",
"argmaxinc/mlx-FLUX.1-dev": "argmaxinc/mlx-FLUX.1-dev",
}

T5_MAX_LENGTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"argmaxinc/mlx-FLUX.1-schnell": 256,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 256,
"argmaxinc/mlx-FLUX.1-dev": 512,
}


Expand Down Expand Up @@ -653,7 +655,9 @@ def encode_text(
text,
(negative_text if cfg_weight > 1 else None),
)
padded_tokens_t5 = mx.zeros((1, 256)).astype(tokens_t5.dtype)
padded_tokens_t5 = mx.zeros((1, T5_MAX_LENGTH[self.model_version])).astype(
tokens_t5.dtype
)
padded_tokens_t5[:, : tokens_t5.shape[1]] = tokens_t5[
[0], :
] # Ignore negative text
Expand Down
18 changes: 18 additions & 0 deletions python/src/diffusionkit/mlx/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def hidden_size(self) -> int:

low_memory_mode: bool = True

guidance_embed: bool = False


SD3_8b = MMDiTConfig(depth_multimodal=38, num_heads=3, upcast_multimodal_blocks=[35])

Expand All @@ -90,6 +92,22 @@ def hidden_size(self) -> int:
dtype=mx.bfloat16,
)

FLUX_DEV = MMDiTConfig(
num_heads=24,
depth_multimodal=19,
depth_unified=38,
parallel_mlp_for_unified_blocks=True,
hidden_size_override=3072,
patchify_via_reshape=True,
pos_embed_type=PositionalEncoding.PreSDPARope,
rope_axes_dim=(16, 56, 56),
pooled_text_embed_dim=768, # CLIP-L/14 only
use_qk_norm=True,
float16_dtype=mx.bfloat16,
guidance_embed=True,
dtype=mx.bfloat16,
)


@dataclass
class AutoencoderConfig:
Expand Down
24 changes: 23 additions & 1 deletion python/src/diffusionkit/mlx/mmdit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,13 @@ def __init__(self, config: MMDiTConfig):
super().__init__()
self.config = config

if config.guidance_embed:
self.guidance_in = MLPEmbedder(
in_dim=config.frequency_embed_dim, hidden_dim=config.hidden_size
)
else:
self.guidance_in = nn.Identity()

# Input adapters and embeddings
self.x_embedder = LatentImageAdapter(config)

Expand Down Expand Up @@ -209,6 +216,9 @@ def __call__(
else:
positional_encodings = None

if self.config.guidance_embed:
timestep = self.guidance_in(self.t_embedder(timestep))

# MultiModalTransformer layers
if self.config.depth_multimodal > 0:
for bidx, block in enumerate(self.multimodal_transformer_blocks):
Expand Down Expand Up @@ -236,7 +246,6 @@ def __call__(
:, token_level_text_embeddings.shape[1] :, ...
]

# Final layer
latent_image_embeddings = self.final_layer(
latent_image_embeddings,
timestep,
Expand Down Expand Up @@ -933,6 +942,19 @@ def apply(q_or_k: mx.array, rope: mx.array) -> mx.array:
)


class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.SiLU(),
nn.Linear(hidden_dim, hidden_dim),
)

def __call__(self, x):
return self.mlp(x)


def affine_transform(
x: mx.array,
shift: mx.array,
Expand Down
10 changes: 9 additions & 1 deletion python/src/diffusionkit/mlx/model_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": "flux-schnell-4bit-quantized.safetensors",
"vae": "ae.safetensors",
},
"argmaxinc/mlx-FLUX.1-dev": {
"argmaxinc/mlx-FLUX.1-dev": "flux1-dev.safetensors",
"vae": "ae.safetensors",
},
}
_DEFAULT_MODEL = "argmaxinc/stable-diffusion"
_MODELS = {
Expand Down Expand Up @@ -75,6 +79,10 @@
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
"argmaxinc/mlx-FLUX.1-dev": {
"vae_encoder": "encoder.",
"vae_decoder": "decoder.",
},
}

_FLOAT16 = mx.bfloat16
Expand Down Expand Up @@ -704,7 +712,7 @@ def load_flux(
hf_hub_download(key, "config.json")
weights = mx.load(flux_weights_ckpt)

if model_key == "argmaxinc/mlx-FLUX.1-schnell":
if model_key in ["argmaxinc/mlx-FLUX.1-schnell", "argmaxinc/mlx-FLUX.1-dev"]:
weights = flux_state_dict_adjustments(
weights,
prefix="",
Expand Down
5 changes: 4 additions & 1 deletion python/src/diffusionkit/mlx/scripts/generate_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,21 @@
"sd3-8b-unreleased": 1024,
"argmaxinc/mlx-FLUX.1-schnell": 512,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-dev": 512,
}
WIDTH = {
"argmaxinc/mlx-stable-diffusion-3-medium": 512,
"sd3-8b-unreleased": 1024,
"argmaxinc/mlx-FLUX.1-schnell": 512,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 512,
"argmaxinc/mlx-FLUX.1-dev": 512,
}
SHIFT = {
"argmaxinc/mlx-stable-diffusion-3-medium": 3.0,
"sd3-8b-unreleased": 3.0,
"argmaxinc/mlx-FLUX.1-schnell": 1.0,
"argmaxinc/mlx-FLUX.1-schnell-4bit-quantized": 1.0,
"argmaxinc/mlx-FLUX.1-dev": 1.0,
}


Expand Down Expand Up @@ -111,7 +114,7 @@ def cli():
args.a16 = True

if "FLUX" in args.model_version and args.cfg > 0.0:
logger.warning("Disabling CFG for FLUX.1-schnell model.")
logger.warning(f"Disabling CFG for {args.model_version} model.")
args.cfg = 0.0

if args.benchmark_mode:
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from setuptools import find_packages, setup
from setuptools.command.install import install

VERSION = "0.3.5"
VERSION = "0.4.0"


class VersionInstallCommand(install):
Expand All @@ -29,7 +29,7 @@ def run(self):
"argmaxtools>=0.1.13",
"torch",
"safetensors",
"mlx>=0.16.3",
"mlx>=0.17.1",
"jaxtyping",
"transformers",
"pillow",
Expand Down

0 comments on commit bfbdd0e

Please sign in to comment.