diff --git a/packages/core/src/impl.ts b/packages/core/src/impl.ts index 43ec980..7d1a4ce 100644 --- a/packages/core/src/impl.ts +++ b/packages/core/src/impl.ts @@ -636,13 +636,13 @@ interface Layout { const aligned = ({ size, align }: Layout): number => (size + align - 1) & ~(align - 1); -/** An aligned `ArrayBuffer` view, or `undefined` for zero-sized types. */ -type View = undefined | Uint8Array | Uint16Array | Uint32Array | Float64Array; +/** An aligned `ArrayBuffer` view. */ +type View = Uint8Array | Uint16Array | Uint32Array | Float64Array; const getView = (buffer: ArrayBuffer, layout: Layout, offset: number): View => { // this code assumes that the layout is uniquely determined by its `size` const { size } = layout; - if (size === 0) return undefined; + if (size === 0) throw Error("zero-sized type"); else if (size === 1) return new Uint8Array(buffer, offset); else if (size === 2) return new Uint16Array(buffer, offset); else if (size === 4) return new Uint32Array(buffer, offset); @@ -660,13 +660,13 @@ interface Meta { * * The given byte offset is only used for pointer types. */ - encode: (x: unknown, offset: number) => number; + encode: (x: unknown, pointer: number, buffer: ArrayBuffer) => number; /** Total memory cost of an object of this type, including sub-allocations. */ cost: number; /** Return a JS value represented by the Wasm value `x`. */ - decode: (x: number) => unknown; + decode: (x: number, buffer: ArrayBuffer) => unknown; } /** @@ -680,7 +680,6 @@ interface Meta { */ const getMeta = ( f: Fn, - buffer: ArrayBuffer, metas: (Meta | undefined)[], t: number, ): Meta | undefined => { @@ -724,33 +723,39 @@ const getMeta = ( const n = func.size(func.index(t)); const elem = aligned(layout); const total = aligned({ size: n * elem, align: 8 }); - const view = getView(buffer, layout, 0); return { layout: { size: 4, align: 4 }, encode: - view === undefined - ? (x, offset) => offset - : (x, offset) => { - let child = offset + total; + layout.size === 0 + ? (x, pointer) => pointer + : (x, pointer, buffer) => { + const view = getView(buffer, layout, 0); + let child = pointer + total; for (let i = 0; i < n; ++i) { - view[offset / elem + i] = encode((x as unknown[])[i], child); + view[pointer / elem + i] = encode( + (x as unknown[])[i], + child, + buffer, + ); child += cost; } - return offset; + return pointer; }, cost: total + n * cost, decode: - view === undefined - ? () => { + layout.size === 0 + ? (x, buffer) => { const arr: unknown[] = []; // this code assumes that all values of all zero-sized types can // be represented by zero - for (let i = 0; i < n; ++i) arr.push(decode(0)); + for (let i = 0; i < n; ++i) arr.push(decode(0, buffer)); return arr; } - : (x) => { + : (x, buffer) => { + const view = getView(buffer, layout, 0); const arr: unknown[] = []; - for (let i = 0; i < n; ++i) arr.push(decode(view[x / elem + i])); + for (let i = 0; i < n; ++i) + arr.push(decode(view[x / elem + i], buffer)); return arr; }, }; @@ -758,7 +763,8 @@ const getMeta = ( const keys = func.keys(t); const members = func.mems(t); const n = keys.length; - const mems: { key: string; meta: Meta; view?: View; child?: number }[] = []; + const mems: { key: string; meta: Meta; offset?: number; child?: number }[] = + []; for (let i = 0; i < n; ++i) { const meta = metas[members[i]]; if (meta === undefined) return undefined; @@ -774,38 +780,41 @@ const getMeta = ( const { layout } = meta; const { size, align } = layout; offset = aligned({ size: offset, align }); - mem.view = getView(buffer, layout, offset); + mem.offset = offset; offset += size; } const total = aligned({ size: offset, align: 8 }); return { layout: { size: 4, align: 4 }, - encode: (x, offset) => { - for (const { key, meta, view, child } of mems) { + encode: (x, pointer, buffer) => { + for (const { key, meta, offset, child } of mems) { // instead of mutating each element of `mems` above to add more data // and then still having an `if` statement in here, it would be nicer // to just map over `mems` above to produce an array of closures that // can be called directly, with the condition on `view === undefined` // being handled once rather than in every call to `encode` here - if (view !== undefined) { - view[offset / aligned(meta.layout)] = meta.encode( + if (meta.layout.size > 0) { + const view = getView(buffer, meta.layout, offset!); + view[pointer / aligned(meta.layout)] = meta.encode( (x as any)[key], - offset + total + child!, + pointer + total + child!, + buffer, ); } } - return offset; + return pointer; }, cost: total + cost, - decode: (x) => { + decode: (x, buffer) => { const obj: any = {}; - for (const { key, meta, view } of mems) { - if (view === undefined) { + for (const { key, meta, offset } of mems) { + if (meta.layout.size === 0) { // this code assumes that all values of all zero-sized types can be // represented by zero - obj[key] = meta.decode(0); + obj[key] = meta.decode(0, buffer); } else { - obj[key] = meta.decode(view[x / aligned(meta.layout)]); + const view = getView(buffer, meta.layout, offset!); + obj[key] = meta.decode(view[x / aligned(meta.layout)], buffer); } } return obj; @@ -834,9 +843,11 @@ export const compile = async ( const pages = Number(res.pages); const imports = res.imports()!; res.free(); - let memory = opts?.memory; - if (memory === undefined) memory = new WebAssembly.Memory({ initial: pages }); + let memory: WebAssembly.Memory; + const given = opts?.memory; + if (given === undefined) memory = new WebAssembly.Memory({ initial: pages }); else { + memory = given; // https://webassembly.github.io/spec/core/exec/runtime.html#page-size const pageSize = 65536; const delta = pages - memory.buffer.byteLength / pageSize; @@ -847,11 +858,10 @@ export const compile = async ( m: { "": memory }, "": Object.fromEntries(imports.map((g, i) => [i.toString(), g])), }); - const { f: g, m } = instance.exports; + const { f: g } = instance.exports; const metas: (Meta | undefined)[] = []; const n = func.numTypes(); - for (let t = 0; t < n; ++t) - metas.push(getMeta(f, (m as WebAssembly.Memory).buffer, metas, t)); + for (let t = 0; t < n; ++t) metas.push(getMeta(f, metas, t)); let total = 0; const params = Array.from(func.paramTypes()).map((t) => { const { encode, cost } = metas[t]!; @@ -861,8 +871,10 @@ export const compile = async ( }); const { decode } = metas[func.retType()]!; return (...args): any => { - const vals = params.map(({ encode, offset }, i) => encode(args[i], offset)); - return decode((g as any)(...vals, total)); + const vals = params.map(({ encode, offset }, i) => + encode(args[i], offset, memory.buffer), + ); + return decode((g as any)(...vals, total), memory.buffer); }; }; diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 11d6da1..4a60f35 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -643,10 +643,9 @@ describe("valid", () => { const memory = new WebAssembly.Memory({ initial: 0 }); expect(memory.buffer.byteLength).toBe(0); - const f = fn([Vec(2, Real)], Real, ([x, y]) => mul(x, y)); + const f = fn([Vec(2, Real)], Vec(2, Real), ([x, y]) => [y, x]); const fCompiled = await compile(f, { memory }); expect(memory.buffer.byteLength).toBe(pageSize); - expect(fCompiled([2, 3])).toBe(6); const n = 10000; const g = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (a, b) => @@ -654,6 +653,9 @@ describe("valid", () => { ); const gCompiled = await compile(g, { memory }); expect(memory.buffer.byteLength).toBeGreaterThan(pageSize); + + expect(fCompiled([2, 3])).toEqual([3, 2]); + const a = []; const b = []; for (let i = 1; i <= n; ++i) {