Skip to content

Commit

Permalink
remove opencv nms
Browse files Browse the repository at this point in the history
  • Loading branch information
cospectrum committed Jan 4, 2025
1 parent 427c186 commit e116f46
Showing 1 changed file with 35 additions and 16 deletions.
51 changes: 35 additions & 16 deletions src/microwink/seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ def postprocess(
likely = x[:, 4 : 4 + NUM_CLASSES].max(axis=1) > conf_threshold
x = x[likely]

boxes = x[:, :4]
scores = x[:, 4 : 4 + NUM_CLASSES].max(axis=1)
boxes = x[:, :4]
boxes = self.postprocess_boxes(boxes, img_size, ratio, pad_w=pad_w, pad_h=pad_h)
keep = self.nms(
boxes,
scores,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
)
N = len(keep)
Expand All @@ -199,7 +199,6 @@ def postprocess(
masks_in = masks_in[keep]

ih, iw = img_size
boxes = self.postprocess_boxes(boxes, img_size, ratio, pad_w=pad_w, pad_h=pad_h)
masks = self.postprocess_masks(protos, masks_in, boxes, (ih, iw))

assert masks.shape == (N, ih, iw)
Expand Down Expand Up @@ -293,21 +292,41 @@ def nms(
boxes: np.ndarray,
scores: np.ndarray,
*,
conf_threshold: float,
iou_threshold: float,
) -> list[int]:
from cv2.dnn import NMSBoxes

sorted_indices = np.argsort(scores)[::-1]
N = len(boxes)
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
keep = NMSBoxes(
boxes, # type: ignore
scores, # type: ignore
conf_threshold,
iou_threshold,
)
return list(keep)
assert sorted_indices.shape == (N,)

keep_boxes = []
while sorted_indices.size > 0:
box_id = int(sorted_indices[0])
ious = SegModel._compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
keep_indices = np.where(ious < iou_threshold)[0]
sorted_indices = sorted_indices[keep_indices + 1]

keep_boxes.append(box_id)
return keep_boxes

@staticmethod
def _compute_iou(box: np.ndarray, boxes: np.ndarray) -> np.ndarray:
assert box.shape == (4,)
assert boxes.shape == (len(boxes), 4)
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])

intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)

box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area

iou = intersection_area / union_area
return iou

@staticmethod
def with_border(
Expand All @@ -318,11 +337,11 @@ def with_border(
right: int,
color: tuple[int, int, int],
) -> np.ndarray:
from cv2 import BORDER_CONSTANT, copyMakeBorder
import cv2

assert img.ndim == 3
return copyMakeBorder(
img, top, bottom, left, right, BORDER_CONSTANT, value=color
return cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
)


Expand Down

0 comments on commit e116f46

Please sign in to comment.