Skip to content

Commit

Permalink
test with zygote checkpointed & forwarddiff
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Aug 1, 2023
1 parent 3145244 commit 4c1a96c
Showing 1 changed file with 72 additions and 4 deletions.
76 changes: 72 additions & 4 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,29 @@
# @test !isnothing(AbstractDifferentiation.hessian(adb, diff_loss, ps))
end

diff_loss2(x) = Zygote.checkpointed(diff_loss, x)
diff_loss3(x) = Zygote.forwarddiff(diff_loss, x)
diff_loss4(x) = Zygote.forwarddiff(diff_loss2, x)
@test !isnothing(Zygote.gradient(diff_loss, ps))
@test !isnothing(Zygote.jacobian(diff_loss, ps))
@test !isnothing(Zygote.forwarddiff(diff_loss, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss, ps))
# @test !isnothing(Zygote.hessian(diff_loss, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss, ps))
@test !isnothing(Zygote.gradient(diff_loss2, ps))
@test !isnothing(Zygote.jacobian(diff_loss2, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss2, ps))
# @test !isnothing(Zygote.hessian(diff_loss2, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss2, ps))
@test !isnothing(Zygote.gradient(diff_loss3, ps))
@test !isnothing(Zygote.jacobian(diff_loss3, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss3, ps))
# @test !isnothing(Zygote.hessian(diff_loss3, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss3, ps))
@test !isnothing(Zygote.gradient(diff_loss4, ps))
@test !isnothing(Zygote.jacobian(diff_loss4, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss4, ps))
# @test !isnothing(Zygote.hessian(diff_loss4, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss4, ps))

@test !isnothing(ReverseDiff.gradient(diff_loss, ps))
@test_throws MethodError !isnothing(ReverseDiff.jacobian(diff_loss, ps))
Expand Down Expand Up @@ -187,12 +204,29 @@
# @test !isnothing(AbstractDifferentiation.hessian(adb, diff_loss, ps))
end

diff_loss2(x) = Zygote.checkpointed(diff_loss, x)
diff_loss3(x) = Zygote.forwarddiff(diff_loss, x)
diff_loss4(x) = Zygote.forwarddiff(diff_loss2, x)
@test !isnothing(Zygote.gradient(diff_loss, ps))
@test !isnothing(Zygote.jacobian(diff_loss, ps))
@test !isnothing(Zygote.forwarddiff(diff_loss, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss, ps))
# @test !isnothing(Zygote.hessian(diff_loss, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss, ps))
@test !isnothing(Zygote.gradient(diff_loss2, ps))
@test !isnothing(Zygote.jacobian(diff_loss2, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss2, ps))
# @test !isnothing(Zygote.hessian(diff_loss2, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss2, ps))
@test !isnothing(Zygote.gradient(diff_loss3, ps))
@test !isnothing(Zygote.jacobian(diff_loss3, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss3, ps))
# @test !isnothing(Zygote.hessian(diff_loss3, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss3, ps))
@test !isnothing(Zygote.gradient(diff_loss4, ps))
@test !isnothing(Zygote.jacobian(diff_loss4, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss4, ps))
# @test !isnothing(Zygote.hessian(diff_loss4, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss4, ps))

@test !isnothing(ReverseDiff.gradient(diff_loss, ps))
@test_throws MethodError !isnothing(ReverseDiff.jacobian(diff_loss, ps))
Expand Down Expand Up @@ -290,12 +324,29 @@
# @test !isnothing(AbstractDifferentiation.hessian(adb, diff_loss, ps))
end

diff_loss2(x) = Zygote.checkpointed(diff_loss, x)
diff_loss3(x) = Zygote.forwarddiff(diff_loss, x)
diff_loss4(x) = Zygote.forwarddiff(diff_loss2, x)
@test !isnothing(Zygote.gradient(diff_loss, ps))
@test !isnothing(Zygote.jacobian(diff_loss, ps))
@test !isnothing(Zygote.forwarddiff(diff_loss, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss, ps))
# @test !isnothing(Zygote.hessian(diff_loss, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss, ps))
@test !isnothing(Zygote.gradient(diff_loss2, ps))
@test !isnothing(Zygote.jacobian(diff_loss2, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss2, ps))
# @test !isnothing(Zygote.hessian(diff_loss2, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss2, ps))
@test !isnothing(Zygote.gradient(diff_loss3, ps))
@test !isnothing(Zygote.jacobian(diff_loss3, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss3, ps))
# @test !isnothing(Zygote.hessian(diff_loss3, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss3, ps))
@test !isnothing(Zygote.gradient(diff_loss4, ps))
@test !isnothing(Zygote.jacobian(diff_loss4, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss4, ps))
# @test !isnothing(Zygote.hessian(diff_loss4, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss4, ps))

@test !isnothing(ReverseDiff.gradient(diff_loss, ps))
@test_throws MethodError !isnothing(ReverseDiff.jacobian(diff_loss, ps))
Expand Down Expand Up @@ -388,12 +439,29 @@
# @test !isnothing(AbstractDifferentiation.hessian(adb, diff_loss, ps))
end

diff_loss2(x) = Zygote.checkpointed(diff_loss, x)
diff_loss3(x) = Zygote.forwarddiff(diff_loss, x)
diff_loss4(x) = Zygote.forwarddiff(diff_loss2, x)
@test !isnothing(Zygote.gradient(diff_loss, ps))
@test !isnothing(Zygote.jacobian(diff_loss, ps))
@test !isnothing(Zygote.forwarddiff(diff_loss, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss, ps))
# @test !isnothing(Zygote.hessian(diff_loss, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss, ps))
@test !isnothing(Zygote.gradient(diff_loss2, ps))
@test !isnothing(Zygote.jacobian(diff_loss2, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss2, ps))
# @test !isnothing(Zygote.hessian(diff_loss2, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss2, ps))
@test !isnothing(Zygote.gradient(diff_loss3, ps))
@test !isnothing(Zygote.jacobian(diff_loss3, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss3, ps))
# @test !isnothing(Zygote.hessian(diff_loss3, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss3, ps))
@test !isnothing(Zygote.gradient(diff_loss4, ps))
@test !isnothing(Zygote.jacobian(diff_loss4, ps))
# @test !isnothing(Zygote.diaghessian(diff_loss4, ps))
# @test !isnothing(Zygote.hessian(diff_loss4, ps))
# @test !isnothing(Zygote.hessian_reverse(diff_loss4, ps))

@test !isnothing(ReverseDiff.gradient(diff_loss, ps))
@test_throws MethodError !isnothing(ReverseDiff.jacobian(diff_loss, ps))
Expand Down

0 comments on commit 4c1a96c

Please sign in to comment.