Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Oct 3, 2024
1 parent c5a04a6 commit ffe1684
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ icnf = construct(
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; mode = set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# compute_mode = DIJacVecMatrixMode(AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Forward), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
# inplace = true, # use the inplace version of functions
# resource = CUDALibs(), # process data by GPU
tspan = (0.0f0, 13.0f0), # have bigger time span
Expand Down
4 changes: 2 additions & 2 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ icnf = ContinuousNormalizingFlows.construct(
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
Expand Down Expand Up @@ -87,7 +87,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
inplace = true,
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
Expand Down
2 changes: 1 addition & 1 deletion src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function construct(
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::ComputeMode = DIJacVecMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
Expand Down
8 changes: 4 additions & 4 deletions test/call_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@ Test.@testset "Call Tests" begin
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Reverse),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Reverse),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
Expand Down
8 changes: 4 additions & 4 deletions test/fit_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,25 @@ Test.@testset "Fit Tests" begin
ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()),
ContinuousNormalizingFlows.DIVecJacVectorMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Reverse),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecVectorMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIVecJacMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Reverse),
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation = Enzyme.Const,
),
),
ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
Expand Down
2 changes: 1 addition & 1 deletion test/instability_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Test.@testset "Instability" begin
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
Expand Down
2 changes: 1 addition & 1 deletion test/regression_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Test.@testset "Regression Tests" begin
naugs;
compute_mode = ContinuousNormalizingFlows.DIJacVecMatrixMode(
ADTypes.AutoEnzyme(;
mode = set_runtime_activity(Enzyme.Forward),
mode = Enzyme.set_runtime_activity(Enzyme.Forward),
function_annotation = Enzyme.Const,
),
),
Expand Down

0 comments on commit ffe1684

Please sign in to comment.