diff --git a/README.md b/README.md
index 21109323d..76244ccda 100644
--- a/README.md
+++ b/README.md
@@ -44,7 +44,7 @@ Object detection and instance segmentation are by far the most important applica
##
Quick Start Examples
-[📜 List of publications that cite SAHI (currently 100+)](https://scholar.google.com/scholar?hl=en&as_sdt=2005&sciodt=0,5&cites=14065474760484865747&scipsc=&q=&scisbd=1)
+[📜 List of publications that cite SAHI (currently 150+)](https://scholar.google.com/scholar?hl=en&as_sdt=2005&sciodt=0,5&cites=14065474760484865747&scipsc=&q=&scisbd=1)
[🏆 List of competition winners that used SAHI](https://github.com/obss/sahi/discussions/688)
@@ -52,10 +52,12 @@ Object detection and instance segmentation are by far the most important applica
- [Introduction to SAHI](https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80)
-- [Official paper](https://ieeexplore.ieee.org/document/9897990) (ICIP 2022 oral) (NEW)
+- [Official paper](https://ieeexplore.ieee.org/document/9897990) (ICIP 2022 oral)
- [Pretrained weights and ICIP 2022 paper files](https://github.com/fcakyon/small-object-detection-benchmark)
+- [Visualizing and Evaluating SAHI predictions with FiftyOne](https://voxel51.com/blog/how-to-detect-small-objects/) (2024) (NEW)
+
- ['Exploring SAHI' Research Article from 'learnopencv.com'](https://learnopencv.com/slicing-aided-hyper-inference/) (2023) (NEW)
- ['VIDEO TUTORIAL: Slicing Aided Hyper Inference for Small Object Detection - SAHI'](https://www.youtube.com/watch?v=UuOjJKxn-M8&t=270s) (2023) (NEW)
@@ -82,9 +84,13 @@ Object detection and instance segmentation are by far the most important applica
- `Detectron2` + `SAHI` walkthrough:
+- `TorchVision` + `SAHI` walkthrough:
+
- `HuggingFace` + `SAHI` walkthrough: (NEW)
-- `TorchVision` + `SAHI` walkthrough: (NEW)
+- `DeepSparse` + `SAHI` walkthrough: (NEW)
+
+- `SuperGradients/YOLONAS` + `SAHI`: (NEW)
@@ -129,6 +135,12 @@ conda install pytorch=1.13.1 torchvision=0.14.1 pytorch-cuda=11.7 -c pytorch -c
pip install yolov5==7.0.13
```
+- Install your desired detection framework (ultralytics):
+
+```console
+pip install ultralytics==8.0.207
+```
+
- Install your desired detection framework (mmdet):
```console
@@ -148,6 +160,12 @@ pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113
pip install transformers timm
```
+- Install your desired detection framework (super-gradients):
+
+```console
+pip install super-gradients==3.3.1
+```
+
### Framework Agnostic Sliced/Standard Prediction
@@ -261,6 +279,20 @@ python -m scripts.run_code_style format
Alzbeta Tureckova
+So Uchida
+
+Yonghye Kwon
+
+Neville
+
+Janne Mäyrä
+
+Christoffer Edlund
+
+Ilker Manap
+
+Nguyễn Thế An
+
Wei Ji
Aynur Susuz
@@ -269,5 +301,13 @@ python -m scripts.run_code_style format
Lakshay Mehra
+Karl-Joan Alesma
+
+Jacob Marks
+
+William Lung
+
+Amogh Dhaliwal
+
diff --git a/sahi/postprocess/combine.py b/sahi/postprocess/combine.py
index 3a28c9c72..3ae151f3c 100644
--- a/sahi/postprocess/combine.py
+++ b/sahi/postprocess/combine.py
@@ -217,19 +217,12 @@ def greedy_nmm(
# according to their confidence scores
order = scores.argsort()
- # initialise an empty list for
- # filtered prediction boxes
- keep = []
-
while len(order) > 0:
# extract the index of the
# prediction with highest score
# we call this prediction S
idx = order[-1]
- # push S in filtered predictions list
- keep.append(idx.tolist())
-
# remove S from P
order = order[:-1]
diff --git a/sahi/predict.py b/sahi/predict.py
index 699cb6f67..e7ddb84ff 100644
--- a/sahi/predict.py
+++ b/sahi/predict.py
@@ -164,7 +164,7 @@ def get_sliced_prediction(
detection accuracy. Default: True.
postprocess_type: str
Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
- Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'.
+ Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.
postprocess_match_metric: str
Metric to be used during object prediction matching after sliced prediction.
'IOU' for intersection over union, 'IOS' for intersection over smaller area.
@@ -231,7 +231,7 @@ def get_sliced_prediction(
# create prediction input
num_group = int(num_slices / num_batch)
if verbose == 1 or verbose == 2:
- tqdm.write(f"Performing prediction on {num_slices} number of slices.")
+ tqdm.write(f"Performing prediction on {num_slices} slices.")
object_prediction_list = []
# perform sliced prediction
for group_ind in range(num_group):
@@ -416,7 +416,7 @@ def predict(
Default to ``0.2``.
postprocess_type: str
Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
- Options are 'NMM', 'GREEDYNMM', 'LSNMS' or 'NMS'. Default is 'GRREDYNMM'.
+ Options are 'NMM', 'GREEDYNMM', 'LSNMS' or 'NMS'. Default is 'GREEDYNMM'.
postprocess_match_metric: str
Metric to be used during object prediction matching after sliced prediction.
'IOU' for intersection over union, 'IOS' for intersection over smaller area.
@@ -781,7 +781,7 @@ def predict_fiftyone(
Default to ``0.2``.
postprocess_type: str
Type of the postprocess to be used after sliced inference while merging/eliminating predictions.
- Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'.
+ Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'.
postprocess_match_metric: str
Metric to be used during object prediction matching after sliced prediction.
'IOU' for intersection over union, 'IOS' for intersection over smaller area.
diff --git a/sahi/utils/coco.py b/sahi/utils/coco.py
index 1777555aa..e84f2f3d3 100644
--- a/sahi/utils/coco.py
+++ b/sahi/utils/coco.py
@@ -130,7 +130,7 @@ def from_coco_annotation_dict(cls, annotation_dict: Dict, category_name: Optiona
annotation_dict: dict
COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id")
"""
- if annotation_dict.__contains__("segmentation") and not isinstance(annotation_dict["segmentation"], list):
+ if annotation_dict.__contains__("segmentation") and isinstance(annotation_dict["segmentation"], dict):
has_rle_segmentation = True
logger.warning(
f"Segmentation annotation for id {annotation_dict['id']} is skipped since RLE segmentation format is not supported."
diff --git a/sahi/utils/cv.py b/sahi/utils/cv.py
index 06434a086..269dc254c 100644
--- a/sahi/utils/cv.py
+++ b/sahi/utils/cv.py
@@ -21,7 +21,6 @@
class Colors:
- # color palette
def __init__(self):
hex = (
"FF3838",
@@ -45,16 +44,38 @@ def __init__(self):
"FF95C8",
"FF37C7",
)
- self.palette = [self.hex2rgb("#" + c) for c in hex]
+ self.palette = [self.hex_to_rgb("#" + c) for c in hex]
self.n = len(self.palette)
- def __call__(self, i, bgr=False):
- c = self.palette[int(i) % self.n]
- return (c[2], c[1], c[0]) if bgr else c
+ def __call__(self, ind, bgr: bool = False):
+ """
+ Convert an index to a color code.
+
+ Args:
+ ind (int): The index to convert.
+ bgr (bool, optional): Whether to return the color code in BGR format. Defaults to False.
+
+ Returns:
+ tuple: The color code in RGB or BGR format, depending on the value of `bgr`.
+ """
+ color_codes = self.palette[int(ind) % self.n]
+ return (color_codes[2], color_codes[1], color_codes[0]) if bgr else color_codes
@staticmethod
- def hex2rgb(h): # rgb order
- return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4))
+ def hex_to_rgb(hex_code):
+ """
+ Converts a hexadecimal color code to RGB format.
+
+ Args:
+ hex_code (str): The hexadecimal color code to convert.
+
+ Returns:
+ tuple: A tuple representing the RGB values in the order (R, G, B).
+ """
+ rgb = []
+ for i in (0, 2, 4):
+ rgb.append(int(hex_code[1 + i : 1 + i + 2], 16))
+ return tuple(rgb)
def crop_object_predictions(
@@ -65,23 +86,25 @@ def crop_object_predictions(
export_format: str = "png",
):
"""
- Crops bounding boxes over the source image and exports it to output folder.
- Arguments:
- object_predictions: a list of prediction.ObjectPrediction
- output_dir: directory for resulting visualization to be exported
- file_name: exported file will be saved as: output_dir+file_name+".png"
- export_format: can be specified as 'jpg' or 'png'
+ Crops bounding boxes over the source image and exports it to the output folder.
+
+ Args:
+ image (np.ndarray): The source image to crop bounding boxes from.
+ object_prediction_list: A list of object predictions.
+ output_dir (str): The directory where the resulting visualizations will be exported. Defaults to an empty string.
+ file_name (str): The name of the exported file. The exported file will be saved as `output_dir + file_name + ".png"`. Defaults to "prediction_visual".
+ export_format (str): The format of the exported file. Can be specified as 'jpg' or 'png'. Defaults to "png".
"""
# create output folder if not present
Path(output_dir).mkdir(parents=True, exist_ok=True)
# add bbox and mask to image if present
for ind, object_prediction in enumerate(object_prediction_list):
- # deepcopy object_prediction_list so that original is not altered
+ # deepcopy object_prediction_list so that the original is not altered
object_prediction = object_prediction.deepcopy()
bbox = object_prediction.bbox.to_xyxy()
category_id = object_prediction.category.id
# crop detections
- # deepcopy crops so that original is not altered
+ # deepcopy crops so that the original is not altered
cropped_img = copy.deepcopy(
image[
int(bbox[1]) : int(bbox[3]),
@@ -98,7 +121,12 @@ def crop_object_predictions(
def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False):
"""
- Reads image from path and saves as given extension.
+ Reads an image from the given path and saves it with the specified extension.
+
+ Args:
+ read_path (str): The path to the image file.
+ extension (str, optional): The desired file extension for the saved image. Defaults to "jpg".
+ grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to False.
"""
image = cv2.imread(read_path)
pre, ext = os.path.splitext(read_path)
@@ -110,6 +138,17 @@ def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False)
def read_large_image(image_path: str):
+ """
+ Reads a large image from the specified image path.
+
+ Args:
+ image_path (str): The path to the image file.
+
+ Returns:
+ tuple: A tuple containing the image data and a flag indicating whether cv2 was used to read the image.
+ The image data is a numpy array representing the image in RGB format.
+ The flag is True if cv2 was used, False otherwise.
+ """
use_cv2 = True
# read image, cv2 fails on large files
try:
@@ -130,7 +169,13 @@ def read_large_image(image_path: str):
def read_image(image_path: str):
"""
- Loads image as numpy array from given path.
+ Loads image as a numpy array from the given path.
+
+ Args:
+ image_path (str): The path to the image file.
+
+ Returns:
+ numpy.ndarray: The loaded image as a numpy array.
"""
# read image
image = cv2.imread(image_path)
@@ -144,7 +189,12 @@ def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool
Loads an image as PIL.Image.Image.
Args:
- image : Can be image path or url (str), numpy image (np.ndarray) or PIL.Image
+ image (Union[Image.Image, str, np.ndarray]): The image to be loaded. It can be an image path or URL (str),
+ a numpy image (np.ndarray), or a PIL.Image object.
+ exif_fix (bool, optional): Whether to apply an EXIF fix to the image. Defaults to False.
+
+ Returns:
+ PIL.Image.Image: The loaded image as a PIL.Image object.
"""
# https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil
Image.MAX_IMAGE_PIXELS = None
@@ -184,7 +234,11 @@ def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool
def select_random_color():
"""
- Selects random color.
+ Selects a random color from a predefined list of colors.
+
+ Returns:
+ list: A list representing the RGB values of the selected color.
+
"""
colors = [
[0, 255, 0],
@@ -205,6 +259,13 @@ def select_random_color():
def apply_color_mask(image: np.ndarray, color: tuple):
"""
Applies color mask to given input image.
+
+ Args:
+ image (np.ndarray): The input image to apply the color mask to.
+ color (tuple): The RGB color tuple to use for the mask.
+
+ Returns:
+ np.ndarray: The resulting image with the applied color mask.
"""
r = np.zeros_like(image).astype(np.uint8)
g = np.zeros_like(image).astype(np.uint8)
@@ -328,6 +389,22 @@ def visualize_prediction(
"""
Visualizes prediction classes, bounding boxes over the source image
and exports it to output folder.
+
+ Args:
+ image (np.ndarray): The source image.
+ boxes (List[List]): List of bounding boxes coordinates.
+ classes (List[str]): List of class labels corresponding to each bounding box.
+ masks (Optional[List[np.ndarray]], optional): List of masks corresponding to each bounding box. Defaults to None.
+ rect_th (float, optional): Thickness of the bounding box rectangle. Defaults to None.
+ text_size (float, optional): Size of the text for class labels. Defaults to None.
+ text_th (float, optional): Thickness of the text for class labels. Defaults to None.
+ color (tuple, optional): Color of the bounding box and text. Defaults to None.
+ hide_labels (bool, optional): Whether to hide the class labels. Defaults to False.
+ output_dir (Optional[str], optional): Output directory to save the visualization. Defaults to None.
+ file_name (Optional[str], optional): File name for the saved visualization. Defaults to "prediction_visual".
+
+ Returns:
+ dict: A dictionary containing the visualized image and the elapsed time for the visualization process.
"""
elapsed_time = time.time()
# deepcopy image so that original is not altered
@@ -354,21 +431,21 @@ def visualize_prediction(
image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0)
# add bboxes to image if present
- for i in range(len(boxes)):
+ for box_indice in range(len(boxes)):
# deepcopy boxso that original is not altered
- box = copy.deepcopy(boxes[i])
- class_ = classes[i]
+ box = copy.deepcopy(boxes[box_indice])
+ class_ = classes[box_indice]
# set color
if colors is not None:
color = colors(class_)
# set bbox points
- p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3]))
+ point1, point2 = [int(box[0]), int(box[1])], [int(box[2]), int(box[3])]
# visualize boxes
cv2.rectangle(
image,
- p1,
- p2,
+ point1,
+ point2,
color=color,
thickness=rect_th,
)
@@ -376,15 +453,17 @@ def visualize_prediction(
if not hide_labels:
# arange bounding box text location
label = f"{class_}"
- w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0] # label width, height
- outside = p1[1] - h - 3 >= 0 # label fits outside box
- p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
+ box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
+ 0
+ ] # label width, height
+ outside = point1[1] - box_height - 3 >= 0 # label fits outside box
+ point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
# add bounding box text
- cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
+ cv2.rectangle(image, point1, point2, color, -1, cv2.LINE_AA) # filled
cv2.putText(
image,
label,
- (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
+ (point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
0,
text_size,
(255, 255, 255),
@@ -417,7 +496,8 @@ def visualize_object_predictions(
"""
Visualizes prediction category names, bounding boxes over the source image
and exports it to output folder.
- Arguments:
+
+ Args:
object_prediction_list: a list of prediction.ObjectPrediction
rect_th: rectangle thickness
text_size: size of the category name over box
@@ -472,12 +552,12 @@ def visualize_object_predictions(
if colors is not None:
color = colors(object_prediction.category.id)
# set bbox points
- p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
+ point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3]))
# visualize boxes
cv2.rectangle(
image,
- p1,
- p2,
+ point1,
+ point2,
color=color,
thickness=rect_th,
)
@@ -489,15 +569,17 @@ def visualize_object_predictions(
if not hide_conf:
label += f" {score:.2f}"
- w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0] # label width, height
- outside = p1[1] - h - 3 >= 0 # label fits outside box
- p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3
+ box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[
+ 0
+ ] # label width, height
+ outside = point1[1] - box_height - 3 >= 0 # label fits outside box
+ point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3
# add bounding box text
- cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled
+ cv2.rectangle(image, point1, point2, color, -1, cv2.LINE_AA) # filled
cv2.putText(
image,
label,
- (p1[0], p1[1] - 2 if outside else p1[1] + h + 2),
+ (point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2),
0,
text_size,
(255, 255, 255),
@@ -541,9 +623,17 @@ def get_coco_segmentation_from_bool_mask(bool_mask):
return coco_segmentation
-def get_bool_mask_from_coco_segmentation(coco_segmentation, width, height):
+def get_bool_mask_from_coco_segmentation(coco_segmentation: List[List[float]], width: int, height: int) -> np.ndarray:
"""
Convert coco segmentation to 2D boolean mask of given height and width
+
+ Parameters:
+ - coco_segmentation: list of points representing the coco segmentation
+ - width: width of the boolean mask
+ - height: height of the boolean mask
+
+ Returns:
+ - bool_mask: 2D boolean mask of size (height, width)
"""
size = [height, width]
points = [np.array(point).reshape(-1, 2).round().astype(int) for point in coco_segmentation]
@@ -553,9 +643,15 @@ def get_bool_mask_from_coco_segmentation(coco_segmentation, width, height):
return bool_mask
-def get_bbox_from_bool_mask(bool_mask):
+def get_bbox_from_bool_mask(bool_mask: np.ndarray) -> Optional[List[int]]:
"""
- Generate voc bbox ([xmin, ymin, xmax, ymax]) from given bool_mask (2D np.ndarray)
+ Generate VOC bounding box [xmin, ymin, xmax, ymax] from given boolean mask.
+
+ Args:
+ bool_mask (np.ndarray): 2D boolean mask.
+
+ Returns:
+ Optional[List[int]]: VOC bounding box [xmin, ymin, xmax, ymax] or None if no bounding box is found.
"""
rows = np.any(bool_mask, axis=1)
cols = np.any(bool_mask, axis=0)
@@ -596,12 +692,16 @@ def ipython_display(image: np.ndarray):
IPython.display.display(i)
-def exif_transpose(image: Image.Image):
+def exif_transpose(image: Image.Image) -> Image.Image:
"""
Transpose a PIL image accordingly if it has an EXIF Orientation tag.
Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose()
- :param image: The image to transpose.
- :return: An image.
+
+ Args:
+ image (Image.Image): The image to transpose.
+
+ Returns:
+ Image.Image: The transposed image.
"""
exif = image.getexif()
orientation = exif.get(0x0112, 1) # default 1
diff --git a/sahi/utils/file.py b/sahi/utils/file.py
index bf2dd1621..be0696120 100644
--- a/sahi/utils/file.py
+++ b/sahi/utils/file.py
@@ -166,6 +166,16 @@ def get_base_filename(path: str):
def get_file_extension(path: str):
+ """
+ Get the file extension from a given file path.
+
+ Args:
+ path (str): The file path.
+
+ Returns:
+ str: The file extension.
+
+ """
filename, file_extension = os.path.splitext(path)
return file_extension
@@ -214,20 +224,49 @@ def import_model_class(model_type, class_name):
return class_
-def increment_path(path, exist_ok=True, sep=""):
- # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
+def increment_path(path: str, exist_ok: bool = True, sep: str = "") -> str:
+ """
+ Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc.
+
+ Args:
+ path: str
+ The base path to increment.
+ exist_ok: bool
+ If True, return the path as is if it already exists. If False, increment the path.
+ sep: str
+ The separator to use between the base path and the increment number.
+
+ Returns:
+ str: The incremented path.
+
+ Example:
+ >>> increment_path("runs/exp", sep="_")
+ 'runs/exp_0'
+ >>> increment_path("runs/exp_0", sep="_")
+ 'runs/exp_1'
+ """
path = Path(path) # os-agnostic
if (path.exists() and exist_ok) or (not path.exists()):
return str(path)
else:
dirs = glob.glob(f"{path}{sep}*") # similar paths
matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs]
- i = [int(m.groups()[0]) for m in matches if m] # indices
- n = max(i) + 1 if i else 2 # increment number
+ indices = [int(m.groups()[0]) for m in matches if m] # indices
+ n = max(indices) + 1 if indices else 2 # increment number
return f"{path}{sep}{n}" # update path
def download_from_url(from_url: str, to_path: str):
+ """
+ Downloads a file from the given URL and saves it to the specified path.
+
+ Args:
+ from_url (str): The URL of the file to download.
+ to_path (str): The path where the downloaded file should be saved.
+
+ Returns:
+ None
+ """
Path(to_path).parent.mkdir(parents=True, exist_ok=True)
if not os.path.exists(to_path):
@@ -238,7 +277,12 @@ def download_from_url(from_url: str, to_path: str):
def is_colab():
+ """
+ Check if the current environment is a Google Colab instance.
+
+ Returns:
+ bool: True if the environment is a Google Colab instance, False otherwise.
+ """
import sys
- # Is environment a Google Colab instance?
return "google.colab" in sys.modules