Skip to content

Commit

Permalink
squashme: wip
Browse files Browse the repository at this point in the history
  • Loading branch information
j2kun committed Nov 30, 2023
1 parent c46d28d commit 8b47076
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 11 deletions.
15 changes: 12 additions & 3 deletions include/Analysis/NoisePropagation/Variance.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
namespace mlir {
namespace heir {

enum VarianceType {
UNSET, // A min value for the lattice, i.e., discarable when joined with
// anything else.
KNOWN, // A known value for the lattice, i.e., when noise can be inferred.
INDEPENDENT, // A known value for the lattice, independent of
MAX // A max value for the lattice, i.e., when noise cannot be inferred and a
// bootstrap must be forced.
};

/// A class representing an optional variance of a noise distribution.
class Variance {
public:
Expand All @@ -31,10 +40,10 @@ class Variance {

/// This method represents how to choose a noise from one of two possible
/// branches, when either could be possible. In the case of FHE, we must
/// assume the worse case, so take the max.
/// assume the worst case. If either is unknown, assume unknown, otherwise
/// take the max.
static Variance join(const Variance &lhs, const Variance &rhs) {
if (!lhs.isKnown()) return rhs;
if (!rhs.isKnown()) return lhs;
if (!lhs.isKnown() || !rhs.isKnown()) return Variance::unknown();
return Variance{std::max(lhs.getValue(), rhs.getValue())};
}

Expand Down
12 changes: 7 additions & 5 deletions lib/Analysis/NoisePropagation/NoisePropagationAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ void NoisePropagationAnalysis::visitOperation(
})) {
LLVM_DEBUG(llvm::dbgs()
<< "Op " << noisePropagationOp->getName()
<< "with argument-dependent noise propagation encountered input "
"with unknown noise. Marking result noise as unknown.\n");
<< " with argument-dependent noise propagation encountered input"
" with unknown noise. Marking result noise as unknown.\n");
return setAllToEntryStates(results);
}

Expand All @@ -45,6 +45,8 @@ void NoisePropagationAnalysis::visitOperation(
Variance oldRange = lattice->getValue();
ChangeResult changed = lattice->join(Variance{variance});

// FIXME: does this even make sense as a lattice??
//
// If the result is yielded, then the best we can do is check to see if the
// op producing this value has argument-independent noise. If so, we can
// propagate that noise. Otherwise, we must assume the worst case scenario
Expand All @@ -58,15 +60,15 @@ void NoisePropagationAnalysis::visitOperation(
// determine where in the codebase one should look for stuff related to
// this method.
if (isYieldedResult && oldRange.isKnown() &&
!(lattice->getValue() == oldRange) &&
!(lattice.getValue() == oldRange) &&
!noisePropagationOp.hasArgumentIndependentResultNoise()) {
LLVM_DEBUG(
llvm::dbgs()
<< "Non-constant noise-propagating op passed to a region "
"terminator. Assuming loop result and marking noise unknown\n");
changed |= lattice->join(Variance(std::nullopt));
changed |= lattice.join(Variance::unknown());
}
propagateIfChanged(lattice, changed);
propagateIfChanged(&lattice, changed);
};

noisePropagationOp.inferResultNoise(argRanges, joinCallback);
Expand Down
1 change: 0 additions & 1 deletion lib/Dialect/CGGI/Transforms/SetDefaultParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#include "include/Dialect/CGGI/IR/CGGIOps.h"
#include "include/Dialect/LWE/IR/LWEAttributes.h"
#include "llvm/include/llvm/ADT/TypeSwitch.h" // from @llvm-project
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project

namespace mlir {
Expand Down
8 changes: 6 additions & 2 deletions lib/Transforms/ValidateNoise/ValidateNoise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
#include "include/Dialect/LWE/IR/LWETypes.h"
#include "include/Interfaces/NoiseInterfaces.h"
#include "llvm/include/llvm/Support/Debug.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project// from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/DeadCodeAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h" // from @llvm-projectject
#include "mlir/include/mlir/Analysis/DataFlowFramework.h" // from @llvm-project
#include "mlir/include/mlir/IR/Visitors.h" // from @llvm-project
Expand Down Expand Up @@ -36,8 +37,11 @@ struct ValidateNoise : impl::ValidateNoiseBase<ValidateNoise> {
auto *module = getOperation();

DataFlowSolver solver;
// The dataflow solver needs DeadCodeAnalysis to run the other analyses
// The dataflow solver needs DeadCodeAnalysis and SparseConstantPropagation
// to run pretty much any data flow analysis, see
// https://discourse.llvm.org/t/mlir-dead-code-analysis/67568/8
solver.load<dataflow::DeadCodeAnalysis>();
solver.load<dataflow::SparseConstantPropagation>();
solver.load<NoisePropagationAnalysis>();
if (failed(solver.initializeAndRun(module))) {
getOperation()->emitOpError() << "Failed to run the analysis.\n";
Expand Down
39 changes: 39 additions & 0 deletions tests/validate_noise/validate_noise_errors.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// RUN: heir-opt --split-input-file --cggi-set-default-parameters --lwe-set-default-parameters --validate-noise --verify-diagnostics %s

// TODO(https://github.com/google/heir/issues/296): use lwe.encrypt with
// realistic initial noise.

// #encoding = #lwe.bit_field_encoding<cleartext_start=30, cleartext_bitwidth=3>
// #poly = #polynomial.polynomial<1 + x**1024>
// !plaintext = !lwe.lwe_plaintext<encoding = #encoding>
// !ciphertext = !lwe.lwe_ciphertext<encoding = #encoding>
//
// func.func @test_cant_add_unknown_value(%arg0 : !ciphertext) -> !ciphertext {
// // expected-error@below {{uses SSA value with unknown noise variance}}
// %1 = lwe.add %arg0, %arg0 : !ciphertext
// return %1 : !ciphertext
// }
//
// // -----

#encoding = #lwe.bit_field_encoding<cleartext_start=30, cleartext_bitwidth=3>
#poly = #polynomial.polynomial<1 + x**1024>
!plaintext = !lwe.lwe_plaintext<encoding = #encoding>
!ciphertext = !lwe.lwe_ciphertext<encoding = #encoding>

func.func @unknown_value_from_loop_result() -> !ciphertext {
%0 = arith.constant 0 : i1
%2 = lwe.encode %0 { encoding = #encoding }: i1 to !plaintext
%3 = lwe.trivial_encrypt %2 : !plaintext to !ciphertext

%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c5 = arith.constant 5 : index

%5 = scf.for %arg1 = %c1 to %c5 step %c1 iter_args(%iter_arg = %3) -> !ciphertext {
// expected-error@below {{uses SSA value with unknown noise variance}}
%6 = lwe.add %iter_arg, %iter_arg : !ciphertext
scf.yield %6 : !ciphertext
}
return %5 : !ciphertext
}

0 comments on commit 8b47076

Please sign in to comment.