From 23a9b9eaad0930829815531e221d7793ab73401c Mon Sep 17 00:00:00 2001 From: Sam Estep Date: Sun, 24 Sep 2023 01:27:38 -0400 Subject: [PATCH] Start trying to compile accumulation --- crates/wasm/src/lib.rs | 285 +++++++++++++++++++++++++------- packages/core/src/index.test.ts | 32 ++++ 2 files changed, 253 insertions(+), 64 deletions(-) diff --git a/crates/wasm/src/lib.rs b/crates/wasm/src/lib.rs index 1ec8d03..8153c24 100644 --- a/crates/wasm/src/lib.rs +++ b/crates/wasm/src/lib.rs @@ -141,22 +141,43 @@ impl Layout { Layout::Unit => (0, 1), Layout::U8 => (1, 1), Layout::U16 => (2, 2), - Layout::U32 | Layout::Ref => (4, 4), + Layout::U32 => (4, 4), Layout::F64 => (8, 8), + Layout::Ref => unreachable!(), } } - fn size(self) -> Size { - let (size, _) = self.size_align(); - size + fn aligned(self) -> Size { + let (s, a) = self.size_align(); + align(s, a) } } type Local = u32; +#[derive(Clone, Copy)] +struct Accum { + /// The ID of the zero function. + zero: u32, + + /// The allocation cost of the zero function. + cost: Size, + + // The ID of the add function, which has no allocation cost. + add: u32, +} + +struct Meta { + ty: Ty, + layout: Layout, + accum: Option, + members: Option>, +} + struct Codegen<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> { - layouts: &'b [(Layout, Option>)], + metas: &'b [Meta], imports: &'b Imports, + extras: usize, funcs: &'b Funcs<'a, T>, costs: &'b [Size], refs: &'b T, @@ -168,9 +189,8 @@ struct Codegen<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> { } impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { - fn layout(&self, t: id::Ty) -> (Layout, Option<&'b [Size]>) { - let (layout, members) = &self.layouts[self.types[t.ty()].ty()]; - (*layout, members.as_ref().map(|x| &x[..])) + fn meta(&self, t: id::Ty) -> &'b Meta { + &self.metas[self.types[t.ty()].ty()] } fn get(&mut self, x: id::Var) { @@ -303,12 +323,12 @@ impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { .instruction(&Instruction::I32Const(val.try_into().unwrap())); } Expr::Array { elems } => { - let (layout, _) = - self.layout(match self.def.types[self.def.vars[instr.var.var()].ty()] { + let &Meta { layout, .. } = + self.meta(match self.def.types[self.def.vars[instr.var.var()].ty()] { Ty::Array { elem, .. } => elem, _ => unreachable!(), }); - let size = layout.size(); + let size = layout.aligned(); for (i, &elem) in elems.iter().enumerate() { self.pointer(); self.get(elem); @@ -319,20 +339,20 @@ impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { } Expr::Tuple { members } => { let mut size = 0; - let (_, mems) = self.layout(self.def.vars[instr.var.var()]); - for (&member, &offset) in members.iter().zip(mems.unwrap().iter()) { - let (layout, _) = self.layout(self.def.vars[member.var()]); + let Meta { members: mems, .. } = self.meta(self.def.vars[instr.var.var()]); + for (&member, &offset) in members.iter().zip(mems.as_ref().unwrap().iter()) { + let &Meta { layout, .. } = self.meta(self.def.vars[member.var()]); self.pointer(); self.get(member); self.store(layout, offset); - size = size.max(offset + layout.size()); + size = size.max(offset + layout.aligned()); } self.pointer(); self.bump(size); } &Expr::Index { array, index } => { - let (layout, _) = self.layout(self.def.vars[instr.var.var()]); - let size = layout.size(); + let &Meta { layout, .. } = self.meta(self.def.vars[instr.var.var()]); + let size = layout.aligned(); self.get(array); self.get(index); self.u32_const(size); @@ -341,9 +361,9 @@ impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { self.load(layout, 0); } &Expr::Member { tuple, member } => { - let (_, members) = self.layout(self.def.vars[tuple.var()]); - let offset = members.unwrap()[member.member()]; - let (layout, _) = self.layout(self.def.vars[instr.var.var()]); + let Meta { members, .. } = self.meta(self.def.vars[tuple.var()]); + let offset = members.as_ref().unwrap()[member.member()]; + let &Meta { layout, .. } = self.meta(self.def.vars[instr.var.var()]); self.get(tuple); self.load(layout, offset); } @@ -426,7 +446,7 @@ impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { self.pointer(); let j = self.funcs.get_index_of(&(ByAddress(def), gens)).unwrap(); self.bump(self.costs[j]); - self.imports.len() + j + self.imports.len() + self.extras + j } Node::Opaque { def, .. } => { self.imports.get_index_of(&(def, gens)).unwrap() @@ -436,12 +456,12 @@ impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { .instruction(&Instruction::Call(i.try_into().unwrap())); } Expr::For { arg, body, ret } => { - let n = u_size(match self.def.types[self.def.vars[arg.var()].ty()] { + let n = u_size(match self.meta(self.def.vars[arg.var()]).ty { Ty::Fin { size } => size, _ => unreachable!(), }); - let (layout, _) = self.layout(self.def.vars[ret.var()]); - let size = layout.size(); + let &Meta { layout, .. } = self.meta(self.def.vars[ret.var()]); + let size = layout.aligned(); self.pointer(); self.set(instr.var); @@ -478,12 +498,56 @@ impl<'a, 'b, O: Hash + Eq, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> { continue; } - &Expr::Accum { shape: _ } => todo!(), - &Expr::Add { - accum: _, - addend: _, - } => todo!(), - &Expr::Resolve { var: _ } => todo!(), + &Expr::Accum { shape } => { + let meta = self.meta(self.def.vars[shape.var()]); + match &meta.ty { + Ty::Unit | Ty::Bool | Ty::Fin { .. } => self.get(shape), + Ty::F64 => { + self.pointer(); + self.pointer(); + self.wasm.instruction(&Instruction::F64Const(0.)); + self.store(Layout::F64, 0); + self.bump(8); + } + Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), + Ty::Array { .. } | Ty::Tuple { .. } => { + let Accum { zero, cost, .. } = meta.accum.unwrap(); + self.pointer(); + self.get(shape); + self.wasm.instruction(&Instruction::Call(zero)); + self.pointer(); + self.bump(cost); + } + } + } + &Expr::Add { accum, addend } => { + let meta = self.meta(self.def.vars[addend.var()]); + match &meta.ty { + Ty::Unit | Ty::Bool | Ty::Fin { .. } => {} + Ty::F64 => { + self.get(accum); + self.get(accum); + self.load(Layout::F64, 0); + self.get(addend); + self.wasm.instruction(&Instruction::F64Add); + self.store(Layout::F64, 0); + } + Ty::Generic { .. } | Ty::Ref { .. } => unreachable!(), + Ty::Array { .. } | Ty::Tuple { .. } => { + self.get(accum); + self.get(addend); + self.wasm + .instruction(&Instruction::Call(meta.accum.unwrap().add)); + } + } + self.wasm.instruction(&Instruction::I32Const(0)); + } + &Expr::Resolve { var } => { + self.get(var); + if let Ty::F64 = &self.meta(self.def.vars[var.var()]).ty { + self.load(Layout::F64, 0); + } + } } self.set(instr.var); } @@ -531,12 +595,33 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> funcs, } = topsort; - let mut layouts: Vec<(Layout, Option>)> = vec![]; - for ty in types.iter() { - let (layout, members) = match ty { - Ty::Unit => (Layout::Unit, None), - Ty::Bool => (Layout::U8, None), - Ty::F64 => (Layout::F64, None), + let mut func_types: IndexSet<(Box<[ValType]>, ValType)> = IndexSet::new(); + + let mut import_section = ImportSection::new(); + for (i, (params, ret)) in imports.values().enumerate() { + let (type_index, _) = func_types.insert_full(( + params.iter().map(|t| val_type(&types[t.ty()])).collect(), + val_type(&types[ret.ty()]), + )); + import_section.import( + "", + &i.to_string(), + EntityType::Function((1 + type_index).try_into().unwrap()), + ); + } + + let mut costs = vec![]; + + let mut function_section = FunctionSection::new(); + let mut code_section = CodeSection::new(); + + let mut metas: Vec = vec![]; + let mut extras: usize = 0; + for ty in types.into_iter() { + let (layout, cost, members) = match &ty { + Ty::Unit => (Layout::Unit, None, None), + Ty::Bool => (Layout::U8, None, None), + Ty::F64 => (Layout::F64, None, None), &Ty::Fin { size } => ( if size <= 256 { Layout::U8 @@ -546,16 +631,84 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> Layout::U32 }, None, + None, ), Ty::Generic { .. } => unreachable!(), - Ty::Ref { .. } => (Layout::Ref, None), - Ty::Array { .. } => (Layout::U32, None), + Ty::Ref { .. } => (Layout::Ref, None, None), + Ty::Array { index, elem } => { + let n = match metas[index.ty()].ty { + Ty::Fin { size } => size, + _ => unreachable!(), + }; + let meta = &metas[elem.ty()]; + let size = meta.layout.aligned(); + + let mut zero = Function::new([(2, ValType::I32)]); + let mut add = Function::new([(1, ValType::I32)]); + + if n > 0 { + let total = size * u_size(n); + + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::I32Const(total.try_into().unwrap())); + zero.instruction(&Instruction::I32Add); + zero.instruction(&Instruction::LocalTee(2)); + zero.instruction(&Instruction::LocalSet(3)); + zero.instruction(&Instruction::Loop(BlockType::Empty)); + + add.instruction(&Instruction::LocalGet(0)); + add.instruction(&Instruction::I32Const(total.try_into().unwrap())); + add.instruction(&Instruction::I32Add); + add.instruction(&Instruction::LocalSet(2)); + add.instruction(&Instruction::Loop(BlockType::Empty)); + + match &meta.ty { + Ty::Unit => {} + Ty::Bool | Ty::Fin { .. } => todo!(), + Ty::F64 => { + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::F64Const(0.)); + todo!() + } + Ty::Generic { id } => todo!(), + Ty::Ref { inner } => todo!(), + Ty::Array { index, elem } => todo!(), + Ty::Tuple { members } => todo!(), + } + + zero.instruction(&Instruction::LocalGet(0)); + zero.instruction(&Instruction::I32Const(size.try_into().unwrap())); + zero.instruction(&Instruction::I32Add); + zero.instruction(&Instruction::LocalTee(0)); + zero.instruction(&Instruction::LocalGet(2)); + zero.instruction(&Instruction::I32LtU); + zero.instruction(&Instruction::BrIf(0)); + zero.instruction(&Instruction::End); + + add.instruction(&Instruction::LocalGet(0)); + add.instruction(&Instruction::I32Const(size.try_into().unwrap())); + add.instruction(&Instruction::I32Add); + add.instruction(&Instruction::LocalTee(0)); + add.instruction(&Instruction::LocalGet(2)); + add.instruction(&Instruction::I32LtU); + add.instruction(&Instruction::BrIf(0)); + add.instruction(&Instruction::End); + } + + code_section.function(&zero); + code_section.function(&add); + ( + Layout::U32, + Some((size + meta.accum.map_or(0, |acc| acc.cost)) * u_size(n)), + None, + ) + } Ty::Tuple { members } => { let mut mems: Vec<_> = members .iter() .enumerate() .map(|(i, t)| { - let (layout, _) = layouts[t.ty()]; + let Meta { layout, .. } = metas[t.ty()]; let (size, align) = layout.size_align(); (i, size, align) }) @@ -568,33 +721,33 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> offsets[i] = offset; offset += s; } - (Layout::U32, Some(offsets.into())) - } - }; - layouts.push((layout, members)); - } - let mut func_types: IndexSet<(Box<[ValType]>, ValType)> = IndexSet::new(); + let mut zero = Function::new([]); - let mut import_section = ImportSection::new(); - for (i, (params, ret)) in imports.values().enumerate() { - let (type_index, _) = func_types.insert_full(( - params.iter().map(|t| val_type(&types[t.ty()])).collect(), - val_type(&types[ret.ty()]), - )); - import_section.import( - "", - &i.to_string(), - EntityType::Function(type_index.try_into().unwrap()), - ); - } + let mut add = Function::new([]); - let mut costs = vec![]; + code_section.function(&zero); + code_section.function(&add); + (Layout::U32, Some(todo!()), Some(offsets.into())) + } + }; + metas.push(Meta { + ty, + layout, + accum: cost.map(|cost| { + let zero = extras.try_into().unwrap(); + function_section.function(0); + let add = (extras + 1).try_into().unwrap(); + function_section.function(0); + extras += 2; + Accum { zero, cost, add } + }), + members, + }); + } - let mut function_section = FunctionSection::new(); - let mut code_section = CodeSection::new(); for ((def, _), (refs, def_types)) in funcs.iter() { - let vt = |t: id::Ty| val_type(&types[def_types[t.ty()].ty()]); + let vt = |t: id::Ty| val_type(&metas[def_types[t.ty()].ty()].ty); let params: Local = (def.params.len() + 1).try_into().unwrap(); let mut locals = vec![None; def.vars.len()]; @@ -610,7 +763,7 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> .collect(), vt(def.vars[def.ret.var()]), )); - function_section.function(type_index.try_into().unwrap()); + function_section.function((1 + type_index).try_into().unwrap()); let mut i32s = 0; for (i, &t) in def.vars.iter().enumerate() { @@ -632,8 +785,9 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> let locals = locals.into_iter().map(Option::unwrap).collect::>(); let mut codegen = Codegen { - layouts: &layouts, + metas: &metas, imports: &imports, + extras, funcs: &funcs, costs: &costs, refs, @@ -651,6 +805,7 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> } let mut type_section = TypeSection::new(); + type_section.function([ValType::I32, ValType::I32], []); for (params, ret) in func_types { type_section.function(params.into_vec(), [ret]); } @@ -669,7 +824,9 @@ pub fn compile<'a, O: Hash + Eq, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) -> export_section.export( "f", wasm_encoder::ExportKind::Func, - (imports.len() + funcs.len() - 1).try_into().unwrap(), + (imports.len() + extras + funcs.len() - 1) + .try_into() + .unwrap(), ); export_section.export("m", wasm_encoder::ExportKind::Memory, 0); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 0c82768..d45bd32 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -736,4 +736,36 @@ describe("valid", () => { (await compile(g))(); expect(i).toEqual(0); }); + + test("compile VJP", async () => { + const f = fn( + [Vec(2, { p: Bool, x: Real } as const)], + { p: Vec(2, Bool), x: Vec(2, Real) }, + (v) => ({ + p: vec(2, Bool, (i) => not(v[i].p)), + x: vec(2, Real, (i) => { + const { p, x } = v[i]; + return select(p, Real, mul(x, x), x); + }), + }), + ); + const g = fn([Bool, Real, Bool, Real], Real, (p1, x1, q1, y1) => { + const { ret, grad } = vjp(f)([ + { p: p1, x: x1 }, + { p: q1, x: y1 }, + ]); + const { x } = ret; + const x2 = x[0]; + const y2 = x[1]; + const v = grad({ p: [true, false] as any, x: [2, 3] as any }); + const { x: x3 } = v[0]; + const { x: y3 } = v[1]; + return mul(sub(x3, y2), sub(y3, x2)); + }); + const h = await compile(g); + expect(h(true, 2, true, 3)).toBe(-14); + expect(h(true, 5, false, 7)).toBe(-286); + expect(h(false, 11, true, 13)).toBe(-11189); + expect(h(false, 17, false, 19)).toBe(238); + }); });