Skip to content

Commit

Permalink
Support AD back in compute_mode (#422)
Browse files Browse the repository at this point in the history
* update types and usages

* update icnf

* fix readme

* fix type check

* fix
  • Loading branch information
prbzrg authored May 23, 2024
1 parent 413cade commit 3a0a570
Show file tree
Hide file tree
Showing 11 changed files with 133 additions and 118 deletions.
62 changes: 31 additions & 31 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
# ContinuousNormalizingFlows.jl

[![deps](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/deps.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows?t=2)
[![version](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/version.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![pkgeval](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/pkgeval.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev)
[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl)
[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.html)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)

Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia

## Citing

See [`CITATION.bib`](CITATION.bib) for the relevant reference(s).

## Installation

```julia
# ContinuousNormalizingFlows.jl

[![deps](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/deps.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows?t=2)
[![version](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/version.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![pkgeval](https://juliahub.com/docs/General/ContinuousNormalizingFlows/stable/pkgeval.svg)](https://juliahub.com/ui/Packages/General/ContinuousNormalizingFlows)
[![Stable](https://img.shields.io/badge/docs-stable-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/stable)
[![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://impICNF.github.io/ContinuousNormalizingFlows.jl/dev)
[![Build Status](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/impICNF/ContinuousNormalizingFlows.jl/actions/workflows/CI.yml?query=branch%3Amain)
[![Coverage](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl/branch/main/graph/badge.svg)](https://codecov.io/gh/impICNF/ContinuousNormalizingFlows.jl)
[![Coverage](https://coveralls.io/repos/github/impICNF/ContinuousNormalizingFlows.jl/badge.svg?branch=main)](https://coveralls.io/github/impICNF/ContinuousNormalizingFlows.jl?branch=main)
[![Aqua](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)
[![PkgEval](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.svg)](https://JuliaCI.github.io/NanosoldierReports/pkgeval_badges/C/ContinuousNormalizingFlows.html)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor%27s%20Guide-blueviolet)](https://github.com/SciML/ColPrac)

Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia

## Citing

See [`CITATION.bib`](CITATION.bib) for the relevant reference(s).

## Installation

```julia
using Pkg
Pkg.add("ContinuousNormalizingFlows")
```

## Usage

```julia
```

## Usage

```julia
# Enable Logging
using Logging, TerminalLoggers
global_logger(TerminalLogger())
Expand All @@ -40,15 +40,15 @@ n_in = nvars + naugs # with augmentation
n = 1024

# Model
using ContinuousNormalizingFlows, Lux #, CUDA, ComputationalResources
using ContinuousNormalizingFlows, Lux, ADTypes, Zygote #, CUDA, ComputationalResources
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
# icnf = construct(RNODE, nn, nvars) # use defaults
icnf = construct(
RNODE,
nn,
nvars, # number of variables
naugs; # number of augmented dimensions
compute_mode = DIVecJacMatrixMode, # process data in batches
compute_mode = DIVecJacMatrixMode(AutoZygote()), # process data in batches
tspan = (0.0f0, 13.0f0), # have bigger time span
steer_rate = 1.0f-1, # add random noise to end of the time span
# resource = CUDALibs(), # process data by GPU
Expand Down Expand Up @@ -105,4 +105,4 @@ lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "estimated")
axislegend(ax)
save("result-fig.svg", f)
save("result-fig.png", f)
```
```
4 changes: 2 additions & 2 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ icnf = ContinuousNormalizingFlows.construct(
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand Down Expand Up @@ -79,7 +79,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
nvars,
naugs;
inplace = true,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
Expand Down
11 changes: 3 additions & 8 deletions src/base_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ function construct(
nvars::Int,
naugmented::Int = 0;
data_type::Type{<:AbstractFloat} = Float32,
compute_mode::Type{<:ComputeMode} = ADVecJacVectorMode,
compute_mode::ComputeMode = ADVecJacVectorMode(AbstractDifferentiation.ZygoteBackend()),
inplace::Bool = false,
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(),
Expand All @@ -18,8 +18,6 @@ function construct(
FillArrays.Zeros{data_type}(nvars + naugmented),
FillArrays.Eye{data_type}(nvars + naugmented),
),
differentiation_backend::AbstractDifferentiation.AbstractBackend = AbstractDifferentiation.ZygoteBackend(),
autodiff_backend::ADTypes.AbstractADType = ADTypes.AutoZygote(),
sol_kwargs::NamedTuple = (save_everystep = false,),
rng::Random.AbstractRNG = rng_AT(resource),
λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE}
Expand All @@ -38,7 +36,7 @@ function construct(

ICNF{
data_type,
compute_mode,
typeof(compute_mode),
inplace,
cond,
!iszero(naugmented),
Expand All @@ -53,21 +51,18 @@ function construct(
typeof(tspan),
typeof(steerdist),
typeof(epsdist),
typeof(differentiation_backend),
typeof(autodiff_backend),
typeof(sol_kwargs),
typeof(rng),
}(
nn,
nvars,
naugmented,
compute_mode,
resource,
basedist,
tspan,
steerdist,
epsdist,
differentiation_backend,
autodiff_backend,
sol_kwargs,
rng,
λ₁,
Expand Down
16 changes: 7 additions & 9 deletions src/exts/mlj_ext/core_cond_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@ mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF}

use_batch::Bool
batch_size::Int

compute_mode::Type{<:ComputeMode}
end

function CondICNFModel(
m::AbstractICNF{<:AbstractFloat, CM},
m::AbstractICNF,
loss::Function = loss;
optimizers::Tuple = (Optimisers.Lion(),),
n_epochs::Int = 300,
adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(),
use_batch::Bool = true,
batch_size::Int = 32,
) where {CM <: ComputeMode}
CondICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size, CM)
)
CondICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size)
end

function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
Expand All @@ -45,7 +43,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
tst_overall = @timed for opt in model.optimizers
tst_epochs = @timed for ep in 1:(model.n_epochs)
if model.use_batch
if model.compute_mode <: VectorMode
if model.m.compute_mode isa VectorMode
data = MLUtils.DataLoader(
(x, y);
batchsize = -1,
Expand All @@ -54,7 +52,7 @@ function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY)
parallel = false,
buffer = false,
)
elseif model.compute_mode <: MatrixMode
elseif model.m.compute_mode isa MatrixMode
data = MLUtils.DataLoader(
(x, y);
batchsize = model.batch_size,
Expand Down Expand Up @@ -112,13 +110,13 @@ function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew)
end
(ps, st) = fitresult

tst = @timed if model.compute_mode <: VectorMode
tst = @timed if model.m.compute_mode isa VectorMode
logp̂x = broadcast(
(x, y) -> first(inference(model.m, TestMode(), x, y, ps, st)),
eachcol(xnew),
eachcol(ynew),
)
elseif model.compute_mode <: MatrixMode
elseif model.m.compute_mode isa MatrixMode
logp̂x = first(inference(model.m, TestMode(), xnew, ynew, ps, st))
else
error("Not Implemented")
Expand Down
16 changes: 7 additions & 9 deletions src/exts/mlj_ext/core_icnf.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,18 @@ mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF}

use_batch::Bool
batch_size::Int

compute_mode::Type{<:ComputeMode}
end

function ICNFModel(
m::AbstractICNF{<:AbstractFloat, CM},
m::AbstractICNF,
loss::Function = loss;
optimizers::Tuple = (Optimisers.Lion(),),
n_epochs::Int = 300,
adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(),
use_batch::Bool = true,
batch_size::Int = 32,
) where {CM <: ComputeMode}
ICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size, CM)
)
ICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size)
end

function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
Expand All @@ -43,7 +41,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
tst_overall = @timed for opt in model.optimizers
tst_epochs = @timed for ep in 1:(model.n_epochs)
if model.use_batch
if model.compute_mode <: VectorMode
if model.m.compute_mode isa VectorMode
data = MLUtils.DataLoader(
(x,);
batchsize = -1,
Expand All @@ -52,7 +50,7 @@ function MLJModelInterface.fit(model::ICNFModel, verbosity, X)
parallel = false,
buffer = false,
)
elseif model.compute_mode <: MatrixMode
elseif model.m.compute_mode isa MatrixMode
data = MLUtils.DataLoader(
(x,);
batchsize = model.batch_size,
Expand Down Expand Up @@ -107,9 +105,9 @@ function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew)
end
(ps, st) = fitresult

tst = @timed if model.compute_mode <: VectorMode
tst = @timed if model.m.compute_mode isa VectorMode
logp̂x = broadcast(x -> first(inference(model.m, TestMode(), x, ps, st)), eachcol(xnew))
elseif model.compute_mode <: MatrixMode
elseif model.m.compute_mode isa MatrixMode
logp̂x = first(inference(model.m, TestMode(), xnew, ps, st))
else
error("Not Implemented")
Expand Down
Loading

0 comments on commit 3a0a570

Please sign in to comment.