Skip to content

Commit

Permalink
[Tosa] : Add support for negative indices in index.tensor and index.T…
Browse files Browse the repository at this point in the history
…ensor_hacked_twin for TorchToTosa lowering. (llvm#3790)

1. Negative indices for tensor indexing is handled by wrapping around
the index values by checking their values at run time. Without the fix,
there was a runtime error.
2. Added a lit test to lock down the behavior.
3. Updated the `xfails_set` for `fx_importer_tosa` config to lockdown
the behavior with e2e test as well.

"THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS
IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED
TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY."
  • Loading branch information
sahas3 authored Oct 25, 2024
1 parent 54d9e24 commit 2b01f8b
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 38 deletions.
83 changes: 48 additions & 35 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4093,6 +4093,25 @@ LogicalResult ConvertAtenOp<AtenIndexPutHackedTwinOp>::matchAndRewrite(
return success();
}

Value wrapNegativeIndices(Value index, int maxIndex, Operation *op,
ConversionPatternRewriter &rewriter) {

auto zeroValue = tosa::getConstTensor<int32_t>(rewriter, op, 0, {}).value();
auto maxIndexValue =
tosa::getConstTensor<int32_t>(rewriter, op, maxIndex, {}).value();

auto indexType = dyn_cast<RankedTensorType>(index.getType());

auto wrappedIndicesOp = tosa::CreateOpAndInfer<tosa::AddOp>(
rewriter, op->getLoc(), indexType, maxIndexValue, index);
auto boolType = indexType.clone(rewriter.getIntegerType(1));
auto isNegativeIndices = tosa::CreateOpAndInfer<tosa::GreaterOp>(
rewriter, op->getLoc(), boolType, zeroValue, index);
return tosa::CreateOpAndInfer<tosa::SelectOp>(rewriter, op->getLoc(),
indexType, isNegativeIndices,
wrappedIndicesOp, index);
}

template <>
LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
AtenIndexTensorHackedTwinOp op, OpAdaptor adaptor,
Expand Down Expand Up @@ -4124,6 +4143,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(

auto outType = getTypeConverter()->convertType(op.getType());

Operation *indicesTf;

// Support for multiple indexes
if (indexTensors.size() > 1) {
// t[i, i]
Expand Down Expand Up @@ -4157,6 +4178,8 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
index);
}

index = wrapNegativeIndices(index, inputTensorType.getShape()[i], op,
rewriter);
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indiceShapeOneDim;
for (auto shape : indexShape) {
Expand Down Expand Up @@ -4299,57 +4322,47 @@ LogicalResult ConvertAtenOp<AtenIndexTensorHackedTwinOp>::matchAndRewrite(
auto indicesShapeConcat = indexesShape[0];
uint64_t lastDim = indexesRank[0];
indicesShapeConcat.push_back(indicesTfConcatTensors.size());
auto indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
indicesTf = tosa::CreateOpAndInfer<tosa::ConcatOp>(
rewriter, op->getLoc(),
GetTypeFromTensorShape(indicesShapeConcat, rewriter.getIntegerType(32)),
indicesTfConcatTensors, lastDim);

if (!indicesTf) {
return rewriter.notifyMatchFailure(
op, "Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
} else {

if (!result) {
return rewriter.notifyMatchFailure(
op, "Convert GatherNdOp fail for index tensor.");
// Single index
auto index = indexTensors[0];
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)),
index);
}
rewriter.replaceOp(op, {result.value()});

return success();
}
index =
wrapNegativeIndices(index, inputTensorType.getShape()[0], op, rewriter);

// Support for multiple index
auto index = indexTensors[0];
auto indexType = dyn_cast<RankedTensorType>(index.getType());
auto indexShape = indexType.getShape();
// index i64 to i32 for tosa compatible
if (indexType.getElementType() != rewriter.getIntegerType(32)) {
index = rewriter.create<tosa::CastOp>(
op->getLoc(),
RankedTensorType::get(indexShape, rewriter.getIntegerType(32)), index);
}

// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
// Expand last dim of index to tf indices [2,3] -> [2,3,1]
SmallVector<int64_t> indicesShape;
for (auto shape : indexShape) {
indicesShape.push_back(shape);
}
indicesShape.push_back(1);
indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));
}
indicesShape.push_back(1);
auto indicesTf = tosa::CreateOpAndInfer<tosa::ReshapeOp>(
rewriter, op->getLoc(),
RankedTensorType::get(indicesShape, rewriter.getIntegerType(32)), index,
rewriter.getDenseI64ArrayAttr(indicesShape));

if (!indicesTf) {
return rewriter.notifyMatchFailure(op,
"Convert TorchIndex To TfIndices fail.");
}
// do the tf gathernp algorithm with tf style indices as input.
auto result = tosa::convertGatherNdOp(rewriter, op, outType, input,
indicesTf.getResult());
indicesTf->getResult(0));

if (!result) {
return rewriter.notifyMatchFailure(
Expand Down
4 changes: 1 addition & 3 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1698,15 +1698,13 @@
"ArangeStartOutModule_basic",
"ScatterSrcStaticModule_basic",
# Runtime op verification: Out of bounds access
"IndexTensorNegativeIndexModule_basic",
"ReduceAllDimEmpty_basic",
}

FX_IMPORTER_TOSA_CRASHING_SET = {
"ScatterSrcModule_basic",
"ScatterSrcStaticModule_basic",
"HBC_basic",
"IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_scales_recompute_bilinear",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
Expand Down Expand Up @@ -2162,6 +2160,7 @@
"HardswishRandomModule_basic",
"HardtanhBackward_basic",
"IndexTensorMultiIndexStaticModule_basic",
"IndexTensorNegativeIndexModule_basic",
"IndexTensorStaticModule_basic",
"IscloseStaticModuleTrue_basic",
"IscloseStaticModule_basic",
Expand Down Expand Up @@ -3635,7 +3634,6 @@
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexPutImplIndexWithNoneModule_basic",
"IndexSelectRank0IdxModule_basic",
"IndexTensorNegativeIndexModule_basic",
"InterpolateDynamicModule_sizes_bilinear",
"InterpolateDynamicModule_sizes_nearest",
"InterpolateStaticModule_scales_bilinear_align_corners",
Expand Down
32 changes: 32 additions & 0 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2131,3 +2131,35 @@ func.func @torch.aten.diag_embed$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !t
%0 = torch.aten.diag_embed %arg0, %int0, %int-2, %int-1 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,4],f32>
return %0 : !torch.vtensor<[2,3,4,4],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.index.Tensor_hacked_twin(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,4,2],si64>,
// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,4,2],si64> -> tensor<2x4x2xi64>
// CHECK: %[[VAL_1:.*]] = torch.prim.ListConstruct %[[ARG1]] : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[],si64> -> tensor<i64>
// CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<i64>) -> tensor<i32>
// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2> : tensor<i32>}> : () -> tensor<i32>
// CHECK: %[[VAL_6:.*]] = tosa.add %[[VAL_5]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_7:.*]] = tosa.greater %[[VAL_4]], %[[VAL_3]] : (tensor<i32>, tensor<i32>) -> tensor<i1>
// CHECK: %[[VAL_8:.*]] = tosa.select %[[VAL_7]], %[[VAL_6]], %[[VAL_3]] : (tensor<i1>, tensor<i32>, tensor<i32>) -> tensor<i32>
// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]] {new_shape = array<i64: 1>} : (tensor<i32>) -> tensor<1xi32>
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_0]] {new_shape = array<i64: 1, 2, 8>} : (tensor<2x4x2xi64>) -> tensor<1x2x8xi64>
// CHECK: %[[VAL_11:.*]] = tosa.reshape %[[VAL_9]] {new_shape = array<i64: 1, 1>} : (tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32>
// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_11]], %[[VAL_12]] {shift = 0 : i8} : (tensor<1x1xi32>, tensor<1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_14:.*]] = tosa.reduce_sum %[[VAL_13]] {axis = 1 : i32} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_15:.*]] = tosa.reshape %[[VAL_14]] {new_shape = array<i64: 1, 1>} : (tensor<1x1xi32>) -> tensor<1x1xi32>
// CHECK: %[[VAL_16:.*]] = tosa.gather %[[VAL_10]], %[[VAL_15]] : (tensor<1x2x8xi64>, tensor<1x1xi32>) -> tensor<1x1x8xi64>
// CHECK: %[[VAL_17:.*]] = tosa.reshape %[[VAL_16]] {new_shape = array<i64: 4, 2>} : (tensor<1x1x8xi64>) -> tensor<4x2xi64>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[VAL_17]] : tensor<4x2xi64> -> !torch.vtensor<[4,2],si64>
// CHECK: return %[[RESULT]] : !torch.vtensor<[4,2],si64>

func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si64>, %arg1: !torch.vtensor<[],si64>) -> !torch.vtensor<[4,2],si64> {
%0 = torch.prim.ListConstruct %arg1 : (!torch.vtensor<[],si64>) -> !torch.list<vtensor>
%1 = torch.aten.index.Tensor_hacked_twin %arg0, %0 : !torch.vtensor<[2,4,2],si64>, !torch.list<vtensor> -> !torch.vtensor<[4,2],si64>
return %1 : !torch.vtensor<[4,2],si64>
}

0 comments on commit 2b01f8b

Please sign in to comment.