Skip to content

Commit

Permalink
Implement index addition
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Jan 15, 2024
1 parent b2e5d8e commit e3c8633
Show file tree
Hide file tree
Showing 10 changed files with 60 additions and 8 deletions.
3 changes: 2 additions & 1 deletion crates/autodiff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,8 @@ impl Autodiff<'_> {
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq => self.code.push(Instr {
| Binop::IGeq
| Binop::IAdd => self.code.push(Instr {
var,
expr: Expr::Binary { op, left, right },
}),
Expand Down
3 changes: 3 additions & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,9 @@ pub enum Binop {
IGt,
IGeq,

// `Fin` -> `Fin` -> `Fin`
IAdd,

// `F64` -> `F64` -> `Bool`
Neq,
Lt,
Expand Down
2 changes: 2 additions & 0 deletions crates/interp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
Binop::IGt => Val::Bool(x.fin() > y.fin()),
Binop::IGeq => Val::Bool(x.fin() >= y.fin()),

Binop::IAdd => Val::Fin(x.fin() + y.fin()),

Binop::Neq => Val::Bool(x.f64() != y.f64()),
Binop::Lt => Val::Bool(x.f64() < y.f64()),
Binop::Leq => Val::Bool(x.f64() <= y.f64()),
Expand Down
4 changes: 3 additions & 1 deletion crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ impl<'a> Transpose<'a> {
| Binop::IEq
| Binop::IGt
| Binop::IGeq
| Binop::IAdd
| Binop::Neq
| Binop::Lt
| Binop::Leq
Expand Down Expand Up @@ -719,7 +720,8 @@ impl<'a> Transpose<'a> {
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq => (left, right),
| Binop::IGeq
| Binop::IAdd => (left, right),
Binop::Neq
| Binop::Lt
| Binop::Leq
Expand Down
1 change: 1 addition & 0 deletions crates/wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,7 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {
Binop::IEq => self.wasm.instruction(&Instruction::I32Eq),
Binop::IGt => self.wasm.instruction(&Instruction::I32GtU),
Binop::IGeq => self.wasm.instruction(&Instruction::I32GeU),
Binop::IAdd => self.wasm.instruction(&Instruction::I32Add),
Binop::Neq => self.wasm.instruction(&Instruction::F64Ne),
Binop::Lt => self.wasm.instruction(&Instruction::F64Lt),
Binop::Leq => self.wasm.instruction(&Instruction::F64Le),
Expand Down
24 changes: 18 additions & 6 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1318,7 +1318,7 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer not equal" instruction on `left` and `right`.
/// Return the variable ID for a new "index not equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ineq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
Expand All @@ -1331,7 +1331,7 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer less than" instruction on `left` and `right`.
/// Return the variable ID for a new "index less than" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ilt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
Expand All @@ -1344,7 +1344,7 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer less than or equal" instruction on `left` and
/// Return the variable ID for a new "index less than or equal" instruction on `left` and
/// `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
Expand All @@ -1358,7 +1358,7 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer equal" instruction on `left` and `right`.
/// Return the variable ID for a new "index equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ieq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
Expand All @@ -1371,7 +1371,7 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer greater than" instruction on `left` and `right`.
/// Return the variable ID for a new "index greater than" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn igt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
Expand All @@ -1384,7 +1384,7 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer greater than or equal" instruction on `left` and
/// Return the variable ID for a new "index greater than or equal" instruction on `left` and
/// `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
Expand All @@ -1398,6 +1398,18 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index add" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn iadd(&mut self, f: &mut FuncBuilder, t: usize, left: usize, right: usize) -> usize {
let expr = rose::Expr::Binary {
op: rose::Binop::IAdd,
left: id::var(left),
right: id::var(right),
};
self.instr(f, id::ty(t), expr)
}

/// Return the variable ID for a new "not equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
Expand Down
1 change: 1 addition & 0 deletions crates/web/src/pprint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> {
Binop::IEq => writeln!(f, "x{} == x{}", left.var(), right.var())?,
Binop::IGt => writeln!(f, "x{} > x{}", left.var(), right.var())?,
Binop::IGeq => writeln!(f, "x{} >= x{}", left.var(), right.var())?,
Binop::IAdd => writeln!(f, "x{} + x{}", left.var(), right.var())?,
Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,
Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,14 @@ export const igeq = (ty: Nats, i: Nat, j: Nat): Bool => {
return newVar(ctx.block.igeq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return the abstract index `i` plus the abstract index `y`. */
export const iadd = (ty: Nats, i: Nat, j: Nat): Nat => {
const ctx = getCtx();
const t = tyId(ctx, ty);
const k = ctx.block.iadd(ctx.func, t, valId(ctx, t, i), valId(ctx, t, j));
return idVal(ctx, t, k) as Nat;
};

/** Return an abstract value selecting between `then` and `els` via `cond`. */
export const select = <const T>(
cond: Bool,
Expand Down
21 changes: 21 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
floor,
fn,
gt,
iadd,
ieq,
iff,
igeq,
Expand Down Expand Up @@ -1031,4 +1032,24 @@ describe("valid", () => {
geq: true,
});
});

test("index addition", async () => {
const f = fn([3, 3], 3, (i, j) => iadd(3, i, j));

let g = interp(f);
expect(g(0, 0)).toBe(0);
expect(g(0, 1)).toBe(1);
expect(g(0, 2)).toBe(2);
expect(g(1, 0)).toBe(1);
expect(g(1, 1)).toBe(2);
expect(g(2, 0)).toBe(2);

g = await compile(f);
expect(g(0, 0)).toBe(0);
expect(g(0, 1)).toBe(1);
expect(g(0, 2)).toBe(2);
expect(g(1, 0)).toBe(1);
expect(g(1, 1)).toBe(2);
expect(g(2, 0)).toBe(2);
});
});
1 change: 1 addition & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ export {
fn,
geq,
gt,
iadd,
ieq,
iff,
igeq,
Expand Down

0 comments on commit e3c8633

Please sign in to comment.