diff --git a/supervision/detection/tools/polygon_zone.py b/supervision/detection/tools/polygon_zone.py index 00b746aaa..713b1ec5e 100644 --- a/supervision/detection/tools/polygon_zone.py +++ b/supervision/detection/tools/polygon_zone.py @@ -89,6 +89,45 @@ def trigger(self, detections: Detections) -> npt.NDArray[np.bool_]: self.current_count = int(np.sum(is_in_zone)) return is_in_zone.astype(bool) + def trigger_overlap( + self, detections: Detections, overlap_threshold: float + ) -> npt.NDArray[np.bool_]: + """ + Determines if the detections are within the polygon zone. + + Parameters: + detections (Detections): The detections + to be checked against the polygon zone + + overlap_threshold (float): threshold of overlap + + Returns: + np.ndarray: A boolean numpy array indicating + if each detection overlaps the polygon zone + """ + + clipped_xyxy = clip_boxes( + xyxy=detections.xyxy, resolution_wh=self.frame_resolution_wh + ) + clipped_detections = replace(detections, xyxy=clipped_xyxy) + overlap_results: npt.NDArray[np.bool_] = np.zeros( + len(clipped_detections), dtype=bool + ) + + masks = [] + for i, (x1, y1, x2, y2) in enumerate(detections.xyxy.astype(np.int32)): + bbox_mask = np.zeros_like(self.mask, dtype=bool) + + bbox_mask[y1:y2, x1:x2] = True + masks.append(bbox_mask) + overlap_ratio = ( + np.count_nonzero(np.logical_and(bbox_mask, self.mask)) + / detections.area[i] + ) + overlap_results[i] = overlap_ratio >= overlap_threshold + + return overlap_results + class PolygonZoneAnnotator: """ diff --git a/test/detection/test_polygonzone.py b/test/detection/test_polygonzone.py index 1a86a45b4..9d5b4896b 100644 --- a/test/detection/test_polygonzone.py +++ b/test/detection/test_polygonzone.py @@ -92,3 +92,85 @@ def test_polygon_zone_trigger( with exception: in_zone = polygon_zone.trigger(detections) assert np.all(in_zone == expected_results) + + +@pytest.mark.parametrize( + "detections, polygon_zone, overlap_threshold, expected_results, exception", + [ + # Test cases for trigger_overlap function + ( + DETECTIONS, + sv.PolygonZone( + POLYGON, + FRAME_RESOLUTION, + triggering_anchors=( + sv.Position.TOP_LEFT, + sv.Position.TOP_RIGHT, + sv.Position.BOTTOM_LEFT, + sv.Position.BOTTOM_RIGHT, + ), + ), + 0.25, # Overlap threshold + np.array( + [False, False, True, True, True, True, True, False, False], dtype=bool + ), + DoesNotRaise(), + ), # Case with specific overlap threshold + ( + DETECTIONS, + sv.PolygonZone( + POLYGON, + FRAME_RESOLUTION, + ), + 0.5, # Overlap threshold + np.array( + [False, False, False, True, True, True, False, False, False], dtype=bool + ), + DoesNotRaise(), + ), # Another overlap threshold + ( + DETECTIONS, + sv.PolygonZone( + POLYGON, + FRAME_RESOLUTION, + ), + 0.1, # Lower overlap threshold to catch more detections + np.array( + [False, False, True, True, True, True, True, False, False], dtype=bool + ), + DoesNotRaise(), + ), # Lower threshold test + ( + DETECTIONS, + sv.PolygonZone( + POLYGON, + FRAME_RESOLUTION, + ), + 0, # Lower overlap threshold to catch more detections + np.array( + [True, True, True, True, True, True, True, True, True], dtype=bool + ), + DoesNotRaise(), + ), # Lower threshold test + ( + sv.Detections.empty(), + sv.PolygonZone( + POLYGON, + FRAME_RESOLUTION, + ), + 0, + np.array([], dtype=bool), + DoesNotRaise(), + ), # Test empty detections + ], +) +def test_polygon_zone_overlap( + detections: sv.Detections, + polygon_zone: sv.PolygonZone, + overlap_threshold: float, + expected_results: np.ndarray, + exception: Exception, +) -> None: + with exception: + overlaps = polygon_zone.trigger_overlap(detections, overlap_threshold) + assert np.all(overlaps == expected_results)