Skip to content

Commit

Permalink
Allow custom opaque functions (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep authored Aug 29, 2023
1 parent d6d72cc commit 2c23c1d
Show file tree
Hide file tree
Showing 11 changed files with 831 additions and 535 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 31 additions & 4 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,42 @@ pub struct Function {
pub body: Box<[Instr]>,
}

/// Wrapper for a `Function` that knows how to resolve its `id::Function`s.
pub trait FuncNode {
fn def(&self) -> &Function;
/// Resolves `id::Function`s.
pub trait Refs<'a> {
/// See `Node`.
type Opaque;

fn get(&self, id: id::Function) -> Option<Self>
/// Resolve `id` to a function node.
fn get(&self, id: id::Function) -> Option<Node<'a, Self::Opaque, Self>>
where
Self: Sized;
}

/// A node in a graph of functions.
#[derive(Clone, Debug, Copy)]
pub enum Node<'a, O, T: Refs<'a, Opaque = O>> {
/// A function with an explicit body.
Transparent {
/// To traverse the graph by resolving functions called by this one.
refs: T,
/// The signature and definition of this function.
def: &'a Function,
},
/// A function with an opaque body.
Opaque {
/// Generic type parameters.
generics: &'a [EnumSet<Constraint>],
/// Types used in this function's signature.
types: &'a [Ty],
/// Parameter types.
params: &'a [id::Ty],
/// Return type.
ret: id::Ty,
/// Definition of this function; semantics may vary.
def: O,
},
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Instr {
Expand Down
1 change: 1 addition & 0 deletions crates/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ indexmap = "2"
lalrpop-util = "0.19"
logos = "0.13"
rose = { path = "../core" }
rose-interp = { path = "../interp" }
thiserror = "1"

[build-dependencies]
Expand Down
27 changes: 10 additions & 17 deletions crates/frontend/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::{ast, tokens};
use enumset::EnumSet;
use indexmap::{IndexMap, IndexSet};
use rose::{self as ir, id};
use std::{collections::HashMap, ops::Range};
use std::{collections::HashMap, convert::Infallible, ops::Range};

#[derive(Debug, thiserror::Error)]
pub enum TypeError {
Expand Down Expand Up @@ -113,19 +113,15 @@ pub struct Module<'input> {
funcs: IndexMap<&'input str, rose::Function>,
}

#[derive(Clone, Copy, Debug)]
pub struct FuncRef<'input, 'a> {
m: &'a Module<'input>,
id: id::Function,
}
type Opaque = Infallible;

impl<'input, 'a> rose::FuncNode for FuncRef<'input, 'a> {
fn def(&self) -> &rose::Function {
&self.m.funcs[self.id.function()]
}
impl<'input, 'a> rose::Refs<'a> for &'a Module<'input> {
type Opaque = Opaque;

fn get(&self, id: id::Function) -> Option<Self> {
Some(Self { m: self.m, id })
fn get(&self, id: id::Function) -> Option<ir::Node<'a, Opaque, Self>> {
self.funcs
.get_index(id.function())
.map(|(_, def)| ir::Node::Transparent { refs: *self, def })
}
}

Expand All @@ -134,12 +130,9 @@ impl Module<'_> {
self.types.get(name)
}

pub fn get_func(&self, name: &str) -> Option<FuncRef> {
pub fn get_func(&self, name: &str) -> Option<ir::Node<Opaque, &Module>> {
let i = self.funcs.get_index_of(name)?;
Some(FuncRef {
m: self,
id: id::function(i),
})
ir::Refs::get(&self, id::function(i))
}
}

Expand Down
178 changes: 136 additions & 42 deletions crates/interp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use indexmap::IndexSet;
use rose::{id, Binop, Expr, FuncNode, Ty, Unop};
use std::{cell::Cell, rc::Rc};
use rose::{id, Binop, Expr, Function, Node, Refs, Ty, Unop};
use std::{cell::Cell, convert::Infallible, rc::Rc};

#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -129,24 +129,44 @@ fn resolve(typemap: &mut IndexSet<Ty>, generics: &[id::Ty], types: &[id::Ty], ty
id::ty(i)
}

struct Interpreter<'a, F: FuncNode> {
typemap: &'a mut IndexSet<Ty>,
f: &'a F, // reference instead of value because otherwise borrow checker complains in `fn block`
/// An opaque function that can be called by the interpreter.
pub trait Opaque {
fn call(&self, types: &IndexSet<Ty>, generics: &[id::Ty], args: &[Val]) -> Val;
}

impl Opaque for Infallible {
fn call(&self, _: &IndexSet<Ty>, _: &[id::Ty], _: &[Val]) -> Val {
match *self {}
}
}

/// basically, the `'a` lifetime is for the graph of functions, and the `'b` lifetime is just for
/// this particular instance of interpretation
struct Interpreter<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> {
typemap: &'b mut IndexSet<Ty>,
refs: T,
def: &'a Function,
types: Vec<id::Ty>,
vars: Vec<Option<Val>>,
}

impl<'a, F: FuncNode> Interpreter<'a, F> {
fn new(typemap: &'a mut IndexSet<Ty>, f: &'a F, generics: &'a [id::Ty]) -> Self {
impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
fn new(
typemap: &'b mut IndexSet<Ty>,
refs: T,
def: &'a Function,
generics: &'b [id::Ty],
) -> Self {
let mut types = vec![];
for ty in f.def().types.iter() {
for ty in def.types.iter() {
types.push(resolve(typemap, generics, &types, ty));
}
Self {
typemap,
f,
refs,
def,
types,
vars: vec![None; f.def().vars.len()],
vars: vec![None; def.vars.len()],
}
}

Expand Down Expand Up @@ -229,10 +249,10 @@ impl<'a, F: FuncNode> Interpreter<'a, F> {
Expr::Call { id, generics, args } => {
let resolved: Vec<id::Ty> = generics.iter().map(|id| self.types[id.ty()]).collect();
let vals = args.iter().map(|id| self.vars[id.var()].clone().unwrap());
call(self.f.get(*id).unwrap(), self.typemap, &resolved, vals)
call(self.refs.get(*id).unwrap(), self.typemap, &resolved, vals)
}
Expr::For { arg, body, ret } => {
let n = match self.typemap[self.types[self.f.def().vars[arg.var()].ty()].ty()] {
let n = match self.typemap[self.types[self.def.vars[arg.var()].ty()].ty()] {
Ty::Fin { size } => size,
_ => unreachable!(),
};
Expand Down Expand Up @@ -279,30 +299,44 @@ impl<'a, F: FuncNode> Interpreter<'a, F> {
}

/// Assumes `generics` and `arg` are valid.
fn call(
f: impl FuncNode,
types: &mut IndexSet<Ty>,
generics: &[id::Ty],
fn call<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>>(
f: Node<'a, O, T>,
types: &'b mut IndexSet<Ty>,
generics: &'b [id::Ty],
args: impl Iterator<Item = Val>,
) -> Val {
let mut interp = Interpreter::new(types, &f, generics);
for (var, arg) in f.def().params.iter().zip(args) {
interp.vars[var.var()] = Some(arg.clone());
}
for instr in f.def().body.iter() {
interp.vars[instr.var.var()] = Some(interp.expr(&instr.expr));
match f {
Node::Transparent { refs, def } => {
let mut interp = Interpreter::new(types, refs, def, generics);
for (var, arg) in def.params.iter().zip(args) {
interp.vars[var.var()] = Some(arg.clone());
}
for instr in def.body.iter() {
interp.vars[instr.var.var()] = Some(interp.expr(&instr.expr));
}
interp.vars[def.ret.var()].as_ref().unwrap().clone()
}
Node::Opaque {
generics: _,
types: _,
params: _,
ret: _,
def,
} => {
let vals: Box<[Val]> = args.collect();
def.call(types, generics, &vals)
}
}
interp.vars[f.def().ret.var()].as_ref().unwrap().clone()
}

#[derive(Debug, thiserror::Error)]
pub enum Error {}

/// Guaranteed not to panic if `f` is valid.
pub fn interp(
f: impl FuncNode,
pub fn interp<'a, O: Opaque, T: Refs<'a, Opaque = O>>(
f: Node<'a, O, T>,
mut types: IndexSet<Ty>,
generics: &[id::Ty],
generics: &'a [id::Ty],
args: impl Iterator<Item = Val>,
) -> Result<Val, Error> {
// TODO: check that `generics` and `arg` are valid
Expand All @@ -314,21 +348,56 @@ mod tests {
use super::*;
use rose::{Function, Instr};

#[derive(Clone, Copy, Debug)]
type CustomRef<'a> = &'a dyn Fn(&IndexSet<Ty>, &[id::Ty], &[Val]) -> Val;
type CustomBox = Box<dyn Fn(&IndexSet<Ty>, &[id::Ty], &[Val]) -> Val>;

struct Custom<'a> {
f: CustomRef<'a>,
}

impl Opaque for Custom<'_> {
fn call(&self, types: &IndexSet<Ty>, generics: &[id::Ty], args: &[Val]) -> Val {
(self.f)(types, generics, args)
}
}

struct FuncInSlice<'a> {
custom: &'a [CustomBox],
funcs: &'a [Function],
id: id::Function,
}

impl FuncNode for FuncInSlice<'_> {
fn def(&self) -> &Function {
&self.funcs[self.id.function()]
impl<'a> Refs<'a> for FuncInSlice<'a> {
type Opaque = Custom<'a>;

fn get(&self, id: id::Function) -> Option<Node<'a, Custom<'a>, Self>> {
if id.function() < self.id.function() {
node(self.custom, self.funcs, id)
} else {
None
}
}
}

fn get(&self, id: id::Function) -> Option<Self> {
Some(Self {
funcs: self.funcs,
id,
fn node<'a>(
custom: &'a [CustomBox],
funcs: &'a [Function],
id: id::Function,
) -> Option<Node<'a, Custom<'a>, FuncInSlice<'a>>> {
let n = custom.len();
let i = id.function();
if i < n {
Some(Node::Opaque {
generics: &[],
types: &[],
params: &[],
ret: id::ty(0),
def: Custom { f: &custom[i] },
})
} else {
funcs.get(i - n).map(|def| Node::Transparent {
refs: FuncInSlice { custom, funcs, id },
def,
})
}
}
Expand All @@ -352,10 +421,7 @@ mod tests {
.into(),
}];
let answer = interp(
FuncInSlice {
funcs: &funcs,
id: id::function(0),
},
node(&[], &funcs, id::function(0)).unwrap(),
IndexSet::new(),
&[],
[val_f64(2.), val_f64(2.)].into_iter(),
Expand Down Expand Up @@ -407,15 +473,43 @@ mod tests {
},
];
let answer = interp(
FuncInSlice {
funcs: &funcs,
id: id::function(1),
},
node(&[], &funcs, id::function(1)).unwrap(),
IndexSet::new(),
&[],
[].into_iter(),
)
.unwrap();
assert_eq!(answer, val_f64(1764.));
}

#[test]
fn test_custom() {
let custom: [CustomBox; 1] = [Box::new(|_, _, args| {
Val::F64(Cell::new(args[0].f64().powf(args[1].f64())))
})];
let funcs = [Function {
generics: [].into(),
types: [Ty::F64].into(),
vars: [id::ty(0), id::ty(0), id::ty(0)].into(),
params: [id::var(0), id::var(1)].into(),
ret: id::var(2),
body: [Instr {
var: id::var(2),
expr: Expr::Call {
id: id::function(0),
generics: [].into(),
args: [id::var(0), id::var(1)].into(),
},
}]
.into(),
}];
let answer = interp(
node(&custom, &funcs, id::function(1)).unwrap(),
IndexSet::new(),
&[],
[val_f64(std::f64::consts::E), val_f64(std::f64::consts::PI)].into_iter(),
)
.unwrap();
assert_eq!(answer, val_f64(23.140692632779263));
}
}
Loading

0 comments on commit 2c23c1d

Please sign in to comment.