Skip to content

Commit

Permalink
Fix compile when memory grows afterward (#126)
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep authored Mar 27, 2024
1 parent 5c2c0b0 commit 0eff92a
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 40 deletions.
88 changes: 50 additions & 38 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -636,13 +636,13 @@ interface Layout {
const aligned = ({ size, align }: Layout): number =>
(size + align - 1) & ~(align - 1);

/** An aligned `ArrayBuffer` view, or `undefined` for zero-sized types. */
type View = undefined | Uint8Array | Uint16Array | Uint32Array | Float64Array;
/** An aligned `ArrayBuffer` view. */
type View = Uint8Array | Uint16Array | Uint32Array | Float64Array;

const getView = (buffer: ArrayBuffer, layout: Layout, offset: number): View => {
// this code assumes that the layout is uniquely determined by its `size`
const { size } = layout;
if (size === 0) return undefined;
if (size === 0) throw Error("zero-sized type");
else if (size === 1) return new Uint8Array(buffer, offset);
else if (size === 2) return new Uint16Array(buffer, offset);
else if (size === 4) return new Uint32Array(buffer, offset);
Expand All @@ -660,13 +660,13 @@ interface Meta {
*
* The given byte offset is only used for pointer types.
*/
encode: (x: unknown, offset: number) => number;
encode: (x: unknown, pointer: number, buffer: ArrayBuffer) => number;

/** Total memory cost of an object of this type, including sub-allocations. */
cost: number;

/** Return a JS value represented by the Wasm value `x`. */
decode: (x: number) => unknown;
decode: (x: number, buffer: ArrayBuffer) => unknown;
}

/**
Expand All @@ -680,7 +680,6 @@ interface Meta {
*/
const getMeta = (
f: Fn,
buffer: ArrayBuffer,
metas: (Meta | undefined)[],
t: number,
): Meta | undefined => {
Expand Down Expand Up @@ -724,41 +723,48 @@ const getMeta = (
const n = func.size(func.index(t));
const elem = aligned(layout);
const total = aligned({ size: n * elem, align: 8 });
const view = getView(buffer, layout, 0);
return {
layout: { size: 4, align: 4 },
encode:
view === undefined
? (x, offset) => offset
: (x, offset) => {
let child = offset + total;
layout.size === 0
? (x, pointer) => pointer
: (x, pointer, buffer) => {
const view = getView(buffer, layout, 0);
let child = pointer + total;
for (let i = 0; i < n; ++i) {
view[offset / elem + i] = encode((x as unknown[])[i], child);
view[pointer / elem + i] = encode(
(x as unknown[])[i],
child,
buffer,
);
child += cost;
}
return offset;
return pointer;
},
cost: total + n * cost,
decode:
view === undefined
? () => {
layout.size === 0
? (x, buffer) => {
const arr: unknown[] = [];
// this code assumes that all values of all zero-sized types can
// be represented by zero
for (let i = 0; i < n; ++i) arr.push(decode(0));
for (let i = 0; i < n; ++i) arr.push(decode(0, buffer));
return arr;
}
: (x) => {
: (x, buffer) => {
const view = getView(buffer, layout, 0);
const arr: unknown[] = [];
for (let i = 0; i < n; ++i) arr.push(decode(view[x / elem + i]));
for (let i = 0; i < n; ++i)
arr.push(decode(view[x / elem + i], buffer));
return arr;
},
};
} else if (func.isStruct(t)) {
const keys = func.keys(t);
const members = func.mems(t);
const n = keys.length;
const mems: { key: string; meta: Meta; view?: View; child?: number }[] = [];
const mems: { key: string; meta: Meta; offset?: number; child?: number }[] =
[];
for (let i = 0; i < n; ++i) {
const meta = metas[members[i]];
if (meta === undefined) return undefined;
Expand All @@ -774,38 +780,41 @@ const getMeta = (
const { layout } = meta;
const { size, align } = layout;
offset = aligned({ size: offset, align });
mem.view = getView(buffer, layout, offset);
mem.offset = offset;
offset += size;
}
const total = aligned({ size: offset, align: 8 });
return {
layout: { size: 4, align: 4 },
encode: (x, offset) => {
for (const { key, meta, view, child } of mems) {
encode: (x, pointer, buffer) => {
for (const { key, meta, offset, child } of mems) {
// instead of mutating each element of `mems` above to add more data
// and then still having an `if` statement in here, it would be nicer
// to just map over `mems` above to produce an array of closures that
// can be called directly, with the condition on `view === undefined`
// being handled once rather than in every call to `encode` here
if (view !== undefined) {
view[offset / aligned(meta.layout)] = meta.encode(
if (meta.layout.size > 0) {
const view = getView(buffer, meta.layout, offset!);
view[pointer / aligned(meta.layout)] = meta.encode(
(x as any)[key],
offset + total + child!,
pointer + total + child!,
buffer,
);
}
}
return offset;
return pointer;
},
cost: total + cost,
decode: (x) => {
decode: (x, buffer) => {
const obj: any = {};
for (const { key, meta, view } of mems) {
if (view === undefined) {
for (const { key, meta, offset } of mems) {
if (meta.layout.size === 0) {
// this code assumes that all values of all zero-sized types can be
// represented by zero
obj[key] = meta.decode(0);
obj[key] = meta.decode(0, buffer);
} else {
obj[key] = meta.decode(view[x / aligned(meta.layout)]);
const view = getView(buffer, meta.layout, offset!);
obj[key] = meta.decode(view[x / aligned(meta.layout)], buffer);
}
}
return obj;
Expand Down Expand Up @@ -834,9 +843,11 @@ export const compile = async <const A extends readonly any[], const R>(
const pages = Number(res.pages);
const imports = res.imports()!;
res.free();
let memory = opts?.memory;
if (memory === undefined) memory = new WebAssembly.Memory({ initial: pages });
let memory: WebAssembly.Memory;
const given = opts?.memory;
if (given === undefined) memory = new WebAssembly.Memory({ initial: pages });
else {
memory = given;
// https://webassembly.github.io/spec/core/exec/runtime.html#page-size
const pageSize = 65536;
const delta = pages - memory.buffer.byteLength / pageSize;
Expand All @@ -847,11 +858,10 @@ export const compile = async <const A extends readonly any[], const R>(
m: { "": memory },
"": Object.fromEntries(imports.map((g, i) => [i.toString(), g])),
});
const { f: g, m } = instance.exports;
const { f: g } = instance.exports;
const metas: (Meta | undefined)[] = [];
const n = func.numTypes();
for (let t = 0; t < n; ++t)
metas.push(getMeta(f, (m as WebAssembly.Memory).buffer, metas, t));
for (let t = 0; t < n; ++t) metas.push(getMeta(f, metas, t));
let total = 0;
const params = Array.from(func.paramTypes()).map((t) => {
const { encode, cost } = metas[t]!;
Expand All @@ -861,8 +871,10 @@ export const compile = async <const A extends readonly any[], const R>(
});
const { decode } = metas[func.retType()]!;
return (...args): any => {
const vals = params.map(({ encode, offset }, i) => encode(args[i], offset));
return decode((g as any)(...vals, total));
const vals = params.map(({ encode, offset }, i) =>
encode(args[i], offset, memory.buffer),
);
return decode((g as any)(...vals, total), memory.buffer);
};
};

Expand Down
6 changes: 4 additions & 2 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -643,17 +643,19 @@ describe("valid", () => {
const memory = new WebAssembly.Memory({ initial: 0 });
expect(memory.buffer.byteLength).toBe(0);

const f = fn([Vec(2, Real)], Real, ([x, y]) => mul(x, y));
const f = fn([Vec(2, Real)], Vec(2, Real), ([x, y]) => [y, x]);
const fCompiled = await compile(f, { memory });
expect(memory.buffer.byteLength).toBe(pageSize);
expect(fCompiled([2, 3])).toBe(6);

const n = 10000;
const g = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (a, b) =>
vec(n, Real, (i) => mul(a[i], b[i])),
);
const gCompiled = await compile(g, { memory });
expect(memory.buffer.byteLength).toBeGreaterThan(pageSize);

expect(fCompiled([2, 3])).toEqual([3, 2]);

const a = [];
const b = [];
for (let i = 1; i <= n; ++i) {
Expand Down

0 comments on commit 0eff92a

Please sign in to comment.