Skip to content

Commit

Permalink
add --convert-tensor-to-scalars pass
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderViand-Intel committed Jul 5, 2024
1 parent b09eccb commit c4d5c5f
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 0 deletions.
44 changes: 44 additions & 0 deletions lib/Transforms/TensorToScalars/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

cc_library(
name = "TensorToScalars",
srcs = ["TensorToScalars.cpp"],
hdrs = ["TensorToScalars.h"],
deps = [
":pass_inc_gen",
"@heir//lib/Conversion:Utils",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TensorDialect",
"@llvm-project//mlir:Transforms",
],
)

gentbl_cc_library(
name = "pass_inc_gen",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=TensorToScalars",
],
"TensorToScalars.h.inc",
),
(
["-gen-pass-doc"],
"TensorToScalarsPasses.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "TensorToScalars.td",
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:PassBaseTdFiles",
],
)
144 changes: 144 additions & 0 deletions lib/Transforms/TensorToScalars/TensorToScalars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
#include "lib/Transforms/TensorToScalars/TensorToScalars.h"

#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/Transforms/FuncConversions.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/SCF/Transforms/Patterns.h" // from @llvm-project
#include "mlir/include/mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/OneToNTypeConversion.h" // from @llvm-project
namespace mlir {
namespace heir {

#define GEN_PASS_DEF_TENSORTOSCALARS
#include "lib/Transforms/TensorToScalars/TensorToScalars.h.inc"

static std::optional<Value> buildFromElementsOp(OpBuilder &builder,
RankedTensorType resultType,
ValueRange inputs,
Location loc) {
return builder.create<tensor::FromElementsOp>(loc, resultType, inputs);
}

static std::optional<SmallVector<Value>> buildExtractOps(OpBuilder &builder,
TypeRange resultTypes,
Value input,
Location loc) {
// This conversion only operates on tensors of static shape
RankedTensorType inputType = dyn_cast<RankedTensorType>(input.getType());
if (!inputType || !inputType.hasStaticShape()) return {};

// Create extract ops in "natural" order (dimension-by-dimension)
SmallVector<Value> values;
for (auto dim : inputType.getShape()) {
for (int i = 0; i < dim; ++i) {
Value index = builder.create<arith::ConstantIndexOp>(loc, i);
Value element = builder.create<tensor::ExtractOp>(loc, input, index);
values.push_back(element);
}
}
return values;
}

class ConvertFromElementsOp
: public OneToNOpConversionPattern<tensor::FromElementsOp> {
public:
using OneToNOpConversionPattern<
tensor::FromElementsOp>::OneToNOpConversionPattern;

LogicalResult matchAndRewrite(
tensor::FromElementsOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
// This conversion only operates on tensors of static shape,
// but no check is necessary here as from_elements' shape is always static

// Replace the current op with the flattened operands.
// This should already match the "natural" order expected by this pass.
rewriter.replaceOp(op, adaptor.getFlatOperands(),
adaptor.getResultMapping());
return success();
}
};

class ConvertInsertOp : public OneToNOpConversionPattern<tensor::InsertOp> {
public:
using OneToNOpConversionPattern<tensor::InsertOp>::OneToNOpConversionPattern;

LogicalResult matchAndRewrite(
tensor::InsertOp op, OpAdaptor adaptor,
OneToNPatternRewriter &rewriter) const override {
// This conversion only operates on tensors of static shape
if (!op.getResult().getType().hasStaticShape()) return failure();

// We can only support statically known indices
// that have been constant-folded to a single arith.constant op
for (auto idx : op.getIndices()) {
if (!llvm::isa<arith::ConstantIndexOp>(idx.getDefiningOp()))
return failure();
}

// Compute the insertion offset (in dimension-by-dimension order):
int64_t multiplier = 1;
int64_t offset = 0;
for (auto [dim, idx] :
llvm::zip(op.getResult().getType().getShape(), op.getIndices())) {
offset +=
idx.getDefiningOp<arith::ConstantIndexOp>().value() * multiplier;
multiplier *= dim;
}

// get converted "tensor" operand from op (likely a unrealized_builtin_cast)
SmallVector<Value> elements = adaptor.getOperands()[1];
// replace element at offset with the "scalar" operand to be inserted
elements[offset] = adaptor.getOperands()[0].front();
// replace the current op with the converted operands.
rewriter.replaceOp(op, elements, adaptor.getResultMapping());
return success();
}
};

struct TensorToScalars : impl::TensorToScalarsBase<TensorToScalars> {
using TensorToScalarsBase::TensorToScalarsBase;

void runOnOperation() override {
MLIRContext *context = &getContext();

OneToNTypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
typeConverter.addConversion(
[](RankedTensorType tensorType,
SmallVectorImpl<Type> &types) -> std::optional<LogicalResult> {
if (!tensorType.hasStaticShape()) return failure();
int count = 1;
for (auto dim : tensorType.getShape()) {
if (dim > 0) count *= dim;
}
types = SmallVector<Type>(count, tensorType.getElementType());
return success();
});
typeConverter.addArgumentMaterialization(buildFromElementsOp);
typeConverter.addSourceMaterialization(buildFromElementsOp);
typeConverter.addTargetMaterialization(buildExtractOps);

RewritePatternSet patterns(context);
patterns.add<ConvertFromElementsOp, ConvertInsertOp>(typeConverter,
context);
scf::populateSCFStructuralOneToNTypeConversions(typeConverter, patterns);
populateFuncTypeConversionPatterns(typeConverter, patterns);

if (mlir::failed(mlir::applyPartialOneToNConversion(
getOperation(), typeConverter, std::move(patterns))))
signalPassFailure();

// Empty PatternSet = only run folders (should never fail)
RewritePatternSet emptyPatterns(context);
(void)applyPatternsAndFoldGreedily(getOperation(),
std::move(emptyPatterns));
}
};

} // namespace heir
} // namespace mlir
18 changes: 18 additions & 0 deletions lib/Transforms/TensorToScalars/TensorToScalars.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#ifndef LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_H_
#define LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_H_
#include "mlir/include/mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {

#define GEN_PASS_DECL
#include "lib/Transforms/TensorToScalars/TensorToScalars.h.inc"

#define GEN_PASS_REGISTRATION
#include "lib/Transforms/TensorToScalars/TensorToScalars.h.inc"

} // namespace heir
} // namespace mlir

#endif // LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_H_
37 changes: 37 additions & 0 deletions lib/Transforms/TensorToScalars/TensorToScalars.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_TD_
#define LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_TD_

include "mlir/Pass/PassBase.td"


def TensorToScalars : Pass<"convert-tensor-to-scalars"> {
let summary = "Effectively 'unrolls' tensors of static shape to scalars";
let description = [{
This pass will convert a static-shaped tensor type to a TypeRange
containing product(dim) copies of the element type of the tensor.
This pass currently includes two patterns:

1. It converts tensor.from_elements operations to
the corresponding scalar inputs.
2. It converts tensor.insert operations by updating the
ValueRange corresponding to the converted input and r
updating it with the scalar to be inserted.

It also applies folders greedily to simplify, e.g., extract(from_elements).

Note: The pass is designed to be run on an IR, where the only operations
with tensor typed operands are tensor "management" operations such as insert/extract,
with all other operations (e.g., arith operations) already taking (extracted) scalar inputs.
For example, an IR where elementwise operations have been converted to scalar operations via
`--convert-elementwise-to-affine`.

The pass might insert new tensor.from_elements operations or manually create the scalar ValueRange
via inserting tensor.extract operations if any operations remain that operate on tensors.
The pass currently applies irrespective of tensor size, i.e., might be very slow for large tensors.
}];
let dependentDialects = [
"tensor::TensorDialect"
];
}

#endif // LIB_TRANSFORMS_TENSORTOSCALARS_TENSORTOSCALARS_TD_
1 change: 1 addition & 0 deletions tools/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ cc_binary(
"@heir//lib/Transforms/FullLoopUnroll",
"@heir//lib/Transforms/Secretize",
"@heir//lib/Transforms/StraightLineVectorizer",
"@heir//lib/Transforms/TensorToScalars",
"@heir//lib/Transforms/UnusedMemRef",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:AffineDialect",
Expand Down
2 changes: 2 additions & 0 deletions tools/heir-opt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "lib/Transforms/FullLoopUnroll/FullLoopUnroll.h"
#include "lib/Transforms/Secretize/Passes.h"
#include "lib/Transforms/StraightLineVectorizer/StraightLineVectorizer.h"
#include "lib/Transforms/TensorToScalars/TensorToScalars.h"
#include "lib/Transforms/UnusedMemRef/UnusedMemRef.h"
#include "llvm/include/llvm/Support/CommandLine.h" // from @llvm-project
#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
Expand Down Expand Up @@ -522,6 +523,7 @@ int main(int argc, char **argv) {
registerForwardStoreToLoadPasses();
registerStraightLineVectorizerPasses();
registerUnusedMemRefPasses();
registerTensorToScalarsPasses();
// Register yosys optimizer pipeline if configured.
#ifndef HEIR_NO_YOSYS
#ifndef HEIR_ABC_BINARY
Expand Down

0 comments on commit c4d5c5f

Please sign in to comment.