Skip to content

Commit

Permalink
Allow custom opaque functions
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Aug 29, 2023
1 parent 8a07848 commit 7e381b4
Show file tree
Hide file tree
Showing 6 changed files with 649 additions and 497 deletions.
21 changes: 17 additions & 4 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,28 @@ pub struct Function {
pub body: Box<[Instr]>,
}

/// Wrapper for a `Function` that knows how to resolve its `id::Function`s.
pub trait FuncNode {
fn def(&self) -> &Function;
pub trait Refs<'a> {
type Opaque;

fn get(&self, id: id::Function) -> Option<Self>
fn get(&self, id: id::Function) -> Option<Node<'a, Self::Opaque, Self>>
where
Self: Sized;
}

pub enum Node<'a, O, T: Refs<'a, Opaque = O>> {
Transparent {
refs: T,
def: &'a Function,
},
Opaque {
generics: &'a [EnumSet<Constraint>],
types: &'a [Ty],
params: &'a [id::Ty],
ret: id::Ty,
def: O,
},
}

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Debug)]
pub struct Instr {
Expand Down
1 change: 1 addition & 0 deletions crates/frontend/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ indexmap = "2"
lalrpop-util = "0.19"
logos = "0.13"
rose = { path = "../core" }
rose-interp = { path = "../interp" }
thiserror = "1"

[build-dependencies]
Expand Down
34 changes: 19 additions & 15 deletions crates/frontend/src/translate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,26 @@ pub struct Module<'input> {
funcs: IndexMap<&'input str, rose::Function>,
}

#[derive(Clone, Copy, Debug)]
pub struct FuncRef<'input, 'a> {
m: &'a Module<'input>,
id: id::Function,
pub enum Opaque {}

impl rose_interp::Opaque for Opaque {
fn call(
&self,
_: &IndexSet<rose::Ty>,
_: &[id::Ty],
_: &[rose_interp::Val],
) -> rose_interp::Val {
match *self {}
}
}

impl<'input, 'a> rose::FuncNode for FuncRef<'input, 'a> {
fn def(&self) -> &rose::Function {
&self.m.funcs[self.id.function()]
}
impl<'input, 'a> rose::Refs<'a> for &'a Module<'input> {
type Opaque = Opaque;

fn get(&self, id: id::Function) -> Option<Self> {
Some(Self { m: self.m, id })
fn get(&self, id: id::Function) -> Option<ir::Node<'a, Opaque, Self>> {
self.funcs
.get_index(id.function())
.map(|(_, def)| ir::Node::Transparent { refs: *self, def })
}
}

Expand All @@ -134,12 +141,9 @@ impl Module<'_> {
self.types.get(name)
}

pub fn get_func(&self, name: &str) -> Option<FuncRef> {
pub fn get_func(&self, name: &str) -> Option<ir::Node<Opaque, &Module>> {
let i = self.funcs.get_index_of(name)?;
Some(FuncRef {
m: self,
id: id::function(i),
})
ir::Refs::get(&self, id::function(i))
}
}

Expand Down
167 changes: 126 additions & 41 deletions crates/interp/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use indexmap::IndexSet;
use rose::{id, Binop, Expr, FuncNode, Ty, Unop};
use rose::{id, Binop, Expr, Function, Node, Refs, Ty, Unop};
use std::{cell::Cell, rc::Rc};

#[cfg(feature = "serde")]
Expand Down Expand Up @@ -129,24 +129,35 @@ fn resolve(typemap: &mut IndexSet<Ty>, generics: &[id::Ty], types: &[id::Ty], ty
id::ty(i)
}

struct Interpreter<'a, F: FuncNode> {
typemap: &'a mut IndexSet<Ty>,
f: &'a F, // reference instead of value because otherwise borrow checker complains in `fn block`
pub trait Opaque {
fn call(&self, types: &IndexSet<Ty>, generics: &[id::Ty], args: &[Val]) -> Val;
}

struct Interpreter<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> {
typemap: &'b mut IndexSet<Ty>,
refs: T,
def: &'a Function,
types: Vec<id::Ty>,
vars: Vec<Option<Val>>,
}

impl<'a, F: FuncNode> Interpreter<'a, F> {
fn new(typemap: &'a mut IndexSet<Ty>, f: &'a F, generics: &'a [id::Ty]) -> Self {
impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
fn new(
typemap: &'b mut IndexSet<Ty>,
refs: T,
def: &'a Function,
generics: &'b [id::Ty],
) -> Self {
let mut types = vec![];
for ty in f.def().types.iter() {
for ty in def.types.iter() {
types.push(resolve(typemap, generics, &types, ty));
}
Self {
typemap,
f,
refs,
def,
types,
vars: vec![None; f.def().vars.len()],
vars: vec![None; def.vars.len()],
}
}

Expand Down Expand Up @@ -229,10 +240,10 @@ impl<'a, F: FuncNode> Interpreter<'a, F> {
Expr::Call { id, generics, args } => {
let resolved: Vec<id::Ty> = generics.iter().map(|id| self.types[id.ty()]).collect();
let vals = args.iter().map(|id| self.vars[id.var()].clone().unwrap());
call(self.f.get(*id).unwrap(), self.typemap, &resolved, vals)
call(self.refs.get(*id).unwrap(), self.typemap, &resolved, vals)
}
Expr::For { arg, body, ret } => {
let n = match self.typemap[self.types[self.f.def().vars[arg.var()].ty()].ty()] {
let n = match self.typemap[self.types[self.def.vars[arg.var()].ty()].ty()] {
Ty::Fin { size } => size,
_ => unreachable!(),
};
Expand Down Expand Up @@ -279,30 +290,44 @@ impl<'a, F: FuncNode> Interpreter<'a, F> {
}

/// Assumes `generics` and `arg` are valid.
fn call(
f: impl FuncNode,
types: &mut IndexSet<Ty>,
generics: &[id::Ty],
fn call<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>>(
f: Node<'a, O, T>,
types: &'b mut IndexSet<Ty>,
generics: &'b [id::Ty],
args: impl Iterator<Item = Val>,
) -> Val {
let mut interp = Interpreter::new(types, &f, generics);
for (var, arg) in f.def().params.iter().zip(args) {
interp.vars[var.var()] = Some(arg.clone());
}
for instr in f.def().body.iter() {
interp.vars[instr.var.var()] = Some(interp.expr(&instr.expr));
match f {
Node::Transparent { refs, def } => {
let mut interp = Interpreter::new(types, refs, def, generics);
for (var, arg) in def.params.iter().zip(args) {
interp.vars[var.var()] = Some(arg.clone());
}
for instr in def.body.iter() {
interp.vars[instr.var.var()] = Some(interp.expr(&instr.expr));
}
interp.vars[def.ret.var()].as_ref().unwrap().clone()
}
Node::Opaque {
generics: _,
types: _,
params: _,
ret: _,
def,
} => {
let vals: Box<[Val]> = args.collect();
def.call(types, generics, &vals)
}
}
interp.vars[f.def().ret.var()].as_ref().unwrap().clone()
}

#[derive(Debug, thiserror::Error)]
pub enum Error {}

/// Guaranteed not to panic if `f` is valid.
pub fn interp(
f: impl FuncNode,
pub fn interp<'a, O: Opaque, T: Refs<'a, Opaque = O>>(
f: Node<'a, O, T>,
mut types: IndexSet<Ty>,
generics: &[id::Ty],
generics: &'a [id::Ty],
args: impl Iterator<Item = Val>,
) -> Result<Val, Error> {
// TODO: check that `generics` and `arg` are valid
Expand All @@ -314,21 +339,56 @@ mod tests {
use super::*;
use rose::{Function, Instr};

#[derive(Clone, Copy, Debug)]
type CustomRef<'a> = &'a dyn Fn(&IndexSet<Ty>, &[id::Ty], &[Val]) -> Val;
type CustomBox = Box<dyn Fn(&IndexSet<Ty>, &[id::Ty], &[Val]) -> Val>;

struct Custom<'a> {
f: CustomRef<'a>,
}

impl Opaque for Custom<'_> {
fn call(&self, types: &IndexSet<Ty>, generics: &[id::Ty], args: &[Val]) -> Val {
(self.f)(types, generics, args)
}
}

struct FuncInSlice<'a> {
custom: &'a [CustomBox],
funcs: &'a [Function],
id: id::Function,
}

impl FuncNode for FuncInSlice<'_> {
fn def(&self) -> &Function {
&self.funcs[self.id.function()]
impl<'a> Refs<'a> for FuncInSlice<'a> {
type Opaque = Custom<'a>;

fn get(&self, id: id::Function) -> Option<Node<'a, Custom<'a>, Self>> {
if id.function() < self.id.function() {
node(self.custom, self.funcs, id)
} else {
None
}
}
}

fn get(&self, id: id::Function) -> Option<Self> {
Some(Self {
funcs: self.funcs,
id,
fn node<'a>(
custom: &'a [CustomBox],
funcs: &'a [Function],
id: id::Function,
) -> Option<Node<'a, Custom<'a>, FuncInSlice<'a>>> {
let n = custom.len();
let i = id.function();
if i < n {
Some(Node::Opaque {
generics: &[],
types: &[],
params: &[],
ret: id::ty(0),
def: Custom { f: &custom[i] },
})
} else {
funcs.get(i - n).map(|def| Node::Transparent {
refs: FuncInSlice { custom, funcs, id },
def,
})
}
}
Expand All @@ -352,10 +412,7 @@ mod tests {
.into(),
}];
let answer = interp(
FuncInSlice {
funcs: &funcs,
id: id::function(0),
},
node(&[], &funcs, id::function(0)).unwrap(),
IndexSet::new(),
&[],
[val_f64(2.), val_f64(2.)].into_iter(),
Expand Down Expand Up @@ -407,15 +464,43 @@ mod tests {
},
];
let answer = interp(
FuncInSlice {
funcs: &funcs,
id: id::function(1),
},
node(&[], &funcs, id::function(1)).unwrap(),
IndexSet::new(),
&[],
[].into_iter(),
)
.unwrap();
assert_eq!(answer, val_f64(1764.));
}

#[test]
fn test_custom() {
let custom: [CustomBox; 1] = [Box::new(|_, _, args| {
Val::F64(Cell::new(args[0].f64().powf(args[1].f64())))
})];
let funcs = [Function {
generics: [].into(),
types: [Ty::F64].into(),
vars: [id::ty(0), id::ty(0), id::ty(0)].into(),
params: [id::var(0), id::var(1)].into(),
ret: id::var(2),
body: [Instr {
var: id::var(2),
expr: Expr::Call {
id: id::function(0),
generics: [].into(),
args: [id::var(0), id::var(1)].into(),
},
}]
.into(),
}];
let answer = interp(
node(&custom, &funcs, id::function(1)).unwrap(),
IndexSet::new(),
&[],
[val_f64(std::f64::consts::E), val_f64(std::f64::consts::PI)].into_iter(),
)
.unwrap();
assert_eq!(answer, val_f64(23.140692632779263));
}
}
Loading

0 comments on commit 7e381b4

Please sign in to comment.