Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
YHallouard committed Nov 3, 2024
1 parent 7ba853d commit 3b2b60c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 54 deletions.
13 changes: 5 additions & 8 deletions supervision/detection/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1323,19 +1323,17 @@ def with_nms(
return self[indices]

def with_soft_nms(
self, threshold: float = 0.5, class_agnostic: bool = False, sigma: float = 0.5
self, sigma: float = 0.5, class_agnostic: bool = False
) -> Detections:
"""
Perform soft non-maximum suppression on the current set of object detections.
Args:
threshold (float): The intersection-over-union threshold
to use for non-maximum suppression. Defaults to 0.5.
sigma (float): The sigma value to use for the soft non-maximum suppression
algorithm. Defaults to 0.5.
class_agnostic (bool): Whether to perform class-agnostic
non-maximum suppression. If True, the class_id of each detection
will be ignored. Defaults to False.
sigma (float): The sigma value to use for the soft non-maximum suppression
algorithm. Defaults to 0.5.
Returns:
Detections: A new Detections object containing the subset of detections
Expand Down Expand Up @@ -1370,13 +1368,12 @@ def with_soft_nms(
soft_confidences = mask_soft_non_max_suppression(
predictions=predictions,
masks=self.mask,
iou_threshold=threshold,
sigma=sigma,
)
self.confidence = soft_confidences
else:
indices, soft_confidences = box_soft_non_max_suppression(
predictions=predictions, iou_threshold=threshold, sigma=sigma
soft_confidences = box_soft_non_max_suppression(
predictions=predictions, sigma=sigma
)
self.confidence = soft_confidences

Expand Down
34 changes: 13 additions & 21 deletions supervision/detection/overlap_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def resize_masks(masks: np.ndarray, max_dimension: int = 640) -> np.ndarray:


def __prepare_data_for_mask_nms(
iou_threshold: float,
mask_dimension: int,
masks: np.ndarray,
predictions: np.ndarray,
Expand All @@ -48,8 +47,6 @@ def __prepare_data_for_mask_nms(
Get IOUs from mask. Prepare the data for non-max suppression.
Args:
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
mask_dimension (int): The dimension to which the masks should be
resized before computing IOU values.
masks (np.ndarray): A 3D array of binary masks corresponding to the predictions.
Expand All @@ -68,10 +65,6 @@ def __prepare_data_for_mask_nms(
AssertionError: If `iou_threshold` is not within the closed range from
`0` to `1`.
"""
assert 0 <= iou_threshold <= 1, (
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
f"{iou_threshold} given."
)
rows, columns = predictions.shape

if columns == 5:
Expand Down Expand Up @@ -117,8 +110,12 @@ def mask_non_max_suppression(
AssertionError: If `iou_threshold` is not within the closed
range from `0` to `1`.
"""
assert 0 <= iou_threshold <= 1, (
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
f"{iou_threshold} given."
)
_, categories, ious, rows, sort_index = __prepare_data_for_mask_nms(
iou_threshold, mask_dimension, masks, predictions
mask_dimension, masks, predictions
)

keep = np.ones(rows, dtype=bool)
Expand All @@ -133,7 +130,6 @@ def mask_non_max_suppression(
def mask_soft_non_max_suppression(
predictions: np.ndarray,
masks: np.ndarray,
iou_threshold: float = 0.5,
mask_dimension: int = 640,
sigma: float = 0.5,
) -> np.ndarray:
Expand All @@ -160,7 +156,7 @@ def mask_soft_non_max_suppression(
0 < sigma < 1
), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given."
predictions, categories, ious, rows, sort_index = __prepare_data_for_mask_nms(
iou_threshold, mask_dimension, masks, predictions
mask_dimension, masks, predictions
)

not_this_row = np.ones(rows)
Expand All @@ -175,14 +171,12 @@ def mask_soft_non_max_suppression(


def __prepare_data_for_box_nsm(
iou_threshold: float, predictions: np.ndarray
predictions: np.ndarray,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, int, np.ndarray]:
"""
Prepare the data for non-max suppression.
Args:
iou_threshold (float): The intersection-over-union threshold
to use for non-maximum suppression.
predictions (np.ndarray): An array of object detection predictions in the
format of `(x_min, y_min, x_max, y_max, score)` or
`(x_min, y_min, x_max, y_max, score, class)`. Shape: `(N, 5)` or `(N, 6)`,
Expand All @@ -198,10 +192,6 @@ def __prepare_data_for_box_nsm(
"""
assert 0 <= iou_threshold <= 1, (
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
f"{iou_threshold} given."
)
rows, columns = predictions.shape

# add column #5 - category filled with zeros for agnostic nms
Expand Down Expand Up @@ -240,9 +230,11 @@ def box_non_max_suppression(
AssertionError: If `iou_threshold` is not within the
closed range from `0` to `1`.
"""
_, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(
iou_threshold, predictions
assert 0 <= iou_threshold <= 1, (
"Value of `iou_threshold` must be in the closed range from 0 to 1, "
f"{iou_threshold} given."
)
_, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(predictions)

keep = np.ones(rows, dtype=bool)
for index, (iou, category) in enumerate(zip(ious, categories)):
Expand All @@ -258,7 +250,7 @@ def box_non_max_suppression(


def box_soft_non_max_suppression(
predictions: np.ndarray, iou_threshold: float = 0.5, sigma: float = 0.5
predictions: np.ndarray, sigma: float = 0.5
) -> np.ndarray:
"""
Perform Soft Non-Maximum Suppression (Soft-NMS) on object detection predictions.
Expand All @@ -283,7 +275,7 @@ def box_soft_non_max_suppression(
0 < sigma < 1
), f"Value of `sigma` must be greater than 0 and less than 1, {sigma} given."
predictions, categories, ious, rows, sort_index = __prepare_data_for_box_nsm(
iou_threshold, predictions
predictions
)

not_this_row = np.ones(rows)
Expand Down
28 changes: 3 additions & 25 deletions test/detection/test_overlap_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,25 +246,22 @@ def test_box_non_max_suppression(


@pytest.mark.parametrize(
"predictions, iou_threshold, sigma, expected_result, exception",
"predictions, sigma, expected_result, exception",
[
(
np.empty(shape=(0, 5)),
0.5,
0.1,
np.array([]),
DoesNotRaise(),
), # single box with no category
(
np.array([[10.0, 10.0, 40.0, 40.0, 0.8]]),
0.5,
0.8,
np.array([0.8]),
DoesNotRaise(),
), # single box with no category
(
np.array([[10.0, 10.0, 40.0, 40.0, 0.8, 0]]),
0.5,
0.9,
np.array([0.8]),
DoesNotRaise(),
Expand All @@ -276,7 +273,6 @@ def test_box_non_max_suppression(
[15.0, 15.0, 40.0, 40.0, 0.9],
]
),
0.5,
0.2,
np.array([0.07176137, 0.9]),
DoesNotRaise(),
Expand All @@ -288,7 +284,6 @@ def test_box_non_max_suppression(
[15.0, 15.0, 40.0, 40.0, 0.9, 1],
]
),
0.5,
0.3,
np.array([0.8, 0.9]),
DoesNotRaise(),
Expand All @@ -300,7 +295,6 @@ def test_box_non_max_suppression(
[15.0, 15.0, 40.0, 40.0, 0.9, 0],
]
),
0.5,
0.9,
np.array([0.46814354, 0.9]),
DoesNotRaise(),
Expand All @@ -313,7 +307,6 @@ def test_box_non_max_suppression(
[10.0, 10.0, 40.0, 50.0, 0.85],
]
),
0.5,
0.7,
np.array([0.42648529, 0.9, 0.53109062]),
DoesNotRaise(),
Expand All @@ -327,7 +320,6 @@ def test_box_non_max_suppression(
]
),
0.5,
0.5,
np.array([0.8, 0.9, 0.85]),
DoesNotRaise(),
), # three boxes with same category
Expand All @@ -339,7 +331,6 @@ def test_box_non_max_suppression(
[10.0, 10.0, 40.0, 50.0, 0.85, 1],
]
),
0.5,
0.9,
np.array([0.55491779, 0.9, 0.85]),
DoesNotRaise(),
Expand All @@ -348,15 +339,12 @@ def test_box_non_max_suppression(
)
def test_box_soft_non_max_suppression(
predictions: np.ndarray,
iou_threshold: float,
sigma: float,
expected_result: Optional[np.ndarray],
exception: Exception,
) -> None:
with exception:
result = box_soft_non_max_suppression(
predictions=predictions, iou_threshold=iou_threshold, sigma=sigma
)
result = box_soft_non_max_suppression(predictions=predictions, sigma=sigma)
np.testing.assert_almost_equal(result, expected_result, decimal=5)


Expand Down Expand Up @@ -567,12 +555,11 @@ def test_mask_non_max_suppression(


@pytest.mark.parametrize(
"predictions, masks, iou_threshold, sigma, expected_result, exception",
"predictions, masks, sigma, expected_result, exception",
[
(
np.empty((0, 6)),
np.empty((0, 5, 5)),
0.5,
0.1,
np.array([]),
DoesNotRaise(),
Expand All @@ -590,7 +577,6 @@ def test_mask_non_max_suppression(
]
]
),
0.5,
0.2,
np.array([0.8]),
DoesNotRaise(),
Expand All @@ -608,7 +594,6 @@ def test_mask_non_max_suppression(
]
]
),
0.5,
0.99,
np.array([0.8]),
DoesNotRaise(),
Expand All @@ -633,7 +618,6 @@ def test_mask_non_max_suppression(
],
]
),
0.5,
0.8,
np.array([0.8, 0.9]),
DoesNotRaise(),
Expand All @@ -658,7 +642,6 @@ def test_mask_non_max_suppression(
],
]
),
0.4,
0.6,
np.array([0.3831756, 0.9]),
DoesNotRaise(),
Expand All @@ -683,7 +666,6 @@ def test_mask_non_max_suppression(
],
]
),
0.5,
0.9,
np.array([0.8, 0.9]),
DoesNotRaise(),
Expand Down Expand Up @@ -721,7 +703,6 @@ def test_mask_non_max_suppression(
],
]
),
0.5,
0.3,
np.array([0.02853919, 0.85, 0.9]),
DoesNotRaise(),
Expand Down Expand Up @@ -759,7 +740,6 @@ def test_mask_non_max_suppression(
],
]
),
0.5,
0.1,
np.array([0.8, 0.85, 0.9]),
DoesNotRaise(),
Expand All @@ -769,7 +749,6 @@ def test_mask_non_max_suppression(
def test_mask_soft_non_max_suppression(
predictions: np.ndarray,
masks: np.ndarray,
iou_threshold: float,
sigma: float,
expected_result: Optional[np.ndarray],
exception: Exception,
Expand All @@ -778,7 +757,6 @@ def test_mask_soft_non_max_suppression(
result = mask_soft_non_max_suppression(
predictions=predictions,
masks=masks,
iou_threshold=iou_threshold,
sigma=sigma,
)
np.testing.assert_almost_equal(result, expected_result, decimal=6)

0 comments on commit 3b2b60c

Please sign in to comment.