Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inference slicer batching #1239

Open
wants to merge 5 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 109 additions & 24 deletions supervision/detection/tools/inference_slicer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Callable, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union

import numpy as np

Expand All @@ -13,6 +13,7 @@
SupervisionWarnings,
warn_deprecated,
)
from supervision.utils.iterables import create_batches


def move_detections(
Expand Down Expand Up @@ -71,9 +72,14 @@ class InferenceSlicer:
overlap_filter (Union[OverlapFilter, str]): Strategy for
filtering or merging overlapping detections in slices.
iou_threshold (float): Intersection over Union (IoU) threshold
used when filtering by overlap.
used for non-max suppression.
callback (Callable): A function that performs inference on a given image
slice and returns detections.
slice and returns detections. Should accept `np.ndarray` if
`batch_size` is `1` (default) and `List[np.ndarray]` otherwise.
See examples for more details.
batch_size (int): How many images to pass to the model. Defaults to 1.
For other values, `callback` should accept a list of images. Higher
value uses more memory but may be faster.
thread_workers (int): Number of threads for parallel execution.

Note:
Expand All @@ -85,12 +91,16 @@ class InferenceSlicer:

def __init__(
self,
callback: Callable[[np.ndarray], Detections],
callback: Union[
Callable[[np.ndarray], Detections],
Callable[[List[np.ndarray]], List[Detections]],
],
slice_wh: Tuple[int, int] = (320, 320),
overlap_ratio_wh: Optional[Tuple[float, float]] = (0.2, 0.2),
overlap_wh: Optional[Tuple[int, int]] = None,
overlap_filter: Union[OverlapFilter, str] = OverlapFilter.NON_MAX_SUPPRESSION,
iou_threshold: float = 0.5,
batch_size: int = 1,
thread_workers: int = 1,
):
if overlap_ratio_wh is not None:
Expand All @@ -108,8 +118,14 @@ def __init__(
self.iou_threshold = iou_threshold
self.overlap_filter = OverlapFilter.from_value(overlap_filter)
self.callback = callback
self.batch_size = batch_size
self.thread_workers = thread_workers

if self.batch_size < 1:
raise ValueError("batch_size should be greater than 0")
if self.thread_workers < 1:
raise ValueError("thread_workers should be greater than 0.")

def __call__(self, image: np.ndarray) -> Detections:
"""
Performs slicing-based inference on the provided image using the specified
Expand All @@ -133,9 +149,22 @@ def __call__(self, image: np.ndarray) -> Detections:
image = cv2.imread(SOURCE_IMAGE_PATH)
model = YOLO(...)

def callback(image_slice: np.ndarray) -> sv.Detections:
result = model(image_slice)[0]
return sv.Detections.from_ultralytics(result)
# Option 1: Single slice
def callback(slice: np.ndarray) -> sv.Detections:
result = model(slice)[0]
detections = sv.Detections.from_ultralytics(result)
return detections

slicer = sv.InferenceSlicer(callback=callback)
detections = slicer(image)


# Option 2: Batch slices (Faster, but uses more memory)
def callback(slices: List[np.ndarray]) -> List[sv.Detections]:
results = model(slices)
detections_list = [
sv.Detections.from_ultralytics(result) for result in results]
return detections_list

slicer = sv.InferenceSlicer(
callback=callback,
Expand All @@ -153,13 +182,36 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
overlap_ratio_wh=self.overlap_ratio_wh,
overlap_wh=self.overlap_wh,
)
batched_offsets_generator = create_batches(offsets, self.batch_size)

if self.thread_workers == 1:
for offset_batch in batched_offsets_generator:
if self.batch_size == 1:
result = self._callback_image_single(image, offset_batch[0])
detections_list.append(result)
else:
results = self._callback_image_batch(image, offset_batch)
detections_list.extend(results)

with ThreadPoolExecutor(max_workers=self.thread_workers) as executor:
futures = [
executor.submit(self._run_callback, image, offset) for offset in offsets
]
for future in as_completed(futures):
detections_list.append(future.result())
else:
with ThreadPoolExecutor(max_workers=self.thread_workers) as executor:
futures = []
for offset_batch in batched_offsets_generator:
if self.batch_size == 1:
future = executor.submit(
self._callback_image_single, image, offset_batch[0]
)
else:
future = executor.submit(
self._callback_image_batch, image, offset_batch
)
futures.append(future)

for future in as_completed(futures):
if self.batch_size == 1:
detections_list.append(future.result())
else:
detections_list.extend(future.result())

merged = Detections.merge(detections_list=detections_list)
if self.overlap_filter == OverlapFilter.NONE:
Expand All @@ -175,27 +227,60 @@ def callback(image_slice: np.ndarray) -> sv.Detections:
)
return merged

def _run_callback(self, image, offset) -> Detections:
def _callback_image_single(
self, image: np.ndarray, offset: np.ndarray
) -> Detections:
"""
Run the provided callback on a slice of an image.
Run the callback on a single image.

Args:
image (np.ndarray): The input image on which inference needs to run
offset (np.ndarray): An array of shape `(4,)` containing coordinates
for the slice.

Returns:
Detections: A collection of detections for the slice.
"""
assert isinstance(offset, np.ndarray)

image_slice = crop_image(image=image, xyxy=offset)
detections = self.callback(image_slice)
resolution_wh = (image.shape[1], image.shape[0])
detections = move_detections(
detections=detections, offset=offset[:2], resolution_wh=resolution_wh
)
if not isinstance(detections, Detections):
raise ValueError(
f"Callback should return a single Detections object when "
f"max_batch_size is 1. Instead it returned: {type(detections)}"
)

detections = move_detections(detections=detections, offset=offset[:2])
return detections

def _callback_image_batch(
self, image: np.ndarray, offsets_batch: List[np.ndarray]
) -> List[Detections]:
"""
Run the callback on a batch of images.

Args:
image (np.ndarray): The input image on which inference needs to run
offsets_batch (List[np.ndarray]): List of N arrays of shape `(4,)`,
containing coordinates of the slices.

Returns:
List[Detections]: Detections found in each slice
"""
assert isinstance(offsets_batch, list)

slices = [crop_image(image=image, xyxy=offset) for offset in offsets_batch]
detections_in_slices = self.callback(slices)
if not isinstance(detections_in_slices, list):
raise ValueError(
f"Callback should return a list of Detections objects when "
f"max_batch_size is greater than 1. "
f"Instead it returned: {type(detections_in_slices)}"
)

detections_with_offset = [
move_detections(detections=detections, offset=offset[:2])
for detections, offset in zip(detections_in_slices, offsets_batch)
]

return detections_with_offset

@staticmethod
def _generate_offset(
resolution_wh: Tuple[int, int],
Expand Down