Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

yolov7 segemntation onnx #66

Open
yxl502 opened this issue Aug 11, 2023 · 1 comment
Open

yolov7 segemntation onnx #66

yxl502 opened this issue Aug 11, 2023 · 1 comment

Comments

@yxl502
Copy link

yxl502 commented Aug 11, 2023

yolov5-7.0 onnx model inference

import torch
import cv2
import numpy as np
from copy import deepcopy
import onnxruntime as ort
from utils.general import non_max_suppression, scale_boxes
from utils.augmentations import letterbox
from utils.segment.general import masks2segments, process_mask, scale_image
from utils.plots import colors

model_path = "D:/PycharmProjects/yolov5-7.0/yolov5s-seg.onnx"
img_path = "./data/images/zidane.jpg"

数据预处理

img = cv2.imread(img_path)
print('ori_img:', img.shape)
ori_img = deepcopy(img)
img = letterbox(img, new_shape=(640, 640), stride=32, auto=True)[0] # padded resize
img = np.ascontiguousarray(img.transpose((2, 0, 1))[::-1]) # HWC to CHW, BGR to RGB,contiguous

img = torch.from_numpy(img).to("cuda:0") # ndarray to tensor

img = img.float() # uint8 to fp32

img = img.astype(np.float32)
img = img / 255 # 0 - 255 to 0.0 - 1.0
if len(img.shape) == 3:
img = img[None] # expand for batch dim

加载模型

sess = ort.InferenceSession(model_path, providers=['CUDAExecutionProvider']) # 'CPUExecutionProvider'
input_name = sess.get_inputs()[0].name
output_name = [output.name for output in sess.get_outputs()]

推理

outputs = sess.run(output_name, {input_name: img})

pred, proto = outputs[0:2]
pred = torch.from_numpy(pred).to("cuda:0")
proto = torch.from_numpy(proto).to("cuda:0")

NMS

pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45, max_det=1000, nm=32)

画图

print('resize_img:', img.shape)
det = pred[0]
masks = process_mask(proto[0], det[:, 6:], det[:, :4], img.shape[2:], upsample=True) # HWC

画mask

img = torch.from_numpy(img).to("cuda:0")
if len(masks) == 0:
dst_img = img[0].permute(1, 2, 0).contiguous().cpu().numpy() * 255
colors = [colors(x, True) for x in det[:, 5]]
colors = torch.tensor(colors, device="cuda:0", dtype=torch.float32) / 255.0
colors = colors[:, None, None] # shape(n,1,1,3)
masks = masks.unsqueeze(3) # shape(n,h,w,1)
masks_color = masks * (colors * 0.5) # shape(n,h,w,3)

inv_alph_masks = (1 - masks * 0.5).cumprod(0) # shape(n,h,w,1)
mcs = (masks_color * inv_alph_masks).sum(0) * 2 # mask color summand shape(n,h,w,3)
dst_img = img[0].flip(dims=[0]) # flip channel
dst_img = dst_img.permute(1, 2, 0).contiguous() # shape(h,w,3)
dst_img = dst_img * inv_alph_masks[-1] + mcs
im_mask = (dst_img * 255).byte().cpu().numpy()
dst_img = scale_image(dst_img.shape, im_mask, ori_img.shape)

box label score

det[:, :4] = scale_boxes(img.shape[2:], det[:, :4], ori_img.shape).round() # rescale boxes to ori_img size
det_cpu = pred[0].cpu().numpy()
for i, item in enumerate(det_cpu):
# 画box
box_int = item[0:4].astype(np.int32)
cv2.rectangle(dst_img, (box_int[0], box_int[1]), (box_int[2], box_int[3]), color=(0, 255, 0), thickness=1)
# 画label
label = item[5]
score = item[4]
org = (min(box_int[0], box_int[2]), min(box_int[1], box_int[3]) - 8)
text = '{}|{:.2f}'.format(int(label), score)
cv2.putText(dst_img, text, org=org, fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.5, color=(0, 255, 0),
thickness=1)

cv2.imshow('result', dst_img)
cv2.waitKey()
cv2.imwrite('out.jpg', dst_img)

@jinjidexiaowang123
Copy link

Hello, may I ask if the YOLOv7 segmentation ONNX operations, specifically the imports 'from utils.general import non_max_suppression, scale_boxes' and 'from utils.segment.general import masks2segments, process_mask, scale_image', are intended for general use in this repository? Or do they require some modifications to be usable? The error messages 'cannot find reference to 'scale_boxes' in 'general.py'' and 'cannot find reference to 'scale_image' in 'general.py'' suggest that these references are missing in the 'general.py' file. Thank you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants
@yxl502 @jinjidexiaowang123 and others