diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 07f93e566..80a10d494 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -106,11 +106,11 @@ jobs: run: > pip install pycocotools==2.0.7 - - name: Install ultralytics(8.0.207) + - name: Install ultralytics(8.3.50) run: > - pip install ultralytics==8.0.207 + pip install ultralytics==8.3.50 - - name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms + - name: Unittest for SAHI+YOLO11/RTDETR/MMDET/HuggingFace/Torchvision on all platforms run: | python -m unittest diff --git a/.github/workflows/package_testing.yml b/.github/workflows/package_testing.yml index 8ae96db5d..ac5433e4f 100644 --- a/.github/workflows/package_testing.yml +++ b/.github/workflows/package_testing.yml @@ -84,15 +84,15 @@ jobs: run: > pip install pycocotools==2.0.7 - - name: Install ultralytics(8.0.207) + - name: Install ultralytics(8.3.50) run: > - pip install ultralytics==8.0.207 + pip install ultralytics==8.3.50 - name: Install latest SAHI package run: > pip install --upgrade --force-reinstall sahi - - name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms + - name: Unittest for SAHI+YOLO11/RTDETR/MMDET/HuggingFace/Torchvision on all platforms run: | python -m unittest diff --git a/README.md b/README.md index 6c40e3cef..131bc5ad5 100644 --- a/README.md +++ b/README.md @@ -76,11 +76,13 @@ Object detection and instance segmentation are by far the most important applica - [Slicing operation notebook](demo/slicing.ipynb) -- `YOLOX` + `SAHI` demo: sahi-yolox (RECOMMENDED) +- `YOLOX` + `SAHI` demo: sahi-yolox + +- `YOLO11` + `SAHI` walkthrough: sahi-yolov8 (NEW) - `RT-DETR` + `SAHI` walkthrough: sahi-rtdetr (NEW) -- `YOLOv8` + `SAHI` walkthrough: sahi-yolov8 +- `YOLOv8` + `SAHI` walkthrough: sahi-yolov8 - `DeepSparse` + `SAHI` walkthrough: sahi-deepsparse @@ -141,7 +143,7 @@ pip install yolov5==7.0.13 - Install your desired detection framework (ultralytics): ```console -pip install ultralytics==8.0.207 +pip install ultralytics==8.3.50 ``` - Install your desired detection framework (mmdet): @@ -228,9 +230,9 @@ If you use this package in your work, please cite it as: ##
Contributing
-`sahi` library currently supports all [YOLOv5 models](https://github.com/ultralytics/yolov5/releases), [MMDetection models](https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md), [Detectron2 models](https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md), and [HuggingFace object detection models](https://huggingface.co/models?pipeline_tag=object-detection&sort=downloads). Moreover, it is easy to add new frameworks. +`sahi` library currently supports all [Ultralytics (YOLOv8/v10/v11/RTDETR) models](https://github.com/ultralytics/ultralytics), [MMDetection models](https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md), [Detectron2 models](https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md), and [HuggingFace object detection models](https://huggingface.co/models?pipeline_tag=object-detection&sort=downloads). Moreover, it is easy to add new frameworks. -All you need to do is, create a new .py file under [sahi/models/](https://github.com/obss/sahi/tree/main/sahi/models) folder and create a new class in that .py file that implements [DetectionModel class](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/base.py#L12). You can take the [MMDetection wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/mmdet.py#L18) or [YOLOv5 wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/yolov5.py#L17) as a reference. +All you need to do is, create a new .py file under [sahi/models/](https://github.com/obss/sahi/tree/main/sahi/models) folder and create a new class in that .py file that implements [DetectionModel class](https://github.com/obss/sahi/blob/aaeb57c39780a5a32c4de2848e54df9a874df58b/sahi/models/base.py#L12). You can take the [MMDetection wrapper](https://github.com/obss/sahi/blob/aaeb57c39780a5a32c4de2848e54df9a874df58b/sahi/models/mmdet.py#L91) or [YOLOv5 wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/yolov5.py#L17) as a reference. Before opening a PR: diff --git a/demo/inference_for_yolov8.ipynb b/demo/inference_for_ultralytics.ipynb similarity index 99% rename from demo/inference_for_yolov8.ipynb rename to demo/inference_for_ultralytics.ipynb index fdcba4940..1d3976cc1 100644 --- a/demo/inference_for_yolov8.ipynb +++ b/demo/inference_for_ultralytics.ipynb @@ -55,8 +55,9 @@ "outputs": [], "source": [ "# arrange an instance segmentation model for test\n", - "from sahi.utils.yolov8 import (\n", - " download_yolov8s_model, download_yolov8s_seg_model\n", + "from sahi.utils.ultralytics import (\n", + " download_yolo11n_model, download_yolo11n_seg_model,\n", + " # download_yolov8n_model, download_yolov8n_seg_model\n", ")\n", "\n", "from sahi import AutoDetectionModel\n", @@ -80,9 +81,10 @@ "metadata": {}, "outputs": [], "source": [ - "# download YOLOV5S6 model to 'models/yolov5s6.pt'\n", - "yolov8_model_path = \"models/yolov8s.pt\"\n", - "download_yolov8s_model(yolov8_model_path)\n", + "yolo11n_model_path = \"models/yolov11n.pt\"\n", + "download_yolo11n_model(yolo11n_model_path)\n", + "# yolov8n_model_path = \"models/yolov8n.pt\"\n", + "# download_yolov8n_model(yolov8n_model_path)\n", "\n", "# download test images into demo_data folder\n", "download_from_url('https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/small-vehicles1.jpeg', 'demo_data/small-vehicles1.jpeg')\n", @@ -94,7 +96,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 1. Standard Inference with a YOLOv8 Model" + "## 1. Standard Inference with a YOLOv8/YOLO11 Model" ] }, { @@ -111,8 +113,8 @@ "outputs": [], "source": [ "detection_model = AutoDetectionModel.from_pretrained(\n", - " model_type='yolov8',\n", - " model_path=yolov8_model_path,\n", + " model_type='yolo11', # or 'yolov8'\n", + " model_path=yolo11n_model_path,\n", " confidence_threshold=0.3,\n", " device=\"cpu\", # or 'cuda:0'\n", ")" @@ -185,7 +187,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## 2. Sliced Inference with a YOLOv8 Model" + "## 2. Sliced Inference with a YOLOv8/YOLO11 Model" ] }, { @@ -466,8 +468,8 @@ "metadata": {}, "outputs": [], "source": [ - "model_type = \"yolov8\"\n", - "model_path = yolov8_model_path\n", + "model_type = \"yolo11\"\n", + "model_path = yolo11n_model_path\n", "model_device = \"cpu\" # or 'cuda:0'\n", "model_confidence_threshold = 0.4\n", "\n", @@ -599,9 +601,10 @@ "metadata": {}, "outputs": [], "source": [ - "#download YOLOV8S model to 'models/yolov8s.pt'\n", - "yolov8_seg_model_path = \"models/yolov8s-seg.pt\"\n", - "download_yolov8s_seg_model(yolov8_seg_model_path)" + "yolo11n_seg_model_path = \"models/yolov11n-seg.pt\"\n", + "download_yolo11n_seg_model(yolo11n_seg_model_path)\n", + "# yolov8n_seg_model_path = \"models/yolov8n-seg.pt\"\n", + "# download_yolov8n_seg_model(yolov8n_seg_model_path)\n" ] }, { @@ -611,8 +614,8 @@ "outputs": [], "source": [ "detection_model_seg = AutoDetectionModel.from_pretrained(\n", - " model_type='yolov8',\n", - " model_path=yolov8_seg_model_path,\n", + " model_type='yolo11', # or 'yolov8'\n", + " model_path=yolo11n_seg_model_path,\n", " confidence_threshold=0.3,\n", " device=\"cpu\", # or 'cuda:0'\n", ")" diff --git a/sahi/__init__.py b/sahi/__init__.py index 301ca944e..bed884633 100644 --- a/sahi/__init__.py +++ b/sahi/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.11.19" +__version__ = "0.11.20" from sahi.annotation import BoundingBox, Category, Mask from sahi.auto_model import AutoDetectionModel diff --git a/sahi/annotation.py b/sahi/annotation.py index 78318699e..5d7bc639c 100644 --- a/sahi/annotation.py +++ b/sahi/annotation.py @@ -130,8 +130,8 @@ class Mask: @classmethod def from_float_mask( cls, - mask, - full_shape=None, + mask: np.ndarray, + full_shape: List[int], mask_threshold: float = 0.5, shift_amount: list = [0, 0], ): @@ -144,7 +144,7 @@ def from_float_mask( shift_amount: List To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] - full_shape: List + full_shape: List[int] Size of the full image after shifting, should be in the form of [height, width] """ bool_mask = mask > mask_threshold @@ -156,8 +156,8 @@ def from_float_mask( def __init__( self, - segmentation, - full_shape=None, + segmentation: List[List[float]], + full_shape: List[int], shift_amount: list = [0, 0], ): """ @@ -170,9 +170,9 @@ def __init__( [x1, y1, x2, y2, x3, y3, ...], ... ] - full_shape: List + full_shape: List[int] Size of the full image, should be in the form of [height, width] - shift_amount: List + shift_amount: List[int] To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] """ @@ -195,17 +195,17 @@ def __init__( @classmethod def from_bool_mask( cls, - bool_mask=None, - full_shape=None, + bool_mask: np.ndarray, + full_shape: List[int], shift_amount: list = [0, 0], ): """ Args: bool_mask: np.ndarray with bool elements 2D mask of object, should have a shape of height*width - full_shape: List + full_shape: List[int] Size of the full image, should be in the form of [height, width] - shift_amount: List + shift_amount: List[int] To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] """ @@ -420,7 +420,7 @@ def from_coco_annotation_dict( @classmethod def from_shapely_annotation( cls, - annotation, + annotation: ShapelyAnnotation, full_shape: List[int], category_id: Optional[int] = None, category_name: Optional[str] = None, @@ -441,12 +441,9 @@ def from_shapely_annotation( To shift the box and mask predictions from sliced image to full sized image, should be in the form of [shift_x, shift_y] """ - bool_mask = get_bool_mask_from_coco_segmentation( - annotation.to_coco_segmentation(), width=full_shape[1], height=full_shape[0] - ) return cls( category_id=category_id, - bool_mask=bool_mask, + segmentation=annotation.to_coco_segmentation(), category_name=category_name, shift_amount=shift_amount, full_shape=full_shape, @@ -512,6 +509,8 @@ def __init__( raise ValueError("category_id must be an integer") if (bbox is None) and (segmentation is None): raise ValueError("you must provide a bbox or segmentation") + + self.mask: Mask | None = None if segmentation is not None: self.mask = Mask( segmentation=segmentation, @@ -524,8 +523,6 @@ def __init__( bbox = bbox_from_segmentation else: raise ValueError("Invalid segmentation mask.") - else: - self.mask = None # if bbox is a numpy object, convert it to python List[float] if type(bbox).__module__ == "numpy": @@ -552,13 +549,13 @@ def __init__( self.merged = None - def to_coco_annotation(self): + def to_coco_annotation(self) -> CocoAnnotation: """ Returns sahi.utils.coco.CocoAnnotation representation of ObjectAnnotation. """ if self.mask: coco_annotation = CocoAnnotation.from_coco_segmentation( - segmentation=self.mask.segmentation(), + segmentation=self.mask.segmentation, category_id=self.category.id, category_name=self.category.name, ) @@ -570,13 +567,13 @@ def to_coco_annotation(self): ) return coco_annotation - def to_coco_prediction(self): + def to_coco_prediction(self) -> CocoPrediction: """ Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation. """ if self.mask: coco_prediction = CocoPrediction.from_coco_segmentation( - segmentation=self.mask.segmentation(), + segmentation=self.mask.segmentation, category_id=self.category.id, category_name=self.category.name, score=1, @@ -590,13 +587,13 @@ def to_coco_prediction(self): ) return coco_prediction - def to_shapely_annotation(self): + def to_shapely_annotation(self) -> ShapelyAnnotation: """ Returns sahi.utils.shapely.ShapelyAnnotation representation of ObjectAnnotation. """ if self.mask: shapely_annotation = ShapelyAnnotation.from_coco_segmentation( - segmentation=self.mask.segmentation(), + segmentation=self.mask.segmentation, ) else: shapely_annotation = ShapelyAnnotation.from_coco_bbox( diff --git a/sahi/auto_model.py b/sahi/auto_model.py index a60f2d9d9..e66ab680b 100644 --- a/sahi/auto_model.py +++ b/sahi/auto_model.py @@ -3,7 +3,7 @@ from sahi.utils.file import import_model_class MODEL_TYPE_TO_MODEL_CLASS_NAME = { - "yolov8": "Yolov8DetectionModel", + "ultralytics": "UltralyticsDetectionModel", "rtdetr": "RTDetrDetectionModel", "mmdet": "MmdetDetectionModel", "yolov5": "Yolov5DetectionModel", @@ -14,6 +14,8 @@ "yolov8onnx": "Yolov8OnnxDetectionModel", } +ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"] + class AutoDetectionModel: @staticmethod @@ -60,7 +62,8 @@ def from_pretrained( Raises: ImportError: If given {model_type} framework is not installed """ - + if model_type in ULTRALYTICS_MODEL_NAMES: + model_type = "ultralytics" model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type] DetectionModel = import_model_class(model_type, model_class_name) diff --git a/sahi/models/__init__.py b/sahi/models/__init__.py index 4f2a5710b..c85fa5da8 100644 --- a/sahi/models/__init__.py +++ b/sahi/models/__init__.py @@ -1 +1 @@ -from . import base, detectron2, huggingface, mmdet, torchvision, yolov5, yolov8onnx +from . import base, detectron2, huggingface, mmdet, torchvision, ultralytics, yolov5, yolov8onnx diff --git a/sahi/models/base.py b/sahi/models/base.py index d1cd227d0..a434728dd 100644 --- a/sahi/models/base.py +++ b/sahi/models/base.py @@ -1,10 +1,11 @@ # OBSS SAHI Tool # Code written by Fatih C Akyon, 2020. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional import numpy as np +from sahi.prediction import ObjectPrediction from sahi.utils.import_utils import is_available from sahi.utils.torch import select_device as select_torch_device @@ -173,7 +174,7 @@ def convert_original_predictions( self._apply_category_remapping() @property - def object_prediction_list(self): + def object_prediction_list(self) -> List[ObjectPrediction]: return self._object_prediction_list_per_image[0] @property diff --git a/sahi/models/rtdetr.py b/sahi/models/rtdetr.py index d704eee59..91bc3250a 100644 --- a/sahi/models/rtdetr.py +++ b/sahi/models/rtdetr.py @@ -3,13 +3,13 @@ import logging -from sahi.models.yolov8 import Yolov8DetectionModel +from sahi.models.ultralytics import UltralyticsDetectionModel from sahi.utils.import_utils import check_requirements logger = logging.getLogger(__name__) -class RTDetrDetectionModel(Yolov8DetectionModel): +class RTDetrDetectionModel(UltralyticsDetectionModel): def check_dependencies(self) -> None: check_requirements(["ultralytics"]) diff --git a/sahi/models/yolov8.py b/sahi/models/ultralytics.py similarity index 71% rename from sahi/models/yolov8.py rename to sahi/models/ultralytics.py index b8311329d..dd421f5cc 100644 --- a/sahi/models/yolov8.py +++ b/sahi/models/ultralytics.py @@ -2,7 +2,7 @@ # Code written by AnNT, 2023. import logging -from typing import Any, Dict, List, Optional +from typing import Any, List, Optional import cv2 import numpy as np @@ -13,11 +13,11 @@ from sahi.models.base import DetectionModel from sahi.prediction import ObjectPrediction from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list -from sahi.utils.cv import get_coco_segmentation_from_bool_mask +from sahi.utils.cv import get_coco_segmentation_from_bool_mask, get_coco_segmentation_from_obb_points from sahi.utils.import_utils import check_requirements -class Yolov8DetectionModel(DetectionModel): +class UltralyticsDetectionModel(DetectionModel): def check_dependencies(self) -> None: check_requirements(["ultralytics"]) @@ -33,14 +33,14 @@ def load_model(self): model.to(self.device) self.set_model(model) except Exception as e: - raise TypeError("model_path is not a valid yolov8 model path: ", e) + raise TypeError("model_path is not a valid Ultralytics model path: ", e) def set_model(self, model: Any): """ - Sets the underlying YOLOv8 model. + Sets the underlying Ultralytics model. Args: model: Any - A YOLOv8 model + A Ultralytics model """ self.model = model @@ -52,11 +52,9 @@ def set_model(self, model: Any): def perform_inference(self, image: np.ndarray): """ Prediction is performed using self.model and the prediction result is set to self._original_predictions. - If predictions have masks, each prediction is a tuple like (boxes, masks). Args: image: np.ndarray A numpy array that contains the image to be predicted. 3 channel image should be in RGB order. - """ from ultralytics.engine.results import Masks @@ -86,8 +84,27 @@ def perform_inference(self, image: np.ndarray): ) for result in prediction_result ] - - else: # If model doesn't do segmentation then no need to check masks + elif self.is_obb: + # For OBB task, get OBB points in xyxyxyxy format + prediction_result = [ + ( + # Get OBB data: xyxy, conf, cls + torch.cat( + [ + result.obb.xyxy, # box coordinates + result.obb.conf.unsqueeze(-1), # confidence scores + result.obb.cls.unsqueeze(-1), # class ids + ], + dim=1, + ) + if result.obb is not None + else torch.empty((0, 6), device=self.model.device), + # Get OBB points in (N, 4, 2) format + result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=self.model.device), + ) + for result in prediction_result + ] + else: # If model doesn't do segmentation or OBB then no need to check masks # We do not filter results again as confidence threshold is already applied above prediction_result = [result.boxes.data for result in prediction_result] @@ -112,6 +129,13 @@ def has_mask(self): """ return self.model.overrides["task"] == "segment" + @property + def is_obb(self): + """ + Returns if model output contains oriented bounding boxes + """ + return self.model.overrides["task"] == "obb" + def _create_object_prediction_list_from_original_predictions( self, shift_amount_list: Optional[List[List[int]]] = [[0, 0]], @@ -142,13 +166,13 @@ def _create_object_prediction_list_from_original_predictions( full_shape = None if full_shape_list is None else full_shape_list[image_ind] object_prediction_list = [] - # Extract boxes and optional masks - if self.has_mask: + # Extract boxes and optional masks/obb + if self.has_mask or self.is_obb: boxes = image_predictions[0].cpu().detach().numpy() - masks = image_predictions[1].cpu().detach().numpy() + masks_or_points = image_predictions[1].cpu().detach().numpy() else: boxes = image_predictions.data.cpu().detach().numpy() - masks = None + masks_or_points = None # Process each prediction for pred_ind, prediction in enumerate(boxes): @@ -171,14 +195,20 @@ def _create_object_prediction_list_from_original_predictions( logger.warning(f"ignoring invalid prediction with bbox: {bbox}") continue - # Get segmentation if available + # Get segmentation or OBB points segmentation = None - if masks is not None: - bool_mask = masks[pred_ind] - orig_width = self._original_shape[1] - orig_height = self._original_shape[0] - bool_mask = cv2.resize(bool_mask.astype(np.uint8), (orig_width, orig_height)) - segmentation = get_coco_segmentation_from_bool_mask(bool_mask) + if masks_or_points is not None: + if self.has_mask: + bool_mask = masks_or_points[pred_ind] + # Resize mask to original image size + bool_mask = cv2.resize( + bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0]) + ) + segmentation = get_coco_segmentation_from_bool_mask(bool_mask) + else: # is_obb + obb_points = masks_or_points[pred_ind] # Get OBB points for this prediction + segmentation = get_coco_segmentation_from_obb_points(obb_points) + if len(segmentation) == 0: continue @@ -190,7 +220,7 @@ def _create_object_prediction_list_from_original_predictions( segmentation=segmentation, category_name=category_name, shift_amount=shift_amount, - full_shape=full_shape, + full_shape=self._original_shape[:2] if full_shape is None else full_shape, # (height, width) ) object_prediction_list.append(object_prediction) diff --git a/sahi/prediction.py b/sahi/prediction.py index e8ca14260..5fca9648a 100644 --- a/sahi/prediction.py +++ b/sahi/prediction.py @@ -8,7 +8,7 @@ from PIL import Image from sahi.annotation import ObjectAnnotation -from sahi.utils.coco import CocoAnnotation, CocoPrediction +from sahi.utils.coco import CocoPrediction from sahi.utils.cv import read_image_as_pil, visualize_object_predictions from sahi.utils.file import Path diff --git a/sahi/utils/cv.py b/sahi/utils/cv.py index c2006eb2b..693097aa9 100644 --- a/sahi/utils/cv.py +++ b/sahi/utils/cv.py @@ -167,7 +167,7 @@ def read_large_image(image_path: str): return image0, use_cv2 -def read_image(image_path: str): +def read_image(image_path: str) -> np.ndarray: """ Loads image as a numpy array from the given path. @@ -184,7 +184,7 @@ def read_image(image_path: str): return image -def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = False): +def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = False) -> Image.Image: """ Loads an image as PIL.Image.Image. @@ -688,6 +688,30 @@ def get_bbox_from_coco_segmentation(coco_segmentation): return [xmin, ymin, xmax, ymax] +def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[float]]: + """ + Convert OBB (Oriented Bounding Box) points to COCO polygon format. + + Args: + obb_points: np.ndarray + OBB points tensor from ultralytics.engine.results.OBB + Shape: (4, 2) containing 4 points with (x,y) coordinates each + Returns: + List[List[float]]: Polygon points in COCO format + [[x1, y1, x2, y2, x3, y3, x4, y4, x1, y1], [...], ...] + """ + # Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format + points = obb_points.reshape(-1).tolist() + + # Create polygon from points and close it by repeating first point + polygons = [] + # Add first point to end to close polygon + closed_polygon = points + [points[0], points[1]] + polygons.append(closed_polygon) + + return polygons + + def normalize_numpy_image(image: np.ndarray): """ Normalizes numpy image diff --git a/sahi/utils/ultralytics.py b/sahi/utils/ultralytics.py new file mode 100644 index 000000000..272668ca6 --- /dev/null +++ b/sahi/utils/ultralytics.py @@ -0,0 +1,121 @@ +import os +from pathlib import Path +from typing import Optional + +import requests +from tqdm import tqdm + +YOLOV8N_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt" +YOLOV8N_SEG_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n-seg.pt" +YOLO11N_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt" +YOLO11N_SEG_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n-seg.pt" +YOLO11N_OBB_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n-obb.pt" + + +class UltralyticsTestConstants: + YOLOV8N_MODEL_PATH = "tests/data/models/yolov8n.pt" + YOLOV8N_SEG_MODEL_PATH = "tests/data/models/yolov8n-seg.pt" + YOLO11N_MODEL_PATH = "tests/data/models/yolo11n.pt" + YOLO11N_SEG_MODEL_PATH = "tests/data/models/yolo11n-seg.pt" + YOLO11N_OBB_MODEL_PATH = "tests/data/models/yolo11n-obb.pt" + + +def download_file(url: str, save_path: str, chunk_size: int = 8192) -> None: + """ + Downloads a file from a given URL to the specified path. + + Args: + url: URL to download the file from + save_path: Path where the file will be saved + chunk_size: Size of chunks for downloading + """ + response = requests.get(url, stream=True) + total_size = int(response.headers.get("content-length", 0)) + + # Ensure directory exists + os.makedirs(os.path.dirname(save_path), exist_ok=True) + + with open(save_path, "wb") as f, tqdm( + desc=os.path.basename(save_path), + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as pbar: + for data in response.iter_content(chunk_size=chunk_size): + size = f.write(data) + pbar.update(size) + + +def download_yolov8n_model(destination_path: Optional[str] = None) -> str: + """Downloads YOLOv8n model if not already downloaded.""" + if destination_path is None: + destination_path = UltralyticsTestConstants.YOLOV8N_MODEL_PATH + + if not os.path.exists(destination_path): + download_file(YOLOV8N_WEIGHTS_URL, destination_path) + return destination_path + + +def download_yolov8n_seg_model(destination_path: Optional[str] = None) -> str: + """Downloads YOLOv8n-seg model if not already downloaded.""" + if destination_path is None: + destination_path = UltralyticsTestConstants.YOLOV8N_SEG_MODEL_PATH + + if not os.path.exists(destination_path): + download_file(YOLOV8N_SEG_WEIGHTS_URL, destination_path) + return destination_path + + +def download_yolo11n_model(destination_path: Optional[str] = None) -> str: + """Downloads YOLO11n model if not already downloaded.""" + if destination_path is None: + destination_path = UltralyticsTestConstants.YOLO11N_MODEL_PATH + + if not os.path.exists(destination_path): + download_file(YOLO11N_WEIGHTS_URL, destination_path) + return destination_path + + +def download_yolo11n_seg_model(destination_path: Optional[str] = None) -> str: + """Downloads YOLO11n-seg model if not already downloaded.""" + if destination_path is None: + destination_path = UltralyticsTestConstants.YOLO11N_SEG_MODEL_PATH + + if not os.path.exists(destination_path): + download_file(YOLO11N_SEG_WEIGHTS_URL, destination_path) + return destination_path + + +def download_yolo11n_obb_model(destination_path: Optional[str] = None) -> str: + """Downloads YOLO11n-obb model if not already downloaded.""" + if destination_path is None: + destination_path = UltralyticsTestConstants.YOLO11N_OBB_MODEL_PATH + + if not os.path.exists(destination_path): + download_file(YOLO11N_OBB_WEIGHTS_URL, destination_path) + return destination_path + + +def download_model_weights(model_path: str) -> str: + """ + Downloads model weights based on the model path. + + Args: + model_path: Path or name of the model + Returns: + Path to the downloaded weights file + """ + model_name = Path(model_path).stem + if model_name == "yolov8n": + return download_yolov8n_model() + elif model_name == "yolov8n-seg": + return download_yolov8n_seg_model() + elif model_name == "yolo11n": + return download_yolo11n_model() + elif model_name == "yolo11n-seg": + return download_yolo11n_seg_model() + elif model_name == "yolo11n-obb": + return download_yolo11n_obb_model() + else: + raise ValueError(f"Unknown model: {model_name}") diff --git a/sahi/utils/yolov8.py b/sahi/utils/yolov8.py deleted file mode 100644 index 8bfb556d4..000000000 --- a/sahi/utils/yolov8.py +++ /dev/null @@ -1,166 +0,0 @@ -import urllib.request -from os import path -from pathlib import Path -from typing import Optional - - -class Yolov8TestConstants: - YOLOV8N_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt" - YOLOV8N_MODEL_PATH = "tests/data/models/yolov8/yolov8n.pt" - - YOLOV8S_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s.pt" - YOLOV8S_MODEL_PATH = "tests/data/models/yolov8/yolov8s.pt" - - YOLOV8M_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m.pt" - YOLOV8M_MODEL_PATH = "tests/data/models/yolov8/yolov8m.pt" - - YOLOV8L_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l.pt" - YOLOV8L_MODEL_PATH = "tests/data/models/yolov8/yolov8l.pt" - - YOLOV8X_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt" - YOLOV8X_MODEL_PATH = "tests/data/models/yolov8/yolov8x.pt" - - YOLOV8N_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-seg.pt" - YOLOV8N_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8n-seg.pt" - - YOLOV8S_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-seg.pt" - YOLOV8S_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8s-seg.pt" - - YOLOV8M_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-seg.pt" - YOLOV8M_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8m-seg.pt" - - YOLOV8L_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-seg.pt" - YOLOV8L_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8l-seg.pt" - - YOLOV8X_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-seg.pt" - YOLOV8X_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8x-seg.pt" - - -def download_yolov8n_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8N_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8N_MODEL_URL, - destination_path, - ) - - -def download_yolov8s_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8S_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8S_MODEL_URL, - destination_path, - ) - - -def download_yolov8m_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8M_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8M_MODEL_URL, - destination_path, - ) - - -def download_yolov8l_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8L_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8L_MODEL_URL, - destination_path, - ) - - -def download_yolov8x_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8X_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8X_MODEL_URL, - destination_path, - ) - - -def download_yolov8n_seg_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8N_SEG_MODEL_URL, - destination_path, - ) - - -def download_yolov8s_seg_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8S_SEG_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8S_SEG_MODEL_URL, - destination_path, - ) - - -def download_yolov8m_seg_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8M_SEG_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8M_SEG_MODEL_URL, - destination_path, - ) - - -def download_yolov8l_seg_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8L_SEG_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8L_SEG_MODEL_URL, - destination_path, - ) - - -def download_yolov8x_seg_model(destination_path: Optional[str] = None): - if destination_path is None: - destination_path = Yolov8TestConstants.YOLOV8X_SEG_MODEL_PATH - - Path(destination_path).parent.mkdir(parents=True, exist_ok=True) - - if not path.exists(destination_path): - urllib.request.urlretrieve( - Yolov8TestConstants.YOLOV8X_SEG_MODEL_URL, - destination_path, - ) diff --git a/sahi/utils/yolov8onnx.py b/sahi/utils/yolov8onnx.py index ec137470e..9b20da403 100644 --- a/sahi/utils/yolov8onnx.py +++ b/sahi/utils/yolov8onnx.py @@ -3,7 +3,7 @@ import numpy as np -from sahi.utils.yolov8 import download_yolov8n_model +from sahi.utils.ultralytics import download_yolov8n_model class Yolov8ONNXTestConstants: diff --git a/scripts/utils.py b/scripts/utils.py index 5b2576736..a0b70c300 100644 --- a/scripts/utils.py +++ b/scripts/utils.py @@ -2,6 +2,8 @@ import shutil import sys +import click + def shell(command, exit_status=0): """ @@ -30,9 +32,15 @@ def validate_and_exit(expected_out_status=0, **kwargs): fail_count = 0 for component, exit_status in kwargs.items(): if exit_status != expected_out_status: - print(f"{component} failed.") + click.secho(f"{component} failed.", fg="red") fail_count += 1 + print_console_centered(f"{len(kwargs)-fail_count} success, {fail_count} failure") + click.secho("\nTo fix formatting issues:", fg="yellow") + click.secho("1. Install development dependencies:", fg="cyan") + click.secho(' pip install -e ."[dev]"', fg="green") + click.secho("\n2. Run code formatting:", fg="cyan") + click.secho(" python -m scripts.run_code_style format", fg="green") sys.exit(1) diff --git a/tests/test_predict.py b/tests/test_predict.py index eb108039a..6c6f8f164 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -8,6 +8,7 @@ import numpy as np from sahi.utils.cv import read_image +from sahi.utils.ultralytics import UltralyticsTestConstants, download_yolo11n_model MODEL_DEVICE = "cpu" CONFIDENCE_THRESHOLD = 0.5 @@ -294,9 +295,118 @@ def test_get_sliced_prediction_yolov5(self): num_car += 1 self.assertEqual(num_car, 11) + def test_get_prediction_yolo11(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + from sahi.predict import get_prediction + + # init model + download_yolo11n_model() + + yolo11_detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=False, + image_size=IMAGE_SIZE, + ) + yolo11_detection_model.load_model() + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + + # get full sized prediction + prediction_result = get_prediction( + image=image, detection_model=yolo11_detection_model, shift_amount=[0, 0], full_shape=None, postprocess=None + ) + object_prediction_list = prediction_result.object_prediction_list + + # compare + self.assertGreater(len(object_prediction_list), 0) + num_person = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "person": + num_person += 1 + self.assertEqual(num_person, 0) + num_truck = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "truck": + num_truck += 1 + self.assertEqual(num_truck, 0) + num_car = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "car": + num_car += 1 + self.assertGreater(num_car, 0) + + def test_get_sliced_prediction_yolo11(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + from sahi.predict import get_sliced_prediction + + # init model + download_yolo11n_model() + + yolo11_detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=False, + image_size=IMAGE_SIZE, + ) + yolo11_detection_model.load_model() + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + + slice_height = 512 + slice_width = 512 + overlap_height_ratio = 0.1 + overlap_width_ratio = 0.2 + postprocess_type = "GREEDYNMM" + match_metric = "IOS" + match_threshold = 0.5 + class_agnostic = True + + # get sliced prediction + prediction_result = get_sliced_prediction( + image=image_path, + detection_model=yolo11_detection_model, + slice_height=slice_height, + slice_width=slice_width, + overlap_height_ratio=overlap_height_ratio, + overlap_width_ratio=overlap_width_ratio, + perform_standard_pred=False, + postprocess_type=postprocess_type, + postprocess_match_threshold=match_threshold, + postprocess_match_metric=match_metric, + postprocess_class_agnostic=class_agnostic, + ) + object_prediction_list = prediction_result.object_prediction_list + + # compare + self.assertGreater(len(object_prediction_list), 0) + num_person = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "person": + num_person += 1 + self.assertEqual(num_person, 0) + num_truck = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "truck": + num_truck += 1 + self.assertEqual(num_truck, 0) + num_car = 0 + for object_prediction in object_prediction_list: + if object_prediction.category.name == "car": + num_car += 1 + self.assertGreater(num_car, 0) + def test_coco_json_prediction(self): from sahi.predict import predict from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_yolox_tiny_model + from sahi.utils.ultralytics import UltralyticsTestConstants, download_yolo11n_model from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model # init model @@ -382,6 +492,45 @@ def test_coco_json_prediction(self): verbose=1, ) + # init model + download_yolo11n_model() + + # prepare paths + dataset_json_path = "tests/data/coco_utils/terrain_all_coco.json" + source = "tests/data/coco_utils/" + project_dir = "tests/data/predict_result" + + # get sliced prediction + if os.path.isdir(project_dir): + shutil.rmtree(project_dir, ignore_errors=True) + predict( + model_type="ultralytics", + model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH, + model_config_path=None, + model_confidence_threshold=CONFIDENCE_THRESHOLD, + model_device=MODEL_DEVICE, + model_category_mapping=None, + model_category_remapping=None, + source=source, + no_sliced_prediction=False, + no_standard_prediction=True, + slice_height=512, + slice_width=512, + overlap_height_ratio=0.2, + overlap_width_ratio=0.2, + postprocess_type=postprocess_type, + postprocess_match_metric=match_metric, + postprocess_match_threshold=match_threshold, + postprocess_class_agnostic=class_agnostic, + novisual=True, + export_pickle=False, + export_crop=False, + dataset_json_path=dataset_json_path, + project=project_dir, + name="exp", + verbose=1, + ) + def test_video_prediction(self): from os import path diff --git a/tests/test_ultralyticsmodel.py b/tests/test_ultralyticsmodel.py new file mode 100644 index 000000000..cdb626d25 --- /dev/null +++ b/tests/test_ultralyticsmodel.py @@ -0,0 +1,221 @@ +# OBSS SAHI Tool +# Code written by AnNT, 2024. + +import unittest + +from sahi.utils.cv import read_image +from sahi.utils.file import download_from_url +from sahi.utils.ultralytics import ( + UltralyticsTestConstants, + download_yolo11n_model, + download_yolo11n_obb_model, + download_yolo11n_seg_model, + download_yolov8n_model, + download_yolov8n_seg_model, +) + +MODEL_DEVICE = "cpu" +CONFIDENCE_THRESHOLD = 0.3 +IMAGE_SIZE = 640 + + +class TestUltralyticsDetectionModel(unittest.TestCase): + def test_load_yolov8_model(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + + download_yolov8n_model() + + detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLOV8N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + ) + + self.assertNotEqual(detection_model.model, None) + self.assertTrue(hasattr(detection_model.model, "task")) + self.assertEqual(detection_model.model.task, "detect") + + def test_load_yolo11_model(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + + download_yolo11n_model() + + detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + ) + + self.assertNotEqual(detection_model.model, None) + self.assertTrue(hasattr(detection_model.model, "task")) + self.assertEqual(detection_model.model.task, "detect") + + def test_perform_inference_yolov8(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + + # init model + download_yolov8n_model() + + detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLOV8N_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + image_size=IMAGE_SIZE, + ) + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + + # perform inference + detection_model.perform_inference(image) + original_predictions = detection_model.original_predictions + + boxes = original_predictions[0].data + + # find box of first car detection with conf greater than 0.5 + for box in boxes: + if box[5].item() == 2: # if category car + if box[4].item() > 0.5: + break + + # compare + desired_bbox = [448, 309, 497, 342] + predicted_bbox = list(map(int, box[:4].tolist())) + margin = 2 + for ind, point in enumerate(predicted_bbox): + assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin + self.assertEqual(len(detection_model.category_names), 80) + for box in boxes: + self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD) + + def test_perform_inference_yolo11(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + + # init model + detection_model = UltralyticsDetectionModel( + model_path="yolo11n.pt", + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + image_size=IMAGE_SIZE, + ) + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + + # perform inference + detection_model.perform_inference(image) + original_predictions = detection_model.original_predictions + + boxes = original_predictions[0].data + + # verify predictions + self.assertEqual(len(detection_model.category_names), 80) + for box in boxes: + self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD) + + def test_yolo11_segmentation(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + + # init model + download_yolo11n_seg_model() + + detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLO11N_SEG_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + image_size=IMAGE_SIZE, + ) + + # Verify model properties + self.assertTrue(detection_model.has_mask) + self.assertEqual(detection_model.model.task, "segment") + + # prepare image and run inference + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + detection_model.perform_inference(image) + + # Verify segmentation output + original_predictions = detection_model.original_predictions + boxes = original_predictions[0][0] # Boxes + masks = original_predictions[0][1] # Masks + + self.assertGreater(len(boxes), 0) + self.assertEqual(masks.shape[0], len(boxes)) # One mask per box + self.assertEqual(len(masks.shape), 3) # (num_predictions, height, width) + + def test_yolo11_obb(self): + from sahi.models.ultralytics import UltralyticsDetectionModel + + # init model + download_yolo11n_obb_model() + + detection_model = UltralyticsDetectionModel( + model_path=UltralyticsTestConstants.YOLO11N_OBB_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + image_size=640, + ) + + # Verify model task + self.assertTrue(detection_model.is_obb) + self.assertEqual(detection_model.model.task, "obb") + + # prepare image and run inference + image_url = "https://ultralytics.com/images/boats.jpg" + image_path = "tests/data/boats.jpg" + download_from_url(image_url, to_path=image_path) + image = read_image(image_path) + detection_model.perform_inference(image) + + # Verify OBB predictions + original_predictions = detection_model.original_predictions + boxes = original_predictions[0][0] # Original box data + obb_points = original_predictions[0][1] # OBB points in xyxyxyxy format + + self.assertGreater(len(boxes), 0) + # Check box format: x1,y1,x2,y2,conf,cls + self.assertEqual(boxes.shape[1], 6) + # Check OBB points format + self.assertEqual(obb_points.shape[1:], (4, 2)) # (N, 4, 2) format + + # Convert predictions and verify + detection_model.convert_original_predictions() + object_prediction_list = detection_model.object_prediction_list + + # Verify converted predictions + self.assertEqual(len(object_prediction_list), len(boxes)) + for object_prediction in object_prediction_list: + # Verify confidence threshold + self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD) + + coco_segmentation = object_prediction.mask.segmentation + # Verify segmentation exists (converted from OBB) + self.assertIsNotNone(coco_segmentation) + # Verify segmentation is a list of points + self.assertTrue(isinstance(coco_segmentation, list)) + self.assertGreater(len(coco_segmentation), 0) + # Verify each segment is a valid closed polygon + for segment in coco_segmentation: + self.assertEqual(len(segment), 10) # 4 points + 1 closing point (x,y coordinates) + # Verify polygon is closed (first point equals last point) + self.assertEqual(segment[0], segment[-2]) # x coordinate + self.assertEqual(segment[1], segment[-1]) # y coordinate + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_yolov8model.py b/tests/test_yolov8model.py deleted file mode 100644 index f578113d7..000000000 --- a/tests/test_yolov8model.py +++ /dev/null @@ -1,236 +0,0 @@ -# OBSS SAHI Tool -# Code written by Fatih C Akyon, 2020. - -import unittest - -import numpy as np - -from sahi.utils.cv import read_image -from sahi.utils.yolov8 import Yolov8TestConstants, download_yolov8n_model, download_yolov8n_seg_model - -MODEL_DEVICE = "cpu" -CONFIDENCE_THRESHOLD = 0.3 -IMAGE_SIZE = 640 - - -class TestYolov8DetectionModel(unittest.TestCase): - def test_load_model(self): - from sahi.models.yolov8 import Yolov8DetectionModel - - download_yolov8n_model() - - yolov8_detection_model = Yolov8DetectionModel( - model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH, - confidence_threshold=CONFIDENCE_THRESHOLD, - device=MODEL_DEVICE, - category_remapping=None, - load_at_init=True, - ) - - self.assertNotEqual(yolov8_detection_model.model, None) - - def test_set_model(self): - from ultralytics import YOLO - - from sahi.models.yolov8 import Yolov8DetectionModel - - download_yolov8n_model() - - yolo_model = YOLO(Yolov8TestConstants.YOLOV8N_MODEL_PATH) - - yolov8_detection_model = Yolov8DetectionModel( - model=yolo_model, - confidence_threshold=CONFIDENCE_THRESHOLD, - device=MODEL_DEVICE, - category_remapping=None, - load_at_init=True, - ) - - self.assertNotEqual(yolov8_detection_model.model, None) - - def test_perform_inference(self): - from sahi.models.yolov8 import Yolov8DetectionModel - - # init model - download_yolov8n_model() - - yolov8_detection_model = Yolov8DetectionModel( - model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH, - confidence_threshold=CONFIDENCE_THRESHOLD, - device=MODEL_DEVICE, - category_remapping=None, - load_at_init=True, - image_size=IMAGE_SIZE, - ) - - # prepare image - image_path = "tests/data/small-vehicles1.jpeg" - image = read_image(image_path) - - # perform inference - yolov8_detection_model.perform_inference(image) - original_predictions = yolov8_detection_model.original_predictions - - boxes = original_predictions[0].data - - # find box of first car detection with conf greater than 0.5 - for box in boxes: - if box[5].item() == 2: # if category car - if box[4].item() > 0.5: - break - - # compare - desired_bbox = [448, 309, 497, 342] - predicted_bbox = list(map(int, box[:4].tolist())) - margin = 2 - for ind, point in enumerate(predicted_bbox): - assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin - self.assertEqual(len(yolov8_detection_model.category_names), 80) - for box in boxes: - self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD) - - def test_convert_original_predictions(self): - from sahi.models.yolov8 import Yolov8DetectionModel - - # init model - download_yolov8n_model() - - yolov8_detection_model = Yolov8DetectionModel( - model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH, - confidence_threshold=CONFIDENCE_THRESHOLD, - device=MODEL_DEVICE, - category_remapping=None, - load_at_init=True, - image_size=IMAGE_SIZE, - ) - - # prepare image - image_path = "tests/data/small-vehicles1.jpeg" - image = read_image(image_path) - - # get raw predictions for reference - original_results = yolov8_detection_model.model.predict(image_path, conf=CONFIDENCE_THRESHOLD)[0].boxes - num_results = len(original_results) - - # perform inference - yolov8_detection_model.perform_inference(image) - - # convert predictions to ObjectPrediction list - yolov8_detection_model.convert_original_predictions() - object_prediction_list = yolov8_detection_model.object_prediction_list - - # compare - self.assertEqual(len(object_prediction_list), num_results) - - # loop through predictions and check that they are equal - for i in range(num_results): - desired_bbox = [ - original_results[i].xyxy[0][0], - original_results[i].xyxy[0][1], - original_results[i].xywh[0][2], - original_results[i].xywh[0][3], - ] - desired_cat_id = int(original_results[i].cls[0]) - self.assertEqual(object_prediction_list[i].category.id, desired_cat_id) - predicted_bbox = object_prediction_list[i].bbox.to_xywh() - margin = 2 - for ind, point in enumerate(predicted_bbox): - assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin - for object_prediction in object_prediction_list: - self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD) - - def test_perform_inference_with_mask_output(self): - from sahi.models.yolov8 import Yolov8DetectionModel - - # init model - download_yolov8n_seg_model() - - yolov8_detection_model = Yolov8DetectionModel( - model_path=Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH, - confidence_threshold=CONFIDENCE_THRESHOLD, - device=MODEL_DEVICE, - category_remapping=None, - load_at_init=True, - image_size=IMAGE_SIZE, - ) - # prepare image - image_path = "tests/data/small-vehicles1.jpeg" - image = read_image(image_path) - - # perform inference - yolov8_detection_model.perform_inference(image) - original_predictions = yolov8_detection_model.original_predictions - boxes = original_predictions[0][0] - masks = original_predictions[0][1] - - # find box of first car detection with conf greater than 0.5 - for box in boxes: - if box[5].item() == 2: # if category car - if box[4].item() > 0.5: - break - - # compare - desired_bbox = [320, 323, 380, 365] - predicted_bbox = list(map(int, box[:4].tolist())) - margin = 3 - for ind, point in enumerate(predicted_bbox): - assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin - self.assertEqual(len(yolov8_detection_model.category_names), 80) - for box in boxes: - self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD) - self.assertEqual(masks.shape, (12, 352, 640)) - self.assertEqual(masks.shape[0], len(boxes)) - - def test_convert_original_predictions_with_mask_output(self): - from sahi.models.yolov8 import Yolov8DetectionModel - - # init model - download_yolov8n_seg_model() - - yolov8_detection_model = Yolov8DetectionModel( - model_path=Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH, - confidence_threshold=CONFIDENCE_THRESHOLD, - device=MODEL_DEVICE, - category_remapping=None, - load_at_init=True, - image_size=IMAGE_SIZE, - ) - - # prepare image - image_path = "tests/data/small-vehicles1.jpeg" - image = read_image(image_path) - - # get raw predictions for reference - original_results = yolov8_detection_model.model.predict(image_path, conf=CONFIDENCE_THRESHOLD)[0].boxes - num_results = len(original_results) - - # perform inference - yolov8_detection_model.perform_inference(image) - - # convert predictions to ObjectPrediction list - yolov8_detection_model.convert_original_predictions(full_shape=(image.shape[0], image.shape[1])) - object_prediction_list = yolov8_detection_model.object_prediction_list - - # compare - self.assertEqual(len(object_prediction_list), num_results) - - # loop through predictions and check that they are equal - for i in range(num_results): - desired_bbox = [ - original_results[i].xyxy[0][0], - original_results[i].xyxy[0][1], - original_results[i].xywh[0][2], - original_results[i].xywh[0][3], - ] - desired_cat_id = int(original_results[i].cls[0]) - self.assertEqual(object_prediction_list[i].category.id, desired_cat_id) - predicted_bbox = object_prediction_list[i].bbox.to_xywh() - margin = 20 # Margin high because for some reason some original predictions are really poor - for ind, point in enumerate(predicted_bbox): - assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin - for object_prediction in object_prediction_list: - self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD) - - -if __name__ == "__main__": - unittest.main()