Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an LWE noise propagation model #275

Draft
wants to merge 20 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions include/Analysis/NoisePropagation/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# NoisePropagationAnalysis analysis pass
package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

exports_files(
[
"NoisePropagationAnalysis.h",
"Variance.h",
],
)
41 changes: 41 additions & 0 deletions include/Analysis/NoisePropagation/NoisePropagationAnalysis.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef INCLUDE_ANALYSIS_NOISEPROPAGATION_NOISEPROPAGATIONANALYSIS_H_
#define INCLUDE_ANALYSIS_NOISEPROPAGATION_NOISEPROPAGATIONANALYSIS_H_

#include "include/Analysis/NoisePropagation/Variance.h"
#include "mlir/include/mlir/Analysis/DataFlow/SparseAnalysis.h" // from @llvm-project
#include "mlir/include/mlir/IR/Operation.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project

namespace mlir {
namespace heir {

/// This lattice element represents the noise distribution of an SSA value.
class VarianceLattice : public dataflow::Lattice<Variance> {
public:
using Lattice::Lattice;
};

/// Noise propagation analysis determines a noise bound for SSA values,
/// represented by the variance of a symmetric Gaussian distribution. This
/// analysis propagates noise across operations that implement
/// `NoisePropagationInterface`, but does not support propagation for SSA
/// values that represent loop bounds or induction variables. It can be viewed
/// as a simplified port of IntegerRangeAnalysis.
class NoisePropagationAnalysis
: public dataflow::SparseForwardDataFlowAnalysis<VarianceLattice> {
public:
using SparseForwardDataFlowAnalysis::SparseForwardDataFlowAnalysis;

void setToEntryState(VarianceLattice *lattice) override {
// At an entry point, we have no information about the noise.
propagateIfChanged(lattice, lattice->join(Variance::uninitialized()));
}

void visitOperation(Operation *op, ArrayRef<const VarianceLattice *> operands,
ArrayRef<VarianceLattice *> results) override;
};

} // namespace heir
} // namespace mlir

#endif // INCLUDE_ANALYSIS_NOISEPROPAGATION_NOISEPROPAGATIONANALYSIS_H_
104 changes: 104 additions & 0 deletions include/Analysis/NoisePropagation/Variance.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#ifndef INCLUDE_ANALYSIS_NOISEPROPAGATION_VARIANCE_H_
#define INCLUDE_ANALYSIS_NOISEPROPAGATION_VARIANCE_H_

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <optional>

#include "llvm/include/llvm/Support/raw_ostream.h" // from @llvm-project
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project

namespace mlir {
namespace heir {

enum VarianceType {
// A min value for the lattice, discarable when joined with anything else.
UNINITIALIZED,
// A known value for the lattice, when noise can be inferred.
SET,
// A max value for the lattice, when noise cannot be inferred and a bootstrap
// must be forced.
UNBOUNDED
};

/// A class representing an optional variance of a noise distribution.
class Variance {
public:
static Variance uninitialized() {
return Variance(VarianceType::UNINITIALIZED, std::nullopt);
}
static Variance unbounded() {
return Variance(VarianceType::UNBOUNDED, std::nullopt);
}
static Variance of(int64_t value) {
return Variance(VarianceType::SET, value);
}

/// Create an integer value range lattice value.
/// The default constructor must be equivalent to the "entry state" of the
/// lattice, i.e., an uninitialized noise variance.
Variance(VarianceType varianceType = VarianceType::UNINITIALIZED,
std::optional<int64_t> value = std::nullopt)
: varianceType(varianceType), value(value) {}

bool isKnown() const { return varianceType == VarianceType::SET; }

bool isInitialized() const {
return varianceType != VarianceType::UNINITIALIZED;
}

bool isBounded() const { return varianceType != VarianceType::UNBOUNDED; }

const int64_t &getValue() const {
assert(isKnown());
return *value;
}

bool operator==(const Variance &rhs) const {
return varianceType == rhs.varianceType && value == rhs.value;
}

static Variance join(const Variance &lhs, const Variance &rhs) {
// Uninitialized variances correspond to values that are not secret,
// which may be the inputs to an encryption operation.
if (lhs.varianceType == VarianceType::UNINITIALIZED) {
return rhs;
}
if (rhs.varianceType == VarianceType::UNINITIALIZED) {
return lhs;
}

// Unbounded represents a pessimistic worst case, and so it must be
// preserved no matter the other operand.
if (lhs.varianceType == VarianceType::UNBOUNDED) {
return lhs;
}
if (rhs.varianceType == VarianceType::UNBOUNDED) {
return rhs;
}

assert(lhs.varianceType == VarianceType::SET &&
rhs.varianceType == VarianceType::SET);
return Variance::of(std::max(lhs.getValue(), rhs.getValue()));
}

void print(llvm::raw_ostream &os) const { os << value; }

std::string toString() const;

friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os,
const Variance &variance);

friend Diagnostic &operator<<(Diagnostic &diagnostic,
const Variance &variance);

private:
VarianceType varianceType;
std::optional<int64_t> value;
};

} // namespace heir
} // namespace mlir

#endif // INCLUDE_ANALYSIS_NOISEPROPAGATION_VARIANCE_H_
1 change: 1 addition & 0 deletions include/Dialect/CGGI/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//include/Interfaces:NoiseInterfacesTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
"@llvm-project//mlir:SideEffectInterfacesTdFiles",
Expand Down
12 changes: 6 additions & 6 deletions include/Dialect/CGGI/IR/CGGIAttributes.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@ def CGGI_CGGIParams : AttrDef<CGGI_Dialect, "CGGIParams"> {
// to lwe dialect?
let parameters = (ins
"::mlir::heir::lwe::RLWEParamsAttr": $rlweParams,
"unsigned": $bsk_noise_variance,
"unsigned": $bsk_gadget_base_log,
"unsigned": $bsk_gadget_num_levels,
"unsigned": $ksk_noise_variance,
"unsigned": $ksk_gadget_base_log,
"unsigned": $ksk_gadget_num_levels
"int64_t": $bsk_noise_variance,
"int64_t": $bsk_gadget_base_log,
"int64_t": $bsk_gadget_num_levels,
"int64_t": $ksk_noise_variance,
"int64_t": $ksk_gadget_base_log,
"int64_t": $ksk_gadget_num_levels
);

let assemblyFormat = "`<` struct(params) `>`";
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/CGGI/IR/CGGIOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "include/Dialect/CGGI/IR/CGGIDialect.h"
#include "include/Dialect/LWE/IR/LWETypes.h"
#include "include/Interfaces/NoiseInterfaces.h"
#include "mlir/include/mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

Expand Down
4 changes: 3 additions & 1 deletion include/Dialect/CGGI/IR/CGGIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ include "include/Dialect/CGGI/IR/CGGIDialect.td"

include "include/Dialect/Polynomial/IR/PolynomialAttributes.td"
include "include/Dialect/LWE/IR/LWETypes.td"
include "include/Interfaces/NoiseInterfaces.td"

include "mlir/IR/OpBase.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

class CGGI_Op<string mnemonic, list<Trait> traits = []> :
Op<CGGI_Dialect, mnemonic, traits> {
Op<CGGI_Dialect, mnemonic, traits # [
DeclareOpInterfaceMethods<NoisePropagationInterface>]> {
let assemblyFormat = [{
`(` operands `)` attr-dict `:` `(` qualified(type(operands)) `)` `->` qualified(type(results))
}];
Expand Down
1 change: 1 addition & 0 deletions include/Dialect/LWE/IR/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ td_library(
# include from the heir-root to enable fully-qualified include-paths
includes = ["../../../.."],
deps = [
"@heir//include/Interfaces:NoiseInterfacesTdFiles",
"@llvm-project//mlir:BuiltinDialectTdFiles",
"@llvm-project//mlir:InferTypeOpInterfaceTdFiles",
"@llvm-project//mlir:OpBaseTdFiles",
Expand Down
2 changes: 2 additions & 0 deletions include/Dialect/LWE/IR/LWEDialect.h
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
#ifndef HEIR_INCLUDE_DIALECT_LWE_IR_LWEDIALECT_H_
#define HEIR_INCLUDE_DIALECT_LWE_IR_LWEDIALECT_H_

#include "include/Interfaces/NoiseInterfaces.h"
#include "mlir/include/mlir/IR/Builders.h" // from @llvm-project
#include "mlir/include/mlir/IR/Dialect.h" // from @llvm-project
#include "mlir/include/mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project

// Generated headers (block clang-format from messing up order)
#include "include/Dialect/LWE/IR/LWEDialect.h.inc"
Expand Down
23 changes: 21 additions & 2 deletions include/Dialect/LWE/IR/LWEOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

include "include/Dialect/LWE/IR/LWEDialect.td"
include "include/Dialect/LWE/IR/LWETypes.td"
include "include/Interfaces/NoiseInterfaces.td"
include "mlir/IR/BuiltinAttributeInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/CommonAttrConstraints.td"

class LWE_Op<string mnemonic, list<Trait> traits = []> :
Op<LWE_Dialect, mnemonic, traits> {
Expand Down Expand Up @@ -40,7 +41,10 @@ def LWE_EncodeOp : LWE_Op<"encode", [Pure]> {
let hasVerifier = 1;
}

def LWE_TrivialEncryptOp: LWE_Op<"trivial_encrypt", [Pure]> {
def LWE_TrivialEncryptOp: LWE_Op<"trivial_encrypt", [
Pure,
DeclareOpInterfaceMethods<NoisePropagationInterface>
]> {
let summary = "Create a trivial encryption of a plaintext.";

let arguments = (ins
Expand All @@ -59,4 +63,19 @@ def LWE_TrivialEncryptOp: LWE_Op<"trivial_encrypt", [Pure]> {
let hasVerifier = 1;
}

def LWE_AddOp : LWE_Op<"add", [
Pure,
SameOperandsAndResultType,
DeclareOpInterfaceMethods<NoisePropagationInterface>
]> {
let arguments = (ins
LWECiphertext:$lhs,
LWECiphertext:$rhs
);
let results = (outs LWECiphertext:$output);
let assemblyFormat = [{
operands attr-dict `:` qualified(type($output))
}];
}

#endif // HEIR_INCLUDE_DIALECT_LWE_IR_LWEOPS_TD_
40 changes: 40 additions & 0 deletions include/Interfaces/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# HEIR project-wide interfaces
load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library")

package(
default_applicable_licenses = ["@heir//:license"],
default_visibility = ["//visibility:public"],
)

td_library(
name = "NoiseInterfacesTdFiles",
srcs = ["NoiseInterfaces.td"],
deps = [
"@llvm-project//mlir:OpBaseTdFiles",
],
)

gentbl_cc_library(
name = "NoiseInterfacesIncGen",
tbl_outs = [
(
["-gen-op-interface-decls"],
"NoiseInterfaces.h.inc",
),
(
["-gen-op-interface-defs"],
"NoiseInterfaces.cpp.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "NoiseInterfaces.td",
deps = [
":NoiseInterfacesTdFiles",
],
)

exports_files(
[
"NoiseInterfaces.h",
],
)
19 changes: 19 additions & 0 deletions include/Interfaces/NoiseInterfaces.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef INCLUDE_INTERFACES_NOISEINTERFACES_H_
#define INCLUDE_INTERFACES_NOISEINTERFACES_H_

#include "include/Analysis/NoisePropagation/Variance.h"
#include "mlir/include/mlir/IR/OpDefinition.h" // trom @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project

namespace mlir {
namespace heir {

// Variance is a type defined by NoisePropagationAnalysis
using SetNoiseFn = function_ref<void(Value, Variance)>;

} // namespace heir
} // namespace mlir

#include "include/Interfaces/NoiseInterfaces.h.inc"

#endif // INCLUDE_INTERFACES_NOISEINTERFACES_H_
49 changes: 49 additions & 0 deletions include/Interfaces/NoiseInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#ifndef INCLUDE_INTERFACES_NOISEINTERFACES_TD_
#define INCLUDE_INTERFACES_NOISEINTERFACES_TD_

include "mlir/IR/OpBase.td"

def NoisePropagationInterface : OpInterface<"NoisePropagationInterface"> {
let description = [{
Declares that an operation produces results with noise, and provides an
interface for passes to compute bounds on the noise in the results
from the input noises.

Here "noise" is defined as the (perhaps upper-bounded) variance of a
Gaussian distribution centered at zero.
}];
let cppNamespace = "::mlir::heir";

let methods = [
InterfaceMethod<[{
Infers the noise distribution of the result of this op given the
distributions of its inputs.

All noise distributions are assumed to be Gaussian centered at zero, and
so the inputs and results are represented by their variances.

For each result value or block argument (that isn't a branch argument,
since the dataflow analysis handles those case), the method should call
`setValueNoise` with that `Value` as an argument. When `setValueNoise`
is not called for some value, the analysis will raise an error.

`argNoises` contains one `int64_t` for each operand to the op in ODS
order. Operands that don't have a prior noise associated with them
will have this value set to zero.
}],
"void", "inferResultNoise", (ins
"::llvm::ArrayRef<Variance>":$argNoises,
"::mlir::heir::SetNoiseFn":$setValueNoise)

>,
InterfaceMethod<[{
Returns true if the noise in the result op is independent of the noise in
its inputs. This is suitable for ops like bootstrap and initial
encryption.
}],
"bool", "hasArgumentIndependentResultNoise", (ins)
>];
}


#endif // INCLUDE_INTERFACES_NOISEINTERFACES_TD_
Loading
Loading