Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DI support #399

Merged
merged 15 commits into from
Apr 12, 2024
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -30,7 +31,6 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
SparseDiffTools = "47a9eef4-7e08-11e9-0b38-333d64bd3804"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -50,6 +50,7 @@ ComputationalResources = "0.3"
DataFrames = "1"
Dates = "1"
DifferentialEquations = "7"
DifferentiationInterface = "0.1"
Distributions = "0.25"
DistributionsAD = "0.6"
FillArrays = "1"
Expand All @@ -68,7 +69,6 @@ Random = "1"
SciMLBase = "2"
SciMLSensitivity = "7"
ScientificTypes = "3"
SparseDiffTools = "2"
Statistics = "1"
Zygote = "0.6"
julia = "1.9"
60 changes: 30 additions & 30 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
# ContinuousNormalizingFlows.jl

[![deps](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/deps.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows?t=2)
[![version](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/version.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![pkgeval](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/pkgeval.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev)
[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl)
[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.html)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)

Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia

## Citing

See [`CITATION.bib`](CITATION.bib) for the relevant reference(s).

## Installation

```julia
# ContinuousNormalizingFlows.jl
[![deps](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/deps.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows?t=2)
[![version](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/version.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![pkgeval](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/pkgeval.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev)
[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl)
[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.html)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia
## Citing
See [`CITATION.bib`](CITATION.bib) for the relevant reference(s).
## Installation
```julia
using Pkg
prbzrg marked this conversation as resolved.
Show resolved Hide resolved
Pkg.add("ContinuousNormalizingFlows")
```

## Usage

```julia
```
## Usage
```julia
# Enable Logging
prbzrg marked this conversation as resolved.
Show resolved Hide resolved
using Logging, TerminalLoggers
global_logger(TerminalLogger())
Expand All @@ -48,7 +48,7 @@ icnf = construct(
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
compute_mode = ZygoteMatrixMode, # process data in batches
compute_mode = DIVecJacMatrixMode, # process data in batches
tspan = (0.0f0, 13.0f0), # have bigger time span
steer_rate = 1.0f-1, # add random noise to end of the time span
# resource = CUDALibs(), # process data by GPU
Expand Down Expand Up @@ -105,4 +105,4 @@ lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "estimated")
axislegend(ax)
save("result-fig.svg", f)
save("result-fig.png", f)
```
```
4 changes: 4 additions & 0 deletions benchmark/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.2"
BenchmarkTools = "1"
ComponentArrays = "0.15"
DifferentiationInterface = "0.1"
Lux = "0.5"
PkgBenchmark = "0.2"
StableRNGs = "1"
Expand Down
29 changes: 18 additions & 11 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
using ContinuousNormalizingFlows,
BenchmarkTools, ComponentArrays, Lux, PkgBenchmark, StableRNGs, Zygote
ADTypes,
BenchmarkTools,
ComponentArrays,
DifferentiationInterface,
Lux,
PkgBenchmark,
StableRNGs,
Zygote

SUITE = BenchmarkGroup()

Expand All @@ -26,7 +33,7 @@ icnf = construct(
nn,
nvars,
naugs;
compute_mode = ZygoteMatrixMode,
compute_mode = DIVecJacMatrixMode,
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand All @@ -41,24 +48,24 @@ diff_loss_tt(x) = loss(icnf, TestMode(), r, x, st)

diff_loss_tn(ps)
diff_loss_tt(ps)
Zygote.gradient(diff_loss_tn, ps)
Zygote.gradient(diff_loss_tt, ps)
DifferentiationInterface.gradient(diff_loss_tn, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, AutoZygote(), ps)
GC.gc()

SUITE["main"]["no_inplace"]["direct"]["train"] = @benchmarkable diff_loss_tn(ps)
SUITE["main"]["no_inplace"]["direct"]["test"] = @benchmarkable diff_loss_tt(ps)
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
@benchmarkable Zygote.gradient(diff_loss_tn, ps)
@benchmarkable DifferentiationInterface.gradient(diff_loss_tn, AutoZygote(), ps)
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
@benchmarkable Zygote.gradient(diff_loss_tt, ps)
@benchmarkable DifferentiationInterface.gradient(diff_loss_tt, AutoZygote(), ps)

icnf2 = construct(
RNODE,
nn,
nvars,
naugs;
inplace = true,
compute_mode = ZygoteMatrixMode,
compute_mode = DIVecJacMatrixMode,
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand All @@ -70,13 +77,13 @@ diff_loss_tt2(x) = loss(icnf2, TestMode(), r, x, st)

diff_loss_tn2(ps)
diff_loss_tt2(ps)
Zygote.gradient(diff_loss_tn2, ps)
Zygote.gradient(diff_loss_tt2, ps)
DifferentiationInterface.gradient(diff_loss_tn2, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt2, AutoZygote(), ps)
GC.gc()

SUITE["main"]["inplace"]["direct"]["train"] = @benchmarkable diff_loss_tn2(ps)
SUITE["main"]["inplace"]["direct"]["test"] = @benchmarkable diff_loss_tt2(ps)
SUITE["main"]["inplace"]["AD-1-order"]["train"] =
@benchmarkable Zygote.gradient(diff_loss_tn2, ps)
@benchmarkable DifferentiationInterface.gradient(diff_loss_tn2, AutoZygote(), ps)
SUITE["main"]["inplace"]["AD-1-order"]["test"] =
@benchmarkable Zygote.gradient(diff_loss_tt2, ps)
@benchmarkable DifferentiationInterface.gradient(diff_loss_tt2, AutoZygote(), ps)
2 changes: 1 addition & 1 deletion src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using AbstractDifferentiation,
DataFrames,
Dates,
DifferentialEquations,
DifferentiationInterface,
Distributions,
DistributionsAD,
FillArrays,
Expand All @@ -27,7 +28,6 @@ using AbstractDifferentiation,
ScientificTypes,
SciMLBase,
SciMLSensitivity,
SparseDiffTools,
Statistics,
Zygote

Expand Down
6 changes: 1 addition & 5 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,7 @@ function construct(
Eye{data_type}(nvars + naugmented),
),
differentiation_backend::AbstractDifferentiation.AbstractBackend = AbstractDifferentiation.ZygoteBackend(),
autodiff_backend::ADTypes.AbstractADType = ifelse(
compute_mode <: SDJacVecMatrixMode,
AutoForwardDiff(),
AutoZygote(),
),
autodiff_backend::ADTypes.AbstractADType = AutoZygote(),
sol_kwargs::NamedTuple = (
save_everystep = false,
alg = Tsit5(; thread = OrdinaryDiffEq.True()),
Expand Down
Loading
Loading