diff --git a/keras_hub/api/layers/__init__.py b/keras_hub/api/layers/__init__.py index 78a26075d1..2060fb0da3 100644 --- a/keras_hub/api/layers/__init__.py +++ b/keras_hub/api/layers/__init__.py @@ -41,6 +41,9 @@ DenseNetImageConverter, ) from keras_hub.src.models.mit.mit_image_converter import MiTImageConverter +from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( + MobileNetImageConverter, +) from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import ( PaliGemmaImageConverter, ) diff --git a/keras_hub/api/models/__init__.py b/keras_hub/api/models/__init__.py index e0e8773a35..88aa733c78 100644 --- a/keras_hub/api/models/__init__.py +++ b/keras_hub/api/models/__init__.py @@ -211,6 +211,9 @@ from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( MobileNetImageClassifier, ) +from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( + MobileNetImageClassifierPreprocessor, +) from keras_hub.src.models.opt.opt_backbone import OPTBackbone from keras_hub.src.models.opt.opt_causal_lm import OPTCausalLM from keras_hub.src.models.opt.opt_causal_lm_preprocessor import ( diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone.py b/keras_hub/src/models/mobilenet/mobilenet_backbone.py index c4fe6f3413..e40eac32b1 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone.py @@ -4,8 +4,8 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.backbone import Backbone -BN_EPSILON = 1e-3 -BN_MOMENTUM = 0.999 +BN_EPSILON = 1e-5 +BN_MOMENTUM = 0.9 @keras_hub_export("keras_hub.models.MobileNetBackbone") @@ -29,29 +29,24 @@ class MobileNetBackbone(Backbone): (ICCV 2019) Args: - stackwise_expansion: list of ints or floats, the expansion ratio for + stackwise_expansion: list of list of ints, the expanded filters for + each inverted residual block for each block in the model. + stackwise_num_blocks: list of ints, number of inversted residual blocks + per block + stackwise_num_filters: list of list of ints, number of filters for each inverted residual block in the model. - stackwise_num_filters: list of ints, number of filters for each inverted - residual block in the model. - stackwise_kernel_size: list of ints, kernel size for each inverted - residual block in the model. - stackwise_num_strides: list of ints, stride length for each inverted - residual block in the model. + stackwise_kernel_size: list of list of ints, kernel size for each + inverted residual block in the model. + stackwise_num_strides: list of list of ints, stride length for each + inverted residual block in the model. stackwise_se_ratio: se ratio for each inverted residual block in the model. 0 if dont want to add Squeeze and Excite layer. - stackwise_activation: list of activation functions, for each inverted - residual block in the model. + stackwise_activation: list of list of activation functions, for each + inverted residual block in the model. image_shape: optional shape tuple, defaults to (224, 224, 3). - depth_multiplier: float, controls the width of the network. - - If `depth_multiplier` < 1.0, proportionally decreases the number - of filters in each layer. - - If `depth_multiplier` > 1.0, proportionally increases the number - of filters in each layer. - - If `depth_multiplier` = 1, default number of filters from the paper - are used at each layer. input_num_filters: number of filters in first convolution layer - output_num_filters: specifies whether to add conv and batch_norm in the end, - if set to None, it will not add these layers in the end. + output_num_filters: specifies whether to add conv and batch_norm in the + end, if set to None, it will not add these layers in the end. 'None' for MobileNetV1 input_activation: activation function to be used in the input layer 'hard_swish' for MobileNetV3, @@ -59,9 +54,10 @@ class MobileNetBackbone(Backbone): output_activation: activation function to be used in the output layer 'hard_swish' for MobileNetV3, 'relu6' for MobileNetV1 and MobileNetV2 - inverted_res_block: whether to use inverted residual blocks or not, - 'False' for MobileNetV1, - 'True' for MobileNetV2 and MobileNetV3 + depthwise_filters: int, number of filters in depthwise separable + convolution layer + squeeze_and_excite: float, squeeze and excite ratio in the depthwise + layer, None, if dont want to do squeeze and excite Example: @@ -70,16 +66,40 @@ class MobileNetBackbone(Backbone): # Randomly initialized backbone with a custom config model = MobileNetBackbone( - stackwise_expansion=[1, 4, 6], - stackwise_num_filters=[4, 8, 16], - stackwise_kernel_size=[3, 3, 5], - stackwise_num_strides=[2, 2, 1], - stackwise_se_ratio=[0.25, None, 0.25], - stackwise_activation=["relu", "relu6", "hard_swish"], - output_num_filters=1280, - input_activation='hard_swish', - output_activation='hard_swish', - inverted_res_block=True, + stackwise_expansion=[ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ], + stackwise_num_blocks=[2, 3, 2, 3], + stackwise_num_filters=[ + [16, 16], + [24, 24, 24], + [24, 24], + [48, 48, 48], + ], + stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]], + stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]], + stackwise_se_ratio=[ + [None, None], + [0.25, 0.25, 0.25], + [0.3, 0.3], + [0.3, 0.25, 0.25], + ], + stackwise_activation=[ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ], + output_num_filters=288, + input_activation="hard_swish", + output_activation="hard_swish", + input_num_filters=16, + image_shape=(224, 224, 3), + depthwise_filters=8, + squeeze_and_excite=0.5, ) output = model(input_data) @@ -89,17 +109,20 @@ class MobileNetBackbone(Backbone): def __init__( self, stackwise_expansion, + stackwise_num_blocks, stackwise_num_filters, stackwise_kernel_size, stackwise_num_strides, stackwise_se_ratio, stackwise_activation, + stackwise_padding, output_num_filters, - inverted_res_block, + depthwise_filters, + last_layer_filter, + squeeze_and_excite=None, image_shape=(None, None, 3), input_activation="hard_swish", output_activation="hard_swish", - depth_multiplier=1.0, input_num_filters=16, **kwargs, ): @@ -109,13 +132,20 @@ def __init__( ) image_input = keras.layers.Input(shape=image_shape) - x = image_input # Intermediate result. + x = image_input input_num_filters = adjust_channels(input_num_filters) + + pad_width = ( + (0, 0), # No padding for batch + (1, 1), # 1 pixel padding for height + (1, 1), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) x = keras.layers.Conv2D( input_num_filters, kernel_size=3, strides=(2, 2), - padding="same", data_format=keras.config.image_data_format(), use_bias=False, name="input_conv", @@ -128,66 +158,72 @@ def __init__( )(x) x = keras.layers.Activation(input_activation)(x) - for stack_index in range(len(stackwise_num_filters)): - filters = adjust_channels( - (stackwise_num_filters[stack_index]) * depth_multiplier - ) + x = apply_depthwise_conv_block( + x, depthwise_filters, se=squeeze_and_excite, name="block_0" + ) - if inverted_res_block: + for block in range(len(stackwise_num_blocks)): + for inverted_block in range(stackwise_num_blocks[block]): x = apply_inverted_res_block( x, - expansion=stackwise_expansion[stack_index], - filters=filters, - kernel_size=stackwise_kernel_size[stack_index], - stride=stackwise_num_strides[stack_index], - se_ratio=(stackwise_se_ratio[stack_index]), - activation=stackwise_activation[stack_index], - expansion_index=stack_index, - ) - else: - x = apply_depthwise_conv_block( - x, - filters=filters, - kernel_size=3, - stride=stackwise_num_strides[stack_index], - depth_multiplier=depth_multiplier, - block_id=stack_index, + expansion=stackwise_expansion[block][inverted_block], + filters=adjust_channels( + stackwise_num_filters[block][inverted_block] + ), + kernel_size=stackwise_kernel_size[block][inverted_block], + stride=stackwise_num_strides[block][inverted_block], + se_ratio=stackwise_se_ratio[block][inverted_block], + activation=stackwise_activation[block][inverted_block], + padding=stackwise_padding[block][inverted_block], + name=f"block_{block+1}_{inverted_block}", ) - if output_num_filters is not None: - last_conv_ch = adjust_channels(x.shape[channel_axis] * 6) + x = ConvBnAct( + x, + filter=adjust_channels(last_layer_filter), + activation="hard_swish", + name=f"block_{len(stackwise_num_blocks)+1}_0", + ) - x = keras.layers.Conv2D( - last_conv_ch, - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name="output_conv", - )(x) + last_conv_ch = adjust_channels(output_num_filters) + + x = keras.layers.Conv2D( + last_conv_ch, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name="output_conv", + )(x) + + # no output normalization in mobilenetv3 + if output_activation == "relu6": x = keras.layers.BatchNormalization( axis=channel_axis, epsilon=BN_EPSILON, momentum=BN_MOMENTUM, name="output_batch_norm", )(x) - x = keras.layers.Activation(output_activation)(x) + + x = keras.layers.Activation(output_activation)(x) super().__init__(inputs=image_input, outputs=x, **kwargs) # === Config === self.stackwise_expansion = stackwise_expansion + self.stackwise_num_blocks = stackwise_num_blocks self.stackwise_num_filters = stackwise_num_filters self.stackwise_kernel_size = stackwise_kernel_size self.stackwise_num_strides = stackwise_num_strides self.stackwise_se_ratio = stackwise_se_ratio self.stackwise_activation = stackwise_activation - self.depth_multiplier = depth_multiplier + self.stackwise_padding = stackwise_padding self.input_num_filters = input_num_filters self.output_num_filters = output_num_filters + self.depthwise_filters = depthwise_filters + self.last_layer_filter = last_layer_filter + self.squeeze_and_excite = squeeze_and_excite self.input_activation = keras.activations.get(input_activation) self.output_activation = keras.activations.get(output_activation) - self.inverted_res_block = inverted_res_block self.image_shape = image_shape def get_config(self): @@ -195,22 +231,25 @@ def get_config(self): config.update( { "stackwise_expansion": self.stackwise_expansion, + "stackwise_num_blocks": self.stackwise_num_blocks, "stackwise_num_filters": self.stackwise_num_filters, "stackwise_kernel_size": self.stackwise_kernel_size, "stackwise_num_strides": self.stackwise_num_strides, "stackwise_se_ratio": self.stackwise_se_ratio, "stackwise_activation": self.stackwise_activation, + "stackwise_padding": self.stackwise_padding, "image_shape": self.image_shape, - "depth_multiplier": self.depth_multiplier, "input_num_filters": self.input_num_filters, "output_num_filters": self.output_num_filters, + "depthwise_filters": self.depthwise_filters, + "last_layer_filter": self.last_layer_filter, + "squeeze_and_excite": self.squeeze_and_excite, "input_activation": keras.activations.serialize( activation=self.input_activation ), "output_activation": keras.activations.serialize( activation=self.output_activation ), - "inverted_res_block": self.inverted_res_block, } ) return config @@ -249,7 +288,8 @@ def apply_inverted_res_block( stride, se_ratio, activation, - expansion_index, + padding, + name=None, ): """An Inverted Residual Block. @@ -263,9 +303,8 @@ def apply_inverted_res_block( se_ratio: float, ratio for bottleneck filters. Number of bottleneck filters = filters * se_ratio. activation: the activation layer to use. - expansion_index: integer, a unique identification if you want to use - expanded convolutions. If greater than 0, an additional Conv+BN - layer is added after the expanded convolutional layer. + padding: padding in the conv2d layer + name: string, block label. Returns: the updated input tensor. @@ -275,88 +314,91 @@ def apply_inverted_res_block( ) activation = keras.activations.get(activation) shortcut = x - prefix = "expanded_conv_" infilters = x.shape[channel_axis] + expanded_channels = adjust_channels(expansion) - if expansion_index > 0: - prefix = f"expanded_conv_{expansion_index}_" + x = keras.layers.Conv2D( + expanded_channels, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv1", + )(x) - x = keras.layers.Conv2D( - adjust_channels(infilters * expansion), - kernel_size=1, - padding="same", - data_format=keras.config.image_data_format(), - use_bias=False, - name=prefix + "expand", - )(x) - x = keras.layers.BatchNormalization( - axis=channel_axis, - epsilon=BN_EPSILON, - momentum=BN_MOMENTUM, - name=prefix + "expand_BatchNorm", - )(x) - x = keras.layers.Activation(activation=activation)(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn1", + )(x) - if stride == 2: - x = keras.layers.ZeroPadding2D( - padding=correct_pad_downsample(x, kernel_size), - name=prefix + "depthwise_pad", - )(x) + x = keras.layers.Activation(activation=activation)(x) - x = keras.layers.DepthwiseConv2D( + # if stride == 2: + # x = keras.layers.ZeroPadding2D( + # padding=correct_pad_downsample(x, kernel_size), + # )(x) + + # pad_width=[[padding, padding], [padding, padding]] + pad_width = ( + (0, 0), # No padding for batch + (padding, padding), # 1 pixel padding for height + (padding, padding), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) + + x = keras.layers.Conv2D( + expanded_channels, kernel_size, strides=stride, - padding="same" if stride == 1 else "valid", + padding="valid", + groups=expanded_channels, data_format=keras.config.image_data_format(), use_bias=False, - name=prefix + "depthwise", + name=f"{name}_conv2", )(x) x = keras.layers.BatchNormalization( axis=channel_axis, epsilon=BN_EPSILON, momentum=BN_MOMENTUM, - name=prefix + "depthwise_BatchNorm", + name=f"{name}_bn2", )(x) + x = keras.layers.Activation(activation=activation)(x) if se_ratio: - se_filters = adjust_channels(infilters * expansion) + se_filters = expanded_channels x = SqueezeAndExcite2D( input=x, filters=se_filters, bottleneck_filters=adjust_channels(se_filters * se_ratio), squeeze_activation="relu", excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", ) x = keras.layers.Conv2D( filters, kernel_size=1, - padding="same", data_format=keras.config.image_data_format(), use_bias=False, - name=prefix + "project", + name=f"{name}_conv3", )(x) x = keras.layers.BatchNormalization( axis=channel_axis, epsilon=BN_EPSILON, momentum=BN_MOMENTUM, - name=prefix + "project_BatchNorm", + name=f"{name}_bn3", )(x) if stride == 1 and infilters == filters: - x = keras.layers.Add(name=prefix + "Add")([shortcut, x]) - + x = keras.layers.Add(name=f"{name}_add")([shortcut, x]) return x def apply_depthwise_conv_block( - x, - filters, - kernel_size=3, - depth_multiplier=1, - stride=1, - block_id=1, + x, filters, kernel_size=3, stride=2, se=None, name=None ): """Adds a depthwise convolution block. @@ -368,13 +410,6 @@ def apply_depthwise_conv_block( x: Input tensor of shape `(rows, cols, channels) filters: Integer, the dimensionality of the output space (i.e. the number of output filters in the pointwise convolution). - depth_multiplier: controls the width of the network. - - If `depth_multiplier` < 1.0, proportionally decreases the number - of filters in each layer. - - If `depth_multiplier` > 1.0, proportionally increases the number - of filters in each layer. - - If `depth_multiplier` = 1, default number of filters from the - paper are used at each layer. strides: An integer or tuple/list of 2 integers, specifying the strides of the convolution along the width and height. Can be a single integer to specify the same value for @@ -391,44 +426,62 @@ def apply_depthwise_conv_block( channel_axis = ( -1 if keras.config.image_data_format() == "channels_last" else 1 ) - if stride == 2: - x = keras.layers.ZeroPadding2D( - padding=correct_pad_downsample(x, kernel_size), - name="conv_pad_%d" % block_id, - )(x) - - x = keras.layers.DepthwiseConv2D( + infilters = x.shape[channel_axis] + name = f"{name}_0" + + # if stride == 2: + # x = keras.layers.ZeroPadding2D( + # padding=correct_pad_downsample(x, kernel_size), + # )(x) + pad_width = ( + (0, 0), # No padding for batch + (1, 1), # 1 pixel padding for height + (1, 1), # 1 pixel padding for width + (0, 0), + ) # No padding for channels + x = ops.pad(x, pad_width=pad_width) + x = keras.layers.Conv2D( + infilters, kernel_size, strides=stride, - padding="same" if stride == 1 else "valid", + padding="valid", data_format=keras.config.image_data_format(), - depth_multiplier=depth_multiplier, + groups=infilters, use_bias=False, - name="depthwise_%d" % block_id, + name=f"{name}_conv1", )(x) x = keras.layers.BatchNormalization( axis=channel_axis, epsilon=BN_EPSILON, momentum=BN_MOMENTUM, - name="depthwise_BatchNorm_%d" % block_id, + name=f"{name}_bn1", )(x) x = keras.layers.ReLU(6.0)(x) + if se: + x = SqueezeAndExcite2D( + input=x, + filters=infilters, + bottleneck_filters=adjust_channels(infilters * se), + squeeze_activation="relu", + excite_activation=keras.activations.hard_sigmoid, + name=f"{name}_se", + ) + x = keras.layers.Conv2D( filters, kernel_size=1, - padding="same", data_format=keras.config.image_data_format(), use_bias=False, - name="conv_%d" % block_id, + name=f"{name}_conv2", )(x) x = keras.layers.BatchNormalization( axis=channel_axis, epsilon=BN_EPSILON, momentum=BN_MOMENTUM, - name="BatchNorm_%d" % block_id, + name=f"{name}_bn2", )(x) - return keras.layers.ReLU(6.0)(x) + return x def SqueezeAndExcite2D( @@ -437,6 +490,7 @@ def SqueezeAndExcite2D( bottleneck_filters=None, squeeze_activation="relu", excite_activation="sigmoid", + name=None, ): """ Description: @@ -458,29 +512,52 @@ def SqueezeAndExcite2D( keras.layers.Layer) or keras.activations.Activation instance denoting activation to be applied after excite convolution. Defaults to `sigmoid`. + name: Name of the layer """ if not bottleneck_filters: bottleneck_filters = filters // 4 - x = keras.layers.GlobalAveragePooling2D(keepdims=True)(input) - + x = input x = keras.layers.Conv2D( bottleneck_filters, (1, 1), data_format=keras.config.image_data_format(), activation=squeeze_activation, + name=f"{name}_conv_reduce", )(x) x = keras.layers.Conv2D( filters, (1, 1), data_format=keras.config.image_data_format(), activation=excite_activation, + name=f"{name}_conv_expand", )(x) x = ops.multiply(x, input) return x +def ConvBnAct(x, filter, activation, name=None): + channel_axis = ( + -1 if keras.config.image_data_format() == "channels_last" else 1 + ) + x = keras.layers.Conv2D( + filter, + kernel_size=1, + data_format=keras.config.image_data_format(), + use_bias=False, + name=f"{name}_conv", + )(x) + x = keras.layers.BatchNormalization( + axis=channel_axis, + epsilon=BN_EPSILON, + momentum=BN_MOMENTUM, + name=f"{name}_bn", + )(x) + x = keras.layers.Activation(activation)(x) + return x + + def correct_pad_downsample(inputs, kernel_size): """Returns a tuple for zero-padding for 2D convolution with downsampling. diff --git a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py index 24fdd0db4c..3d909c9221 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_backbone_test.py @@ -1,43 +1,78 @@ -import numpy as np -import pytest - -from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.tests.test_case import TestCase - - -class MobileNetBackboneTest(TestCase): - def setUp(self): - self.init_kwargs = { - "stackwise_expansion": [1, 4, 6], - "stackwise_num_filters": [4, 8, 16], - "stackwise_kernel_size": [3, 3, 5], - "stackwise_num_strides": [2, 2, 1], - "stackwise_se_ratio": [0.25, None, 0.25], - "stackwise_activation": ["relu", "relu", "hard_swish"], - "output_num_filters": 1280, - "input_activation": "hard_swish", - "output_activation": "hard_swish", - "inverted_res_block": True, - "input_num_filters": 16, - "image_shape": (224, 224, 3), - "depth_multiplier": 1, - } - self.input_data = np.ones((2, 224, 224, 3), dtype="float32") - - def test_backbone_basics(self): - self.run_vision_backbone_test( - cls=MobileNetBackbone, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - expected_output_shape=(2, 28, 28, 96), - run_mixed_precision_check=False, - run_data_format_check=False, - ) - - @pytest.mark.large - def test_saved_model(self): - self.run_model_saving_test( - cls=MobileNetBackbone, - init_kwargs=self.init_kwargs, - input_data=self.input_data, - ) +import numpy as np +import pytest + +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_hub.src.tests.test_case import TestCase + + +class MobileNetBackboneTest(TestCase): + def setUp(self): + + self.init_kwargs = { + "stackwise_expansion": [ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ], + "stackwise_num_blocks": [2, 3, 2, 3], + "stackwise_num_filters": [ + [16, 16], + [24, 24, 24], + [24, 24], + [48, 48, 48], + ], + "stackwise_kernel_size": [ + [3, 3], + [5, 5, 5], + [5, 5], + [5, 5, 5], + ], + "stackwise_num_strides": [ + [2, 1], + [2, 1, 1], + [1, 1], + [2, 1, 1], + ], + "stackwise_se_ratio": [ + [None, None], + [0.25, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], + ], + "stackwise_activation": [ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish"], + ], + "stackwise_padding": [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]], + "output_num_filters": 1024, + "input_activation": "hard_swish", + "output_activation": "hard_swish", + "input_num_filters": 16, + "image_shape": (224, 224, 3), + "depthwise_filters": 8, + "squeeze_and_excite": 0.5, + "last_layer_filter": 288, + } + self.input_data = np.ones((2, 224, 224, 3), dtype="float32") + + def test_backbone_basics(self): + self.run_vision_backbone_test( + cls=MobileNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + expected_output_shape=(2, 7, 7, 1024), + run_mixed_precision_check=False, + run_data_format_check=False, + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetBackbone, + init_kwargs=self.init_kwargs, + input_data=self.input_data, + ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py index 96977bdf9f..e9cc0fc153 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier.py @@ -1,8 +1,12 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.image_classifier import ImageClassifier from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_hub.src.models.mobilenet.mobilenet_image_classifier_preprocessor import ( + MobileNetImageClassifierPreprocessor, +) @keras_hub_export("keras_hub.models.MobileNetImageClassifier") class MobileNetImageClassifier(ImageClassifier): backbone_cls = MobileNetBackbone + preprocessor_cls = MobileNetImageClassifierPreprocessor diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_preprocessor.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_preprocessor.py new file mode 100644 index 0000000000..2ad3ef1ed7 --- /dev/null +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_preprocessor.py @@ -0,0 +1,14 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.models.image_classifier_preprocessor import ( + ImageClassifierPreprocessor, +) +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_hub.src.models.mobilenet.mobilenet_image_converter import ( + MobileNetImageConverter, +) + + +@keras_hub_export("keras_hub.models.MobileNetImageClassifierPreprocessor") +class MobileNetImageClassifierPreprocessor(ImageClassifierPreprocessor): + backbone_cls = MobileNetBackbone + image_converter_cls = MobileNetImageConverter diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py index 57ebd65039..7997b444fd 100644 --- a/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py +++ b/keras_hub/src/models/mobilenet/mobilenet_image_classifier_test.py @@ -1,57 +1,93 @@ -import numpy as np -import pytest - -from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone -from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( - MobileNetImageClassifier, -) -from keras_hub.src.tests.test_case import TestCase - - -class MobileNetImageClassifierTest(TestCase): - def setUp(self): - # Setup model. - self.images = np.ones((2, 224, 224, 3), dtype="float32") - self.labels = [0, 3] - self.backbone = MobileNetBackbone( - stackwise_expansion=[1, 4, 6], - stackwise_num_filters=[4, 8, 16], - stackwise_kernel_size=[3, 3, 5], - stackwise_num_strides=[2, 2, 1], - stackwise_se_ratio=[0.25, None, 0.25], - stackwise_activation=["relu", "relu", "hard_swish"], - output_num_filters=1280, - input_activation="hard_swish", - output_activation="hard_swish", - inverted_res_block=True, - input_num_filters=16, - image_shape=(224, 224, 3), - ) - self.init_kwargs = { - "backbone": self.backbone, - "num_classes": 2, - "activation": "softmax", - } - self.train_data = ( - self.images, - self.labels, - ) - - def test_classifier_basics(self): - pytest.skip( - reason="TODO: enable after preprocessor flow is figured out" - ) - self.run_task_test( - cls=MobileNetImageClassifier, - init_kwargs=self.init_kwargs, - train_data=self.train_data, - expected_output_shape=(2, 2), - ) - - @pytest.mark.large - def test_saved_model(self): - self.run_model_saving_test( - cls=MobileNetImageClassifier, - init_kwargs=self.init_kwargs, - input_data=self.images, - ) +import numpy as np +import pytest + +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone +from keras_hub.src.models.mobilenet.mobilenet_image_classifier import ( + MobileNetImageClassifier, +) +from keras_hub.src.tests.test_case import TestCase + + +class MobileNetImageClassifierTest(TestCase): + def setUp(self): + # Setup model. + self.images = np.ones((2, 224, 224, 3), dtype="float32") + self.labels = [0, 3] + self.backbone = MobileNetBackbone( + stackwise_expansion=[ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ], + stackwise_num_blocks=[2, 3, 2, 3], + stackwise_num_filters=[ + [16, 16], + [24, 24, 24], + [24, 24], + [48, 48, 48], + ], + stackwise_kernel_size=[[3, 3], [5, 5, 5], [5, 5], [5, 5, 5], [1]], + stackwise_num_strides=[[2, 1], [2, 1, 1], [1, 1], [2, 1, 1], [1]], + stackwise_se_ratio=[ + [None, None], + [0.25, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], + ], + stackwise_activation=[ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ], + stackwise_padding=[[1, 1], [2, 2, 2], [2, 2], [2, 2, 2], [1]], + output_num_filters=1024, + input_activation="hard_swish", + output_activation="hard_swish", + input_num_filters=16, + image_shape=(224, 224, 3), + depthwise_filters=8, + squeeze_and_excite=0.5, + last_layer_filter=288, + ) + self.init_kwargs = { + "backbone": self.backbone, + "num_classes": 2, + "activation": "softmax", + } + self.train_data = ( + self.images, + self.labels, + ) + + def test_classifier_basics(self): + pytest.skip( + reason="TODO: enable after preprocessor flow is figured out" + ) + self.run_task_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + train_data=self.train_data, + expected_output_shape=(2, 2), + ) + + @pytest.mark.large + def test_smallest_preset(self): + # Test that our forward pass is stable! + image_batch = self.load_test_image()[None, ...] / 255.0 + self.run_preset_test( + cls=MobileNetImageClassifier, + preset="mobilenetv3_small_050", + input_data=image_batch, + expected_output_shape=(1, 1000), + expected_labels=[85], + ) + + @pytest.mark.large + def test_saved_model(self): + self.run_model_saving_test( + cls=MobileNetImageClassifier, + init_kwargs=self.init_kwargs, + input_data=self.images, + ) diff --git a/keras_hub/src/models/mobilenet/mobilenet_image_converter.py b/keras_hub/src/models/mobilenet/mobilenet_image_converter.py new file mode 100644 index 0000000000..da6fb0ab6a --- /dev/null +++ b/keras_hub/src/models/mobilenet/mobilenet_image_converter.py @@ -0,0 +1,8 @@ +from keras_hub.src.api_export import keras_hub_export +from keras_hub.src.layers.preprocessing.image_converter import ImageConverter +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone + + +@keras_hub_export("keras_hub.layers.MobileNetImageConverter") +class MobileNetImageConverter(ImageConverter): + backbone_cls = MobileNetBackbone diff --git a/keras_hub/src/models/mobilenet/mobilenet_presets.py b/keras_hub/src/models/mobilenet/mobilenet_presets.py new file mode 100644 index 0000000000..172e7fdbf6 --- /dev/null +++ b/keras_hub/src/models/mobilenet/mobilenet_presets.py @@ -0,0 +1,15 @@ +"""MobileNet preset configurations.""" + +backbone_presets = { + "mobilenetv3_small_050": { + "metadata": { + "description": ( + "Small MObilenet V3 model pre-trained on the ImageNet 1k dataset " + "at a 224x224 resolution." + ), + "official_name": "MobileNet", + "path": "mobilenet3", + }, + "kaggle_handle": "kaggle://keras/mobilenet/keras/mobilenetv3_small_050", + }, +} diff --git a/keras_hub/src/utils/timm/convert_mobilenet.py b/keras_hub/src/utils/timm/convert_mobilenet.py new file mode 100644 index 0000000000..307f4a4acc --- /dev/null +++ b/keras_hub/src/utils/timm/convert_mobilenet.py @@ -0,0 +1,168 @@ +import numpy as np + +from keras_hub.src.models.mobilenet.mobilenet_backbone import MobileNetBackbone + +backbone_cls = MobileNetBackbone + + +def convert_backbone_config(timm_config): + timm_architecture = timm_config["architecture"] + + if "mobilenetv3_" in timm_architecture: + input_activation = "hard_swish" + output_activation = "hard_swish" + + else: + input_activation = "relu6" + output_activation = "relu6" + + if timm_architecture == "mobilenetv3_small_050": + stackwise_num_blocks = [2, 3, 2, 3] + stackwise_expansion = [ + [40, 56], + [64, 144, 144], + [72, 72], + [144, 288, 288], + ] + stackwise_num_filters = [[16, 16], [24, 24, 24], [24, 24], [48, 48, 48]] + stackwise_kernel_size = [[3, 3], [5, 5, 5], [5, 5], [5, 5, 5]] + stackwise_num_strides = [[2, 1], [2, 1, 1], [1, 1], [2, 1, 1]] + stackwise_se_ratio = [ + [None, None], + [0.25, 0.25, 0.25], + [0.25, 0.25], + [0.25, 0.25, 0.25], + ] + stackwise_activation = [ + ["relu", "relu"], + ["hard_swish", "hard_swish", "hard_swish"], + ["hard_swish", "hard_swish"], + ["hard_swish", "hard_swish", "hard_swish"], + ] + stackwise_padding = [[1, 1], [2, 2, 2], [2, 2], [2, 2, 2]] + output_num_filters = 1024 + input_num_filters = 16 + depthwise_filters = 8 + squeeze_and_excite = 0.5 + last_layer_filter = 288 + else: + raise ValueError( + f"Currently, the architecture {timm_architecture} is not supported." + ) + + return dict( + input_num_filters=input_num_filters, + input_activation=input_activation, + depthwise_filters=depthwise_filters, + squeeze_and_excite=squeeze_and_excite, + stackwise_num_blocks=stackwise_num_blocks, + stackwise_expansion=stackwise_expansion, + stackwise_num_filters=stackwise_num_filters, + stackwise_kernel_size=stackwise_kernel_size, + stackwise_num_strides=stackwise_num_strides, + stackwise_se_ratio=stackwise_se_ratio, + stackwise_activation=stackwise_activation, + stackwise_padding=stackwise_padding, + output_num_filters=output_num_filters, + output_activation=output_activation, + last_layer_filter=last_layer_filter, + ) + + +def convert_weights(backbone, loader, timm_config): + def port_conv2d(keras_layer_name, hf_weight_prefix): + print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") + loader.port_weight( + backbone.get_layer(keras_layer_name).kernel, + hf_weight_key=f"{hf_weight_prefix}.weight", + hook_fn=lambda x, _: np.transpose(x, (2, 3, 1, 0)), + ) + + def port_batch_normalization(keras_layer_name, hf_weight_prefix): + print(f"porting weights {hf_weight_prefix} -> {keras_layer_name}") + loader.port_weight( + backbone.get_layer(keras_layer_name).gamma, + hf_weight_key=f"{hf_weight_prefix}.weight", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).beta, + hf_weight_key=f"{hf_weight_prefix}.bias", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_mean, + hf_weight_key=f"{hf_weight_prefix}.running_mean", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + loader.port_weight( + backbone.get_layer(keras_layer_name).moving_variance, + hf_weight_key=f"{hf_weight_prefix}.running_var", + ) + + # Stem + port_conv2d("input_conv", "conv_stem") + port_batch_normalization("input_batch_norm", "bn1") + + # DepthWise Block (block 0) + hf_name = "blocks.0.0" + keras_name = "block_0_0" + + port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_dw") + port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") + + port_conv2d(f"{keras_name}_se_conv_reduce", f"{hf_name}.se.conv_reduce") + port_conv2d(f"{keras_name}_se_conv_expand", f"{hf_name}.se.conv_expand") + + port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_pw") + port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") + + # Stages + num_stacks = len(backbone.stackwise_num_blocks) + for block_idx in range(num_stacks): + for inverted_block in range(backbone.stackwise_num_blocks[block_idx]): + keras_name = f"block_{block_idx+1}_{inverted_block}" + hf_name = f"blocks.{block_idx+1}.{inverted_block}" + + # Inverted Residual Block + port_conv2d(f"{keras_name}_conv1", f"{hf_name}.conv_pw") + port_batch_normalization(f"{keras_name}_bn1", f"{hf_name}.bn1") + port_conv2d(f"{keras_name}_conv2", f"{hf_name}.conv_dw") + port_batch_normalization(f"{keras_name}_bn2", f"{hf_name}.bn2") + + if backbone.stackwise_se_ratio[block_idx][inverted_block]: + port_conv2d( + f"{keras_name}_se_conv_reduce", + f"{hf_name}.se.conv_reduce", + ) + port_conv2d( + f"{keras_name}_se_conv_expand", + f"{hf_name}.se.conv_expand", + ) + + port_conv2d(f"{keras_name}_conv3", f"{hf_name}.conv_pwl") + port_batch_normalization(f"{keras_name}_bn3", f"{hf_name}.bn3") + + # ConvBnAct Block + port_conv2d(f"block_{num_stacks+1}_0_conv", f"blocks.{num_stacks+1}.0.conv") + port_batch_normalization( + f"block_{num_stacks+1}_0_bn", f"blocks.{num_stacks+1}.0.bn1" + ) + + port_conv2d("output_conv", "conv_head") + # if version == "v2": + # port_batch_normalization("output_batch_norm", "bn2") + + +def convert_head(task, loader, timm_config): + prefix = "classifier." + loader.port_weight( + task.output_dense.kernel, + hf_weight_key=prefix + "weight", + hook_fn=lambda x, _: np.transpose(np.squeeze(x)), + ) + loader.port_weight( + task.output_dense.bias, + hf_weight_key=prefix + "bias", + ) diff --git a/keras_hub/src/utils/timm/convert_mobilenet_test.py b/keras_hub/src/utils/timm/convert_mobilenet_test.py new file mode 100644 index 0000000000..59c504b306 --- /dev/null +++ b/keras_hub/src/utils/timm/convert_mobilenet_test.py @@ -0,0 +1,26 @@ +import pytest +from keras import ops + +from keras_hub.src.models.backbone import Backbone +from keras_hub.src.models.image_classifier import ImageClassifier +from keras_hub.src.tests.test_case import TestCase + + +class TimmMobileNetBackboneTest(TestCase): + @pytest.mark.large + def test_convert_mobilenet_backbone(self): + model = Backbone.from_preset( + "hf://timm/mobilenetv3_small_050.lamb_in1k" + ) + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 7, 7, 1024)) + + @pytest.mark.large + def test_convert_mobilenet_classifier(self): + model = ImageClassifier.from_preset( + "hf://timm/mobilenetv3_small_050.lamb_in1k" + ) + outputs = model.predict(ops.ones((1, 224, 224, 3))) + self.assertEqual(outputs.shape, (1, 1000)) + + # TODO: compare numerics with timm model diff --git a/keras_hub/src/utils/timm/preset_loader.py b/keras_hub/src/utils/timm/preset_loader.py index 1524db8530..392f432bb1 100644 --- a/keras_hub/src/utils/timm/preset_loader.py +++ b/keras_hub/src/utils/timm/preset_loader.py @@ -4,6 +4,7 @@ from keras_hub.src.utils.preset_utils import PresetLoader from keras_hub.src.utils.preset_utils import jax_memory_cleanup from keras_hub.src.utils.timm import convert_densenet +from keras_hub.src.utils.timm import convert_mobilenet from keras_hub.src.utils.timm import convert_resnet from keras_hub.src.utils.timm import convert_vgg from keras_hub.src.utils.transformers.safetensor_utils import SafetensorLoader @@ -17,6 +18,8 @@ def __init__(self, preset, config): self.converter = convert_resnet elif "densenet" in architecture: self.converter = convert_densenet + elif "mobilenet" in architecture: + self.converter = convert_mobilenet elif "vgg" in architecture: self.converter = convert_vgg else: diff --git a/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py new file mode 100644 index 0000000000..270d18eef9 --- /dev/null +++ b/tools/checkpoint_conversion/convert_mobilenet_checkpoints.py @@ -0,0 +1,112 @@ +"""Convert mobilenet checkpoints. + +python tools/checkpoint_conversion/convert_mobilenet_checkpoints.py \ + --preset mobilenetv3_small_050 --upload_uri kaggle://alexbutcher/mobilenet/keras/mobilenetv3_small_050 +""" + +import os +import shutil + +import keras +import numpy as np +import PIL +import timm +import torch +from absl import app +from absl import flags + +import keras_hub + +PRESET_MAP = { + "mobilenetv3_small_050": "timm/mobilenetv3_small_050.lamb_in1k", +} +FLAGS = flags.FLAGS + + +flags.DEFINE_string( + "preset", + None, + "Must be a valid `MobileNet` preset from KerasHub", + required=True, +) +flags.DEFINE_string( + "upload_uri", + None, + 'Could be "kaggle://keras/{variant}/keras/{preset}_int8"', + required=False, +) + + +def validate_output(keras_model, timm_model): + file = keras.utils.get_file( + origin=( + "https://storage.googleapis.com/keras-cv/" + "models/paligemma/cow_beach_1.png" + ) + ) + image = PIL.Image.open(file) + batch = np.array([image]) + + # Preprocess with Timm. + data_config = timm.data.resolve_model_data_config(timm_model) + data_config["crop_pct"] = 1.0 # Stop timm from cropping. + transforms = timm.data.create_transform(**data_config, is_training=False) + timm_preprocessed = transforms(image) + timm_preprocessed = keras.ops.transpose(timm_preprocessed, axes=(1, 2, 0)) + timm_preprocessed = keras.ops.expand_dims(timm_preprocessed, 0) + + # Preprocess with Keras. + keras_preprocessed = keras_model.preprocessor(batch) + + # Call with Timm. Use the keras preprocessed image so we can keep modeling + # and preprocessing comparisons independent. + timm_batch = keras.ops.transpose(keras_preprocessed, axes=(0, 3, 1, 2)) + timm_batch = torch.from_numpy(np.array(timm_batch)) + timm_outputs = timm_model(timm_batch).detach().numpy() + timm_label = np.argmax(timm_outputs[0]) + + # Call with Keras. + keras_outputs = keras_model.predict(batch) + keras_label = np.argmax(keras_outputs[0]) + + print("🔶 Keras output:", keras_outputs[0, :10]) + print("🔶 TIMM output:", timm_outputs[0, :10]) + print("🔶 Keras label:", keras_label) + print("🔶 TIMM label:", timm_label) + modeling_diff = np.mean(np.abs(keras_outputs - timm_outputs)) + print("🔶 Modeling difference:", modeling_diff) + preprocessing_diff = np.mean(np.abs(keras_preprocessed - timm_preprocessed)) + print("🔶 Preprocessing difference:", preprocessing_diff) + + +def main(_): + preset = FLAGS.preset + if os.path.exists(preset): + shutil.rmtree(preset) + os.makedirs(preset) + + timm_name = PRESET_MAP[preset] + + print("✅ Loaded TIMM model.") + timm_model = timm.create_model(timm_name, pretrained=True) + timm_model = timm_model.eval() + + print("✅ Loaded KerasHub model.") + keras_model = keras_hub.models.ImageClassifier.from_preset( + "hf://" + timm_name, + ) + + keras_model.save_to_preset(f"./{preset}") + print(f"🏁 Preset saved to ./{preset}") + + validate_output(keras_model, timm_model) + + upload_uri = FLAGS.upload_uri + if upload_uri: + keras_hub.upload_preset(uri=upload_uri, preset=f"./{preset}") + print(f"🏁 Preset uploaded to {upload_uri}") + + +if __name__ == "__main__": + flags.mark_flag_as_required("preset") + app.run(main)