-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #144 from oarriaga/refactor_boxes
Refactor boxes
- Loading branch information
Showing
12 changed files
with
395 additions
and
139 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import numpy as np | ||
from paz.backend.boxes import compute_ious, to_corner_form | ||
|
||
|
||
def match(boxes, prior_boxes, iou_threshold=0.5): | ||
"""Matches each prior box with a ground truth box (box from `boxes`). | ||
It then selects which matched box will be considered positive e.g. iou > .5 | ||
and returns for each prior box a ground truth box that is either positive | ||
(with a class argument different than 0) or negative. | ||
# Arguments | ||
boxes: Numpy array of shape `(num_ground_truh_boxes, 4 + 1)`, | ||
where the first the first four coordinates correspond to | ||
box coordinates and the last coordinates is the class | ||
argument. This boxes should be the ground truth boxes. | ||
prior_boxes: Numpy array of shape `(num_prior_boxes, 4)`. | ||
where the four coordinates are in center form coordinates. | ||
iou_threshold: Float between [0, 1]. Intersection over union | ||
used to determine which box is considered a positive box. | ||
# Returns | ||
numpy array of shape `(num_prior_boxes, 4 + 1)`. | ||
where the first the first four coordinates correspond to point | ||
form box coordinates and the last coordinates is the class | ||
argument. | ||
""" | ||
ious = compute_ious(boxes, to_corner_form(np.float32(prior_boxes))) | ||
per_prior_which_box_iou = np.max(ious, axis=0) | ||
per_prior_which_box_arg = np.argmax(ious, 0) | ||
|
||
# overwriting per_prior_which_box_arg if they are the best prior box | ||
per_box_which_prior_arg = np.argmax(ious, 1) | ||
per_prior_which_box_iou[per_box_which_prior_arg] = 2 | ||
for box_arg in range(len(per_box_which_prior_arg)): | ||
best_prior_box_arg = per_box_which_prior_arg[box_arg] | ||
per_prior_which_box_arg[best_prior_box_arg] = box_arg | ||
|
||
matches = boxes[per_prior_which_box_arg] | ||
matches[per_prior_which_box_iou < iou_threshold, 4] = 0 | ||
return matches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
|
||
from paz import processors as pr | ||
from paz.abstract import SequentialProcessor | ||
from processors import MatchBoxes | ||
|
||
|
||
class AugmentImage(SequentialProcessor): | ||
"""Augments an RGB image by randomly changing contrast, brightness | ||
saturation and hue. | ||
""" | ||
def __init__(self): | ||
super(AugmentImage, self).__init__() | ||
self.add(pr.RandomContrast()) | ||
self.add(pr.RandomBrightness()) | ||
self.add(pr.RandomSaturation(0.7)) | ||
self.add(pr.RandomHue()) | ||
|
||
|
||
class PreprocessImage(SequentialProcessor): | ||
"""Preprocess RGB image by resizing it to the given ``shape``. If a | ||
``mean`` is given it is substracted from image and it not the image gets | ||
normalized. | ||
# Arguments | ||
shape: List of two Ints. | ||
mean: List of three Ints indicating the per-channel mean to be | ||
subtracted. | ||
""" | ||
def __init__(self, shape, mean=pr.BGR_IMAGENET_MEAN): | ||
super(PreprocessImage, self).__init__() | ||
self.add(pr.ResizeImage(shape)) | ||
self.add(pr.CastImage(float)) | ||
if mean is None: | ||
self.add(pr.NormalizeImage()) | ||
else: | ||
self.add(pr.SubtractMeanImage(mean)) | ||
|
||
|
||
class AugmentBoxes(SequentialProcessor): | ||
"""Perform data augmentation with bounding boxes. | ||
# Arguments | ||
mean: List of three elements used to fill empty image spaces. | ||
""" | ||
def __init__(self, mean=pr.BGR_IMAGENET_MEAN): | ||
super(AugmentBoxes, self).__init__() | ||
self.add(pr.ToImageBoxCoordinates()) | ||
self.add(pr.Expand(mean=mean)) | ||
# RandomSampleCrop was commented out | ||
self.add(pr.RandomSampleCrop()) | ||
self.add(pr.RandomFlipBoxesLeftRight()) | ||
self.add(pr.ToNormalizedBoxCoordinates()) | ||
|
||
|
||
class PreprocessBoxes(SequentialProcessor): | ||
"""Preprocess bounding boxes | ||
# Arguments | ||
num_classes: Int. | ||
prior_boxes: Numpy array of shape ``[num_boxes, 4]`` containing | ||
prior/default bounding boxes. | ||
IOU: Float. Intersection over union used to match boxes. | ||
variances: List of two floats indicating variances to be encoded | ||
for encoding bounding boxes. | ||
""" | ||
def __init__(self, num_classes, prior_boxes, IOU, variances): | ||
super(PreprocessBoxes, self).__init__() | ||
self.add(MatchBoxes(prior_boxes, IOU),) | ||
self.add(pr.EncodeBoxes(prior_boxes, variances)) | ||
self.add(pr.BoxClassToOneHotVector(num_classes)) | ||
|
||
|
||
class AugmentDetection(SequentialProcessor): | ||
"""Augment boxes and images for object detection. | ||
# Arguments | ||
prior_boxes: Numpy array of shape ``[num_boxes, 4]`` containing | ||
prior/default bounding boxes. | ||
split: Flag from `paz.processors.TRAIN`, ``paz.processors.VAL`` | ||
or ``paz.processors.TEST``. Certain transformations would take | ||
place depending on the flag. | ||
num_classes: Int. | ||
size: Int. Image size. | ||
mean: List of three elements indicating the per channel mean. | ||
IOU: Float. Intersection over union used to match boxes. | ||
variances: List of two floats indicating variances to be encoded | ||
for encoding bounding boxes. | ||
""" | ||
def __init__(self, prior_boxes, split=pr.TRAIN, num_classes=21, size=300, | ||
mean=pr.BGR_IMAGENET_MEAN, IOU=.5, | ||
variances=[0.1, 0.1, 0.2, 0.2]): | ||
super(AugmentDetection, self).__init__() | ||
# image processors | ||
self.augment_image = AugmentImage() | ||
self.augment_image.add(pr.ConvertColorSpace(pr.RGB2BGR)) | ||
self.preprocess_image = PreprocessImage((size, size), mean) | ||
|
||
# box processors | ||
self.augment_boxes = AugmentBoxes() | ||
args = (num_classes, prior_boxes, IOU, variances) | ||
self.preprocess_boxes = PreprocessBoxes(*args) | ||
|
||
# pipeline | ||
self.add(pr.UnpackDictionary(['image', 'boxes'])) | ||
self.add(pr.ControlMap(pr.LoadImage(), [0], [0])) | ||
if split == pr.TRAIN: | ||
self.add(pr.ControlMap(self.augment_image, [0], [0])) | ||
self.add(pr.ControlMap(self.augment_boxes, [0, 1], [0, 1])) | ||
self.add(pr.ControlMap(self.preprocess_image, [0], [0])) | ||
self.add(pr.ControlMap(self.preprocess_boxes, [1], [1])) | ||
self.add(pr.SequenceWrapper( | ||
{0: {'image': [size, size, 3]}}, | ||
{1: {'boxes': [len(prior_boxes), 4 + num_classes]}})) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from paz.abstract import Processor | ||
from boxes import match | ||
|
||
|
||
class MatchBoxes(Processor): | ||
"""Match prior boxes with ground truth boxes. | ||
# Arguments | ||
prior_boxes: Numpy array of shape (num_boxes, 4). | ||
iou: Float in [0, 1]. Intersection over union in which prior boxes | ||
will be considered positive. A positive box is box with a class | ||
different than `background`. | ||
variance: List of two floats. | ||
""" | ||
def __init__(self, prior_boxes, iou=.5): | ||
self.prior_boxes = prior_boxes | ||
self.iou = iou | ||
super(MatchBoxes, self).__init__() | ||
|
||
def call(self, boxes): | ||
boxes = match(boxes, self.prior_boxes, self.iou) | ||
return boxes |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.