diff --git a/Cargo.lock b/Cargo.lock index 1202eef..bbe4553 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,6 +14,12 @@ version = "3.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535" +[[package]] +name = "by_address" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf8dba2868114ed769a1f2590fc9ae5eb331175b44313b6c9b922f8f7ca813d0" + [[package]] name = "cfg-if" version = "1.0.0" @@ -128,6 +134,12 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "leb128" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67" + [[package]] name = "log" version = "0.4.17" @@ -196,10 +208,21 @@ dependencies = [ "rose", ] +[[package]] +name = "rose-wasm" +version = "0.3.2" +dependencies = [ + "by_address", + "indexmap", + "rose", + "wasm-encoder", +] + [[package]] name = "rose-web" version = "0.3.2" dependencies = [ + "by_address", "console_error_panic_hook", "enumset", "indexmap", @@ -208,6 +231,7 @@ dependencies = [ "rose-autodiff", "rose-interp", "rose-transpose", + "rose-wasm", "serde", "serde-wasm-bindgen", "wasm-bindgen", @@ -378,6 +402,15 @@ version = "0.2.87" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ca6ad05a4870b2bf5fe995117d3728437bd27d7cd5f06f13c17443ef369775a1" +[[package]] +name = "wasm-encoder" +version = "0.33.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b39de0723a53d3c8f54bed106cfbc0d06b3e4d945c5c5022115a61e3b29183ae" +dependencies = [ + "leb128", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/crates/interp/src/lib.rs b/crates/interp/src/lib.rs index 436f444..35c2062 100644 --- a/crates/interp/src/lib.rs +++ b/crates/interp/src/lib.rs @@ -123,9 +123,6 @@ impl Val { /// This is meant to be used to pull all the types from a callee into a broader context. The /// `generics` are the IDs of all the types provided as generic type parameters for the callee. The /// `types are the IDs of all the types that have been pulled in so far. -/// -/// The interpreter is meant to be used with no generic "free variables," and does not do any scope -/// checking, so all scopes are replaced with a block ID of zero. fn resolve(typemap: &mut IndexSet, generics: &[id::Ty], types: &[id::Ty], ty: &Ty) -> id::Ty { let resolved = match ty { Ty::Generic { id } => return generics[id.generic()], diff --git a/crates/wasm/Cargo.toml b/crates/wasm/Cargo.toml new file mode 100644 index 0000000..0a02305 --- /dev/null +++ b/crates/wasm/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "rose-wasm" +version = "0.3.2" +publish = false +edition = "2021" + +[dependencies] +by_address = "1" +indexmap = "2" +rose = { path = "../core" } +wasm-encoder = "0.33" diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs new file mode 100644 index 0000000..c66348f --- /dev/null +++ b/crates/wasm/src/lib.rs @@ -0,0 +1,1168 @@ +use by_address::ByAddress; +use indexmap::{map::Entry, IndexMap, IndexSet}; +use rose::{id, Binop, Expr, Func, Instr, Node, Refs, Ty, Unop}; +use std::{hash::Hash, mem::take}; +use wasm_encoder::{ + BlockType, CodeSection, EntityType, ExportSection, Function, FunctionSection, ImportSection, + Instruction, MemArg, MemorySection, MemoryType, Module, TypeSection, ValType, +}; + +/// Resolve `ty` via `generics` and `types`, then return its ID in `typemap`, inserting if need be. +/// +/// This is meant to be used to pull all the types from a callee into a broader context. The +/// `generics` are the IDs of all the types provided as generic type parameters for the callee. The +/// `types are the IDs of all the types that have been pulled in so far. +fn resolve(typemap: &mut IndexSet, generics: &[id::Ty], types: &[id::Ty], ty: &Ty) -> id::Ty { + let resolved = match ty { + Ty::Generic { id } => return generics[id.generic()], + + Ty::Unit => Ty::Unit, + Ty::Bool => Ty::Bool, + Ty::F64 => Ty::F64, + &Ty::Fin { size } => Ty::Fin { size }, + + Ty::Ref { inner } => Ty::Ref { + inner: types[inner.ty()], + }, + Ty::Array { index, elem } => Ty::Array { + index: types[index.ty()], + elem: types[elem.ty()], + }, + Ty::Tuple { members } => Ty::Tuple { + members: members.iter().map(|&x| types[x.ty()]).collect(), + }, + }; + let (i, _) = typemap.insert_full(resolved); + id::ty(i) +} + +/// An index of opaque functions. +/// +/// Each key holds the opaque function itself followed by the generic parameters used for this +/// particular instance. The value is the resolved type signature of the function according to a +/// global type index. +type Imports = IndexMap<(O, Box<[id::Ty]>), (Box<[id::Ty]>, id::Ty)>; + +/// An index of transparent functions. +/// +/// Each key holds a reference to the function itself followed by the generic parameters used for +/// this particular instance. The value holds the function's immediate callees (see `rose::Refs`) +/// followed by a mapping from the function's own type indices to resolved type indices in a +/// global type index. +type Funcs<'a, T> = IndexMap<(ByAddress<&'a Func>, Box<[id::Ty]>), (T, Box<[id::Ty]>)>; + +/// Computes a topological sort of a call graph via depth-first search. +struct Topsort<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>> { + /// All types seen so far. + types: IndexSet, + + /// All opaque functions seen so far. + imports: Imports, + + /// All transparent functions seen so far, in topological sorted order. + funcs: Funcs<'a, T>, +} + +impl<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>> Topsort<'a, O, T> { + /// Search in the given `block` of `f`, using `refs` to resolve immediate function calls. + /// + /// The `types` argument is the resolved type ID for each of `f.types` in `self.types`. + fn block(&mut self, refs: &T, f: &'a Func, types: &[id::Ty], block: &[Instr]) { + for instr in block.iter() { + match &instr.expr { + Expr::Unit + | Expr::Bool { .. } + | Expr::F64 { .. } + | Expr::Fin { .. } + | Expr::Array { .. } + | Expr::Tuple { .. } + | Expr::Index { .. } + | Expr::Member { .. } + | Expr::Slice { .. } + | Expr::Field { .. } + | Expr::Unary { .. } + | Expr::Binary { .. } + | Expr::Select { .. } + | Expr::Accum { .. } + | Expr::Add { .. } + | Expr::Resolve { .. } => {} + Expr::For { body, .. } => { + self.block(refs, f, types, body); + } + Expr::Call { id, generics, args } => { + let gens = generics.iter().map(|t| types[t.ty()]).collect(); + match refs.get(*id).unwrap() { + Node::Transparent { refs, def } => { + let key = (ByAddress(def), gens); + if !self.funcs.contains_key(&key) { + let (_, gens) = key; // get back `gens` to please the borrow checker + self.func(refs, def, gens); + } + } + Node::Opaque { def, .. } => { + let resolved = ( + args.iter().map(|x| types[f.vars[x.var()].ty()]).collect(), + types[f.vars[instr.var.var()].ty()], + ); + match self.imports.entry((def, gens)) { + Entry::Occupied(entry) => { + // we should never see the same exact opaque function with the + // same generic type parameters but multiple different type + // signatures + assert_eq!(entry.get(), &resolved); + } + Entry::Vacant(entry) => { + entry.insert(resolved); + } + } + } + } + } + } + } + } + + /// Search from `def` with the given `generics`, using `refs` to resolve immediate calls. + fn func(&mut self, refs: T, def: &'a Func, generics: Box<[id::Ty]>) { + let mut types = vec![]; + for ty in def.types.iter() { + types.push(resolve(&mut self.types, &generics, &types, ty)); + } + self.block(&refs, def, &types, &def.body); + let prev = self + .funcs + .insert((ByAddress(def), generics), (refs, types.into())); + // we're doing depth-first search on a DAG, so even if we wait until this last moment to + // mark the node as visited, we still can't have seen it already + assert!(prev.is_none()); + } +} + +/// Return the WebAssembly value type used to represent a local of type `ty`. +fn val_type(ty: &Ty) -> ValType { + match ty { + Ty::Unit + | Ty::Bool + | Ty::Fin { .. } + | Ty::Ref { .. } + | Ty::Array { .. } + | Ty::Tuple { .. } => ValType::I32, + Ty::F64 => ValType::F64, + Ty::Generic { .. } => unreachable!(), + } +} + +/// A WebAssembly memory offset or size. +type Size = u32; + +/// Convert a `usize` to a `Size`. +/// +/// This will always succeed if the compiler itself is running inside WebAssembly. +fn u_size(x: usize) -> Size { + x.try_into().unwrap() +} + +/// Round up `size` to the nearest multiple of `align`. +fn aligned(size: Size, align: Size) -> Size { + (size + align - 1) & !(align - 1) +} + +/// The layout of a type in memory. +#[derive(Clone, Copy)] +enum Layout { + /// The unit type. Zero-sized. + Unit, + + /// An unsigned 8-bit integer. + U8, + + /// An unsigned 16-bit integer. + U16, + + /// An unsigned 32-bit integer. + U32, + + /// A 64-bit floating-point number. + F64, + + /// `Ty::Ref` cannot be stored in memory. + Ref, +} + +impl Layout { + /// Return the size and alignment of this `Layout`, in bytes. + fn size_align(self) -> (Size, Size) { + match self { + Self::Unit => (0, 1), + Self::U8 => (1, 1), + Self::U16 => (2, 2), + Self::U32 => (4, 4), + Self::F64 => (8, 8), + Self::Ref => unreachable!(), + } + } + + /// Return the size of this `Layout`, which is always aligned. + fn size(self) -> Size { + let (size, _) = self.size_align(); + size // no need to use alignment, because every possible `Layout` size is already aligned + } + + /// Emit a load instruction for this layout with the given byte offset. + fn load(self, function: &mut Function, offset: Size) { + let offset = offset.into(); + match self { + Self::Unit => { + function.instruction(&Instruction::Drop); + function.instruction(&Instruction::I32Const(0)); + } + Self::U8 => { + function.instruction(&Instruction::I32Load8U(MemArg { + offset, + align: 0, + memory_index: 0, + })); + } + Self::U16 => { + function.instruction(&Instruction::I32Load16U(MemArg { + offset, + align: 1, + memory_index: 0, + })); + } + Self::U32 => { + function.instruction(&Instruction::I32Load(MemArg { + offset, + align: 2, + memory_index: 0, + })); + } + Self::F64 => { + function.instruction(&Instruction::F64Load(MemArg { + offset, + align: 3, + memory_index: 0, + })); + } + Self::Ref => unreachable!(), + } + } + + /// Emit a store instruction for this layout with the given byte offset. + fn store(self, function: &mut Function, offset: Size) { + let offset = offset.into(); + match self { + Self::Unit => { + function.instruction(&Instruction::Drop); + function.instruction(&Instruction::Drop); + } + Self::U8 => { + function.instruction(&Instruction::I32Store8(MemArg { + offset, + align: 0, + memory_index: 0, + })); + } + Self::U16 => { + function.instruction(&Instruction::I32Store16(MemArg { + offset, + align: 1, + memory_index: 0, + })); + } + Self::U32 => { + function.instruction(&Instruction::I32Store(MemArg { + offset, + align: 2, + memory_index: 0, + })); + } + Self::F64 => { + function.instruction(&Instruction::F64Store(MemArg { + offset, + align: 3, + memory_index: 0, + })); + } + Self::Ref => unreachable!(), + } + } +} + +/// The index of a WebAssembly local. +type Local = u32; + +/// Information about a type that has functions for accumulation. +#[derive(Clone, Copy)] +struct Accum { + /// The ID of the zero function. + zero: u32, + + /// The allocation cost of the zero function. + cost: Size, + + // The ID of the add function, which has no allocation cost. + add: u32, +} + +/// Information about a type that is necessary for code generation. +struct Meta { + /// The type. + ty: Ty, + + /// The layout of the type. + layout: Layout, + + /// Zero and add functions for accumulation, if this type is an array or tuple. + accum: Option, + + /// Offsets of each member of a tuple. + members: Option>, +} + +/// Generates WebAssembly code for a function. +struct Codegen<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> { + /// Metadata about all the types in the global type index. + metas: &'b [Meta], + + /// All opaque functions. + imports: &'b Imports, + + /// The number of opaque functions plus the number of accumulation functions (zeros and adds). + extras: usize, + + /// All transparent functions. + funcs: &'b Funcs<'a, T>, + + /// The allocation cost of each transparent function. + costs: &'b [Size], + + /// To resolve calls. + refs: &'b T, + + /// The definition of the particular function we're generating code for. + def: &'b Func, + + /// Mapping from this function's type indices to type indices in the global type index. + types: &'b [id::Ty], + + /// The WebAssembly local assigned to each variable in this function. + locals: &'b [Local], + + /// The amount of memory allocated so far in the current block. + /// + /// This is for the block and not the entire function, because for instance, a loop's total + /// allocation cost depends both on its block's allocation cost and on the number of iterations. + offset: Size, + + /// The WebAssembly function under construction. + wasm: Function, +} + +impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { + /// Return metadata for the type ID `t` in the current function. + /// + /// Do not use this if your type ID is already resolved to refer to the global type index. + fn meta(&self, t: id::Ty) -> &'b Meta { + &self.metas[self.types[t.ty()].ty()] + } + + /// Emit an instruction to push the value of `x` onto the stack. + fn get(&mut self, x: id::Var) { + self.wasm + .instruction(&Instruction::LocalGet(self.locals[x.var()])); + } + + /// Emit an instruction to pop the top of the stack and store it in `x`. + fn set(&mut self, x: id::Var) { + self.wasm + .instruction(&Instruction::LocalSet(self.locals[x.var()])); + } + + /// Emit an instruction to store the stack top in `x` without popping it. + fn tee(&mut self, x: id::Var) { + self.wasm + .instruction(&Instruction::LocalTee(self.locals[x.var()])); + } + + /// Emit an instruction to push the current memory allocation pointer onto the stack. + fn pointer(&mut self) { + self.wasm + .instruction(&Instruction::LocalGet(u_size(self.def.params.len()))); + } + + /// Emit an instruction to push the constant integer value `x` onto the stack. + fn u32_const(&mut self, x: u32) { + self.wasm.instruction(&Instruction::I32Const(x as i32)); + } + + /// Emit instructions to increase the memory allocation pointer by `size` bytes. + fn bump(&mut self, size: Size) { + let aligned = aligned(size, 8); + self.pointer(); + self.u32_const(aligned); + self.wasm.instruction(&Instruction::I32Add); + self.wasm + .instruction(&Instruction::LocalSet(u_size(self.def.params.len()))); + self.offset += aligned; + } + + /// Emit instruction(s) to load a value with the given `layout` and `offset`. + fn load(&mut self, layout: Layout, offset: Size) { + layout.load(&mut self.wasm, offset) + } + + /// Emit instruction(s) to store a value with the given `layout` and `offset`. + fn store(&mut self, layout: Layout, offset: Size) { + layout.store(&mut self.wasm, offset) + } + + /// Generate code for the given `block`. + fn block(&mut self, block: &[Instr]) { + for instr in block.iter() { + match &instr.expr { + Expr::Unit => { + self.wasm.instruction(&Instruction::I32Const(0)); + } + &Expr::Bool { val } => { + self.wasm.instruction(&Instruction::I32Const(val.into())); + } + &Expr::F64 { val } => { + self.wasm.instruction(&Instruction::F64Const(val)); + } + &Expr::Fin { val } => { + self.wasm + .instruction(&Instruction::I32Const(val.try_into().unwrap())); + } + Expr::Array { elems } => { + let &Meta { layout, .. } = + self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { + Ty::Array { elem, .. } => elem, + _ => unreachable!(), + }); + let size = layout.size(); + for (i, &elem) in elems.iter().enumerate() { + self.pointer(); + self.get(elem); + self.store(layout, size * u_size(i)); + } + self.pointer(); + self.bump(size * u_size(elems.len())); + } + Expr::Tuple { members } => { + let Meta { members: mems, .. } = self.meta(self.def.vars[instr.var.var()]); + let mut size = 0; + for (&member, &offset) in members.iter().zip(mems.as_ref().unwrap().iter()) { + let &Meta { layout, .. } = self.meta(self.def.vars[member.var()]); + self.pointer(); + self.get(member); + self.store(layout, offset); + size = size.max(offset + layout.size()); + } + self.pointer(); + self.bump(size); + } + &Expr::Index { array, index } => { + let &Meta { layout, .. } = self.meta(self.def.vars[instr.var.var()]); + let size = layout.size(); + self.get(array); + self.get(index); + self.u32_const(size); + self.wasm.instruction(&Instruction::I32Mul); + self.wasm.instruction(&Instruction::I32Add); + self.load(layout, 0); + } + &Expr::Member { tuple, member } => { + let Meta { members, .. } = self.meta(self.def.vars[tuple.var()]); + let offset = members.as_ref().unwrap()[member.member()]; + let &Meta { layout, .. } = self.meta(self.def.vars[instr.var.var()]); + self.get(tuple); + self.load(layout, offset); + } + &Expr::Slice { array, index } => { + let meta = + self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { + Ty::Ref { inner } => inner, + _ => unreachable!(), + }); + let size = meta.layout.size(); + self.get(array); + self.get(index); + self.u32_const(size); + self.wasm.instruction(&Instruction::I32Mul); + self.wasm.instruction(&Instruction::I32Add); + if let Ty::Array { .. } | Ty::Tuple { .. } = &meta.ty { + // if this array holds primitives then we just want a pointer to the + // element, but if it's actually another composite value then it's already a + // pointer, so we need to do a load because otherwise we'd have a pointer to + // a pointer instead of just one direct pointer + self.load(meta.layout, 0); + } + } + &Expr::Field { tuple, member } => { + let Meta { members, .. } = + self.meta(match self.def.types[self.def.vars[tuple.var()].ty()] { + Ty::Ref { inner } => inner, + _ => unreachable!(), + }); + let offset = members.as_ref().unwrap()[member.member()]; + let meta = + self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { + Ty::Ref { inner } => inner, + _ => unreachable!(), + }); + self.get(tuple); + match &meta.ty { + Ty::Unit | Ty::Bool | Ty::F64 | Ty::Fin { .. } => { + // if this array holds primitives then we just want a pointer to the + // element + self.u32_const(offset); + self.wasm.instruction(&Instruction::I32Add); + } + Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), + Ty::Array { .. } | Ty::Tuple { .. } => { + // if this array holds other composite values then each element is + // already a pointer, so we need to do a load because otherwise we'd + // have a pointer to a pointer instead of just one direct pointer + self.load(meta.layout, offset); + } + } + } + &Expr::Unary { op, arg } => match op { + Unop::Not => { + self.get(arg); + self.wasm.instruction(&Instruction::I32Eqz); + } + Unop::Neg => { + self.get(arg); + self.wasm.instruction(&Instruction::F64Neg); + } + Unop::Abs => { + self.get(arg); + self.wasm.instruction(&Instruction::F64Abs); + } + Unop::Sign => { + // TODO: `f64.const` instructions are always 8 bytes, much larger than most + // instructions; maybe we should just keep this constant in a local + self.wasm.instruction(&Instruction::F64Const(1.)); + self.get(arg); + self.wasm.instruction(&Instruction::F64Copysign); + } + Unop::Ceil => { + self.get(arg); + self.wasm.instruction(&Instruction::F64Ceil); + } + Unop::Floor => { + self.get(arg); + self.wasm.instruction(&Instruction::F64Floor); + } + Unop::Trunc => { + self.get(arg); + self.wasm.instruction(&Instruction::F64Trunc); + } + Unop::Sqrt => { + self.get(arg); + self.wasm.instruction(&Instruction::F64Sqrt); + } + }, + &Expr::Binary { op, left, right } => { + self.get(left); + self.get(right); + match op { + Binop::And => self.wasm.instruction(&Instruction::I32And), + Binop::Or => self.wasm.instruction(&Instruction::I32Or), + Binop::Iff => self.wasm.instruction(&Instruction::I32Eq), + Binop::Xor => self.wasm.instruction(&Instruction::I32Xor), + Binop::Neq => self.wasm.instruction(&Instruction::F64Ne), + Binop::Lt => self.wasm.instruction(&Instruction::F64Lt), + Binop::Leq => self.wasm.instruction(&Instruction::F64Le), + Binop::Eq => self.wasm.instruction(&Instruction::F64Eq), + Binop::Gt => self.wasm.instruction(&Instruction::F64Gt), + Binop::Geq => self.wasm.instruction(&Instruction::F64Ge), + Binop::Add => self.wasm.instruction(&Instruction::F64Add), + Binop::Sub => self.wasm.instruction(&Instruction::F64Sub), + Binop::Mul => self.wasm.instruction(&Instruction::F64Mul), + Binop::Div => self.wasm.instruction(&Instruction::F64Div), + }; + } + &Expr::Select { cond, then, els } => { + self.get(then); + self.get(els); + self.get(cond); + self.wasm.instruction(&Instruction::Select); + } + Expr::Call { id, generics, args } => { + let gens = generics + .iter() + .map(|t| self.types[self.def.vars[t.ty()].ty()]) + .collect(); + for &arg in args.iter() { + self.get(arg); + } + let i = match self.refs.get(*id).unwrap() { + Node::Transparent { def, .. } => { + self.pointer(); + let j = self.funcs.get_index_of(&(ByAddress(def), gens)).unwrap(); + self.bump(self.costs[j]); + self.extras + j + } + Node::Opaque { def, .. } => { + self.imports.get_index_of(&(def, gens)).unwrap() + } + }; + self.wasm + .instruction(&Instruction::Call(i.try_into().unwrap())); + } + Expr::For { arg, body, ret } => { + let n = u_size(match self.meta(self.def.vars[arg.var()]).ty { + Ty::Fin { size } => size, + _ => unreachable!(), + }); + let &Meta { layout, .. } = self.meta(self.def.vars[ret.var()]); + let size = layout.size(); + + // we need to set the local now rather than later, because we're going to bump + // the pointer for the array itself and possibly in the loop body, but we still + // need to know this pointer so we can use it to store each element of the array + self.pointer(); + self.set(instr.var); + + // we put the bounds check at the end of the loop, so if it's going to execute + // zero times then we need to make sure not to enter it at all; easiest way is + // to just not emit the loop instructions at all + if n > 0 { + self.bump(size * n); + let offset = take(&mut self.offset); + + self.wasm.instruction(&Instruction::I32Const(0)); + self.set(*arg); + self.wasm.instruction(&Instruction::Loop(BlockType::Empty)); + + self.block(body); + + self.get(instr.var); + self.get(*arg); + self.u32_const(size); + self.wasm.instruction(&Instruction::I32Mul); + self.wasm.instruction(&Instruction::I32Add); + self.get(*ret); + self.store(layout, 0); + + self.get(*arg); + self.wasm.instruction(&Instruction::I32Const(1)); + self.wasm.instruction(&Instruction::I32Add); + self.tee(*arg); + self.u32_const(n); + self.wasm.instruction(&Instruction::I32LtU); + self.wasm.instruction(&Instruction::BrIf(0)); + self.wasm.instruction(&Instruction::End); + + self.offset = offset + self.offset * n; + } + + continue; + } + &Expr::Accum { shape } => { + let meta = self.meta(self.def.vars[shape.var()]); + match &meta.ty { + // this is a bit subtle: usually a `Ref` variable is a pointer, and that is + // also true for `Ref` variables to values of these three discrete + // continuous types if they come from an `Expr::Slice` or `Expr::Field`; + // but, if we're directly starting an accumulator for a discrete primitive + // value, then its value can't be modified, so we can just store it directly + // instead of allocating extra memory; this works because the WebAssembly + // value types for all these discrete primitive types are the same as for + // pointers, and it's OK to have the representation be different depending + // on whether it's directly introduced by `Expr::Accum` or not, because + // those are the only ones on which we can use `Expr::Resolve`, and `Ref`s + // cannot be directly read before they're resolved anyway + Ty::Unit | Ty::Bool | Ty::Fin { .. } => self.get(shape), + Ty::F64 => { + self.pointer(); + self.pointer(); + // TODO: `f64.const` instructions are always 8 bytes, much larger than + // most instructions; maybe we should just keep this constant in a local + self.wasm.instruction(&Instruction::F64Const(0.)); + self.store(Layout::F64, 0); + self.bump(8); + } + Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), + Ty::Array { .. } | Ty::Tuple { .. } => { + let Accum { zero, cost, .. } = meta.accum.unwrap(); + self.pointer(); + self.get(shape); + self.wasm.instruction(&Instruction::Call(zero)); + self.pointer(); + self.bump(cost); + } + } + } + &Expr::Add { accum, addend } => { + let meta = self.meta(self.def.vars[addend.var()]); + match &meta.ty { + Ty::Unit | Ty::Bool | Ty::Fin { .. } => {} + Ty::F64 => { + self.get(accum); + self.get(accum); + self.load(Layout::F64, 0); + self.get(addend); + self.wasm.instruction(&Instruction::F64Add); + self.store(Layout::F64, 0); + } + Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), + Ty::Array { .. } | Ty::Tuple { .. } => { + self.get(accum); + self.get(addend); + self.wasm + .instruction(&Instruction::Call(meta.accum.unwrap().add)); + } + } + self.wasm.instruction(&Instruction::I32Const(0)); + } + &Expr::Resolve { var } => { + self.get(var); + if let Ty::F64 = &self.meta(self.def.vars[instr.var.var()]).ty { + // as explained above, if the `inner` value is a discrete primitive type + // then we cheated and stored it directly in the local so there's nothing to + // do, and if it's a pointer then we still just need a pointer so there's + // also nothing to do; but if it's a continuous primitive type then we + // introduced an extra layer of indirection so we need to loads + self.load(Layout::F64, 0); + } + } + } + self.set(instr.var); + } + } +} + +/// A WebAssembly module for a graph of functions. +/// +/// The module exports its memory with name `"m"` and its entrypoint function with name `"f"`. The +/// function takes one parameter in addition to its original parameters, which must be an +/// 8-byte-aligned pointer to the start of the memory region it can use for allocation. The memory +/// is the exact number of pages necessary to accommodate the function's own memory allocation as +/// well as memory allocation for all of its parameters, with each node in each parameter's memory +/// allocation tree being 8-byte aligned. That is, the function's last argument should be just large +/// enough to accommodate those allocations for all the parameters; in that case, no memory will be +/// incorrectly overwritten and no out-of-bounds memory accesses will occur. +pub struct Wasm { + /// The bytes of the WebAssembly module binary. + pub bytes: Vec, + + /// All the opaque functions that the WebAssembly module must import, in order. + /// + /// The module name for each import is the empty string, and the field name is the base-ten + /// representation of its index in this collection. + pub imports: Imports, +} + +/// Compile `f` and all its direct and indirect callees to a WebAssembly module. +pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> Wasm { + let mut topsort = Topsort { + types: IndexSet::new(), + imports: IndexMap::new(), + funcs: IndexMap::new(), + }; + match f { + Node::Transparent { refs, def } => { + topsort.func(refs, def, [].into()); + } + Node::Opaque { + types, + params, + ret, + def, + .. + } => { + // if `f` itself is an opaque function then the graph of all callees has only one node + let mut def_types = vec![]; + for ty in types.iter() { + def_types.push(resolve(&mut topsort.types, &[], &def_types, ty)); + } + topsort.imports.insert( + (def, [].into()), + ( + params.iter().map(|t| def_types[t.ty()]).collect(), + def_types[ret.ty()], + ), + ); + } + } + let Topsort { + types, + imports, + funcs, + } = topsort; + + // we add to this lazily as we generate our imports, functions, and code, after which we'll + // generate the actual function types section right near the end; it doesn't matter as long as + // the order we actually add the sections to the module itself is correct + let mut func_types: IndexSet<(Box<[ValType]>, ValType)> = IndexSet::new(); + + let mut import_section = ImportSection::new(); + for (i, (params, ret)) in imports.values().enumerate() { + let (type_index, _) = func_types.insert_full(( + params.iter().map(|t| val_type(&types[t.ty()])).collect(), + val_type(&types[ret.ty()]), + )); + // we reserve type index zero for the type with two `i32` params and no results, which we + // use for accumulation zero and add functions; we don't include that in the `func_types` + // index itself, because that index only holds function types with exactly one result + import_section.import( + "", + &i.to_string(), + EntityType::Function((1 + type_index).try_into().unwrap()), + ); + } + + let mut function_section = FunctionSection::new(); + let mut code_section = CodeSection::new(); + + let mut metas: Vec = vec![]; + let mut extras: usize = imports.len(); + for ty in types.into_iter() { + let (layout, cost, members) = match &ty { + Ty::Unit => (Layout::Unit, None, None), + Ty::Bool => (Layout::U8, None, None), + Ty::F64 => (Layout::F64, None, None), + &Ty::Fin { size } => ( + if size <= 1 { + Layout::Unit + } else if size <= 256 { + Layout::U8 + } else if size <= 65536 { + Layout::U16 + } else { + Layout::U32 + }, + None, + None, + ), + Ty::Generic { .. } => unreachable!(), + Ty::Ref { .. } => (Layout::Ref, None, None), + Ty::Array { index, elem } => { + let n = u_size(match metas[index.ty()].ty { + Ty::Fin { size } => size, + _ => unreachable!(), + }); + let meta = &metas[elem.ty()]; + let size = meta.layout.size(); + + // for both the zero function and the add function, the first parameter is a pointer + // to the accumulator value, and the second parameter is the pointer to the other + // value (the shape for zero, or the addend for add) + + // the first local is a pointer to the end of the accumulator array, used for bounds + // checking; the second local is a memory allocation pointer, used as the + // accumulator pointer for calls to the zero function for elements if this array + // stores composite values + let mut zero = Function::new([(2, ValType::I32)]); + let mut total = aligned(size * n, 8); + // same as zero, the local is a pointer to the end of the accumulator array, used + // for bounds checking + let mut add = Function::new([(1, ValType::I32)]); + + if n > 0 { + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::I32Const(total.try_into().unwrap())); + zero.instruction(&Instruction::I32Add); + zero.instruction(&Instruction::LocalTee(2)); + zero.instruction(&Instruction::LocalSet(3)); + zero.instruction(&Instruction::Loop(BlockType::Empty)); + + add.instruction(&Instruction::LocalGet(0)); + add.instruction(&Instruction::I32Const(total.try_into().unwrap())); + add.instruction(&Instruction::I32Add); + add.instruction(&Instruction::LocalSet(2)); + add.instruction(&Instruction::Loop(BlockType::Empty)); + + match &meta.ty { + Ty::Unit => {} + Ty::Bool | Ty::Fin { .. } => { + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::LocalGet(1)); + meta.layout.load(&mut zero, 0); + meta.layout.store(&mut zero, 0); + } + Ty::F64 => { + zero.instruction(&Instruction::LocalGet(0)); + // TODO: `f64.const` instructions are always 8 bytes, much larger than + // most instructions; maybe we should just keep this constant in a local + zero.instruction(&Instruction::F64Const(0.)); + meta.layout.store(&mut zero, 0); + + add.instruction(&Instruction::LocalGet(0)); + add.instruction(&Instruction::LocalGet(0)); + meta.layout.load(&mut add, 0); + add.instruction(&Instruction::LocalGet(1)); + meta.layout.load(&mut add, 0); + add.instruction(&Instruction::F64Add); + meta.layout.store(&mut add, 0); + } + Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), + Ty::Array { .. } | Ty::Tuple { .. } => { + let accum = meta.accum.unwrap(); + let cost = accum.cost; + + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::LocalGet(3)); + meta.layout.store(&mut zero, 0); + zero.instruction(&Instruction::LocalGet(3)); + zero.instruction(&Instruction::LocalGet(1)); + meta.layout.load(&mut zero, 0); + zero.instruction(&Instruction::Call(accum.zero)); + zero.instruction(&Instruction::LocalGet(3)); + zero.instruction(&Instruction::I32Const(cost.try_into().unwrap())); + zero.instruction(&Instruction::I32Add); + zero.instruction(&Instruction::LocalSet(3)); + + total += cost * n; + + add.instruction(&Instruction::LocalGet(0)); + meta.layout.load(&mut add, 0); + add.instruction(&Instruction::LocalGet(1)); + meta.layout.load(&mut add, 0); + add.instruction(&Instruction::Call(accum.add)); + } + } + + zero.instruction(&Instruction::LocalGet(1)); + zero.instruction(&Instruction::I32Const(size.try_into().unwrap())); + zero.instruction(&Instruction::I32Add); + zero.instruction(&Instruction::LocalSet(1)); + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::I32Const(size.try_into().unwrap())); + zero.instruction(&Instruction::I32Add); + zero.instruction(&Instruction::LocalTee(0)); + zero.instruction(&Instruction::LocalGet(2)); + zero.instruction(&Instruction::I32LtU); + zero.instruction(&Instruction::BrIf(0)); + zero.instruction(&Instruction::End); + + add.instruction(&Instruction::LocalGet(1)); + add.instruction(&Instruction::I32Const(size.try_into().unwrap())); + add.instruction(&Instruction::I32Add); + add.instruction(&Instruction::LocalSet(1)); + add.instruction(&Instruction::LocalGet(0)); + add.instruction(&Instruction::I32Const(size.try_into().unwrap())); + add.instruction(&Instruction::I32Add); + add.instruction(&Instruction::LocalTee(0)); + add.instruction(&Instruction::LocalGet(2)); + add.instruction(&Instruction::I32LtU); + add.instruction(&Instruction::BrIf(0)); + add.instruction(&Instruction::End); + } + + zero.instruction(&Instruction::End); + code_section.function(&zero); + + add.instruction(&Instruction::End); + code_section.function(&add); + + (Layout::U32, Some(total), None) + } + Ty::Tuple { members } => { + let mut mems: Vec<_> = members + .iter() + .enumerate() + .map(|(i, t)| { + let Meta { layout, .. } = metas[t.ty()]; + let (size, align) = layout.size_align(); + (i, size, align) + }) + .collect(); + mems.sort_unstable_by_key(|&(_, _, align)| align); + let mut offsets = vec![0; members.len()]; + let mut offset = 0; + for (i, s, a) in mems { + offset = aligned(offset, a); + offsets[i] = offset; + offset += s; + } + + // the local is a memory allocation pointer, used as the accumulator pointer for + // calls to the zero function for composite elements of the tuple + let mut zero = Function::new([(1, ValType::I32)]); + let mut total = aligned(offset, 8); + let mut add = Function::new([]); + + for (member, &offset) in members.iter().zip(offsets.iter()) { + let meta = &metas[member.ty()]; + + match &meta.ty { + Ty::Unit => {} + Ty::Bool | Ty::Fin { .. } => { + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::LocalGet(1)); + meta.layout.load(&mut zero, offset); + meta.layout.store(&mut zero, offset); + } + Ty::F64 => { + zero.instruction(&Instruction::LocalGet(0)); + // TODO: `f64.const` instructions are always 8 bytes, much larger than + // most instructions; maybe we should just keep this constant in a local + zero.instruction(&Instruction::F64Const(0.)); + meta.layout.store(&mut zero, offset); + + add.instruction(&Instruction::LocalGet(0)); + add.instruction(&Instruction::LocalGet(0)); + meta.layout.load(&mut add, offset); + add.instruction(&Instruction::LocalGet(1)); + meta.layout.load(&mut add, offset); + add.instruction(&Instruction::F64Add); + meta.layout.store(&mut add, offset); + } + Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), + Ty::Array { .. } | Ty::Tuple { .. } => { + let accum = meta.accum.unwrap(); + let cost = accum.cost; + + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::I32Const(total.try_into().unwrap())); + zero.instruction(&Instruction::I32Add); + zero.instruction(&Instruction::LocalTee(2)); + meta.layout.store(&mut zero, offset); + zero.instruction(&Instruction::LocalGet(2)); + zero.instruction(&Instruction::LocalGet(1)); + meta.layout.load(&mut zero, offset); + zero.instruction(&Instruction::Call(accum.zero)); + + total += cost; + + add.instruction(&Instruction::LocalGet(0)); + meta.layout.load(&mut add, offset); + add.instruction(&Instruction::LocalGet(1)); + meta.layout.load(&mut add, offset); + add.instruction(&Instruction::Call(accum.add)); + } + } + } + + zero.instruction(&Instruction::End); + code_section.function(&zero); + + add.instruction(&Instruction::End); + code_section.function(&add); + + (Layout::U32, Some(total), Some(offsets.into())) + } + }; + metas.push(Meta { + ty, + layout, + accum: cost.map(|cost| { + let zero = extras.try_into().unwrap(); + function_section.function(0); + let add = (extras + 1).try_into().unwrap(); + function_section.function(0); + extras += 2; + Accum { zero, cost, add } + }), + members, + }); + } + + let mut costs = vec![]; // allocation cost of each function, in bytes + for ((def, _), (refs, def_types)) in funcs.iter() { + let vt = |t: id::Ty| val_type(&metas[def_types[t.ty()].ty()].ty); // short for `ValType` + let params: Local = (def.params.len() + 1).try_into().unwrap(); // extra pointer parameter + let mut locals = vec![None; def.vars.len()]; + + let (type_index, _) = func_types.insert_full(( + def.params + .iter() + .enumerate() + .map(|(i, param)| { + locals[param.var()] = Some(i.try_into().unwrap()); + vt(def.vars[param.var()]) + }) + .chain([ValType::I32]) // extra pointer parameter + .collect(), + vt(def.vars[def.ret.var()]), + )); + function_section.function((1 + type_index).try_into().unwrap()); + + let mut i32s = 0; + for (i, &t) in def.vars.iter().enumerate() { + if locals[i].is_none() { + if let ValType::I32 = vt(t) { + locals[i] = Some(params + i32s); + i32s += 1; + } + } + } + let mut f64s = 0; + for (i, &t) in def.vars.iter().enumerate() { + if locals[i].is_none() { + assert_eq!(vt(t), ValType::F64); + locals[i] = Some(params + i32s + f64s); + f64s += 1; + } + } + + let locals: Box<_> = locals.into_iter().map(Option::unwrap).collect(); + let mut codegen = Codegen { + metas: &metas, + imports: &imports, + extras, + funcs: &funcs, + costs: &costs, + refs, + def, + types: def_types, + locals: &locals, + offset: 0, + wasm: Function::new([(i32s, ValType::I32), (f64s, ValType::F64)]), + }; + codegen.block(&def.body); + codegen.get(def.ret); + codegen.wasm.instruction(&Instruction::End); + code_section.function(&codegen.wasm); + costs.push(codegen.offset); + } + + let mut type_section = TypeSection::new(); + type_section.function([ValType::I32, ValType::I32], []); // for accumulation functions + for (params, ret) in func_types { + type_section.function(params.into_vec(), [ret]); + } + + let mut memory_section = MemorySection::new(); + let page_size = 65536; // https://webassembly.github.io/spec/core/exec/runtime.html#page-size + let cost = funcs.last().map_or(0, |((def, _), (_, def_types))| { + def.params + .iter() + .filter_map(|param| metas[def_types[def.vars[param.var()].ty()].ty()].accum) + .map(|accum| accum.cost) + .sum() + }) + costs.last().unwrap_or(&0); + let pages = ((cost + page_size - 1) / page_size).into(); // round up to a whole number of pages + memory_section.memory(MemoryType { + minimum: pages, + maximum: Some(pages), + memory64: false, + shared: false, + }); + + let mut export_section = ExportSection::new(); + export_section.export( + "f", + wasm_encoder::ExportKind::Func, + (extras + funcs.len() - 1).try_into().unwrap(), + ); + export_section.export("m", wasm_encoder::ExportKind::Memory, 0); + + let mut module = Module::new(); + module.section(&type_section); + module.section(&import_section); + module.section(&function_section); + module.section(&memory_section); + module.section(&export_section); + module.section(&code_section); + Wasm { + bytes: module.finish(), + imports, + } +} diff --git a/crates/web/Cargo.toml b/crates/web/Cargo.toml index 15e2f06..e76233f 100644 --- a/crates/web/Cargo.toml +++ b/crates/web/Cargo.toml @@ -8,6 +8,7 @@ edition = "2021" crate-type = ["cdylib"] [dependencies] +by_address = "1" console_error_panic_hook = { version = "0.1", optional = true } enumset = "1" indexmap = "2" @@ -16,6 +17,7 @@ rose = { path = "../core" } rose-autodiff = { path = "../autodiff" } rose-interp = { path = "../interp", features = ["serde"] } rose-transpose = { path = "../transpose" } +rose-wasm = { path = "../wasm" } serde = { version = "1", features = ["derive"] } serde-wasm-bindgen = "0.4" wasm-bindgen = "=0.2.87" # Must be this version of wbg diff --git a/crates/web/src/lib.rs b/crates/web/src/lib.rs index 16ddb60..748989a 100644 --- a/crates/web/src/lib.rs +++ b/crates/web/src/lib.rs @@ -1,3 +1,4 @@ +use by_address::ByAddress; use enumset::EnumSet; use indexmap::{IndexMap, IndexSet}; use rose::id; @@ -53,8 +54,9 @@ fn val_to_js(x: &rose_interp::Val) -> JsValue { } /// Reference to an opaque function that just points to a JavaScript function as its implementation. +#[derive(Clone, Copy, Eq, Hash, PartialEq)] struct Opaque<'a> { - f: &'a js_sys::Function, + f: ByAddress<&'a js_sys::Function>, } impl rose_interp::Opaque for Opaque<'_> { @@ -182,7 +184,7 @@ impl Func { types, params, ret: *ret, - def: Opaque { f: def }, + def: Opaque { f: ByAddress(def) }, }, } } @@ -209,44 +211,91 @@ impl Func { } } - /// Return true iff `t` is the ID of a finite integer type. - #[wasm_bindgen(js_name = "isFin")] - pub fn is_fin(&mut self, t: usize) -> bool { - let Pointee { inner, .. } = self.rc.as_ref(); - let ty = match inner { + /// Return the number of types defined in this function. + #[wasm_bindgen(js_name = "numTypes")] + pub fn num_types(&self) -> usize { + match &self.rc.as_ref().inner { + Inner::Transparent { def, .. } => def.types.len(), + Inner::Opaque { types, .. } => types.len(), + } + } + + /// Return the type with ID `t`, if it exists. + fn ty(&self, t: usize) -> Option<&rose::Ty> { + match &self.rc.as_ref().inner { Inner::Transparent { def, .. } => def.types.get(t), Inner::Opaque { types, .. } => types.get(t), - }; - matches!(ty, Some(rose::Ty::Fin { .. })) + } + } + + /// Return true iff `t` is the ID of a unit type. + #[wasm_bindgen(js_name = "isUnit")] + pub fn is_unit(&self, t: usize) -> bool { + matches!(self.ty(t), Some(rose::Ty::Unit)) + } + + /// Return true iff `t` is the ID of a boolean type. + #[wasm_bindgen(js_name = "isBool")] + pub fn is_bool(&self, t: usize) -> bool { + matches!(self.ty(t), Some(rose::Ty::Bool)) + } + + /// Return true iff `t` is the ID of a 64-bit floating-point type. + #[wasm_bindgen(js_name = "isF64")] + pub fn is_f64(&self, t: usize) -> bool { + matches!(self.ty(t), Some(rose::Ty::F64)) + } + + /// Return true iff `t` is the ID of a finite integer type. + #[wasm_bindgen(js_name = "isFin")] + pub fn is_fin(&self, t: usize) -> bool { + matches!(self.ty(t), Some(rose::Ty::Fin { .. })) + } + + /// Return true iff `t` is the ID of an array type. + #[wasm_bindgen(js_name = "isArray")] + pub fn is_array(&self, t: usize) -> bool { + matches!(self.ty(t), Some(rose::Ty::Array { .. })) + } + + /// Return true iff `t` is the ID of a struct type. + #[wasm_bindgen(js_name = "isStruct")] + pub fn is_struct(&self, t: usize) -> bool { + self.rc.as_ref().structs[t].is_some() + } + + /// Return the size of the finite integer type with ID `t`. + pub fn size(&self, t: usize) -> usize { + match self.ty(t).unwrap() { + &rose::Ty::Fin { size } => size, + _ => panic!("not a finite integer"), + } + } + + /// Return the ID of the index type for the array type with ID `t`. + pub fn index(&self, t: usize) -> usize { + match self.ty(t).unwrap() { + rose::Ty::Array { index, elem: _ } => index.ty(), + _ => panic!("not an array"), + } } /// Return the ID of the element type for the array type with ID `t`. pub fn elem(&self, t: usize) -> usize { - let Pointee { inner, .. } = self.rc.as_ref(); - match inner { - Inner::Transparent { def, .. } => match def.types[t] { - rose::Ty::Array { index: _, elem } => elem.ty(), - _ => panic!("not an array"), - }, - Inner::Opaque { .. } => panic!(), + match self.ty(t).unwrap() { + rose::Ty::Array { index: _, elem } => elem.ty(), + _ => panic!("not an array"), } } /// Return the string IDs for the struct type with ID `t`. pub fn keys(&self, t: usize) -> Box<[usize]> { - let Pointee { structs, .. } = self.rc.as_ref(); - structs[t].as_ref().unwrap().clone() + self.rc.as_ref().structs[t].as_ref().unwrap().clone() } /// Return the member type IDs for the struct type with ID `t`. pub fn mems(&self, t: usize) -> Box<[usize]> { - let Pointee { inner, .. } = self.rc.as_ref(); - let ty = match inner { - Inner::Transparent { def, .. } => def.types.get(t), - Inner::Opaque { types, .. } => types.get(t), - } - .unwrap(); - match ty { + match self.ty(t).unwrap() { rose::Ty::Tuple { members } => members.iter().map(|m| m.ty()).collect(), _ => panic!("not a struct"), } @@ -262,6 +311,21 @@ impl Func { Ok(to_js_value(&ret)?) } + /// Compile the call graph subtended by this function to WebAssembly. + pub fn compile(&self) -> Wasm { + let rose_wasm::Wasm { bytes, imports } = rose_wasm::compile(self.node()); + Wasm { + bytes: Some(bytes), + imports: Some( + imports + .into_keys() + .map(|(Opaque { f }, _)| (*f).clone()) + .collect(), + ), + } + } + + /// Set the JVP of this function to `f`. #[wasm_bindgen(js_name = "setJvp")] pub fn set_jvp(&self, f: &Func) { self.rc.as_ref().jvp.replace(Some(Rc::clone(&f.rc))); @@ -333,7 +397,7 @@ impl Func { } let (deps_fwd, deps_bwd): (Vec<_>, Vec<_>) = deps.iter().map(|f| f.transpose_pair()).unzip(); - let dep_types: Box<_> = deps_bwd + let dep_types: Box<_> = deps_fwd .iter() .map(|f| match &f.rc.as_ref().inner { Inner::Transparent { def, .. } => { @@ -403,6 +467,26 @@ impl Func { } } +/// A temporary object to hold a generated WebAssembly module and its imports. +#[wasm_bindgen] +pub struct Wasm { + bytes: Option>, + imports: Option>, +} + +#[wasm_bindgen] +impl Wasm { + /// Return the module binary. + pub fn bytes(&mut self) -> Option> { + self.bytes.take() + } + + /// Return the imports. + pub fn imports(&mut self) -> Option> { + self.imports.take() + } +} + /// A temporary object to hold the two passes of a transposed function before they are destructured. #[wasm_bindgen] pub struct Transpose { diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index ea9cc1f..17572c3 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -582,7 +582,7 @@ type JsArgs = { [K in keyof T]: ToJs; }; -/** Concretize the nullary abstract function `f` using the interpreter. */ +/** Concretize the abstract function `f` using the interpreter. */ export const interp = ( f: Fn & ((...args: A) => R), @@ -595,6 +595,228 @@ export const interp = return unpack(f, func.retType(), func.interp(vals)) as ToJs; }; +// https://github.com/rose-lang/rose/issues/116 + +// TODO: use something more like an enum +interface Layout { + size: number; + align: number; +} + +/** Round up `size` to the nearest multiple of `align`. */ +const aligned = ({ size, align }: Layout): number => + (size + align - 1) & ~(align - 1); + +/** An aligned `ArrayBuffer` view, or `undefined` for zero-sized types. */ +type View = undefined | Uint8Array | Uint16Array | Uint32Array | Float64Array; + +const getView = (buffer: ArrayBuffer, layout: Layout, offset: number): View => { + // this code assumes that the layout is uniquely determined by its `size` + const { size } = layout; + if (size === 0) return undefined; + else if (size === 1) return new Uint8Array(buffer, offset); + else if (size === 2) return new Uint16Array(buffer, offset); + else if (size === 4) return new Uint32Array(buffer, offset); + else if (size === 8) return new Float64Array(buffer, offset); + else throw Error("unknown layout"); +}; + +/** Memory representation for a type. */ +interface Meta { + /** Layout of an individual value of this type in memory. */ + layout: Layout; + + /** + * Return the Wasm representation of the JS value `x`. + * + * The given byte offset is only used for pointer types. + */ + encode: (x: unknown, offset: number) => number; + + /** Total memory cost of an object of this type, including sub-allocations. */ + cost: number; + + /** Return a JS value represented by the Wasm value `x`. */ + decode: (x: number) => unknown; +} + +/** + * Return enough information to encode and decode Wasm values with type ID `t`. + * + * The given function `f` must have already been compiled to WebAssembly, + * yielding the given `buffer` of sufficient size. The `metas` array should hold + * encoding/decoding information for all types with IDs less than `t`, or + * `undefined` for reference types and non-struct tuple types since those cannot + * appear in user-facing function signatures. + */ +const getMeta = ( + f: Fn, + buffer: ArrayBuffer, + metas: (Meta | undefined)[], + t: number, +): Meta | undefined => { + const func = f[inner]; + if (func.isUnit(t)) { + return { + layout: { size: 0, align: 1 }, + encode: () => 0, + cost: 0, + decode: () => null, + }; + } else if (func.isBool(t)) { + return { + layout: { size: 1, align: 1 }, + encode: (x) => (x ? 1 : 0), + cost: 0, + decode: Boolean, + }; + } else if (func.isF64(t)) { + return { + layout: { size: 8, align: 8 }, + encode: (x) => x as number, + cost: 0, + decode: (x) => x, + }; + } else if (func.isFin(t)) { + const size = func.size(t); + const layout = + size <= 1 + ? { size: 0, align: 1 } + : size <= 256 + ? { size: 1, align: 1 } + : size <= 65536 + ? { size: 2, align: 2 } + : { size: 4, align: 4 }; + return { layout, encode: (x) => x as number, cost: 0, decode: (x) => x }; + } else if (func.isArray(t)) { + const meta = metas[func.elem(t)]; + if (meta === undefined) return undefined; + const { layout, encode, cost, decode } = meta; + const n = func.size(func.index(t)); + const elem = aligned(layout); + const total = aligned({ size: n * elem, align: 8 }); + const view = getView(buffer, layout, 0); + return { + layout: { size: 4, align: 4 }, + encode: + view === undefined + ? (x, offset) => offset + : (x, offset) => { + let child = offset + total; + for (let i = 0; i < n; ++i) { + view[offset / elem + i] = encode((x as unknown[])[i], child); + child += cost; + } + return offset; + }, + cost: total + n * cost, + decode: + view === undefined + ? () => { + const arr: unknown[] = []; + // this code assumes that all values of all zero-sized types can + // be represented by zero + for (let i = 0; i < n; ++i) arr.push(decode(0)); + return arr; + } + : (x) => { + const arr: unknown[] = []; + for (let i = 0; i < n; ++i) arr.push(decode(view[x / elem + i])); + return arr; + }, + }; + } else if (func.isStruct(t)) { + const keys = func.keys(t); + const members = func.mems(t); + const n = keys.length; + const mems: { key: string; meta: Meta; view?: View; child?: number }[] = []; + for (let i = 0; i < n; ++i) { + const meta = metas[members[i]]; + if (meta === undefined) return undefined; + mems.push({ key: f[strings][keys[i]], meta }); + } + mems.sort((a, b) => a.meta.layout.align - b.meta.layout.align); + let cost = 0; + let offset = 0; + for (const mem of mems) { + const { meta } = mem; + mem.child = cost; + cost += meta.cost; + const { layout } = meta; + const { size, align } = layout; + offset = aligned({ size: offset, align }); + mem.view = getView(buffer, layout, offset); + offset += size; + } + const total = aligned({ size: offset, align: 8 }); + return { + layout: { size: 4, align: 4 }, + encode: (x, offset) => { + for (const { key, meta, view, child } of mems) { + // instead of mutating each element of `mems` above to add more data + // and then still having an `if` statement in here, it would be nicer + // to just map over `mems` above to produce an array of closures that + // can be called directly, with the condition on `view === undefined` + // being handled once rather than in every call to `encode` here + if (view !== undefined) { + view[offset / aligned(meta.layout)] = meta.encode( + (x as any)[key], + offset + total + child!, + ); + } + } + return offset; + }, + cost: total + cost, + decode: (x) => { + const obj: any = {}; + for (const { key, meta, view } of mems) { + if (view === undefined) { + // this code assumes that all values of all zero-sized types can be + // represented by zero + obj[key] = meta.decode(0); + } else { + obj[key] = meta.decode(view[x / aligned(meta.layout)]); + } + } + return obj; + }, + }; + } else return undefined; +}; + +/** Concretize the abstract function `f` using the compiler. */ +export const compile = async ( + f: Fn & ((...args: A) => R), +): Promise<(...args: JsArgs) => ToJs> => { + const func = f[inner]; + const res = func.compile(); + const bytes = res.bytes()!; + const imports = res.imports()!; + res.free(); + const instance = await WebAssembly.instantiate( + await WebAssembly.compile(bytes), + { "": Object.fromEntries(imports.map((g, i) => [i.toString(), g])) }, + ); + const { f: g, m } = instance.exports; + const metas: (Meta | undefined)[] = []; + const n = func.numTypes(); + for (let t = 0; t < n; ++t) + metas.push(getMeta(f, (m as WebAssembly.Memory).buffer, metas, t)); + let total = 0; + const params = Array.from(func.paramTypes()).map((t) => { + const { encode, cost } = metas[t]!; + const offset = total; + total += cost; + return { encode, offset }; + }); + const { decode } = metas[func.retType()]!; + return (...args): any => { + const vals = params.map(({ encode, offset }, i) => encode(args[i], offset)); + return decode((g as any)(...vals, total)); + }; +}; + // https://www.typescriptlang.org/docs/handbook/2/conditional-types.html type ToJvp = [T] extends [Null] ? Null @@ -632,6 +854,7 @@ export const vjp = ( const tp = g[inner].transpose(); const fwdFunc = tp.fwd()!; const bwdFunc = tp.bwd()!; + tp.free(); const fwd = makeFn({ [inner]: fwdFunc, [strings]: [...f[strings]] }); const bwd = makeFn({ [inner]: bwdFunc, [strings]: [...f[strings]] }); return (arg: A) => { diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 58b18ea..ffc3662 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -7,11 +7,14 @@ import { Vec, abs, add, + and, ceil, + compile, div, floor, fn, gt, + iff, interp, jvp, mul, @@ -27,6 +30,7 @@ import { trunc, vec, vjp, + xor, } from "./index.js"; describe("invalid", () => { @@ -268,7 +272,7 @@ describe("valid", () => { expect(g(0, [true])).toBe(true); }); - test("matrix multiplication", () => { + test("matrix multiplication", async () => { const n = 6; const Rn = Vec(n, Real); @@ -303,7 +307,7 @@ describe("valid", () => { }), ); - const f = interp(mmul); + const f = await compile(mmul); expect( f( [ @@ -408,12 +412,12 @@ describe("valid", () => { expect(h({ re: 3, du: 1 })).toEqual({ re: 9, du: 6 }); }); - test("JVP with sharing in call graph", () => { + test("JVP with sharing in call graph", async () => { let f = fn([Real], Real, (x) => x); for (let i = 0; i < 20; ++i) { f = fn([Real], Real, (x) => add(f(x), f(x))); } - const g = interp(jvp(f)); + const g = await compile(jvp(f)); expect(g({ re: 2, du: 3 })).toEqual({ re: 2097152, du: 3145728 }); }); @@ -439,7 +443,7 @@ describe("valid", () => { expect(interp(g)()).toEqual([6, 3, 2]); }); - test("VJP with sharing in call graph", () => { + test("VJP with sharing in call graph", async () => { const iterate = ( n: number, f: (x: Real) => Real, @@ -465,7 +469,7 @@ describe("valid", () => { const { ret, grad } = g(x); return [ret, grad(y)]; }); - const v = interp(h)(2, 3); + const v = (await compile(h))(2, 3); expect(v[0]).toBeCloseTo(2); expect(v[1]).toBeCloseTo(3); }); @@ -614,4 +618,231 @@ describe("valid", () => { f = grad(f); expect(interp(f)(1)).toBeCloseTo(-Math.cos(1)); }); + + test("compile", async () => { + const f1 = fn([Real], Real, (x) => sqrt(x)); + const f2 = fn([Real, Real], Real, (x, y) => mul(x, f1(y))); + const f3 = fn([Real, Real], Real, (x, y) => mul(f1(x), y)); + const f = fn([Real, Real], Real, (x, y) => sub(f2(x, y), f3(x, y))); + const g = await compile(f); + expect(g(2, 3)).toBeCloseTo(-0.7785390719815313); + }); + + test("compile opaque function", async () => { + const f = opaque([Real], Real, Math.sin); + const g = await compile(f); + expect(g(1)).toBeCloseTo(Math.sin(1)); + }); + + test("compile calls to multiple opaque functions", async () => { + const sin = opaque([Real], Real, Math.sin); + const cos = opaque([Real], Real, Math.cos); + const f = fn([Real], Real, (x) => sub(sin(x), cos(x))); + const g = await compile(f); + expect(g(1)).toBeCloseTo(Math.sin(1) - Math.cos(1)); + }); + + test("compile opaque and transparent calls together", async () => { + const log = opaque([Real], Real, Math.log); + const f = fn([Real], Real, (x) => add(log(x), sqrt(x))); + const g = fn([Real], Real, (x) => add(f(x), x)); + const h = await compile(g); + expect(h(1)).toBeCloseTo(Math.log(1) + Math.sqrt(1) + 1); + }); + + test("compile array", async () => { + const f = fn([Vec(2, Real)], Real, (v) => mul(v[0], v[1])); + const g = fn([Real, Real], Real, (x, y) => f([x, y])); + const h = await compile(g); + expect(h(2, 3)).toBe(6); + }); + + test("compile null array", async () => { + const f = fn([Vec(2, Null)], Null, (v) => v[1]); + const g = fn([], Real, () => { + f([null, null]); + return 42; + }); + const h = await compile(g); + expect(h()).toBe(42); + }); + + test("compile struct", async () => { + const f = fn([{ x: Real, y: Real }], Real, ({ x, y }) => mul(x, y)); + const g = fn([Real, Real], Real, (x, y) => f({ x, y })); + const h = await compile(g); + expect(h(2, 3)).toBe(6); + }); + + test("compile logic", async () => { + const f = fn([Vec(3, Bool)], Bool, (v) => { + const p = v[0]; + const q = v[1]; + const r = v[2]; + return iff(and(or(p, not(q)), xor(r, q)), or(not(p), and(q, r))); + }); + const g = fn([Bool, Bool, Bool], Real, (p, q, r) => + select(f([p, q, r]), Real, -1, -2), + ); + const h = await compile(g); + expect(h(true, true, true)).toBe(-2); + expect(h(true, true, false)).toBe(-2); + expect(h(true, false, true)).toBe(-2); + expect(h(true, false, false)).toBe(-1); + expect(h(false, true, true)).toBe(-2); + expect(h(false, true, false)).toBe(-2); + expect(h(false, false, true)).toBe(-1); + expect(h(false, false, false)).toBe(-2); + }); + + test("compile signum", async () => { + const f = fn([Real], Real, (x) => sign(x)); + const g = await compile(f); + expect(g(-2)).toBe(-1); + expect(g(-0)).toBe(-1); + expect(g(0)).toBe(1); + expect(g(2)).toBe(1); + }); + + test("compile select", async () => { + const f = fn([Bool, Real, Real], Real, (p, x, y) => select(p, Real, x, y)); + const g = await compile(f); + expect(g(true, 2, 3)).toBe(2); + expect(g(false, 5, 7)).toBe(7); + }); + + test("compile vector comprehension", async () => { + const f = fn([Real, Vec(3, Real)], Vec(3, Real), (c, v) => + vec(3, Real, (i) => mul(c, v[i])), + ); + const g = fn([Real, Real, Real, Real], Real, (c, x, y, z) => { + const v = f(c, [x, y, z]); + return add(add(v[0], v[1]), v[2]); + }); + const h = await compile(g); + expect(h(2, 3, 5, 7)).toBe(30); + }); + + test("compile empty vector comprehension", async () => { + let i = 0; + const f = opaque([], Real, () => { + ++i; + return i; + }); + const g = fn([], Real, () => { + vec(0, Real, () => f()); + return 0; + }); + (await compile(g))(); + expect(i).toEqual(0); + }); + + test("compile VJP", async () => { + const f = fn( + [Vec(2, { p: Bool, x: Real } as const)], + { p: Vec(2, Bool), x: Vec(2, Real) }, + (v) => ({ + p: vec(2, Bool, (i) => not(v[i].p)), + x: vec(2, Real, (i) => { + const { p, x } = v[i]; + return select(p, Real, mul(x, x), x); + }), + }), + ); + const g = fn([Bool, Real, Bool, Real], Real, (p1, x1, q1, y1) => { + const { ret, grad } = vjp(f)([ + { p: p1, x: x1 }, + { p: q1, x: y1 }, + ]); + const { x } = ret; + const x2 = x[0]; + const y2 = x[1]; + const v = grad({ p: [true, false] as any, x: [2, 3] as any }); + const { x: x3 } = v[0]; + const { x: y3 } = v[1]; + return mul(sub(x3, y2), sub(y3, x2)); + }); + const h = await compile(g); + expect(h(true, 2, true, 3)).toBe(-14); + expect(h(true, 5, false, 7)).toBe(-286); + expect(h(false, 11, true, 13)).toBe(-11189); + expect(h(false, 17, false, 19)).toBe(238); + }); + + test("compile VJP with call", async () => { + const f = fn([Real], Real, (x) => x); + const g = fn([Real], Real, (x) => f(x)); + const h = fn([Real], Real, (x) => vjp(g)(x).ret); + expect((await compile(h))(1)).toBe(1); + }); + + test("compile VJP with opaque call", async () => { + const exp = opaque([Real], Real, Math.exp); + exp.jvp = fn([Dual], Dual, ({ re: x, du: dx }) => { + const y = exp(x); + return { re: y, du: mulLin(dx, y) }; + }); + const g = fn([Real], Real, (x) => exp(x)); + const h = fn([Real], Real, (x) => vjp(g)(x).ret); + expect((await compile(h))(1)).toBeCloseTo(Math.E); + }); + + test("compile nulls in signature", async () => { + const f = fn([Null], Null, (x) => x); + const g = await compile(f); + expect(g(null)).toBe(null); + }); + + test("compile booleans in signature", async () => { + const f = fn([Bool], Bool, (p) => not(p)); + const g = await compile(f); + expect(g(true)).toBe(false); + expect(g(false)).toBe(true); + }); + + test("compile null arrays in signature", async () => { + const f = fn([Vec(2, Null)], Vec(2, Null), (v) => v); + const g = await compile(f); + expect(g([null, null])).toEqual([null, null]); + }); + + test("compile byte index arrays in signature", async () => { + const n = 256; + const f = fn([Vec(3, n), Vec(3, 3)], Vec(3, n), (v, i) => + vec(3, n, (j) => v[i[j]]), + ); + const g = await compile(f); + expect(g([12, 221, 234], [1, 2, 0])).toEqual([221, 234, 12]); + }); + + test("compile structs in signature", async () => { + const Pair = { x: Real, y: Real } as const; + const f = fn([Pair], Pair, ({ x, y }) => ({ x: y, y: x })); + const g = await compile(f); + expect(g({ x: 2, y: 3 })).toEqual({ x: 3, y: 2 }); + }); + + test("compile zero-sized struct members in signature", async () => { + const Stuff = { a: Null, b: 0, c: 0, d: Null } as const; + const f = fn([Stuff], Stuff, ({ a, b, c, d }) => { + return { a: d, b: c, c: b, d: a }; + }); + const g = await compile(f); + const stuff = { a: null, b: 0, c: 0, d: null }; + expect(g(stuff)).toEqual(stuff); + }); + + test("compile nested structs in signature", async () => { + const Pair = { x: Real, y: Real } as const; + const Stuff = { p: Bool, q: Pair } as const; + const f = fn([Stuff], Stuff, ({ p, q }) => ({ + p: not(p), + q: { x: q.y, y: q.x }, + })); + const g = await compile(f); + expect(g({ p: true, q: { x: 2, y: 3 } })).toEqual({ + p: false, + q: { x: 3, y: 2 }, + }); + }); }); diff --git a/packages/core/src/index.ts b/packages/core/src/index.ts index 447c056..bf15ac5 100644 --- a/packages/core/src/index.ts +++ b/packages/core/src/index.ts @@ -11,6 +11,7 @@ export { addLin, and, ceil, + compile, div, divLin, eq,