diff --git a/crates/transpose/src/lib.rs b/crates/transpose/src/lib.rs index c606cc5..40c3e35 100644 --- a/crates/transpose/src/lib.rs +++ b/crates/transpose/src/lib.rs @@ -615,6 +615,8 @@ impl<'a> Transpose<'a> { expr: Expr::Unary { op, arg: x }, }); self.keep(var); + let lin = self.accum(var); + self.resolve(lin); } } self.prims[var.var()] = Some(Src(None)); @@ -723,6 +725,8 @@ impl<'a> Transpose<'a> { }, }); self.keep(var); + let lin = self.accum(var); + self.resolve(lin); } } self.prims[var.var()] = Some(Src(None)); diff --git a/packages/core/src/index.test.ts b/packages/core/src/index.test.ts index 51e8ba0..58b18ea 100644 --- a/packages/core/src/index.test.ts +++ b/packages/core/src/index.test.ts @@ -484,6 +484,14 @@ describe("valid", () => { expect(h(false, 7)).toEqual({ x: 7, stuff: { a: null, b: false, c: 3 } }); }); + test("VJP with logic", () => { + const f = fn([Bool], Bool, (p) => not(p)); + const g = fn([Bool], Bool, (p) => vjp(f)(p).ret); + const h = interp(g); + expect(h(true)).toBe(false); + expect(h(false)).toBe(true); + }); + test("VJP with select on null", () => { const f = fn([Null], Null, () => select(true, Null, null, null)); const g = fn([], Null, () => vjp(f)(null).ret);