diff --git a/README.md b/README.md index 21109323d..76244ccda 100644 --- a/README.md +++ b/README.md @@ -44,7 +44,7 @@ Object detection and instance segmentation are by far the most important applica ##
Quick Start Examples
-[📜 List of publications that cite SAHI (currently 100+)](https://scholar.google.com/scholar?hl=en&as_sdt=2005&sciodt=0,5&cites=14065474760484865747&scipsc=&q=&scisbd=1) +[📜 List of publications that cite SAHI (currently 150+)](https://scholar.google.com/scholar?hl=en&as_sdt=2005&sciodt=0,5&cites=14065474760484865747&scipsc=&q=&scisbd=1) [🏆 List of competition winners that used SAHI](https://github.com/obss/sahi/discussions/688) @@ -52,10 +52,12 @@ Object detection and instance segmentation are by far the most important applica - [Introduction to SAHI](https://medium.com/codable/sahi-a-vision-library-for-performing-sliced-inference-on-large-images-small-objects-c8b086af3b80) -- [Official paper](https://ieeexplore.ieee.org/document/9897990) (ICIP 2022 oral) (NEW) +- [Official paper](https://ieeexplore.ieee.org/document/9897990) (ICIP 2022 oral) - [Pretrained weights and ICIP 2022 paper files](https://github.com/fcakyon/small-object-detection-benchmark) +- [Visualizing and Evaluating SAHI predictions with FiftyOne](https://voxel51.com/blog/how-to-detect-small-objects/) (2024) (NEW) + - ['Exploring SAHI' Research Article from 'learnopencv.com'](https://learnopencv.com/slicing-aided-hyper-inference/) (2023) (NEW) - ['VIDEO TUTORIAL: Slicing Aided Hyper Inference for Small Object Detection - SAHI'](https://www.youtube.com/watch?v=UuOjJKxn-M8&t=270s) (2023) (NEW) @@ -82,9 +84,13 @@ Object detection and instance segmentation are by far the most important applica - `Detectron2` + `SAHI` walkthrough: sahi-detectron2 +- `TorchVision` + `SAHI` walkthrough: sahi-torchvision + - `HuggingFace` + `SAHI` walkthrough: sahi-huggingface (NEW) -- `TorchVision` + `SAHI` walkthrough: sahi-torchvision (NEW) +- `DeepSparse` + `SAHI` walkthrough: sahi-deepsparse (NEW) + +- `SuperGradients/YOLONAS` + `SAHI`: sahi-yolonas (NEW) sahi-yolox @@ -129,6 +135,12 @@ conda install pytorch=1.13.1 torchvision=0.14.1 pytorch-cuda=11.7 -c pytorch -c pip install yolov5==7.0.13 ``` +- Install your desired detection framework (ultralytics): + +```console +pip install ultralytics==8.0.207 +``` + - Install your desired detection framework (mmdet): ```console @@ -148,6 +160,12 @@ pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu113 pip install transformers timm ``` +- Install your desired detection framework (super-gradients): + +```console +pip install super-gradients==3.3.1 +``` + ### Framework Agnostic Sliced/Standard Prediction @@ -261,6 +279,20 @@ python -m scripts.run_code_style format Alzbeta Tureckova +So Uchida + +Yonghye Kwon + +Neville + +Janne Mäyrä + +Christoffer Edlund + +Ilker Manap + +Nguyễn Thế An + Wei Ji Aynur Susuz @@ -269,5 +301,13 @@ python -m scripts.run_code_style format Lakshay Mehra +Karl-Joan Alesma + +Jacob Marks + +William Lung + +Amogh Dhaliwal + diff --git a/sahi/postprocess/combine.py b/sahi/postprocess/combine.py index 3a28c9c72..3ae151f3c 100644 --- a/sahi/postprocess/combine.py +++ b/sahi/postprocess/combine.py @@ -217,19 +217,12 @@ def greedy_nmm( # according to their confidence scores order = scores.argsort() - # initialise an empty list for - # filtered prediction boxes - keep = [] - while len(order) > 0: # extract the index of the # prediction with highest score # we call this prediction S idx = order[-1] - # push S in filtered predictions list - keep.append(idx.tolist()) - # remove S from P order = order[:-1] diff --git a/sahi/predict.py b/sahi/predict.py index 699cb6f67..e7ddb84ff 100644 --- a/sahi/predict.py +++ b/sahi/predict.py @@ -164,7 +164,7 @@ def get_sliced_prediction( detection accuracy. Default: True. postprocess_type: str Type of the postprocess to be used after sliced inference while merging/eliminating predictions. - Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'. + Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'. postprocess_match_metric: str Metric to be used during object prediction matching after sliced prediction. 'IOU' for intersection over union, 'IOS' for intersection over smaller area. @@ -231,7 +231,7 @@ def get_sliced_prediction( # create prediction input num_group = int(num_slices / num_batch) if verbose == 1 or verbose == 2: - tqdm.write(f"Performing prediction on {num_slices} number of slices.") + tqdm.write(f"Performing prediction on {num_slices} slices.") object_prediction_list = [] # perform sliced prediction for group_ind in range(num_group): @@ -416,7 +416,7 @@ def predict( Default to ``0.2``. postprocess_type: str Type of the postprocess to be used after sliced inference while merging/eliminating predictions. - Options are 'NMM', 'GREEDYNMM', 'LSNMS' or 'NMS'. Default is 'GRREDYNMM'. + Options are 'NMM', 'GREEDYNMM', 'LSNMS' or 'NMS'. Default is 'GREEDYNMM'. postprocess_match_metric: str Metric to be used during object prediction matching after sliced prediction. 'IOU' for intersection over union, 'IOS' for intersection over smaller area. @@ -781,7 +781,7 @@ def predict_fiftyone( Default to ``0.2``. postprocess_type: str Type of the postprocess to be used after sliced inference while merging/eliminating predictions. - Options are 'NMM', 'GRREDYNMM' or 'NMS'. Default is 'GRREDYNMM'. + Options are 'NMM', 'GREEDYNMM' or 'NMS'. Default is 'GREEDYNMM'. postprocess_match_metric: str Metric to be used during object prediction matching after sliced prediction. 'IOU' for intersection over union, 'IOS' for intersection over smaller area. diff --git a/sahi/utils/coco.py b/sahi/utils/coco.py index 1777555aa..e84f2f3d3 100644 --- a/sahi/utils/coco.py +++ b/sahi/utils/coco.py @@ -130,7 +130,7 @@ def from_coco_annotation_dict(cls, annotation_dict: Dict, category_name: Optiona annotation_dict: dict COCO formatted annotation dict (with fields "bbox", "segmentation", "category_id") """ - if annotation_dict.__contains__("segmentation") and not isinstance(annotation_dict["segmentation"], list): + if annotation_dict.__contains__("segmentation") and isinstance(annotation_dict["segmentation"], dict): has_rle_segmentation = True logger.warning( f"Segmentation annotation for id {annotation_dict['id']} is skipped since RLE segmentation format is not supported." diff --git a/sahi/utils/cv.py b/sahi/utils/cv.py index 06434a086..269dc254c 100644 --- a/sahi/utils/cv.py +++ b/sahi/utils/cv.py @@ -21,7 +21,6 @@ class Colors: - # color palette def __init__(self): hex = ( "FF3838", @@ -45,16 +44,38 @@ def __init__(self): "FF95C8", "FF37C7", ) - self.palette = [self.hex2rgb("#" + c) for c in hex] + self.palette = [self.hex_to_rgb("#" + c) for c in hex] self.n = len(self.palette) - def __call__(self, i, bgr=False): - c = self.palette[int(i) % self.n] - return (c[2], c[1], c[0]) if bgr else c + def __call__(self, ind, bgr: bool = False): + """ + Convert an index to a color code. + + Args: + ind (int): The index to convert. + bgr (bool, optional): Whether to return the color code in BGR format. Defaults to False. + + Returns: + tuple: The color code in RGB or BGR format, depending on the value of `bgr`. + """ + color_codes = self.palette[int(ind) % self.n] + return (color_codes[2], color_codes[1], color_codes[0]) if bgr else color_codes @staticmethod - def hex2rgb(h): # rgb order - return tuple(int(h[1 + i : 1 + i + 2], 16) for i in (0, 2, 4)) + def hex_to_rgb(hex_code): + """ + Converts a hexadecimal color code to RGB format. + + Args: + hex_code (str): The hexadecimal color code to convert. + + Returns: + tuple: A tuple representing the RGB values in the order (R, G, B). + """ + rgb = [] + for i in (0, 2, 4): + rgb.append(int(hex_code[1 + i : 1 + i + 2], 16)) + return tuple(rgb) def crop_object_predictions( @@ -65,23 +86,25 @@ def crop_object_predictions( export_format: str = "png", ): """ - Crops bounding boxes over the source image and exports it to output folder. - Arguments: - object_predictions: a list of prediction.ObjectPrediction - output_dir: directory for resulting visualization to be exported - file_name: exported file will be saved as: output_dir+file_name+".png" - export_format: can be specified as 'jpg' or 'png' + Crops bounding boxes over the source image and exports it to the output folder. + + Args: + image (np.ndarray): The source image to crop bounding boxes from. + object_prediction_list: A list of object predictions. + output_dir (str): The directory where the resulting visualizations will be exported. Defaults to an empty string. + file_name (str): The name of the exported file. The exported file will be saved as `output_dir + file_name + ".png"`. Defaults to "prediction_visual". + export_format (str): The format of the exported file. Can be specified as 'jpg' or 'png'. Defaults to "png". """ # create output folder if not present Path(output_dir).mkdir(parents=True, exist_ok=True) # add bbox and mask to image if present for ind, object_prediction in enumerate(object_prediction_list): - # deepcopy object_prediction_list so that original is not altered + # deepcopy object_prediction_list so that the original is not altered object_prediction = object_prediction.deepcopy() bbox = object_prediction.bbox.to_xyxy() category_id = object_prediction.category.id # crop detections - # deepcopy crops so that original is not altered + # deepcopy crops so that the original is not altered cropped_img = copy.deepcopy( image[ int(bbox[1]) : int(bbox[3]), @@ -98,7 +121,12 @@ def crop_object_predictions( def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False): """ - Reads image from path and saves as given extension. + Reads an image from the given path and saves it with the specified extension. + + Args: + read_path (str): The path to the image file. + extension (str, optional): The desired file extension for the saved image. Defaults to "jpg". + grayscale (bool, optional): Whether to convert the image to grayscale. Defaults to False. """ image = cv2.imread(read_path) pre, ext = os.path.splitext(read_path) @@ -110,6 +138,17 @@ def convert_image_to(read_path, extension: str = "jpg", grayscale: bool = False) def read_large_image(image_path: str): + """ + Reads a large image from the specified image path. + + Args: + image_path (str): The path to the image file. + + Returns: + tuple: A tuple containing the image data and a flag indicating whether cv2 was used to read the image. + The image data is a numpy array representing the image in RGB format. + The flag is True if cv2 was used, False otherwise. + """ use_cv2 = True # read image, cv2 fails on large files try: @@ -130,7 +169,13 @@ def read_large_image(image_path: str): def read_image(image_path: str): """ - Loads image as numpy array from given path. + Loads image as a numpy array from the given path. + + Args: + image_path (str): The path to the image file. + + Returns: + numpy.ndarray: The loaded image as a numpy array. """ # read image image = cv2.imread(image_path) @@ -144,7 +189,12 @@ def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool Loads an image as PIL.Image.Image. Args: - image : Can be image path or url (str), numpy image (np.ndarray) or PIL.Image + image (Union[Image.Image, str, np.ndarray]): The image to be loaded. It can be an image path or URL (str), + a numpy image (np.ndarray), or a PIL.Image object. + exif_fix (bool, optional): Whether to apply an EXIF fix to the image. Defaults to False. + + Returns: + PIL.Image.Image: The loaded image as a PIL.Image object. """ # https://stackoverflow.com/questions/56174099/how-to-load-images-larger-than-max-image-pixels-with-pil Image.MAX_IMAGE_PIXELS = None @@ -184,7 +234,11 @@ def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool def select_random_color(): """ - Selects random color. + Selects a random color from a predefined list of colors. + + Returns: + list: A list representing the RGB values of the selected color. + """ colors = [ [0, 255, 0], @@ -205,6 +259,13 @@ def select_random_color(): def apply_color_mask(image: np.ndarray, color: tuple): """ Applies color mask to given input image. + + Args: + image (np.ndarray): The input image to apply the color mask to. + color (tuple): The RGB color tuple to use for the mask. + + Returns: + np.ndarray: The resulting image with the applied color mask. """ r = np.zeros_like(image).astype(np.uint8) g = np.zeros_like(image).astype(np.uint8) @@ -328,6 +389,22 @@ def visualize_prediction( """ Visualizes prediction classes, bounding boxes over the source image and exports it to output folder. + + Args: + image (np.ndarray): The source image. + boxes (List[List]): List of bounding boxes coordinates. + classes (List[str]): List of class labels corresponding to each bounding box. + masks (Optional[List[np.ndarray]], optional): List of masks corresponding to each bounding box. Defaults to None. + rect_th (float, optional): Thickness of the bounding box rectangle. Defaults to None. + text_size (float, optional): Size of the text for class labels. Defaults to None. + text_th (float, optional): Thickness of the text for class labels. Defaults to None. + color (tuple, optional): Color of the bounding box and text. Defaults to None. + hide_labels (bool, optional): Whether to hide the class labels. Defaults to False. + output_dir (Optional[str], optional): Output directory to save the visualization. Defaults to None. + file_name (Optional[str], optional): File name for the saved visualization. Defaults to "prediction_visual". + + Returns: + dict: A dictionary containing the visualized image and the elapsed time for the visualization process. """ elapsed_time = time.time() # deepcopy image so that original is not altered @@ -354,21 +431,21 @@ def visualize_prediction( image = cv2.addWeighted(image, 1, rgb_mask, 0.6, 0) # add bboxes to image if present - for i in range(len(boxes)): + for box_indice in range(len(boxes)): # deepcopy boxso that original is not altered - box = copy.deepcopy(boxes[i]) - class_ = classes[i] + box = copy.deepcopy(boxes[box_indice]) + class_ = classes[box_indice] # set color if colors is not None: color = colors(class_) # set bbox points - p1, p2 = (int(box[0]), int(box[1])), (int(box[2]), int(box[3])) + point1, point2 = [int(box[0]), int(box[1])], [int(box[2]), int(box[3])] # visualize boxes cv2.rectangle( image, - p1, - p2, + point1, + point2, color=color, thickness=rect_th, ) @@ -376,15 +453,17 @@ def visualize_prediction( if not hide_labels: # arange bounding box text location label = f"{class_}" - w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0] # label width, height - outside = p1[1] - h - 3 >= 0 # label fits outside box - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[ + 0 + ] # label width, height + outside = point1[1] - box_height - 3 >= 0 # label fits outside box + point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3 # add bounding box text - cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.rectangle(image, point1, point2, color, -1, cv2.LINE_AA) # filled cv2.putText( image, label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + (point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2), 0, text_size, (255, 255, 255), @@ -417,7 +496,8 @@ def visualize_object_predictions( """ Visualizes prediction category names, bounding boxes over the source image and exports it to output folder. - Arguments: + + Args: object_prediction_list: a list of prediction.ObjectPrediction rect_th: rectangle thickness text_size: size of the category name over box @@ -472,12 +552,12 @@ def visualize_object_predictions( if colors is not None: color = colors(object_prediction.category.id) # set bbox points - p1, p2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])) + point1, point2 = (int(bbox[0]), int(bbox[1])), (int(bbox[2]), int(bbox[3])) # visualize boxes cv2.rectangle( image, - p1, - p2, + point1, + point2, color=color, thickness=rect_th, ) @@ -489,15 +569,17 @@ def visualize_object_predictions( if not hide_conf: label += f" {score:.2f}" - w, h = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[0] # label width, height - outside = p1[1] - h - 3 >= 0 # label fits outside box - p2 = p1[0] + w, p1[1] - h - 3 if outside else p1[1] + h + 3 + box_width, box_height = cv2.getTextSize(label, 0, fontScale=text_size, thickness=text_th)[ + 0 + ] # label width, height + outside = point1[1] - box_height - 3 >= 0 # label fits outside box + point2 = point1[0] + box_width, point1[1] - box_height - 3 if outside else point1[1] + box_height + 3 # add bounding box text - cv2.rectangle(image, p1, p2, color, -1, cv2.LINE_AA) # filled + cv2.rectangle(image, point1, point2, color, -1, cv2.LINE_AA) # filled cv2.putText( image, label, - (p1[0], p1[1] - 2 if outside else p1[1] + h + 2), + (point1[0], point1[1] - 2 if outside else point1[1] + box_height + 2), 0, text_size, (255, 255, 255), @@ -541,9 +623,17 @@ def get_coco_segmentation_from_bool_mask(bool_mask): return coco_segmentation -def get_bool_mask_from_coco_segmentation(coco_segmentation, width, height): +def get_bool_mask_from_coco_segmentation(coco_segmentation: List[List[float]], width: int, height: int) -> np.ndarray: """ Convert coco segmentation to 2D boolean mask of given height and width + + Parameters: + - coco_segmentation: list of points representing the coco segmentation + - width: width of the boolean mask + - height: height of the boolean mask + + Returns: + - bool_mask: 2D boolean mask of size (height, width) """ size = [height, width] points = [np.array(point).reshape(-1, 2).round().astype(int) for point in coco_segmentation] @@ -553,9 +643,15 @@ def get_bool_mask_from_coco_segmentation(coco_segmentation, width, height): return bool_mask -def get_bbox_from_bool_mask(bool_mask): +def get_bbox_from_bool_mask(bool_mask: np.ndarray) -> Optional[List[int]]: """ - Generate voc bbox ([xmin, ymin, xmax, ymax]) from given bool_mask (2D np.ndarray) + Generate VOC bounding box [xmin, ymin, xmax, ymax] from given boolean mask. + + Args: + bool_mask (np.ndarray): 2D boolean mask. + + Returns: + Optional[List[int]]: VOC bounding box [xmin, ymin, xmax, ymax] or None if no bounding box is found. """ rows = np.any(bool_mask, axis=1) cols = np.any(bool_mask, axis=0) @@ -596,12 +692,16 @@ def ipython_display(image: np.ndarray): IPython.display.display(i) -def exif_transpose(image: Image.Image): +def exif_transpose(image: Image.Image) -> Image.Image: """ Transpose a PIL image accordingly if it has an EXIF Orientation tag. Inplace version of https://github.com/python-pillow/Pillow/blob/master/src/PIL/ImageOps.py exif_transpose() - :param image: The image to transpose. - :return: An image. + + Args: + image (Image.Image): The image to transpose. + + Returns: + Image.Image: The transposed image. """ exif = image.getexif() orientation = exif.get(0x0112, 1) # default 1 diff --git a/sahi/utils/file.py b/sahi/utils/file.py index bf2dd1621..be0696120 100644 --- a/sahi/utils/file.py +++ b/sahi/utils/file.py @@ -166,6 +166,16 @@ def get_base_filename(path: str): def get_file_extension(path: str): + """ + Get the file extension from a given file path. + + Args: + path (str): The file path. + + Returns: + str: The file extension. + + """ filename, file_extension = os.path.splitext(path) return file_extension @@ -214,20 +224,49 @@ def import_model_class(model_type, class_name): return class_ -def increment_path(path, exist_ok=True, sep=""): - # Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc. +def increment_path(path: str, exist_ok: bool = True, sep: str = "") -> str: + """ + Increment path, i.e. runs/exp --> runs/exp{sep}0, runs/exp{sep}1 etc. + + Args: + path: str + The base path to increment. + exist_ok: bool + If True, return the path as is if it already exists. If False, increment the path. + sep: str + The separator to use between the base path and the increment number. + + Returns: + str: The incremented path. + + Example: + >>> increment_path("runs/exp", sep="_") + 'runs/exp_0' + >>> increment_path("runs/exp_0", sep="_") + 'runs/exp_1' + """ path = Path(path) # os-agnostic if (path.exists() and exist_ok) or (not path.exists()): return str(path) else: dirs = glob.glob(f"{path}{sep}*") # similar paths matches = [re.search(rf"%s{sep}(\d+)" % path.stem, d) for d in dirs] - i = [int(m.groups()[0]) for m in matches if m] # indices - n = max(i) + 1 if i else 2 # increment number + indices = [int(m.groups()[0]) for m in matches if m] # indices + n = max(indices) + 1 if indices else 2 # increment number return f"{path}{sep}{n}" # update path def download_from_url(from_url: str, to_path: str): + """ + Downloads a file from the given URL and saves it to the specified path. + + Args: + from_url (str): The URL of the file to download. + to_path (str): The path where the downloaded file should be saved. + + Returns: + None + """ Path(to_path).parent.mkdir(parents=True, exist_ok=True) if not os.path.exists(to_path): @@ -238,7 +277,12 @@ def download_from_url(from_url: str, to_path: str): def is_colab(): + """ + Check if the current environment is a Google Colab instance. + + Returns: + bool: True if the environment is a Google Colab instance, False otherwise. + """ import sys - # Is environment a Google Colab instance? return "google.colab" in sys.modules