diff --git a/sahi/predict.py b/sahi/predict.py index e846f057..41fb9905 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -53,9 +53,11 @@ logger = logging.getLogger(__name__) + def filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id): return [ - obj_pred for obj_pred in object_prediction_list + obj_pred + for obj_pred in object_prediction_list if obj_pred.category.name not in (exclude_classes_by_name or []) and obj_pred.category.id not in (exclude_classes_by_id or []) ] @@ -87,6 +89,12 @@ def get_prediction( verbose: int 0: no print (default) 1: print prediction duration + exclude_classes_by_name: Optional[List[str]] + None: if no classes are excluded + List[str]: set of classes to exclude using its/their class label name/s + exclude_classes_by_id: Optional[List[int]] + None: if no classes are excluded + List[str]: set of classes to exclude using one or more IDs Returns: A dict with fields: @@ -111,11 +119,7 @@ def get_prediction( full_shape=full_shape, ) object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list - object_prediction_list = filter_predictions( - object_prediction_list, - exclude_classes_by_name, - exclude_classes_by_id - ) + object_prediction_list = filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id) # postprocess matching predictions if postprocess is not None: @@ -204,6 +208,12 @@ def get_sliced_prediction( Prefix for the exported slices. Defaults to None. slice_dir: str Directory to save the slices. Defaults to None. + exclude_classes_by_name: Optional[List[str]] + None: if no classes are excluded + List[str]: set of classes to exclude using its/their class label name/s + exclude_classes_by_id: Optional[List[int]] + None: if no classes are excluded + List[str]: set of classes to exclude using one or more IDs Returns: A Dict with fields: @@ -485,6 +495,12 @@ def predict( If True, returns a dict with 'export_dir' field. force_postprocess_type: bool If True, auto postprocess check will e disabled + exclude_classes_by_name: Optional[List[str]] + None: if no classes are excluded + List[str]: set of classes to exclude using its/their class label name/s + exclude_classes_by_id: Optional[List[int]] + None: if no classes are excluded + List[str]: set of classes to exclude using one or more IDs """ # assert prediction type if no_standard_prediction and no_sliced_prediction: @@ -831,6 +847,12 @@ def predict_fiftyone( verbose: int 0: no print 1: print slice/prediction durations, number of slices, model loading/file exporting durations + exclude_classes_by_name: Optional[List[str]] + None: if no classes are excluded + List[str]: set of classes to exclude using its/their class label name/s + exclude_classes_by_id: Optional[List[int]] + None: if no classes are excluded + List[str]: set of classes to exclude using one or more IDs """ check_requirements(["fiftyone"]) @@ -944,4 +966,4 @@ def predict_fiftyone( # Show samples with most false positives session.view = eval_view.sort_by("eval_fp", reverse=True) while 1: - time.sleep(3) \ No newline at end of file + time.sleep(3) diff --git a/tests/test_exclude_classes.py b/tests/test_exclude_classes.py index 17512970..6ae696d3 100644 --- a/tests/test_exclude_classes.py +++ b/tests/test_exclude_classes.py @@ -1,7 +1,7 @@ +from sahi import AutoDetectionModel +from sahi.predict import get_prediction, get_sliced_prediction, predict from sahi.utils.file import download_from_url from sahi.utils.yolov8 import download_yolov8s_model -from sahi import AutoDetectionModel -from sahi.predict import get_sliced_prediction, get_prediction, predict # 1. Download the YOLOv8 model weights yolov8_model_path = "models/yolov8s.pt" @@ -13,16 +13,16 @@ "demo_data/small-vehicles1.jpeg", ) download_from_url( - "https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/terrain2.png", + "https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/terrain2.png", "demo_data/terrain2.png", ) # 3. Load the YOLOv8 detection model detection_model = AutoDetectionModel.from_pretrained( - model_type="yolov8", # Model type (YOLOv8 in this case) - model_path=yolov8_model_path, # Path to model weights - confidence_threshold=0.5, # Confidence threshold for predictions - device="cpu", # Use "cuda" for GPU inference + model_type="yolov8", # Model type (YOLOv8 in this case) + model_path=yolov8_model_path, # Path to model weights + confidence_threshold=0.5, # Confidence threshold for predictions + device="cpu", # Use "cuda" for GPU inference ) # 4. Define the classes to exclude @@ -33,11 +33,11 @@ result = get_prediction( image="demo_data/small-vehicles1.jpeg", detection_model=detection_model, - shift_amount=[0, 0], # No shift applied - full_shape=None, # Full image shape is not provided - postprocess=None, # Postprocess disabled - verbose=1, # Enable verbose output - exclude_classes_by_name=exclude_classes_by_name # Exclude 'car' + shift_amount=[0, 0], # No shift applied + full_shape=None, # Full image shape is not provided + postprocess=None, # Postprocess disabled + verbose=1, # Enable verbose output + exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car' ) print("\nFiltered Results from `get_prediction` (First 5 Predictions):") @@ -49,11 +49,11 @@ result = get_sliced_prediction( image="demo_data/small-vehicles1.jpeg", detection_model=detection_model, - slice_height=256, # Slice height - slice_width=256, # Slice width - overlap_height_ratio=0.2, # Overlap height ratio - overlap_width_ratio=0.2, # Overlap width ratio - verbose=1, # Enable verbose output + slice_height=256, # Slice height + slice_width=256, # Slice width + overlap_height_ratio=0.2, # Overlap height ratio + overlap_width_ratio=0.2, # Overlap width ratio + verbose=1, # Enable verbose output ) print("\nNon-Filtered Results from `get_sliced_prediction` (First 5 Predictions):") for obj in result.object_prediction_list[:5]: @@ -68,7 +68,7 @@ overlap_height_ratio=0.2, overlap_width_ratio=0.2, verbose=1, - exclude_classes_by_name=exclude_classes_by_name # Exclude 'car' + exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car' ) print("\nFiltered Results from `get_sliced_prediction` (First 5 Predictions):") for obj in result.object_prediction_list[:5]: @@ -79,9 +79,9 @@ predict( detection_model=detection_model, source="demo_data/small-vehicles1.jpeg", # Single image source - project="runs/test_predict", # Output project directory - name="exclude_test", # Run name - verbose=1, # Enable verbose output - exclude_classes_by_name=exclude_classes_by_name # Exclude 'car' + project="runs/test_predict", # Output project directory + name="exclude_test", # Run name + verbose=1, # Enable verbose output + exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car' ) print("\nFiltered results from `predict` saved in 'runs/test_predict/exclude_test'")