diff --git a/.kokoro/github/ubuntu/gpu/build.sh b/.kokoro/github/ubuntu/gpu/build.sh index 879f106558..00c442edc8 100644 --- a/.kokoro/github/ubuntu/gpu/build.sh +++ b/.kokoro/github/ubuntu/gpu/build.sh @@ -74,7 +74,8 @@ then keras_cv/src/models/object_detection_3d \ keras_cv/src/models/segmentation \ keras_cv/src/models/feature_extractor/clip \ - keras_cv/src/models/stable_diffusion + keras_cv/src/models/stable_diffusion \ + keras_cv/src/models/segmentation/yolo_v8_segmentation else pytest --cache-clear --check_gpu --run_large --durations 0 \ keras_cv/src/bounding_box \ @@ -90,5 +91,6 @@ else keras_cv/src/models/object_detection_3d \ keras_cv/src/models/segmentation \ keras_cv/src/models/feature_extractor/clip \ - keras_cv/src/models/stable_diffusion -fi \ No newline at end of file + keras_cv/src/models/stable_diffusion \ + keras_cv/src/models/segmentation/yolo_v8_segmentation +fi diff --git a/keras_cv/api/models/__init__.py b/keras_cv/api/models/__init__.py index 54be7764b8..6276b685f3 100644 --- a/keras_cv/api/models/__init__.py +++ b/keras_cv/api/models/__init__.py @@ -257,6 +257,9 @@ from keras_cv.src.models.segmentation.segment_anything.sam_transformer import ( TwoWayTransformer, ) +from keras_cv.src.models.segmentation.yolo_v8_segmentation.yolo_v8_segmentation import ( + YOLOV8Segmentation, +) from keras_cv.src.models.stable_diffusion.stable_diffusion import ( StableDiffusion, ) diff --git a/keras_cv/api/models/segmentation/__init__.py b/keras_cv/api/models/segmentation/__init__.py index 9f5276304b..f111e7c9c0 100644 --- a/keras_cv/api/models/segmentation/__init__.py +++ b/keras_cv/api/models/segmentation/__init__.py @@ -12,3 +12,6 @@ from keras_cv.src.models.segmentation.segment_anything.sam import ( SegmentAnythingModel, ) +from keras_cv.src.models.segmentation.yolo_v8_segmentation.yolo_v8_segmentation import ( + YOLOV8Segmentation, +) diff --git a/keras_cv/src/layers/object_detection/non_max_suppression.py b/keras_cv/src/layers/object_detection/non_max_suppression.py index 45993258e4..fc657e09f5 100644 --- a/keras_cv/src/layers/object_detection/non_max_suppression.py +++ b/keras_cv/src/layers/object_detection/non_max_suppression.py @@ -159,6 +159,7 @@ def call( image_shape=image_shape, ) bounding_boxes = { + "idx": idx, "boxes": box_prediction, "confidence": confidence_prediction, "classes": ops.argmax(class_prediction, axis=-1), diff --git a/keras_cv/src/models/__init__.py b/keras_cv/src/models/__init__.py index ebe22b7709..56c711d58e 100644 --- a/keras_cv/src/models/__init__.py +++ b/keras_cv/src/models/__init__.py @@ -222,6 +222,7 @@ from keras_cv.src.models.segmentation import SAMPromptEncoder from keras_cv.src.models.segmentation import SegmentAnythingModel from keras_cv.src.models.segmentation import TwoWayTransformer +from keras_cv.src.models.segmentation import YOLOV8Segmentation from keras_cv.src.models.segmentation.segformer.segformer_aliases import ( SegFormer, ) diff --git a/keras_cv/src/models/segmentation/__init__.py b/keras_cv/src/models/segmentation/__init__.py index cec25eb010..b94278e6cb 100644 --- a/keras_cv/src/models/segmentation/__init__.py +++ b/keras_cv/src/models/segmentation/__init__.py @@ -21,3 +21,6 @@ SegmentAnythingModel, ) from keras_cv.src.models.segmentation.segment_anything import TwoWayTransformer +from keras_cv.src.models.segmentation.yolo_v8_segmentation import ( + YOLOV8Segmentation, +) diff --git a/keras_cv/src/models/segmentation/yolo_v8_segmentation/__init__.py b/keras_cv/src/models/segmentation/yolo_v8_segmentation/__init__.py new file mode 100644 index 0000000000..4659dd7878 --- /dev/null +++ b/keras_cv/src/models/segmentation/yolo_v8_segmentation/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .yolo_v8_segmentation import YOLOV8Segmentation diff --git a/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_backbone.py b/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_backbone.py new file mode 100644 index 0000000000..5a6af54566 --- /dev/null +++ b/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_backbone.py @@ -0,0 +1,378 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops + +BATCH_NORM_EPSILON = 1e-3 +BATCH_NORM_MOMENTUM = 0.97 +BOX_REGRESSION_CHANNELS = 64 + + +def apply_conv_bn( + inputs, + output_channel, + kernel_size=1, + strides=1, + activation="swish", + name="conv_bn", +): + if kernel_size > 1: + inputs = keras.layers.ZeroPadding2D( + padding=kernel_size // 2, name=f"{name}_pad" + )(inputs) + + x = keras.layers.Conv2D( + filters=output_channel, + kernel_size=kernel_size, + strides=strides, + padding="valid", + use_bias=False, + name=f"{name}_conv", + )(inputs) + x = keras.layers.BatchNormalization( + momentum=BATCH_NORM_MOMENTUM, + epsilon=BATCH_NORM_EPSILON, + name=f"{name}_bn", + )(x) + x = keras.layers.Activation(activation, name=name)(x) + return x + + +# TODO(ianstenbit): Remove this method once we're using CSPDarkNet backbone +# Calls to it should instead call the CSP block from the DarkNet implementation. +def apply_csp_block( + inputs, + channels=-1, + depth=2, + shortcut=True, + expansion=0.5, + activation="swish", + name="csp_block", +): + channel_axis = -1 + channels = channels if channels > 0 else inputs.shape[channel_axis] + hidden_channels = int(channels * expansion) + + pre = apply_conv_bn( + inputs, + hidden_channels * 2, + kernel_size=1, + activation=activation, + name=f"{name}_pre", + ) + short, deep = ops.split(pre, 2, axis=channel_axis) + + out = [short, deep] + for id in range(depth): + deep = apply_conv_bn( + deep, + hidden_channels, + kernel_size=3, + activation=activation, + name=f"{name}_pre_{id}_1", + ) + deep = apply_conv_bn( + deep, + hidden_channels, + kernel_size=3, + activation=activation, + name=f"{name}_pre_{id}_2", + ) + deep = (out[-1] + deep) if shortcut else deep + out.append(deep) + out = ops.concatenate(out, axis=channel_axis) + out = apply_conv_bn( + out, + channels, + kernel_size=1, + activation=activation, + name=f"{name}_output", + ) + return out + + +def get_anchors( + image_shape, + strides=[8, 16, 32], + base_anchors=[0.5, 0.5], +): + """Gets anchor points for YOLOV8. + + YOLOV8 uses anchor points representing the center of proposed boxes, and + matches ground truth boxes to anchors based on center points. + + Args: + image_shape: tuple or list of two integers representing the height and + width of input images, respectively. + strides: tuple of list of integers, the size of the strides across the + image size that should be used to create anchors. + base_anchors: tuple or list of two integers representing the offset from + (0,0) to start creating the center of anchor boxes, relative to the + stride. For example, using the default (0.5, 0.5) creates the first + anchor box for each stride such that its center is half of a stride + from the edge of the image. + + Returns: + A tuple of anchor centerpoints and anchor strides. Multiplying the + two together will yield the centerpoints in absolute x,y format. + + """ + base_anchors = ops.array(base_anchors, dtype="float32") + + all_anchors = [] + all_strides = [] + for stride in strides: + hh_centers = ops.arange(0, image_shape[0], stride) + ww_centers = ops.arange(0, image_shape[1], stride) + ww_grid, hh_grid = ops.meshgrid(ww_centers, hh_centers) + grid = ops.cast( + ops.reshape(ops.stack([hh_grid, ww_grid], 2), [-1, 1, 2]), + "float32", + ) + anchors = ( + ops.expand_dims( + base_anchors * ops.array([stride, stride], "float32"), 0 + ) + + grid + ) + anchors = ops.reshape(anchors, [-1, 2]) + all_anchors.append(anchors) + all_strides.append(ops.repeat(stride, anchors.shape[0])) + + all_anchors = ops.cast(ops.concatenate(all_anchors, axis=0), "float32") + all_strides = ops.cast(ops.concatenate(all_strides, axis=0), "float32") + + all_anchors = all_anchors / all_strides[:, None] + + # Swap the x and y coordinates of the anchors. + all_anchors = ops.concatenate( + [all_anchors[:, 1, None], all_anchors[:, 0, None]], axis=-1 + ) + return all_anchors, all_strides + + +def apply_path_aggregation_fpn(features, depth=3, name="fpn"): + """Applies the Feature Pyramid Network (FPN) to the outputs of a backbone. + + Args: + features: list of tensors representing the P3, P4, and P5 outputs of the + backbone. + depth: integer, the depth of the CSP blocks used in the FPN. + name: string, a prefix for names of layers used by the FPN. + + Returns: + A list of three tensors whose shapes are the same as the three inputs, + but which are dependent on each of the three inputs to combine the high + resolution of the P3 inputs with the strong feature representations of + the P5 inputs. + + """ + p3, p4, p5 = features + + # Upsample P5 and concatenate with P4, then apply a CSPBlock. + p5_upsampled = ops.repeat(ops.repeat(p5, 2, axis=1), 2, axis=2) + p4p5 = ops.concatenate([p5_upsampled, p4], axis=-1) + p4p5 = apply_csp_block( + p4p5, + channels=p4.shape[-1], + depth=depth, + shortcut=False, + activation="swish", + name=f"{name}_p4p5", + ) + + # Upsample P4P5 and concatenate with P3, then apply a CSPBlock. + p4p5_upsampled = ops.repeat(ops.repeat(p4p5, 2, axis=1), 2, axis=2) + p3p4p5 = ops.concatenate([p4p5_upsampled, p3], axis=-1) + p3p4p5 = apply_csp_block( + p3p4p5, + channels=p3.shape[-1], + depth=depth, + shortcut=False, + activation="swish", + name=f"{name}_p3p4p5", + ) + + # Downsample P3P4P5, concatenate with P4P5, and apply a CSP Block. + p3p4p5_d1 = apply_conv_bn( + p3p4p5, + p3p4p5.shape[-1], + kernel_size=3, + strides=2, + activation="swish", + name=f"{name}_p3p4p5_downsample1", + ) + p3p4p5_d1 = ops.concatenate([p3p4p5_d1, p4p5], axis=-1) + p3p4p5_d1 = apply_csp_block( + p3p4p5_d1, + channels=p4p5.shape[-1], + shortcut=False, + activation="swish", + name=f"{name}_p3p4p5_downsample1_block", + ) + + # Downsample the resulting P3P4P5 again, concatenate with P5, and apply + # another CSP Block. + p3p4p5_d2 = apply_conv_bn( + p3p4p5_d1, + p3p4p5_d1.shape[-1], + kernel_size=3, + strides=2, + activation="swish", + name=f"{name}_p3p4p5_downsample2", + ) + p3p4p5_d2 = ops.concatenate([p3p4p5_d2, p5], axis=-1) + p3p4p5_d2 = apply_csp_block( + p3p4p5_d2, + channels=p5.shape[-1], + shortcut=False, + activation="swish", + name=f"{name}_p3p4p5_downsample2_block", + ) + + return [p3p4p5, p3p4p5_d1, p3p4p5_d2] + + +def decode_regression_to_boxes(preds): + """Decodes the results of the YOLOV8Detector forward-pass into boxes. + + Returns left / top / right / bottom predictions with respect to anchor + points. + + Each coordinate is encoded with 16 predicted values. Those predictions are + softmaxed and multiplied by [0..15] to make predictions. The resulting + predictions are relative to the stride of an anchor box (and correspondingly + relative to the scale of the feature map from which the predictions came). + """ + preds_bbox = keras.layers.Reshape((-1, 4, BOX_REGRESSION_CHANNELS // 4))( + preds + ) + preds_bbox = ops.nn.softmax(preds_bbox, axis=-1) * ops.arange( + BOX_REGRESSION_CHANNELS // 4, dtype="float32" + ) + return ops.sum(preds_bbox, axis=-1) + + +def dist2bbox(distance, anchor_points): + """Decodes distance predictions into xyxy boxes. + + Input left / top / right / bottom predictions are transformed into xyxy box + predictions based on anchor points. + + The resulting xyxy predictions must be scaled by the stride of their + corresponding anchor points to yield an absolute xyxy box. + """ + left_top, right_bottom = ops.split(distance, 2, axis=-1) + x1y1 = anchor_points - left_top + x2y2 = anchor_points + right_bottom + return ops.concatenate((x1y1, x2y2), axis=-1) # xyxy bbox + + +def apply_yolo_v8_head( + inputs, + num_classes, + name="yolo_v8_head", +): + """Applies a YOLOV8 head. + + Makes box and class predictions based on the output of a feature pyramid + network. + + Args: + inputs: list of tensors output by the Feature Pyramid Network, should + have the same shape as the P3, P4, and P5 outputs of the backbone. + num_classes: integer, the number of classes that a bounding box could + possibly be assigned to. + name: string, a prefix for names of layers used by the head. + + Returns: A dictionary with two entries. The "boxes" entry contains box + regression predictions, while the "classes" entry contains class + predictions. + """ + # 64 is the default number of channels, as 16 components are used to predict + # each of the 4 offsets for corner points of a bounding box with respect + # to the center point. In cases where the input has much higher resolution + # (e.g. the P3 input has >256 channels), we use additional channels for + # the intermediate conv layers. This is only true for very large backbones. + box_channels = max(BOX_REGRESSION_CHANNELS, inputs[0].shape[-1] // 4) + + # We use at least num_classes channels for intermediate conv layer for class + # predictions. In most cases, the P3 input has many more channels than the + # number of classes, so we preserve those channels until the final layer. + class_channels = max(num_classes, inputs[0].shape[-1]) + + # We compute box and class predictions for each of the feature maps from + # the FPN and then combine them. + outputs = [] + for id, feature in enumerate(inputs): + cur_name = f"{name}_{id+1}" + + box_predictions = apply_conv_bn( + feature, + box_channels, + kernel_size=3, + activation="swish", + name=f"{cur_name}_box_1", + ) + box_predictions = apply_conv_bn( + box_predictions, + box_channels, + kernel_size=3, + activation="swish", + name=f"{cur_name}_box_2", + ) + box_predictions = keras.layers.Conv2D( + filters=BOX_REGRESSION_CHANNELS, + kernel_size=1, + name=f"{cur_name}_box_3_conv", + )(box_predictions) + + class_predictions = apply_conv_bn( + feature, + class_channels, + kernel_size=3, + activation="swish", + name=f"{cur_name}_class_1", + ) + class_predictions = apply_conv_bn( + class_predictions, + class_channels, + kernel_size=3, + activation="swish", + name=f"{cur_name}_class_2", + ) + class_predictions = keras.layers.Conv2D( + filters=num_classes, + kernel_size=1, + name=f"{cur_name}_class_3_conv", + )(class_predictions) + class_predictions = keras.layers.Activation( + "sigmoid", name=f"{cur_name}_classifier" + )(class_predictions) + + out = ops.concatenate([box_predictions, class_predictions], axis=-1) + out = keras.layers.Reshape( + [-1, out.shape[-1]], name=f"{cur_name}_output_reshape" + )(out) + outputs.append(out) + + outputs = ops.concatenate(outputs, axis=1) + outputs = keras.layers.Activation( + "linear", dtype="float32", name="box_outputs" + )(outputs) + + return { + "boxes": outputs[:, :, :BOX_REGRESSION_CHANNELS], + "classes": outputs[:, :, BOX_REGRESSION_CHANNELS:], + } diff --git a/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_label_encoder.py b/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_label_encoder.py new file mode 100644 index 0000000000..4eebd27535 --- /dev/null +++ b/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_label_encoder.py @@ -0,0 +1,271 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf + +from keras_cv.src import bounding_box +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.bounding_box.iou import compute_ciou + + +def is_anchor_center_within_box(anchors, gt_bboxes): + return ops.all( + ops.logical_and( + gt_bboxes[:, :, None, :2] < anchors, + gt_bboxes[:, :, None, 2:] > anchors, + ), + axis=-1, + ) + + +class YOLOV8LabelEncoder(keras.layers.Layer): + """ + Encodes ground truth boxes to target boxes and class labels for training a + YOLOV8 model. This uses the TOOD Task Aligned Assigner approach. + See https://arxiv.org/abs/2108.07755 for more info, as well as a reference + implementation at https://github.com/fcjian/TOOD/blob/master/ + mmdet/core/bbox/assigners/task_aligned_assigner.py + + Args: + num_classes: integer, the number of classes in the training dataset + max_anchor_matches: optional integer, the maximum number of anchors to + match with any given ground truth box. For example, when the default + 10 is used, the 10 candidate anchor points with the highest + alignment score are matched with a ground truth box. If less than 10 + candidate anchors exist, all candidates will be matched to the box. + alpha: float, a parameter to control the influence of class predictions + on the alignment score of an anchor box. This is the alpha parameter + in equation 9 of https://arxiv.org/pdf/2108.07755.pdf. + beta: float, a parameter to control the influence of box IOUs on the + alignment score of an anchor box. This is the beta parameter in + equation 9 of https://arxiv.org/pdf/2108.07755.pdf. + epsilon: float, a small number used for numerical stability in division + (to avoid diving by zero), and used as a threshold to eliminate very + small matches based on alignment scores of approximately zero. + """ + + def __init__( + self, + num_classes, + max_anchor_matches=10, + alpha=0.5, + beta=6.0, + epsilon=1e-9, + **kwargs, + ): + super().__init__(**kwargs) + self.max_anchor_matches = max_anchor_matches + self.num_classes = num_classes + self.alpha = alpha + self.beta = beta + self.epsilon = epsilon + + def assign( + self, scores, decode_bboxes, anchors, gt_labels, gt_bboxes, gt_mask + ): + """Assigns ground-truth boxes to anchors. + + Uses the task-aligned assignment strategy for matching ground truth + and anchor boxes based on prediction scores and IoU. + """ + num_anchors = anchors.shape[0] + + # Box scores are the predicted scores for each anchor, ground truth box + # pair. Only the predicted score for the class of the GT box is included + # Shape: (B, num_gt_boxes, num_anchors) (after transpose) + bbox_scores = ops.take_along_axis( + scores, + ops.cast(ops.maximum(gt_labels[:, None, :], 0), "int32"), + axis=-1, + ) + bbox_scores = ops.transpose(bbox_scores, (0, 2, 1)) + + # Overlaps are the IoUs of each predicted box and each GT box. + # Shape: (B, num_gt_boxes, num_anchors) + overlaps = compute_ciou( + ops.expand_dims(gt_bboxes, axis=2), + ops.expand_dims(decode_bboxes, axis=1), + bounding_box_format="xyxy", + ) + + # Alignment metrics are a combination of box scores and overlaps, per + # the task-aligned-assignment formula. + # Metrics are forced to 0 for boxes which have been masked in the GT + # input (e.g. due to padding) + alignment_metrics = ops.power(bbox_scores, self.alpha) * ops.power( + overlaps, self.beta + ) + alignment_metrics = ops.where(gt_mask, alignment_metrics, 0) + + # Only anchors which are inside of relevant GT boxes are considered + # for assignment. + # This is a boolean tensor of shape (B, num_gt_boxes, num_anchors) + matching_anchors_in_gt_boxes = is_anchor_center_within_box( + anchors, gt_bboxes + ) + alignment_metrics = ops.where( + matching_anchors_in_gt_boxes, alignment_metrics, 0 + ) + + # The top-k highest alignment metrics are used to select K candidate + # anchors for each GT box. + candidate_metrics, candidate_idxs = ops.top_k( + alignment_metrics, self.max_anchor_matches + ) + candidate_idxs = ops.where(candidate_metrics > 0, candidate_idxs, -1) + + # We now compute a dense grid of anchors and GT boxes. This is useful + # for picking a GT box when an anchor matches to 2, as well as returning + # to a dense format for a mask of which anchors have been matched. + anchors_matched_gt_box = ops.zeros_like(overlaps) + for k in range(self.max_anchor_matches): + anchors_matched_gt_box += ops.one_hot( + candidate_idxs[:, :, k], num_anchors + ) + + # We zero-out the overlap for anchor, GT box pairs which don't match. + overlaps *= anchors_matched_gt_box + # In cases where one anchor matches to 2 GT boxes, we pick the GT box + # with the highest overlap as a max. + gt_box_matches_per_anchor = ops.argmax(overlaps, axis=1) + gt_box_matches_per_anchor_mask = ops.max(overlaps, axis=1) > 0 + # TODO(ianstenbit): Once ops.take_along_axis supports -1 in Torch, + # replace gt_box_matches_per_anchor with + # ops.where( + # ops.max(overlaps, axis=1) > 0, ops.argmax(overlaps, axis=1), -1 + # ) + # and get rid of the manual masking + gt_box_matches_per_anchor = ops.cast(gt_box_matches_per_anchor, "int32") + + # We select the GT boxes and labels that correspond to anchor matches. + bbox_labels = ops.take_along_axis( + gt_bboxes, gt_box_matches_per_anchor[:, :, None], axis=1 + ) + bbox_labels = ops.where( + gt_box_matches_per_anchor_mask[:, :, None], bbox_labels, -1 + ) + class_labels = ops.take_along_axis( + gt_labels, gt_box_matches_per_anchor, axis=1 + ) + class_labels = ops.where( + gt_box_matches_per_anchor_mask, class_labels, -1 + ) + + class_labels = ops.one_hot( + ops.cast(class_labels, "int32"), self.num_classes + ) + + # Finally, we normalize an anchor's class labels based on the relative + # strength of the anchors match with the corresponding GT box. + alignment_metrics *= anchors_matched_gt_box + max_alignment_per_gt_box = ops.max( + alignment_metrics, axis=-1, keepdims=True + ) + max_overlap_per_gt_box = ops.max(overlaps, axis=-1, keepdims=True) + + normalized_alignment_metrics = ops.max( + alignment_metrics + * max_overlap_per_gt_box + / (max_alignment_per_gt_box + self.epsilon), + axis=-2, + ) + class_labels *= normalized_alignment_metrics[:, :, None] + + # On TF backend, the final "4" becomes a dynamic shape so we include + # this to force it to a static shape of 4. This does not actually + # reshape the Tensor. + bbox_labels = ops.reshape(bbox_labels, (-1, num_anchors, 4)) + return ( + ops.stop_gradient(bbox_labels), + ops.stop_gradient(class_labels), + ops.stop_gradient( + # ops.cast(gt_box_matches_per_anchor > -1, "float32") + ops.cast(gt_box_matches_per_anchor > 0, "float32") + ), + ) + + def call( + self, scores, decode_bboxes, anchors, gt_labels, gt_bboxes, gt_mask + ): + """Computes target boxes and classes for anchors. + + Args: + scores: a Float Tensor of shape (batch_size, num_anchors, + num_classes) representing predicted class scores for each + anchor. + decode_bboxes: a Float Tensor of shape (batch_size, num_anchors, 4) + representing predicted boxes for each anchor. + anchors: a Float Tensor of shape (batch_size, num_anchors, 2) + representing the xy coordinates of the center of each anchor. + gt_labels: a Float Tensor of shape (batch_size, num_gt_boxes) + representing the classes of ground truth boxes. + gt_bboxes: a Float Tensor of shape (batch_size, num_gt_boxes, 4) + representing the ground truth bounding boxes in xyxy format. + gt_mask: A Boolean Tensor of shape (batch_size, num_gt_boxes) + representing whether a box in `gt_bboxes` is a real box or a + non-box that exists due to padding. + + Returns: + A tuple of the following: + - A Float Tensor of shape (batch_size, num_anchors, 4) + representing box targets for the model. + - A Float Tensor of shape (batch_size, num_anchors, num_classes) + representing class targets for the model. + - A Boolean Tensor of shape (batch_size, num_anchors) + representing whether each anchor was a match with a ground + truth box. Anchors that didn't match with a ground truth + box should be excluded from both class and box losses. + """ + if isinstance(gt_bboxes, tf.RaggedTensor): + dense_bounding_boxes = bounding_box.to_dense( + {"boxes": gt_bboxes, "classes": gt_labels}, + ) + gt_bboxes = dense_bounding_boxes["boxes"] + gt_labels = dense_bounding_boxes["classes"] + + if isinstance(gt_mask, tf.RaggedTensor): + gt_mask = gt_mask.to_tensor() + + max_num_boxes = ops.shape(gt_bboxes)[1] + + # If there are no GT boxes in the batch, we short-circuit and return + # empty targets to avoid NaNs. + return ops.cond( + ops.array(max_num_boxes > 0), + lambda: self.assign( + scores, decode_bboxes, anchors, gt_labels, gt_bboxes, gt_mask + ), + lambda: ( + ops.zeros_like(decode_bboxes), + ops.zeros_like(scores), + ops.zeros_like(scores[..., 0]), + ), + ) + + def count_params(self): + # The label encoder has no weights, so we short-circuit the weight + # counting to avoid having to `build` this layer unnecessarily. + return 0 + + def get_config(self): + config = { + "max_anchor_matches": self.max_anchor_matches, + "num_classes": self.num_classes, + "alpha": self.alpha, + "beta": self.beta, + "epsilon": self.epsilon, + } + base_config = super().get_config() + return dict(list(base_config.items()) + list(config.items())) diff --git a/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_segmentation.py b/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_segmentation.py new file mode 100644 index 0000000000..dcf625df0a --- /dev/null +++ b/keras_cv/src/models/segmentation/yolo_v8_segmentation/yolo_v8_segmentation.py @@ -0,0 +1,832 @@ +# Copyright 2023 The KerasCV Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import copy +import warnings + +from keras.layers import Activation +from keras.layers import Concatenate +from keras.layers import Conv2D +from keras.layers import Input +from keras.layers import Reshape +from keras.layers import UpSampling2D +from keras.losses import BinaryCrossentropy + +from keras_cv.src import bounding_box +from keras_cv.src.api_export import keras_cv_export +from keras_cv.src.backend import keras +from keras_cv.src.backend import ops +from keras_cv.src.layers import NonMaxSuppression +from keras_cv.src.losses.ciou_loss import CIoULoss +from keras_cv.src.models.backbones.backbone_presets import backbone_presets +from keras_cv.src.models.backbones.backbone_presets import ( + backbone_presets_with_weights, +) +from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_detector_presets import ( # noqa: E501 + yolo_v8_detector_presets, +) +from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_layers import ( + apply_conv_bn, +) +from keras_cv.src.models.segmentation.yolo_v8_segmentation.yolo_v8_backbone import ( # noqa: E501 + apply_path_aggregation_fpn, +) +from keras_cv.src.models.segmentation.yolo_v8_segmentation.yolo_v8_backbone import ( # noqa: E501 + apply_yolo_v8_head, +) +from keras_cv.src.models.segmentation.yolo_v8_segmentation.yolo_v8_backbone import ( # noqa: E501 + decode_regression_to_boxes, +) +from keras_cv.src.models.segmentation.yolo_v8_segmentation.yolo_v8_backbone import ( # noqa: E501 + dist2bbox, +) +from keras_cv.src.models.segmentation.yolo_v8_segmentation.yolo_v8_backbone import ( # noqa: E501 + get_anchors, +) +from keras_cv.src.models.segmentation.yolo_v8_segmentation.yolo_v8_label_encoder import ( # noqa: E501 + YOLOV8LabelEncoder, +) +from keras_cv.src.models.task import Task +from keras_cv.src.utils.python_utils import classproperty +from keras_cv.src.utils.train import get_feature_extractor + + +def build_mask_prototypes(x, dimension, num_prototypes, name="prototypes"): + """Builds mask prototype network. + + The outputs of this module are linearly combined with the regressed mask + coefficients to produce the predicted masks. This is an implementation of + the module proposed in YOLACT https://arxiv.org/abs/1904.02689. + + Args: + x: tensor, representing the output of a low backone featuremap i.e. P3. + dimension: integer, inner number of channels used for mask prototypes. + num_prototypes: integer, number of mask prototypes to build predictions. + name: string, a prefix for names of layers used by the prototypes. + + Returns: + Tensor whose resolution is double than the inputted tensor. + This tensor is used as a base to build linear combinations of masks. + """ + x = apply_conv_bn(x, dimension, 3, name=f"{name}_0") + x = apply_conv_bn(x, dimension, 3, name=f"{name}_1") + x = apply_conv_bn(x, dimension, 3, name=f"{name}_2") + x = UpSampling2D((2, 2), "channels_last", "bilinear", name=f"{name}_3")(x) + x = apply_conv_bn(x, dimension, 3, name=f"{name}_4") + x = Conv2D(num_prototypes, 1, padding="same", name=f"{name}_5")(x) + x = Activation("relu", name=name)(x) + return x + + +def build_branch_mask_coefficients(x, dimension, num_prototypes, branch_arg): + """Builds mask coefficients of a single branch. + + Args: + x: tensor, representing the outputs of a single branch of FPN i.e. P3. + dimension: integer, inner number of channels used for mask coefficients. + num_prototypes: integer, number of mask prototypes to build predictions. + branch_arg: integer, representing the branch number. This is used to + build the name of the tensors. + + Returns: + Tensor representing the coefficients used to regress the outputted masks + of a single branch. + """ + name = f"branch_{branch_arg}_mask_coefficients" + x = apply_conv_bn(x, dimension, 3, name=f"{name}_0") + x = apply_conv_bn(x, dimension, 3, name=f"{name}_1") + x = Conv2D(num_prototypes, 1, name=f"{name}_2")(x) + x = Activation("tanh", name=f"{name}_3")(x) + x = Reshape((-1, num_prototypes), name=f"{name}_4")(x) + return x + + +def build_mask_coefficients(branches, num_prototypes, dimension): + """Builds all mask coefficients. + + This coefficients represent the linear terms used to combine the masks. + + Args: + branches: list of tensors, representing the outputs of a backbone model. + num_prototypes: integer, number of mask prototypes to build predictions. + dimension: integer, inner number of channels used for mask coefficients. + + Returns: + Tensor representing the linear coefficients for regressing masks. + """ + coefficients = [] + for branch_arg, branch in enumerate(branches): + branch_coefficients = build_branch_mask_coefficients( + branch, dimension, num_prototypes, branch_arg + ) + coefficients.append(branch_coefficients) + return Concatenate(axis=1, name="coefficients")(coefficients) + + +def combine_linearly_prototypes(coefficients, prototypes): + """Linearly combines prototypes masks using the predicted coefficients. + + This applies equation 1 of YOLACT https://arxiv.org/abs/1904.02689. + + Args: + coefficients: tensor representing the linear coefficients of the + prototypes masks. + prototypes: tensor representing a base of masks that can be + linearly combined to produce predicted masks. + + Returns: + Tensor representing all the predicted masks. + """ + masks = ops.sigmoid(ops.einsum("bnm,bhwm->bnhw", coefficients, prototypes)) + return masks + + +def build_segmentation_head( + branches, prototype_dimension, num_prototypes, coefficient_dimension +): + """Builds a YOLACT https://arxiv.org/abs/1904.02689 segmentation head. + + The proposed segmentation head of YOLACT https://arxiv.org/abs/1904.02689 + predicts prototype masks, their linear coefficients, and combines them to + build the predicted masks. + + Args: + branches: list of tensors, representing the outputs of a backbone model. + prototype_dimension: integer, inner number of channels used for mask + prototypes. + num_prototypes: integer, number of mask prototypes to build predictions. + coefficient_dimension: integer, inner number of channels used for + predicting the mask coefficients. + + Returns: + Tensor representing all the predicted masks. + """ + prototypes = build_mask_prototypes( + branches[0], prototype_dimension, num_prototypes + ) + coefficients = build_mask_coefficients( + branches, num_prototypes, coefficient_dimension + ) + masks = combine_linearly_prototypes(coefficients, prototypes) + return masks + + +def split_masks(masks, num_classes): + """Splits single channel segmentation mask into different class channels. + + Args: + masks: tensor representing ground truth masks using a single + channel consisting of integers representing the pixel class. + num_classes: integer, total number of classes in the dataset. + + Returns: + tensor representing each class mask in a different channel. + """ + splitted_masks = [] + for class_arg in range(1, num_classes + 1): + splitted_masks.append(masks == class_arg) + splitted_masks = ops.concatenate(splitted_masks, axis=-1) + splitted_masks = ops.cast(splitted_masks, float) + return splitted_masks + + +def repeat_masks(masks, class_labels, num_classes): + """Repeats ground truth masks. + + Each ground truth mask channel is gathered using the assigned class label. + This is used to build a tensor with the same shape as the predicted masks + in order to compute the mask loss. + + Args: + masks: tensor representing ground truth masks using a single + channel consisting of integers representing the pixel class. + class_labels: tensor, with the assigned class labels in each anchor box. + The class labels are in a one-hot encoding vector form. + num_classes: integer, total number of classes in the dataset. + + Returns: + tensor representing each class mask in a different channel. + """ + class_args = ops.argmax(class_labels, axis=-1) + batch_shape = class_args.shape[0] + class_args = ops.reshape(class_args, (batch_shape, 1, 1, -1)) + splitted_masks = split_masks(masks, num_classes) + repeated_masks = ops.take_along_axis(splitted_masks, class_args, axis=-1) + return repeated_masks + + +def build_target_masks(true_masks, true_scores, H_mask, W_mask, num_classes): + """Build target masks by resizing and repeating ground truth masks. + + Resizes ground truth masks to the predicted tensor mask shape, and repeats + masks using the largest true score value. + + Args: + true_masks: tensor representing the ground truth masks. + true_scores: tensor with the class scores assigned by the label encoder. + num_classes: integer indicating the total number of classes. + + Returns: + Tensor with resized and repeated target masks. + """ + true_masks = ops.image.resize(true_masks, (H_mask, W_mask), "nearest") + true_masks = repeat_masks(true_masks, true_scores, num_classes) + true_masks = ops.moveaxis(true_masks, 3, 1) + return true_masks + + +def compute_box_areas(boxes): + """Computes area for bounding boxes + + Args: + boxes: (N, 4) or (batch_size, N, 4) float tensor, either batched + or unbatched boxes. + + Returns: + a float Tensor of [N] or [batch_size, N] + """ + y_min, x_min, y_max, x_max = ops.split(boxes[..., :4], 4, axis=-1) + box_areas = ops.squeeze((y_max - y_min) * (x_max - x_min), axis=-1) + return box_areas + + +def normalize_box_areas(box_areas, H, W): + """Normalizes box areas by dividing by the total image area. + + Args: + boxes: tensor of shape (B, N, 4) with bounding boxes in xyxy format. + H: integer indicating the mask height. + W: integer indicating the mask width. + + Returns: + Tensor of shape (B, N, 4). + """ + return box_areas / (H * W) + + +def get_backbone_pyramid_layer_names(backbone, level_names): + """Gets actual layer names from the provided pyramid levels inside backbone. + + Args: + backbone: Keras backbone model with the field "pyramid_level_inputs". + level_names: list of strings indicating the level names. + + Returns: + List of layer strings indicating the layer names of each level. + """ + layer_names = [] + for level_name in level_names: + layer_names.append(backbone.pyramid_level_inputs[level_name]) + return layer_names + + +def build_feature_extractor(backbone, level_names): + """Builds feature extractor directly from the level names + + Args: + backbone: Keras backbone model with the field "pyramid_level_inputs". + level_names: list of strings indicating the level names. + + Returns: + Keras Model with level names as outputs. + """ + layer_names = get_backbone_pyramid_layer_names(backbone, level_names) + extractor = get_feature_extractor(backbone, layer_names, level_names) + return extractor + + +def extend_branches(inputs, extractor, FPN_depth): + """Extends extractor model with a feature pyramid network. + + Args: + inputs: tensor, with image input. + extractor: Keras Model with level names as outputs. + FPN_depth: integer representing the feature pyramid depth. + + Returns: + List of extended branch tensors. + """ + features = list(extractor(inputs).values()) + branches = apply_path_aggregation_fpn(features, FPN_depth, name="pa_fpn") + return branches + + +def extend_backbone(backbone, level_names, trainable, FPN_depth): + """Extends backbone levels with a feature pyramid network. + + Args: + backbone: Keras backbone model with the field "pyramid_level_inputs". + level_names: list of strings indicating the level names. + trainable: boolean indicating if backbone should be optimized. + FPN_depth: integer representing the feature pyramid depth. + + Return: + Tuple with input image tensor, and list of extended branch tensors. + """ + feature_extractor = build_feature_extractor(backbone, level_names) + feature_extractor.trainable = trainable + inputs = Input(feature_extractor.input_shape[1:]) + branches = extend_branches(inputs, feature_extractor, FPN_depth) + return inputs, branches + + +def add_no_op_for_pretty_print(x, name): + """Wrap tensor with dummy operation to change tensor name. + + # Args: + x: tensor. + name: string name given to the tensor. + + Return: + Tensor with new wrapped name. + """ + return Concatenate(axis=1, name=name)([x]) + + +def unpack_input(data): + """Unpacks standard keras-cv data dictionary into inputs and outputs. + + Args: + data: Dictionary with the standard key-value pairs of keras-cv + + Returns: + Tuple containing inputs and outputs. + """ + classes = data["bounding_boxes"]["classes"] + boxes = data["bounding_boxes"]["boxes"] + segmentation_masks = data["segmentation_masks"] + y = { + "classes": classes, + "boxes": boxes, + "segmentation_masks": segmentation_masks, + } + return data["images"], y + + +def boxes_to_masks(boxes, H_mask, W_mask): + """Build mask with True values inside the bounding box and False elsewhere. + + Args: + boxes: tensor of shape (N, 4) with bounding boxes in xyxy format. + H_mask: integer indicating the height of the mask. + W_mask: integer indicating the width of the mask. + + Returns: + A mask of the specified shape with True values inside bounding box. + """ + x_min, y_min, x_max, y_max = ops.split(boxes, 4, 1) + + y_range = ops.arange(H_mask) + x_range = ops.arange(W_mask) + y_indices, x_indices = ops.meshgrid(y_range, x_range, indexing="ij") + + y_indices = ops.expand_dims(y_indices, 0) + x_indices = ops.expand_dims(x_indices, 0) + + x_min = ops.expand_dims(x_min, axis=1) + y_min = ops.expand_dims(y_min, axis=1) + x_max = ops.expand_dims(x_max, axis=1) + y_max = ops.expand_dims(y_max, axis=1) + + in_x_min_to_x_max = ops.logical_and(x_indices >= x_min, x_indices < x_max) + in_y_min_to_y_max = ops.logical_and(y_indices >= y_min, y_indices < y_max) + masks = ops.logical_and(in_x_min_to_x_max, in_y_min_to_y_max) + return masks + + +def batch_boxes_to_masks(boxes, H_mask, W_mask): + """Converts boxes to masks over the batch dimension. + + Args: + boxes: tensor of shape (B, N, 4) with bounding boxes in xyxy format. + H_mask: integer indicating the height of the mask. + W_mask: integer indicating the width of the mask. + + Returns: + Batch of masks with True values inside the bounding box. + """ + batch_size = boxes.shape[0] + crop_masks = [] + for batch_arg in range(batch_size): + boxes_sample = ops.cast(boxes[batch_arg], "int32") + crop_mask = boxes_to_masks(boxes_sample, H_mask, W_mask) + crop_masks.append(crop_mask[None]) + crop_masks = ops.concatenate(crop_masks) + crop_masks = ops.cast(crop_masks, "float32") + return crop_masks + + +def build_mask_weights(weight, boxes, H_mask, W_mask): + """Build mask sample weights used to scale the loss at every batch. + + To balance the loss of masks with different shapes, YOLACT assigns a weight + to each mask that is inversely proportional to its area. + + Args: + weight: float, weight multiplied to the mask loss. + boxes: tensor of shape (B, N, 4) with bounding boxes in xyxy format. + H_image: integer indicating the inputted image height. + W_image: integer indicating the inputted image width. + H_mask: integer indicating the predicted mask height. + W_mask: integer indicating the predicted mask width. + + Returns: + Tensor of shape [B, num_anchors, 1, 1] containing the mask weights. + """ + box_areas = compute_box_areas(boxes) + box_areas = normalize_box_areas(box_areas, H_mask, W_mask) + weights = ops.divide_no_nan(weight, box_areas) + weights = weights / (H_mask * W_mask) + return weights[..., None, None] + + +@keras_cv_export( + [ + "keras_cv.models.YOLOV8Segmentation", + "keras_cv.models.segmentation.YOLOV8Segmentation", + ] +) +class YOLOV8Segmentation(Task): + """Implements the YOLOV8 instance segmentation model. + + Args: + backbone: `keras.Model`, must implement the `pyramid_level_inputs` + property with keys "P3", "P4", and "P5" and layer names as values. + A sensible backbone to use is the `keras_cv.models.YOLOV8Backbone`. + num_classes: integer, the number of classes in your dataset excluding + the background class. Classes should be represented by integers in + the range [0, num_classes). + bounding_box_format: string, the format of bounding boxes of input + dataset. + fpn_depth: integer, a specification of the depth of the CSP blocks in + the Feature Pyramid Network. This is usually 1, 2, or 3, depending + on the size of your YOLOV8Detector model. We recommend using 3 for + "yolo_v8_l_backbone" and "yolo_v8_xl_backbone". Defaults to 2. + label_encoder: (Optional) A `YOLOV8LabelEncoder` that is + responsible for transforming input boxes into trainable labels for + YOLOV8Detector. If not provided, a default is provided. + prediction_decoder: (Optional) A `keras.layers.Layer` that is + responsible for transforming YOLOV8 predictions into usable + bounding boxes. If not provided, a default is provided. The + default `prediction_decoder` layer is a + `keras_cv.layers.MultiClassNonMaxSuppression` layer, which uses + a Non-Max Suppression for box pruning. + prototype_dimension: integer, inner number of channels used for mask + prototypes. Defaults to 256. + num_prototypes: integer, number of mask prototypes to build predictions. + Defaults to 32. + coefficient_dimension: integer, inner number of channels used for + predicting the mask coefficients. Defaults to 32 + trainable_backbone: boolean indicating if the provided backbone should + be trained as well. Defaults to False. + + Example: + ```python + images = tf.ones(shape=(1, 512, 512, 3)) + + model = keras_cv.models.YOLOV8Segmentation( + num_classes=20, + bounding_box_format="xywh", + backbone=keras_cv.models.YOLOV8Backbone.from_preset( + "yolo_v8_m_backbone_coco" + ), + fpn_depth=2 + ) + + # Evaluate model without box decoding and NMS + model(images) + + # Prediction with box decoding and NMS + model.predict(images) + ``` + """ + + def __init__( + self, + backbone, + num_classes, + bounding_box_format, + fpn_depth=2, + label_encoder=None, + prediction_decoder=None, + prototype_dimension=256, + num_prototypes=32, + coefficient_dimension=32, + trainable_backbone=False, + **kwargs, + ): + level_names = ["P3", "P4", "P5"] + images, branches = extend_backbone( + backbone, level_names, trainable_backbone, fpn_depth + ) + masks = build_segmentation_head( + branches, prototype_dimension, num_prototypes, coefficient_dimension + ) + detection_head = apply_yolo_v8_head(branches, num_classes) + boxes, classes = detection_head["boxes"], detection_head["classes"] + boxes = add_no_op_for_pretty_print(boxes, "box") + masks = add_no_op_for_pretty_print(masks, "masks") + classes = add_no_op_for_pretty_print(classes, "class") + outputs = {"boxes": boxes, "classes": classes, "masks": masks} + super().__init__(inputs=images, outputs=outputs, **kwargs) + + self.bounding_box_format = bounding_box_format + self._prediction_decoder = prediction_decoder or NonMaxSuppression( + bounding_box_format=bounding_box_format, + from_logits=False, + confidence_threshold=0.2, + iou_threshold=0.7, + ) + self.backbone = backbone + self.fpn_depth = fpn_depth + self.num_classes = num_classes + self.label_encoder = label_encoder or YOLOV8LabelEncoder( + num_classes=num_classes + ) + self.prototype_dimension = prototype_dimension + self.num_prototypes = num_prototypes + self.coefficient_dimension = coefficient_dimension + self.trainable_backbone = trainable_backbone + + def compile( + self, + box_loss, + classification_loss, + segmentation_loss, + box_loss_weight=7.5, + classification_loss_weight=0.5, + segmentation_loss_weight=6.125, + metrics=None, + **kwargs, + ): + """Compiles the YOLOV8Segmentation. + + `compile()` mirrors the standard Keras `compile()` method, but has one + key distinction -- two losses must be provided: `box_loss` and + `classification_loss`. + + Args: + box_loss: a Keras loss to use for box offset regression. A + preconfigured loss is given when the string "ciou" is passed. + classification_loss: a Keras loss to use for box classification. A + preconfigured loss is provided when the string + "binary_crossentropy" is passed. + segmentation_loss:a Keras loss for segmentation. + box_loss_weight: (optional) float, a scaling factor for the box + loss. Defaults to 7.5. + classification_loss_weight: (optional) float, a scaling factor for + the classification loss. Defaults to 0.5. + segmentation_loss_weight: (optional) float, a scaling factor for + the classification loss. Defaults to 6.125. + kwargs: most other `keras.Model.compile()` arguments are supported + and propagated to the `keras.Model` class. + """ + if metrics is not None: + raise ValueError("User metrics not yet supported for YOLOV8") + + if isinstance(box_loss, str): + if box_loss == "ciou": + box_loss = CIoULoss(bounding_box_format="xyxy", reduction="sum") + elif box_loss == "iou": + warnings.warn( + "YOLOV8 recommends using CIoU loss, but was configured to " + "use standard IoU. Consider using `box_loss='ciou'` " + "instead." + ) + else: + raise ValueError( + f"Invalid box loss for YOLOV8Detector: {box_loss}. Box " + "loss should be a keras.Loss or the string 'ciou'." + ) + if isinstance(classification_loss, str): + if classification_loss == "binary_crossentropy": + classification_loss = BinaryCrossentropy(reduction="sum") + else: + raise ValueError( + "Invalid classification loss for YOLOV8Detector: " + f"{classification_loss}. Classification loss should be a " + "keras.Loss or the string 'binary_crossentropy'." + ) + + if isinstance(segmentation_loss, str): + if segmentation_loss == "binary_crossentropy": + segmentation_loss = BinaryCrossentropy(reduction="sum") + else: + raise ValueError( + "Invalid segmentation loss for YOLOV8Detector: " + f"{classification_loss}. Classification loss should be a " + "keras.Loss or the string 'binary_crossentropy'." + ) + + self.box_loss = box_loss + self.classification_loss = classification_loss + self.segmentation_loss = segmentation_loss + self.box_loss_weight = box_loss_weight + self.classification_loss_weight = classification_loss_weight + self.segmentation_loss_weight = segmentation_loss_weight + + losses = { + "box": self.box_loss, + "class": self.classification_loss, + "masks": self.segmentation_loss, + } + + super().compile(loss=losses, **kwargs) + + def train_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().train_step(*args, (x, y)) + + def test_step(self, *args): + data = args[-1] + args = args[:-1] + x, y = unpack_input(data) + return super().test_step(*args, (x, y)) + + def compute_loss(self, x, y, y_pred, sample_weight=None, **kwargs): + box_pred, cls_pred = y_pred["boxes"], y_pred["classes"] + + pred_boxes = decode_regression_to_boxes(box_pred) + pred_scores = cls_pred + + anchor_points, stride_tensor = get_anchors(image_shape=x.shape[1:]) + stride_tensor = ops.expand_dims(stride_tensor, axis=-1) + + gt_labels = y["classes"] + + mask_gt = ops.all(y["boxes"] > -1.0, axis=-1, keepdims=True) + gt_bboxes = bounding_box.convert_format( + y["boxes"], + source=self.bounding_box_format, + target="xyxy", + images=x, + ) + + pred_bboxes = dist2bbox(pred_boxes, anchor_points) + + target_bboxes, target_scores, fg_mask = self.label_encoder( + pred_scores, + ops.cast(pred_bboxes * stride_tensor, gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + + target_bboxes /= stride_tensor + target_scores_sum = ops.maximum(ops.sum(target_scores), 1) + + box_weight = ops.expand_dims( + ops.sum(target_scores, axis=-1) * fg_mask, + axis=-1, + ) + + true_masks = y["segmentation_masks"] + pred_masks = y_pred["masks"] + batch_size, _, H_mask, W_mask = pred_masks.shape + true_masks = build_target_masks( + true_masks, target_scores, H_mask, W_mask, self.num_classes + ) + + crop_masks = batch_boxes_to_masks(target_bboxes, H_mask, W_mask) + H_image, W_image = x.shape[1:3] + mask_weights = build_mask_weights( + self.segmentation_loss_weight, target_bboxes, H_mask, W_mask + ) + + y_true = { + "box": target_bboxes * fg_mask[..., None], + "class": target_scores, + "masks": true_masks * crop_masks * fg_mask[..., None, None], + } + y_pred = { + "box": pred_bboxes * fg_mask[..., None], + "class": pred_scores, + "masks": pred_masks * crop_masks * fg_mask[..., None, None], + } + sample_weights = { + "box": self.box_loss_weight * box_weight / target_scores_sum, + "class": self.classification_loss_weight / target_scores_sum, + "masks": mask_weights, + } + + return super().compute_loss( + x=x, y=y_true, y_pred=y_pred, sample_weight=sample_weights, **kwargs + ) + + def decode_predictions(self, pred, images): + boxes = pred["boxes"] + scores = pred["classes"] + boxes = decode_regression_to_boxes(boxes) + + anchor_points, stride_tensor = get_anchors(image_shape=images.shape[1:]) + stride_tensor = ops.expand_dims(stride_tensor, axis=-1) + + box_preds = dist2bbox(boxes, anchor_points) * stride_tensor + box_preds = bounding_box.convert_format( + box_preds, + source="xyxy", + target=self.bounding_box_format, + images=images, + ) + + return self.prediction_decoder(box_preds, scores) + + def predict_step(self, *args): + outputs = super().predict_step(*args) + decoded_outputs = self.decode_predictions(outputs, args[-1]) + selected_args = decoded_outputs["idx"][..., None, None] + masks = outputs["masks"] + masks = ops.take_along_axis(masks, selected_args, axis=1) + is_valid_output = decoded_outputs["confidence"] > -1 + masks = ops.where(is_valid_output[..., None, None], masks, -1) + decoded_outputs["masks"] = masks + return decoded_outputs + + @property + def prediction_decoder(self): + return self._prediction_decoder + + @prediction_decoder.setter + def prediction_decoder(self, prediction_decoder): + if prediction_decoder.bounding_box_format != self.bounding_box_format: + raise ValueError( + "Expected `prediction_decoder` and YOLOV8Detector to " + "use the same `bounding_box_format`, but got " + "`prediction_decoder.bounding_box_format=" + f"{prediction_decoder.bounding_box_format}`, and " + "`self.bounding_box_format=" + f"{self.bounding_box_format}`." + ) + self._prediction_decoder = prediction_decoder + self.make_predict_function(force=True) + self.make_train_function(force=True) + self.make_test_function(force=True) + + def get_config(self): + return { + "backbone": keras.saving.serialize_keras_object(self.backbone), + "num_classes": self.num_classes, + "bounding_box_format": self.bounding_box_format, + "fpn_depth": self.fpn_depth, + "label_encoder": keras.saving.serialize_keras_object( + self.label_encoder + ), + "prediction_decoder": keras.saving.serialize_keras_object( + self._prediction_decoder + ), + "prototype_dimension": self.prototype_dimension, + "num_prototypes": self.num_prototypes, + "coefficient_dimension": self.coefficient_dimension, + "trainable_backbone": self.trainable_backbone, + } + + @classmethod + def from_config(cls, config): + config["backbone"] = keras.saving.deserialize_keras_object( + config["backbone"] + ) + label_encoder = config.get("label_encoder") + if label_encoder is not None and isinstance(label_encoder, dict): + config["label_encoder"] = keras.saving.deserialize_keras_object( + label_encoder + ) + prediction_decoder = config.get("prediction_decoder") + if prediction_decoder is not None and isinstance( + prediction_decoder, dict + ): + config["prediction_decoder"] = ( + keras.saving.deserialize_keras_object(prediction_decoder) + ) + return cls(**config) + + @classproperty + def presets(cls): + """Dictionary of preset names and configurations.""" + return copy.deepcopy({**backbone_presets, **yolo_v8_detector_presets}) + + @classproperty + def presets_with_weights(cls): + """Dictionary of preset names and configurations that include + weights.""" + return copy.deepcopy( + {**backbone_presets_with_weights, **yolo_v8_detector_presets} + ) + + @classproperty + def backbone_presets(cls): + """Dictionary of preset names and configurations of compatible + backbones.""" + return copy.deepcopy(backbone_presets)