Skip to content

Commit

Permalink
use Enzyme: test it & make it default
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Sep 7, 2024
1 parent 44e6f38 commit fbe1660
Show file tree
Hide file tree
Showing 12 changed files with 51 additions and 15 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Expand Down Expand Up @@ -51,6 +52,7 @@ Dates = "1"
DifferentiationInterface = "0.5"
Distributions = "0.25"
DistributionsAD = "0.6"
Enzyme = "0.12"
FillArrays = "1"
LinearAlgebra = "1"
Lux = "0.5"
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,15 @@ n_in = nvars + naugs # with augmentation
n = 1024

# Model
using ContinuousNormalizingFlows, Lux, ADTypes, Zygote #, CUDA, ComputationalResources
using ContinuousNormalizingFlows, Lux, ADTypes, Enzyme #, CUDA, ComputationalResources
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
# icnf = construct(RNODE, nn, nvars) # use defaults
icnf = construct(
RNODE,
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
compute_mode = DIVecJacMatrixMode(AutoZygote()), # process data in batches
compute_mode = DIJacVecMatrixMode(AutoEnzyme(; function_annotation = Enzyme.Const)), # process data in batches
tspan = (0.0f0, 13.0f0), # have bigger time span
steer_rate = 1.0f-1, # add random noise to end of the time span
# resource = CUDALibs(), # process data by GPU
Expand Down
2 changes: 2 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Expand All @@ -13,6 +14,7 @@ ADTypes = "1"
BenchmarkTools = "1"
ComponentArrays = "0.15"
DifferentiationInterface = "0.5"
Enzyme = "0.12"
Lux = "0.5"
PkgBenchmark = "0.2"
StableRNGs = "1"
Expand Down
11 changes: 9 additions & 2 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,15 @@ import ADTypes,
BenchmarkTools,
ComponentArrays,
DifferentiationInterface,
Enzyme,
Lux,
PkgBenchmark,
StableRNGs,
Zygote,
ContinuousNormalizingFlows

Enzyme.API.runtimeActivity!(true)

SUITE = BenchmarkTools.BenchmarkGroup()

SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"])
Expand All @@ -33,7 +36,9 @@ icnf = ContinuousNormalizingFlows.construct(
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand Down Expand Up @@ -79,7 +84,9 @@ icnf2 = ContinuousNormalizingFlows.construct(
nvars,
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand Down
1 change: 1 addition & 0 deletions src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import AbstractDifferentiation,
DifferentiationInterface,
Distributions,
DistributionsAD,
Enzyme,
FillArrays,
LinearAlgebra,
Lux,
Expand Down
4 changes: 3 additions & 1 deletion src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ function construct(
nvars::Int,
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::ComputeMode = ADVecJacVectorMode(AbstractDifferentiation.ZygoteBackend()),
compute_mode::ComputeMode = DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
inplace::Bool = false,
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(),
Expand Down
6 changes: 0 additions & 6 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ struct DIJacVecMatrixMode{ADBack <: ADTypes.AbstractADType} <: DIMatrixMode{ADBa
adback::ADBack
end

Base.Base.@deprecate_binding SDVecJacMatrixMode DIVecJacMatrixMode true
Base.Base.@deprecate_binding SDJacVecMatrixMode DIJacVecMatrixMode true

Base.Base.@deprecate_binding ZygoteVectorMode DIVecJacVectorMode true
Base.Base.@deprecate_binding ZygoteMatrixMode DIVecJacMatrixMode true

abstract type AbstractICNF{
T <: AbstractFloat,
CM <: ComputeMode,
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
Expand All @@ -30,6 +31,7 @@ ComputationalResources = "0.3"
DataFrames = "1"
DifferentiationInterface = "0.5"
Distributions = "0.25"
Enzyme = "0.12"
GPUArraysCore = "0.1"
JET = "0.9"
Lux = "0.5"
Expand Down
12 changes: 12 additions & 0 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,18 @@ Test.@testset "Call Tests" begin
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
]
data_types = Type{<:AbstractFloat}[Float32]
resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()]
Expand Down
12 changes: 12 additions & 0 deletions test/fit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,18 @@ Test.@testset "Fit Tests" begin
ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
]
data_types = Type{<:AbstractFloat}[Float32]
resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()]
Expand Down
4 changes: 3 additions & 1 deletion test/instability_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ Test.@testset "Instability" begin
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; function_annotation = Enzyme.Const),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import AbstractDifferentiation,
DataFrames,
DifferentiationInterface,
Distributions,
Enzyme,
GPUArraysCore,
JET,
Logging,
Expand All @@ -20,6 +21,8 @@ import AbstractDifferentiation,
Zygote,
ContinuousNormalizingFlows

Enzyme.API.runtimeActivity!(true)

GROUP = get(ENV, "GROUP", "All")
USE_GPU = get(ENV, "USE_GPU", "Yes") == "Yes"

Expand All @@ -28,9 +31,6 @@ if (GROUP == "All")

debuglogger = TerminalLoggers.TerminalLogger(stderr, Logging.Debug)
Logging.global_logger(debuglogger)
else
warnlogger = TerminalLoggers.TerminalLogger(stderr, Logging.Warn)
Logging.global_logger(warnlogger)
end

Test.@testset "Overall" begin
Expand Down

0 comments on commit fbe1660

Please sign in to comment.