Skip to content

Commit

Permalink
[Flux] Port Flux Core Model (#1864)
Browse files Browse the repository at this point in the history
* starter commit - ported time embeddings to keras ops

* add mlpembedder

* add RMS Norm re-implementation

* add qknorm reimplementation

* add rope, scaled dot product attention and self attention

* modulation layer

* fix typing

* add double stream block

* adjustments to doublestreamblock

* add signle stream layer@

* update layers and add flux core model

* functions to layers

* refactor layer usage

* refactor layer usage

* position math args in call()

* name arguments

* fix arg name

* start adding conversion script utils

* change reshape into rearrange

* add rest of weight conversion and remove redundant shape extraction

* fix mlpembedder arg

* remove redundant args

* fix params. to self.

* add license

* add einops

* fix default arg

* expand docstrings

* tanh to gelu

* refactor weight conversion into tools

* update weight conversion

* add stand-in presets until weights are uploaded

* set float32 to t.dtype in timestep embedding

* update more float32s into dynamic types

* dtype

* dtype

* enable float16 mode

* update conversion script to not require flux repo

* add build() methods to avoid running dummy input through model

* update build call

* fix build calls

* style

* change dummy call into build() call

* reference einops issue

* address docstring comments in flux layers

* address docstring comments in flux maths

* remove numpy

* add docstrings for flux model

* qkv bias -> use_bias

* docstring updates

* remove type hints

* all img->image, txt->text

* functional subclassing model

* shape fixes

* format

* self.hidden_size -> self.dim

* einops rearrange

* remove build method

* ops to rearrange

* remove build

* rearrange -> symbolic_rearrange

* turn timesteps and guidance into inputs

* basic preprocessor flow

* refactor layer names in conversion script

* add backbone tests

* raise not implemented on encode, encode_text, etc. methods

* styling

* fix shape hack with a cleaner alternative

* remove unused attributes, fix tests

* change list into tuple for the expected shape

* address comments

* save mdel on conversion
  • Loading branch information
DavidLandup0 authored Nov 13, 2024
1 parent 893a4db commit 0756fb4
Show file tree
Hide file tree
Showing 14 changed files with 1,586 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.layers.modeling.rms_normalization import RMSNormalization
from keras_hub.src.layers.modeling.rotary_embedding import RotaryEmbedding
from keras_hub.src.layers.modeling.sine_position_encoding import (
SinePositionEncoding,
Expand Down
5 changes: 5 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,11 @@
)
from keras_hub.src.models.falcon.falcon_tokenizer import FalconTokenizer
from keras_hub.src.models.feature_pyramid_backbone import FeaturePyramidBackbone
from keras_hub.src.models.flux.flux_model import FluxBackbone
from keras_hub.src.models.flux.flux_text_to_image import FluxTextToImage
from keras_hub.src.models.flux.flux_text_to_image_preprocessor import (
FluxTextToImagePreprocessor,
)
from keras_hub.src.models.gemma.gemma_backbone import GemmaBackbone
from keras_hub.src.models.gemma.gemma_causal_lm import GemmaCausalLM
from keras_hub.src.models.gemma.gemma_causal_lm_preprocessor import (
Expand Down
34 changes: 34 additions & 0 deletions keras_hub/src/layers/modeling/rms_normalization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import keras
from keras import ops

from keras_hub.src.api_export import keras_hub_export


@keras_hub_export("keras_hub.layers.RMSNormalization")
class RMSNormalization(keras.layers.Layer):
"""
Root Mean Square (RMS) Normalization layer.
This layer normalizes the input tensor based on its RMS value and applies
a learned scaling factor.
Args:
input_dim: int. The dimensionality of the input tensor.
"""

def __init__(self, input_dim):
super().__init__()
self.scale = self.add_weight(
name="scale", shape=(input_dim,), initializer="ones"
)

def call(self, x):
"""
Applies RMS normalization to the input tensor.
Args:
x: KerasTensor. Input tensor of shape (batch_size, input_dim).
Returns:
KerasTensor: The RMS-normalized tensor of the same shape (batch_size, input_dim),
scaled by the learned `scale` parameter.
"""
x = ops.cast(x, float)
rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6)
return (x * rrms) * self.scale
5 changes: 5 additions & 0 deletions keras_hub/src/models/flux/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from keras_hub.src.models.flux.flux_model import FluxBackbone
from keras_hub.src.models.flux.flux_presets import presets
from keras_hub.src.utils.preset_utils import register_presets

register_presets(presets, FluxBackbone)
73 changes: 73 additions & 0 deletions keras_hub/src/models/flux/flux_backbone_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import pytest
from keras import ops

from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder
from keras_hub.src.models.flux.flux_model import FluxBackbone
from keras_hub.src.models.vae.vae_backbone import VAEBackbone
from keras_hub.src.tests.test_case import TestCase


class FluxBackboneTest(TestCase):
def setUp(self):
vae = VAEBackbone(
[32, 32, 32, 32],
[1, 1, 1, 1],
[32, 32, 32, 32],
[1, 1, 1, 1],
# Use `mode` generate a deterministic output.
sampler_method="mode",
name="vae",
)
clip_l = CLIPTextEncoder(
20, 32, 32, 2, 2, 64, "quick_gelu", -2, name="clip_l"
)
self.init_kwargs = {
"input_channels": 256,
"hidden_size": 1024,
"mlp_ratio": 2.0,
"num_heads": 8,
"depth": 4,
"depth_single_blocks": 8,
"axes_dim": [16, 56, 56],
"theta": 10_000,
"use_bias": True,
"guidance_embed": True,
"image_shape": (32, 256),
"text_shape": (32, 256),
"image_ids_shape": (32, 3),
"text_ids_shape": (32, 3),
"y_shape": (256,),
}

self.pipeline_models = {
"vae": vae,
"clip_l": clip_l,
}

self.input_data = {
"image": ops.ones((1, 32, 256)),
"image_ids": ops.ones((1, 32, 3)),
"text": ops.ones((1, 32, 256)),
"text_ids": ops.ones((1, 32, 3)),
"y": ops.ones((1, 256)),
"timesteps": ops.ones((1)),
"guidance": ops.ones((1)),
}

def test_backbone_basics(self):
self.run_backbone_test(
cls=FluxBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output_shape=(1, 32, 256),
run_mixed_precision_check=False,
run_quantization_check=False,
)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
cls=FluxBackbone,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
)
Loading

0 comments on commit 0756fb4

Please sign in to comment.