Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port Faster R-CNN to Keras3 #2458

Merged
merged 46 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
c37ae23
Base structure for faster rcnn till rpn head
sineeli Jun 10, 2024
973dd6a
Add export for Faster RNN
sineeli Jun 10, 2024
70c7f24
add init file
sineeli Jun 10, 2024
de67b89
initalize faster rcnn at model level
sineeli Jun 10, 2024
aaebe30
code fix fo roi align
sineeli Jun 12, 2024
0707858
Forward Pass code for Faster R-CNN
sineeli Jun 12, 2024
cff3b8e
Faster RCNN Base code for Keras3(Draft-1)
sineeli Jun 25, 2024
4f511e9
Add local batch size
sineeli Jun 25, 2024
0eef933
Add parameters to RPN Head
sineeli Jul 2, 2024
75c64ca
Make FPN more customizable with parameters and remove redudant code
sineeli Jul 2, 2024
6267a4b
Compute output shape for ROI Generator
sineeli Jul 2, 2024
1931f02
Faster RCNN functional model with required import corrections
sineeli Jul 2, 2024
58dc7f9
add clip boxes to forward pass
sineeli Jul 8, 2024
7c65348
add prediction decoder and use "yxyx" as default internal bounding bo…
sineeli Jul 11, 2024
676fcf1
feature pryamid correction
sineeli Jul 16, 2024
dcea19f
change ops.divide to ops.divide_no_nan
sineeli Jul 29, 2024
2179157
use from logits=True for Non Max supression
sineeli Jul 29, 2024
a002c49
include box convertions for both rois and ground truth boxes
sineeli Jul 29, 2024
5953f0a
Change number of detections in decoder
sineeli Jul 29, 2024
91f21fa
Use categoricalcrossentropy to avoid -1 class error + added get_confi…
sineeli Jul 30, 2024
abf0b44
add basic test cases + linting
sineeli Jul 30, 2024
d2b78e0
Add seed generator for sampling in RPN label encoding and ROI samplin…
sineeli Jul 30, 2024
a397a6c
Use only spatial dimension for ops.nn.avg_pool + use ops.convert_to_t…
sineeli Jul 30, 2024
e336d69
Convert list to tensor using keras ops
sineeli Jul 30, 2024
ecd0dad
Remove seed number from seed generator
sineeli Jul 31, 2024
c91ac27
Remove print and add proper comments
sineeli Aug 5, 2024
ba86502
- Use stddev(0.01) as per paper across RPN and R-CNN Heads
sineeli Aug 8, 2024
4979a99
- Fixes slice for multi backend
sineeli Aug 8, 2024
357a14a
- Add compute metrics method
sineeli Aug 9, 2024
ef27533
Correct test cases and add missing args
sineeli Aug 12, 2024
f37d799
Fix lint issues
sineeli Aug 13, 2024
36d4e10
- Fix lint and remove hard coded params to make it user friendly.
sineeli Aug 13, 2024
5060382
- Generate ROI's while decoding for predictions
sineeli Aug 14, 2024
02d24b0
- Add faster rcnn to build method
sineeli Aug 14, 2024
c0556d8
- Test only for Keras3
sineeli Aug 14, 2024
879028f
- Correct test case
sineeli Aug 15, 2024
c77d03c
- Correct the test cases decorator to skip for Keras2
sineeli Aug 16, 2024
10b9e76
- Skip Legacy test cases
sineeli Aug 16, 2024
e1d89e7
- Remove unecessary import in legacy code to fix lint
sineeli Aug 16, 2024
58178c6
- Correct pytest complexity
sineeli Aug 16, 2024
1c6125b
- FIx Image Shape to 512, 512 default which will not break other test…
sineeli Aug 16, 2024
df56fa6
- Lower image sizes for test cases
sineeli Aug 19, 2024
6b03271
- fix keras to 3.3.3 version
sineeli Aug 20, 2024
8608516
- Generate api
sineeli Aug 20, 2024
d1f05af
- Lint fix
sineeli Aug 20, 2024
8360e5b
- Increase the atol, rtol for YOLOv8 Detector forward pass
sineeli Aug 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ jobs:
keras_cv/src/models/classification \
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
--durations 0
Expand Down
5 changes: 5 additions & 0 deletions .kokoro/github/ubuntu/gpu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,20 +36,23 @@ then
pip install -r requirements-tensorflow-cuda.txt --progress-bar off --timeout 1000
pip install keras-nlp-nightly --no-deps
pip install tensorflow-text~=2.16.0
pip install keras~=3.3.3

elif [ "$KERAS_BACKEND" == "jax" ]
then
echo "JAX backend detected."
pip install -r requirements-jax-cuda.txt --progress-bar off --timeout 1000
pip install keras-nlp-nightly --no-deps
pip install tensorflow-text~=2.16.0
pip install keras~=3.3.3

elif [ "$KERAS_BACKEND" == "torch" ]
then
echo "PyTorch backend detected."
pip install -r requirements-torch-cuda.txt --progress-bar off --timeout 1000
pip install keras-nlp-nightly --no-deps
pip install tensorflow-text~=2.16.0
pip install keras~=3.3.3
fi

pip install --no-deps -e "." --progress-bar off
Expand All @@ -67,6 +70,7 @@ then
keras_cv/src/models/classification \
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
keras_cv/src/models/feature_extractor/clip \
Expand All @@ -82,6 +86,7 @@ else
keras_cv/src/models/classification \
keras_cv/src/models/object_detection/retinanet \
keras_cv/src/models/object_detection/yolo_v8 \
keras_cv/src/models/object_detection/faster_rcnn \
keras_cv/src/models/object_detection_3d \
keras_cv/src/models/segmentation \
keras_cv/src/models/feature_extractor/clip \
Expand Down
4 changes: 4 additions & 0 deletions keras_cv/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from keras_cv.api.models import classification
from keras_cv.api.models import faster_rcnn
from keras_cv.api.models import feature_extractor
from keras_cv.api.models import object_detection
from keras_cv.api.models import retinanet
Expand Down Expand Up @@ -205,6 +206,9 @@
from keras_cv.src.models.classification.image_classifier import ImageClassifier
from keras_cv.src.models.classification.video_classifier import VideoClassifier
from keras_cv.src.models.feature_extractor.clip.clip_model import CLIP
from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import (
FasterRCNN,
)
from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_backbone import (
YOLOV8Backbone,
Expand Down
11 changes: 11 additions & 0 deletions keras_cv/api/models/faster_rcnn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
"""DO NOT EDIT.

This file was autogenerated. Do not edit it by hand,
since your modifications would be overwritten.
"""

from keras_cv.src.models.object_detection.faster_rcnn.feature_pyramid import (
FeaturePyramid,
)
from keras_cv.src.models.object_detection.faster_rcnn.rcnn_head import RCNNHead
from keras_cv.src.models.object_detection.faster_rcnn.rpn_head import RPNHead
3 changes: 3 additions & 0 deletions keras_cv/api/models/object_detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
since your modifications would be overwritten.
"""

from keras_cv.src.models.object_detection.faster_rcnn.faster_rcnn import (
FasterRCNN,
)
from keras_cv.src.models.object_detection.retinanet.retinanet import RetinaNet
from keras_cv.src.models.object_detection.yolo_v8.yolo_v8_detector import (
YOLOV8Detector,
Expand Down
2 changes: 1 addition & 1 deletion keras_cv/src/bounding_box/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def _clip_boxes(boxes, box_format, image_shape):

if isinstance(image_shape, list) or isinstance(image_shape, tuple):
height, width, _ = image_shape
max_length = [height, width, height, width]
max_length = ops.stack([height, width, height, width], axis=-1)
else:
image_shape = ops.cast(image_shape, dtype=boxes.dtype)
height = image_shape[0]
Expand Down
51 changes: 33 additions & 18 deletions keras_cv/src/layers/object_detection/roi_align.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,7 @@ def _feature_bilinear_interpolation(features, kernel_y, kernel_x):
features,
[batch_size * num_boxes, output_size * 2, output_size * 2, num_filters],
)
features = ops.nn.average_pool(
features, [1, 2, 2, 1], [1, 2, 2, 1], "VALID"
)
features = ops.nn.average_pool(features, (2, 2), (2, 2), "VALID")
features = ops.reshape(
features, [batch_size, num_boxes, output_size, output_size, num_filters]
)
Expand Down Expand Up @@ -242,6 +240,11 @@ def multilevel_crop_and_resize(
for i in range(len(feature_widths) - 1):
level_dim_offsets.append(level_dim_offsets[i] + level_dim_sizes[i])
batch_dim_size = level_dim_offsets[-1] + level_dim_sizes[-1]

level_dim_offsets = ops.convert_to_tensor(level_dim_offsets)
feature_widths = ops.convert_to_tensor(feature_widths)
feature_heights = ops.convert_to_tensor(feature_heights)

level_dim_offsets = (
ops.ones_like(level_dim_offsets, dtype="int32") * level_dim_offsets
)
Expand All @@ -259,7 +262,9 @@ def multilevel_crop_and_resize(
# following the FPN paper to divide by 224.
levels = ops.cast(
ops.floor_divide(
ops.log(ops.divide(areas_sqrt, 224.0)),
ops.log(
ops.divide_no_nan(areas_sqrt, ops.convert_to_tensor(224.0))
),
ops.log(2.0),
)
+ 4.0,
Expand Down Expand Up @@ -292,12 +297,18 @@ def multilevel_crop_and_resize(
ops.concatenate(
[
ops.expand_dims(
[[ops.cast(max_feature_height, "float32")]] / level_strides
ops.convert_to_tensor(
[[ops.cast(max_feature_height, "float32")]]
)
/ level_strides
- 1,
axis=-1,
),
ops.expand_dims(
[[ops.cast(max_feature_width, "float32")]] / level_strides
ops.convert_to_tensor(
[[ops.cast(max_feature_width, "float32")]]
)
/ level_strides
- 1,
axis=-1,
),
Expand Down Expand Up @@ -357,7 +368,7 @@ def multilevel_crop_and_resize(
# TODO(tanzhenyu): replace tf.gather with tf.gather_nd and try to get
# similar performance.
features_per_box = ops.reshape(
ops.take(features_r2, indices),
ops.take(features_r2, indices, axis=0),
[
batch_size,
num_boxes,
Expand All @@ -378,7 +389,7 @@ def multilevel_crop_and_resize(
# performance as this is mostly a duplicate of
# https://github.com/tensorflow/models/blob/master/official/legacy/detection/ops/spatial_transform_ops.py#L324
@keras.utils.register_keras_serializable(package="keras_cv")
class _ROIAligner(keras.layers.Layer):
class ROIAligner(keras.layers.Layer):
"""Performs ROIAlign for the second stage processing."""

def __init__(
Expand All @@ -397,13 +408,11 @@ def __init__(
sample_offset: A `float` in [0, 1] of the subpixel sample offset.
**kwargs: Additional keyword arguments passed to Layer.
"""
# assert_tf_keras("keras_cv.layers._ROIAligner")
self._config_dict = {
"bounding_box_format": bounding_box_format,
"crop_size": target_size,
"sample_offset": sample_offset,
}
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.target_size = target_size
self.sample_offset = sample_offset
self.built = True

def call(
self,
Expand All @@ -427,16 +436,22 @@ def call(
"""
boxes = bounding_box.convert_format(
boxes,
source=self._config_dict["bounding_box_format"],
source=self.bounding_box_format,
target="yxyx",
)
roi_features = multilevel_crop_and_resize(
features,
boxes,
output_size=self._config_dict["crop_size"],
sample_offset=self._config_dict["sample_offset"],
output_size=self.target_size,
sample_offset=self.sample_offset,
)
return roi_features

def get_config(self):
return self._config_dict
config = super().get_config()
config["bounding_box_format"] = self.bounding_box_format
config["target_size"] = self.target_size
config["sample_offset"] = self.sample_offset

def compute_output_shape(self, input_shape):
return (None, None, self.target_size, self.target_size, 256)
8 changes: 7 additions & 1 deletion keras_cv/src/layers/object_detection/roi_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class ROIGenerator(keras.layers.Layer):
applying NMS in inference mode. When RPN is run on multiple
feature maps / levels (as in FPN) this number is per
feature map / level.
nms_from_logits: bool. True means input score is logits, False means confidence.

Example:
```python
Expand All @@ -90,6 +91,7 @@ def __init__(
nms_score_threshold_test: float = 0.0,
nms_iou_threshold_test: float = 0.7,
post_nms_topk_test: int = 1000,
nms_from_logits: bool = False,
**kwargs,
):
super().__init__(**kwargs)
Expand All @@ -102,6 +104,7 @@ def __init__(
self.nms_score_threshold_test = nms_score_threshold_test
self.nms_iou_threshold_test = nms_iou_threshold_test
self.post_nms_topk_test = post_nms_topk_test
self.nms_from_logits = nms_from_logits
self.built = True

def call(
Expand Down Expand Up @@ -158,7 +161,7 @@ def per_level_gen(boxes, scores):
# TODO(tanzhenyu): consider supporting soft / batched nms for accl
boxes = NonMaxSuppression(
bounding_box_format=self.bounding_box_format,
from_logits=False,
from_logits=self.nms_from_logits,
iou_threshold=nms_iou_threshold,
confidence_threshold=nms_score_threshold,
max_detections=level_post_nms_topk,
Expand Down Expand Up @@ -191,6 +194,9 @@ def per_level_gen(boxes, scores):

return rois, roi_scores

def compute_output_shape(self, input_shape):
return (None, None, 4), (None, None, 1)

def get_config(self):
config = {
"bounding_box_format": self.bounding_box_format,
Expand Down
44 changes: 24 additions & 20 deletions keras_cv/src/layers/object_detection/roi_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@keras.utils.register_keras_serializable(package="keras_cv")
class _ROISampler(keras.layers.Layer):
class ROISampler(keras.layers.Layer):
"""
Sample ROIs for loss related calculation.

Expand All @@ -41,9 +41,10 @@ class _ROISampler(keras.layers.Layer):
if its range is [0, num_classes).

Args:
bounding_box_format: The format of bounding boxes to generate. Refer
roi_bounding_box_format: The format of roi bounding boxes. Refer
[to the keras.io docs](https://keras.io/api/keras_cv/bounding_box/formats/)
for more details on supported bounding box formats.
gt_bounding_box_format: The format of ground truth bounding boxes.
roi_matcher: a `BoxMatcher` object that matches proposals with ground
truth boxes. The positive match must be 1 and negative match must be -1.
Such assumption is not being validated here.
Expand All @@ -59,7 +60,8 @@ class _ROISampler(keras.layers.Layer):

def __init__(
self,
bounding_box_format: str,
roi_bounding_box_format: str,
gt_bounding_box_format: str,
roi_matcher: box_matcher.BoxMatcher,
positive_fraction: float = 0.25,
background_class: int = 0,
Expand All @@ -68,12 +70,14 @@ def __init__(
**kwargs,
):
super().__init__(**kwargs)
self.bounding_box_format = bounding_box_format
self.roi_bounding_box_format = roi_bounding_box_format
self.gt_bounding_box_format = gt_bounding_box_format
self.roi_matcher = roi_matcher
self.positive_fraction = positive_fraction
self.background_class = background_class
self.num_sampled_rois = num_sampled_rois
self.append_gt_boxes = append_gt_boxes
self.seed_generator = keras.random.SeedGenerator()
self.built = True
# for debugging.
self._positives = keras.metrics.Mean()
Expand All @@ -97,6 +101,12 @@ def call(
sampled_gt_classes: [batch_size, num_sampled_rois, 1]
sampled_class_weights: [batch_size, num_sampled_rois, 1]
"""
rois = bounding_box.convert_format(
rois, source=self.roi_bounding_box_format, target="yxyx"
)
gt_boxes = bounding_box.convert_format(
gt_boxes, source=self.gt_bounding_box_format, target="yxyx"
)
if self.append_gt_boxes:
# num_rois += num_gt
rois = ops.concatenate([rois, gt_boxes], axis=1)
Expand All @@ -110,12 +120,6 @@ def call(
"num_rois must be less than `num_sampled_rois` "
f"({self.num_sampled_rois}), got {num_rois}"
)
rois = bounding_box.convert_format(
rois, source=self.bounding_box_format, target="yxyx"
)
gt_boxes = bounding_box.convert_format(
gt_boxes, source=self.bounding_box_format, target="yxyx"
)
# [batch_size, num_rois, num_gt]
similarity_mat = iou.compute_iou(
rois, gt_boxes, bounding_box_format="yxyx", use_masking=True
Expand Down Expand Up @@ -171,6 +175,7 @@ def call(
negative_matches,
self.num_sampled_rois,
self.positive_fraction,
seed=self.seed_generator,
)
# [batch_size, num_sampled_rois] in the range of [0, num_rois)
sampled_indicators, sampled_indices = ops.top_k(
Expand Down Expand Up @@ -204,16 +209,15 @@ def call(
)

def get_config(self):
config = {
"bounding_box_format": self.bounding_box_format,
"positive_fraction": self.positive_fraction,
"background_class": self.background_class,
"num_sampled_rois": self.num_sampled_rois,
"append_gt_boxes": self.append_gt_boxes,
"roi_matcher": self.roi_matcher.get_config(),
}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
config = super().get_config()
config["roi_bounding_box_format"] = self.roi_bounding_box_format
config["gt_bounding_box_format"] = self.gt_bounding_box_format
config["positive_fraction"] = self.positive_fraction
config["background_class"] = self.background_class
config["num_sampled_rois"] = self.num_sampled_rois
config["append_gt_boxes"] = self.append_gt_boxes
config["roi_matcher"] = self.roi_matcher.get_config()
return config

@classmethod
def from_config(cls, config, custom_objects=None):
Expand Down
Loading
Loading