Skip to content

Commit

Permalink
Merge pull request #9 from cospectrum/dev
Browse files Browse the repository at this point in the history
remove OpenCV NMS
  • Loading branch information
cospectrum authored Jan 4, 2025
2 parents edacb84 + e116f46 commit c50dbed
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 19 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,7 @@ wheels/
.venv
venv

generated
output.png

draw.py
boxes.txt
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ from PIL import Image

seg_model = SegModel.from_path("./models/seg_model.onnx")

img = Image.open("./input.png").convert("RGB")
img = Image.open("./assets/data/us_card.png").convert("RGB")
cards = seg_model.apply(img)

for card in cards:
print(f"score={card.score}, box={card.box}")
img = draw_mask(img, card.mask > 0.5)
img = draw_box(img, card.box)
img = draw_mask(img, card.mask > 0.5)
img.save("./output.png")
```

## License
Apache-2.0
30 changes: 30 additions & 0 deletions generate_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from microwink import SegModel
from microwink.common import draw_mask, draw_box

from PIL import Image
from pathlib import Path
from tests.utils import round_box

MODEL_PATH = "./models/seg_model.onnx"
SAVE_TO = Path("./generated")


def main() -> None:
SAVE_TO.mkdir(exist_ok=True)
seg_model = SegModel.from_path(MODEL_PATH)

for img_path in Path("./assets/data/").iterdir():
img = Image.open(img_path).convert("RGB")
cards = seg_model.apply(img)

print(img_path.name)
for card in cards:
img = draw_box(img, card.box)
img = draw_mask(img, card.mask > 0.5)
print(round_box(card.box))
print()
img.save(SAVE_TO / img_path.name)


if __name__ == "__main__":
main()
51 changes: 35 additions & 16 deletions src/microwink/seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,12 @@ def postprocess(
likely = x[:, 4 : 4 + NUM_CLASSES].max(axis=1) > conf_threshold
x = x[likely]

boxes = x[:, :4]
scores = x[:, 4 : 4 + NUM_CLASSES].max(axis=1)
boxes = x[:, :4]
boxes = self.postprocess_boxes(boxes, img_size, ratio, pad_w=pad_w, pad_h=pad_h)
keep = self.nms(
boxes,
scores,
conf_threshold=conf_threshold,
iou_threshold=iou_threshold,
)
N = len(keep)
Expand All @@ -199,7 +199,6 @@ def postprocess(
masks_in = masks_in[keep]

ih, iw = img_size
boxes = self.postprocess_boxes(boxes, img_size, ratio, pad_w=pad_w, pad_h=pad_h)
masks = self.postprocess_masks(protos, masks_in, boxes, (ih, iw))

assert masks.shape == (N, ih, iw)
Expand Down Expand Up @@ -293,21 +292,41 @@ def nms(
boxes: np.ndarray,
scores: np.ndarray,
*,
conf_threshold: float,
iou_threshold: float,
) -> list[int]:
from cv2.dnn import NMSBoxes

sorted_indices = np.argsort(scores)[::-1]
N = len(boxes)
assert boxes.shape == (N, 4)
assert scores.shape == (N,)
keep = NMSBoxes(
boxes, # type: ignore
scores, # type: ignore
conf_threshold,
iou_threshold,
)
return list(keep)
assert sorted_indices.shape == (N,)

keep_boxes = []
while sorted_indices.size > 0:
box_id = int(sorted_indices[0])
ious = SegModel._compute_iou(boxes[box_id, :], boxes[sorted_indices[1:], :])
keep_indices = np.where(ious < iou_threshold)[0]
sorted_indices = sorted_indices[keep_indices + 1]

keep_boxes.append(box_id)
return keep_boxes

@staticmethod
def _compute_iou(box: np.ndarray, boxes: np.ndarray) -> np.ndarray:
assert box.shape == (4,)
assert boxes.shape == (len(boxes), 4)
xmin = np.maximum(box[0], boxes[:, 0])
ymin = np.maximum(box[1], boxes[:, 1])
xmax = np.minimum(box[2], boxes[:, 2])
ymax = np.minimum(box[3], boxes[:, 3])

intersection_area = np.maximum(0, xmax - xmin) * np.maximum(0, ymax - ymin)

box_area = (box[2] - box[0]) * (box[3] - box[1])
boxes_area = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
union_area = box_area + boxes_area - intersection_area

iou = intersection_area / union_area
return iou

@staticmethod
def with_border(
Expand All @@ -318,11 +337,11 @@ def with_border(
right: int,
color: tuple[int, int, int],
) -> np.ndarray:
from cv2 import BORDER_CONSTANT, copyMakeBorder
import cv2

assert img.ndim == 3
return copyMakeBorder(
img, top, bottom, left, right, BORDER_CONSTANT, value=color
return cv2.copyMakeBorder(
img, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color
)


Expand Down
27 changes: 27 additions & 0 deletions tests/test_readme.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from pathlib import Path


Script = str


def test_readme() -> None:
readme_path = Path("./README.md")
start_tag = "```python"
end_tag = "```"
scripts = parse_readme_code(readme_path, start_tag, end_tag)
for script in scripts:
print("\n# executing the following script")
print(script)
print("\n# stdout...")
exec(script)


def parse_readme_code(path: Path, start_tag: str, end_tag) -> list[Script]:
assert path.exists()
text = path.read_text()
_, *sections = text.split(start_tag)
scripts = []
for section in sections:
script, *_ = section.split(end_tag)
scripts.append(script.strip())
return scripts

0 comments on commit c50dbed

Please sign in to comment.