Skip to content

Commit

Permalink
add DI support (#399)
Browse files Browse the repository at this point in the history
* add DI support

* fix

* fix jac

* use DI in tests

* use DI in benchmark

* replace ZygoteMatrixMode

* only test gradient

* more test

* broken for x on `ReverseDiff`

* correct broken

* jacobian_batched for DIVecJacMatrixMode

* fix

* jacobian_batched for DIJacVecMatrixMode

* fix

* fix jb
  • Loading branch information
prbzrg authored Apr 12, 2024
1 parent 2c53568 commit d8eb45a
Show file tree
Hide file tree
Showing 14 changed files with 249 additions and 258 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -30,7 +31,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -50,6 +50,7 @@ ComputationalResources = "0.3"
DataFrames = "1"
Dates = "1"
DifferentialEquations = "7"
DifferentiationInterface = "0.1"
Distributions = "0.25"
DistributionsAD = "0.6"
FillArrays = "1"
Expand All @@ -68,7 +69,6 @@ Random = "1"
SciMLBase = "2"
SciMLSensitivity = "7"
ScientificTypes = "3"
SparseDiffTools = "2"
Statistics = "1"
Zygote = "0.6"
julia = "1.9"
60 changes: 30 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
# ContinuousNormalizingFlows.jl

[![deps](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/deps.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows?t=2)
[![version](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/version.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![pkgeval](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/pkgeval.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev)
[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl)
[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.html)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)

Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia

## Citing

See [`CITATION.bib`](CITATION.bib) for the relevant reference(s).

## Installation

```julia
# ContinuousNormalizingFlows.jl

[![deps](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/deps.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows?t=2)
[![version](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/version.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![pkgeval](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/pkgeval.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev)
[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl)
[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.html)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)

Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia

## Citing

See [`CITATION.bib`](CITATION.bib) for the relevant reference(s).

## Installation

```julia
using Pkg
Pkg.add("ContinuousNormalizingFlows")
```

## Usage

```julia
```

## Usage

```julia
# Enable Logging
using Logging, TerminalLoggers
global_logger(TerminalLogger())
Expand All @@ -48,7 +48,7 @@ icnf = construct(
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
compute_mode = ZygoteMatrixMode, # process data in batches
compute_mode = DIVecJacMatrixMode, # process data in batches
tspan = (0.0f0, 13.0f0), # have bigger time span
steer_rate = 1.0f-1, # add random noise to end of the time span
# resource = CUDALibs(), # process data by GPU
Expand Down Expand Up @@ -105,4 +105,4 @@ lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "estimated")
axislegend(ax)
save("result-fig.svg", f)
save("result-fig.png", f)
```
```
4 changes: 4 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.2"
BenchmarkTools = "1"
ComponentArrays = "0.15"
DifferentiationInterface = "0.1"
Lux = "0.5"
PkgBenchmark = "0.2"
StableRNGs = "1"
Expand Down
29 changes: 18 additions & 11 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using ContinuousNormalizingFlows,
BenchmarkTools, ComponentArrays, Lux, PkgBenchmark, StableRNGs, Zygote
ADTypes,
BenchmarkTools,
ComponentArrays,
DifferentiationInterface,
Lux,
PkgBenchmark,
StableRNGs,
Zygote

SUITE = BenchmarkGroup()

Expand All @@ -26,7 +33,7 @@ icnf = construct(
nn,
nvars,
naugs;
compute_mode = ZygoteMatrixMode,
compute_mode = DIVecJacMatrixMode,
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand All @@ -41,24 +48,24 @@ diff_loss_tt(x) = loss(icnf, TestMode(), r, x, st)

diff_loss_tn(ps)
diff_loss_tt(ps)
Zygote.gradient(diff_loss_tn, ps)
Zygote.gradient(diff_loss_tt, ps)
DifferentiationInterface.gradient(diff_loss_tn, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, AutoZygote(), ps)
GC.gc()

SUITE["main"]["no_inplace"]["direct"]["train"] = @benchmarkable diff_loss_tn(ps)
SUITE["main"]["no_inplace"]["direct"]["test"] = @benchmarkable diff_loss_tt(ps)
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
@benchmarkable Zygote.gradient(diff_loss_tn, ps)
@benchmarkable DifferentiationInterface.gradient(diff_loss_tn, AutoZygote(), ps)
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
@benchmarkable Zygote.gradient(diff_loss_tt, ps)
@benchmarkable DifferentiationInterface.gradient(diff_loss_tt, AutoZygote(), ps)

icnf2 = construct(
RNODE,
nn,
nvars,
naugs;
inplace = true,
compute_mode = ZygoteMatrixMode,
compute_mode = DIVecJacMatrixMode,
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand All @@ -70,13 +77,13 @@ diff_loss_tt2(x) = loss(icnf2, TestMode(), r, x, st)

diff_loss_tn2(ps)
diff_loss_tt2(ps)
Zygote.gradient(diff_loss_tn2, ps)
Zygote.gradient(diff_loss_tt2, ps)
DifferentiationInterface.gradient(diff_loss_tn2, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt2, AutoZygote(), ps)
GC.gc()

SUITE["main"]["inplace"]["direct"]["train"] = @benchmarkable diff_loss_tn2(ps)
SUITE["main"]["inplace"]["direct"]["test"] = @benchmarkable diff_loss_tt2(ps)
SUITE["main"]["inplace"]["AD-1-order"]["train"] =
@benchmarkable Zygote.gradient(diff_loss_tn2, ps)
@benchmarkable DifferentiationInterface.gradient(diff_loss_tn2, AutoZygote(), ps)
SUITE["main"]["inplace"]["AD-1-order"]["test"] =
@benchmarkable Zygote.gradient(diff_loss_tt2, ps)
@benchmarkable DifferentiationInterface.gradient(diff_loss_tt2, AutoZygote(), ps)
2 changes: 1 addition & 1 deletion src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using AbstractDifferentiation,
DataFrames,
Dates,
DifferentialEquations,
DifferentiationInterface,
Distributions,
DistributionsAD,
FillArrays,
Expand All @@ -27,7 +28,6 @@ using AbstractDifferentiation,
ScientificTypes,
SciMLBase,
SciMLSensitivity,
SparseDiffTools,
Statistics,
Zygote

Expand Down
6 changes: 1 addition & 5 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ function construct(
Eye{data_type}(nvars + naugmented),
),
differentiation_backend::AbstractDifferentiation.AbstractBackend = AbstractDifferentiation.ZygoteBackend(),
autodiff_backend::ADTypes.AbstractADType = ifelse(
compute_mode <: SDJacVecMatrixMode,
AutoForwardDiff(),
AutoZygote(),
),
autodiff_backend::ADTypes.AbstractADType = AutoZygote(),
sol_kwargs::NamedTuple = (
save_everystep = false,
alg = Tsit5(; thread = OrdinaryDiffEq.True()),
Expand Down
Loading

0 comments on commit d8eb45a

Please sign in to comment.