Skip to content

Commit

Permalink
improve segmentation visualization (#812)
Browse files Browse the repository at this point in the history
  • Loading branch information
fcakyon authored Jan 15, 2023
1 parent 33e2150 commit 0463bce
Showing 1 changed file with 29 additions and 17 deletions.
46 changes: 29 additions & 17 deletions sahi/utils/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,17 @@ def visualize_prediction(
text_th = text_th or max(rect_th - 1, 1)
# set text_size for category names
text_size = text_size or rect_th / 3
# add bbox and mask to image if present

# add masks to image if present
if masks is not None:
for mask in masks:
# deepcopy mask so that original is not altered
mask = copy.deepcopy(mask)
# draw mask
rgb_mask = apply_color_mask(np.squeeze(mask), color)
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

# add bboxes to image if present
for i in range(len(boxes)):
# deepcopy boxso that original is not altered
box = copy.deepcopy(boxes[i])
Expand All @@ -350,13 +360,6 @@ def visualize_prediction(
# set color
if colors is not None:
color = colors(class_)
# visualize masks if present
if masks is not None:
# deepcopy mask so that original is not altered
mask = copy.deepcopy(masks[i])
# draw mask
rgb_mask = apply_color_mask(np.squeeze(mask), color)
image = cv2.addWeighted(image, 1, rgb_mask, 0.7, 0)
# set bbox points
p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
# visualize boxes
Expand Down Expand Up @@ -427,12 +430,28 @@ def visualize_object_predictions(
else:
colors = None
# set rect_th for boxes
rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.001), 1)
rect_th = rect_th or max(round(sum(image.shape) / 2 * 0.003), 2)
# set text_th for category names
text_th = text_th or max(rect_th - 1, 1)
# set text_size for category names
text_size = text_size or rect_th / 3
# add bbox and mask to image if present

# add masks to image if present
for object_prediction in object_prediction_list:
# deepcopy object_prediction_list so that original is not altered
object_prediction = object_prediction.deepcopy()
# visualize masks if present
if object_prediction.mask is not None:
# deepcopy mask so that original is not altered
mask = object_prediction.mask.bool_mask
# set color
if colors is not None:
color = colors(object_prediction.category.id)
# draw mask
rgb_mask = apply_color_mask(mask, color)
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)

# add bboxes to image if present
for object_prediction in object_prediction_list:
# deepcopy object_prediction_list so that original is not altered
object_prediction = object_prediction.deepcopy()
Expand All @@ -444,13 +463,6 @@ def visualize_object_predictions(
# set color
if colors is not None:
color = colors(object_prediction.category.id)
# visualize masks if present
if object_prediction.mask is not None:
# deepcopy mask so that original is not altered
mask = object_prediction.mask.bool_mask
# draw mask
rgb_mask = apply_color_mask(mask, color)
image = cv2.addWeighted(image, 1, rgb_mask, 0.4, 0)
# set bbox points
p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
# visualize boxes
Expand Down

0 comments on commit 0463bce

Please sign in to comment.