From ea8e6c740a4006b50c5a48f8c0b78d414e65af7b Mon Sep 17 00:00:00 2001 From: Laxma Reddy Patlolla Date: Thu, 19 Sep 2024 10:12:36 -0700 Subject: [PATCH] Avoid TimeDistributed layers to fix for keras 3.3.3 and Acknowledge randomness for test --- .../object_detection/mask_rcnn/mask_head.py | 60 ++++++++++--------- .../mask_rcnn/roi_sampler_test.py | 3 +- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py b/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py index c40541e558..9aec892611 100644 --- a/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py +++ b/keras_cv/src/models/object_detection/mask_rcnn/mask_head.py @@ -49,49 +49,53 @@ def __init__( self.num_deconv_filters = num_deconv_filters self.layers = [] for num_filters in stackwise_num_conv_filters: - conv = keras.layers.TimeDistributed( - keras.layers.Conv2D( - filters=num_filters, - kernel_size=3, - padding="same", - ) - ) - batchnorm = keras.layers.TimeDistributed( - keras.layers.BatchNormalization() + conv = keras.layers.Conv2D( + filters=num_filters, + kernel_size=3, + padding="same", ) + batchnorm = keras.layers.BatchNormalization() activation = keras.layers.Activation("relu") self.layers.extend([conv, batchnorm, activation]) - self.deconv = keras.layers.TimeDistributed( - keras.layers.Conv2DTranspose( - num_deconv_filters, - kernel_size=2, - strides=2, - activation="relu", - padding="valid", - ) + self.deconv = keras.layers.Conv2DTranspose( + num_deconv_filters, + kernel_size=2, + strides=2, + activation="relu", + padding="valid", ) # we do not use a final sigmoid activation, since we use # from_logits=True during training - self.segmentation_mask_output = keras.layers.TimeDistributed( - keras.layers.Conv2D( - num_classes + 1, - kernel_size=1, - strides=1, - activation="linear", - ) + self.segmentation_mask_output = keras.layers.Conv2D( + num_classes + 1, + kernel_size=1, + strides=1, + activation="linear", ) def call(self, feature_map, training=False): - x = feature_map + # reshape batch and ROI axes into one axis to obtain a suitable + # shape for conv layers + num_rois = keras.ops.shape(feature_map)[1] + x = keras.ops.reshape(feature_map, (-1, *feature_map.shape[2:])) for layer in self.layers: x = layer(x, training=training) x = self.deconv(x) - mask = self.segmentation_mask_output(x) - return mask + segmentation_mask = self.segmentation_mask_output(x) + segmentation_mask = keras.ops.reshape( + segmentation_mask, (-1, num_rois, *segmentation_mask.shape[1:]) + ) + return segmentation_mask def build(self, input_shape): - intermediate_shape = input_shape + if input_shape[0] is None or input_shape[1] is None: + intermediate_shape = (None, *input_shape[2:]) + else: + intermediate_shape = ( + input_shape[0] * input_shape[1], + *input_shape[2:], + ) for idx, num_filters in enumerate(self.stackwise_num_conv_filters): self.layers[idx * 3].build(intermediate_shape) intermediate_shape = tuple(intermediate_shape[:-1]) + (num_filters,) diff --git a/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler_test.py b/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler_test.py index 7a97f018e1..d382926855 100644 --- a/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler_test.py +++ b/keras_cv/src/models/object_detection/mask_rcnn/roi_sampler_test.py @@ -76,8 +76,9 @@ def test_roi_sampler(self, mask_value): ) # the sampled mask is only set to 1 if the ground truth # mask indicates object 2 + sampled_index = ops.where(sampled_gt_classes[0, :, 0] == 10)[0][0] self.assertAllClose( - sampled_gt_masks[0, 0], + sampled_gt_masks[0, sampled_index], (mask_value == 2) * np.ones((14, 14)), )