-
Notifications
You must be signed in to change notification settings - Fork 243
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* starter commit - ported time embeddings to keras ops * add mlpembedder * add RMS Norm re-implementation * add qknorm reimplementation * add rope, scaled dot product attention and self attention * modulation layer * fix typing * add double stream block * adjustments to doublestreamblock * add signle stream layer@ * update layers and add flux core model * functions to layers * refactor layer usage * refactor layer usage * position math args in call() * name arguments * fix arg name * start adding conversion script utils * change reshape into rearrange * add rest of weight conversion and remove redundant shape extraction * fix mlpembedder arg * remove redundant args * fix params. to self. * add license * add einops * fix default arg * expand docstrings * tanh to gelu * refactor weight conversion into tools * update weight conversion * add stand-in presets until weights are uploaded * set float32 to t.dtype in timestep embedding * update more float32s into dynamic types * dtype * dtype * enable float16 mode * update conversion script to not require flux repo * add build() methods to avoid running dummy input through model * update build call * fix build calls * style * change dummy call into build() call * reference einops issue * address docstring comments in flux layers * address docstring comments in flux maths * remove numpy * add docstrings for flux model * qkv bias -> use_bias * docstring updates * remove type hints * all img->image, txt->text * functional subclassing model * shape fixes * format * self.hidden_size -> self.dim * einops rearrange * remove build method * ops to rearrange * remove build * rearrange -> symbolic_rearrange * turn timesteps and guidance into inputs * basic preprocessor flow * refactor layer names in conversion script * add backbone tests * raise not implemented on encode, encode_text, etc. methods * styling * fix shape hack with a cleaner alternative * remove unused attributes, fix tests * change list into tuple for the expected shape * address comments * save mdel on conversion
- Loading branch information
1 parent
893a4db
commit 0756fb4
Showing
14 changed files
with
1,586 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import keras | ||
from keras import ops | ||
|
||
from keras_hub.src.api_export import keras_hub_export | ||
|
||
|
||
@keras_hub_export("keras_hub.layers.RMSNormalization") | ||
class RMSNormalization(keras.layers.Layer): | ||
""" | ||
Root Mean Square (RMS) Normalization layer. | ||
This layer normalizes the input tensor based on its RMS value and applies | ||
a learned scaling factor. | ||
Args: | ||
input_dim: int. The dimensionality of the input tensor. | ||
""" | ||
|
||
def __init__(self, input_dim): | ||
super().__init__() | ||
self.scale = self.add_weight( | ||
name="scale", shape=(input_dim,), initializer="ones" | ||
) | ||
|
||
def call(self, x): | ||
""" | ||
Applies RMS normalization to the input tensor. | ||
Args: | ||
x: KerasTensor. Input tensor of shape (batch_size, input_dim). | ||
Returns: | ||
KerasTensor: The RMS-normalized tensor of the same shape (batch_size, input_dim), | ||
scaled by the learned `scale` parameter. | ||
""" | ||
x = ops.cast(x, float) | ||
rrms = ops.rsqrt(ops.mean(ops.square(x), axis=-1, keepdims=True) + 1e-6) | ||
return (x * rrms) * self.scale |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from keras_hub.src.models.flux.flux_model import FluxBackbone | ||
from keras_hub.src.models.flux.flux_presets import presets | ||
from keras_hub.src.utils.preset_utils import register_presets | ||
|
||
register_presets(presets, FluxBackbone) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
import pytest | ||
from keras import ops | ||
|
||
from keras_hub.src.models.clip.clip_text_encoder import CLIPTextEncoder | ||
from keras_hub.src.models.flux.flux_model import FluxBackbone | ||
from keras_hub.src.models.vae.vae_backbone import VAEBackbone | ||
from keras_hub.src.tests.test_case import TestCase | ||
|
||
|
||
class FluxBackboneTest(TestCase): | ||
def setUp(self): | ||
vae = VAEBackbone( | ||
[32, 32, 32, 32], | ||
[1, 1, 1, 1], | ||
[32, 32, 32, 32], | ||
[1, 1, 1, 1], | ||
# Use `mode` generate a deterministic output. | ||
sampler_method="mode", | ||
name="vae", | ||
) | ||
clip_l = CLIPTextEncoder( | ||
20, 32, 32, 2, 2, 64, "quick_gelu", -2, name="clip_l" | ||
) | ||
self.init_kwargs = { | ||
"input_channels": 256, | ||
"hidden_size": 1024, | ||
"mlp_ratio": 2.0, | ||
"num_heads": 8, | ||
"depth": 4, | ||
"depth_single_blocks": 8, | ||
"axes_dim": [16, 56, 56], | ||
"theta": 10_000, | ||
"use_bias": True, | ||
"guidance_embed": True, | ||
"image_shape": (32, 256), | ||
"text_shape": (32, 256), | ||
"image_ids_shape": (32, 3), | ||
"text_ids_shape": (32, 3), | ||
"y_shape": (256,), | ||
} | ||
|
||
self.pipeline_models = { | ||
"vae": vae, | ||
"clip_l": clip_l, | ||
} | ||
|
||
self.input_data = { | ||
"image": ops.ones((1, 32, 256)), | ||
"image_ids": ops.ones((1, 32, 3)), | ||
"text": ops.ones((1, 32, 256)), | ||
"text_ids": ops.ones((1, 32, 3)), | ||
"y": ops.ones((1, 256)), | ||
"timesteps": ops.ones((1)), | ||
"guidance": ops.ones((1)), | ||
} | ||
|
||
def test_backbone_basics(self): | ||
self.run_backbone_test( | ||
cls=FluxBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
expected_output_shape=(1, 32, 256), | ||
run_mixed_precision_check=False, | ||
run_quantization_check=False, | ||
) | ||
|
||
@pytest.mark.large | ||
def test_saved_model(self): | ||
self.run_model_saving_test( | ||
cls=FluxBackbone, | ||
init_kwargs=self.init_kwargs, | ||
input_data=self.input_data, | ||
) |
Oops, something went wrong.