From 1976635fd1f6807f1d066dfb5c7f8c16ec7c9b6a Mon Sep 17 00:00:00 2001 From: Wenxin Yang <595026238@qq.com> Date: Thu, 26 Dec 2024 22:21:24 +0800 Subject: [PATCH 1/3] predict.py --- sahi/predict.py | 51 ++++++++++++++++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 18 deletions(-) diff --git a/sahi/predict.py b/sahi/predict.py index 65f46bdf..52d1134a 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -4,6 +4,7 @@ import logging import os import time +import math from typing import List, Optional from sahi.utils.import_utils import is_available @@ -87,10 +88,11 @@ def get_prediction( durations_in_seconds = dict() # read image as pil - image_as_pil = read_image_as_pil(image) + # image_as_pil = read_image_as_pil(image) # get prediction time_start = time.time() - detection_model.perform_inference(np.ascontiguousarray(image_as_pil)) + # detection_model.perform_inference(np.ascontiguousarray(image_as_pil)) + detection_model.perform_inference(image) time_end = time.time() - time_start durations_in_seconds["prediction"] = time_end @@ -101,12 +103,10 @@ def get_prediction( shift_amount=shift_amount, full_shape=full_shape, ) - object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list - + object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list_per_image # postprocess matching predictions if postprocess is not None: object_prediction_list = postprocess(object_prediction_list) - time_end = time.time() - time_start durations_in_seconds["postprocess"] = time_end @@ -139,6 +139,7 @@ def get_sliced_prediction( auto_slice_resolution: bool = True, slice_export_prefix: str = None, slice_dir: str = None, + num_batch: int = 1 ) -> PredictionResult: """ Function for slice image + get predicion for each slice + combine predictions in full image. @@ -198,8 +199,8 @@ def get_sliced_prediction( # for profiling durations_in_seconds = dict() - # currently only 1 batch supported - num_batch = 1 + # # currently only 1 batch supported + # num_batch = 1 # create slices from full image time_start = time.time() slice_image_result = slice_image( @@ -233,7 +234,8 @@ def get_sliced_prediction( ) # create prediction input - num_group = int(num_slices / num_batch) + # num_group = int(num_slices / num_batch) + num_group = math.ceil(num_slices / num_batch) if verbose == 1 or verbose == 2: tqdm.write(f"Performing prediction on {num_slices} slices.") object_prediction_list = [] @@ -243,22 +245,31 @@ def get_sliced_prediction( image_list = [] shift_amount_list = [] for image_ind in range(num_batch): - image_list.append(slice_image_result.images[group_ind * num_batch + image_ind]) + if (group_ind * num_batch + image_ind) >= num_slices: + break + # image_list.append(slice_image_result.images[group_ind * num_batch + image_ind]) + img_slice = slice_image_result.images[group_ind * num_batch + image_ind] + img_slice = img_slice[:,:,::-1] + image_list.append(img_slice) shift_amount_list.append(slice_image_result.starting_pixels[group_ind * num_batch + image_ind]) # perform batch prediction + num_full = len(image_list) prediction_result = get_prediction( - image=image_list[0], + image=image_list, detection_model=detection_model, - shift_amount=shift_amount_list[0], - full_shape=[ + shift_amount=shift_amount_list, + full_shape=[[ slice_image_result.original_image_height, slice_image_result.original_image_width, - ], + ]] * num_full, ) + # convert sliced predictions to full predictions - for object_prediction in prediction_result.object_prediction_list: - if object_prediction: # if not empty - object_prediction_list.append(object_prediction.get_shifted_object_prediction()) + for object_prediction_per in prediction_result.object_prediction_list: + + if len(object_prediction_per) != 0: # if not empty + for object_prediction in object_prediction_per: + object_prediction_list.append(object_prediction.get_shifted_object_prediction()) # merge matching predictions during sliced prediction if merge_buffer_length is not None and len(object_prediction_list) > merge_buffer_length: @@ -267,7 +278,7 @@ def get_sliced_prediction( # perform standard prediction if num_slices > 1 and perform_standard_pred: prediction_result = get_prediction( - image=image, + image=[np.array(image)], detection_model=detection_model, shift_amount=[0, 0], full_shape=[ @@ -276,7 +287,9 @@ def get_sliced_prediction( ], postprocess=None, ) - object_prediction_list.extend(prediction_result.object_prediction_list) + if len(prediction_result.object_prediction_list) != 0: + for _predicion_result in prediction_result.object_prediction_list: + object_prediction_list.extend(_predicion_result) # merge matching predictions if len(object_prediction_list) > 1: @@ -377,6 +390,7 @@ def predict( verbose: int = 1, return_dict: bool = False, force_postprocess_type: bool = False, + num_batch: int = 1, **kwargs, ): """ @@ -569,6 +583,7 @@ def predict( postprocess_match_threshold=postprocess_match_threshold, postprocess_class_agnostic=postprocess_class_agnostic, verbose=1 if verbose else 0, + num_batch = num_batch, ) object_prediction_list = prediction_result.object_prediction_list durations_in_seconds["slice"] += prediction_result.durations_in_seconds["slice"] From f5fec83c2a9c184b72dfc43efca4dac43bde6a8d Mon Sep 17 00:00:00 2001 From: Wenxin Yang <595026238@qq.com> Date: Thu, 26 Dec 2024 22:22:03 +0800 Subject: [PATCH 2/3] prediction.py --- sahi/prediction.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/sahi/prediction.py b/sahi/prediction.py index 5fca9648..d6db0edd 100644 --- a/sahi/prediction.py +++ b/sahi/prediction.py @@ -164,8 +164,13 @@ def __init__( image: Union[Image.Image, str, np.ndarray], durations_in_seconds: Optional[Dict] = None, ): - self.image: Image.Image = read_image_as_pil(image) - self.image_width, self.image_height = self.image.size + + if type(image) is list: + self.image = image + self.image_width, self.image_height = self.image[0].shape[:2] + else : + self.image: Image.Image = read_image_as_pil(image) + self.image_width, self.image_height = self.image.size self.object_prediction_list: List[ObjectPrediction] = object_prediction_list self.durations_in_seconds = durations_in_seconds From 42fd93b57aa6ec266a481350cb3f364f59084aff Mon Sep 17 00:00:00 2001 From: Wenxin Yang <595026238@qq.com> Date: Thu, 26 Dec 2024 22:22:47 +0800 Subject: [PATCH 3/3] ultralytics.py --- sahi/models/ultralytics.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/sahi/models/ultralytics.py b/sahi/models/ultralytics.py index dd421f5c..a9434b80 100644 --- a/sahi/models/ultralytics.py +++ b/sahi/models/ultralytics.py @@ -67,9 +67,11 @@ def perform_inference(self, image: np.ndarray): if self.image_size is not None: kwargs = {"imgsz": self.image_size, **kwargs} + if type(image) is list: - prediction_result = self.model(image[:, :, ::-1], **kwargs) # YOLOv8 expects numpy arrays to have BGR - + prediction_result = self.model(image, **kwargs) # YOLOv8 expects numpy arrays to have BGR + else : + prediction_result = self.model(image[:, :, ::-1], **kwargs) if self.has_mask: if not prediction_result[0].masks: prediction_result[0].masks = Masks( @@ -109,7 +111,10 @@ def perform_inference(self, image: np.ndarray): prediction_result = [result.boxes.data for result in prediction_result] self._original_predictions = prediction_result - self._original_shape = image.shape + if type(image) == list: + self._original_shape = image[0].shape + else: + self._original_shape = image.shape @property def category_names(self):