Skip to content

Commit

Permalink
Flatten reference scopes (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep authored Sep 7, 2023
1 parent 1a0b8dc commit e2d86f0
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 127 deletions.
78 changes: 25 additions & 53 deletions crates/autodiff/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,18 @@ impl Autodiff<'_> {

fn unpack(&mut self, var: id::Var) {
let i = var.var();
if let Ty::F64 = self.old_types[self.old_vars[i].ty()] {
let x = self.real(Expr::Member {
tuple: var,
member: RE,
});
let dx = self.dual(Expr::Member {
tuple: var,
member: DU,
});
self.unpacked[i] = Some((x, dx))
}
}

fn maybe_unpack(&mut self, var: id::Var) {
if self.unpacked[var.var()].is_none() {
self.unpack(var);
if self.unpacked[i].is_none() {
if let Ty::F64 = self.old_types[self.old_vars[i].ty()] {
let x = self.real(Expr::Member {
tuple: var,
member: RE,
});
let dx = self.dual(Expr::Member {
tuple: var,
member: DU,
});
self.unpacked[i] = Some((x, dx))
}
}
}

Expand Down Expand Up @@ -88,7 +84,7 @@ impl Autodiff<'_> {
fn block(mut self, orig: &[Instr]) -> Box<[Instr]> {
for Instr { var, expr } in orig {
self.instr(*var, expr);
self.maybe_unpack(*var);
self.unpack(*var);
}
self.code.into()
}
Expand Down Expand Up @@ -140,6 +136,14 @@ impl Autodiff<'_> {
var,
expr: Expr::Select { cond, then, els },
}),
&Expr::Read { var: orig } => self.code.push(Instr {
var,
expr: Expr::Read { var: orig },
}),
&Expr::Accum { shape } => self.code.push(Instr {
var,
expr: Expr::Accum { shape },
}),
&Expr::Ask { var } => self.code.push(Instr {
var,
expr: Expr::Ask { var },
Expand All @@ -148,6 +152,10 @@ impl Autodiff<'_> {
var,
expr: Expr::Add { accum, addend },
}),
&Expr::Resolve { var: container } => self.code.push(Instr {
var,
expr: Expr::Resolve { var: container },
}),

// less boring cases
Expr::Call { id, generics, args } => self.code.push(Instr {
Expand All @@ -169,42 +177,6 @@ impl Autodiff<'_> {
},
})
}
Expr::Read {
var: orig,
arg,
body,
ret,
} => {
let body = self.child(body);
self.code.push(Instr {
var,
expr: Expr::Read {
var: *orig,
arg: *arg,
body,
ret: *ret,
},
});
self.unpack(*ret); // not `maybe_unpack`, because those vars might be scoped wrong
}
Expr::Accum {
shape,
arg,
body,
ret,
} => {
let body = self.child(body);
self.code.push(Instr {
var,
expr: Expr::Accum {
shape: *shape,
arg: *arg,
body,
ret: *ret,
},
});
self.unpack(*ret); // not `maybe_unpack`, because those vars might be scoped wrong
}

// interesting cases
&Expr::F64 { val } => {
Expand Down
21 changes: 9 additions & 12 deletions crates/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,25 +176,16 @@ pub enum Expr {
/// Variable from `body` holding an array element.
ret: id::Var,
},
/// Scope for a `Ref` with `Constraint::Read`. Returns `Unit`.

/// Start a scope for a `Ref` with `Constraint::Read`.
Read {
/// Contents of the `Ref`.
var: id::Var,
/// Has type `Ref` with scope `arg` and inner type same as `var`.
arg: id::Var,
body: Box<[Instr]>,
/// Variable from `body` holding the result of this block; escapes into outer scope.
ret: id::Var,
},
/// Scope for a `Ref` with `Constraint::Accum`. Returns the final contents of the `Ref`.
/// Start a scope for a `Ref` with `Constraint::Accum`.
Accum {
/// Topology of the `Ref`.
shape: id::Var,
/// Has type `Ref` with scope `arg` and inner type same as `shape`.
arg: id::Var,
body: Box<[Instr]>,
/// Variable from `body` holding the result of this block; escapes into outer scope.
ret: id::Var,
},

/// Read from a `Ref` whose `scope` satisfies `Constraint::Read`.
Expand All @@ -209,6 +200,12 @@ pub enum Expr {
/// Must be of the `Ref`'s inner type.
addend: id::Var,
},

/// Consume a `Ref` to get its contained value.
Resolve {
/// The `Ref`, which must be in scope.
var: id::Var,
},
}

#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)]
Expand Down
25 changes: 5 additions & 20 deletions crates/interp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -256,32 +256,17 @@ impl<'a, 'b, O: Opaque, T: Refs<'a, Opaque = O>> Interpreter<'a, 'b, O, T> {
(0..n).map(|i| self.block(*arg, body, *ret, Val::Fin(i)).clone()),
))
}
Expr::Read {
var,
arg,
body,
ret,
} => {
let r = Val::Ref(Rc::new(self.get(*var).clone()));
self.block(*arg, body, *ret, r);
Val::Unit
}
Expr::Accum {
shape,
arg,
body,
ret,
} => {
let x = Val::Ref(Rc::new(self.get(*shape).zero()));
self.block(*arg, body, *ret, x.clone());
x.inner().clone()
}

&Expr::Read { var } => Val::Ref(Rc::new(self.get(var).clone())),
&Expr::Accum { shape } => Val::Ref(Rc::new(self.get(shape).zero())),

&Expr::Ask { var } => self.get(var).inner().clone(),
&Expr::Add { accum, addend } => {
self.get(accum).inner().add(self.get(addend));
Val::Unit
}

&Expr::Resolve { var } => self.get(var).inner().clone(),
}
}

Expand Down
45 changes: 3 additions & 42 deletions crates/web/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -384,46 +384,13 @@ pub fn pprint(f: &Func) -> Result<String, JsError> {
}
writeln!(&mut s, "}}")?
}
rose::Expr::Read {
var,
arg,
body,
ret,
} => {
writeln!(&mut s, "read x{} {{", var.var())?;
for _ in 0..spaces {
write!(&mut s, " ")?;
}
let x = arg.var();
writeln!(&mut s, " x{x}: T{}", def.vars[x].ty())?;
print_block(s, def, spaces + 2, body, *ret)?;
for _ in 0..spaces {
write!(&mut s, " ")?;
}
writeln!(&mut s, "}}")?
}
rose::Expr::Accum {
shape,
arg,
body,
ret,
} => {
writeln!(&mut s, "accum x{} {{", shape.var())?;
for _ in 0..spaces {
write!(&mut s, " ")?;
}
let x = arg.var();
writeln!(&mut s, " x{x}: T{}", def.vars[x].ty())?;
print_block(s, def, spaces + 2, body, *ret)?;
for _ in 0..spaces {
write!(&mut s, " ")?;
}
writeln!(&mut s, "}}")?
}
rose::Expr::Read { var } => writeln!(&mut s, "read x{}", var.var())?,
rose::Expr::Accum { shape } => writeln!(&mut s, "accum x{}", shape.var())?,
rose::Expr::Ask { var } => writeln!(&mut s, "ask x{}", var.var())?,
rose::Expr::Add { accum, addend } => {
writeln!(&mut s, "x{} += x{}", accum.var(), addend.var())?
}
rose::Expr::Resolve { var } => writeln!(&mut s, "resolve x{}", var.var())?,
}
Ok(())
}
Expand Down Expand Up @@ -1121,12 +1088,6 @@ impl Block {
let var = instr.var;
code.push(instr);
f.extra(var, code);
match &code.last().unwrap().expr {
&rose::Expr::Read { ret, .. } | &rose::Expr::Accum { ret, .. } => {
f.extra(ret, code);
}
_ => {}
}
}
}

Expand Down

0 comments on commit e2d86f0

Please sign in to comment.