Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ONNX] Add OnnxToTorch lowering for NonMaxSuppression op #3501

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

vivekkhandelwal1
Copy link
Collaborator

@vivekkhandelwal1 vivekkhandelwal1 commented Jun 27, 2024

Signed-Off By: Vivek Khandelwal [email protected]

nod-ai/SHARK-Turbine#650

Copy link
Collaborator

@zjgarvey zjgarvey left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Vivek,

I think this op needs quite a bit of case-checking, since we cannot reasonably support all of the optional arguments.

The lit tests are numerous and seem redundant to me. Is there any reason for adding so many of them?

lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp Outdated Show resolved Hide resolved
test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir Outdated Show resolved Hide resolved
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
rewriter.replaceOpWithNewOp<Torch::TorchvisionNmsOp>(
binder.op, resultType, boxes, scores, iouThreshold);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The result type of onnx.NonMaxSuppression doesn't match TorchvisionNmsOp. It looks like nms has output shape [n] where n is the number of selected indices out of N (which I assume corresponds to onnx's spatial_dimension). The onnx op has output shape [n, 3] (the three is the tuple batch_index, channel_index , and box_index -- or I assume spatial_index by their docs), so we might need to concatenate the torchvision result with a torch.zeroes([n,2]), since this conversion will only work when the batches and channels are both 1.

I'm wondering if these two ops are sufficiently different to consider something else entirely. I know there is a batched_nms, but it's in a completely different format. For onnx boxes with shape [B,N], we would need to flatten it to shape [B*N], and make a torchvision idxs tensor which is like N copies of arange(0,B) flattened to shape [B*N]. Then, in the onnx scores tensor of shape [B, C, N], extract each channel one at a time and perform batched_nms. Here is some pseudocode:

flattened_boxes = boxes.flatten()
idxs = arange(0, B).unsqueeze(1).repeat(1, N).flatten()
# or maybe easier:
idxs = floor(arange(0, B, 1/N))
onnx_result = tensor.empty(resultShape)
count = 0
for c in range(C):
   channelscores = scores[ : , c , : ]
   flattenedscores = channelscores.flatten() # shape = [B*N]
   result = torchvision.batched_nms(flattenedboxes, flattenedscores, idxs, iou_threshold)
   #result contains the values indexing into [B*N] corresponding to the selected boxes
   for res_index in result:
      batch_idx = idxs[res_index]
      channel_ix = c
      box_idx = res_index % batch_idx
      onnx_result[count, 0] = batch_idx
      onnx_result[count, 1] = channel_idx
      onnx_result[count, 2] = box_idx
      count++
return onnx_result

Frankly, I'm just guessing how these ops line up since the documentation is a bit limited. It might be good to compare something like the python pseudocode implementation above with the onnx runtime result and see if this is even a viable approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants