Skip to content

Commit

Permalink
only enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Aug 3, 2024
1 parent b1eee06 commit 5a44fc4
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 29 deletions.
34 changes: 17 additions & 17 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,28 +29,28 @@ Test.@testset "Call Tests" begin
nvars_ = Int[2]
aug_steers = Bool[false, true]
inplaces = Bool[false, true]
adb_list =
AbstractDifferentiation.AbstractBackend[AbstractDifferentiation.ZygoteBackend(),
# AbstractDifferentiation.ReverseDiffBackend(),
# AbstractDifferentiation.ForwardDiffBackend(),
]
adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(),
# ADTypes.AutoEnzyme(Enzyme.Forward),
adb_list = AbstractDifferentiation.AbstractBackend[
# AbstractDifferentiation.ZygoteBackend(),
# AbstractDifferentiation.ReverseDiffBackend(),
# AbstractDifferentiation.ForwardDiffBackend(),
]
adtypes = ADTypes.AbstractADType[ADTypes.AutoEnzyme(Enzyme.Forward),
# ADTypes.AutoEnzyme(Enzyme.Reverse),
# ADTypes.AutoZygote(),
# ADTypes.AutoReverseDiff(),
# ADTypes.AutoForwardDiff(),
]
compute_modes = ContinuousNormalizingFlows.ComputeMode[
ContinuousNormalizingFlows.ADVecJacVectorMode(
AbstractDifferentiation.ZygoteBackend(),
),
ContinuousNormalizingFlows.ADJacVecVectorMode(
AbstractDifferentiation.ZygoteBackend(),
),
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.ADVecJacVectorMode(
# AbstractDifferentiation.ZygoteBackend(),
# ),
# ContinuousNormalizingFlows.ADJacVecVectorMode(
# AbstractDifferentiation.ZygoteBackend(),
# ),
# ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoEnzyme(Enzyme.Reverse)),
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoEnzyme(Enzyme.Forward)),
# ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoEnzyme(Enzyme.Reverse)),
Expand Down
24 changes: 12 additions & 12 deletions test/fit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,23 @@ Test.@testset "Fit Tests" begin
nvars_ = Int[2]
aug_steers = Bool[false, true]
inplaces = Bool[false, true]
adtypes = ADTypes.AbstractADType[ADTypes.AutoZygote(),
# ADTypes.AutoEnzyme(Enzyme.Forward),
adtypes = ADTypes.AbstractADType[ADTypes.AutoEnzyme(Enzyme.Forward),
# ADTypes.AutoEnzyme(Enzyme.Reverse),
# ADTypes.AutoZygote(),
# ADTypes.AutoReverseDiff(),
# ADTypes.AutoForwardDiff(),
]
compute_modes = ContinuousNormalizingFlows.ComputeMode[
ContinuousNormalizingFlows.ADVecJacVectorMode(
AbstractDifferentiation.ZygoteBackend(),
),
ContinuousNormalizingFlows.ADJacVecVectorMode(
AbstractDifferentiation.ZygoteBackend(),
),
ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.ADVecJacVectorMode(
# AbstractDifferentiation.ZygoteBackend(),
# ),
# ContinuousNormalizingFlows.ADJacVecVectorMode(
# AbstractDifferentiation.ZygoteBackend(),
# ),
# ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
# ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoEnzyme(Enzyme.Reverse)),
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoEnzyme(Enzyme.Forward)),
# ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoEnzyme(Enzyme.Reverse)),
Expand Down

0 comments on commit 5a44fc4

Please sign in to comment.