diff --git a/crates/autodiff/src/lib.rs b/crates/autodiff/src/lib.rs index 1d2a9eb..cfde52c 100644 --- a/crates/autodiff/src/lib.rs +++ b/crates/autodiff/src/lib.rs @@ -240,7 +240,17 @@ impl Autodiff<'_> { }, &Expr::Binary { op, left, right } => match op { // boring cases - Binop::And | Binop::Or | Binop::Iff | Binop::Xor => self.code.push(Instr { + Binop::And + | Binop::Or + | Binop::Iff + | Binop::Xor + | Binop::INeq + | Binop::ILt + | Binop::ILeq + | Binop::IEq + | Binop::IGt + | Binop::IGeq + | Binop::IAdd => self.code.push(Instr { var, expr: Expr::Binary { op, left, right }, }), diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index 478123e..bceba08 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -210,6 +210,17 @@ pub enum Binop { Iff, Xor, + // `Fin` -> `Fin` -> `Bool` + INeq, + ILt, + ILeq, + IEq, + IGt, + IGeq, + + // `Fin` -> `Fin` -> `Fin` + IAdd, + // `F64` -> `F64` -> `Bool` Neq, Lt, diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 185c629..7aa6fc8 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -236,6 +236,15 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> { Binop::Iff => Val::Bool(x.bool() == y.bool()), Binop::Xor => Val::Bool(x.bool() != y.bool()), + Binop::INeq => Val::Bool(x.fin() != y.fin()), + Binop::ILt => Val::Bool(x.fin() < y.fin()), + Binop::ILeq => Val::Bool(x.fin() <= y.fin()), + Binop::IEq => Val::Bool(x.fin() == y.fin()), + Binop::IGt => Val::Bool(x.fin() > y.fin()), + Binop::IGeq => Val::Bool(x.fin() >= y.fin()), + + Binop::IAdd => Val::Fin(x.fin() + y.fin()), + Binop::Neq => Val::Bool(x.f64() != y.f64()), Binop::Lt => Val::Bool(x.f64() < y.f64()), Binop::Leq => Val::Bool(x.f64() <= y.f64()), diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index 40c3e35..ba1a901 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -630,6 +630,13 @@ impl<'a> Transpose<'a> { | Binop::Or | Binop::Iff | Binop::Xor + | Binop::INeq + | Binop::ILt + | Binop::ILeq + | Binop::IEq + | Binop::IGt + | Binop::IGeq + | Binop::IAdd | Binop::Neq | Binop::Lt | Binop::Leq @@ -704,7 +711,17 @@ impl<'a> Transpose<'a> { } _ => { let (a, b) = match op { - Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right), + Binop::And + | Binop::Or + | Binop::Iff + | Binop::Xor + | Binop::INeq + | Binop::ILt + | Binop::ILeq + | Binop::IEq + | Binop::IGt + | Binop::IGeq + | Binop::IAdd => (left, right), Binop::Neq | Binop::Lt | Binop::Leq diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs index 4107cba..5b088f6 100644 --- a/crates/wasm/src/lib.rs +++ b/crates/wasm/src/lib.rs @@ -782,6 +782,13 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { Binop::Or => self.wasm.instruction(&Instruction::I32Or), Binop::Iff => self.wasm.instruction(&Instruction::I32Eq), Binop::Xor => self.wasm.instruction(&Instruction::I32Xor), + Binop::INeq => self.wasm.instruction(&Instruction::I32Ne), + Binop::ILt => self.wasm.instruction(&Instruction::I32LtU), + Binop::ILeq => self.wasm.instruction(&Instruction::I32LeU), + Binop::IEq => self.wasm.instruction(&Instruction::I32Eq), + Binop::IGt => self.wasm.instruction(&Instruction::I32GtU), + Binop::IGeq => self.wasm.instruction(&Instruction::I32GeU), + Binop::IAdd => self.wasm.instruction(&Instruction::I32Add), Binop::Neq => self.wasm.instruction(&Instruction::F64Ne), Binop::Lt => self.wasm.instruction(&Instruction::F64Lt), Binop::Leq => self.wasm.instruction(&Instruction::F64Le), diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 34623c8..aa58bd0 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -1318,6 +1318,98 @@ impl Block { self.instr(f, t, expr) } + /// Return the variable ID for a new "index not equal" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn ineq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::INeq, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "index less than" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn ilt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::ILt, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "index less than or equal" instruction on `left` and + /// `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn ileq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::ILeq, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "index equal" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn ieq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::IEq, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "index greater than" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn igt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::IGt, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "index greater than or equal" instruction on `left` and + /// `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn igeq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize { + let t = id::ty(f.ty_bool()); + let expr = rose::Expr::Binary { + op: rose::Binop::IGeq, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, t, expr) + } + + /// Return the variable ID for a new "index add" instruction on `left` and `right`. + /// + /// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type. + pub fn iadd(&mut self, f: &mut FuncBuilder, t: usize, left: usize, right: usize) -> usize { + let expr = rose::Expr::Binary { + op: rose::Binop::IAdd, + left: id::var(left), + right: id::var(right), + }; + self.instr(f, id::ty(t), expr) + } + /// Return the variable ID for a new "not equal" instruction on `left` and `right`. /// /// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type. diff --git a/crates/web/src/pprint.rs b/crates/web/src/pprint.rs index df38690..c616cbc 100644 --- a/crates/web/src/pprint.rs +++ b/crates/web/src/pprint.rs @@ -171,6 +171,13 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> { Binop::Or => writeln!(f, "x{} or x{}", left.var(), right.var())?, Binop::Iff => writeln!(f, "x{} iff x{}", left.var(), right.var())?, Binop::Xor => writeln!(f, "x{} xor x{}", left.var(), right.var())?, + Binop::INeq => writeln!(f, "x{} != x{}", left.var(), right.var())?, + Binop::ILt => writeln!(f, "x{} < x{}", left.var(), right.var())?, + Binop::ILeq => writeln!(f, "x{} <= x{}", left.var(), right.var())?, + Binop::IEq => writeln!(f, "x{} == x{}", left.var(), right.var())?, + Binop::IGt => writeln!(f, "x{} > x{}", left.var(), right.var())?, + Binop::IGeq => writeln!(f, "x{} >= x{}", left.var(), right.var())?, + Binop::IAdd => writeln!(f, "x{} + x{}", left.var(), right.var())?, Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?, Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?, Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?, diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 99b02e3..3c0a135 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -90,7 +90,7 @@ type Zero = typeof zeroSymbol; export type Tan = Zero | Var; /** An abstract natural number, which can be used to index into a vector. */ -type Nat = number | symbol; +export type Nat = number | symbol; /** The portion of an abstract vector that can be directly indexed. */ interface VecIndex { @@ -954,6 +954,56 @@ export const xor = (p: Bool, q: Bool): Bool => { return newVar(ctx.block.xor(ctx.func, boolId(ctx, p), boolId(ctx, q))); }; +/** Return an abstract boolean for if `i` is not equal to `j`. */ +export const ineq = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.ineq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is less than `j`. */ +export const ilt = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.ilt(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is less than or equal to `j`. */ +export const ileq = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.ileq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is equal to `j`. */ +export const ieq = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.ieq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is greater than `j`. */ +export const igt = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.igt(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return an abstract boolean for if `i` is greater than or equal to `j`. */ +export const igeq = (ty: Nats, i: Nat, j: Nat): Bool => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + return newVar(ctx.block.igeq(ctx.func, valId(ctx, t, i), valId(ctx, t, j))); +}; + +/** Return the abstract index `i` plus the abstract index `y`. */ +export const iadd = (ty: Nats, i: Nat, j: Nat): Nat => { + const ctx = getCtx(); + const t = tyId(ctx, ty); + const k = ctx.block.iadd(ctx.func, t, valId(ctx, t, i), valId(ctx, t, j)); + return idVal(ctx, t, k) as Nat; +}; + /** Return an abstract value selecting between `then` and `els` via `cond`. */ export const select = ( cond: Bool, diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 4528f64..77c59d7 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -14,7 +14,14 @@ import { floor, fn, gt, + iadd, + ieq, iff, + igeq, + igt, + ileq, + ilt, + ineq, interp, jvp, mul, @@ -942,4 +949,107 @@ describe("valid", () => { const h = await compile(g); expect(h({ v: [2], i: 0 })).toEqual({ v: [1], i: 0 }); }); + + test("index comparison", async () => { + const f = fn( + [2, 2], + { neq: Bool, lt: Bool, leq: Bool, eq: Bool, gt: Bool, geq: Bool }, + (i, j) => ({ + neq: ineq(2, i, j), + lt: ilt(2, i, j), + leq: ileq(2, i, j), + eq: ieq(2, i, j), + gt: igt(2, i, j), + geq: igeq(2, i, j), + }), + ); + + let g = interp(f); + expect(g(0, 0)).toEqual({ + neq: false, + lt: false, + leq: true, + eq: true, + gt: false, + geq: true, + }); + expect(g(0, 1)).toEqual({ + neq: true, + lt: true, + leq: true, + eq: false, + gt: false, + geq: false, + }); + expect(g(1, 0)).toEqual({ + neq: true, + lt: false, + leq: false, + eq: false, + gt: true, + geq: true, + }); + expect(g(1, 1)).toEqual({ + neq: false, + lt: false, + leq: true, + eq: true, + gt: false, + geq: true, + }); + + g = await compile(f); + expect(g(0, 0)).toEqual({ + neq: false, + lt: false, + leq: true, + eq: true, + gt: false, + geq: true, + }); + expect(g(0, 1)).toEqual({ + neq: true, + lt: true, + leq: true, + eq: false, + gt: false, + geq: false, + }); + expect(g(1, 0)).toEqual({ + neq: true, + lt: false, + leq: false, + eq: false, + gt: true, + geq: true, + }); + expect(g(1, 1)).toEqual({ + neq: false, + lt: false, + leq: true, + eq: true, + gt: false, + geq: true, + }); + }); + + test("index addition", async () => { + const f = fn([3, 3], 3, (i, j) => iadd(3, i, j)); + + let g = interp(f); + expect(g(0, 0)).toBe(0); + expect(g(0, 1)).toBe(1); + expect(g(0, 2)).toBe(2); + expect(g(1, 0)).toBe(1); + expect(g(1, 1)).toBe(2); + expect(g(2, 0)).toBe(2); + + g = await compile(f); + expect(g(0, 0)).toBe(0); + expect(g(0, 1)).toBe(1); + expect(g(0, 2)).toBe(2); + expect(g(1, 0)).toBe(1); + expect(g(1, 1)).toBe(2); + expect(g(2, 0)).toBe(2); + }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 589e750..e803e2d 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -3,6 +3,7 @@ export { Bools, Dual, Fn, + Nat, Nats, Null, Nulls, @@ -25,7 +26,14 @@ export { fn, geq, gt, + iadd, + ieq, iff, + igeq, + igt, + ileq, + ilt, + ineq, interp, jvp, leq,