Skip to content

Commit

Permalink
Handle interpreter translation more carefully (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep authored Aug 30, 2023
1 parent c386542 commit ea00c0a
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 37 deletions.
35 changes: 24 additions & 11 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,17 @@ impl Func {
}
}

/// Return true iff `t` is the ID of a finite integer type.
#[wasm_bindgen(js_name = "isFin")]
pub fn is_fin(&mut self, t: usize) -> bool {
let (inner, _) = self.rc.as_ref();
let ty = match inner {
Inner::Transparent { def, .. } => def.types.get(t),
Inner::Opaque { types, .. } => types.get(t),
};
matches!(ty, Some(rose::Ty::Fin { .. }))
}

/// Return the ID of the element type for the array type with ID `t`.
pub fn elem(&self, t: usize) -> usize {
let (inner, _) = self.rc.as_ref();
Expand All @@ -200,21 +211,23 @@ impl Func {
}
}

/// Return the string ID for member `i` of the struct type with ID `t`.
pub fn key(&self, t: usize, i: usize) -> usize {
/// Return the string IDs for the struct type with ID `t`.
pub fn keys(&self, t: usize) -> Box<[usize]> {
let (_, structs) = self.rc.as_ref();
structs[t].as_ref().unwrap()[i]
structs[t].as_ref().unwrap().clone()
}

/// Return the member type ID for member `i` of the struct type with ID `t`.
pub fn mem(&self, t: usize, i: usize) -> usize {
/// Return the member type IDs for the struct type with ID `t`.
pub fn mems(&self, t: usize) -> Box<[usize]> {
let (inner, _) = self.rc.as_ref();
match inner {
Inner::Transparent { def, .. } => match &def.types[t] {
rose::Ty::Tuple { members } => members[i].ty(),
_ => panic!("not a struct"),
},
Inner::Opaque { .. } => panic!(),
let ty = match inner {
Inner::Transparent { def, .. } => def.types.get(t),
Inner::Opaque { types, .. } => types.get(t),
}
.unwrap();
match ty {
rose::Ty::Tuple { members } => members.iter().map(|m| m.ty()).collect(),
_ => panic!("not a struct"),
}
}

Expand Down
61 changes: 36 additions & 25 deletions packages/core/src/impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -368,9 +368,7 @@ type ToSymbolic<T> = T extends Nulls
? Nat
: T extends Vecs<any, infer V>
? Vec<ToSymbolic<V>>
: {
[K in keyof T]: ToSymbolic<T[K]>;
};
: { [K in keyof T]: ToSymbolic<T[K]> };

/**
* Map from a type of a type to the type of the abstract values it represents.
Expand All @@ -389,9 +387,7 @@ type ToValue<T> = T extends Nulls
? Nat
: T extends Vecs<any, infer V>
? Vec<ToValue<V>> | ToValue<V>[]
: {
[K in keyof T]: ToValue<T[K]>;
};
: { [K in keyof T]: ToValue<T[K]> };

/** Map from a parameter type array to a symbolic parameter value type array. */
type SymbolicParams<T> = {
Expand Down Expand Up @@ -458,7 +454,7 @@ export const fn = <const P extends readonly any[], const R>(
export const custom = <const P extends readonly Reals[], const R extends Reals>(
params: P,
ret: R,
f: (...args: ToJs<SymbolicParams<P>>) => ToJs<ToValue<R>>,
f: (...args: JsArgs<SymbolicParams<P>>) => ToJs<ToValue<R>>,
): Fn & ((...args: ValueParams<P>) => ToSymbolic<R>) => {
// TODO: support more complicated signatures for opaque functions
const func = new wasm.Func(params.length, f);
Expand All @@ -475,20 +471,25 @@ export const custom = <const P extends readonly Reals[], const R extends Reals>(
type Js = null | boolean | number | Js[] | { [K: string]: Js };

/** Translate from the interpreteer's raw format to a concrete value. */
const pack = (f: Fn, t: number, x: Js): RawVal => {
const pack = (f: Fn, t: number, x: unknown): RawVal => {
const func = f[inner];
if (x === null) return "Unit";
if (typeof x === "boolean") return { Bool: x };
if (typeof x === "number") return { F64: x };
if (Array.isArray(x))
return { Array: x.map((y) => pack(f, func.elem(t), y)) };
else
return {
Tuple: Object.entries(x).map(([k, y], i) => [
f[strings][func.key(t, i)],
pack(f, func.mem(t, i), y),
]),
};
else if (typeof x === "number")
return func.isFin(t) ? { Fin: x } : { F64: x };
else if (typeof x === "object") {
if (x === null) return "Unit";
else if (Array.isArray(x))
return { Array: x.map((y) => pack(f, func.elem(t), y)) };
else {
const keys = func.keys(t);
const mems = func.mems(t);
const vals: RawVal[] = [];
for (let i = 0; i < keys.length; ++i) {
vals.push(pack(f, mems[i], (x as any)[f[strings][keys[i]]]));
}
return { Tuple: vals };
}
} else throw Error("invalid value");
};

/** Translate a concrete value from the interpreter's raw format. */
Expand All @@ -501,19 +502,29 @@ const unpack = (f: Fn, t: number, x: RawVal): Js => {
if ("Ref" in x) throw Error("Ref not supported");
if ("Array" in x)
return x.Array.map((y: RawVal) => unpack(f, func.elem(t), y));
else
else {
const keys = func.keys(t);
const mems = func.mems(t);
return Object.fromEntries(
x.Tuple.map((y: RawVal, i: number) => [
f[strings][func.key(t, i)],
unpack(f, func.mem(t, i), y),
f[strings][keys[i]],
unpack(f, mems[i], y),
]),
);
}
};

/** Map from an abstract value type to its corresponding concrete value type. */
type ToJs<T> = T extends ArrayOrVec<infer V>
? ToJs<V>[]
: Exclude<{ [K in keyof T]: ToJs<T[K]> }, Var | symbol>;
// https://www.typescriptlang.org/docs/handbook/2/conditional-types.html
type ToJs<T> = [T] extends [Null]
? null
: [T] extends [Bool]
? boolean
: [T] extends [Real]
? number
: [T] extends [Nat]
? number
: { [K in keyof T]: ToJs<T[K]> };

/** Map from an abstract value type array to a concrete argument type array. */
type JsArgs<T> = {
Expand Down
21 changes: 20 additions & 1 deletion packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,18 @@ describe("invalid", () => {
});

describe("valid", () => {
test("null", () => {
test("return null", () => {
const f = fn([], Null, () => null);
const g = interp(f);
expect(g()).toBe(null);
});

test("interp with null", () => {
const f = fn([Null], Null, (x) => x);
const g = interp(f);
expect(g(null)).toBe(null);
});

test("2 + 2 = 4", () => {
const f = fn([Real, Real], Real, (x, y) => add(x, y));
const g = interp(f);
Expand Down Expand Up @@ -171,6 +177,13 @@ describe("valid", () => {
expect(h()).toEqual([1, 2, 0]);
});

test("interp with index value", () => {
const n = 1;
const f = fn([n, Vec(n, Bool)], Bool, (i, v) => v[i]);
const g = interp(f);
expect(g(0, [true])).toBe(true);
});

test("matrix multiplication", () => {
const n = 6;

Expand Down Expand Up @@ -286,6 +299,12 @@ describe("valid", () => {
expect(g([3, 5])).toEqual([0, 1]);
});

test("interp struct arg", () => {
const f = fn([{ x: Real }], Real, (p) => p.x);
const g = interp(f);
expect(g({ x: 42 })).toBe(42);
});

test("custom unary function", () => {
const log = custom([Real], Real, Math.log);
const f = interp(log);
Expand Down

0 comments on commit ea00c0a

Please sign in to comment.