Skip to content

Commit

Permalink
fix ad for benchmark
Browse files Browse the repository at this point in the history
  • Loading branch information
prbzrg committed Jan 8, 2024
1 parent cc6f3e8 commit 1559c8e
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ n = 2^6
r = rand(Float32, nvars, n)
nn = Dense(nvars => nvars, tanh)

icnf = construct(RNODE, nn, nvars; compute_mode = ZygoteMatrixMode)
icnf = construct(
RNODE,
nn,
nvars;
compute_mode = ZygoteMatrixMode,
sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad,
)
ps, st = Lux.setup(icnf.rng, icnf)
ps = ComponentArray(ps)

Expand All @@ -37,7 +43,14 @@ SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
@benchmarkable Zygote.gradient(loss, icnf, TestMode(), r, ps, st)

icnf2 = construct(RNODE, nn, nvars; compute_mode = ZygoteMatrixMode, inplace = true)
icnf2 = construct(
RNODE,
nn,
nvars;
compute_mode = ZygoteMatrixMode,
inplace = true,
sol_kwargs = ContinuousNormalizingFlows.sol_kwargs_defaults.medium_noad,
)

loss(icnf2, TrainMode(), r, ps, st)
loss(icnf2, TestMode(), r, ps, st)
Expand Down

0 comments on commit 1559c8e

Please sign in to comment.