Skip to content

Commit

Permalink
Merge pull request #756 from j2kun:autohog-importer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 653639002
  • Loading branch information
Copybara-Service committed Jul 18, 2024
2 parents 0310089 + 2a12b63 commit 195121f
Show file tree
Hide file tree
Showing 18 changed files with 1,511 additions and 15 deletions.
42 changes: 41 additions & 1 deletion lib/Dialect/CGGI/IR/CGGIOps.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "lib/Dialect/CGGI/IR/CGGIOps.h"

#include <cstdint>

#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "mlir/include/mlir/IR/Diagnostics.h" // from @llvm-project
Expand Down Expand Up @@ -30,7 +32,7 @@ LogicalResult LutLinCombOp::verify() {
auto encoding = dyn_cast<lwe::BitFieldEncodingAttr>(type.getEncoding());

if (encoding) {
unsigned maxCoeff = (1 << encoding.getCleartextBitwidth()) - 1;
int64_t maxCoeff = (1 << encoding.getCleartextBitwidth()) - 1;
for (auto c : getCoefficients()) {
if (c > maxCoeff) {
InFlightDiagnostic diag =
Expand All @@ -55,6 +57,44 @@ LogicalResult LutLinCombOp::verify() {
return success();
}

LogicalResult MultiLutLinCombOp::verify() {
if (getInputs().size() != getCoefficients().size())
return emitOpError("number of coefficients must match number of inputs");
if (getOutputs().size() != getLookupTables().size())
return emitOpError("number of outputs must match number of LUTs");

lwe::LWECiphertextType type =
cast<lwe::LWECiphertextType>(getOutputs().front().getType());
auto encoding = dyn_cast<lwe::BitFieldEncodingAttr>(type.getEncoding());

if (encoding) {
int64_t maxCoeff = (1 << encoding.getCleartextBitwidth()) - 1;
for (auto c : getCoefficients()) {
if (c > maxCoeff) {
InFlightDiagnostic diag =
emitOpError("coefficient pushes error bits into message space");
diag.attachNote() << "coefficient is " << c;
diag.attachNote() << "largest allowable coefficient is " << maxCoeff;
return diag;
}
}

for (int64_t lut : getLookupTables()) {
APInt apintLut = APInt(64, lut);
if (apintLut.getActiveBits() > maxCoeff + 1) {
InFlightDiagnostic diag =
emitOpError("LUT is larger than available cleartext bit width");
diag.attachNote() << "LUT has " << apintLut.getActiveBits()
<< " active bits";
diag.attachNote() << "max LUT size is " << maxCoeff + 1 << " bits";
return diag;
}
}
}

return success();
}

} // namespace cggi
} // namespace heir
} // namespace mlir
53 changes: 53 additions & 0 deletions lib/Dialect/CGGI/IR/CGGIOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -143,4 +143,57 @@ def CGGI_LutLinCombOp : CGGI_Op<"lut_lincomb", [
let hasVerifier = 1;
}

def CGGI_MultiLutLinCombOp : CGGI_Op<"multi_lut_lincomb", [
Pure,
Commutative,
ElementwiseMappable,
Scalarizable
]> {
let summary = "A multi-output version of lut_lincomb with one LUT per output.";
let description = [{
An op representing multiple lookup tables applied to a shared input, which
is prepared via a static linear combination. This is equivalent to
`cggi.lut_lincomb`, but where the linear combination is given to multiple
lookup tables, each producing a separate output.

This can be achieved by a special implementation of blind rotate in the CGGI
scheme. See [AutoHoG](https://ieeexplore.ieee.org/document/10413195).

Example:

```mlir
#encoding = #lwe.bit_field_encoding<cleartext_start=30, cleartext_bitwidth=3>
#params = #lwe.lwe_params<cmod=7917, dimension=4>
!ciphertext = !lwe.lwe_ciphertext<encoding = #encoding, lwe_params = #params>

%4 = cggi.multi_lut_lincomb %c0, %c1, %c2, %c3 {
coefficients = array<i32: 1, 2, 3, 2>,
lookup_tables = array<index: 68, 70, 4, 8>
} : (!ciphertext, !ciphertext, !ciphertext, !ciphertext) -> (!ciphertext, !ciphertext, !ciphertext, !ciphertext)
```

Represents applying the following LUTs. Performance-wise, this is
comparable to applying a single LUT to a linear combination.

```
x = (1 * c0 + 2 * c1 + 3 * c2 + 2 * c3)
return (
(68 >> x) & 1,
(70 >> x) & 1,
(4 >> x) & 1,
(8 >> x) & 1
)
```
}];

let arguments = (ins
Variadic<LWECiphertext>:$inputs,
DenseI32ArrayAttr:$coefficients,
DenseI32ArrayAttr:$lookup_tables
);
let results = (outs Variadic<LWECiphertext>:$outputs);
let assemblyFormat = "operands attr-dict `:` functional-type($inputs, $outputs)" ;
let hasVerifier = 1;
}

#endif // HEIR_LIB_DIALECT_CGGI_IR_CGGIOPS_TD_
27 changes: 14 additions & 13 deletions lib/Graph/Graph.h
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
#ifndef LIB_GRAPH_GRAPH_H_
#define LIB_GRAPH_GRAPH_H_

#include <algorithm>
#include <cstdint>
#include <map>
#include <set>
#include <vector>

#include "llvm/include/llvm/ADT/ArrayRef.h" // from @llvm-project
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
#include "llvm/include/llvm/ADT/SmallSet.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project

namespace mlir {
Expand All @@ -25,7 +26,7 @@ class Graph {
// if either the source or target is not a previously inserted vertex, and
// returns true otherwise. The graph is unchanged if false is returned.
bool addEdge(V source, V target) {
if (!vertices.contains(source) || !vertices.contains(target)) {
if (vertices.count(source) == 0 || vertices.count(target) == 0) {
return false;
}
outEdges[source].insert(target);
Expand All @@ -35,15 +36,15 @@ class Graph {

// Returns true iff the given vertex has previously been added to the graph
// using `AddVertex`.
bool contains(V vertex) { return vertices.contains(vertex); }
bool contains(V vertex) { return vertices.count(vertex) > 0; }

bool empty() { return vertices.empty(); }

const llvm::SmallSet<V, 4>& getVertices() { return vertices; }
const std::set<V>& getVertices() { return vertices; }

// Returns the edges that point out of the given vertex.
std::vector<V> edgesOutOf(V vertex) {
if (vertices.contains(vertex)) {
if (vertices.count(vertex)) {
std::vector<V> result(outEdges[vertex].begin(), outEdges[vertex].end());
// Note: The vertices are sorted to ensure determinism in the output.
std::sort(result.begin(), result.end());
Expand All @@ -54,7 +55,7 @@ class Graph {

// Returns the edges that point into the given vertex.
std::vector<V> edgesInto(V vertex) {
if (vertices.contains(vertex)) {
if (vertices.count(vertex)) {
std::vector<V> result(inEdges[vertex].begin(), inEdges[vertex].end());
// Note: The vertices are sorted to ensure determinism in the output.
std::sort(result.begin(), result.end());
Expand All @@ -70,7 +71,7 @@ class Graph {

// Kahn's algorithm
std::vector<V> active;
llvm::DenseMap<V, int64_t> edgeCount;
std::map<V, int64_t> edgeCount;
for (const V& vertex : vertices) {
edgeCount[vertex] = edgesInto(vertex).size();
if (edgeCount.at(vertex) == 0) {
Expand Down Expand Up @@ -118,7 +119,7 @@ class Graph {
}
auto topoOrder = result.value();
std::reverse(topoOrder.begin(), topoOrder.end());
llvm::DenseMap<V, int> levels;
std::map<V, int> levels;

// Assign levels to the nodes:
// Traverse through the reversed topologically sorted nodes
Expand Down Expand Up @@ -148,9 +149,9 @@ class Graph {
}

private:
llvm::SmallSet<V, 4> vertices;
llvm::DenseMap<V, llvm::SmallSet<V, 4>> outEdges;
llvm::DenseMap<V, llvm::SmallSet<V, 4>> inEdges;
std::set<V> vertices;
std::map<V, std::set<V>> outEdges;
std::map<V, std::set<V>> inEdges;
};

} // namespace graph
Expand Down
Loading

0 comments on commit 195121f

Please sign in to comment.