Skip to content

Commit

Permalink
Merge branch 'obss:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
gguzzy authored Dec 21, 2024
2 parents 69e4be5 + c9469f5 commit b508bce
Show file tree
Hide file tree
Showing 22 changed files with 646 additions and 490 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ jobs:
run: >
pip install pycocotools==2.0.7
- name: Install ultralytics(8.0.207)
- name: Install ultralytics(8.3.50)
run: >
pip install ultralytics==8.0.207
pip install ultralytics==8.3.50
- name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
- name: Unittest for SAHI+YOLO11/RTDETR/MMDET/HuggingFace/Torchvision on all platforms
run: |
python -m unittest
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/package_testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ jobs:
run: >
pip install pycocotools==2.0.7
- name: Install ultralytics(8.0.207)
- name: Install ultralytics(8.3.50)
run: >
pip install ultralytics==8.0.207
pip install ultralytics==8.3.50
- name: Install latest SAHI package
run: >
pip install --upgrade --force-reinstall sahi
- name: Unittest for SAHI+YOLOV5/MMDET/Detectron2 on all platforms
- name: Unittest for SAHI+YOLO11/RTDETR/MMDET/HuggingFace/Torchvision on all platforms
run: |
python -m unittest
Expand Down
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,13 @@ Object detection and instance segmentation are by far the most important applica

- [Slicing operation notebook](demo/slicing.ipynb)

- `YOLOX` + `SAHI` demo: <a href="https://huggingface.co/spaces/fcakyon/sahi-yolox"><img src="https://raw.githubusercontent.com/obss/sahi/main/resources/hf_spaces_badge.svg" alt="sahi-yolox"></a> (RECOMMENDED)
- `YOLOX` + `SAHI` demo: <a href="https://huggingface.co/spaces/fcakyon/sahi-yolox"><img src="https://raw.githubusercontent.com/obss/sahi/main/resources/hf_spaces_badge.svg" alt="sahi-yolox"></a>

- `YOLO11` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_ultralytics.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-yolov8"></a> (NEW)

- `RT-DETR` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_rtdetr.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-rtdetr"></a> (NEW)

- `YOLOv8` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_yolov8.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-yolov8"></a>
- `YOLOv8` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_ultralytics.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-yolov8"></a>

- `DeepSparse` + `SAHI` walkthrough: <a href="https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_sparse_yolov5.ipynb"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="sahi-deepsparse"></a>

Expand Down Expand Up @@ -141,7 +143,7 @@ pip install yolov5==7.0.13
- Install your desired detection framework (ultralytics):

```console
pip install ultralytics==8.0.207
pip install ultralytics==8.3.50
```

- Install your desired detection framework (mmdet):
Expand Down Expand Up @@ -228,9 +230,9 @@ If you use this package in your work, please cite it as:

## <div align="center">Contributing</div>

`sahi` library currently supports all [YOLOv5 models](https://github.com/ultralytics/yolov5/releases), [MMDetection models](https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md), [Detectron2 models](https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md), and [HuggingFace object detection models](https://huggingface.co/models?pipeline_tag=object-detection&sort=downloads). Moreover, it is easy to add new frameworks.
`sahi` library currently supports all [Ultralytics (YOLOv8/v10/v11/RTDETR) models](https://github.com/ultralytics/ultralytics), [MMDetection models](https://github.com/open-mmlab/mmdetection/blob/master/docs/en/model_zoo.md), [Detectron2 models](https://github.com/facebookresearch/detectron2/blob/main/MODEL_ZOO.md), and [HuggingFace object detection models](https://huggingface.co/models?pipeline_tag=object-detection&sort=downloads). Moreover, it is easy to add new frameworks.

All you need to do is, create a new .py file under [sahi/models/](https://github.com/obss/sahi/tree/main/sahi/models) folder and create a new class in that .py file that implements [DetectionModel class](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/base.py#L12). You can take the [MMDetection wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/mmdet.py#L18) or [YOLOv5 wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/yolov5.py#L17) as a reference.
All you need to do is, create a new .py file under [sahi/models/](https://github.com/obss/sahi/tree/main/sahi/models) folder and create a new class in that .py file that implements [DetectionModel class](https://github.com/obss/sahi/blob/aaeb57c39780a5a32c4de2848e54df9a874df58b/sahi/models/base.py#L12). You can take the [MMDetection wrapper](https://github.com/obss/sahi/blob/aaeb57c39780a5a32c4de2848e54df9a874df58b/sahi/models/mmdet.py#L91) or [YOLOv5 wrapper](https://github.com/obss/sahi/blob/7e48bdb6afda26f977b763abdd7d8c9c170636bd/sahi/models/yolov5.py#L17) as a reference.

Before opening a PR:

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@
"outputs": [],
"source": [
"# arrange an instance segmentation model for test\n",
"from sahi.utils.yolov8 import (\n",
" download_yolov8s_model, download_yolov8s_seg_model\n",
"from sahi.utils.ultralytics import (\n",
" download_yolo11n_model, download_yolo11n_seg_model,\n",
" # download_yolov8n_model, download_yolov8n_seg_model\n",
")\n",
"\n",
"from sahi import AutoDetectionModel\n",
Expand All @@ -80,9 +81,10 @@
"metadata": {},
"outputs": [],
"source": [
"# download YOLOV5S6 model to 'models/yolov5s6.pt'\n",
"yolov8_model_path = \"models/yolov8s.pt\"\n",
"download_yolov8s_model(yolov8_model_path)\n",
"yolo11n_model_path = \"models/yolov11n.pt\"\n",
"download_yolo11n_model(yolo11n_model_path)\n",
"# yolov8n_model_path = \"models/yolov8n.pt\"\n",
"# download_yolov8n_model(yolov8n_model_path)\n",
"\n",
"# download test images into demo_data folder\n",
"download_from_url('https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/small-vehicles1.jpeg', 'demo_data/small-vehicles1.jpeg')\n",
Expand All @@ -94,7 +96,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Standard Inference with a YOLOv8 Model"
"## 1. Standard Inference with a YOLOv8/YOLO11 Model"
]
},
{
Expand All @@ -111,8 +113,8 @@
"outputs": [],
"source": [
"detection_model = AutoDetectionModel.from_pretrained(\n",
" model_type='yolov8',\n",
" model_path=yolov8_model_path,\n",
" model_type='yolo11', # or 'yolov8'\n",
" model_path=yolo11n_model_path,\n",
" confidence_threshold=0.3,\n",
" device=\"cpu\", # or 'cuda:0'\n",
")"
Expand Down Expand Up @@ -185,7 +187,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Sliced Inference with a YOLOv8 Model"
"## 2. Sliced Inference with a YOLOv8/YOLO11 Model"
]
},
{
Expand Down Expand Up @@ -466,8 +468,8 @@
"metadata": {},
"outputs": [],
"source": [
"model_type = \"yolov8\"\n",
"model_path = yolov8_model_path\n",
"model_type = \"yolo11\"\n",
"model_path = yolo11n_model_path\n",
"model_device = \"cpu\" # or 'cuda:0'\n",
"model_confidence_threshold = 0.4\n",
"\n",
Expand Down Expand Up @@ -599,9 +601,10 @@
"metadata": {},
"outputs": [],
"source": [
"#download YOLOV8S model to 'models/yolov8s.pt'\n",
"yolov8_seg_model_path = \"models/yolov8s-seg.pt\"\n",
"download_yolov8s_seg_model(yolov8_seg_model_path)"
"yolo11n_seg_model_path = \"models/yolov11n-seg.pt\"\n",
"download_yolo11n_seg_model(yolo11n_seg_model_path)\n",
"# yolov8n_seg_model_path = \"models/yolov8n-seg.pt\"\n",
"# download_yolov8n_seg_model(yolov8n_seg_model_path)\n"
]
},
{
Expand All @@ -611,8 +614,8 @@
"outputs": [],
"source": [
"detection_model_seg = AutoDetectionModel.from_pretrained(\n",
" model_type='yolov8',\n",
" model_path=yolov8_seg_model_path,\n",
" model_type='yolo11', # or 'yolov8'\n",
" model_path=yolo11n_seg_model_path,\n",
" confidence_threshold=0.3,\n",
" device=\"cpu\", # or 'cuda:0'\n",
")"
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,3 @@ fire
terminaltables
requests
click
numpy<2.0.0
2 changes: 1 addition & 1 deletion sahi/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.11.19"
__version__ = "0.11.20"

from sahi.annotation import BoundingBox, Category, Mask
from sahi.auto_model import AutoDetectionModel
Expand Down
45 changes: 21 additions & 24 deletions sahi/annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ class Mask:
@classmethod
def from_float_mask(
cls,
mask,
full_shape=None,
mask: np.ndarray,
full_shape: List[int],
mask_threshold: float = 0.5,
shift_amount: list = [0, 0],
):
Expand All @@ -144,7 +144,7 @@ def from_float_mask(
shift_amount: List
To shift the box and mask predictions from sliced image
to full sized image, should be in the form of [shift_x, shift_y]
full_shape: List
full_shape: List[int]
Size of the full image after shifting, should be in the form of [height, width]
"""
bool_mask = mask > mask_threshold
Expand All @@ -156,8 +156,8 @@ def from_float_mask(

def __init__(
self,
segmentation,
full_shape=None,
segmentation: List[List[float]],
full_shape: List[int],
shift_amount: list = [0, 0],
):
"""
Expand All @@ -170,9 +170,9 @@ def __init__(
[x1, y1, x2, y2, x3, y3, ...],
...
]
full_shape: List
full_shape: List[int]
Size of the full image, should be in the form of [height, width]
shift_amount: List
shift_amount: List[int]
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
Expand All @@ -195,17 +195,17 @@ def __init__(
@classmethod
def from_bool_mask(
cls,
bool_mask=None,
full_shape=None,
bool_mask: np.ndarray,
full_shape: List[int],
shift_amount: list = [0, 0],
):
"""
Args:
bool_mask: np.ndarray with bool elements
2D mask of object, should have a shape of height*width
full_shape: List
full_shape: List[int]
Size of the full image, should be in the form of [height, width]
shift_amount: List
shift_amount: List[int]
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
Expand Down Expand Up @@ -420,7 +420,7 @@ def from_coco_annotation_dict(
@classmethod
def from_shapely_annotation(
cls,
annotation,
annotation: ShapelyAnnotation,
full_shape: List[int],
category_id: Optional[int] = None,
category_name: Optional[str] = None,
Expand All @@ -441,12 +441,9 @@ def from_shapely_annotation(
To shift the box and mask predictions from sliced image to full
sized image, should be in the form of [shift_x, shift_y]
"""
bool_mask = get_bool_mask_from_coco_segmentation(
annotation.to_coco_segmentation(), width=full_shape[1], height=full_shape[0]
)
return cls(
category_id=category_id,
bool_mask=bool_mask,
segmentation=annotation.to_coco_segmentation(),
category_name=category_name,
shift_amount=shift_amount,
full_shape=full_shape,
Expand Down Expand Up @@ -512,6 +509,8 @@ def __init__(
raise ValueError("category_id must be an integer")
if (bbox is None) and (segmentation is None):
raise ValueError("you must provide a bbox or segmentation")

self.mask: Mask | None = None
if segmentation is not None:
self.mask = Mask(
segmentation=segmentation,
Expand All @@ -524,8 +523,6 @@ def __init__(
bbox = bbox_from_segmentation
else:
raise ValueError("Invalid segmentation mask.")
else:
self.mask = None

# if bbox is a numpy object, convert it to python List[float]
if type(bbox).__module__ == "numpy":
Expand All @@ -552,13 +549,13 @@ def __init__(

self.merged = None

def to_coco_annotation(self):
def to_coco_annotation(self) -> CocoAnnotation:
"""
Returns sahi.utils.coco.CocoAnnotation representation of ObjectAnnotation.
"""
if self.mask:
coco_annotation = CocoAnnotation.from_coco_segmentation(
segmentation=self.mask.segmentation(),
segmentation=self.mask.segmentation,
category_id=self.category.id,
category_name=self.category.name,
)
Expand All @@ -570,13 +567,13 @@ def to_coco_annotation(self):
)
return coco_annotation

def to_coco_prediction(self):
def to_coco_prediction(self) -> CocoPrediction:
"""
Returns sahi.utils.coco.CocoPrediction representation of ObjectAnnotation.
"""
if self.mask:
coco_prediction = CocoPrediction.from_coco_segmentation(
segmentation=self.mask.segmentation(),
segmentation=self.mask.segmentation,
category_id=self.category.id,
category_name=self.category.name,
score=1,
Expand All @@ -590,13 +587,13 @@ def to_coco_prediction(self):
)
return coco_prediction

def to_shapely_annotation(self):
def to_shapely_annotation(self) -> ShapelyAnnotation:
"""
Returns sahi.utils.shapely.ShapelyAnnotation representation of ObjectAnnotation.
"""
if self.mask:
shapely_annotation = ShapelyAnnotation.from_coco_segmentation(
segmentation=self.mask.segmentation(),
segmentation=self.mask.segmentation,
)
else:
shapely_annotation = ShapelyAnnotation.from_coco_bbox(
Expand Down
7 changes: 5 additions & 2 deletions sahi/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sahi.utils.file import import_model_class

MODEL_TYPE_TO_MODEL_CLASS_NAME = {
"yolov8": "Yolov8DetectionModel",
"ultralytics": "UltralyticsDetectionModel",
"rtdetr": "RTDetrDetectionModel",
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
Expand All @@ -14,6 +14,8 @@
"yolov8onnx": "Yolov8OnnxDetectionModel",
}

ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"]


class AutoDetectionModel:
@staticmethod
Expand Down Expand Up @@ -60,7 +62,8 @@ def from_pretrained(
Raises:
ImportError: If given {model_type} framework is not installed
"""

if model_type in ULTRALYTICS_MODEL_NAMES:
model_type = "ultralytics"
model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
DetectionModel = import_model_class(model_type, model_class_name)

Expand Down
2 changes: 1 addition & 1 deletion sahi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import base, detectron2, huggingface, mmdet, torchvision, yolov5, yolov8onnx
from . import base, detectron2, huggingface, mmdet, torchvision, ultralytics, yolov5, yolov8onnx
5 changes: 3 additions & 2 deletions sahi/models/base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# OBSS SAHI Tool
# Code written by Fatih C Akyon, 2020.

from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional

import numpy as np

from sahi.prediction import ObjectPrediction
from sahi.utils.import_utils import is_available
from sahi.utils.torch import select_device as select_torch_device

Expand Down Expand Up @@ -173,7 +174,7 @@ def convert_original_predictions(
self._apply_category_remapping()

@property
def object_prediction_list(self):
def object_prediction_list(self) -> List[ObjectPrediction]:
return self._object_prediction_list_per_image[0]

@property
Expand Down
4 changes: 2 additions & 2 deletions sahi/models/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import logging

from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.models.ultralytics import UltralyticsDetectionModel
from sahi.utils.import_utils import check_requirements

logger = logging.getLogger(__name__)


class RTDetrDetectionModel(Yolov8DetectionModel):
class RTDetrDetectionModel(UltralyticsDetectionModel):
def check_dependencies(self) -> None:
check_requirements(["ultralytics"])

Expand Down
Loading

0 comments on commit b508bce

Please sign in to comment.