Skip to content

Commit

Permalink
Fix transpose for not again
Browse files Browse the repository at this point in the history
  • Loading branch information
samestep committed Sep 23, 2023
1 parent 861dfaa commit ac8fa7f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
4 changes: 4 additions & 0 deletions crates/transpose/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,8 @@ impl<'a> Transpose<'a> {
expr: Expr::Unary { op, arg: x },
});
self.keep(var);
let lin = self.accum(var);
self.resolve(lin);
}
}
self.prims[var.var()] = Some(Src(None));
Expand Down Expand Up @@ -723,6 +725,8 @@ impl<'a> Transpose<'a> {
},
});
self.keep(var);
let lin = self.accum(var);
self.resolve(lin);
}
}
self.prims[var.var()] = Some(Src(None));
Expand Down
8 changes: 8 additions & 0 deletions packages/core/src/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,14 @@ describe("valid", () => {
expect(h(false, 7)).toEqual({ x: 7, stuff: { a: null, b: false, c: 3 } });
});

test("VJP with logic", () => {
const f = fn([Bool], Bool, (p) => not(p));
const g = fn([Bool], Bool, (p) => vjp(f)(p).ret);
const h = interp(g);
expect(h(true)).toBe(false);
expect(h(false)).toBe(true);
});

test("VJP with select on null", () => {
const f = fn([Null], Null, () => select(true, Null, null, null));
const g = fn([], Null, () => vjp(f)(null).ret);
Expand Down

0 comments on commit ac8fa7f

Please sign in to comment.