From 4784b77e40a17ab8e83939c8e088257efcf72d52 Mon Sep 17 00:00:00 2001
From: fatih akyon <34196005+fcakyon@users.noreply.github.com>
Date: Mon, 16 Dec 2024 20:06:12 +0300
Subject: [PATCH] add yolo11 and ultralytics obb task support (#1109)
---
.github/workflows/ci.yml | 6 +-
.github/workflows/package_testing.yml | 6 +-
README.md | 12 +-
....ipynb => inference_for_ultralytics.ipynb} | 35 +--
sahi/__init__.py | 2 +-
sahi/annotation.py | 45 ++--
sahi/auto_model.py | 7 +-
sahi/models/__init__.py | 2 +-
sahi/models/base.py | 5 +-
sahi/models/rtdetr.py | 4 +-
sahi/models/{yolov8.py => ultralytics.py} | 74 ++++--
sahi/prediction.py | 2 +-
sahi/utils/cv.py | 28 ++-
sahi/utils/ultralytics.py | 121 +++++++++
sahi/utils/yolov8.py | 166 ------------
sahi/utils/yolov8onnx.py | 2 +-
scripts/utils.py | 10 +-
tests/test_predict.py | 149 +++++++++++
tests/test_ultralyticsmodel.py | 221 ++++++++++++++++
tests/test_yolov8model.py | 236 ------------------
20 files changed, 645 insertions(+), 488 deletions(-)
rename demo/{inference_for_yolov8.ipynb => inference_for_ultralytics.ipynb} (99%)
rename sahi/models/{yolov8.py => ultralytics.py} (71%)
create mode 100644 sahi/utils/ultralytics.py
delete mode 100644 sahi/utils/yolov8.py
create mode 100644 tests/test_ultralyticsmodel.py
delete mode 100644 tests/test_yolov8model.py
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 07f93e566..80a10d494 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -106,11 +106,11 @@ jobs:
run: >
pip install pycocotools==2.0.7
- - name: Install ultralytics(8.0.207)
+ - name: Install ultralytics(8.3.50)
run: >
- pip install ultralytics==8.0.207
+ pip install ultralytics==8.3.50
- - name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
+ - name: Unittest for SAHI+YOLO11/RTDETR/MMDET/HuggingFace/Torchvision on all platforms
run: |
python -m unittest
diff --git a/.github/workflows/package_testing.yml b/.github/workflows/package_testing.yml
index 8ae96db5d..ac5433e4f 100644
--- a/.github/workflows/package_testing.yml
+++ b/.github/workflows/package_testing.yml
@@ -84,15 +84,15 @@ jobs:
run: >
pip install pycocotools==2.0.7
- - name: Install ultralytics(8.0.207)
+ - name: Install ultralytics(8.3.50)
run: >
- pip install ultralytics==8.0.207
+ pip install ultralytics==8.3.50
- name: Install latest SAHI package
run: >
pip install --upgrade --force-reinstall sahi
- - name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
+ - name: Unittest for SAHI+YOLO11/RTDETR/MMDET/HuggingFace/Torchvision on all platforms
run: |
python -m unittest
diff --git a/README.md b/README.md
index 6c40e3cef..131bc5ad5 100644
--- a/README.md
+++ b/README.md
@@ -76,11 +76,13 @@ Object detection and instance segmentation are by far the most important applica
- [Slicing operation notebook](demo/slicing.ipynb)
-- `YOLOX` + `SAHI` demo: (RECOMMENDED)
+- `YOLOX` + `SAHI` demo:
+
+- `YOLO11` + `SAHI` walkthrough: (NEW)
- `RT-DETR` + `SAHI` walkthrough: (NEW)
-- `YOLOv8` + `SAHI` walkthrough:
+- `YOLOv8` + `SAHI` walkthrough:
- `DeepSparse` + `SAHI` walkthrough:
@@ -141,7 +143,7 @@ pip install yolov5==7.0.13
- Install your desired detection framework (ultralytics):
```console
-pip install ultralytics==8.0.207
+pip install ultralytics==8.3.50
```
- Install your desired detection framework (mmdet):
@@ -228,9 +230,9 @@ If you use this package in your work, please cite it as:
##
Contributing
-`sahi` library currently supports all [YOLOv5 models](https://github.com/ultralytics/yolov5/releases), [MMDetection models](https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md), [Detectron2 models](https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md), and [HuggingFace object detection models](https://huggingface.co/models?pipeline_tag=object-detection&sort=downloads). Moreover, it is easy to add new frameworks.
+`sahi` library currently supports all [Ultralytics (YOLOv8/v10/v11/RTDETR) models](https://github.com/ultralytics/ultralytics), [MMDetection models](https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md), [Detectron2 models](https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md), and [HuggingFace object detection models](https://huggingface.co/models?pipeline_tag=object-detection&sort=downloads). Moreover, it is easy to add new frameworks.
-All you need to do is, create a new .py file under [sahi/models/](https://github.com/obss/sahi/tree/main/sahi/models) folder and create a new class in that .py file that implements [DetectionModel class](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/base.py#L12). You can take the [MMDetection wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/mmdet.py#L18) or [YOLOv5 wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/yolov5.py#L17) as a reference.
+All you need to do is, create a new .py file under [sahi/models/](https://github.com/obss/sahi/tree/main/sahi/models) folder and create a new class in that .py file that implements [DetectionModel class](https://github.com/obss/sahi/blob/aaeb57c39780a5a32c4de2848e54df9a874df58b/sahi/models/base.py#L12). You can take the [MMDetection wrapper](https://github.com/obss/sahi/blob/aaeb57c39780a5a32c4de2848e54df9a874df58b/sahi/models/mmdet.py#L91) or [YOLOv5 wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/yolov5.py#L17) as a reference.
Before opening a PR:
diff --git a/demo/inference_for_yolov8.ipynb b/demo/inference_for_ultralytics.ipynb
similarity index 99%
rename from demo/inference_for_yolov8.ipynb
rename to demo/inference_for_ultralytics.ipynb
index fdcba4940..1d3976cc1 100644
--- a/demo/inference_for_yolov8.ipynb
+++ b/demo/inference_for_ultralytics.ipynb
@@ -55,8 +55,9 @@
"outputs": [],
"source": [
"# arrange an instance segmentation model for test\n",
- "from sahi.utils.yolov8 import (\n",
- " download_yolov8s_model, download_yolov8s_seg_model\n",
+ "from sahi.utils.ultralytics import (\n",
+ " download_yolo11n_model, download_yolo11n_seg_model,\n",
+ " # download_yolov8n_model, download_yolov8n_seg_model\n",
")\n",
"\n",
"from sahi import AutoDetectionModel\n",
@@ -80,9 +81,10 @@
"metadata": {},
"outputs": [],
"source": [
- "# download YOLOV5S6 model to 'models/yolov5s6.pt'\n",
- "yolov8_model_path = \"models/yolov8s.pt\"\n",
- "download_yolov8s_model(yolov8_model_path)\n",
+ "yolo11n_model_path = \"models/yolov11n.pt\"\n",
+ "download_yolo11n_model(yolo11n_model_path)\n",
+ "# yolov8n_model_path = \"models/yolov8n.pt\"\n",
+ "# download_yolov8n_model(yolov8n_model_path)\n",
"\n",
"# download test images into demo_data folder\n",
"download_from_url('https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/small-vehicles1.jpeg', 'demo_data/small-vehicles1.jpeg')\n",
@@ -94,7 +96,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## 1. Standard Inference with a YOLOv8 Model"
+ "## 1. Standard Inference with a YOLOv8/YOLO11 Model"
]
},
{
@@ -111,8 +113,8 @@
"outputs": [],
"source": [
"detection_model = AutoDetectionModel.from_pretrained(\n",
- " model_type='yolov8',\n",
- " model_path=yolov8_model_path,\n",
+ " model_type='yolo11', # or 'yolov8'\n",
+ " model_path=yolo11n_model_path,\n",
" confidence_threshold=0.3,\n",
" device=\"cpu\", # or 'cuda:0'\n",
")"
@@ -185,7 +187,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
- "## 2. Sliced Inference with a YOLOv8 Model"
+ "## 2. Sliced Inference with a YOLOv8/YOLO11 Model"
]
},
{
@@ -466,8 +468,8 @@
"metadata": {},
"outputs": [],
"source": [
- "model_type = \"yolov8\"\n",
- "model_path = yolov8_model_path\n",
+ "model_type = \"yolo11\"\n",
+ "model_path = yolo11n_model_path\n",
"model_device = \"cpu\" # or 'cuda:0'\n",
"model_confidence_threshold = 0.4\n",
"\n",
@@ -599,9 +601,10 @@
"metadata": {},
"outputs": [],
"source": [
- "#download YOLOV8S model to 'models/yolov8s.pt'\n",
- "yolov8_seg_model_path = \"models/yolov8s-seg.pt\"\n",
- "download_yolov8s_seg_model(yolov8_seg_model_path)"
+ "yolo11n_seg_model_path = \"models/yolov11n-seg.pt\"\n",
+ "download_yolo11n_seg_model(yolo11n_seg_model_path)\n",
+ "# yolov8n_seg_model_path = \"models/yolov8n-seg.pt\"\n",
+ "# download_yolov8n_seg_model(yolov8n_seg_model_path)\n"
]
},
{
@@ -611,8 +614,8 @@
"outputs": [],
"source": [
"detection_model_seg = AutoDetectionModel.from_pretrained(\n",
- " model_type='yolov8',\n",
- " model_path=yolov8_seg_model_path,\n",
+ " model_type='yolo11', # or 'yolov8'\n",
+ " model_path=yolo11n_seg_model_path,\n",
" confidence_threshold=0.3,\n",
" device=\"cpu\", # or 'cuda:0'\n",
")"
diff --git a/sahi/__init__.py b/sahi/__init__.py
index 301ca944e..bed884633 100644
--- a/sahi/__init__.py
+++ b/sahi/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "0.11.19"
+__version__ = "0.11.20"
from sahi.annotation import BoundingBox, Category, Mask
from sahi.auto_model import AutoDetectionModel
diff --git a/sahi/annotation.py b/sahi/annotation.py
index 78318699e..5d7bc639c 100644
--- a/sahi/annotation.py
+++ b/sahi/annotation.py
@@ -130,8 +130,8 @@ class Mask:
@classmethod
def from_float_mask(
cls,
- mask,
- full_shape=None,
+ mask: np.ndarray,
+ full_shape: List[int],
mask_threshold: float = 0.5,
shift_amount: list = [0, 0],
):
@@ -144,7 +144,7 @@ def from_float_mask(
shift_amount: List
To shift the box and mask predictions from sliced image
to full sized image, should be in the form of [shift_x, shift_y]
- full_shape: List
+ full_shape: List[int]
Size of the full image after shifting, should be in the form of [height, width]
"""
bool_mask = mask > mask_threshold
@@ -156,8 +156,8 @@ def from_float_mask(
def __init__(
self,
- segmentation,
- full_shape=None,
+ segmentation: List[List[float]],
+ full_shape: List[int],
shift_amount: list = [0, 0],
):
"""
@@ -170,9 +170,9 @@ def __init__(
[x1, y1, x2, y2, x3, y3, ...],
...
]
- full_shape: List
+ full_shape: List[int]
Size of the full image, should be in the form of [height, width]
- shift_amount: List
+ shift_amount: List[int]
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
@@ -195,17 +195,17 @@ def __init__(
@classmethod
def from_bool_mask(
cls,
- bool_mask=None,
- full_shape=None,
+ bool_mask: np.ndarray,
+ full_shape: List[int],
shift_amount: list = [0, 0],
):
"""
Args:
bool_mask: np.ndarray with bool elements
2D mask of object, should have a shape of height*width
- full_shape: List
+ full_shape: List[int]
Size of the full image, should be in the form of [height, width]
- shift_amount: List
+ shift_amount: List[int]
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
@@ -420,7 +420,7 @@ def from_coco_annotation_dict(
@classmethod
def from_shapely_annotation(
cls,
- annotation,
+ annotation: ShapelyAnnotation,
full_shape: List[int],
category_id: Optional[int] = None,
category_name: Optional[str] = None,
@@ -441,12 +441,9 @@ def from_shapely_annotation(
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
- bool_mask = get_bool_mask_from_coco_segmentation(
- annotation.to_coco_segmentation(), width=full_shape[1], height=full_shape[0]
- )
return cls(
category_id=category_id,
- bool_mask=bool_mask,
+ segmentation=annotation.to_coco_segmentation(),
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
@@ -512,6 +509,8 @@ def __init__(
raise ValueError("category_id must be an integer")
if (bbox is None) and (segmentation is None):
raise ValueError("you must provide a bbox or segmentation")
+
+ self.mask: Mask | None = None
if segmentation is not None:
self.mask = Mask(
segmentation=segmentation,
@@ -524,8 +523,6 @@ def __init__(
bbox = bbox_from_segmentation
else:
raise ValueError("Invalid segmentation mask.")
- else:
- self.mask = None
# if bbox is a numpy object, convert it to python List[float]
if type(bbox).__module__ == "numpy":
@@ -552,13 +549,13 @@ def __init__(
self.merged = None
- def to_coco_annotation(self):
+ def to_coco_annotation(self) -> CocoAnnotation:
"""
Returns sahi.utils.coco.CocoAnnotation representation of ObjectAnnotation.
"""
if self.mask:
coco_annotation = CocoAnnotation.from_coco_segmentation(
- segmentation=self.mask.segmentation(),
+ segmentation=self.mask.segmentation,
category_id=self.category.id,
category_name=self.category.name,
)
@@ -570,13 +567,13 @@ def to_coco_annotation(self):
)
return coco_annotation
- def to_coco_prediction(self):
+ def to_coco_prediction(self) -> CocoPrediction:
"""
Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation.
"""
if self.mask:
coco_prediction = CocoPrediction.from_coco_segmentation(
- segmentation=self.mask.segmentation(),
+ segmentation=self.mask.segmentation,
category_id=self.category.id,
category_name=self.category.name,
score=1,
@@ -590,13 +587,13 @@ def to_coco_prediction(self):
)
return coco_prediction
- def to_shapely_annotation(self):
+ def to_shapely_annotation(self) -> ShapelyAnnotation:
"""
Returns sahi.utils.shapely.ShapelyAnnotation representation of ObjectAnnotation.
"""
if self.mask:
shapely_annotation = ShapelyAnnotation.from_coco_segmentation(
- segmentation=self.mask.segmentation(),
+ segmentation=self.mask.segmentation,
)
else:
shapely_annotation = ShapelyAnnotation.from_coco_bbox(
diff --git a/sahi/auto_model.py b/sahi/auto_model.py
index a60f2d9d9..e66ab680b 100644
--- a/sahi/auto_model.py
+++ b/sahi/auto_model.py
@@ -3,7 +3,7 @@
from sahi.utils.file import import_model_class
MODEL_TYPE_TO_MODEL_CLASS_NAME = {
- "yolov8": "Yolov8DetectionModel",
+ "ultralytics": "UltralyticsDetectionModel",
"rtdetr": "RTDetrDetectionModel",
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
@@ -14,6 +14,8 @@
"yolov8onnx": "Yolov8OnnxDetectionModel",
}
+ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"]
+
class AutoDetectionModel:
@staticmethod
@@ -60,7 +62,8 @@ def from_pretrained(
Raises:
ImportError: If given {model_type} framework is not installed
"""
-
+ if model_type in ULTRALYTICS_MODEL_NAMES:
+ model_type = "ultralytics"
model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
DetectionModel = import_model_class(model_type, model_class_name)
diff --git a/sahi/models/__init__.py b/sahi/models/__init__.py
index 4f2a5710b..c85fa5da8 100644
--- a/sahi/models/__init__.py
+++ b/sahi/models/__init__.py
@@ -1 +1 @@
-from . import base, detectron2, huggingface, mmdet, torchvision, yolov5, yolov8onnx
+from . import base, detectron2, huggingface, mmdet, torchvision, ultralytics, yolov5, yolov8onnx
diff --git a/sahi/models/base.py b/sahi/models/base.py
index d1cd227d0..a434728dd 100644
--- a/sahi/models/base.py
+++ b/sahi/models/base.py
@@ -1,10 +1,11 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional
import numpy as np
+from sahi.prediction import ObjectPrediction
from sahi.utils.import_utils import is_available
from sahi.utils.torch import select_device as select_torch_device
@@ -173,7 +174,7 @@ def convert_original_predictions(
self._apply_category_remapping()
@property
- def object_prediction_list(self):
+ def object_prediction_list(self) -> List[ObjectPrediction]:
return self._object_prediction_list_per_image[0]
@property
diff --git a/sahi/models/rtdetr.py b/sahi/models/rtdetr.py
index d704eee59..91bc3250a 100644
--- a/sahi/models/rtdetr.py
+++ b/sahi/models/rtdetr.py
@@ -3,13 +3,13 @@
import logging
-from sahi.models.yolov8 import Yolov8DetectionModel
+from sahi.models.ultralytics import UltralyticsDetectionModel
from sahi.utils.import_utils import check_requirements
logger = logging.getLogger(__name__)
-class RTDetrDetectionModel(Yolov8DetectionModel):
+class RTDetrDetectionModel(UltralyticsDetectionModel):
def check_dependencies(self) -> None:
check_requirements(["ultralytics"])
diff --git a/sahi/models/yolov8.py b/sahi/models/ultralytics.py
similarity index 71%
rename from sahi/models/yolov8.py
rename to sahi/models/ultralytics.py
index b8311329d..dd421f5cc 100644
--- a/sahi/models/yolov8.py
+++ b/sahi/models/ultralytics.py
@@ -2,7 +2,7 @@
# Code written by AnNT, 2023.
import logging
-from typing import Any, Dict, List, Optional
+from typing import Any, List, Optional
import cv2
import numpy as np
@@ -13,11 +13,11 @@
from sahi.models.base import DetectionModel
from sahi.prediction import ObjectPrediction
from sahi.utils.compatibility import fix_full_shape_list, fix_shift_amount_list
-from sahi.utils.cv import get_coco_segmentation_from_bool_mask
+from sahi.utils.cv import get_coco_segmentation_from_bool_mask, get_coco_segmentation_from_obb_points
from sahi.utils.import_utils import check_requirements
-class Yolov8DetectionModel(DetectionModel):
+class UltralyticsDetectionModel(DetectionModel):
def check_dependencies(self) -> None:
check_requirements(["ultralytics"])
@@ -33,14 +33,14 @@ def load_model(self):
model.to(self.device)
self.set_model(model)
except Exception as e:
- raise TypeError("model_path is not a valid yolov8 model path: ", e)
+ raise TypeError("model_path is not a valid Ultralytics model path: ", e)
def set_model(self, model: Any):
"""
- Sets the underlying YOLOv8 model.
+ Sets the underlying Ultralytics model.
Args:
model: Any
- A YOLOv8 model
+ A Ultralytics model
"""
self.model = model
@@ -52,11 +52,9 @@ def set_model(self, model: Any):
def perform_inference(self, image: np.ndarray):
"""
Prediction is performed using self.model and the prediction result is set to self._original_predictions.
- If predictions have masks, each prediction is a tuple like (boxes, masks).
Args:
image: np.ndarray
A numpy array that contains the image to be predicted. 3 channel image should be in RGB order.
-
"""
from ultralytics.engine.results import Masks
@@ -86,8 +84,27 @@ def perform_inference(self, image: np.ndarray):
)
for result in prediction_result
]
-
- else: # If model doesn't do segmentation then no need to check masks
+ elif self.is_obb:
+ # For OBB task, get OBB points in xyxyxyxy format
+ prediction_result = [
+ (
+ # Get OBB data: xyxy, conf, cls
+ torch.cat(
+ [
+ result.obb.xyxy, # box coordinates
+ result.obb.conf.unsqueeze(-1), # confidence scores
+ result.obb.cls.unsqueeze(-1), # class ids
+ ],
+ dim=1,
+ )
+ if result.obb is not None
+ else torch.empty((0, 6), device=self.model.device),
+ # Get OBB points in (N, 4, 2) format
+ result.obb.xyxyxyxy if result.obb is not None else torch.empty((0, 4, 2), device=self.model.device),
+ )
+ for result in prediction_result
+ ]
+ else: # If model doesn't do segmentation or OBB then no need to check masks
# We do not filter results again as confidence threshold is already applied above
prediction_result = [result.boxes.data for result in prediction_result]
@@ -112,6 +129,13 @@ def has_mask(self):
"""
return self.model.overrides["task"] == "segment"
+ @property
+ def is_obb(self):
+ """
+ Returns if model output contains oriented bounding boxes
+ """
+ return self.model.overrides["task"] == "obb"
+
def _create_object_prediction_list_from_original_predictions(
self,
shift_amount_list: Optional[List[List[int]]] = [[0, 0]],
@@ -142,13 +166,13 @@ def _create_object_prediction_list_from_original_predictions(
full_shape = None if full_shape_list is None else full_shape_list[image_ind]
object_prediction_list = []
- # Extract boxes and optional masks
- if self.has_mask:
+ # Extract boxes and optional masks/obb
+ if self.has_mask or self.is_obb:
boxes = image_predictions[0].cpu().detach().numpy()
- masks = image_predictions[1].cpu().detach().numpy()
+ masks_or_points = image_predictions[1].cpu().detach().numpy()
else:
boxes = image_predictions.data.cpu().detach().numpy()
- masks = None
+ masks_or_points = None
# Process each prediction
for pred_ind, prediction in enumerate(boxes):
@@ -171,14 +195,20 @@ def _create_object_prediction_list_from_original_predictions(
logger.warning(f"ignoring invalid prediction with bbox: {bbox}")
continue
- # Get segmentation if available
+ # Get segmentation or OBB points
segmentation = None
- if masks is not None:
- bool_mask = masks[pred_ind]
- orig_width = self._original_shape[1]
- orig_height = self._original_shape[0]
- bool_mask = cv2.resize(bool_mask.astype(np.uint8), (orig_width, orig_height))
- segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
+ if masks_or_points is not None:
+ if self.has_mask:
+ bool_mask = masks_or_points[pred_ind]
+ # Resize mask to original image size
+ bool_mask = cv2.resize(
+ bool_mask.astype(np.uint8), (self._original_shape[1], self._original_shape[0])
+ )
+ segmentation = get_coco_segmentation_from_bool_mask(bool_mask)
+ else: # is_obb
+ obb_points = masks_or_points[pred_ind] # Get OBB points for this prediction
+ segmentation = get_coco_segmentation_from_obb_points(obb_points)
+
if len(segmentation) == 0:
continue
@@ -190,7 +220,7 @@ def _create_object_prediction_list_from_original_predictions(
segmentation=segmentation,
category_name=category_name,
shift_amount=shift_amount,
- full_shape=full_shape,
+ full_shape=self._original_shape[:2] if full_shape is None else full_shape, # (height, width)
)
object_prediction_list.append(object_prediction)
diff --git a/sahi/prediction.py b/sahi/prediction.py
index e8ca14260..5fca9648a 100644
--- a/sahi/prediction.py
+++ b/sahi/prediction.py
@@ -8,7 +8,7 @@
from PIL import Image
from sahi.annotation import ObjectAnnotation
-from sahi.utils.coco import CocoAnnotation, CocoPrediction
+from sahi.utils.coco import CocoPrediction
from sahi.utils.cv import read_image_as_pil, visualize_object_predictions
from sahi.utils.file import Path
diff --git a/sahi/utils/cv.py b/sahi/utils/cv.py
index c2006eb2b..693097aa9 100644
--- a/sahi/utils/cv.py
+++ b/sahi/utils/cv.py
@@ -167,7 +167,7 @@ def read_large_image(image_path: str):
return image0, use_cv2
-def read_image(image_path: str):
+def read_image(image_path: str) -> np.ndarray:
"""
Loads image as a numpy array from the given path.
@@ -184,7 +184,7 @@ def read_image(image_path: str):
return image
-def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = False):
+def read_image_as_pil(image: Union[Image.Image, str, np.ndarray], exif_fix: bool = False) -> Image.Image:
"""
Loads an image as PIL.Image.Image.
@@ -688,6 +688,30 @@ def get_bbox_from_coco_segmentation(coco_segmentation):
return [xmin, ymin, xmax, ymax]
+def get_coco_segmentation_from_obb_points(obb_points: np.ndarray) -> List[List[float]]:
+ """
+ Convert OBB (Oriented Bounding Box) points to COCO polygon format.
+
+ Args:
+ obb_points: np.ndarray
+ OBB points tensor from ultralytics.engine.results.OBB
+ Shape: (4, 2) containing 4 points with (x,y) coordinates each
+ Returns:
+ List[List[float]]: Polygon points in COCO format
+ [[x1, y1, x2, y2, x3, y3, x4, y4, x1, y1], [...], ...]
+ """
+ # Convert from (4,2) to [x1,y1,x2,y2,x3,y3,x4,y4] format
+ points = obb_points.reshape(-1).tolist()
+
+ # Create polygon from points and close it by repeating first point
+ polygons = []
+ # Add first point to end to close polygon
+ closed_polygon = points + [points[0], points[1]]
+ polygons.append(closed_polygon)
+
+ return polygons
+
+
def normalize_numpy_image(image: np.ndarray):
"""
Normalizes numpy image
diff --git a/sahi/utils/ultralytics.py b/sahi/utils/ultralytics.py
new file mode 100644
index 000000000..272668ca6
--- /dev/null
+++ b/sahi/utils/ultralytics.py
@@ -0,0 +1,121 @@
+import os
+from pathlib import Path
+from typing import Optional
+
+import requests
+from tqdm import tqdm
+
+YOLOV8N_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n.pt"
+YOLOV8N_SEG_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolov8n-seg.pt"
+YOLO11N_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n.pt"
+YOLO11N_SEG_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n-seg.pt"
+YOLO11N_OBB_WEIGHTS_URL = "https://github.com/ultralytics/assets/releases/download/v8.3.0/yolo11n-obb.pt"
+
+
+class UltralyticsTestConstants:
+ YOLOV8N_MODEL_PATH = "tests/data/models/yolov8n.pt"
+ YOLOV8N_SEG_MODEL_PATH = "tests/data/models/yolov8n-seg.pt"
+ YOLO11N_MODEL_PATH = "tests/data/models/yolo11n.pt"
+ YOLO11N_SEG_MODEL_PATH = "tests/data/models/yolo11n-seg.pt"
+ YOLO11N_OBB_MODEL_PATH = "tests/data/models/yolo11n-obb.pt"
+
+
+def download_file(url: str, save_path: str, chunk_size: int = 8192) -> None:
+ """
+ Downloads a file from a given URL to the specified path.
+
+ Args:
+ url: URL to download the file from
+ save_path: Path where the file will be saved
+ chunk_size: Size of chunks for downloading
+ """
+ response = requests.get(url, stream=True)
+ total_size = int(response.headers.get("content-length", 0))
+
+ # Ensure directory exists
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+
+ with open(save_path, "wb") as f, tqdm(
+ desc=os.path.basename(save_path),
+ total=total_size,
+ unit="B",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as pbar:
+ for data in response.iter_content(chunk_size=chunk_size):
+ size = f.write(data)
+ pbar.update(size)
+
+
+def download_yolov8n_model(destination_path: Optional[str] = None) -> str:
+ """Downloads YOLOv8n model if not already downloaded."""
+ if destination_path is None:
+ destination_path = UltralyticsTestConstants.YOLOV8N_MODEL_PATH
+
+ if not os.path.exists(destination_path):
+ download_file(YOLOV8N_WEIGHTS_URL, destination_path)
+ return destination_path
+
+
+def download_yolov8n_seg_model(destination_path: Optional[str] = None) -> str:
+ """Downloads YOLOv8n-seg model if not already downloaded."""
+ if destination_path is None:
+ destination_path = UltralyticsTestConstants.YOLOV8N_SEG_MODEL_PATH
+
+ if not os.path.exists(destination_path):
+ download_file(YOLOV8N_SEG_WEIGHTS_URL, destination_path)
+ return destination_path
+
+
+def download_yolo11n_model(destination_path: Optional[str] = None) -> str:
+ """Downloads YOLO11n model if not already downloaded."""
+ if destination_path is None:
+ destination_path = UltralyticsTestConstants.YOLO11N_MODEL_PATH
+
+ if not os.path.exists(destination_path):
+ download_file(YOLO11N_WEIGHTS_URL, destination_path)
+ return destination_path
+
+
+def download_yolo11n_seg_model(destination_path: Optional[str] = None) -> str:
+ """Downloads YOLO11n-seg model if not already downloaded."""
+ if destination_path is None:
+ destination_path = UltralyticsTestConstants.YOLO11N_SEG_MODEL_PATH
+
+ if not os.path.exists(destination_path):
+ download_file(YOLO11N_SEG_WEIGHTS_URL, destination_path)
+ return destination_path
+
+
+def download_yolo11n_obb_model(destination_path: Optional[str] = None) -> str:
+ """Downloads YOLO11n-obb model if not already downloaded."""
+ if destination_path is None:
+ destination_path = UltralyticsTestConstants.YOLO11N_OBB_MODEL_PATH
+
+ if not os.path.exists(destination_path):
+ download_file(YOLO11N_OBB_WEIGHTS_URL, destination_path)
+ return destination_path
+
+
+def download_model_weights(model_path: str) -> str:
+ """
+ Downloads model weights based on the model path.
+
+ Args:
+ model_path: Path or name of the model
+ Returns:
+ Path to the downloaded weights file
+ """
+ model_name = Path(model_path).stem
+ if model_name == "yolov8n":
+ return download_yolov8n_model()
+ elif model_name == "yolov8n-seg":
+ return download_yolov8n_seg_model()
+ elif model_name == "yolo11n":
+ return download_yolo11n_model()
+ elif model_name == "yolo11n-seg":
+ return download_yolo11n_seg_model()
+ elif model_name == "yolo11n-obb":
+ return download_yolo11n_obb_model()
+ else:
+ raise ValueError(f"Unknown model: {model_name}")
diff --git a/sahi/utils/yolov8.py b/sahi/utils/yolov8.py
deleted file mode 100644
index 8bfb556d4..000000000
--- a/sahi/utils/yolov8.py
+++ /dev/null
@@ -1,166 +0,0 @@
-import urllib.request
-from os import path
-from pathlib import Path
-from typing import Optional
-
-
-class Yolov8TestConstants:
- YOLOV8N_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n.pt"
- YOLOV8N_MODEL_PATH = "tests/data/models/yolov8/yolov8n.pt"
-
- YOLOV8S_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s.pt"
- YOLOV8S_MODEL_PATH = "tests/data/models/yolov8/yolov8s.pt"
-
- YOLOV8M_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m.pt"
- YOLOV8M_MODEL_PATH = "tests/data/models/yolov8/yolov8m.pt"
-
- YOLOV8L_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l.pt"
- YOLOV8L_MODEL_PATH = "tests/data/models/yolov8/yolov8l.pt"
-
- YOLOV8X_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x.pt"
- YOLOV8X_MODEL_PATH = "tests/data/models/yolov8/yolov8x.pt"
-
- YOLOV8N_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8n-seg.pt"
- YOLOV8N_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8n-seg.pt"
-
- YOLOV8S_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8s-seg.pt"
- YOLOV8S_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8s-seg.pt"
-
- YOLOV8M_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8m-seg.pt"
- YOLOV8M_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8m-seg.pt"
-
- YOLOV8L_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8l-seg.pt"
- YOLOV8L_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8l-seg.pt"
-
- YOLOV8X_SEG_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/yolov8x-seg.pt"
- YOLOV8X_SEG_MODEL_PATH = "tests/data/models/yolov8/yolov8x-seg.pt"
-
-
-def download_yolov8n_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8N_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8N_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8s_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8S_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8S_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8m_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8M_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8M_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8l_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8L_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8L_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8x_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8X_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8X_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8n_seg_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8N_SEG_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8s_seg_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8S_SEG_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8S_SEG_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8m_seg_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8M_SEG_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8M_SEG_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8l_seg_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8L_SEG_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8L_SEG_MODEL_URL,
- destination_path,
- )
-
-
-def download_yolov8x_seg_model(destination_path: Optional[str] = None):
- if destination_path is None:
- destination_path = Yolov8TestConstants.YOLOV8X_SEG_MODEL_PATH
-
- Path(destination_path).parent.mkdir(parents=True, exist_ok=True)
-
- if not path.exists(destination_path):
- urllib.request.urlretrieve(
- Yolov8TestConstants.YOLOV8X_SEG_MODEL_URL,
- destination_path,
- )
diff --git a/sahi/utils/yolov8onnx.py b/sahi/utils/yolov8onnx.py
index ec137470e..9b20da403 100644
--- a/sahi/utils/yolov8onnx.py
+++ b/sahi/utils/yolov8onnx.py
@@ -3,7 +3,7 @@
import numpy as np
-from sahi.utils.yolov8 import download_yolov8n_model
+from sahi.utils.ultralytics import download_yolov8n_model
class Yolov8ONNXTestConstants:
diff --git a/scripts/utils.py b/scripts/utils.py
index 5b2576736..a0b70c300 100644
--- a/scripts/utils.py
+++ b/scripts/utils.py
@@ -2,6 +2,8 @@
import shutil
import sys
+import click
+
def shell(command, exit_status=0):
"""
@@ -30,9 +32,15 @@ def validate_and_exit(expected_out_status=0, **kwargs):
fail_count = 0
for component, exit_status in kwargs.items():
if exit_status != expected_out_status:
- print(f"{component} failed.")
+ click.secho(f"{component} failed.", fg="red")
fail_count += 1
+
print_console_centered(f"{len(kwargs)-fail_count} success, {fail_count} failure")
+ click.secho("\nTo fix formatting issues:", fg="yellow")
+ click.secho("1. Install development dependencies:", fg="cyan")
+ click.secho(' pip install -e ."[dev]"', fg="green")
+ click.secho("\n2. Run code formatting:", fg="cyan")
+ click.secho(" python -m scripts.run_code_style format", fg="green")
sys.exit(1)
diff --git a/tests/test_predict.py b/tests/test_predict.py
index eb108039a..6c6f8f164 100644
--- a/tests/test_predict.py
+++ b/tests/test_predict.py
@@ -8,6 +8,7 @@
import numpy as np
from sahi.utils.cv import read_image
+from sahi.utils.ultralytics import UltralyticsTestConstants, download_yolo11n_model
MODEL_DEVICE = "cpu"
CONFIDENCE_THRESHOLD = 0.5
@@ -294,9 +295,118 @@ def test_get_sliced_prediction_yolov5(self):
num_car += 1
self.assertEqual(num_car, 11)
+ def test_get_prediction_yolo11(self):
+ from sahi.models.ultralytics import UltralyticsDetectionModel
+ from sahi.predict import get_prediction
+
+ # init model
+ download_yolo11n_model()
+
+ yolo11_detection_model = UltralyticsDetectionModel(
+ model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH,
+ confidence_threshold=CONFIDENCE_THRESHOLD,
+ device=MODEL_DEVICE,
+ category_remapping=None,
+ load_at_init=False,
+ image_size=IMAGE_SIZE,
+ )
+ yolo11_detection_model.load_model()
+
+ # prepare image
+ image_path = "tests/data/small-vehicles1.jpeg"
+ image = read_image(image_path)
+
+ # get full sized prediction
+ prediction_result = get_prediction(
+ image=image, detection_model=yolo11_detection_model, shift_amount=[0, 0], full_shape=None, postprocess=None
+ )
+ object_prediction_list = prediction_result.object_prediction_list
+
+ # compare
+ self.assertGreater(len(object_prediction_list), 0)
+ num_person = 0
+ for object_prediction in object_prediction_list:
+ if object_prediction.category.name == "person":
+ num_person += 1
+ self.assertEqual(num_person, 0)
+ num_truck = 0
+ for object_prediction in object_prediction_list:
+ if object_prediction.category.name == "truck":
+ num_truck += 1
+ self.assertEqual(num_truck, 0)
+ num_car = 0
+ for object_prediction in object_prediction_list:
+ if object_prediction.category.name == "car":
+ num_car += 1
+ self.assertGreater(num_car, 0)
+
+ def test_get_sliced_prediction_yolo11(self):
+ from sahi.models.ultralytics import UltralyticsDetectionModel
+ from sahi.predict import get_sliced_prediction
+
+ # init model
+ download_yolo11n_model()
+
+ yolo11_detection_model = UltralyticsDetectionModel(
+ model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH,
+ confidence_threshold=CONFIDENCE_THRESHOLD,
+ device=MODEL_DEVICE,
+ category_remapping=None,
+ load_at_init=False,
+ image_size=IMAGE_SIZE,
+ )
+ yolo11_detection_model.load_model()
+
+ # prepare image
+ image_path = "tests/data/small-vehicles1.jpeg"
+
+ slice_height = 512
+ slice_width = 512
+ overlap_height_ratio = 0.1
+ overlap_width_ratio = 0.2
+ postprocess_type = "GREEDYNMM"
+ match_metric = "IOS"
+ match_threshold = 0.5
+ class_agnostic = True
+
+ # get sliced prediction
+ prediction_result = get_sliced_prediction(
+ image=image_path,
+ detection_model=yolo11_detection_model,
+ slice_height=slice_height,
+ slice_width=slice_width,
+ overlap_height_ratio=overlap_height_ratio,
+ overlap_width_ratio=overlap_width_ratio,
+ perform_standard_pred=False,
+ postprocess_type=postprocess_type,
+ postprocess_match_threshold=match_threshold,
+ postprocess_match_metric=match_metric,
+ postprocess_class_agnostic=class_agnostic,
+ )
+ object_prediction_list = prediction_result.object_prediction_list
+
+ # compare
+ self.assertGreater(len(object_prediction_list), 0)
+ num_person = 0
+ for object_prediction in object_prediction_list:
+ if object_prediction.category.name == "person":
+ num_person += 1
+ self.assertEqual(num_person, 0)
+ num_truck = 0
+ for object_prediction in object_prediction_list:
+ if object_prediction.category.name == "truck":
+ num_truck += 1
+ self.assertEqual(num_truck, 0)
+ num_car = 0
+ for object_prediction in object_prediction_list:
+ if object_prediction.category.name == "car":
+ num_car += 1
+ self.assertGreater(num_car, 0)
+
def test_coco_json_prediction(self):
from sahi.predict import predict
from sahi.utils.mmdet import MmdetTestConstants, download_mmdet_yolox_tiny_model
+ from sahi.utils.ultralytics import UltralyticsTestConstants, download_yolo11n_model
from sahi.utils.yolov5 import Yolov5TestConstants, download_yolov5n_model
# init model
@@ -382,6 +492,45 @@ def test_coco_json_prediction(self):
verbose=1,
)
+ # init model
+ download_yolo11n_model()
+
+ # prepare paths
+ dataset_json_path = "tests/data/coco_utils/terrain_all_coco.json"
+ source = "tests/data/coco_utils/"
+ project_dir = "tests/data/predict_result"
+
+ # get sliced prediction
+ if os.path.isdir(project_dir):
+ shutil.rmtree(project_dir, ignore_errors=True)
+ predict(
+ model_type="ultralytics",
+ model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH,
+ model_config_path=None,
+ model_confidence_threshold=CONFIDENCE_THRESHOLD,
+ model_device=MODEL_DEVICE,
+ model_category_mapping=None,
+ model_category_remapping=None,
+ source=source,
+ no_sliced_prediction=False,
+ no_standard_prediction=True,
+ slice_height=512,
+ slice_width=512,
+ overlap_height_ratio=0.2,
+ overlap_width_ratio=0.2,
+ postprocess_type=postprocess_type,
+ postprocess_match_metric=match_metric,
+ postprocess_match_threshold=match_threshold,
+ postprocess_class_agnostic=class_agnostic,
+ novisual=True,
+ export_pickle=False,
+ export_crop=False,
+ dataset_json_path=dataset_json_path,
+ project=project_dir,
+ name="exp",
+ verbose=1,
+ )
+
def test_video_prediction(self):
from os import path
diff --git a/tests/test_ultralyticsmodel.py b/tests/test_ultralyticsmodel.py
new file mode 100644
index 000000000..cdb626d25
--- /dev/null
+++ b/tests/test_ultralyticsmodel.py
@@ -0,0 +1,221 @@
+# OBSS SAHI Tool
+# Code written by AnNT, 2024.
+
+import unittest
+
+from sahi.utils.cv import read_image
+from sahi.utils.file import download_from_url
+from sahi.utils.ultralytics import (
+ UltralyticsTestConstants,
+ download_yolo11n_model,
+ download_yolo11n_obb_model,
+ download_yolo11n_seg_model,
+ download_yolov8n_model,
+ download_yolov8n_seg_model,
+)
+
+MODEL_DEVICE = "cpu"
+CONFIDENCE_THRESHOLD = 0.3
+IMAGE_SIZE = 640
+
+
+class TestUltralyticsDetectionModel(unittest.TestCase):
+ def test_load_yolov8_model(self):
+ from sahi.models.ultralytics import UltralyticsDetectionModel
+
+ download_yolov8n_model()
+
+ detection_model = UltralyticsDetectionModel(
+ model_path=UltralyticsTestConstants.YOLOV8N_MODEL_PATH,
+ confidence_threshold=CONFIDENCE_THRESHOLD,
+ device=MODEL_DEVICE,
+ category_remapping=None,
+ load_at_init=True,
+ )
+
+ self.assertNotEqual(detection_model.model, None)
+ self.assertTrue(hasattr(detection_model.model, "task"))
+ self.assertEqual(detection_model.model.task, "detect")
+
+ def test_load_yolo11_model(self):
+ from sahi.models.ultralytics import UltralyticsDetectionModel
+
+ download_yolo11n_model()
+
+ detection_model = UltralyticsDetectionModel(
+ model_path=UltralyticsTestConstants.YOLO11N_MODEL_PATH,
+ confidence_threshold=CONFIDENCE_THRESHOLD,
+ device=MODEL_DEVICE,
+ category_remapping=None,
+ load_at_init=True,
+ )
+
+ self.assertNotEqual(detection_model.model, None)
+ self.assertTrue(hasattr(detection_model.model, "task"))
+ self.assertEqual(detection_model.model.task, "detect")
+
+ def test_perform_inference_yolov8(self):
+ from sahi.models.ultralytics import UltralyticsDetectionModel
+
+ # init model
+ download_yolov8n_model()
+
+ detection_model = UltralyticsDetectionModel(
+ model_path=UltralyticsTestConstants.YOLOV8N_MODEL_PATH,
+ confidence_threshold=CONFIDENCE_THRESHOLD,
+ device=MODEL_DEVICE,
+ category_remapping=None,
+ load_at_init=True,
+ image_size=IMAGE_SIZE,
+ )
+
+ # prepare image
+ image_path = "tests/data/small-vehicles1.jpeg"
+ image = read_image(image_path)
+
+ # perform inference
+ detection_model.perform_inference(image)
+ original_predictions = detection_model.original_predictions
+
+ boxes = original_predictions[0].data
+
+ # find box of first car detection with conf greater than 0.5
+ for box in boxes:
+ if box[5].item() == 2: # if category car
+ if box[4].item() > 0.5:
+ break
+
+ # compare
+ desired_bbox = [448, 309, 497, 342]
+ predicted_bbox = list(map(int, box[:4].tolist()))
+ margin = 2
+ for ind, point in enumerate(predicted_bbox):
+ assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
+ self.assertEqual(len(detection_model.category_names), 80)
+ for box in boxes:
+ self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD)
+
+ def test_perform_inference_yolo11(self):
+ from sahi.models.ultralytics import UltralyticsDetectionModel
+
+ # init model
+ detection_model = UltralyticsDetectionModel(
+ model_path="yolo11n.pt",
+ confidence_threshold=CONFIDENCE_THRESHOLD,
+ device=MODEL_DEVICE,
+ category_remapping=None,
+ load_at_init=True,
+ image_size=IMAGE_SIZE,
+ )
+
+ # prepare image
+ image_path = "tests/data/small-vehicles1.jpeg"
+ image = read_image(image_path)
+
+ # perform inference
+ detection_model.perform_inference(image)
+ original_predictions = detection_model.original_predictions
+
+ boxes = original_predictions[0].data
+
+ # verify predictions
+ self.assertEqual(len(detection_model.category_names), 80)
+ for box in boxes:
+ self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD)
+
+ def test_yolo11_segmentation(self):
+ from sahi.models.ultralytics import UltralyticsDetectionModel
+
+ # init model
+ download_yolo11n_seg_model()
+
+ detection_model = UltralyticsDetectionModel(
+ model_path=UltralyticsTestConstants.YOLO11N_SEG_MODEL_PATH,
+ confidence_threshold=CONFIDENCE_THRESHOLD,
+ device=MODEL_DEVICE,
+ category_remapping=None,
+ load_at_init=True,
+ image_size=IMAGE_SIZE,
+ )
+
+ # Verify model properties
+ self.assertTrue(detection_model.has_mask)
+ self.assertEqual(detection_model.model.task, "segment")
+
+ # prepare image and run inference
+ image_path = "tests/data/small-vehicles1.jpeg"
+ image = read_image(image_path)
+ detection_model.perform_inference(image)
+
+ # Verify segmentation output
+ original_predictions = detection_model.original_predictions
+ boxes = original_predictions[0][0] # Boxes
+ masks = original_predictions[0][1] # Masks
+
+ self.assertGreater(len(boxes), 0)
+ self.assertEqual(masks.shape[0], len(boxes)) # One mask per box
+ self.assertEqual(len(masks.shape), 3) # (num_predictions, height, width)
+
+ def test_yolo11_obb(self):
+ from sahi.models.ultralytics import UltralyticsDetectionModel
+
+ # init model
+ download_yolo11n_obb_model()
+
+ detection_model = UltralyticsDetectionModel(
+ model_path=UltralyticsTestConstants.YOLO11N_OBB_MODEL_PATH,
+ confidence_threshold=CONFIDENCE_THRESHOLD,
+ device=MODEL_DEVICE,
+ category_remapping=None,
+ load_at_init=True,
+ image_size=640,
+ )
+
+ # Verify model task
+ self.assertTrue(detection_model.is_obb)
+ self.assertEqual(detection_model.model.task, "obb")
+
+ # prepare image and run inference
+ image_url = "https://ultralytics.com/images/boats.jpg"
+ image_path = "tests/data/boats.jpg"
+ download_from_url(image_url, to_path=image_path)
+ image = read_image(image_path)
+ detection_model.perform_inference(image)
+
+ # Verify OBB predictions
+ original_predictions = detection_model.original_predictions
+ boxes = original_predictions[0][0] # Original box data
+ obb_points = original_predictions[0][1] # OBB points in xyxyxyxy format
+
+ self.assertGreater(len(boxes), 0)
+ # Check box format: x1,y1,x2,y2,conf,cls
+ self.assertEqual(boxes.shape[1], 6)
+ # Check OBB points format
+ self.assertEqual(obb_points.shape[1:], (4, 2)) # (N, 4, 2) format
+
+ # Convert predictions and verify
+ detection_model.convert_original_predictions()
+ object_prediction_list = detection_model.object_prediction_list
+
+ # Verify converted predictions
+ self.assertEqual(len(object_prediction_list), len(boxes))
+ for object_prediction in object_prediction_list:
+ # Verify confidence threshold
+ self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD)
+
+ coco_segmentation = object_prediction.mask.segmentation
+ # Verify segmentation exists (converted from OBB)
+ self.assertIsNotNone(coco_segmentation)
+ # Verify segmentation is a list of points
+ self.assertTrue(isinstance(coco_segmentation, list))
+ self.assertGreater(len(coco_segmentation), 0)
+ # Verify each segment is a valid closed polygon
+ for segment in coco_segmentation:
+ self.assertEqual(len(segment), 10) # 4 points + 1 closing point (x,y coordinates)
+ # Verify polygon is closed (first point equals last point)
+ self.assertEqual(segment[0], segment[-2]) # x coordinate
+ self.assertEqual(segment[1], segment[-1]) # y coordinate
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/test_yolov8model.py b/tests/test_yolov8model.py
deleted file mode 100644
index f578113d7..000000000
--- a/tests/test_yolov8model.py
+++ /dev/null
@@ -1,236 +0,0 @@
-# OBSS SAHI Tool
-# Code written by Fatih C Akyon, 2020.
-
-import unittest
-
-import numpy as np
-
-from sahi.utils.cv import read_image
-from sahi.utils.yolov8 import Yolov8TestConstants, download_yolov8n_model, download_yolov8n_seg_model
-
-MODEL_DEVICE = "cpu"
-CONFIDENCE_THRESHOLD = 0.3
-IMAGE_SIZE = 640
-
-
-class TestYolov8DetectionModel(unittest.TestCase):
- def test_load_model(self):
- from sahi.models.yolov8 import Yolov8DetectionModel
-
- download_yolov8n_model()
-
- yolov8_detection_model = Yolov8DetectionModel(
- model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
- confidence_threshold=CONFIDENCE_THRESHOLD,
- device=MODEL_DEVICE,
- category_remapping=None,
- load_at_init=True,
- )
-
- self.assertNotEqual(yolov8_detection_model.model, None)
-
- def test_set_model(self):
- from ultralytics import YOLO
-
- from sahi.models.yolov8 import Yolov8DetectionModel
-
- download_yolov8n_model()
-
- yolo_model = YOLO(Yolov8TestConstants.YOLOV8N_MODEL_PATH)
-
- yolov8_detection_model = Yolov8DetectionModel(
- model=yolo_model,
- confidence_threshold=CONFIDENCE_THRESHOLD,
- device=MODEL_DEVICE,
- category_remapping=None,
- load_at_init=True,
- )
-
- self.assertNotEqual(yolov8_detection_model.model, None)
-
- def test_perform_inference(self):
- from sahi.models.yolov8 import Yolov8DetectionModel
-
- # init model
- download_yolov8n_model()
-
- yolov8_detection_model = Yolov8DetectionModel(
- model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
- confidence_threshold=CONFIDENCE_THRESHOLD,
- device=MODEL_DEVICE,
- category_remapping=None,
- load_at_init=True,
- image_size=IMAGE_SIZE,
- )
-
- # prepare image
- image_path = "tests/data/small-vehicles1.jpeg"
- image = read_image(image_path)
-
- # perform inference
- yolov8_detection_model.perform_inference(image)
- original_predictions = yolov8_detection_model.original_predictions
-
- boxes = original_predictions[0].data
-
- # find box of first car detection with conf greater than 0.5
- for box in boxes:
- if box[5].item() == 2: # if category car
- if box[4].item() > 0.5:
- break
-
- # compare
- desired_bbox = [448, 309, 497, 342]
- predicted_bbox = list(map(int, box[:4].tolist()))
- margin = 2
- for ind, point in enumerate(predicted_bbox):
- assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
- self.assertEqual(len(yolov8_detection_model.category_names), 80)
- for box in boxes:
- self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD)
-
- def test_convert_original_predictions(self):
- from sahi.models.yolov8 import Yolov8DetectionModel
-
- # init model
- download_yolov8n_model()
-
- yolov8_detection_model = Yolov8DetectionModel(
- model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
- confidence_threshold=CONFIDENCE_THRESHOLD,
- device=MODEL_DEVICE,
- category_remapping=None,
- load_at_init=True,
- image_size=IMAGE_SIZE,
- )
-
- # prepare image
- image_path = "tests/data/small-vehicles1.jpeg"
- image = read_image(image_path)
-
- # get raw predictions for reference
- original_results = yolov8_detection_model.model.predict(image_path, conf=CONFIDENCE_THRESHOLD)[0].boxes
- num_results = len(original_results)
-
- # perform inference
- yolov8_detection_model.perform_inference(image)
-
- # convert predictions to ObjectPrediction list
- yolov8_detection_model.convert_original_predictions()
- object_prediction_list = yolov8_detection_model.object_prediction_list
-
- # compare
- self.assertEqual(len(object_prediction_list), num_results)
-
- # loop through predictions and check that they are equal
- for i in range(num_results):
- desired_bbox = [
- original_results[i].xyxy[0][0],
- original_results[i].xyxy[0][1],
- original_results[i].xywh[0][2],
- original_results[i].xywh[0][3],
- ]
- desired_cat_id = int(original_results[i].cls[0])
- self.assertEqual(object_prediction_list[i].category.id, desired_cat_id)
- predicted_bbox = object_prediction_list[i].bbox.to_xywh()
- margin = 2
- for ind, point in enumerate(predicted_bbox):
- assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
- for object_prediction in object_prediction_list:
- self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD)
-
- def test_perform_inference_with_mask_output(self):
- from sahi.models.yolov8 import Yolov8DetectionModel
-
- # init model
- download_yolov8n_seg_model()
-
- yolov8_detection_model = Yolov8DetectionModel(
- model_path=Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH,
- confidence_threshold=CONFIDENCE_THRESHOLD,
- device=MODEL_DEVICE,
- category_remapping=None,
- load_at_init=True,
- image_size=IMAGE_SIZE,
- )
- # prepare image
- image_path = "tests/data/small-vehicles1.jpeg"
- image = read_image(image_path)
-
- # perform inference
- yolov8_detection_model.perform_inference(image)
- original_predictions = yolov8_detection_model.original_predictions
- boxes = original_predictions[0][0]
- masks = original_predictions[0][1]
-
- # find box of first car detection with conf greater than 0.5
- for box in boxes:
- if box[5].item() == 2: # if category car
- if box[4].item() > 0.5:
- break
-
- # compare
- desired_bbox = [320, 323, 380, 365]
- predicted_bbox = list(map(int, box[:4].tolist()))
- margin = 3
- for ind, point in enumerate(predicted_bbox):
- assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
- self.assertEqual(len(yolov8_detection_model.category_names), 80)
- for box in boxes:
- self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD)
- self.assertEqual(masks.shape, (12, 352, 640))
- self.assertEqual(masks.shape[0], len(boxes))
-
- def test_convert_original_predictions_with_mask_output(self):
- from sahi.models.yolov8 import Yolov8DetectionModel
-
- # init model
- download_yolov8n_seg_model()
-
- yolov8_detection_model = Yolov8DetectionModel(
- model_path=Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH,
- confidence_threshold=CONFIDENCE_THRESHOLD,
- device=MODEL_DEVICE,
- category_remapping=None,
- load_at_init=True,
- image_size=IMAGE_SIZE,
- )
-
- # prepare image
- image_path = "tests/data/small-vehicles1.jpeg"
- image = read_image(image_path)
-
- # get raw predictions for reference
- original_results = yolov8_detection_model.model.predict(image_path, conf=CONFIDENCE_THRESHOLD)[0].boxes
- num_results = len(original_results)
-
- # perform inference
- yolov8_detection_model.perform_inference(image)
-
- # convert predictions to ObjectPrediction list
- yolov8_detection_model.convert_original_predictions(full_shape=(image.shape[0], image.shape[1]))
- object_prediction_list = yolov8_detection_model.object_prediction_list
-
- # compare
- self.assertEqual(len(object_prediction_list), num_results)
-
- # loop through predictions and check that they are equal
- for i in range(num_results):
- desired_bbox = [
- original_results[i].xyxy[0][0],
- original_results[i].xyxy[0][1],
- original_results[i].xywh[0][2],
- original_results[i].xywh[0][3],
- ]
- desired_cat_id = int(original_results[i].cls[0])
- self.assertEqual(object_prediction_list[i].category.id, desired_cat_id)
- predicted_bbox = object_prediction_list[i].bbox.to_xywh()
- margin = 20 # Margin high because for some reason some original predictions are really poor
- for ind, point in enumerate(predicted_bbox):
- assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin
- for object_prediction in object_prediction_list:
- self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD)
-
-
-if __name__ == "__main__":
- unittest.main()