-
Notifications
You must be signed in to change notification settings - Fork 47
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add
--convert-tensor-to-scalars
pass
- Loading branch information
1 parent
b93157c
commit f077a30
Showing
6 changed files
with
244 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "TensorToScalars", | ||
srcs = ["TensorToScalars.cpp"], | ||
hdrs = ["TensorToScalars.h"], | ||
deps = [ | ||
":pass_inc_gen", | ||
"@heir//lib/Conversion:Utils", | ||
"@llvm-project//mlir:IR", | ||
"@llvm-project//mlir:Pass", | ||
"@llvm-project//mlir:Support", | ||
"@llvm-project//mlir:TensorDialect", | ||
"@llvm-project//mlir:Transforms", | ||
], | ||
) | ||
|
||
gentbl_cc_library( | ||
name = "pass_inc_gen", | ||
tbl_outs = [ | ||
( | ||
[ | ||
"-gen-pass-decls", | ||
"-name=TensorToScalars", | ||
], | ||
"TensorToScalars.h.inc", | ||
), | ||
( | ||
["-gen-pass-doc"], | ||
"TensorToScalarsPasses.md", | ||
), | ||
], | ||
tblgen = "@llvm-project//mlir:mlir-tblgen", | ||
td_file = "TensorToScalars.td", | ||
deps = [ | ||
"@llvm-project//mlir:OpBaseTdFiles", | ||
"@llvm-project//mlir:PassBaseTdFiles", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
#include "lib/Transforms/TensorToScalars/TensorToScalars.h" | ||
|
||
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h" // from @llvm-project | ||
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project | ||
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project | ||
#include "mlir/include/mlir/Transforms/OneToNTypeConversion.h" // from @llvm-project | ||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DEF_TENSORTOSCALARS | ||
#include "lib/Transforms/TensorToScalars/TensorToScalars.h.inc" | ||
|
||
static std::optional<Value> buildFromElementsOp(OpBuilder &builder, | ||
RankedTensorType resultType, | ||
ValueRange inputs, | ||
Location loc) { | ||
return builder.create<tensor::FromElementsOp>(loc, resultType, inputs); | ||
} | ||
|
||
static std::optional<SmallVector<Value>> buildExtractOps(OpBuilder &builder, | ||
TypeRange resultTypes, | ||
Value input, | ||
Location loc) { | ||
// This conversion only operates on tensors of static shape | ||
RankedTensorType inputType = dyn_cast<RankedTensorType>(input.getType()); | ||
if (!inputType || !inputType.hasStaticShape()) return {}; | ||
|
||
// Create extract ops in "natural" order (dimension-by-dimension) | ||
SmallVector<Value> values; | ||
for (auto dim : inputType.getShape()) { | ||
for (int i = 0; i < dim; ++i) { | ||
Value index = builder.create<arith::ConstantIndexOp>(loc, i); | ||
Value element = builder.create<tensor::ExtractOp>(loc, input, index); | ||
values.push_back(element); | ||
} | ||
} | ||
return values; | ||
} | ||
|
||
class ConvertFromElementsOp | ||
: public OneToNOpConversionPattern<tensor::FromElementsOp> { | ||
public: | ||
using OneToNOpConversionPattern< | ||
tensor::FromElementsOp>::OneToNOpConversionPattern; | ||
|
||
LogicalResult matchAndRewrite( | ||
tensor::FromElementsOp op, OpAdaptor adaptor, | ||
OneToNPatternRewriter &rewriter) const override { | ||
// This conversion only operates on tensors of static shape, | ||
// but no check is necessary here as from_elements' shape is always static | ||
|
||
// Replace the current op with the flattened operands. | ||
// This should already match the "natural" order expected by this pass. | ||
rewriter.replaceOp(op, adaptor.getFlatOperands(), | ||
adaptor.getResultMapping()); | ||
return success(); | ||
} | ||
}; | ||
|
||
class ConvertInsertOp : public OneToNOpConversionPattern<tensor::InsertOp> { | ||
public: | ||
using OneToNOpConversionPattern<tensor::InsertOp>::OneToNOpConversionPattern; | ||
|
||
LogicalResult matchAndRewrite( | ||
tensor::InsertOp op, OpAdaptor adaptor, | ||
OneToNPatternRewriter &rewriter) const override { | ||
// This conversion only operates on tensors of static shape | ||
if (!op.getResult().getType().hasStaticShape()) return failure(); | ||
|
||
// We can only support statically known indices | ||
// that have been constant-folded to a single arith.constant op | ||
for (auto idx : op.getIndices()) { | ||
if (!llvm::isa<arith::ConstantIndexOp>(idx.getDefiningOp())) | ||
return failure(); | ||
} | ||
|
||
// Compute the insertion offset (in dimension-by-dimension order): | ||
int64_t multiplier = 1; | ||
int64_t offset = 0; | ||
for (auto [dim, idx] : | ||
llvm::zip(op.getResult().getType().getShape(), op.getIndices())) { | ||
offset += | ||
idx.getDefiningOp<arith::ConstantIndexOp>().value() * multiplier; | ||
multiplier *= dim; | ||
} | ||
|
||
// get converted "tensor" operand from op (likely a unrealized_builtin_cast) | ||
SmallVector<Value> elements = adaptor.getOperands()[1]; | ||
// replace element at offset with the "scalar" operand to be inserted | ||
elements[offset] = adaptor.getOperands()[0].front(); | ||
// replace the current op with the converted operands. | ||
rewriter.replaceOp(op, elements, adaptor.getResultMapping()); | ||
return success(); | ||
} | ||
}; | ||
|
||
struct TensorToScalars : impl::TensorToScalarsBase<TensorToScalars> { | ||
using TensorToScalarsBase::TensorToScalarsBase; | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
|
||
OneToNTypeConverter typeConverter; | ||
typeConverter.addConversion([](Type type) { return type; }); | ||
typeConverter.addConversion( | ||
[](RankedTensorType tensorType, | ||
SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> { | ||
if (!tensorType.hasStaticShape()) return failure(); | ||
int count = 1; | ||
for (auto dim : tensorType.getShape()) { | ||
if (dim > 0) count *= dim; | ||
} | ||
types = SmallVector<Type>(count, tensorType.getElementType()); | ||
return success(); | ||
}); | ||
typeConverter.addArgumentMaterialization(buildFromElementsOp); | ||
typeConverter.addSourceMaterialization(buildFromElementsOp); | ||
typeConverter.addTargetMaterialization(buildExtractOps); | ||
|
||
RewritePatternSet patterns(context); | ||
patterns.add<ConvertFromElementsOp, ConvertInsertOp>(typeConverter, | ||
context); | ||
scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns); | ||
populateFuncTypeConversionPatterns(typeConverter, patterns); | ||
|
||
if (mlir::failed(mlir::applyPartialOneToNConversion( | ||
getOperation(), typeConverter, std::move(patterns)))) | ||
signalPassFailure(); | ||
|
||
// Empty PatternSet = only run folders (should never fail) | ||
RewritePatternSet emptyPatterns(context); | ||
(void)applyPatternsAndFoldGreedily(getOperation(), | ||
std::move(emptyPatterns)); | ||
} | ||
}; | ||
|
||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
#ifndef LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_H_ | ||
#define LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_H_ | ||
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project | ||
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
|
||
#define GEN_PASS_DECL | ||
#include "lib/Transforms/TensorToScalars/TensorToScalars.h.inc" | ||
|
||
#define GEN_PASS_REGISTRATION | ||
#include "lib/Transforms/TensorToScalars/TensorToScalars.h.inc" | ||
|
||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#ifndef LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_TD_ | ||
#define LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_TD_ | ||
|
||
include "mlir/Pass/PassBase.td" | ||
|
||
|
||
def TensorToScalars : Pass<"convert-tensor-to-scalars"> { | ||
let summary = "Effectively 'unrolls' tensors of static shape to scalars"; | ||
let description = [{ | ||
This pass will convert a static-shaped tensor type to a TypeRange | ||
containing product(dim) copies of the element type of the tensor. | ||
This pass currently includes two patterns: | ||
1. It converts tensor.from_elements operations to | ||
the corresponding scalar inputs. | ||
2. It converts tensor.insert operations by updating the | ||
ValueRange corresponding to the converted input and r | ||
updating it with the scalar to be inserted. | ||
It also applies folders greedily to simplify, e.g., extract(from_elements). | ||
|
||
Note: The pass is designed to be run on an IR, where the only operations | ||
with tensor typed operands are tensor "management" operations such as insert/extract, | ||
with all other operations (e.g., arith operations) already taking (extracted) scalar inputs. | ||
For example, an IR where elementwise operations have been converted to scalar operations via | ||
`--convert-elementwise-to-affine`. | ||
|
||
The pass might insert new tensor.from_elements operations or manually create the scalar ValueRange | ||
via inserting tensor.extract operations if any operations remain that operate on tensors. | ||
The pass currently applies irrespective of tensor size, i.e., might be very slow for large tensors. | ||
}]; | ||
let dependentDialects = [ | ||
"tensor::TensorDialect" | ||
]; | ||
} | ||
|
||
#endif // LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_TD_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters