Skip to content

Commit

Permalink
Fix stream.tensor.constant for complex<f32> crash (iree-org#14120)
Browse files Browse the repository at this point in the history
Getting bitwidth crashes when splatting constant<f32> values.
  • Loading branch information
rsuderman authored and nhasabni committed Aug 24, 2023
1 parent e24148b commit 78cc39c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 3 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ iree_compiler_cc_library(
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:ArithUtils",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:InferTypeOpInterface",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ iree_cc_library(
LLVMSupport
MLIRArithDialect
MLIRArithUtils
MLIRComplexDialect
MLIRFuncDialect
MLIRIR
MLIRInferTypeOpInterface
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def Stream_Dialect : Dialect {
}];

let dependentDialects = [
"mlir::complex::ComplexDialect",
"IREE::Util::UtilDialect",
];

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "iree/compiler/Dialect/Util/IR/UtilDialect.h"
#include "llvm/Support/SourceMgr.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Attributes.h"
Expand Down Expand Up @@ -96,6 +97,7 @@ struct StripResourceConversionCastPattern
StreamDialect::StreamDialect(MLIRContext *context)
: Dialect(getDialectNamespace(), context, TypeID::get<StreamDialect>()) {
context->loadDialect<IREE::Util::UtilDialect>();
context->loadDialect<mlir::complex::ComplexDialect>();

registerAttributes();
registerTypes();
Expand Down
16 changes: 13 additions & 3 deletions compiler/src/iree/compiler/Dialect/Stream/IR/StreamOpFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/StringExtras.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Complex/IR/Complex.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
Expand Down Expand Up @@ -1081,9 +1082,18 @@ struct TensorConstantToSplat : public OpRewritePattern<TensorConstantOp> {
"only constant splat attrs can be converted to splat ops");
}

auto splatElementAttr = splatAttr.getSplatValue<TypedAttr>();
auto splatValue = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), splatElementAttr.getType(), splatElementAttr);
Value splatValue;
if (isa<ComplexType>(getElementTypeOrSelf(splatAttr.getType()))) {
auto splatElementAttr = splatAttr.getSplatValue<ArrayAttr>();
splatValue = rewriter.create<complex::ConstantOp>(
constantOp.getLoc(), getElementTypeOrSelf(splatAttr.getType()),
cast<ArrayAttr>(splatElementAttr));
} else {
auto splatElementAttr = splatAttr.getSplatValue<TypedAttr>();
splatValue = rewriter.create<arith::ConstantOp>(
constantOp.getLoc(), splatElementAttr.getType(), splatElementAttr);
}

auto resultType = IREE::Stream::ResourceType::get(constantOp.getContext());
auto resultSize = rewriter.createOrFold<IREE::Stream::TensorSizeOfOp>(
constantOp.getLoc(), rewriter.getIndexType(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ func.func @TensorConstantToSplat() -> !stream.resource<constant> {

// -----

// CHECK-LABEL: @TensorComplexConstantToSplat
func.func @TensorComplexConstantToSplat() -> !stream.resource<constant> {
// CHECK-DAG: %[[CST:.+]] = complex.constant [2.000000e+00 : f32, 3.000000e+00 : f32] : complex<f32>
// CHECK-DAG: %[[SIZE:.+]] = stream.tensor.sizeof tensor<2x2xcomplex<f32>> : index
// CHECK: = stream.tensor.splat %[[CST]] : complex<f32> -> tensor<2x2xcomplex<f32>> in !stream.resource<*>{%[[SIZE]]}
%cst = stream.tensor.constant : tensor<2x2xcomplex<f32>> in !stream.resource<constant> = dense<(2.000000e+00,3.000000e+00)> : tensor<2x2xcomplex<f32>>
return %cst : !stream.resource<constant>
}

// -----

// CHECK-LABEL: @NarrowSplatPatternI32ToI8
func.func @NarrowSplatPatternI32ToI8() -> !stream.resource<*> {
%c100 = arith.constant 100 : index
Expand Down

0 comments on commit 78cc39c

Please sign in to comment.