Skip to content

Commit

Permalink
Avoid TimeDistributed layers to fix for keras 3.3.3 and Acknowledge r…
Browse files Browse the repository at this point in the history
…andomness for test
  • Loading branch information
laxmareddyp committed Sep 19, 2024
1 parent 621cdc1 commit ea8e6c7
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 29 deletions.
60 changes: 32 additions & 28 deletions keras_cv/src/models/object_detection/mask_rcnn/mask_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,49 +49,53 @@ def __init__(
self.num_deconv_filters = num_deconv_filters
self.layers = []
for num_filters in stackwise_num_conv_filters:
conv = keras.layers.TimeDistributed(
keras.layers.Conv2D(
filters=num_filters,
kernel_size=3,
padding="same",
)
)
batchnorm = keras.layers.TimeDistributed(
keras.layers.BatchNormalization()
conv = keras.layers.Conv2D(
filters=num_filters,
kernel_size=3,
padding="same",
)
batchnorm = keras.layers.BatchNormalization()
activation = keras.layers.Activation("relu")
self.layers.extend([conv, batchnorm, activation])

self.deconv = keras.layers.TimeDistributed(
keras.layers.Conv2DTranspose(
num_deconv_filters,
kernel_size=2,
strides=2,
activation="relu",
padding="valid",
)
self.deconv = keras.layers.Conv2DTranspose(
num_deconv_filters,
kernel_size=2,
strides=2,
activation="relu",
padding="valid",
)
# we do not use a final sigmoid activation, since we use
# from_logits=True during training
self.segmentation_mask_output = keras.layers.TimeDistributed(
keras.layers.Conv2D(
num_classes + 1,
kernel_size=1,
strides=1,
activation="linear",
)
self.segmentation_mask_output = keras.layers.Conv2D(
num_classes + 1,
kernel_size=1,
strides=1,
activation="linear",
)

def call(self, feature_map, training=False):
x = feature_map
# reshape batch and ROI axes into one axis to obtain a suitable
# shape for conv layers
num_rois = keras.ops.shape(feature_map)[1]
x = keras.ops.reshape(feature_map, (-1, *feature_map.shape[2:]))
for layer in self.layers:
x = layer(x, training=training)
x = self.deconv(x)
mask = self.segmentation_mask_output(x)
return mask
segmentation_mask = self.segmentation_mask_output(x)
segmentation_mask = keras.ops.reshape(
segmentation_mask, (-1, num_rois, *segmentation_mask.shape[1:])
)
return segmentation_mask

def build(self, input_shape):
intermediate_shape = input_shape
if input_shape[0] is None or input_shape[1] is None:
intermediate_shape = (None, *input_shape[2:])
else:
intermediate_shape = (
input_shape[0] * input_shape[1],
*input_shape[2:],
)
for idx, num_filters in enumerate(self.stackwise_num_conv_filters):
self.layers[idx * 3].build(intermediate_shape)
intermediate_shape = tuple(intermediate_shape[:-1]) + (num_filters,)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ def test_roi_sampler(self, mask_value):
)
# the sampled mask is only set to 1 if the ground truth
# mask indicates object 2
sampled_index = ops.where(sampled_gt_classes[0, :, 0] == 10)[0][0]
self.assertAllClose(
sampled_gt_masks[0, 0],
sampled_gt_masks[0, sampled_index],
(mask_value == 2) * np.ones((14, 14)),
)

Expand Down

0 comments on commit ea8e6c7

Please sign in to comment.