Skip to content

Commit

Permalink
Added Segmentation / Export End2End SegModels #16
Browse files Browse the repository at this point in the history
  • Loading branch information
levipereira committed Nov 6, 2024
1 parent 98d4521 commit c293d1f
Show file tree
Hide file tree
Showing 5 changed files with 1,110 additions and 17 deletions.
46 changes: 44 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# YOLOv9 QAT for TensorRT
# YOLOv9 QAT for TensorRT Detection / Segmentation

This repository contains an implementation of YOLOv9 with Quantization-Aware Training (QAT), specifically designed for deployment on platforms utilizing TensorRT for hardware-accelerated inference. <br>
This implementation aims to provide an efficient, low-latency version of YOLOv9 for real-time detection applications.<br>
Expand All @@ -15,7 +15,8 @@ We use [TensorRT's pytorch quntization tool](https://github.com/NVIDIA/TensorRT/
For those who are not familiar with QAT, I highly recommend watching this video:<br> [Quantization explained with PyTorch - Post-Training Quantization, Quantization-Aware Training](https://www.youtube.com/watch?v=0VdNflU08yA)

**Important**<br>
Currently, quantization is only available for object detection models. However, since quantization primarily affects the backbone of the YOLOv9 model and the backbone remains consistent across all YOLOv9 variants, quantization is effectively prepared for all YOLOv9-based models, regardless of whether they are used for detection or segmentation tasks. Quantization support for segmentation models has not yet been released, as it necessitates the development of evaluation criteria and the validation of quantization for the final layers of the model. <br>
Evaluation of the segmentation model using TensorRT is currently under development. Once I have more available time, I will complete and release this work.

🌟 We still have plenty of nodes to improve Q/DQ, and we rely on the community's contribution to enhance this project, benefiting us all. Let's collaborate and make it even better! 🚀

## Release Highlights
Expand All @@ -35,6 +36,7 @@ Currently, quantization is only available for object detection models. However,

### Evaluation Results

## Detection
#### Activation SiLU

| Eval Model | AP | AP50 | Precision | Recall |
Expand Down Expand Up @@ -66,6 +68,14 @@ Currently, quantization is only available for object detection models. However,
| **INT8 (TensorRT)** vs **Origin (Pytorch)** | | | | |
| | -0.002 | -0.005 | +0.004 | -0.003 |

## Segmentation
| Model | Box | | | | Mask | | | |
|--------|-----|--|--|--|------|--|--|--|
| | P | R | mAP50 | mAP50-95 | P | R | mAP50 | mAP50-95 |
| Origin | 0.729 | 0.632 | 0.691 | 0.521 | 0.717 | 0.611 | 0.657 | 0.423 |
| PTQ | 0.729 | 0.626 | 0.688 | 0.520 | 0.717 | 0.604 | 0.654 | 0.421 |
| QAT | 0.725 | 0.631 | 0.689 | 0.521 | 0.714 | 0.609 | 0.655 | 0.421 |


## Latency/Throughput Report - TensorRT

Expand Down Expand Up @@ -530,3 +540,35 @@ D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%)
Total Host Walltime: 10.0286 s
Total GPU Compute Time: 10.0269 s
```


# Segmentation

## FP16
### Batch Size 8

```bash
=== Performance summary ===
Throughput: 124.055 qps
Latency: min = 8.00354 ms, max = 8.18585 ms, mean = 8.05924 ms, median = 8.05072 ms, percentile(90%) = 8.11499 ms, percentile(95%) = 8.1438 ms, percentile(99%) = 8.17456 ms
Enqueue Time: min = 0.00219727 ms, max = 0.0200653 ms, mean = 0.00271174 ms, median = 0.00256348 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00317383 ms, percentile(99%) = 0.00466919 ms
H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
GPU Compute Time: min = 8.00354 ms, max = 8.18585 ms, mean = 8.05924 ms, median = 8.05072 ms, percentile(90%) = 8.11499 ms, percentile(95%) = 8.1438 ms, percentile(99%) = 8.17456 ms
D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
Total Host Walltime: 3.01478 s
Total GPU Compute Time: 3.01415 s
```

## INT8 / FP16
### Batch Size 8
```bash
=== Performance summary ===
Throughput: 223.63 qps
Latency: min = 4.45544 ms, max = 4.71553 ms, mean = 4.47007 ms, median = 4.46777 ms, percentile(90%) = 4.47284 ms, percentile(95%) = 4.47388 ms, percentile(99%) = 4.47693 ms
Enqueue Time: min = 0.00219727 ms, max = 0.00854492 ms, mean = 0.00258152 ms, median = 0.00244141 ms, percentile(90%) = 0.00292969 ms, percentile(95%) = 0.00305176 ms, percentile(99%) = 0.00439453 ms
H2D Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
GPU Compute Time: min = 4.45544 ms, max = 4.71553 ms, mean = 4.47007 ms, median = 4.46777 ms, percentile(90%) = 4.47284 ms, percentile(95%) = 4.47388 ms, percentile(99%) = 4.47693 ms
D2H Latency: min = 0 ms, max = 0 ms, mean = 0 ms, median = 0 ms, percentile(90%) = 0 ms, percentile(95%) = 0 ms, percentile(99%) = 0 ms
Total Host Walltime: 3.00944 s
Total GPU Compute Time: 3.00836 s
```
71 changes: 56 additions & 15 deletions export_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
if platform.system() != 'Windows':
ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative

from models.experimental import attempt_load, End2End
from models.experimental_trt import End2End_TRT
from models.experimental import attempt_load
from models.yolo import ClassificationModel, Detect, DDetect, DualDetect, DualDDetect, DetectionModel, SegmentationModel
from utils.dataloaders import LoadImages
from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
Expand Down Expand Up @@ -175,12 +176,22 @@ def export_onnx(model, im, file, opset, dynamic, simplify, prefix=colorstr('ONNX
remove_redundant_qdq_model(model_onnx, f)
model_onnx = onnx.load(f)
return f, model_onnx




@try_export
def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thres, device, labels, prefix=colorstr('ONNX END2END:')):
if not isinstance(model, DetectionModel) or isinstance(model, SegmentationModel):
def export_onnx_end2end(model, im, file, class_agnostic, simplify, topk_all, iou_thres, conf_thres, device, labels, mask_resolution, pooler_scale, sampling_ratio, prefix=colorstr('ONNX END2END:')):
if not isinstance(model, DetectionModel) or not isinstance(model, SegmentationModel):
raise RuntimeError("Model not supported. Only Detection Models can be exported with End2End functionality.")

is_det_model=True
if isinstance(model, SegmentationModel):
is_det_model=False

env_is_det_model = os.getenv("MODEL_DET")
if env_is_det_model == "0":
is_det_model = False

# YOLO ONNX export
check_requirements('onnx')
import onnx
Expand All @@ -195,6 +206,14 @@ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thr
LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
f = os.path.splitext(file)[0] + "-end2end.onnx"
batch_size = 'batch'
d = {
'stride': int(max(model.stride)),
'names': model.names,
'model type' : 'Detection' if is_det_model else 'Segmentation',
'TRT Compatibility': '8.6 or above' if class_agnostic else '8.5 or above',
'TRT Plugins': 'EfficientNMS_TRT' if is_det_model else 'EfficientNMSX_TRT, ROIAlign'
}


dynamic_axes = {'images': {0 : 'batch', 2: 'height', 3:'width'}, } # variable length axes

Expand All @@ -204,30 +223,40 @@ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thr
'det_scores': {0: 'batch'},
'det_classes': {0: 'batch'},
}
dynamic_axes.update(output_axes)
model = End2End(model, topk_all, iou_thres, conf_thres, None ,device, labels)
if is_det_model:
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
shapes = [ batch_size, 1,
batch_size, topk_all, 4,
batch_size, topk_all,
batch_size, topk_all]

else:
output_axes['det_masks'] = {0: 'batch'}
output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes', 'det_masks']
shapes = [ batch_size, 1,
batch_size, topk_all, 4,
batch_size, topk_all,
batch_size, topk_all,
batch_size, topk_all, mask_resolution * mask_resolution]

output_names = ['num_dets', 'det_boxes', 'det_scores', 'det_classes']
shapes = [ batch_size, 1, batch_size, topk_all, 4,
batch_size, topk_all, batch_size, topk_all]

dynamic_axes.update(output_axes)
model = End2End_TRT(model, class_agnostic, topk_all, iou_thres, conf_thres, mask_resolution, pooler_scale, sampling_ratio, None ,device, labels, is_det_model )


if is_model_qat:
warnings.filterwarnings("ignore")
LOGGER.info(f'{prefix} Model QAT Detected ...')
quant_nn.TensorQuantizer.use_fb_fake_quant = True
model.eval()
quantize.initialize()
quantize.replace_custom_module_forward(model)

with torch.no_grad():
torch.onnx.export(model,
im,
f,
verbose=False,
export_params=True, # store the trained parameter weights inside the model file
opset_version=13,
opset_version=16,
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['images'],
output_names=output_names,
Expand All @@ -239,7 +268,7 @@ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thr
f,
verbose=False,
export_params=True, # store the trained parameter weights inside the model file
opset_version=12,
opset_version=16,
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['images'],
output_names=output_names,
Expand All @@ -248,6 +277,10 @@ def export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thr
# Checks
model_onnx = onnx.load(f) # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
for k, v in d.items():
meta = model_onnx.metadata_props.add()
meta.key, meta.value = k, str(v)

for i in model_onnx.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
Expand Down Expand Up @@ -586,6 +619,7 @@ def run(
batch_size=1, # batch size
device='cpu', # cuda device, i.e. 0 or 0,1,2,3 or cpu
include=('torchscript', 'onnx'), # include formats
class_agnostic=False,
half=False, # FP16 half-precision export
inplace=False, # set YOLO Detect() inplace=True
keras=False, # use Keras
Expand All @@ -602,6 +636,9 @@ def run(
topk_all=100, # TF.js NMS: topk for all classes to keep
iou_thres=0.45, # TF.js NMS: IoU threshold
conf_thres=0.25, # TF.js NMS: confidence threshold
mask_resolution=56,
pooler_scale=0.25,
sampling_ratio=0,
):
t = time.time()
include = [x.lower() for x in include] # to lowercase
Expand Down Expand Up @@ -655,7 +692,7 @@ def run(
f[2], _ = export_onnx(model, im, file, opset, dynamic, simplify)
if onnx_end2end:
labels = model.names
f[2], _ = export_onnx_end2end(model, im, file, simplify, topk_all, iou_thres, conf_thres, device, len(labels))
f[2], _ = export_onnx_end2end(model, im, file, class_agnostic, simplify, topk_all, iou_thres, conf_thres, device, len(labels), mask_resolution, pooler_scale, sampling_ratio )
if xml: # OpenVINO
f[3], _ = export_openvino(file, metadata, half)
if coreml: # CoreML
Expand Down Expand Up @@ -731,6 +768,10 @@ def parse_opt():
parser.add_argument('--topk-all', type=int, default=100, help='ONNX END2END/TF.js NMS: topk for all classes to keep')
parser.add_argument('--iou-thres', type=float, default=0.45, help='ONNX END2END/TF.js NMS: IoU threshold')
parser.add_argument('--conf-thres', type=float, default=0.25, help='ONNX END2END/TF.js NMS: confidence threshold')
parser.add_argument('--class-agnostic', action='store_true', help='Agnostic NMS (single class)')
parser.add_argument('--mask-resolution', type=int, default=160, help='Mask pooled output.')
parser.add_argument('--pooler-scale', type=float, default=0.25, help='Multiplicative factor used to translate the ROI coordinates. ')
parser.add_argument('--sampling-ratio', type=int, default=0, help='Number of sampling points in the interpolation. Allowed values are non-negative integers.')
parser.add_argument(
'--include',
nargs='+',
Expand Down
Loading

0 comments on commit c293d1f

Please sign in to comment.