Skip to content

Commit

Permalink
use set_runtime_activity
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Oct 3, 2024
1 parent d7ede3b commit c5a04a6
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ icnf = construct(
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# inplace = true, # use the inplace version of functions
# resource = CUDALibs(), # process data by GPU
tspan = (0.0f0, 13.0f0), # have bigger time span
Expand Down
10 changes: 8 additions & 2 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ icnf = ContinuousNormalizingFlows.construct(
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
Expand Down Expand Up @@ -83,7 +86,10 @@ icnf2 = ContinuousNormalizingFlows.construct(
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
Expand Down
5 changes: 4 additions & 1 deletion src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ function construct(
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::ComputeMode = DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
inplace::Bool = false,
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
Expand Down
20 changes: 16 additions & 4 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,28 @@ Test.@testset "Call Tests" begin
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
]
data_types = Type{<:AbstractFloat}[Float32]
Expand Down
20 changes: 16 additions & 4 deletions test/fit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,28 @@ Test.@testset "Fit Tests" begin
ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Reverse, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
]
data_types = Type{<:AbstractFloat}[Float32]
Expand Down
5 changes: 4 additions & 1 deletion test/instability_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ Test.@testset "Instability" begin
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
Expand Down
5 changes: 4 additions & 1 deletion test/regression_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ Test.@testset "Regression Tests" begin
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(; mode = Enzyme.Forward, function_annotation = Enzyme.Const),
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
Expand Down

0 comments on commit c5a04a6

Please sign in to comment.