diff --git a/demo/inference_for_rtdetr.ipynb b/demo/inference_for_rtdetr.ipynb new file mode 100644 index 000000000..21d4e32bc --- /dev/null +++ b/demo/inference_for_rtdetr.ipynb @@ -0,0 +1,409 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/obss/sahi/blob/main/demo/inference_for_yolov5.ipynb)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 0. Preperation" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Install latest version of SAHI and ultralytics:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -U torch sahi ultralytics" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.getcwd()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Import required modules:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# arrange an instance segmentation model for test\n", + "from sahi.utils.rtdetr import (\n", + " download_rtdetrl_model\n", + ")\n", + "\n", + "from sahi import AutoDetectionModel\n", + "from sahi.utils.cv import read_image\n", + "from sahi.utils.file import download_from_url\n", + "from sahi.predict import get_prediction, get_sliced_prediction, predict\n", + "from IPython.display import Image" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Download a yolov8 model and two test images:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# download YOLOV5S6 model to 'models/yolov5s6.pt'\n", + "rtdetr_model_path = \"models/rtdetr-l.pt\"\n", + "download_rtdetrl_model(rtdetr_model_path)\n", + "\n", + "# download test images into demo_data folder\n", + "download_from_url('https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/small-vehicles1.jpeg', 'demo_data/small-vehicles1.jpeg')\n", + "download_from_url('https://raw.githubusercontent.com/obss/sahi/main/demo/demo_data/terrain2.png', 'demo_data/terrain2.png')" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Standard Inference with a YOLOv8 Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Instantiate a detection model by defining model weight path and other parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "detection_model = AutoDetectionModel.from_pretrained(\n", + " model_type='rtdetr',\n", + " model_path=rtdetr_model_path,\n", + " confidence_threshold=0.3,\n", + " device=\"cpu\", # or 'cuda:0'\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Perform prediction by feeding the get_prediction function with an image path and a DetectionModel instance:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = get_prediction(\"demo_data/small-vehicles1.jpeg\", detection_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Or perform prediction by feeding the get_prediction function with a numpy image and a DetectionModel instance:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = get_prediction(read_image(\"demo_data/small-vehicles1.jpeg\"), detection_model)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Visualize predicted bounding boxes and masks over the original image:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.export_visuals(export_dir=\"demo_data/\")\n", + "\n", + "Image(\"demo_data/prediction_visual.png\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Sliced Inference with a YOLOv8 Model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- To perform sliced prediction we need to specify slice parameters. In this example we will perform prediction over slices of 256x256 with an overlap ratio of 0.2:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result = get_sliced_prediction(\n", + " \"demo_data/small-vehicles1.jpeg\",\n", + " detection_model,\n", + " slice_height = 256,\n", + " slice_width = 256,\n", + " overlap_height_ratio = 0.2,\n", + " overlap_width_ratio = 0.2\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Visualize predicted bounding boxes and masks over the original image:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.export_visuals(export_dir=\"demo_data/\")\n", + "\n", + "Image(\"demo_data/prediction_visual.png\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Prediction Result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Predictions are returned as [sahi.prediction.PredictionResult](sahi/prediction.py), you can access the object prediction list as:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "object_prediction_list = result.object_prediction_list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "object_prediction_list[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- ObjectPrediction's can be converted to [COCO annotation](https://cocodataset.org/#format-data) format:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.to_coco_annotations()[:3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- ObjectPrediction's can be converted to [COCO prediction](https://github.com/i008/COCO-dataset-explorer) format:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.to_coco_predictions(image_id=1)[:3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- ObjectPrediction's can be converted to [imantics](https://github.com/jsbroks/imantics) annotation format:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.to_imantics_annotations()[:3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- ObjectPrediction's can be converted to [fiftyone](https://github.com/voxel51/fiftyone) detection format:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "result.to_fiftyone_detections()[:3]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 4. Batch Prediction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Set model and directory parameters:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model_type = \"rtdetr\"\n", + "model_path = rtdetr_model_path\n", + "model_device = \"cpu\" # or 'cuda:0'\n", + "model_confidence_threshold = 0.4\n", + "\n", + "slice_height = 256\n", + "slice_width = 256\n", + "overlap_height_ratio = 0.2\n", + "overlap_width_ratio = 0.2\n", + "\n", + "source_image_dir = \"demo_data/\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "- Perform sliced inference on given folder:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "predict(\n", + " model_type=model_type,\n", + " model_path=model_path,\n", + " model_device=model_device,\n", + " model_confidence_threshold=model_confidence_threshold,\n", + " source=source_image_dir,\n", + " slice_height=slice_height,\n", + " slice_width=slice_width,\n", + " overlap_height_ratio=overlap_height_ratio,\n", + " overlap_width_ratio=overlap_width_ratio,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "test", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + }, + "vscode": { + "interpreter": { + "hash": "244b47d5824a96a4079632e50977464d968e13d2c337f65c905f8da81a0b4f95" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/sahi/auto_model.py b/sahi/auto_model.py index 32cbea4d6..49dea644f 100644 --- a/sahi/auto_model.py +++ b/sahi/auto_model.py @@ -5,6 +5,7 @@ MODEL_TYPE_TO_MODEL_CLASS_NAME = { "yolov8openvino": "Yolov8OpenvinoDetectionModel", "yolov8": "Yolov8DetectionModel", + "rtdetr": "RTDetrDetectionModel", "mmdet": "MmdetDetectionModel", "yolov5": "Yolov5DetectionModel", "detectron2": "Detectron2DetectionModel", diff --git a/sahi/models/rtdetr.py b/sahi/models/rtdetr.py new file mode 100644 index 000000000..d704eee59 --- /dev/null +++ b/sahi/models/rtdetr.py @@ -0,0 +1,29 @@ +# OBSS SAHI Tool +# Code written by AnNT, 2023. + +import logging + +from sahi.models.yolov8 import Yolov8DetectionModel +from sahi.utils.import_utils import check_requirements + +logger = logging.getLogger(__name__) + + +class RTDetrDetectionModel(Yolov8DetectionModel): + def check_dependencies(self) -> None: + check_requirements(["ultralytics"]) + + def load_model(self): + """ + Detection model is initialized and set to self.model. + """ + + from ultralytics import RTDETR + + try: + model = RTDETR(self.model_path) + model.to(self.device) + + self.set_model(model) + except Exception as e: + raise TypeError("model_path is not a valid rtdet model path: ", e) diff --git a/sahi/models/yolov8.py b/sahi/models/yolov8.py index 3ae40bda7..8b1509286 100644 --- a/sahi/models/yolov8.py +++ b/sahi/models/yolov8.py @@ -58,6 +58,7 @@ def perform_inference(self, image: np.ndarray): # Confirm model is loaded if self.model is None: raise ValueError("Model is not loaded, load it by calling .load_model()") + if self.image_size is not None: # ADDED IMAGE SIZE OPTION FOR YOLOV8 MODELS: prediction_result = self.model( image[:, :, ::-1], imgsz=self.image_size, verbose=False, device=self.device @@ -66,6 +67,7 @@ def perform_inference(self, image: np.ndarray): prediction_result = self.model( image[:, :, ::-1], verbose=False, device=self.device ) # YOLOv8 expects numpy arrays to have BGR + prediction_result = [ result.boxes.data[result.boxes.data[:, 4] >= self.confidence_threshold] for result in prediction_result ] diff --git a/sahi/utils/rtdetr.py b/sahi/utils/rtdetr.py new file mode 100644 index 000000000..6e207af78 --- /dev/null +++ b/sahi/utils/rtdetr.py @@ -0,0 +1,38 @@ +import urllib.request +from os import path +from pathlib import Path +from typing import Optional + + +class RTDETRTestConstants: + RTDETRL_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/rtdetr-l.pt" + RTDETRL_MODEL_PATH = "tests/data/models/rtdetr/rtdetr-l.pt" + + RTDETRX_MODEL_URL = "https://github.com/ultralytics/assets/releases/download/v0.0.0/rtdetr-x.pt" + RTDETRX_MODEL_PATH = "tests/data/models/rtdetr/rtdetr-x.pt" + + +def download_rtdetrl_model(destination_path: Optional[str] = None): + if destination_path is None: + destination_path = RTDETRTestConstants.RTDETRL_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + RTDETRTestConstants.RTDETRX_MODEL_URL, + destination_path, + ) + + +def download_rtdetrx_model(destination_path: Optional[str] = None): + if destination_path is None: + destination_path = RTDETRTestConstants.RTDETRX_MODEL_PATH + + Path(destination_path).parent.mkdir(parents=True, exist_ok=True) + + if not path.exists(destination_path): + urllib.request.urlretrieve( + RTDETRTestConstants.RTDETRX_MODEL_URL, + destination_path, + ) diff --git a/tests/data/models/mmdet/_base_/models/faster-rcnn_r50_fpn.py b/tests/data/models/mmdet/_base_/models/faster-rcnn_r50_fpn.py index 69f20233a..c5d8dc802 100644 --- a/tests/data/models/mmdet/_base_/models/faster-rcnn_r50_fpn.py +++ b/tests/data/models/mmdet/_base_/models/faster-rcnn_r50_fpn.py @@ -84,7 +84,7 @@ ), test_cfg=dict( rpn=dict(nms_pre=1000, max_per_img=1000, nms=dict(type="nms", iou_threshold=0.7), min_bbox_size=0), - rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100) + rcnn=dict(score_thr=0.05, nms=dict(type="nms", iou_threshold=0.5), max_per_img=100), # soft-nms is also supported for rcnn testing # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) ), diff --git a/tests/test_rtdetr.py b/tests/test_rtdetr.py new file mode 100644 index 000000000..444670e49 --- /dev/null +++ b/tests/test_rtdetr.py @@ -0,0 +1,142 @@ +# OBSS SAHI Tool +# Code written by Fatih C Akyon (2020), Devrim Çavuşoğlu (2024). + +import unittest + +from sahi.utils.cv import read_image +from sahi.utils.rtdetr import RTDETRTestConstants, download_rtdetrl_model + +MODEL_DEVICE = "cpu" +CONFIDENCE_THRESHOLD = 0.3 +IMAGE_SIZE = 640 + + +class TestRTDetrDetectionModel(unittest.TestCase): + def test_load_model(self): + from sahi.models.rtdetr import RTDetrDetectionModel + + download_rtdetrl_model() + + rtdetr_detection_model = RTDetrDetectionModel( + model_path=RTDETRTestConstants.RTDETRL_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + ) + + self.assertNotEqual(rtdetr_detection_model.model, None) + + def test_set_model(self): + from ultralytics import RTDETR + + from sahi.models.rtdetr import RTDetrDetectionModel + + download_rtdetrl_model() + + rtdetr_model = RTDETR(RTDETRTestConstants.RTDETRL_MODEL_PATH) + + rtdetr_detection_model = RTDetrDetectionModel( + model=rtdetr_model, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + ) + + self.assertNotEqual(rtdetr_detection_model.model, None) + + def test_perform_inference(self): + from sahi.models.rtdetr import RTDetrDetectionModel + + # init model + download_rtdetrl_model() + + rtdetr_detection_model = RTDetrDetectionModel( + model_path=RTDETRTestConstants.RTDETRL_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + image_size=IMAGE_SIZE, + ) + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + + # perform inference + rtdetr_detection_model.perform_inference(image) + original_predictions = rtdetr_detection_model.original_predictions + + boxes = original_predictions + + # find box of first car detection with conf greater than 0.5 + for box in boxes[0]: + if box[5].item() == 2: # if category car + if box[4].item() > 0.5: + break + + # compare + desired_bbox = [321, 322, 384, 362] + predicted_bbox = list(map(round, box[:4].tolist())) + margin = 2 + for ind, point in enumerate(predicted_bbox): + assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin + self.assertEqual(len(rtdetr_detection_model.category_names), 80) + for box in boxes[0]: + self.assertGreaterEqual(box[4].item(), CONFIDENCE_THRESHOLD) + + def test_convert_original_predictions(self): + from sahi.models.rtdetr import RTDetrDetectionModel + + # init model + download_rtdetrl_model() + + rtdetr_detection_model = RTDetrDetectionModel( + model_path=RTDETRTestConstants.RTDETRL_MODEL_PATH, + confidence_threshold=CONFIDENCE_THRESHOLD, + device=MODEL_DEVICE, + category_remapping=None, + load_at_init=True, + image_size=IMAGE_SIZE, + ) + + # prepare image + image_path = "tests/data/small-vehicles1.jpeg" + image = read_image(image_path) + + # get raw predictions for reference + original_results = rtdetr_detection_model.model.predict(image_path, conf=CONFIDENCE_THRESHOLD)[0].boxes + num_results = len(original_results) + + # perform inference + rtdetr_detection_model.perform_inference(image) + + # convert predictions to ObjectPrediction list + rtdetr_detection_model.convert_original_predictions() + object_prediction_list = rtdetr_detection_model.object_prediction_list + + # compare + self.assertEqual(len(object_prediction_list), num_results) + + # loop through predictions and check that they are equal + for i in range(num_results): + desired_bbox = [ + original_results[i].xyxy[0][0], + original_results[i].xyxy[0][1], + original_results[i].xywh[0][2], + original_results[i].xywh[0][3], + ] + desired_cat_id = int(original_results[i].cls[0]) + self.assertEqual(object_prediction_list[i].category.id, desired_cat_id) + predicted_bbox = object_prediction_list[i].bbox.to_xywh() + margin = 2 + for ind, point in enumerate(predicted_bbox): + assert point < desired_bbox[ind] + margin and point > desired_bbox[ind] - margin + for object_prediction in object_prediction_list: + self.assertGreaterEqual(object_prediction.score.value, CONFIDENCE_THRESHOLD) + + +if __name__ == "__main__": + unittest.main()