Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Jul 4, 2024
1 parent 7eec686 commit ab94aea
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 148 deletions.
107 changes: 90 additions & 17 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2926,23 +2926,96 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
/*layout=*/cstNone, /*requires_grad=*/cstFalse);
return success();
});
patterns.onOp("NonMaxSuppression", 10,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value> operands;
if (binder.tensorOperandsList(operands) ||
binder.tensorResultType(resultType))
return failure();
patterns.onOp(
"NonMaxSuppression", 10,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
SmallVector<Value> operands;
int64_t centerPointBox;
if (binder.tensorOperandsList(operands) ||
binder.s64IntegerAttr(centerPointBox, "center_point_box", 0) ||
binder.tensorResultType(resultType))
return failure();

// TODO: Add support for handling max_output_boxes_per_class
// and score_threshold arg.
Type scalarTy = rewriter.getType<Torch::FloatType>();
Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), scalarTy, operands[3]);
// TODO: Add support for non-zero center_point_box value.
if (centerPointBox != 0)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected center_point_box "
"attribute value to be 0");

rewriter.replaceOpWithNewOp<Torch::TorchvisionNmsOp>(
binder.op, resultType, /*boxes=*/operands[0],
/*scores=*/operands[1], /*iou_threshold=*/iouThreshold);
return success();
});
// TODO: Add support for optional arguments to be absent.
if (operands.size() != 5)
return rewriter.notifyMatchFailure(
binder.op, "unimplemented: expected all 5 args to be present");

// Squeeze the boxes and scores tensor.
// In Onnx, the shape of boxes is [BxNx4] while the
// torchvision expects it to be of shape [Nx4]. Similarly, for
// the scores tensor shape in Onnx is [BxCxN] while the
// torchvision expects it to be of shape [N].
Value boxes = operands[0], scores = operands[1];
FailureOr<Value> squeezedBoxes = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, boxes);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze boxes tensor");

FailureOr<Value> squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, scores);
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");
squeezedScores = Torch::squeezeTensor(
rewriter, binder.op, binder.getLoc(), 0, squeezedScores.value());
if (failed(squeezedBoxes))
return rewriter.notifyMatchFailure(binder.op,
"failed to squeeze scores tensor");

boxes = squeezedBoxes.value();
scores = squeezedScores.value();

// TODO: Add support for handling max_output_boxes_per_class
// and score_threshold arg.
// If num_boxes (N) > max_output_boxes_per_class then the op can't be
// lowered since the torchvision::nms op doesn't have support for
// handling the max_output_boxes_per_class arg.
Value maxOutputBoxesPerClass = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::IntType>(), operands[2]);
Value cstZero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value numBoxes = rewriter.create<Torch::AtenSizeIntOp>(binder.getLoc(),
boxes, cstZero);
Value boxesCond = rewriter.create<Torch::AtenGtIntOp>(
binder.getLoc(), numBoxes, maxOutputBoxesPerClass);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), boxesCond,
rewriter.getStringAttr("unimplemented: number of boxes should be "
"<= max_output_boxes_per_class"));

// If score_threshold < max(scores) then the op can't be lowered since
// the torchvision::nms op doesn't have support for handling the
// score_threshold arg.
Value scoreThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[4]);
Value maxScores = rewriter.create<Torch::AtenMaxOp>(
binder.getLoc(),
Torch::ValueTensorType::get(binder.op->getContext(), {},
rewriter.getF32Type()),
scores);
maxScores = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), maxScores);

Value scoresCond = rewriter.create<Torch::AtenGtFloatOp>(
binder.getLoc(), maxScores, scoreThreshold);
rewriter.create<Torch::RuntimeAssertOp>(
binder.getLoc(), scoresCond,
rewriter.getStringAttr(
"unimplemented: score_threshold should be >= max(scores)"));

Value iouThreshold = rewriter.create<Torch::AtenItemOp>(
binder.getLoc(), rewriter.getType<Torch::FloatType>(), operands[3]);
rewriter.replaceOpWithNewOp<Torch::TorchvisionNmsOp>(
binder.op, resultType, boxes, scores, iouThreshold);
return success();
});
}
Loading

0 comments on commit ab94aea

Please sign in to comment.