Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: More tosa-to-boolean-tfhe fixes #748

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions lib/Conversion/CombToCGGI/CombToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/include/mlir/IR/OpDefinition.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeRange.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/IR/Types.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
Expand Down Expand Up @@ -65,21 +66,18 @@ Value buildSelectTruthTable(Location loc, OpBuilder &b, Value t, Value f,
selectFalse);
}

// equivalentMultiBitAndMemRefchecks whether the candidateMultiBit integer type
// is equivalent to the candidateMemRef type.
// Return true if the candidateMemRef is a memref of single bits with
// size equal to the number of bits of the candidateMultiBit.
bool equivalentMultiBitAndMemRef(Type candidateMultiBit, Type candidateMemRef) {
if (auto multiBitTy = dyn_cast<IntegerType>(candidateMultiBit)) {
if (auto memrefTy = dyn_cast<MemRefType>(candidateMemRef)) {
auto eltTy = dyn_cast<IntegerType>(memrefTy.getElementType());
if (eltTy && multiBitTy.getWidth() ==
memrefTy.getNumElements() * eltTy.getWidth()) {
return true;
}
}
// equivalentMultiBitAndMemRef checks if a the types hold the same number of
// bits.
bool equivalentMultiBitAndMemRef(Type lhsType, Type rhsType) {
int lhsBits = getElementTypeOrSelf(lhsType).getIntOrFloatBitWidth();
int rhsBits = getElementTypeOrSelf(rhsType).getIntOrFloatBitWidth();
if (auto lhsTensorTy = dyn_cast<ShapedType>(lhsType)) {
lhsBits *= lhsTensorTy.getNumElements();
}
if (auto rhsTensorTy = dyn_cast<ShapedType>(rhsType)) {
rhsBits *= rhsTensorTy.getNumElements();
}
return false;
return lhsBits == rhsBits;
}

Operation *convertWriteOpInterface(Operation *op, SmallVector<Value> indices,
Expand Down Expand Up @@ -533,8 +531,7 @@ struct ConvertSecretCastOp : public OpConversionPattern<secret::CastOp> {
auto outputTy =
cast<secret::SecretType>(op.getOutput().getType()).getValueType();

if (equivalentMultiBitAndMemRef(inputTy, outputTy) ||
equivalentMultiBitAndMemRef(outputTy, inputTy)) {
if (equivalentMultiBitAndMemRef(inputTy, outputTy)) {
rewriter.replaceOp(op, adaptor.getInput());
return success();
}
Expand Down
35 changes: 35 additions & 0 deletions lib/Dialect/Secret/IR/SecretPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,41 @@ void genericAbsorbConstants(secret::GenericOp genericOp,
});
}

void genericAbsorbDealloc(secret::GenericOp genericOp,
mlir::IRRewriter &rewriter) {
// Check the generic's returned memrefs. If their only use outside the generic
// is a dealloc, then move the dealloc inside the generic body.
for (auto result : genericOp.getResults()) {
if (auto memrefTy = dyn_cast<MemRefType>(
cast<secret::SecretType>(result.getType()).getValueType())) {
if (!result.hasOneUse()) {
continue;
}
// Ensure that the single user is a secret.generic.
auto &memrefUse = *result.getUses().begin();
auto genericUseOp = dyn_cast<secret::GenericOp>(memrefUse.getOwner());
if (!genericUseOp) {
continue;
}
auto blockArg =
genericUseOp.getBody()->getArgument(memrefUse.getOperandNumber());
auto blockArgUser = *blockArg.getUsers().begin();
if (!blockArg.hasOneUse() || !isa<memref::DeallocOp>(blockArgUser)) {
continue;
}
LLVM_DEBUG(llvm::dbgs()
<< "Found dealloc op to absorb into generic:" << blockArgUser);
rewriter.setInsertionPoint(genericOp.getYieldOp());
IRMapping mp;
mp.map(blockArg,
genericOp.getYieldOp()->getOperand(result.getResultNumber()));
rewriter.clone(*blockArgUser, mp);
SmallVector<Value> remainingResults;
rewriter.eraseOp(blockArgUser);
}
}
}

LogicalResult extractGenericBody(secret::GenericOp genericOp,
mlir::IRRewriter &rewriter) {
auto module = genericOp->getParentOfType<ModuleOp>();
Expand Down
4 changes: 4 additions & 0 deletions lib/Dialect/Secret/IR/SecretPatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ struct HoistPlaintextOps : public OpRewritePattern<GenericOp> {
void genericAbsorbConstants(secret::GenericOp genericOp,
mlir::IRRewriter &rewriter);

// Absorbs any memref deallocations into the generic body.
void genericAbsorbDealloc(secret::GenericOp genericOp,
mlir::IRRewriter &rewriter);

// Extract the body of a secret.generic into a function and replace the generic
// body with a call to the created function.
LogicalResult extractGenericBody(secret::GenericOp genericOp,
Expand Down
15 changes: 15 additions & 0 deletions lib/Dialect/Secret/Transforms/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ cc_library(
":ExtractGenericBody",
":ForgetSecrets",
":GenericAbsorbConstants",
":GenericAbsorbDealloc",
":MergeAdjacentGenerics",
":pass_inc_gen",
"@heir//lib/Dialect/Secret/IR:Dialect",
Expand Down Expand Up @@ -89,6 +90,20 @@ cc_library(
],
)

cc_library(
name = "GenericAbsorbDealloc",
srcs = ["GenericAbsorbDealloc.cpp"],
hdrs = [
"GenericAbsorbDealloc.h",
],
deps = [
":pass_inc_gen",
"@heir//lib/Dialect/Secret/IR:SecretPatterns",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
],
)

cc_library(
name = "MergeAdjacentGenerics",
srcs = ["MergeAdjacentGenerics.cpp"],
Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Secret/Transforms/GenericAbsorbConstants.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#ifndef HEIR_LIB_DIALECT_SECRET_TRANSFORMS_CONSTANTPASSTHROUGHGENERIC_H_
#define HEIR_LIB_DIALECT_SECRET_TRANSFORMS_CONSTANTPASSTHROUGHGENERIC_H_
#ifndef HEIR_LIB_DIALECT_SECRET_TRANSFORMS_GENERICABSORBCONSTANTS_H_
#define HEIR_LIB_DIALECT_SECRET_TRANSFORMS_GENERICABSORBCONSTANTS_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

Expand All @@ -14,4 +14,4 @@ namespace secret {
} // namespace heir
} // namespace mlir

#endif // HEIR_LIB_DIALECT_SECRET_TRANSFORMS_CONSTANTPASSTHROUGHGENERIC_H_
#endif // HEIR_LIB_DIALECT_SECRET_TRANSFORMS_GENERICABSORBCONSTANTS_H_
31 changes: 31 additions & 0 deletions lib/Dialect/Secret/Transforms/GenericAbsorbDealloc.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include "lib/Dialect/Secret/Transforms/GenericAbsorbDealloc.h"

#include "lib/Dialect/Secret/IR/SecretOps.h"
#include "lib/Dialect/Secret/IR/SecretPatterns.h"
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace secret {

#define GEN_PASS_DEF_SECRETGENERICABSORBDEALLOC
#include "lib/Dialect/Secret/Transforms/Passes.h.inc"

struct GenericAbsorbDealloc
: impl::SecretGenericAbsorbDeallocBase<GenericAbsorbDealloc> {
using SecretGenericAbsorbDeallocBase::SecretGenericAbsorbDeallocBase;

void runOnOperation() override {
mlir::IRRewriter builder(&getContext());

getOperation()->walk([&](secret::GenericOp op) {
genericAbsorbDealloc(op, builder);
return WalkResult::advance();
});
}
};

} // namespace secret
} // namespace heir
} // namespace mlir
17 changes: 17 additions & 0 deletions lib/Dialect/Secret/Transforms/GenericAbsorbDealloc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#ifndef HEIR_LIB_DIALECT_SECRET_TRANSFORMS_GENERICABSORBDEALLOC_H_
#define HEIR_LIB_DIALECT_SECRET_TRANSFORMS_GENERICABSORBDEALLOC_H_

#include "mlir/include/mlir/Pass/Pass.h" // from @llvm-project

namespace mlir {
namespace heir {
namespace secret {

#define GEN_PASS_DECL_SECRETGENERICABSORBDEALLOC
#include "lib/Dialect/Secret/Transforms/Passes.h.inc"

} // namespace secret
} // namespace heir
} // namespace mlir

#endif // HEIR_LIB_DIALECT_SECRET_TRANSFORMS_GENERICABSORBDEALLOC_H_
1 change: 1 addition & 0 deletions lib/Dialect/Secret/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "lib/Dialect/Secret/Transforms/ExtractGenericBody.h"
#include "lib/Dialect/Secret/Transforms/ForgetSecrets.h"
#include "lib/Dialect/Secret/Transforms/GenericAbsorbConstants.h"
#include "lib/Dialect/Secret/Transforms/GenericAbsorbDealloc.h"
#include "lib/Dialect/Secret/Transforms/MergeAdjacentGenerics.h"

namespace mlir {
Expand Down
9 changes: 9 additions & 0 deletions lib/Dialect/Secret/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def SecretGenericAbsorbConstants : Pass<"secret-generic-absorb-constants"> {
let dependentDialects = ["mlir::heir::secret::SecretDialect"];
}

def SecretGenericAbsorbDealloc : Pass<"secret-generic-absorb-dealloc"> {
let summary = "Copy deallocs of internal memrefs into a secret.generic body";
let description = [{
For each memref used only in the body of a `secret.generic` op, add it's
dealloc of the memref into the `generic` body.
}];
let dependentDialects = ["mlir::heir::secret::SecretDialect"];
}

def SecretExtractGenericBody : Pass<"secret-extract-generic-body"> {
let summary = "Extract the bodies of all generic ops into functions";
let description = [{
Expand Down
84 changes: 67 additions & 17 deletions lib/Transforms/YosysOptimizer/YosysOptimizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ read_verilog -sv {0};
hierarchy -check -top \{1};
proc; memory; stat;
techmap -map {2}/techmap.v; stat;
opt_expr; opt; opt_clean -purge; stat;
splitnets -ports \{1} %n;
flatten; opt_expr; opt; opt_clean -purge;
rename -hide */w:*; rename -enumerate */w:*;
Expand All @@ -95,13 +96,16 @@ stat;
// $3: yosys runfiles path
// $4: abc fast option -fast
constexpr std::string_view kYosysBooleanTemplate = R"(
read_verilog {0};
read_verilog -sv {0};
hierarchy -check -top \{1};
proc; memory; stat;
techmap -map {3}/techmap.v; opt; stat;
opt_expr; opt; opt_clean -purge; stat;
splitnets -ports \{1} %n;
flatten; opt_expr; opt; opt_clean -purge;
rename -hide */w:*; rename -enumerate */w:*;
abc -exe {2} -g AND,NAND,OR,NOR,XOR,XNOR {4};
opt_clean -purge; stat;
rename -hide */c:*; rename -enumerate */c:*;
hierarchy -generate * o:Y i:*; opt; opt_clean -purge;
clean;
stat;
Expand Down Expand Up @@ -446,6 +450,7 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) {

// Translate Yosys result back to MLIR and insert into the func
LLVM_DEBUG(Yosys::run_pass("dump;"));
Yosys::log_streams.clear();
std::stringstream cellOrder;
Yosys::log_streams.push_back(&cellOrder);
Yosys::run_pass("torder -stop * P*;");
Expand Down Expand Up @@ -528,6 +533,8 @@ LogicalResult YosysOptimizer::runOnGenericOp(secret::GenericOp op) {

// Optimize the body of a secret.generic op.
void YosysOptimizer::runOnOperation() {
LLVM_DEBUG({ llvm::dbgs() << "Running Yosys optimizer\n"; });

Yosys::yosys_setup();
auto *ctx = &getContext();
auto *op = getOperation();
Expand All @@ -541,16 +548,6 @@ void YosysOptimizer::runOnOperation() {

// Cleanup after unrollAndMergeGenerics
mlir::RewritePatternSet cleanupPatterns(ctx);
// We lift loads/stores into their own generics if possible, to avoid putting
// the entire memref in the verilog module. Some loads would be hoistable but
// they depend on arithmetic of index accessors that are otherwise secret.
// Hence we need the HoistPlaintextOps provided by
// populateGenericCanonicalizers in addition to special patterns that lift
// loads and stores into their own generics.
cleanupPatterns.add<secret::HoistOpBeforeGeneric>(
ctx, std::vector<std::string>{"memref.load", "affine.load"});
cleanupPatterns.add<secret::HoistOpAfterGeneric>(
ctx, std::vector<std::string>{"memref.store", "affine.store"});
secret::populateGenericCanonicalizers(cleanupPatterns, ctx);
if (failed(applyPatternsAndFoldGreedily(op, std::move(cleanupPatterns)))) {
signalPassFailure();
Expand All @@ -572,19 +569,72 @@ void YosysOptimizer::runOnOperation() {
return;
}

mlir::IRRewriter builder(&getContext());
op->walk([&](secret::GenericOp op) {
// Now pass through any constants used after capturing the ambient scope.
// This way Yosys can optimize constants away instead of treating them as
// variables to the optimized body. We also absorb any memref deallocations
// into the generic body when the memref is only used internally within the
// generic body.
genericAbsorbDealloc(op, builder);
});

// Remove unused values after absorbing the deallocs.
mlir::RewritePatternSet unusedYieldPatterns(ctx);
unusedYieldPatterns.add<secret::RemoveUnusedYieldedValues>(ctx);
if (failed(
applyPatternsAndFoldGreedily(op, std::move(unusedYieldPatterns)))) {
signalPassFailure();
getOperation()->emitError()
<< "Failed to merge generic ops before yosys optimizer";
return;
}

// Extract generics into function calls.
auto result = op->walk([&](secret::GenericOp op) {
genericAbsorbConstants(op, builder);

auto isTrivial = op.getBody()->walk([&](Operation *body) {
if (isa<arith::ArithDialect>(body->getDialect()) &&
!isa<arith::ConstantOp>(body)) {
return WalkResult::interrupt();
}
return WalkResult::advance();
});
if (isTrivial.wasInterrupted()) {
if (failed(extractGenericBody(op, builder))) {
return WalkResult::interrupt();
}
}

return WalkResult::advance();
});

if (result.wasInterrupted()) {
signalPassFailure();
}

// Merge generics after the function bodies are extracted.
mlir::RewritePatternSet mergePatterns(ctx);
mergePatterns.add<secret::MergeAdjacentGenerics>(ctx);
if (failed(applyPatternsAndFoldGreedily(op, std::move(mergePatterns)))) {
signalPassFailure();
getOperation()->emitError()
<< "Failed to merge generic ops before yosys optimizer";
return;
}

// Ensure that each generic only has one return value
LLVM_DEBUG({
llvm::dbgs() << "IR after cleanup in preparation for yosys optimizer\n";
getOperation()->dump();
});

mlir::IRRewriter builder(&getContext());
auto result = op->walk([&](secret::GenericOp op) {
result = op->walk([&](secret::GenericOp op) {
// Now pass through any constants used after capturing the ambient scope.
// This
// way Yosys can optimize constants away instead of treating them as
// This way Yosys can optimize constants away instead of treating them as
// variables to the optimized body.
genericAbsorbConstants(op, builder);

if (failed(runOnGenericOp(op))) {
return WalkResult::interrupt();
}
Expand Down
29 changes: 29 additions & 0 deletions tests/secret/generic_absorb_dealloc.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// RUN: heir-opt --secret-generic-absorb-dealloc %s | FileCheck %s

// CHECK-LABEL: test_absorb_dealloc
// CHECK-SAME: %[[Y:.*]]: !secret.secret<memref<1xi32>>) {
func.func @test_absorb_dealloc(%memref : !secret.secret<memref<1xi32>>) {
// CHECK: %[[C0:.*]] = arith.constant 7
%C7 = arith.constant 7 : i32
// CHECK: %[[Z:.*]] = secret.generic ins(%[[Y]], %[[C0]] : !secret.secret<memref<1xi32>>, i32)
%Z:2 = secret.generic ins(%memref, %C7 : !secret.secret<memref<1xi32>>, i32) {
// CHECK-NEXT: ^[[bb0:.*]](%[[y:.*]]: memref<1xi32>, %[[c0:.*]]: i32):
^bb0(%y: memref<1xi32>, %c0 : i32):
// CHECK: %[[d:.*]] = memref.alloc()
// CHECK: memref.dealloc %[[d]]
// CHECK: secret.yield
affine.store %c0, %y[0] : memref<1xi32>
%internal = memref.alloc() : memref<1xi32>
affine.store %c0, %internal[0] : memref<1xi32>
secret.yield %y, %internal : memref<1xi32>, memref<1xi32>
} -> (!secret.secret<memref<1xi32>>, !secret.secret<memref<1xi32>>)
// CHECK: secret.generic ins(%[[Z1:.*]] : !secret.secret<memref<1xi32>>)
secret.generic ins(%Z#1 : !secret.secret<memref<1xi32>>) {
// CHECK-NEXT: ^[[bb0:.*]](%[[y:.*]]: memref<1xi32>):
^bb0(%z1: memref<1xi32>):
// CHECK-NEXT: secret.yield
memref.dealloc %z1 : memref<1xi32>
secret.yield
}
func.return
}
Loading
Loading