From 6e014a8012c1a6f1873422bbac3db5f5ca7490fe Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sun, 24 Mar 2024 11:47:08 +0000 Subject: [PATCH] Fix interface --- src/interface.jl | 3 ++- test/interface.jl | 19 ++++++++----------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/src/interface.jl b/src/interface.jl index 35c487464..3e35018fb 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -10,7 +10,8 @@ function value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R, out, pb!! = rule(fx...) @assert _typeof(tangent(out)) == T ty = increment!!(tangent(out), ȳ) - return primal(out), pb!!(ty, map(tangent, fx)...) + v = copy(primal(out)) + return v, pb!!(ty, map(tangent, fx)...) end """ diff --git a/test/interface.jl b/test/interface.jl index e5409f084..94c35d659 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,13 +1,10 @@ @testset "interface" begin - f = (x, y) -> x * y + sin(x) * cos(y) - x = 5.0 - y = 4.0 - rule = build_rrule(f, x, y) - v, grad = value_and_gradient!!(rule, f, x, y) - @test v ≈ f(x, y) - @test grad isa Tuple{NoTangent, Float64, Float64} - - v, grad2 = value_and_pullback!!(rule, 1.0, f, x, y) - @test v ≈ f(x, y) - @test grad == grad2 + @testset "$(typeof((f, x...)))" for (ȳ, f, x...) in Any[ + (1.0, (x, y) -> x * y + sin(x) * cos(y), 5.0, 4.0), + ([1.0, 1.0], x -> [sin(x), sin(2x)], 3.0), + ] + rule = build_rrule(f, x...) + v, grad2 = value_and_pullback!!(rule, ȳ, f, x...) + @test v ≈ f(x...) + end end