Skip to content

Commit

Permalink
rk opt for yolov8 base@rk_c86cd359
Browse files Browse the repository at this point in the history
Signed-off-by: Randall Zhuo <[email protected]>
  • Loading branch information
Randall Zhuo committed Dec 19, 2023
1 parent c9be1f3 commit 5b7ddd8
Show file tree
Hide file tree
Showing 12 changed files with 761 additions and 10 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
#### Get model optimized for RKNN

Exports detection/segment model with optimization for RKNN, please refer here [RKOPT_README.md](RKOPT_README.md). Optimization for exporting model does not affect the training stage

关于如何导出适配 RKNPU 分割/检测 模型,请参考 [RKOPT_README.zh-CN.md](RKOPT_README.zh-CN.md),该优化只在导出模型时生效,训练代码按照原仓库的指引即可。

---

<div align="center">
<p>
<a href="https://ultralytics.com/yolov8" target="_blank">
Expand Down
6 changes: 6 additions & 0 deletions README.zh-CN.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
#### 导出适配 RKNPU 的模型

关于如何导出适配 RKNPU 分割/检测 模型,请参考 [RKOPT_README_CN.md](RKOPT_README_CN.md), 该优化只在导出模型时生效,训练代码按照原仓库的指引即可。

---

<div align="center">
<p>
<a href="https://ultralytics.com/yolov8" target="_blank">
Expand Down
39 changes: 39 additions & 0 deletions RKOPT_README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# RKNN optimization for exporting model

## Source
Base on https://github.com/ultralytics/ultralytics with commit id as c9be1f3cce89778f79fb462797b8ca0300e3813d




## What different
With inference result values unchanged, the following optimizations were applied:
- Change output node, remove post-process from the model. (post-process block in model is unfriendly for quantization)
- Remove dfl structure at the end of the model. (which slowdown the inference speed on NPU device)
- Add a score-sum output branch to speedup post-process.

All the removed operation will be done on CPU. (the CPU post-process could be found in **RKNN_Model_Zoo**)




## Export ONNX model

After meeting the environment requirements specified in "./requirements.txt," execute the following command to export the model (support detect/segment model):

```
# Adjust the model file path in "./ultralytics/cfg/default.yaml" (default is yolov8n.pt). If you trained your own model, please provide the corresponding path.
# For example, filled with yolov8n.pt for detection model.
# Filling with yolov8n-seg.pt for segmentation model.
export PYTHONPATH=./
python ./ultralytics/engine/exporter.py
# Upon completion, the ".onnx" model will be generated. If the original model is "yolov8n.pt," the generated model will be "yolov8n.onnx"
```



## Convert to RKNN model, Python demo, C demo

Please refer to https://github.com/airockchip/rknn_model_zoo.
45 changes: 45 additions & 0 deletions RKOPT_README.zh-CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# 导出 RKNPU 适配模型说明

## Source

​ 本仓库基于 https://github.com/ultralytics/ultralytics 仓库的 c9be1f3cce89778f79fb462797b8ca0300e3813d commit 进行修改,验证.



## 模型差异

在基于不影响输出结果, 不需要重新训练模型的条件下, 有以下改动:

- 修改输出结构, 移除后处理结构. (后处理结果对于量化不友好)

- dfl 结构在 NPU 处理上性能不佳,移至模型外部的后处理阶段,此操作大部分情况下可提升推理性能。


- 模型输出分支新增置信度的总和,用于后处理阶段加速阈值筛选。


以上移除的操作, 均需要在外部使用CPU进行相应的处理. (对应的后处理代码可以在 **RKNN_Model_Zoo** 中找到)



## 导出onnx模型

在满足 ./requirements.txt 的环境要求后,执行以下语句导出模型

```
# 调整 ./ultralytics/cfg/default.yaml 中 model 文件路径,默认为 yolov8n.pt,若自己训练模型,请调接至对应的路径。支持检测、分割模型。
# 如填入 yolov8n.pt 导出检测模型
# 如填入 yolov8-seg.pt 导出分割模型
export PYTHONPATH=./
python ./ultralytics/engine/exporter.py
# 执行完毕后,会生成 ONNX 模型. 假如原始模型为 yolov8n.pt,则生成 yolov8n.onnx 模型。
```



## 转RKNN模型、Python demo、C demo

请参考 https://github.com/airockchip/rknn_model_zoo

4 changes: 2 additions & 2 deletions ultralytics/cfg/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ task: detect # (str) YOLO task, i.e. detect, segment, classify, pose
mode: train # (str) YOLO mode, i.e. train, val, predict, export, track, benchmark

# Train settings -------------------------------------------------------------------------------------------------------
model: # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
model: yolov8m-seg.pt # (str, optional) path to model file, i.e. yolov8n.pt, yolov8n.yaml
data: # (str, optional) path to data file, i.e. coco128.yaml
epochs: 100 # (int) number of epochs to train for
patience: 50 # (int) epochs to wait for no observable improvement for early stopping of training
Expand Down Expand Up @@ -68,7 +68,7 @@ retina_masks: False # (bool) use high-resolution segmentation masks
boxes: True # (bool) Show boxes in segmentation predictions

# Export settings ------------------------------------------------------------------------------------------------------
format: torchscript # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
format: rknn # (str) format to export to, choices at https://docs.ultralytics.com/modes/export/#export-formats
keras: False # (bool) use Kera=s
optimize: False # (bool) TorchScript: optimize for mobile
int8: False # (bool) CoreML/TF INT8 quantization
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/data/augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def apply_bboxes(self, bboxes, M):
# Create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1)), dtype=bboxes.dtype).reshape(4, n).T
return np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T

def apply_segments(self, segments, M):
"""
Expand Down
2 changes: 1 addition & 1 deletion ultralytics/data/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Ultralytics YOLO 🚀, AGPL-3.0 license
import cv2

import glob
import math
Expand All @@ -9,7 +10,6 @@
from pathlib import Path
from typing import Optional

import cv2
import numpy as np
import psutil
from torch.utils.data import Dataset
Expand Down
37 changes: 34 additions & 3 deletions ultralytics/engine/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
from copy import deepcopy
from datetime import datetime
from pathlib import Path
import cv2

import torch

Expand Down Expand Up @@ -88,8 +89,10 @@ def export_formats():
['TensorFlow Lite', 'tflite', '.tflite', True, False],
['TensorFlow Edge TPU', 'edgetpu', '_edgetpu.tflite', True, False],
['TensorFlow.js', 'tfjs', '_web_model', True, False],
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
['ncnn', 'ncnn', '_ncnn_model', True, True], ]
['PaddlePaddle', 'paddle', '_paddle_model', True, True],
['ncnn', 'ncnn', '_ncnn_model', True, True],
['RKNN', 'rknn', '_rknnopt.torchscript', True, False],
]
return pandas.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])


Expand Down Expand Up @@ -157,7 +160,8 @@ def __call__(self, model=None):
flags = [x == format for x in fmts]
if sum(flags) != 1:
raise ValueError(f"Invalid export format='{format}'. Valid formats are {fmts}")
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn = flags # export booleans
jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, rknn = flags # export booleans


# Load PyTorch model
self.device = select_device('cpu' if self.args.device is None else self.args.device)
Expand Down Expand Up @@ -262,6 +266,8 @@ def __call__(self, model=None):
f[10], _ = self.export_paddle()
if ncnn: # ncnn
f[11], _ = self.export_ncnn()
if rknn:
f[12], _ = self.export_rknn()

# Finish
f = [str(x) for x in f if x] # filter out '' and None
Expand Down Expand Up @@ -297,6 +303,31 @@ def export_torchscript(self, prefix=colorstr('TorchScript:')):
ts.save(str(f), _extra_files=extra_files)
return f, None

@try_export
def export_rknn(self, prefix=colorstr('RKNN:')):
"""YOLOv8 RKNN model export."""
LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')

# ts = torch.jit.trace(self.model, self.im, strict=False)
# f = str(self.file).replace(self.file.suffix, f'_rknnopt.torchscript')
# torch.jit.save(ts, str(f))

f = str(self.file).replace(self.file.suffix, f'.onnx')
opset_version = self.args.opset or get_latest_opset()
torch.onnx.export(
self.model,
self.im[0:1,:,:,:],
f,
verbose=False,
opset_version=12,
do_constant_folding=True, # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
input_names=['images'])

LOGGER.info(f'\n{prefix} feed {f} to RKNN-Toolkit or RKNN-Toolkit2 to generate RKNN model.\n'
'Refer https://github.com/airockchip/rknn_model_zoo/tree/main/models/CV/object_detection/yolo')
return f, None


@try_export
def export_onnx(self, prefix=colorstr('ONNX:')):
"""YOLOv8 ONNX export."""
Expand Down
4 changes: 3 additions & 1 deletion ultralytics/nn/autobackend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def __init__(self,
super().__init__()
w = str(weights[0] if isinstance(weights, list) else weights)
nn_module = isinstance(weights, torch.nn.Module)
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton = \
pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs, paddle, ncnn, triton, rknn = \
self._model_type(w)
fp16 &= pt or jit or onnx or xml or engine or nn_module or triton # FP16
nhwc = coreml or saved_model or pb or tflite or edgetpu # BHWC formats (vs torch BCWH)
Expand Down Expand Up @@ -385,6 +385,8 @@ def forward(self, im, augment=False, visualize=False):
mat_out = self.pyncnn.Mat()
ex.extract(output_name, mat_out)
y.append(np.array(mat_out)[None])
elif getattr(self, 'rknn', False):
assert "for inference, please refer to https://github.com/airockchip/rknn_model_zoo/"
elif self.triton: # NVIDIA Triton Inference Server
y = self.model(im)
else: # TensorFlow (SavedModel, GraphDef, Lite, Edge TPU)
Expand Down
Loading

0 comments on commit 5b7ddd8

Please sign in to comment.