Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add soft Non-Max suppression #1624

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/detection/double_detection_filter.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,25 @@ comments: true

:::supervision.detection.overlap_filter.box_non_max_suppression

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_filter.box_soft_non_max_suppression">box_soft_non_max_suppression</a></h2>
</div>

:::supervision.detection.overlap_filter.box_soft_non_max_suppression


<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_filter.mask_non_max_suppression">mask_non_max_suppression</a></h2>
</div>

:::supervision.detection.overlap_filter.mask_non_max_suppression

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_filter.mask_soft_non_max_suppression">mask_soft_non_max_suppression</a></h2>
</div>

:::supervision.detection.overlap_filter.mask_soft_non_max_suppression

<div class="md-typeset">
<h2><a href="#supervision.detection.overlap_filter.box_non_max_merge">box_non_max_merge</a></h2>
</div>
Expand Down
59 changes: 59 additions & 0 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@
from supervision.detection.overlap_filter import (
box_non_max_merge,
box_non_max_suppression,
box_soft_non_max_suppression,
mask_non_max_suppression,
mask_soft_non_max_suppression,
)
from supervision.detection.tools.transformers import (
process_transformers_detection_result,
Expand Down Expand Up @@ -1320,6 +1322,63 @@ def with_nms(

return self[indices]

def with_soft_nms(
self, sigma: float = 0.5, class_agnostic: bool = False
) -> Detections:
"""
Perform soft non-maximum suppression on the current set of object detections.

Args:
sigma (float): The sigma value to use for the soft non-maximum suppression
algorithm. Defaults to 0.5.
class_agnostic (bool): Whether to perform class-agnostic
non-maximum suppression. If True, the class_id of each detection
will be ignored. Defaults to False.

Returns:
Detections: A new Detections object containing the subset of detections
after non-maximum suppression.

Raises:
AssertionError: If `confidence` is None and class_agnostic is False.
"""
if len(self) == 0:
return self

assert (
self.confidence is not None
), "Detections confidence must be given for NMS to be executed."

if class_agnostic:
predictions = np.hstack((self.xyxy, self.confidence.reshape(-1, 1)))
else:
assert self.class_id is not None, (
"Detections class_id must be given for NMS to be executed. If you"
" intended to perform class agnostic NMS set class_agnostic=True."
)
predictions = np.hstack(
(
self.xyxy,
self.confidence.reshape(-1, 1),
self.class_id.reshape(-1, 1),
)
)

if self.mask is not None:
soft_confidences = mask_soft_non_max_suppression(
predictions=predictions,
masks=self.mask,
sigma=sigma,
)
self.confidence = soft_confidences
else:
soft_confidences = box_soft_non_max_suppression(
predictions=predictions, sigma=sigma
)
self.confidence = soft_confidences

return self

def with_nmm(
self, threshold: float = 0.5, class_agnostic: bool = False
) -> Detections:
Expand Down
188 changes: 165 additions & 23 deletions supervision/detection/overlap_filter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from enum import Enum
from typing import List, Union
from typing import List, Tuple, Union

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -38,6 +38,48 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:
return resized_masks


def __prepare_data_for_mask_nms(
mask_dimension: int,
masks: np.ndarray,
predictions: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]:
"""
Get IOUs from mask. Prepare the data for non-max suppression.

Args:
mask_dimension (int): The dimension to which the masks should be
resized before computing IOU values.
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
Shape: `(N, H, W)`, where N is the number of predictions, and H, W are the
dimensions of each
predictions (np.ndarray): An array of object detection predictions in the format
of `(x_min, y_min, x_max, y_max, score)` or
`(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`,
where N is the number of predictions.

Returns:
Tuple[np.ndarray, np.ndarray, int, np.ndarray]: A tuple containing the
predictions, categories, IOUs, number of rows, and the sorted indices.

Raises:
AssertionError: If `iou_threshold` is not within the closed range from
`0` to `1`.
"""
rows, columns = predictions.shape

if columns == 5:
predictions = np.c_[predictions, np.zeros(rows)]

sort_index = predictions[:, 4].argsort()[::-1]
predictions = predictions[sort_index]
masks = masks[sort_index]
masks_resized = resize_masks(masks, mask_dimension)
ious = mask_iou_batch(masks_resized, masks_resized)
categories = predictions[:, 5]

return predictions, categories, ious, rows, sort_index


def mask_non_max_suppression(
predictions: np.ndarray,
masks: np.ndarray,
Expand Down Expand Up @@ -72,17 +114,9 @@ def mask_non_max_suppression(
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
f"{iou_threshold} given."
)
rows, columns = predictions.shape

if columns == 5:
predictions = np.c_[predictions, np.zeros(rows)]

sort_index = predictions[:, 4].argsort()[::-1]
predictions = predictions[sort_index]
masks = masks[sort_index]
masks_resized = resize_masks(masks, mask_dimension)
ious = mask_iou_batch(masks_resized, masks_resized)
categories = predictions[:, 5]
_, categories, ious, rows, sort_index = __prepare_data_for_mask_nms(
mask_dimension, masks, predictions
)

keep = np.ones(rows, dtype=bool)
for i in range(rows):
Expand All @@ -93,31 +127,71 @@ def mask_non_max_suppression(
return keep[sort_index.argsort()]


def box_non_max_suppression(
predictions: np.ndarray, iou_threshold: float = 0.5
def mask_soft_non_max_suppression(
predictions: np.ndarray,
masks: np.ndarray,
mask_dimension: int = 640,
sigma: float = 0.5,
) -> np.ndarray:
"""
Perform Non-Maximum Suppression (NMS) on object detection predictions.
Perform Soft Non-Maximum Suppression (Soft-NMS) on segmentation predictions.

Args:
Args:
predictions (np.ndarray): An array of object detection predictions in
the format of `(x_min, y_min, x_max, y_max, score)`
or `(x_min, y_min, x_max, y_max, score, class)`.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
sigma (float): The sigma value to use for soft non-maximum suppression.

Returns:
np.ndarray: A boolean array indicating which predictions to keep after n
on-maximum suppression.
np.ndarray: An array containing the updated confidence scores.

Raises:
AssertionError: If `iou_threshold` is not within the
closed range from `0` to `1`.
AssertionError: If `sigma` is not within the open range from `0` to `1`.
"""
assert 0 <= iou_threshold <= 1, (
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
f"{iou_threshold} given."
assert (
0 < sigma < 1
), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given."
predictions, categories, ious, rows, sort_index = __prepare_data_for_mask_nms(
mask_dimension, masks, predictions
)

not_this_row = np.ones(rows)
for i in range(rows):
not_this_row[i] = 0
condition = (categories[i] == categories) * not_this_row
predictions[:, 4] = predictions[:, 4] * np.exp(
-(ious[i] ** 2) / sigma * condition
)

return predictions[sort_index.argsort(), 4]


def __prepare_data_for_box_nsm(
predictions: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]:
"""
Prepare the data for non-max suppression.

Args:
predictions (np.ndarray): An array of object detection predictions in the
format of `(x_min, y_min, x_max, y_max, score)` or
`(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`,
where N is the number of predictions.

Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]: A tuple containing
the predictions, categories, IOUs, number of rows, and the sorted indices

Raises:
AssertionError: If `iou_threshold` is not within the closed range from `0`
to `1`.


"""
rows, columns = predictions.shape

# add column #5 - category filled with zeros for agnostic nms
Expand All @@ -127,14 +201,42 @@ def box_non_max_suppression(
# sort predictions column #4 - score
sort_index = np.flip(predictions[:, 4].argsort())
predictions = predictions[sort_index]

boxes = predictions[:, :4]
categories = predictions[:, 5]
ious = box_iou_batch(boxes, boxes)
ious = ious - np.eye(rows)

keep = np.ones(rows, dtype=bool)
return predictions, categories, ious, rows, sort_index


def box_non_max_suppression(
predictions: np.ndarray, iou_threshold: float = 0.5
) -> np.ndarray:
"""
Perform Non-Maximum Suppression (NMS) on object detection predictions.

Args:
predictions (np.ndarray): An array of object detection predictions in
the format of `(x_min, y_min, x_max, y_max, score)`
or `(x_min, y_min, x_max, y_max, score, class)`.
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.

Returns:
np.ndarray: A boolean array indicating which predictions to keep after n
on-maximum suppression.

Raises:
AssertionError: If `iou_threshold` is not within the
closed range from `0` to `1`.
"""
assert 0 <= iou_threshold <= 1, (
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
f"{iou_threshold} given."
)
_, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(predictions)

keep = np.ones(rows, dtype=bool)
for index, (iou, category) in enumerate(zip(ious, categories)):
if not keep[index]:
continue
Expand All @@ -147,6 +249,46 @@ def box_non_max_suppression(
return keep[sort_index.argsort()]


def box_soft_non_max_suppression(
predictions: np.ndarray, sigma: float = 0.5
) -> np.ndarray:
"""
Perform Soft Non-Maximum Suppression (Soft-NMS) on object detection predictions.

Args:
predictions (np.ndarray): An array of object detection predictions in
the format of `(x_min, y_min, x_max, y_max, score)`
or `(x_min, y_min, x_max, y_max, score, class)`.
iou_threshold (float): The intersection-over-union threshold
to use for soft non-maximum suppression.
sigma (float): The sigma value to use for soft non-maximum suppression.

Returns:
np.ndarray: An array containing the updated confidence scores.
Raises:
AssertionError: If `iou_threshold` is not within the
closed range from `0` to `1`.
AssertionError: If `sigma` is not within the opened range from `0` to `1`.
"""

assert (
0 < sigma < 1
), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given."
predictions, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(
predictions
)

not_this_row = np.ones(rows)
for i in range(rows):
not_this_row[i] = 0
condition = (categories[i] == categories) * not_this_row
predictions[:, 4] = predictions[:, 4] * np.exp(
-(ious[i] ** 2) / sigma * condition
)

return predictions[sort_index.argsort(), 4]


def group_overlapping_boxes(
predictions: npt.NDArray[np.float64], iou_threshold: float = 0.5
) -> List[List[int]]:
Expand Down
Loading