Skip to content
This repository has been archived by the owner on Mar 12, 2024. It is now read-only.

Adding changes for Multiple Input Compatibility with the code along with using Mask threshold as config variable. Also added easier implementation of BITMASK format in the code. #563

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions d2/configs/detr_segm_256_6_6_torchvision.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ MODEL:
PIXEL_MEAN: [123.675, 116.280, 103.530]
PIXEL_STD: [58.395, 57.120, 57.375]
MASK_ON: True
MASK_THRESHOLD: 0.5
RESNETS:
DEPTH: 50
STRIDE_IN_1X1: False
Expand Down
10 changes: 6 additions & 4 deletions d2/detr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__(self, cfg):

self.num_classes = cfg.MODEL.DETR.NUM_CLASSES
self.mask_on = cfg.MODEL.MASK_ON
self.input_format = cfg.INPUT.MASK_FORMAT
hidden_dim = cfg.MODEL.DETR.HIDDEN_DIM
num_queries = cfg.MODEL.DETR.NUM_OBJECT_QUERIES
# Transformer parameters:
Expand Down Expand Up @@ -150,8 +151,8 @@ def __init__(self, cfg):
)
self.criterion.to(self.device)

pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(3, 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(3, 1, 1)
pixel_mean = torch.Tensor(cfg.MODEL.PIXEL_MEAN).to(self.device).view(len(cfg.MODEL.PIXEL_MEAN), 1, 1)
pixel_std = torch.Tensor(cfg.MODEL.PIXEL_STD).to(self.device).view(len(cfg.MODEL.PIXEL_STD), 1, 1)
self.normalizer = lambda x: (x - pixel_mean) / pixel_std
self.to(self.device)

Expand Down Expand Up @@ -210,7 +211,8 @@ def prepare_targets(self, targets):
new_targets.append({"labels": gt_classes, "boxes": gt_boxes})
if self.mask_on and hasattr(targets_per_image, 'gt_masks'):
gt_masks = targets_per_image.gt_masks
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
if self.input_format != "BITMASK" or self.input_format != "bitmask":
gt_masks = convert_coco_poly_to_mask(gt_masks.polygons, h, w)
new_targets[-1].update({'masks': gt_masks})
return new_targets

Expand Down Expand Up @@ -242,7 +244,7 @@ def inference(self, box_cls, box_pred, mask_pred, image_sizes):
result.pred_boxes.scale(scale_x=image_size[1], scale_y=image_size[0])
if self.mask_on:
mask = F.interpolate(mask_pred[i].unsqueeze(0), size=image_size, mode='bilinear', align_corners=False)
mask = mask[0].sigmoid() > 0.5
mask = mask[0].sigmoid() > cfg.MODEL.MASK_THRESHOLD
B, N, H, W = mask_pred.shape
mask = BitMasks(mask.cpu()).crop_and_resize(result.pred_boxes.tensor.cpu(), 32)
result.pred_masks = mask.unsqueeze(1).to(mask_pred[0].device)
Expand Down