Skip to content

Commit

Permalink
Implement index comparison ops
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Jan 15, 2024
1 parent c4bf6df commit b2e5d8e
Show file tree
Hide file tree
Showing 10 changed files with 270 additions and 2 deletions.
11 changes: 10 additions & 1 deletion crates/autodiff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,16 @@ impl Autodiff<'_> {
},
&Expr::Binary { op, left, right } => match op {
// boring cases
Binop::And | Binop::Or | Binop::Iff | Binop::Xor => self.code.push(Instr {
Binop::And
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq => self.code.push(Instr {
var,
expr: Expr::Binary { op, left, right },
}),
Expand Down
8 changes: 8 additions & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ pub enum Binop {
Iff,
Xor,

// `Fin` -> `Fin` -> `Bool`
INeq,
ILt,
ILeq,
IEq,
IGt,
IGeq,

// `F64` -> `F64` -> `Bool`
Neq,
Lt,
Expand Down
7 changes: 7 additions & 0 deletions crates/interp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
Binop::Iff => Val::Bool(x.bool() == y.bool()),
Binop::Xor => Val::Bool(x.bool() != y.bool()),

Binop::INeq => Val::Bool(x.fin() != y.fin()),
Binop::ILt => Val::Bool(x.fin() < y.fin()),
Binop::ILeq => Val::Bool(x.fin() <= y.fin()),
Binop::IEq => Val::Bool(x.fin() == y.fin()),
Binop::IGt => Val::Bool(x.fin() > y.fin()),
Binop::IGeq => Val::Bool(x.fin() >= y.fin()),

Binop::Neq => Val::Bool(x.f64() != y.f64()),
Binop::Lt => Val::Bool(x.f64() < y.f64()),
Binop::Leq => Val::Bool(x.f64() <= y.f64()),
Expand Down
17 changes: 16 additions & 1 deletion crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,12 @@ impl<'a> Transpose<'a> {
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq
| Binop::Neq
| Binop::Lt
| Binop::Leq
Expand Down Expand Up @@ -704,7 +710,16 @@ impl<'a> Transpose<'a> {
}
_ => {
let (a, b) = match op {
Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right),
Binop::And
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq => (left, right),
Binop::Neq
| Binop::Lt
| Binop::Leq
Expand Down
6 changes: 6 additions & 0 deletions crates/wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,12 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {
Binop::Or => self.wasm.instruction(&Instruction::I32Or),
Binop::Iff => self.wasm.instruction(&Instruction::I32Eq),
Binop::Xor => self.wasm.instruction(&Instruction::I32Xor),
Binop::INeq => self.wasm.instruction(&Instruction::I32Ne),
Binop::ILt => self.wasm.instruction(&Instruction::I32LtU),
Binop::ILeq => self.wasm.instruction(&Instruction::I32LeU),
Binop::IEq => self.wasm.instruction(&Instruction::I32Eq),
Binop::IGt => self.wasm.instruction(&Instruction::I32GtU),
Binop::IGeq => self.wasm.instruction(&Instruction::I32GeU),
Binop::Neq => self.wasm.instruction(&Instruction::F64Ne),
Binop::Lt => self.wasm.instruction(&Instruction::F64Lt),
Binop::Leq => self.wasm.instruction(&Instruction::F64Le),
Expand Down
80 changes: 80 additions & 0 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,86 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer not equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ineq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::INeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer less than" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ilt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::ILt,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer less than or equal" instruction on `left` and
/// `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ileq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::ILeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ieq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IEq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer greater than" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn igt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IGt,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "integer greater than or equal" instruction on `left` and
/// `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn igeq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IGeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "not equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
Expand Down
6 changes: 6 additions & 0 deletions crates/web/src/pprint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,12 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> {
Binop::Or => writeln!(f, "x{} or x{}", left.var(), right.var())?,
Binop::Iff => writeln!(f, "x{} iff x{}", left.var(), right.var())?,
Binop::Xor => writeln!(f, "x{} xor x{}", left.var(), right.var())?,
Binop::INeq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
Binop::ILt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
Binop::ILeq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,
Binop::IEq => writeln!(f, "x{} == x{}", left.var(), right.var())?,
Binop::IGt => writeln!(f, "x{} > x{}", left.var(), right.var())?,
Binop::IGeq => writeln!(f, "x{} >= x{}", left.var(), right.var())?,
Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,
Expand Down
42 changes: 42 additions & 0 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,48 @@ export const xor = (p: Bool, q: Bool): Bool => {
return newVar(ctx.block.xor(ctx.func, boolId(ctx, p), boolId(ctx, q)));
};

/** Return an abstract boolean for if `i` is not equal to `j`. */
export const ineq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ineq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is less than `j`. */
export const ilt = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ilt(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is less than or equal to `j`. */
export const ileq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ileq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is equal to `j`. */
export const ieq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ieq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is greater than `j`. */
export const igt = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.igt(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is greater than or equal to `j`. */
export const igeq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.igeq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract value selecting between `then` and `els` via `cond`. */
export const select = <const T>(
cond: Bool,
Expand Down
89 changes: 89 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@ import {
floor,
fn,
gt,
ieq,
iff,
igeq,
igt,
ileq,
ilt,
ineq,
interp,
jvp,
mul,
Expand Down Expand Up @@ -942,4 +948,87 @@ describe("valid", () => {
const h = await compile(g);
expect(h({ v: [2], i: 0 })).toEqual({ v: [1], i: 0 });
});

test("index comparison", async () => {
const f = fn(
[2, 2],
{ neq: Bool, lt: Bool, leq: Bool, eq: Bool, gt: Bool, geq: Bool },
(i, j) => ({
neq: ineq(2, i, j),
lt: ilt(2, i, j),
leq: ileq(2, i, j),
eq: ieq(2, i, j),
gt: igt(2, i, j),
geq: igeq(2, i, j),
}),
);

let g = interp(f);
expect(g(0, 0)).toEqual({
neq: false,
lt: false,
leq: true,
eq: true,
gt: false,
geq: true,
});
expect(g(0, 1)).toEqual({
neq: true,
lt: true,
leq: true,
eq: false,
gt: false,
geq: false,
});
expect(g(1, 0)).toEqual({
neq: true,
lt: false,
leq: false,
eq: false,
gt: true,
geq: true,
});
expect(g(1, 1)).toEqual({
neq: false,
lt: false,
leq: true,
eq: true,
gt: false,
geq: true,
});

g = await compile(f);
expect(g(0, 0)).toEqual({
neq: false,
lt: false,
leq: true,
eq: true,
gt: false,
geq: true,
});
expect(g(0, 1)).toEqual({
neq: true,
lt: true,
leq: true,
eq: false,
gt: false,
geq: false,
});
expect(g(1, 0)).toEqual({
neq: true,
lt: false,
leq: false,
eq: false,
gt: true,
geq: true,
});
expect(g(1, 1)).toEqual({
neq: false,
lt: false,
leq: true,
eq: true,
gt: false,
geq: true,
});
});
});
6 changes: 6 additions & 0 deletions packages/core/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ export {
fn,
geq,
gt,
ieq,
iff,
igeq,
igt,
ileq,
ilt,
ineq,
interp,
jvp,
leq,
Expand Down

0 comments on commit b2e5d8e

Please sign in to comment.