Skip to content

Commit

Permalink
Allow compiled functions to share memory (#125)
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep authored Mar 20, 2024
1 parent 8e3b8fd commit 48eb14e
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 22 deletions.
40 changes: 24 additions & 16 deletions crates/wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
};
use wasm_encoder::{
BlockType, CodeSection, EntityType, ExportSection, Function, FunctionSection, ImportSection,
Instruction, MemArg, MemorySection, MemoryType, Module, TypeSection, ValType,
Instruction, MemArg, MemoryType, Module, TypeSection, ValType,
};

/// Resolve `ty` via `generics` and `types`, then return its ID in `typemap`, inserting if need be.
Expand Down Expand Up @@ -976,18 +976,23 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {

/// A WebAssembly module for a graph of functions.
///
/// The module exports its memory with name `"m"` and its entrypoint function with name `"f"`. The
/// function takes one parameter in addition to its original parameters, which must be an
/// 8-byte-aligned pointer to the start of the memory region it can use for allocation. The memory
/// is the exact number of pages necessary to accommodate the function's own memory allocation as
/// well as memory allocation for all of its parameters, with each node in each parameter's memory
/// allocation tree being 8-byte aligned. That is, the function's last argument should be just large
/// enough to accommodate those allocations for all the parameters; in that case, no memory will be
/// The module exports its entrypoint function with name `"f"`. The function takes one parameter in
/// addition to its original parameters, which must be an 8-byte-aligned pointer to the start of the
/// memory region it can use for allocation.
///
/// Under module name `"m"`, the module imports a memory whose minimum number of pages is the exact
/// number of pages necessary to accommodate the function's own memory allocation as well as memory
/// allocation for all of its parameters, with each node in each parameter's memory allocation tree
/// being 8-byte aligned. That is, the function's last argument should be just large enough to
/// accommodate those allocations for all the parameters; in that case, no memory will be
/// incorrectly overwritten and no out-of-bounds memory accesses will occur.
pub struct Wasm<O> {
/// The bytes of the WebAssembly module binary.
pub bytes: Vec<u8>,

/// The minimum number of pages required by the imported memory.
pub pages: u64,

/// All the opaque functions that the WebAssembly module must import, in order.
///
/// The module name for each import is the empty string, and the field name is the base-ten
Expand Down Expand Up @@ -1390,7 +1395,6 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
type_section.function(params.into_vec(), results.into_vec());
}

let mut memory_section = MemorySection::new();
let page_size = 65536; // https://webassembly.github.io/spec/core/exec/runtime.html#page-size
let cost = funcs.last().map_or(0, |((def, _), (_, def_types, _))| {
def.params
Expand All @@ -1400,12 +1404,16 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
.sum()
}) + costs.last().unwrap_or(&0);
let pages = ((cost + page_size - 1) / page_size).into(); // round up to a whole number of pages
memory_section.memory(MemoryType {
minimum: pages,
maximum: Some(pages),
memory64: false,
shared: false,
});
import_section.import(
"m",
"",
MemoryType {
minimum: pages,
maximum: None,
memory64: false,
shared: false,
},
);

let mut export_section = ExportSection::new();
export_section.export(
Expand All @@ -1419,11 +1427,11 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
module.section(&type_section);
module.section(&import_section);
module.section(&function_section);
module.section(&memory_section);
module.section(&export_section);
module.section(&code_section);
Wasm {
bytes: module.finish(),
pages,
imports,
}
}
8 changes: 7 additions & 1 deletion crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,14 @@ impl Func {

/// Compile the call graph subtended by this function to WebAssembly.
pub fn compile(&self) -> Wasm {
let rose_wasm::Wasm { bytes, imports } = rose_wasm::compile(self.node());
let rose_wasm::Wasm {
bytes,
pages,
imports,
} = rose_wasm::compile(self.node());
Wasm {
bytes: Some(bytes),
pages,
imports: Some(
imports
.into_keys()
Expand Down Expand Up @@ -488,6 +493,7 @@ impl Func {
#[wasm_bindgen]
pub struct Wasm {
bytes: Option<Vec<u8>>,
pub pages: u64,
imports: Option<Vec<js_sys::Function>>,
}

Expand Down
30 changes: 25 additions & 5 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -814,19 +814,39 @@ const getMeta = (
} else return undefined;
};

/** Concretize the abstract function `f` using the compiler. */
interface CompileOptions {
memory?: WebAssembly.Memory;
}

/**
* Concretize the abstract function `f` using the compiler.
*
* Creates a new memory if `opts.memory` is not provided, otherwise attempts to
* grow the provided memory to be large enough.
*/
export const compile = async <const A extends readonly any[], const R>(
f: Fn & ((...args: A) => R),
opts?: CompileOptions,
): Promise<(...args: JsArgs<A>) => ToJs<R>> => {
const func = f[inner];
const res = func.compile();
const bytes = res.bytes()!;
const pages = Number(res.pages);
const imports = res.imports()!;
res.free();
const instance = await WebAssembly.instantiate(
await WebAssembly.compile(bytes),
{ "": Object.fromEntries(imports.map((g, i) => [i.toString(), g])) },
);
let memory = opts?.memory;
if (memory === undefined) memory = new WebAssembly.Memory({ initial: pages });
else {
// https://webassembly.github.io/spec/core/exec/runtime.html#page-size
const pageSize = 65536;
const delta = pages - memory.buffer.byteLength / pageSize;
if (delta > 0) memory.grow(delta);
}
const mod = await WebAssembly.compile(bytes);
const instance = await WebAssembly.instantiate(mod, {
m: { "": memory },
"": Object.fromEntries(imports.map((g, i) => [i.toString(), g])),
});
const { f: g, m } = instance.exports;
const metas: (Meta | undefined)[] = [];
const n = func.numTypes();
Expand Down
28 changes: 28 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,34 @@ describe("valid", () => {
expect(g(2, 3)).toBeCloseTo(-0.7785390719815313);
});

test("compile with shared memory", async () => {
// https://webassembly.github.io/spec/core/exec/runtime.html#page-size
const pageSize = 65536;

const memory = new WebAssembly.Memory({ initial: 0 });
expect(memory.buffer.byteLength).toBe(0);

const f = fn([Vec(2, Real)], Real, ([x, y]) => mul(x, y));
const fCompiled = await compile(f, { memory });
expect(memory.buffer.byteLength).toBe(pageSize);
expect(fCompiled([2, 3])).toBe(6);

const n = 10000;
const g = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (a, b) =>
vec(n, Real, (i) => mul(a[i], b[i])),
);
const gCompiled = await compile(g, { memory });
expect(memory.buffer.byteLength).toBeGreaterThan(pageSize);
const a = [];
const b = [];
for (let i = 1; i <= n; ++i) {
a.push(i);
b.push(1 / i);
}
const c = gCompiled(a, b);
for (let i = 0; i < n; ++i) expect(c[i]).toBeCloseTo(1);
});

test("compile opaque function", async () => {
const f = opaque([Real], Real, Math.sin);
const g = await compile(f);
Expand Down

0 comments on commit 48eb14e

Please sign in to comment.