Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow compiled functions to share memory #125

Merged
merged 2 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions crates/wasm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::{
};
use wasm_encoder::{
BlockType, CodeSection, EntityType, ExportSection, Function, FunctionSection, ImportSection,
Instruction, MemArg, MemorySection, MemoryType, Module, TypeSection, ValType,
Instruction, MemArg, MemoryType, Module, TypeSection, ValType,
};

/// Resolve `ty` via `generics` and `types`, then return its ID in `typemap`, inserting if need be.
Expand Down Expand Up @@ -976,18 +976,23 @@ impl<'a, 'b, O: Eq + Hash, T: Refs<'a, Opaque = O>> Codegen<'a, 'b, O, T> {

/// A WebAssembly module for a graph of functions.
///
/// The module exports its memory with name `"m"` and its entrypoint function with name `"f"`. The
/// function takes one parameter in addition to its original parameters, which must be an
/// 8-byte-aligned pointer to the start of the memory region it can use for allocation. The memory
/// is the exact number of pages necessary to accommodate the function's own memory allocation as
/// well as memory allocation for all of its parameters, with each node in each parameter's memory
/// allocation tree being 8-byte aligned. That is, the function's last argument should be just large
/// enough to accommodate those allocations for all the parameters; in that case, no memory will be
/// The module exports its entrypoint function with name `"f"`. The function takes one parameter in
/// addition to its original parameters, which must be an 8-byte-aligned pointer to the start of the
/// memory region it can use for allocation.
///
/// Under module name `"m"`, the module imports a memory whose minimum number of pages is the exact
/// number of pages necessary to accommodate the function's own memory allocation as well as memory
/// allocation for all of its parameters, with each node in each parameter's memory allocation tree
/// being 8-byte aligned. That is, the function's last argument should be just large enough to
/// accommodate those allocations for all the parameters; in that case, no memory will be
/// incorrectly overwritten and no out-of-bounds memory accesses will occur.
pub struct Wasm<O> {
/// The bytes of the WebAssembly module binary.
pub bytes: Vec<u8>,

/// The minimum number of pages required by the imported memory.
pub pages: u64,

/// All the opaque functions that the WebAssembly module must import, in order.
///
/// The module name for each import is the empty string, and the field name is the base-ten
Expand Down Expand Up @@ -1390,7 +1395,6 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
type_section.function(params.into_vec(), results.into_vec());
}

let mut memory_section = MemorySection::new();
let page_size = 65536; // https://webassembly.github.io/spec/core/exec/runtime.html#page-size
let cost = funcs.last().map_or(0, |((def, _), (_, def_types, _))| {
def.params
Expand All @@ -1400,12 +1404,16 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
.sum()
}) + costs.last().unwrap_or(&0);
let pages = ((cost + page_size - 1) / page_size).into(); // round up to a whole number of pages
memory_section.memory(MemoryType {
minimum: pages,
maximum: Some(pages),
memory64: false,
shared: false,
});
import_section.import(
"m",
"",
MemoryType {
minimum: pages,
maximum: None,
memory64: false,
shared: false,
},
);

let mut export_section = ExportSection::new();
export_section.export(
Expand All @@ -1419,11 +1427,11 @@ pub fn compile<'a, O: Eq + Hash, T: Refs<'a, Opaque = O>>(f: Node<'a, O, T>) ->
module.section(&type_section);
module.section(&import_section);
module.section(&function_section);
module.section(&memory_section);
module.section(&export_section);
module.section(&code_section);
Wasm {
bytes: module.finish(),
pages,
imports,
}
}
8 changes: 7 additions & 1 deletion crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,14 @@ impl Func {

/// Compile the call graph subtended by this function to WebAssembly.
pub fn compile(&self) -> Wasm {
let rose_wasm::Wasm { bytes, imports } = rose_wasm::compile(self.node());
let rose_wasm::Wasm {
bytes,
pages,
imports,
} = rose_wasm::compile(self.node());
Wasm {
bytes: Some(bytes),
pages,
imports: Some(
imports
.into_keys()
Expand Down Expand Up @@ -488,6 +493,7 @@ impl Func {
#[wasm_bindgen]
pub struct Wasm {
bytes: Option<Vec<u8>>,
pub pages: u64,
imports: Option<Vec<js_sys::Function>>,
}

Expand Down
30 changes: 25 additions & 5 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -814,19 +814,39 @@ const getMeta = (
} else return undefined;
};

/** Concretize the abstract function `f` using the compiler. */
interface CompileOptions {
memory?: WebAssembly.Memory;
}

/**
* Concretize the abstract function `f` using the compiler.
*
* Creates a new memory if `opts.memory` is not provided, otherwise attempts to
* grow the provided memory to be large enough.
*/
export const compile = async <const A extends readonly any[], const R>(
f: Fn & ((...args: A) => R),
opts?: CompileOptions,
): Promise<(...args: JsArgs<A>) => ToJs<R>> => {
const func = f[inner];
const res = func.compile();
const bytes = res.bytes()!;
const pages = Number(res.pages);
const imports = res.imports()!;
res.free();
const instance = await WebAssembly.instantiate(
await WebAssembly.compile(bytes),
{ "": Object.fromEntries(imports.map((g, i) => [i.toString(), g])) },
);
let memory = opts?.memory;
if (memory === undefined) memory = new WebAssembly.Memory({ initial: pages });
else {
// https://webassembly.github.io/spec/core/exec/runtime.html#page-size
const pageSize = 65536;
const delta = pages - memory.buffer.byteLength / pageSize;
if (delta > 0) memory.grow(delta);
}
const mod = await WebAssembly.compile(bytes);
const instance = await WebAssembly.instantiate(mod, {
m: { "": memory },
"": Object.fromEntries(imports.map((g, i) => [i.toString(), g])),
});
const { f: g, m } = instance.exports;
const metas: (Meta | undefined)[] = [];
const n = func.numTypes();
Expand Down
28 changes: 28 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,34 @@ describe("valid", () => {
expect(g(2, 3)).toBeCloseTo(-0.7785390719815313);
});

test("compile with shared memory", async () => {
// https://webassembly.github.io/spec/core/exec/runtime.html#page-size
const pageSize = 65536;

const memory = new WebAssembly.Memory({ initial: 0 });
expect(memory.buffer.byteLength).toBe(0);

const f = fn([Vec(2, Real)], Real, ([x, y]) => mul(x, y));
const fCompiled = await compile(f, { memory });
expect(memory.buffer.byteLength).toBe(pageSize);
expect(fCompiled([2, 3])).toBe(6);

const n = 10000;
const g = fn([Vec(n, Real), Vec(n, Real)], Vec(n, Real), (a, b) =>
vec(n, Real, (i) => mul(a[i], b[i])),
);
const gCompiled = await compile(g, { memory });
expect(memory.buffer.byteLength).toBeGreaterThan(pageSize);
const a = [];
const b = [];
for (let i = 1; i <= n; ++i) {
a.push(i);
b.push(1 / i);
}
const c = gCompiled(a, b);
for (let i = 0; i < n; ++i) expect(c[i]).toBeCloseTo(1);
});

test("compile opaque function", async () => {
const f = opaque([Real], Real, Math.sin);
const g = await compile(f);
Expand Down
Loading