-
Notifications
You must be signed in to change notification settings - Fork 450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ONNX] Add OnnxToTorch lowering for NonMaxSuppression op #3501
base: main
Are you sure you want to change the base?
Conversation
e83cc21
to
b0bece0
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Vivek,
I think this op needs quite a bit of case-checking, since we cannot reasonably support all of the optional arguments.
The lit tests are numerous and seem redundant to me. Is there any reason for adding so many of them?
Signed-Off By: Vivek Khandelwal <[email protected]>
b0bece0
to
f581793
Compare
f581793
to
ab94aea
Compare
Value iouThreshold = rewriter.create<Torch::AtenItemOp>( | ||
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]); | ||
rewriter.replaceOpWithNewOp<Torch::TorchvisionNmsOp>( | ||
binder.op, resultType, boxes, scores, iouThreshold); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The result type of onnx.NonMaxSuppression doesn't match TorchvisionNmsOp. It looks like nms has output shape [n]
where n
is the number of selected indices out of N
(which I assume corresponds to onnx's spatial_dimension
). The onnx op has output shape [n, 3]
(the three is the tuple batch_index
, channel_index
, and box_index
-- or I assume spatial_index
by their docs), so we might need to concatenate the torchvision result with a torch.zeroes([n,2])
, since this conversion will only work when the batches and channels are both 1.
I'm wondering if these two ops are sufficiently different to consider something else entirely. I know there is a batched_nms, but it's in a completely different format. For onnx boxes with shape [B,N]
, we would need to flatten it to shape [B*N]
, and make a torchvision idxs
tensor which is like N
copies of arange(0,B)
flattened to shape [B*N]
. Then, in the onnx scores
tensor of shape [B, C, N]
, extract each channel one at a time and perform batched_nms. Here is some pseudocode:
flattened_boxes = boxes.flatten()
idxs = arange(0, B).unsqueeze(1).repeat(1, N).flatten()
# or maybe easier:
idxs = floor(arange(0, B, 1/N))
onnx_result = tensor.empty(resultShape)
count = 0
for c in range(C):
channelscores = scores[ : , c , : ]
flattenedscores = channelscores.flatten() # shape = [B*N]
result = torchvision.batched_nms(flattenedboxes, flattenedscores, idxs, iou_threshold)
#result contains the values indexing into [B*N] corresponding to the selected boxes
for res_index in result:
batch_idx = idxs[res_index]
channel_ix = c
box_idx = res_index % batch_idx
onnx_result[count, 0] = batch_idx
onnx_result[count, 1] = channel_idx
onnx_result[count, 2] = box_idx
count++
return onnx_result
Frankly, I'm just guessing how these ops line up since the documentation is a bit limited. It might be good to compare something like the python pseudocode implementation above with the onnx runtime result and see if this is even a viable approach.
Signed-Off By: Vivek Khandelwal [email protected]
nod-ai/SHARK-Turbine#650