Skip to content

Commit

Permalink
Add clipUtils (#2473)
Browse files Browse the repository at this point in the history
* Add clipUtils

* Update vit_layers.py

* Update vit_layers.py

* Update MMDit_block.py

* Update clip_utils.py

* import fix

* Update vit_layers.py

* import layer fix

* import layer fix

* fix imports

* import changes

* fix failing tests

* nit

* nit

* fix pylint errors

* add Copyright

* nit

---------

Co-authored-by: Divyashree Sreepathihalli <[email protected]>
  • Loading branch information
sachinprasadhs and divyashreepathihalli authored Aug 13, 2024
1 parent a452c60 commit e9c0727
Show file tree
Hide file tree
Showing 9 changed files with 365 additions and 58 deletions.
42 changes: 25 additions & 17 deletions keras_cv/src/layers/vit_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import math

import tensorflow as tf
from keras import layers
from keras import ops

from keras_cv.api_export import keras_cv_export
from keras_cv.src.api_export import keras_cv_export
from keras_cv.src.backend import keras
from keras_cv.src.backend import ops


@keras_cv_export("keras_cv.layers.PatchingAndEmbedding")
class PatchingAndEmbedding(layers.Layer):
class PatchingAndEmbedding(keras.layers.Layer):
"""
Layer to patchify images, prepend a class token, positionally embed and
create a projection of patches for Vision Transformers
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(self, project_dim, patch_size, padding="VALID", **kwargs):
f"Padding must be either 'SAME' or 'VALID', but {padding} was "
"passed."
)
self.projection = layers.Conv2D(
self.projection = keras.layers.Conv2D(
filters=self.project_dim,
kernel_size=self.patch_size,
strides=self.patch_size,
Expand All @@ -88,7 +88,7 @@ def build(self, input_shape):
* input_shape[2]
// self.patch_size
)
self.position_embedding = layers.Embedding(
self.position_embedding = keras.layers.Embedding(
input_dim=self.num_patches + 1, output_dim=self.project_dim
)

Expand Down Expand Up @@ -225,32 +225,40 @@ def get_config(self):


@keras_cv_export("keras_cv.layers.Unpatching")
class Unpatching(layers.Layer):
class Unpatching(keras.layers.Layer):
"""
Layer to unpatchify image data.
This layer expects patches sorted by column and reorganizes the patches such that they will each be positioned as a
2D shape with some number of channels.
This layer expects patches sorted by column and reorganizes the patches
such that they will each be positioned as a 2D shape with some
number of channels.
Any necessary padding or truncation will be applied to reach the target shape.
Any necessary padding or truncation will be applied to reach the target
shape.
Args:
target_shape: The target image shape after unpatching, of form [height, width]
target_shape: The target image shape after unpatching,
of form [height, width]
"""

def __init__(self, target_shape):
self.target_shape = target_shape

def call(self, patches):
"""
Reconstructs an unpatched image from the sequence of column sequence patches.
Reconstructs an unpatched image from the sequence of column sequence
patches.
If there are insufficient patches to construct the image of requested dimensions, additional zero-patches will
be appended. If excessive patches are provided, unnecessary patches will be truncated from the end.
If there are insufficient patches to construct the image of requested
dimensions, additional zero-patches will
be appended. If excessive patches are provided, unnecessary patches
will be truncated from the end.
Args:
patches: Patches of images in column sequence (i.e. each patch is vertically oriented relative to the
previous patch). Expected shape of [batch_size, patch_num, patch_height, patch_width, channels].
patches: Patches of images in column sequence (i.e. each patch
is vertically oriented relative to the
previous patch). Expected shape of [batch_size, patch_num,
patch_height, patch_width, channels].
Returns:
Unpatched image: Image reconstructed from the patches,
Expand Down Expand Up @@ -284,4 +292,4 @@ def call(self, patches):
else:
corrected_patches = patches[:, :required_patches]

return ops.split(corrected_patches, patches_per_column, axis=1)
return ops.split(corrected_patches, patches_per_column, axis=1)
22 changes: 18 additions & 4 deletions keras_cv/src/models/stable_diffusion_v3/MMDit.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
from keras_cv.backend import keras
from keras_cv.layers.vit_layers import PatchingAndEmbedding
from keras_cv.models.stable_diffusion.v3 import embedding
from keras_cv.models.stable_diffusion.v3.MMDiT_block import MMDiTBlock
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writingf, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.src.backend import keras
from keras_cv.src.layers.vit_layers import PatchingAndEmbedding
from keras_cv.src.models.stable_diffusion_v3 import embedding
from keras_cv.src.models.stable_diffusion_v3.MMDit_block import MMDiTBlock


class MMDiT(keras.layers.Layer):
Expand Down
21 changes: 18 additions & 3 deletions keras_cv/src/models/stable_diffusion_v3/MMDit_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
import keras
from keras_cv.backend import ops
# Copyright 2024 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writingf, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_cv.src.backend import keras
from keras_cv.src.backend import ops


class MMDiTSelfAttention(keras.layers.Layer):
Expand All @@ -17,7 +31,8 @@ def __init__(
self.cdense = keras.layers.Dense(key_dim)

if normalization_mode == "rms_normalization":
# TODO(varuns1997): Re-Implement RMSNormalization for Keras 2 Compatibility
# TODO(varuns1997): Re-Implement RMSNormalization
# for Keras 2 Compatibility
self.query_normalization = keras.layers.LayerNormalization(
rms_scaling=True
)
Expand Down
13 changes: 13 additions & 0 deletions keras_cv/src/models/stable_diffusion_v3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2023 The KerasCV Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit e9c0727

Please sign in to comment.