Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support yolo11 #1108

Closed
wants to merge 8 commits into from
7 changes: 5 additions & 2 deletions sahi/auto_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sahi.utils.file import import_model_class

MODEL_TYPE_TO_MODEL_CLASS_NAME = {
"yolov8": "Yolov8DetectionModel",
"ultralytics": "UltralyticsDetectionModel",
"rtdetr": "RTDetrDetectionModel",
"mmdet": "MmdetDetectionModel",
"yolov5": "Yolov5DetectionModel",
Expand All @@ -14,6 +14,8 @@
"yolov8onnx": "Yolov8OnnxDetectionModel",
}

ULTRALYTICS_MODEL_NAMES = ["yolov8", "yolov11", "yolo11", "ultralytics"]


class AutoDetectionModel:
@staticmethod
Expand Down Expand Up @@ -60,7 +62,8 @@ def from_pretrained(
Raises:
ImportError: If given {model_type} framework is not installed
"""

if model_type in ULTRALYTICS_MODEL_NAMES:
model_type = "ultralytics"
model_class_name = MODEL_TYPE_TO_MODEL_CLASS_NAME[model_type]
DetectionModel = import_model_class(model_type, model_class_name)

Expand Down
2 changes: 1 addition & 1 deletion sahi/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from . import base, detectron2, huggingface, mmdet, torchvision, yolov5, yolov8onnx
from . import base, detectron2, huggingface, mmdet, torchvision, ultralytics, yolov5, yolov8onnx
4 changes: 2 additions & 2 deletions sahi/models/rtdetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@

import logging

from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.models.ultralytics import UltralyticsDetectionModel
from sahi.utils.import_utils import check_requirements

logger = logging.getLogger(__name__)


class RTDetrDetectionModel(Yolov8DetectionModel):
class RTDetrDetectionModel(UltralyticsDetectionModel):
def check_dependencies(self) -> None:
check_requirements(["ultralytics"])

Expand Down
10 changes: 5 additions & 5 deletions sahi/models/yolov8.py → sahi/models/ultralytics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Code written by AnNT, 2023.

import logging
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional

import cv2
import numpy as np
Expand All @@ -17,7 +17,7 @@
from sahi.utils.import_utils import check_requirements


class Yolov8DetectionModel(DetectionModel):
class UltralyticsDetectionModel(DetectionModel):
def check_dependencies(self) -> None:
check_requirements(["ultralytics"])

Expand All @@ -33,14 +33,14 @@ def load_model(self):
model.to(self.device)
self.set_model(model)
except Exception as e:
raise TypeError("model_path is not a valid yolov8 model path: ", e)
raise TypeError("model_path is not a valid Ultralytics model path: ", e)

def set_model(self, model: Any):
"""
Sets the underlying YOLOv8 model.
Sets the underlying Ultralytics model.
Args:
model: Any
A YOLOv8 model
A Ultralytics model
"""

self.model = model
Expand Down
2 changes: 1 addition & 1 deletion sahi/prediction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from PIL import Image

from sahi.annotation import ObjectAnnotation
from sahi.utils.coco import CocoAnnotation, CocoPrediction
from sahi.utils.coco import CocoPrediction
from sahi.utils.cv import read_image_as_pil, visualize_object_predictions
from sahi.utils.file import Path

Expand Down
10 changes: 9 additions & 1 deletion scripts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import shutil
import sys

import click


def shell(command, exit_status=0):
"""
Expand Down Expand Up @@ -30,9 +32,15 @@ def validate_and_exit(expected_out_status=0, **kwargs):
fail_count = 0
for component, exit_status in kwargs.items():
if exit_status != expected_out_status:
print(f"{component} failed.")
click.secho(f"{component} failed.", fg="red")
fail_count += 1

print_console_centered(f"{len(kwargs)-fail_count} success, {fail_count} failure")
click.secho("\nTo fix formatting issues:", fg="yellow")
click.secho("1. Install development dependencies:", fg="cyan")
click.secho(' pip install -e ."[dev]"', fg="green")
click.secho("\n2. Run code formatting:", fg="cyan")
click.secho(" python -m scripts.run_code_style format", fg="green")
sys.exit(1)


Expand Down
26 changes: 12 additions & 14 deletions tests/test_yolov8model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@

import unittest

import numpy as np

from sahi.utils.cv import read_image
from sahi.utils.yolov8 import Yolov8TestConstants, download_yolov8n_model, download_yolov8n_seg_model

Expand All @@ -15,11 +13,11 @@

class TestYolov8DetectionModel(unittest.TestCase):
def test_load_model(self):
from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.models.ultralytics import UltralyticsDetectionModel

download_yolov8n_model()

yolov8_detection_model = Yolov8DetectionModel(
yolov8_detection_model = UltralyticsDetectionModel(
model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
Expand All @@ -32,13 +30,13 @@ def test_load_model(self):
def test_set_model(self):
from ultralytics import YOLO

from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.models.ultralytics import UltralyticsDetectionModel

download_yolov8n_model()

yolo_model = YOLO(Yolov8TestConstants.YOLOV8N_MODEL_PATH)

yolov8_detection_model = Yolov8DetectionModel(
yolov8_detection_model = UltralyticsDetectionModel(
model=yolo_model,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
Expand All @@ -49,12 +47,12 @@ def test_set_model(self):
self.assertNotEqual(yolov8_detection_model.model, None)

def test_perform_inference(self):
from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.models.ultralytics import UltralyticsDetectionModel

# init model
download_yolov8n_model()

yolov8_detection_model = Yolov8DetectionModel(
yolov8_detection_model = UltralyticsDetectionModel(
model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
Expand Down Expand Up @@ -90,12 +88,12 @@ def test_perform_inference(self):
self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD)

def test_convert_original_predictions(self):
from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.models.ultralytics import UltralyticsDetectionModel

# init model
download_yolov8n_model()

yolov8_detection_model = Yolov8DetectionModel(
yolov8_detection_model = UltralyticsDetectionModel(
model_path=Yolov8TestConstants.YOLOV8N_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
Expand Down Expand Up @@ -140,12 +138,12 @@ def test_convert_original_predictions(self):
self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD)

def test_perform_inference_with_mask_output(self):
from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.models.ultralytics import UltralyticsDetectionModel

# init model
download_yolov8n_seg_model()

yolov8_detection_model = Yolov8DetectionModel(
yolov8_detection_model = UltralyticsDetectionModel(
model_path=Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
Expand Down Expand Up @@ -182,12 +180,12 @@ def test_perform_inference_with_mask_output(self):
self.assertEqual(masks.shape[0], len(boxes))

def test_convert_original_predictions_with_mask_output(self):
from sahi.models.yolov8 import Yolov8DetectionModel
from sahi.models.ultralytics import UltralyticsDetectionModel

# init model
download_yolov8n_seg_model()

yolov8_detection_model = Yolov8DetectionModel(
yolov8_detection_model = UltralyticsDetectionModel(
model_path=Yolov8TestConstants.YOLOV8N_SEG_MODEL_PATH,
confidence_threshold=CONFIDENCE_THRESHOLD,
device=MODEL_DEVICE,
Expand Down
Loading