Skip to content

Commit

Permalink
fix for using bgr image in inference instead of rgb (#1022)
Browse files Browse the repository at this point in the history
Co-authored-by: fatih cagatay akyon <[email protected]>
  • Loading branch information
bilkosem and fcakyon authored May 20, 2024
1 parent 45904e4 commit 065f7e7
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 5 deletions.
2 changes: 0 additions & 2 deletions sahi/models/mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


try:

check_requirements(["torch", "mmdet", "mmcv", "mmengine"])

from mmdet.apis.det_inferencer import DetInferencer
Expand Down Expand Up @@ -104,7 +103,6 @@ def __init__(
image_size: int = None,
scope: str = "mmdet",
):

if not IMPORT_MMDET_V3:
raise ImportError("Failed to import `DetInferencer`. Please confirm you have installed 'mmdet>=3.0.0'")

Expand Down
2 changes: 1 addition & 1 deletion sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ def predict(
export_format=visual_export_format,
)
if not novisual and source_is_video: # export video
output_video_writer.write(result["image"])
output_video_writer.write(cv2.cvtColor(result["image"], cv2.COLOR_RGB2BGR))

# render video inference
if view_video:
Expand Down
4 changes: 2 additions & 2 deletions sahi/utils/cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ def read_video_frame(video_capture, frame_skip_interval):
if not ret:
print("\n=========================== Video Ended ===========================")
break
yield Image.fromarray(frame)
yield Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

else:
while video_capture.isOpened:
Expand All @@ -349,7 +349,7 @@ def read_video_frame(video_capture, frame_skip_interval):
if not ret:
print("\n=========================== Video Ended ===========================")
break
yield Image.fromarray(frame)
yield Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

if export_visual:
# get video properties and create VideoWriter object
Expand Down

0 comments on commit 065f7e7

Please sign in to comment.