Skip to content

Commit

Permalink
add pattern matching
Browse files Browse the repository at this point in the history
  • Loading branch information
Aleexo committed Mar 7, 2024
1 parent cd20887 commit f992216
Show file tree
Hide file tree
Showing 3 changed files with 380 additions and 0 deletions.
1 change: 1 addition & 0 deletions supervision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from supervision.detection.annotate import BoxAnnotator
from supervision.detection.core import Detections
from supervision.detection.line_zone import LineZone, LineZoneAnnotator
from supervision.detection.match_pattern import MatchPattern
from supervision.detection.tools.csv_sink import CSVSink
from supervision.detection.tools.inference_slicer import InferenceSlicer
from supervision.detection.tools.json_sink import JSONSink
Expand Down
244 changes: 244 additions & 0 deletions supervision/detection/match_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
import itertools
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union

import numpy as np

from supervision.detection.core import Detections
from supervision.geometry.core import Position


class _Constraint:
"""
A constraint is a rule that a pattern must follow. It is defined
by a function and its arguments. The arguments are strings that specify an object
of the pattern and one of its fields.
For example, this constraint tests that the objects A and B of your pattern have
the same class:
```python
_Constraint(lambda x, y: x == y, "A.class_id", "B.class_id")
```
!!! tip
You can use a value instead of a function as the criteria. It will check that
the arguments are all equal to this value. For instance, this constraint tests
that the object A of your pattern has a class_id equal to 1.
```python
_Constraint(1, "A.class_id")
```
This works with any number of arguments, so you can check several objects at
once:
```python
_Constraint(1, "A.class_id", "B.class_id")
```
"""

def __init__(
self, criteria: Union[Callable[..., bool], Any], arguments: List[str]
) -> None:
"""
Args:
criteria (Callable): A function that takes N arguments and returns a
boolean. Criteria can also be any value, in which case the constraint
checks that every argument is equal to this value.
*arguments (str): A list of N strings that will be given as arguments for
the criteria. The arguments should look like "name.field". The name of
the object can be any name that doesn't contain a dot (`.`). The field
should be one of the following:
- `xyxy`, `mask`, `class_id`, `confidence`, or `tracker_id`
- one of the `Position` enum strings
- a field from the `data` attribute of your detections
"""
validate_arguments(arguments)
self.arguments = arguments
if callable(criteria):
self.criteria = criteria
else:
self.criteria = lambda *args: all(equality(arg, criteria) for arg in args)


def validate_arguments(arguments: List[str]) -> None:
for argument in arguments:
if argument.count(".") != 1:
raise ValueError(
f"Constraint argument should be `name.field`, got: '{argument}'"
)


def equality(arg1, arg2):
if isinstance(arg1, np.ndarray) or isinstance(arg2, np.ndarray):
return (arg1 == arg2).all()
return arg1 == arg2


class MatchPattern:
"""
A pattern is a set of constraints that apply to detections. You can think of
patterns as regex for detections. `MatchPattern` will return all matches that
satisfy all the constraints.
A pattern is described as named boxes organized according to rules. Each rule is
given as a constraint. For instance "BoxA and BoxB should have the same class",
"Boxes A and B should overlap", etc. The constraints are functions that apply to
fields from the detections (the `class_id`, the `xyxy` coordinates, etc.).
For example, if you want to search for a cat and a dog that have the same center
point you can use the following pattern:
```python
import cv2
import supervision as sv
from ultralytics import YOLO
image = cv2.imread(<SOURCE_IMAGE_PATH>)
model = YOLO('yolov8s.pt')
pattern = sv.MatchPattern(
[
(lambda class_id: class_id == 0, ["Cat.class_id"]), # class_id for cat is 0
(1, ["Dog.class_id"]), # class_id for dog is 1
(
lambda dog_center, cat_center: dog_center == cat_center),
["Dog.CENTER", "Cat.CENTER"]
),
]
)
result = model(image)[0]
detections = sv.Detections.from_ultralytics(result)
matches = pattern.match(detections)
```
This will return all the matches that satisfy the constraints. The result is a list
of `Detections`. A field `match_name` is added to the Detections.data to keep
track of the names in your pattern.
```python
first_match = matches[0]
first_match["match_name"] # ["Cat", "Dog"]
```
"""

def __init__(
self,
constraints: List[Tuple[Union[Callable[..., bool], Any], List[str]]],
):
"""
Args:
constraints (List[Tuple[Union[Callable[..., bool], Any], List[str]]]):
A list of constraints. Each constraint contains a criterion and a list
of arguments:
- `criteria` is a function that returns a boolean value. See
`_Constraint` for more information.
- arguments is a list of strings. Each argument is composed of
`name.field`. The field should be one of the following:
- `xyxy`, `mask`, `class_id`, `confidence`, or `tracker_id`
- one of the `Position` enum strings
- a field from the `data` attribute of your detections
"""
self._constraints: List[_Constraint] = []
for constraint in constraints:
criteria, arguments = constraint
self.add_constraint(criteria, arguments)

def add_constraint(
self, criteria: Union[Callable[..., bool], Any], arguments: List[str]
) -> None:
"""
Adds a constraint to the matching pattern.
Args:
criteria: A function that returns a boolean value or any value you want to
match with the `arguments`. See `_Constraint` for more details.
arguments: A list of strings. See `_Constraint` for more details.
"""
self._constraints.append(_Constraint(criteria, arguments))

def match(self, detections: Detections) -> List[Detections]:
"""
Matches the pattern of the constraints to the detections.
Args:
detections (Detections): Detections to match the pattern with.
Returns:
List[Detections]: List of detections that match the constraints. A specific
field `match_name` is added to the matches to keep track of the names
specified in the pattern arguments.
"""
combinations = self._generate_combinations(len(detections))

names = self._get_names_from_constraints()
index = 0
while index < len(combinations):
combination = dict(zip(names, combinations[index]))
template_kwargs = {
name: detections[int(box_index)]
for name, box_index in combination.items()
}
for constraint in self._constraints:
criteria_args = [
_get_argument(template_kwargs, detections, arg)
for arg in constraint.arguments
]
if not constraint.criteria(*criteria_args):
incompatible_boxes = {
arg_name: combination[arg_name]
for arg_name in self._get_names_from_arguments(
constraint.arguments
)
}
filter_bool = np.ones(len(combinations), dtype=bool)
for name, values in incompatible_boxes.items():
filter_bool &= combinations[name] == values
combinations = combinations[~filter_bool]
break
else:
index += 1

results: List[Detections] = []
for valid_combination in combinations:
indexes = list(valid_combination)
matching_boxes: Detections = detections[indexes] # type: ignore
matching_boxes["match_name"] = names
results.append(matching_boxes)

return results

def _get_names_from_constraints(self) -> List[str]:
"""
Returns the object names used in the constraints.
"""
arguments = [
arg for constraint in self._constraints for arg in constraint.arguments
]
return self._get_names_from_arguments(arguments)

def _get_names_from_arguments(self, arguments: Iterable[str]) -> List[str]:
"""
Returns the object names used in the arguments. Sorted and unique.
"""
return sorted(
list({arg.split(".")[0] if "." in arg else arg for arg in arguments})
)

def _generate_combinations(self, num_detections) -> np.ndarray:
"""
Generates all the possible combinations for the pattern matching.
Returns an array of shape (N, M) where N is the number of combinations and M is
the number of objects in the pattern. Each row corresponds to the set of indexes
from detections.
"""
names = self._get_names_from_constraints()
return np.fromiter(
itertools.permutations(range(num_detections), len(names)),
np.dtype([(name, int) for name in names]),
)


def _get_argument(kwargs: Dict[str, Any], detections: Detections, argument: str) -> Any:
name, subfield = argument.split(".")
if subfield in ["xyxy", "mask", "class_id", "confidence", "tracker_id"]:
return getattr(kwargs[name], subfield)[0]
if subfield in Position.list():
return kwargs[name].get_anchors_coordinates(Position[subfield])[0]
if subfield in detections.data:
return kwargs[name][subfield][0]
raise ValueError(f"Unknown field '{subfield}' for object '{name}'")
135 changes: 135 additions & 0 deletions test/detection/test_match_pattern.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import numpy as np
import pytest

from supervision import Detections, MatchPattern


@pytest.mark.parametrize(
"constraints",
[
(
[
(1, ["Box1.class_id"]),
(0.1, ["Box1.confidence"]),
([0, 0, 15, 15], ["Box2.xyxy"]),
]
), # Test constraints with values
(
[
(lambda id: id == 1, ["Box1.class_id"]),
(lambda score: score == 0.1, ["Box1.confidence"]),
(lambda xyxy: xyxy[3] == 15, ["Box2.xyxy"]),
]
), # Test constraints with functions
(
[
(lambda id: id == 1, ["Box1.class_id"]),
(lambda xyxy1, xyxy2: xyxy1[0] == xyxy2[0], ["Box1.xyxy", "Box2.xyxy"]),
]
), # Test constraints with multiple arguments
],
)
def test_match_pattern(constraints):
detections = Detections(
xyxy=np.array(
[
[0, 0, 10, 10],
[0, 0, 15, 15],
[5, 5, 20, 20],
]
),
confidence=np.array([0.1, 0.2, 0.3]),
class_id=np.array([1, 2, 3]),
)

expected_result = [
Detections(
xyxy=np.array(
[
[0, 0, 10, 10],
[0, 0, 15, 15],
]
),
confidence=np.array([0.1, 0.2]),
class_id=np.array([1, 2]),
data={"match_name": np.array(["Box1", "Box2"])},
)
]

matches = MatchPattern(constraints).match(detections)

assert matches == expected_result


def test_match_pattern_with_2_results():
detections = Detections(
xyxy=np.array(
[
[0, 0, 10, 10],
[0, 0, 15, 15],
[5, 5, 20, 20],
]
),
confidence=np.array([0.1, 0.2, 0.3]),
class_id=np.array([1, 2, 3]),
)

expected_result = [
Detections(
xyxy=np.array(
[
[0, 0, 10, 10],
]
),
confidence=np.array([0.1]),
class_id=np.array([1]),
data={"match_name": np.array(["Box1"])},
),
Detections(
xyxy=np.array(
[
[0, 0, 15, 15],
]
),
confidence=np.array([0.2]),
class_id=np.array([2]),
data={"match_name": np.array(["Box1"])},
),
]

matches = MatchPattern([[lambda xyxy: xyxy[0] == 0, ["Box1.xyxy"]]]).match(
detections
)

assert matches == expected_result


def test_add_constraint():
detections = Detections(
xyxy=np.array(
[
[0, 0, 10, 10],
[0, 0, 15, 15],
[5, 5, 20, 20],
]
),
confidence=np.array([0.1, 0.2, 0.3]),
class_id=np.array([1, 2, 3]),
)

expected_result = [
Detections(
xyxy=np.array(
[
[0, 0, 10, 10],
]
),
confidence=np.array([0.1]),
class_id=np.array([1]),
data={"match_name": np.array(["Box1"])},
)
]
pattern = MatchPattern([])
pattern.add_constraint(lambda id: id == 1, ["Box1.class_id"])
matches = pattern.match(detections)
assert matches == expected_result

0 comments on commit f992216

Please sign in to comment.