Skip to content

Commit

Permalink
Support index arithmetic (#124)
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep authored Jan 15, 2024
1 parent cf95b07 commit 27a8b79
Show file tree
Hide file tree
Showing 10 changed files with 324 additions and 3 deletions.
12 changes: 11 additions & 1 deletion crates/autodiff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,17 @@ impl Autodiff<'_> {
},
&Expr::Binary { op, left, right } => match op {
// boring cases
Binop::And | Binop::Or | Binop::Iff | Binop::Xor => self.code.push(Instr {
Binop::And
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq
| Binop::IAdd => self.code.push(Instr {
var,
expr: Expr::Binary { op, left, right },
}),
Expand Down
11 changes: 11 additions & 0 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,17 @@ pub enum Binop {
Iff,
Xor,

// `Fin` -> `Fin` -> `Bool`
INeq,
ILt,
ILeq,
IEq,
IGt,
IGeq,

// `Fin` -> `Fin` -> `Fin`
IAdd,

// `F64` -> `F64` -> `Bool`
Neq,
Lt,
Expand Down
9 changes: 9 additions & 0 deletions crates/interp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,15 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
Binop::Iff => Val::Bool(x.bool() == y.bool()),
Binop::Xor => Val::Bool(x.bool() != y.bool()),

Binop::INeq => Val::Bool(x.fin() != y.fin()),
Binop::ILt => Val::Bool(x.fin() < y.fin()),
Binop::ILeq => Val::Bool(x.fin() <= y.fin()),
Binop::IEq => Val::Bool(x.fin() == y.fin()),
Binop::IGt => Val::Bool(x.fin() > y.fin()),
Binop::IGeq => Val::Bool(x.fin() >= y.fin()),

Binop::IAdd => Val::Fin(x.fin() + y.fin()),

Binop::Neq => Val::Bool(x.f64() != y.f64()),
Binop::Lt => Val::Bool(x.f64() < y.f64()),
Binop::Leq => Val::Bool(x.f64() <= y.f64()),
Expand Down
19 changes: 18 additions & 1 deletion crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,13 @@ impl<'a> Transpose<'a> {
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq
| Binop::IAdd
| Binop::Neq
| Binop::Lt
| Binop::Leq
Expand Down Expand Up @@ -704,7 +711,17 @@ impl<'a> Transpose<'a> {
}
_ => {
let (a, b) = match op {
Binop::And | Binop::Or | Binop::Iff | Binop::Xor => (left, right),
Binop::And
| Binop::Or
| Binop::Iff
| Binop::Xor
| Binop::INeq
| Binop::ILt
| Binop::ILeq
| Binop::IEq
| Binop::IGt
| Binop::IGeq
| Binop::IAdd => (left, right),
Binop::Neq
| Binop::Lt
| Binop::Leq
Expand Down
7 changes: 7 additions & 0 deletions crates/wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,13 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {
Binop::Or => self.wasm.instruction(&Instruction::I32Or),
Binop::Iff => self.wasm.instruction(&Instruction::I32Eq),
Binop::Xor => self.wasm.instruction(&Instruction::I32Xor),
Binop::INeq => self.wasm.instruction(&Instruction::I32Ne),
Binop::ILt => self.wasm.instruction(&Instruction::I32LtU),
Binop::ILeq => self.wasm.instruction(&Instruction::I32LeU),
Binop::IEq => self.wasm.instruction(&Instruction::I32Eq),
Binop::IGt => self.wasm.instruction(&Instruction::I32GtU),
Binop::IGeq => self.wasm.instruction(&Instruction::I32GeU),
Binop::IAdd => self.wasm.instruction(&Instruction::I32Add),
Binop::Neq => self.wasm.instruction(&Instruction::F64Ne),
Binop::Lt => self.wasm.instruction(&Instruction::F64Lt),
Binop::Leq => self.wasm.instruction(&Instruction::F64Le),
Expand Down
92 changes: 92 additions & 0 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1318,6 +1318,98 @@ impl Block {
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index not equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ineq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::INeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index less than" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ilt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::ILt,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index less than or equal" instruction on `left` and
/// `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ileq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::ILeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn ieq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IEq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index greater than" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn igt(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IGt,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index greater than or equal" instruction on `left` and
/// `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn igeq(&mut self, f: &mut FuncBuilder, left: usize, right: usize) -> usize {
let t = id::ty(f.ty_bool());
let expr = rose::Expr::Binary {
op: rose::Binop::IGeq,
left: id::var(left),
right: id::var(right),
};
self.instr(f, t, expr)
}

/// Return the variable ID for a new "index add" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have the same `Fin` type.
pub fn iadd(&mut self, f: &mut FuncBuilder, t: usize, left: usize, right: usize) -> usize {
let expr = rose::Expr::Binary {
op: rose::Binop::IAdd,
left: id::var(left),
right: id::var(right),
};
self.instr(f, id::ty(t), expr)
}

/// Return the variable ID for a new "not equal" instruction on `left` and `right`.
///
/// Assumes `left` and `right` are defined, in scope, and have 64-bit floating point type.
Expand Down
7 changes: 7 additions & 0 deletions crates/web/src/pprint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ impl<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>> Function<'a, '_, O, T> {
Binop::Or => writeln!(f, "x{} or x{}", left.var(), right.var())?,
Binop::Iff => writeln!(f, "x{} iff x{}", left.var(), right.var())?,
Binop::Xor => writeln!(f, "x{} xor x{}", left.var(), right.var())?,
Binop::INeq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
Binop::ILt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
Binop::ILeq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,
Binop::IEq => writeln!(f, "x{} == x{}", left.var(), right.var())?,
Binop::IGt => writeln!(f, "x{} > x{}", left.var(), right.var())?,
Binop::IGeq => writeln!(f, "x{} >= x{}", left.var(), right.var())?,
Binop::IAdd => writeln!(f, "x{} + x{}", left.var(), right.var())?,
Binop::Neq => writeln!(f, "x{} != x{}", left.var(), right.var())?,
Binop::Lt => writeln!(f, "x{} < x{}", left.var(), right.var())?,
Binop::Leq => writeln!(f, "x{} <= x{}", left.var(), right.var())?,
Expand Down
52 changes: 51 additions & 1 deletion packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ type Zero = typeof zeroSymbol;
export type Tan = Zero | Var;

/** An abstract natural number, which can be used to index into a vector. */
type Nat = number | symbol;
export type Nat = number | symbol;

/** The portion of an abstract vector that can be directly indexed. */
interface VecIndex<T> {
Expand Down Expand Up @@ -954,6 +954,56 @@ export const xor = (p: Bool, q: Bool): Bool => {
return newVar(ctx.block.xor(ctx.func, boolId(ctx, p), boolId(ctx, q)));
};

/** Return an abstract boolean for if `i` is not equal to `j`. */
export const ineq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ineq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is less than `j`. */
export const ilt = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ilt(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is less than or equal to `j`. */
export const ileq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ileq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is equal to `j`. */
export const ieq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.ieq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is greater than `j`. */
export const igt = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.igt(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return an abstract boolean for if `i` is greater than or equal to `j`. */
export const igeq = (ty: Nats, i: Nat, j: Nat): Bool => {
const ctx = getCtx();
const t = tyId(ctx, ty);
return newVar(ctx.block.igeq(ctx.func, valId(ctx, t, i), valId(ctx, t, j)));
};

/** Return the abstract index `i` plus the abstract index `y`. */
export const iadd = (ty: Nats, i: Nat, j: Nat): Nat => {
const ctx = getCtx();
const t = tyId(ctx, ty);
const k = ctx.block.iadd(ctx.func, t, valId(ctx, t, i), valId(ctx, t, j));
return idVal(ctx, t, k) as Nat;
};

/** Return an abstract value selecting between `then` and `els` via `cond`. */
export const select = <const T>(
cond: Bool,
Expand Down
Loading

0 comments on commit 27a8b79

Please sign in to comment.