Skip to content

Commit

Permalink
Fix interface (#98)
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt authored Mar 24, 2024
1 parent 0854f78 commit 93814a9
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
3 changes: 2 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ function value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R,
out, pb!! = rule(fx...)
@assert _typeof(tangent(out)) == T
ty = increment!!(tangent(out), ȳ)
return primal(out), pb!!(ty, map(tangent, fx)...)
v = copy(primal(out))
return v, pb!!(ty, map(tangent, fx)...)
end

"""
Expand Down
19 changes: 8 additions & 11 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
@testset "interface" begin
f = (x, y) -> x * y + sin(x) * cos(y)
x = 5.0
y = 4.0
rule = build_rrule(f, x, y)
v, grad = value_and_gradient!!(rule, f, x, y)
@test v f(x, y)
@test grad isa Tuple{NoTangent, Float64, Float64}

v, grad2 = value_and_pullback!!(rule, 1.0, f, x, y)
@test v f(x, y)
@test grad == grad2
@testset "$(typeof((f, x...)))" for (ȳ, f, x...) in Any[
(1.0, (x, y) -> x * y + sin(x) * cos(y), 5.0, 4.0),
([1.0, 1.0], x -> [sin(x), sin(2x)], 3.0),
]
rule = build_rrule(f, x...)
v, grad2 = value_and_pullback!!(rule, ȳ, f, x...)
@test v f(x...)
end
end

0 comments on commit 93814a9

Please sign in to comment.