Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] Add OnnxToTorch lowering for Onnx.ImageDecoder op #3478

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ Value createConstantIntList(OpBinder binder,
ConversionPatternRewriter &rewriter,
ArrayRef<int64_t> cstInput);

Value createConstantFloatList(OpBinder binder,
ConversionPatternRewriter &rewriter,
ArrayRef<double> cstInput);

Torch::ValueTensorType getQTorchTypeFromTorchIntType(Type ty);

template <typename T>
Expand Down
77 changes: 77 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2731,4 +2731,81 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, tensorListResultType, input);
return success();
});
patterns.onOp(
"ImageDecoder", 20,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Value encodedImage;
Torch::ValueTensorType resultType;
std::string pixelFormat;
if (binder.tensorOperand(encodedImage) ||
binder.tensorResultType(resultType) ||
binder.customOpNameStringAttr(pixelFormat, "pixel_format", "RGB"))
return failure();

Value zero = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(0));
Value one = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(1));
Value floatType = rewriter.create<Torch::ConstantIntOp>(
binder.getLoc(), rewriter.getI64IntegerAttr(6));
Value none = rewriter.create<Torch::ConstantNoneOp>(binder.getLoc());
Value cstFalse =
rewriter.create<Torch::ConstantBoolOp>(binder.getLoc(), false);

auto encodedImageTy =
cast<Torch::ValueTensorType>(encodedImage.getType());
auto encodedImageShape = encodedImageTy.getSizes();

Value decodedImage;
if (pixelFormat == "BGR") {
// FLip the encoded image tensor across the last dimension.
Value axisToFlip = createConstantIntList(binder, rewriter, {2});
decodedImage = rewriter.create<Torch::AtenFlipOp>(
binder.getLoc(), resultType, encodedImage, axisToFlip);
} else if (pixelFormat == "RGB") {
// Do nothing, as this is already the default mode.
decodedImage = encodedImage;
} else if (pixelFormat == "Grayscale") {
if (encodedImageShape.size() != 3 || encodedImageShape[2] != 3)
return rewriter.notifyMatchFailure(
binder.op, "An input image of shape (H,W,3) is required "
"for pixel_format='Grayscale'");

// This scaling list is created based on ITU-R Rec. 601-7.
// These scaling factors are used by torchvision as well.
Value scalingList = createConstantFloatList(binder, rewriter,
{0.2989, 0.5870, 0.1140});
auto scalingListTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>{3}, rewriter.getF64Type());
scalingList = rewriter.create<Torch::AtenTensorOp>(
binder.getLoc(), scalingListTy, scalingList, floatType, none,
cstFalse);

// Unsqueeze the list of scaling factors.
auto unsqueezeResultTy = rewriter.getType<Torch::ValueTensorType>(
ArrayRef<int64_t>{3, 1}, rewriter.getF64Type());
scalingList = rewriter.create<Torch::AtenUnsqueezeOp>(
binder.getLoc(), unsqueezeResultTy, scalingList, one);

// The input encoded image has shape (H,W,3), and the scaling list has
// shape (3,1). A matmul operation will output a tensor of shape
// (H,W,1), which after squeezing at dim=2, will be equivalent to
// unbinding the channels of the image, multiplying each channel by
// the corresponding scaling factor, and then adding the resulting
// tensors. We do not squeeze the tensor, to preserve the resultType.
decodedImage = rewriter.create<Torch::AtenMatmulOp>(
binder.getLoc(), resultType, encodedImage, scalingList);
} else {
return rewriter.notifyMatchFailure(
binder.op, "Unsupported value for pixel_format");
}

// Cast the decoded image to uint8 type
rewriter.replaceOpWithNewOp<Torch::AtenToDtypeOp>(
binder.op, resultType, decodedImage, /*uInt8Type=*/zero,
/*non_blocking=*/cstFalse, /*copy=*/cstFalse,
/*memory_format=*/none);

return success();
});
}
14 changes: 14 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ Value mlir::torch::onnx_c::createConstantIntList(
cstValue);
}

Value mlir::torch::onnx_c::createConstantFloatList(
OpBinder binder, ConversionPatternRewriter &rewriter,
ArrayRef<double> cstInput) {
SmallVector<Value> cstValue;
for (double i : cstInput) {
cstValue.push_back(rewriter.create<Torch::ConstantFloatOp>(
binder.getLoc(), rewriter.getF64FloatAttr(i)));
}
return rewriter.create<Torch::PrimListConstructOp>(
binder.getLoc(),
Torch::ListType::get(Torch::FloatType::get(binder.op->getContext())),
cstValue);
}

Torch::ValueTensorType
mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) {
Torch::ValueTensorType tty = dyn_cast<Torch::ValueTensorType>(ty);
Expand Down
53 changes: 53 additions & 0 deletions test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1543,3 +1543,56 @@ func.func @test_optional_get_element_tensor(%arg0: !torch.vtensor<[4],f32>) -> !
%0 = torch.operator "onnx.OptionalGetElement"(%arg0) : (!torch.vtensor<[4],f32>) -> !torch.vtensor<[4],f32>
return %0 : !torch.vtensor<[4],f32>
}

// -----

// CHECK-LABEL: func.func @test_image_decoder_decode_jpeg_bgr
func.func @test_image_decoder_decode_jpeg_bgr(%arg0: !torch.vtensor<[32,32,3],ui8>) -> !torch.vtensor<[32,32,3],ui8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[FLOATTYPE:.*]] = torch.constant.int 6
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
// CHECK: %[[INT2_0:.*]] = torch.constant.int 2
// CHECK: %[[PRIMLIST:.*]] = torch.prim.ListConstruct %[[INT2_0]] : (!torch.int) -> !torch.list<int>
// CHECK: %[[FLIP:.*]] = torch.aten.flip %arg0, %[[PRIMLIST]] : !torch.vtensor<[32,32,3],ui8>, !torch.list<int> -> !torch.vtensor<[32,32,3],ui8>
// CHECK: %[[CAST:.*]] = torch.aten.to.dtype %[[FLIP]], %[[INT0_0]], %[[FALSEVAL]], %[[FALSEVAL]], %[[NONEVAL]] : !torch.vtensor<[32,32,3],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[32,32,3],ui8>
// CHECK: return %[[CAST]] : !torch.vtensor<[32,32,3],ui8>
%0 = torch.operator "onnx.ImageDecoder"(%arg0) {torch.onnx.pixel_format = "BGR"} : (!torch.vtensor<[32,32,3],ui8>) -> !torch.vtensor<[32,32,3],ui8>
return %0 : !torch.vtensor<[32,32,3],ui8>
Comment on lines +1560 to +1561
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation assumes that the image in the respective format has already been loaded and converted to an appropriate tensor representation for simplicity, and therefore has different op semantics than the original Onnx definition.

This op takes an encoded stream of bytes (e.g. !torch.vtensor<[1058],ui8>) and decodes it. This is changing the top to take different inputs (an already decoded image, e.g. !torch.vtensor<[32,32,3],ui8>) and perform a different computation.

Here's an imported test case from the ONNX test suite using similar inputs: https://github.com/nod-ai/SHARK-TestSuite/blob/main/iree_tests/onnx/node/generated/test_image_decoder_decode_jpeg_bgr/model.mlir

module {
  func.func @test_image_decoder_decode_jpeg_bgr(%arg0: !torch.vtensor<[1058],ui8>) -> !torch.vtensor<[32,32,3],ui8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
    %none = torch.constant.none
    %0 = torch.operator "onnx.ImageDecoder"(%arg0) {torch.onnx.pixel_format = "BGR"} : (!torch.vtensor<[1058],ui8>) -> !torch.vtensor<[32,32,3],ui8> 
    return %0 : !torch.vtensor<[32,32,3],ui8>
  }
}

Are there any other cases in torch-mlir where an op definition is changed like this? For this to work at all, input ONNX models and/or the ONNX importer would need to be changed to use this different op. I'm deeply skeptical about checking in code like this that uses the same name as the original op but with an entirely different implementation - that's a recipe for confusion and maintenance costs later on.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @ScottTodd , I can totally understand your concern, but I am extremely limited by number of ways to overcome this issue of loading the image tensor, and I am very open to any tips you might have for this too.

However, all the steps that I follow after taking the input are logically correct and the code in the PR is closely modelled after the onnx reference implementation.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and the code in the PR is closely modelled after the onnx reference implementation.

Which reference implementation are you looking at? The one I see is https://github.com/onnx/onnx/blob/main/onnx/reference/ops/op_image_decoder.py and that is calling

img = PIL.Image.open(io.BytesIO(encoded.tobytes()))

that's not something we can hand-wave away - it's a large chunk of code bundled into a complicated library, incompatible with this style of compiler / code generator.

Changing the op definition but using the same name does not count as "supporting" an op. An incorrect implementation is worse than no implementation. We could lower via a custom op somehow to backends that want to use their own implementation, but adding this style of lowering would prevent that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's not something we can hand-wave away - it's a large chunk of code bundled into a complicated library, incompatible with this style of compiler / code generator.

Exactly! But claiming support for the op appears to be a priority, and hence the only way that at the moment seems to get anywhere close to that, I implemented in this PR. I have no issues if we go ahead and decide to close this one as not feasible, as I agree with your opinions. But as I said, the use of PIL(and hence the large amount of bundled code) is an extremely limiting factor in terms of replication through compiler codegen.

So the decision is yours, whether the PR is reasonable, or not.

}

// -----

// CHECK-LABEL: func.func @test_image_decoder_decode_rgb
func.func @test_image_decoder_decode_rgb(%arg0: !torch.vtensor<[32,32,3],ui8>) -> !torch.vtensor<[32,32,3],ui8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[FLOATTYPE:.*]] = torch.constant.int 6
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
// CHECK: %[[CAST:.*]] = torch.aten.to.dtype %arg0, %[[INT0_0]], %[[FALSEVAL]], %[[FALSEVAL]], %[[NONEVAL]] : !torch.vtensor<[32,32,3],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[32,32,3],ui8>
// CHECK: return %[[CAST]] : !torch.vtensor<[32,32,3],ui8>
%0 = torch.operator "onnx.ImageDecoder"(%arg0) {torch.onnx.pixel_format = "RGB"} : (!torch.vtensor<[32,32,3],ui8>) -> !torch.vtensor<[32,32,3],ui8>
return %0 : !torch.vtensor<[32,32,3],ui8>
}

// -----

// CHECK-LABEL: func.func @test_image_decoder_decode_grayscale
func.func @test_image_decoder_decode_grayscale(%arg0: !torch.vtensor<[32,32,3],ui8>) -> !torch.vtensor<[32,32,1],ui8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} {
// CHECK: %[[INT0_0:.*]] = torch.constant.int 0
// CHECK: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK: %[[FLOATTYPE:.*]] = torch.constant.int 6
// CHECK: %[[NONEVAL:.*]] = torch.constant.none
// CHECK: %[[FALSEVAL:.*]] = torch.constant.bool false
// CHECK: %[[RSCALE:.*]] = torch.constant.float 2.989000e-01
// CHECK: %[[GSCALE:.*]] = torch.constant.float 5.870000e-01
// CHECK: %[[BSCALE:.*]] = torch.constant.float 1.140000e-01
// CHECK: %[[SCALELIST:.*]] = torch.prim.ListConstruct %[[RSCALE]], %[[GSCALE]], %[[BSCALE]] : (!torch.float, !torch.float, !torch.float) -> !torch.list<float>
// CHECK: %[[SCALETENSOR:.*]] = torch.aten.tensor %[[SCALELIST]], %[[FLOATTYPE]], %[[NONEVAL]], %[[FALSEVAL]] : !torch.list<float>, !torch.int, !torch.none, !torch.bool -> !torch.vtensor<[3],f64>
// CHECK: %[[UNSQUEEZE:.*]] = torch.aten.unsqueeze %[[SCALETENSOR]], %[[INT1_0]] : !torch.vtensor<[3],f64>, !torch.int -> !torch.vtensor<[3,1],f64>
// CHECK: %[[MATMUL:.*]] = torch.aten.matmul %arg0, %[[UNSQUEEZE]] : !torch.vtensor<[32,32,3],ui8>, !torch.vtensor<[3,1],f64> -> !torch.vtensor<[32,32,1],ui8>
// CHECK: %[[CAST:.*]] = torch.aten.to.dtype %[[MATMUL]], %[[INT0_0]], %[[FALSEVAL]], %[[FALSEVAL]], %[[NONEVAL]] : !torch.vtensor<[32,32,1],ui8>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[32,32,1],ui8>
// CHECK: return %[[CAST]] : !torch.vtensor<[32,32,1],ui8>
%0 = torch.operator "onnx.ImageDecoder"(%arg0) {torch.onnx.pixel_format = "Grayscale"} : (!torch.vtensor<[32,32,3],ui8>) -> !torch.vtensor<[32,32,1],ui8>
return %0 : !torch.vtensor<[32,32,1],ui8>
}
Loading