Skip to content

Commit

Permalink
Fixed formatting and new parameter definitions
Browse files Browse the repository at this point in the history
  • Loading branch information
gguzzy committed Dec 21, 2024
1 parent f4ce179 commit 7f93026
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 29 deletions.
36 changes: 29 additions & 7 deletions sahi/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,11 @@

logger = logging.getLogger(__name__)


def filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id):
return [
obj_pred for obj_pred in object_prediction_list
obj_pred
for obj_pred in object_prediction_list
if obj_pred.category.name not in (exclude_classes_by_name or [])
and obj_pred.category.id not in (exclude_classes_by_id or [])
]
Expand Down Expand Up @@ -87,6 +89,12 @@ def get_prediction(
verbose: int
0: no print (default)
1: print prediction duration
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[str]: set of classes to exclude using one or more IDs
Returns:
A dict with fields:
Expand All @@ -111,11 +119,7 @@ def get_prediction(
full_shape=full_shape,
)
object_prediction_list: List[ObjectPrediction] = detection_model.object_prediction_list
object_prediction_list = filter_predictions(
object_prediction_list,
exclude_classes_by_name,
exclude_classes_by_id
)
object_prediction_list = filter_predictions(object_prediction_list, exclude_classes_by_name, exclude_classes_by_id)

# postprocess matching predictions
if postprocess is not None:
Expand Down Expand Up @@ -204,6 +208,12 @@ def get_sliced_prediction(
Prefix for the exported slices. Defaults to None.
slice_dir: str
Directory to save the slices. Defaults to None.
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[str]: set of classes to exclude using one or more IDs
Returns:
A Dict with fields:
Expand Down Expand Up @@ -485,6 +495,12 @@ def predict(
If True, returns a dict with 'export_dir' field.
force_postprocess_type: bool
If True, auto postprocess check will e disabled
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[str]: set of classes to exclude using one or more IDs
"""
# assert prediction type
if no_standard_prediction and no_sliced_prediction:
Expand Down Expand Up @@ -831,6 +847,12 @@ def predict_fiftyone(
verbose: int
0: no print
1: print slice/prediction durations, number of slices, model loading/file exporting durations
exclude_classes_by_name: Optional[List[str]]
None: if no classes are excluded
List[str]: set of classes to exclude using its/their class label name/s
exclude_classes_by_id: Optional[List[int]]
None: if no classes are excluded
List[str]: set of classes to exclude using one or more IDs
"""
check_requirements(["fiftyone"])

Expand Down Expand Up @@ -944,4 +966,4 @@ def predict_fiftyone(
# Show samples with most false positives
session.view = eval_view.sort_by("eval_fp", reverse=True)
while 1:
time.sleep(3)
time.sleep(3)
44 changes: 22 additions & 22 deletions tests/test_exclude_classes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sahi import AutoDetectionModel
from sahi.predict import get_prediction, get_sliced_prediction, predict
from sahi.utils.file import download_from_url
from sahi.utils.yolov8 import download_yolov8s_model
from sahi import AutoDetectionModel
from sahi.predict import get_sliced_prediction, get_prediction, predict

# 1. Download the YOLOv8 model weights
yolov8_model_path = "models/yolov8s.pt"
Expand All @@ -13,16 +13,16 @@
"demo_data/small-vehicles1.jpeg",
)
download_from_url(
"https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/terrain2.png",
"https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/terrain2.png",
"demo_data/terrain2.png",
)

# 3. Load the YOLOv8 detection model
detection_model = AutoDetectionModel.from_pretrained(
model_type="yolov8", # Model type (YOLOv8 in this case)
model_path=yolov8_model_path, # Path to model weights
confidence_threshold=0.5, # Confidence threshold for predictions
device="cpu", # Use "cuda" for GPU inference
model_type="yolov8", # Model type (YOLOv8 in this case)
model_path=yolov8_model_path, # Path to model weights
confidence_threshold=0.5, # Confidence threshold for predictions
device="cpu", # Use "cuda" for GPU inference
)

# 4. Define the classes to exclude
Expand All @@ -33,11 +33,11 @@
result = get_prediction(
image="demo_data/small-vehicles1.jpeg",
detection_model=detection_model,
shift_amount=[0, 0], # No shift applied
full_shape=None, # Full image shape is not provided
postprocess=None, # Postprocess disabled
verbose=1, # Enable verbose output
exclude_classes_by_name=exclude_classes_by_name # Exclude 'car'
shift_amount=[0, 0], # No shift applied
full_shape=None, # Full image shape is not provided
postprocess=None, # Postprocess disabled
verbose=1, # Enable verbose output
exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car'
)

print("\nFiltered Results from `get_prediction` (First 5 Predictions):")
Expand All @@ -49,11 +49,11 @@
result = get_sliced_prediction(
image="demo_data/small-vehicles1.jpeg",
detection_model=detection_model,
slice_height=256, # Slice height
slice_width=256, # Slice width
overlap_height_ratio=0.2, # Overlap height ratio
overlap_width_ratio=0.2, # Overlap width ratio
verbose=1, # Enable verbose output
slice_height=256, # Slice height
slice_width=256, # Slice width
overlap_height_ratio=0.2, # Overlap height ratio
overlap_width_ratio=0.2, # Overlap width ratio
verbose=1, # Enable verbose output
)
print("\nNon-Filtered Results from `get_sliced_prediction` (First 5 Predictions):")
for obj in result.object_prediction_list[:5]:
Expand All @@ -68,7 +68,7 @@
overlap_height_ratio=0.2,
overlap_width_ratio=0.2,
verbose=1,
exclude_classes_by_name=exclude_classes_by_name # Exclude 'car'
exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car'
)
print("\nFiltered Results from `get_sliced_prediction` (First 5 Predictions):")
for obj in result.object_prediction_list[:5]:
Expand All @@ -79,9 +79,9 @@
predict(
detection_model=detection_model,
source="demo_data/small-vehicles1.jpeg", # Single image source
project="runs/test_predict", # Output project directory
name="exclude_test", # Run name
verbose=1, # Enable verbose output
exclude_classes_by_name=exclude_classes_by_name # Exclude 'car'
project="runs/test_predict", # Output project directory
name="exclude_test", # Run name
verbose=1, # Enable verbose output
exclude_classes_by_name=exclude_classes_by_name, # Exclude 'car'
)
print("\nFiltered results from `predict` saved in 'runs/test_predict/exclude_test'")

0 comments on commit 7f93026

Please sign in to comment.