diff --git a/README.md b/README.md index 56650216..f1fb963a 100644 --- a/README.md +++ b/README.md @@ -21,88 +21,88 @@ See [`CITATION.bib`](CITATION.bib) for the relevant reference(s). ## Installation ```julia -using Pkg -Pkg.add("ContinuousNormalizingFlows") +using Pkg +Pkg.add("ContinuousNormalizingFlows") ``` ## Usage ```julia -# Enable Logging -using Logging, TerminalLoggers -global_logger(TerminalLogger()) - -# Parameters -nvars = 1 -naugs = nvars -# n_in = nvars # without augmentation -n_in = nvars + naugs # with augmentation -n = 1024 - -# Model -using ContinuousNormalizingFlows, Lux, ADTypes, Zygote #, CUDA, ComputationalResources -nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh)) -# icnf = construct(RNODE, nn, nvars) # use defaults -icnf = construct( - RNODE, - nn, - nvars, # number of variables - naugs; # number of augmented dimensions - compute_mode = DIVecJacMatrixMode(AutoZygote()), # process data in batches - tspan = (0.0f0, 13.0f0), # have bigger time span - steer_rate = 1.0f-1, # add random noise to end of the time span - # resource = CUDALibs(), # process data by GPU - # inplace = true, # use the inplace version of functions -) - -# Data -using Distributions -data_dist = Beta{Float32}(2.0f0, 4.0f0) -r = rand(data_dist, nvars, n) -r = convert.(Float32, r) - -# Fit It -using DataFrames, MLJBase #, ForwardDiff, ADTypes, OptimizationOptimisers -df = DataFrame(transpose(r), :auto) -# model = ICNFModel(icnf) # use defaults -model = ICNFModel( - icnf; - batch_size = 256, # have bigger batchs - # n_epochs = 100, # have less epochs - # optimizers = (Adam(),), # use a different optimizer - # adtype = AutoForwardDiff(), # use ForwardDiff -) -mach = machine(model, df) -fit!(mach) -ps, st = fitted_params(mach) - -# Store It -using JLD2, UnPack -jldsave("fitted.jld2"; ps, st) # save -@unpack ps, st = load("fitted.jld2") # load - -# Use It -d = ICNFDist(icnf, TestMode(), ps, st) # direct way -# d = ICNFDist(mach, TestMode()) # alternative way -actual_pdf = pdf.(data_dist, vec(r)) -estimated_pdf = pdf(d, r) -new_data = rand(d, n) - -# Evaluate It -using Distances -mad_ = meanad(estimated_pdf, actual_pdf) -msd_ = msd(estimated_pdf, actual_pdf) -tv_dis = totalvariation(estimated_pdf, actual_pdf) / n -res_df = DataFrame(; mad_, msd_, tv_dis) -display(res_df) - -# Plot It -using CairoMakie -f = Figure() -ax = Makie.Axis(f[1, 1]; title = "Result") -lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(data_dist, x); label = "actual") -lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "estimated") -axislegend(ax) -save("result-fig.svg", f) -save("result-fig.png", f) +# Enable Logging +using Logging, TerminalLoggers +global_logger(TerminalLogger()) + +# Parameters +nvars = 1 +naugs = nvars +# n_in = nvars # without augmentation +n_in = nvars + naugs # with augmentation +n = 1024 + +# Model +using ContinuousNormalizingFlows, Lux, ADTypes, Zygote #, CUDA, ComputationalResources +nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh)) +# icnf = construct(RNODE, nn, nvars) # use defaults +icnf = construct( + RNODE, + nn, + nvars, # number of variables + naugs; # number of augmented dimensions + compute_mode = DIVecJacMatrixMode(AutoZygote()), # process data in batches + tspan = (0.0f0, 13.0f0), # have bigger time span + steer_rate = 1.0f-1, # add random noise to end of the time span + # resource = CUDALibs(), # process data by GPU + # inplace = true, # use the inplace version of functions +) + +# Data +using Distributions +data_dist = Beta{Float32}(2.0f0, 4.0f0) +r = rand(data_dist, nvars, n) +r = convert.(Float32, r) + +# Fit It +using DataFrames, MLJBase #, ForwardDiff, ADTypes, OptimizationOptimisers +df = DataFrame(transpose(r), :auto) +# model = ICNFModel(icnf) # use defaults +model = ICNFModel( + icnf; + batch_size = 256, # have bigger batchs + # n_epochs = 100, # have less epochs + # optimizers = (Adam(),), # use a different optimizer + # adtype = AutoForwardDiff(), # use ForwardDiff +) +mach = machine(model, df) +fit!(mach) +ps, st = fitted_params(mach) + +# Store It +using JLD2, UnPack +jldsave("fitted.jld2"; ps, st) # save +@unpack ps, st = load("fitted.jld2") # load + +# Use It +d = ICNFDist(icnf, TestMode(), ps, st) # direct way +# d = ICNFDist(mach, TestMode()) # alternative way +actual_pdf = pdf.(data_dist, vec(r)) +estimated_pdf = pdf(d, r) +new_data = rand(d, n) + +# Evaluate It +using Distances +mad_ = meanad(estimated_pdf, actual_pdf) +msd_ = msd(estimated_pdf, actual_pdf) +tv_dis = totalvariation(estimated_pdf, actual_pdf) / n +res_df = DataFrame(; mad_, msd_, tv_dis) +display(res_df) + +# Plot It +using CairoMakie +f = Figure() +ax = Makie.Axis(f[1, 1]; title = "Result") +lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(data_dist, x); label = "actual") +lines!(ax, 0.0f0 .. 1.0f0, x -> pdf(d, vcat(x)); label = "estimated") +axislegend(ax) +save("result-fig.svg", f) +save("result-fig.png", f) ``` diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 38d5d313..06c64d71 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,116 +1,116 @@ -import ADTypes, - BenchmarkTools, - ComponentArrays, - DifferentiationInterface, - Lux, - PkgBenchmark, - StableRNGs, - Zygote, - ContinuousNormalizingFlows - -SUITE = BenchmarkTools.BenchmarkGroup() - -SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"]) - -SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"]) -SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"]) - -SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"]) -SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"]) - -SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"]) -SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"]) - -rng = StableRNGs.StableRNG(12345) -nvars = 2^3 -naugs = nvars -n_in = nvars + naugs -n = 2^6 -nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh)) - -icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.RNODE, - nn, - nvars, - naugs; - compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 13.0f0), - steer_rate = 1.0f-1, - λ₃ = 1.0f-2, - rng, -) -ps, st = Lux.setup(icnf.rng, icnf) -ps = ComponentArrays.ComponentArray(ps) -r = rand(icnf.rng, Float32, nvars, n) - -function diff_loss_tn(x) - ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TrainMode(), r, x, st) -end -function diff_loss_tt(x) - ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TestMode(), r, x, st) -end - -diff_loss_tn(ps) -diff_loss_tt(ps) -DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps) -DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps) -GC.gc() - -SUITE["main"]["no_inplace"]["direct"]["train"] = - BenchmarkTools.@benchmarkable diff_loss_tn(ps) -SUITE["main"]["no_inplace"]["direct"]["test"] = - BenchmarkTools.@benchmarkable diff_loss_tt(ps) -SUITE["main"]["no_inplace"]["AD-1-order"]["train"] = - BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( - diff_loss_tn, - ADTypes.AutoZygote(), - ps, - ) -SUITE["main"]["no_inplace"]["AD-1-order"]["test"] = - BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( - diff_loss_tt, - ADTypes.AutoZygote(), - ps, - ) - -icnf2 = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.RNODE, - nn, - nvars, - naugs; - inplace = true, - compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 13.0f0), - steer_rate = 1.0f-1, - λ₃ = 1.0f-2, - rng, -) - -function diff_loss_tn2(x) - ContinuousNormalizingFlows.loss(icnf2, ContinuousNormalizingFlows.TrainMode(), r, x, st) -end -function diff_loss_tt2(x) - ContinuousNormalizingFlows.loss(icnf2, ContinuousNormalizingFlows.TestMode(), r, x, st) -end - -diff_loss_tn2(ps) -diff_loss_tt2(ps) -DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps) -DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps) -GC.gc() - -SUITE["main"]["inplace"]["direct"]["train"] = - BenchmarkTools.@benchmarkable diff_loss_tn2(ps) -SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_loss_tt2(ps) -SUITE["main"]["inplace"]["AD-1-order"]["train"] = - BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( - diff_loss_tn2, - ADTypes.AutoZygote(), - ps, - ) -SUITE["main"]["inplace"]["AD-1-order"]["test"] = - BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( - diff_loss_tt2, - ADTypes.AutoZygote(), - ps, - ) +import ADTypes, + BenchmarkTools, + ComponentArrays, + DifferentiationInterface, + Lux, + PkgBenchmark, + StableRNGs, + Zygote, + ContinuousNormalizingFlows + +SUITE = BenchmarkTools.BenchmarkGroup() + +SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"]) + +SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"]) +SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"]) + +SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"]) +SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"]) + +SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"]) +SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"]) + +rng = StableRNGs.StableRNG(12345) +nvars = 2^3 +naugs = nvars +n_in = nvars + naugs +n = 2^6 +nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh)) + +icnf = ContinuousNormalizingFlows.construct( + ContinuousNormalizingFlows.RNODE, + nn, + nvars, + naugs; + compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + tspan = (0.0f0, 13.0f0), + steer_rate = 1.0f-1, + λ₃ = 1.0f-2, + rng, +) +ps, st = Lux.setup(icnf.rng, icnf) +ps = ComponentArrays.ComponentArray(ps) +r = rand(icnf.rng, Float32, nvars, n) + +function diff_loss_tn(x) + ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TrainMode(), r, x, st) +end +function diff_loss_tt(x) + ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TestMode(), r, x, st) +end + +diff_loss_tn(ps) +diff_loss_tt(ps) +DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps) +DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps) +GC.gc() + +SUITE["main"]["no_inplace"]["direct"]["train"] = + BenchmarkTools.@benchmarkable diff_loss_tn(ps) +SUITE["main"]["no_inplace"]["direct"]["test"] = + BenchmarkTools.@benchmarkable diff_loss_tt(ps) +SUITE["main"]["no_inplace"]["AD-1-order"]["train"] = + BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( + diff_loss_tn, + ADTypes.AutoZygote(), + ps, + ) +SUITE["main"]["no_inplace"]["AD-1-order"]["test"] = + BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( + diff_loss_tt, + ADTypes.AutoZygote(), + ps, + ) + +icnf2 = ContinuousNormalizingFlows.construct( + ContinuousNormalizingFlows.RNODE, + nn, + nvars, + naugs; + inplace = true, + compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + tspan = (0.0f0, 13.0f0), + steer_rate = 1.0f-1, + λ₃ = 1.0f-2, + rng, +) + +function diff_loss_tn2(x) + ContinuousNormalizingFlows.loss(icnf2, ContinuousNormalizingFlows.TrainMode(), r, x, st) +end +function diff_loss_tt2(x) + ContinuousNormalizingFlows.loss(icnf2, ContinuousNormalizingFlows.TestMode(), r, x, st) +end + +diff_loss_tn2(ps) +diff_loss_tt2(ps) +DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps) +DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps) +GC.gc() + +SUITE["main"]["inplace"]["direct"]["train"] = + BenchmarkTools.@benchmarkable diff_loss_tn2(ps) +SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_loss_tt2(ps) +SUITE["main"]["inplace"]["AD-1-order"]["train"] = + BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( + diff_loss_tn2, + ADTypes.AutoZygote(), + ps, + ) +SUITE["main"]["inplace"]["AD-1-order"]["test"] = + BenchmarkTools.@benchmarkable DifferentiationInterface.gradient( + diff_loss_tt2, + ADTypes.AutoZygote(), + ps, + ) diff --git a/docs/make.jl b/docs/make.jl index 2d2f1d7a..ad0876af 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,25 +1,25 @@ -import Documenter, ContinuousNormalizingFlows - -Documenter.DocMeta.setdocmeta!( - ContinuousNormalizingFlows, - :DocTestSetup, - :(using ContinuousNormalizingFlows); - recursive = true, -) - -Documenter.makedocs(; - modules = [ContinuousNormalizingFlows], - authors = "Hossein Pourbozorg and contributors", - repo = "https://github.com/impICNF/ContinuousNormalizingFlows.jl/blob/{commit}{path}#{line}", - sitename = "ContinuousNormalizingFlows.jl", - format = Documenter.HTML(; - canonical = "https://impICNF.github.io/ContinuousNormalizingFlows.jl", - edit_link = "main", - ), - pages = ["Home" => "index.md"], -) - -Documenter.deploydocs(; - repo = "github.com/impICNF/ContinuousNormalizingFlows.jl", - devbranch = "main", -) +import Documenter, ContinuousNormalizingFlows + +Documenter.DocMeta.setdocmeta!( + ContinuousNormalizingFlows, + :DocTestSetup, + :(using ContinuousNormalizingFlows); + recursive = true, +) + +Documenter.makedocs(; + modules = [ContinuousNormalizingFlows], + authors = "Hossein Pourbozorg and contributors", + repo = "https://github.com/impICNF/ContinuousNormalizingFlows.jl/blob/{commit}{path}#{line}", + sitename = "ContinuousNormalizingFlows.jl", + format = Documenter.HTML(; + canonical = "https://impICNF.github.io/ContinuousNormalizingFlows.jl", + edit_link = "main", + ), + pages = ["Home" => "index.md"], +) + +Documenter.deploydocs(; + repo = "github.com/impICNF/ContinuousNormalizingFlows.jl", + devbranch = "main", +) diff --git a/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl b/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl index 7f424f50..555adc30 100644 --- a/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl +++ b/ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl @@ -1,17 +1,17 @@ -module ContinuousNormalizingFlowsCUDAExt - -import CUDA, ComputationalResources, ContinuousNormalizingFlows - -@inline function ContinuousNormalizingFlows.rng_AT(::ComputationalResources.CUDALibs) - CUDA.CURAND.default_rng() -end - -@inline function ContinuousNormalizingFlows.base_AT( - ::ComputationalResources.CUDALibs, - ::ContinuousNormalizingFlows.AbstractICNF{T}, - dims..., -) where {T <: AbstractFloat} - CUDA.CuArray{T}(undef, dims...) -end - -end +module ContinuousNormalizingFlowsCUDAExt + +import CUDA, ComputationalResources, ContinuousNormalizingFlows + +@inline function ContinuousNormalizingFlows.rng_AT(::ComputationalResources.CUDALibs) + CUDA.CURAND.default_rng() +end + +@inline function ContinuousNormalizingFlows.base_AT( + ::ComputationalResources.CUDALibs, + ::ContinuousNormalizingFlows.AbstractICNF{T}, + dims..., +) where {T <: AbstractFloat} + CUDA.CuArray{T}(undef, dims...) +end + +end diff --git a/src/ContinuousNormalizingFlows.jl b/src/ContinuousNormalizingFlows.jl index 38fc4607..dc597f11 100644 --- a/src/ContinuousNormalizingFlows.jl +++ b/src/ContinuousNormalizingFlows.jl @@ -1,84 +1,84 @@ -module ContinuousNormalizingFlows - -import AbstractDifferentiation, - ADTypes, - Base.Iterators, - ChainRulesCore, - ComponentArrays, - ComputationalResources, - DataFrames, - Dates, - DifferentiationInterface, - Distributions, - DistributionsAD, - FillArrays, - LinearAlgebra, - Lux, - LuxCore, - MLJBase, - MLJModelInterface, - MLUtils, - NNlib, - Octavian, - Optimisers, - Optimization, - OptimizationOptimisers, - OrdinaryDiffEq, - Random, - ScientificTypesBase, - SciMLBase, - SciMLSensitivity, - Statistics, - Zygote - -export construct, - inference, - generate, - loss, - ICNF, - RNODE, - CondRNODE, - FFJORD, - CondFFJORD, - Planar, - CondPlanar, - TestMode, - TrainMode, - ADVecJacVectorMode, - ADJacVecVectorMode, - DIVecJacVectorMode, - DIJacVecVectorMode, - DIVecJacMatrixMode, - DIJacVecMatrixMode, - ICNFModel, - CondICNFModel, - CondLayer, - PlanarLayer, - MulLayer - -include(joinpath("layers", "cond_layer.jl")) -include(joinpath("layers", "planar_layer.jl")) -include(joinpath("layers", "mul_layer.jl")) - -include("types.jl") - -include("base_icnf.jl") - -include("icnf.jl") - -include("utils.jl") - -include(joinpath("exts", "mlj_ext", "core.jl")) -include(joinpath("exts", "mlj_ext", "core_icnf.jl")) -include(joinpath("exts", "mlj_ext", "core_cond_icnf.jl")) - -include(joinpath("exts", "dist_ext", "core.jl")) -include(joinpath("exts", "dist_ext", "core_icnf.jl")) -include(joinpath("exts", "dist_ext", "core_cond_icnf.jl")) - -""" -Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia -""" -ContinuousNormalizingFlows - -end +module ContinuousNormalizingFlows + +import AbstractDifferentiation, + ADTypes, + Base.Iterators, + ChainRulesCore, + ComponentArrays, + ComputationalResources, + DataFrames, + Dates, + DifferentiationInterface, + Distributions, + DistributionsAD, + FillArrays, + LinearAlgebra, + Lux, + LuxCore, + MLJBase, + MLJModelInterface, + MLUtils, + NNlib, + Octavian, + Optimisers, + Optimization, + OptimizationOptimisers, + OrdinaryDiffEq, + Random, + ScientificTypesBase, + SciMLBase, + SciMLSensitivity, + Statistics, + Zygote + +export construct, + inference, + generate, + loss, + ICNF, + RNODE, + CondRNODE, + FFJORD, + CondFFJORD, + Planar, + CondPlanar, + TestMode, + TrainMode, + ADVecJacVectorMode, + ADJacVecVectorMode, + DIVecJacVectorMode, + DIJacVecVectorMode, + DIVecJacMatrixMode, + DIJacVecMatrixMode, + ICNFModel, + CondICNFModel, + CondLayer, + PlanarLayer, + MulLayer + +include(joinpath("layers", "cond_layer.jl")) +include(joinpath("layers", "planar_layer.jl")) +include(joinpath("layers", "mul_layer.jl")) + +include("types.jl") + +include("base_icnf.jl") + +include("icnf.jl") + +include("utils.jl") + +include(joinpath("exts", "mlj_ext", "core.jl")) +include(joinpath("exts", "mlj_ext", "core_icnf.jl")) +include(joinpath("exts", "mlj_ext", "core_cond_icnf.jl")) + +include(joinpath("exts", "dist_ext", "core.jl")) +include(joinpath("exts", "dist_ext", "core_icnf.jl")) +include(joinpath("exts", "dist_ext", "core_cond_icnf.jl")) + +""" +Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia +""" +ContinuousNormalizingFlows + +end diff --git a/src/base_icnf.jl b/src/base_icnf.jl index b4d92624..58413cfd 100644 --- a/src/base_icnf.jl +++ b/src/base_icnf.jl @@ -1,538 +1,538 @@ -function construct( - aicnf::Type{<:AbstractICNF}, - nn::LuxCore.AbstractExplicitLayer, - nvars::Int, - naugmented::Int = 0; - data_type::Type{<:AbstractFloat} = Float32, - compute_mode::ComputeMode = ADVecJacVectorMode(AbstractDifferentiation.ZygoteBackend()), - inplace::Bool = false, - cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar}, - resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(), - basedist::Distributions.Distribution = Distributions.MvNormal( - FillArrays.Zeros{data_type}(nvars + naugmented), - FillArrays.Eye{data_type}(nvars + naugmented), - ), - tspan::NTuple{2} = (zero(data_type), one(data_type)), - steer_rate::AbstractFloat = zero(data_type), - epsdist::Distributions.Distribution = Distributions.MvNormal( - FillArrays.Zeros{data_type}(nvars + naugmented), - FillArrays.Eye{data_type}(nvars + naugmented), - ), - sol_kwargs::NamedTuple = (save_everystep = false,), - rng::Random.AbstractRNG = rng_AT(resource), - λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} - convert(data_type, 1e-2) - else - zero(data_type) - end, - λ₂::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} - convert(data_type, 1e-2) - else - zero(data_type) - end, - λ₃::AbstractFloat = zero(data_type), -) - steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) - - ICNF{ - data_type, - typeof(compute_mode), - inplace, - cond, - !iszero(naugmented), - !iszero(steer_rate), - !iszero(λ₁), - !iszero(λ₂), - !iszero(λ₃), - typeof(nn), - typeof(nvars), - typeof(resource), - typeof(basedist), - typeof(tspan), - typeof(steerdist), - typeof(epsdist), - typeof(sol_kwargs), - typeof(rng), - }( - nn, - nvars, - naugmented, - compute_mode, - resource, - basedist, - tspan, - steerdist, - epsdist, - sol_kwargs, - rng, - λ₁, - λ₂, - λ₃, - ) -end - -@inline function n_augment(::AbstractICNF, ::Mode) - 0 -end - -function Base.show(io::IO, icnf::AbstractICNF) - print( - io, - typeof(icnf), - "<", - "Number of Variables: ", - icnf.nvars, - ", Number of Augmentations: ", - n_augment_input(icnf), - ", Time Span: ", - icnf.tspan, - ">", - ) -end - -@inline function n_augment_input( - icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, COND, true}, -) where {INPLACE, COND} - icnf.naugmented -end - -@inline function n_augment_input(::AbstractICNF) - 0 -end - -@inline function steer_tspan( - icnf::AbstractICNF{T, <:ComputeMode, INPLACE, COND, AUGMENTED, true}, - ::TrainMode, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED} - t₀, t₁ = icnf.tspan - Δt = abs(t₁ - t₀) - r = convert(T, rand(icnf.rng, icnf.steerdist)) - t₁_new = muladd(Δt, r, t₁) - (t₀, t₁_new) -end - -@inline function steer_tspan(icnf::AbstractICNF, ::Mode) - icnf.tspan -end - -@inline function rng_AT(::ComputationalResources.AbstractResource) - Random.default_rng() -end - -@inline function base_AT( - ::ComputationalResources.AbstractResource, - ::AbstractICNF{T}, - dims..., -) where {T <: AbstractFloat} - Array{T}(undef, dims...) -end - -ChainRulesCore.@non_differentiable base_AT(::Any...) - -function base_sol( - icnf::AbstractICNF{T, <:ComputeMode, INPLACE}, - prob::SciMLBase.AbstractODEProblem{<:AbstractVecOrMat{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE} - sol = SciMLBase.solve(prob; icnf.sol_kwargs...) - get_fsol(sol) -end - -function inference_sol( - icnf::AbstractICNF{T, <:VectorMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, - mode::Mode, - prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} - n_aug = n_augment(icnf, mode) - fsol = base_sol(icnf, prob) - z = fsol[begin:(end - n_aug - 1)] - Δlogp = fsol[(end - n_aug)] - augs = fsol[(end - n_aug + 1):end] - logpz = oftype(Δlogp, Distributions.logpdf(icnf.basedist, z)) - logp̂x = logpz - Δlogp - Ȧ = if (NORM_Z_AUG && AUGMENTED) - n_aug_input = n_augment_input(icnf) - z_aug = z[(end - n_aug_input + 1):end] - LinearAlgebra.norm(z_aug) - else - zero(T) - end - (logp̂x, vcat(augs, Ȧ)) -end - -function inference_sol( - icnf::AbstractICNF{T, <:MatrixMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, - mode::Mode, - prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} - n_aug = n_augment(icnf, mode) - fsol = base_sol(icnf, prob) - z = fsol[begin:(end - n_aug - 1), :] - Δlogp = fsol[(end - n_aug), :] - augs = fsol[(end - n_aug + 1):end, :] - logpz = oftype(Δlogp, Distributions.logpdf(icnf.basedist, z)) - logp̂x = logpz - Δlogp - Ȧ = transpose(if (NORM_Z_AUG && AUGMENTED) - n_aug_input = n_augment_input(icnf) - z_aug = z[(end - n_aug_input + 1):end, :] - LinearAlgebra.norm.(eachcol(z_aug)) - else - zrs_aug = similar(augs, size(augs, 2)) - ChainRulesCore.@ignore_derivatives fill!(zrs_aug, zero(T)) - zrs_aug - end) - (logp̂x, eachrow(vcat(augs, Ȧ))) -end - -function generate_sol( - icnf::AbstractICNF{T, <:VectorMode, INPLACE}, - mode::Mode, - prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - fsol = base_sol(icnf, prob) - fsol[begin:(end - n_aug_input - n_aug - 1)] -end - -function generate_sol( - icnf::AbstractICNF{T, <:MatrixMode, INPLACE}, - mode::Mode, - prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - fsol = base_sol(icnf, prob) - fsol[begin:(end - n_aug_input - n_aug - 1), :] -end - -@inline function get_fsol(sol::SciMLBase.AbstractODESolution) - last(sol.u) -end - -@inline function get_fsol(sol::AbstractArray{T, N}) where {T, N} - selectdim(sol, N, lastindex(sol, N)) -end - -function inference_prob( - icnf::AbstractICNF{T, <:VectorMode, INPLACE, false}, - mode::Mode, - xs::AbstractVector{<:Real}, - ps::Any, - st::NamedTuple, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - zrs = similar(xs, n_aug_input + n_aug + 1) - ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) - Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = icnf.nn - SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, st, ϵ), - vcat(xs, zrs), - steer_tspan(icnf, mode), - ps, - ) -end - -function inference_prob( - icnf::AbstractICNF{T, <:VectorMode, INPLACE, true}, - mode::Mode, - xs::AbstractVector{<:Real}, - ys::AbstractVector{<:Real}, - ps::Any, - st::NamedTuple, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - zrs = similar(xs, n_aug_input + n_aug + 1) - ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) - Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = CondLayer(icnf.nn, ys) - SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, st, ϵ), - vcat(xs, zrs), - steer_tspan(icnf, mode), - ps, - ) -end - -function inference_prob( - icnf::AbstractICNF{T, <:MatrixMode, INPLACE, false}, - mode::Mode, - xs::AbstractMatrix{<:Real}, - ps::Any, - st::NamedTuple, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2)) - ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) - Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = icnf.nn - SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, st, ϵ), - vcat(xs, zrs), - steer_tspan(icnf, mode), - ps, - ) -end - -function inference_prob( - icnf::AbstractICNF{T, <:MatrixMode, INPLACE, true}, - mode::Mode, - xs::AbstractMatrix{<:Real}, - ys::AbstractMatrix{<:Real}, - ps::Any, - st::NamedTuple, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2)) - ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) - Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = CondLayer(icnf.nn, ys) - SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, st, ϵ), - vcat(xs, zrs), - steer_tspan(icnf, mode), - ps, - ) -end - -function generate_prob( - icnf::AbstractICNF{T, <:VectorMode, INPLACE, false}, - mode::Mode, - ps::Any, - st::NamedTuple, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) - Random.rand!(icnf.rng, icnf.basedist, new_xs) - zrs = similar(new_xs, n_aug + 1) - ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) - Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = icnf.nn - SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, st, ϵ), - vcat(new_xs, zrs), - reverse(steer_tspan(icnf, mode)), - ps, - ) -end - -function generate_prob( - icnf::AbstractICNF{T, <:VectorMode, INPLACE, true}, - mode::Mode, - ys::AbstractVector{<:Real}, - ps::Any, - st::NamedTuple, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) - Random.rand!(icnf.rng, icnf.basedist, new_xs) - zrs = similar(new_xs, n_aug + 1) - ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) - Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = CondLayer(icnf.nn, ys) - SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, st, ϵ), - vcat(new_xs, zrs), - reverse(steer_tspan(icnf, mode)), - ps, - ) -end - -function generate_prob( - icnf::AbstractICNF{T, <:MatrixMode, INPLACE, false}, - mode::Mode, - ps::Any, - st::NamedTuple, - n::Int, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) - Random.rand!(icnf.rng, icnf.basedist, new_xs) - zrs = similar(new_xs, n_aug + 1, n) - ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) - Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = icnf.nn - SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, st, ϵ), - vcat(new_xs, zrs), - reverse(steer_tspan(icnf, mode)), - ps, - ) -end - -function generate_prob( - icnf::AbstractICNF{T, <:MatrixMode, INPLACE, true}, - mode::Mode, - ys::AbstractMatrix{<:Real}, - ps::Any, - st::NamedTuple, - n::Int, -) where {T <: AbstractFloat, INPLACE} - n_aug = n_augment(icnf, mode) - n_aug_input = n_augment_input(icnf) - new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) - Random.rand!(icnf.rng, icnf.basedist, new_xs) - zrs = similar(new_xs, n_aug + 1, n) - ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) - ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) - Random.rand!(icnf.rng, icnf.epsdist, ϵ) - nn = CondLayer(icnf.nn, ys) - SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( - make_ode_func(icnf, mode, nn, st, ϵ), - vcat(new_xs, zrs), - reverse(steer_tspan(icnf, mode)), - ps, - ) -end - -@inline function inference( - icnf::AbstractICNF, - mode::Mode, - xs::AbstractVecOrMat{<:Real}, - ps::Any, - st::NamedTuple, -) - inference_sol(icnf, mode, inference_prob(icnf, mode, xs, ps, st)) -end - -@inline function inference( - icnf::AbstractICNF, - mode::Mode, - xs::AbstractVecOrMat{<:Real}, - ys::AbstractVecOrMat{<:Real}, - ps::Any, - st::NamedTuple, -) - inference_sol(icnf, mode, inference_prob(icnf, mode, xs, ys, ps, st)) -end - -@inline function generate( - icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, - mode::Mode, - ps::Any, - st::NamedTuple, -) - generate_sol(icnf, mode, generate_prob(icnf, mode, ps, st)) -end - -@inline function generate( - icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, - mode::Mode, - ys::AbstractVector{<:Real}, - ps::Any, - st::NamedTuple, -) - generate_sol(icnf, mode, generate_prob(icnf, mode, ys, ps, st)) -end - -@inline function generate( - icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, - mode::Mode, - ps::Any, - st::NamedTuple, - n::Int, -) - generate_sol(icnf, mode, generate_prob(icnf, mode, ps, st, n)) -end - -@inline function generate( - icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, - mode::Mode, - ys::AbstractMatrix{<:Real}, - ps::Any, - st::NamedTuple, - n::Int, -) - generate_sol(icnf, mode, generate_prob(icnf, mode, ys, ps, st, n)) -end - -@inline function loss( - icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, - mode::Mode, - xs::AbstractVector{<:Real}, - ps::Any, - st::NamedTuple, -) - -first(inference(icnf, mode, xs, ps, st)) -end - -@inline function loss( - icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, - mode::Mode, - xs::AbstractVector{<:Real}, - ys::AbstractVector{<:Real}, - ps::Any, - st::NamedTuple, -) - -first(inference(icnf, mode, xs, ys, ps, st)) -end - -@inline function loss( - icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, - mode::Mode, - xs::AbstractMatrix{<:Real}, - ps::Any, - st::NamedTuple, -) - -Statistics.mean(first(inference(icnf, mode, xs, ps, st))) -end - -@inline function loss( - icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, - mode::Mode, - xs::AbstractMatrix{<:Real}, - ys::AbstractMatrix{<:Real}, - ps::Any, - st::NamedTuple, -) - -Statistics.mean(first(inference(icnf, mode, xs, ys, ps, st))) -end - -@inline function make_ode_func( - icnf::AbstractICNF{T, CM, INPLACE}, - mode::Mode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVecOrMat{T}, -) where {T <: AbstractFloat, CM, INPLACE} - function ode_func_op(u, p, t) - augmented_f(u, p, t, icnf, mode, nn, st, ϵ) - end - - function ode_func_ip(du, u, p, t) - augmented_f(du, u, p, t, icnf, mode, nn, st, ϵ) - end - - ifelse(INPLACE, ode_func_ip, ode_func_op) -end - -@inline function (icnf::AbstractICNF{T, CM, INPLACE, false})( - xs::AbstractVecOrMat, - ps::Any, - st::NamedTuple, -) where {T, CM, INPLACE} - first(inference(icnf, TrainMode(), xs, ps, st)), st -end - -@inline function (icnf::AbstractICNF{T, CM, INPLACE, true})( - xs_ys::Tuple, - ps::Any, - st::NamedTuple, -) where {T, CM, INPLACE} - xs, ys = xs_ys - first(inference(icnf, TrainMode(), xs, ys, ps, st)), st -end +function construct( + aicnf::Type{<:AbstractICNF}, + nn::LuxCore.AbstractExplicitLayer, + nvars::Int, + naugmented::Int = 0; + data_type::Type{<:AbstractFloat} = Float32, + compute_mode::ComputeMode = ADVecJacVectorMode(AbstractDifferentiation.ZygoteBackend()), + inplace::Bool = false, + cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar}, + resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(), + basedist::Distributions.Distribution = Distributions.MvNormal( + FillArrays.Zeros{data_type}(nvars + naugmented), + FillArrays.Eye{data_type}(nvars + naugmented), + ), + tspan::NTuple{2} = (zero(data_type), one(data_type)), + steer_rate::AbstractFloat = zero(data_type), + epsdist::Distributions.Distribution = Distributions.MvNormal( + FillArrays.Zeros{data_type}(nvars + naugmented), + FillArrays.Eye{data_type}(nvars + naugmented), + ), + sol_kwargs::NamedTuple = (save_everystep = false,), + rng::Random.AbstractRNG = rng_AT(resource), + λ₁::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} + convert(data_type, 1e-2) + else + zero(data_type) + end, + λ₂::AbstractFloat = if aicnf <: Union{RNODE, CondRNODE} + convert(data_type, 1e-2) + else + zero(data_type) + end, + λ₃::AbstractFloat = zero(data_type), +) + steerdist = Distributions.Uniform{data_type}(-steer_rate, steer_rate) + + ICNF{ + data_type, + typeof(compute_mode), + inplace, + cond, + !iszero(naugmented), + !iszero(steer_rate), + !iszero(λ₁), + !iszero(λ₂), + !iszero(λ₃), + typeof(nn), + typeof(nvars), + typeof(resource), + typeof(basedist), + typeof(tspan), + typeof(steerdist), + typeof(epsdist), + typeof(sol_kwargs), + typeof(rng), + }( + nn, + nvars, + naugmented, + compute_mode, + resource, + basedist, + tspan, + steerdist, + epsdist, + sol_kwargs, + rng, + λ₁, + λ₂, + λ₃, + ) +end + +@inline function n_augment(::AbstractICNF, ::Mode) + 0 +end + +function Base.show(io::IO, icnf::AbstractICNF) + print( + io, + typeof(icnf), + "<", + "Number of Variables: ", + icnf.nvars, + ", Number of Augmentations: ", + n_augment_input(icnf), + ", Time Span: ", + icnf.tspan, + ">", + ) +end + +@inline function n_augment_input( + icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, COND, true}, +) where {INPLACE, COND} + icnf.naugmented +end + +@inline function n_augment_input(::AbstractICNF) + 0 +end + +@inline function steer_tspan( + icnf::AbstractICNF{T, <:ComputeMode, INPLACE, COND, AUGMENTED, true}, + ::TrainMode, +) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED} + t₀, t₁ = icnf.tspan + Δt = abs(t₁ - t₀) + r = convert(T, rand(icnf.rng, icnf.steerdist)) + t₁_new = muladd(Δt, r, t₁) + (t₀, t₁_new) +end + +@inline function steer_tspan(icnf::AbstractICNF, ::Mode) + icnf.tspan +end + +@inline function rng_AT(::ComputationalResources.AbstractResource) + Random.default_rng() +end + +@inline function base_AT( + ::ComputationalResources.AbstractResource, + ::AbstractICNF{T}, + dims..., +) where {T <: AbstractFloat} + Array{T}(undef, dims...) +end + +ChainRulesCore.@non_differentiable base_AT(::Any...) + +function base_sol( + icnf::AbstractICNF{T, <:ComputeMode, INPLACE}, + prob::SciMLBase.AbstractODEProblem{<:AbstractVecOrMat{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE} + sol = SciMLBase.solve(prob; icnf.sol_kwargs...) + get_fsol(sol) +end + +function inference_sol( + icnf::AbstractICNF{T, <:VectorMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, + mode::Mode, + prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} + n_aug = n_augment(icnf, mode) + fsol = base_sol(icnf, prob) + z = fsol[begin:(end - n_aug - 1)] + Δlogp = fsol[(end - n_aug)] + augs = fsol[(end - n_aug + 1):end] + logpz = oftype(Δlogp, Distributions.logpdf(icnf.basedist, z)) + logp̂x = logpz - Δlogp + Ȧ = if (NORM_Z_AUG && AUGMENTED) + n_aug_input = n_augment_input(icnf) + z_aug = z[(end - n_aug_input + 1):end] + LinearAlgebra.norm(z_aug) + else + zero(T) + end + (logp̂x, vcat(augs, Ȧ)) +end + +function inference_sol( + icnf::AbstractICNF{T, <:MatrixMode, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG}, + mode::Mode, + prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} + n_aug = n_augment(icnf, mode) + fsol = base_sol(icnf, prob) + z = fsol[begin:(end - n_aug - 1), :] + Δlogp = fsol[(end - n_aug), :] + augs = fsol[(end - n_aug + 1):end, :] + logpz = oftype(Δlogp, Distributions.logpdf(icnf.basedist, z)) + logp̂x = logpz - Δlogp + Ȧ = transpose(if (NORM_Z_AUG && AUGMENTED) + n_aug_input = n_augment_input(icnf) + z_aug = z[(end - n_aug_input + 1):end, :] + LinearAlgebra.norm.(eachcol(z_aug)) + else + zrs_aug = similar(augs, size(augs, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs_aug, zero(T)) + zrs_aug + end) + (logp̂x, eachrow(vcat(augs, Ȧ))) +end + +function generate_sol( + icnf::AbstractICNF{T, <:VectorMode, INPLACE}, + mode::Mode, + prob::SciMLBase.AbstractODEProblem{<:AbstractVector{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + fsol = base_sol(icnf, prob) + fsol[begin:(end - n_aug_input - n_aug - 1)] +end + +function generate_sol( + icnf::AbstractICNF{T, <:MatrixMode, INPLACE}, + mode::Mode, + prob::SciMLBase.AbstractODEProblem{<:AbstractMatrix{<:Real}, NTuple{2, T}, INPLACE}, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + fsol = base_sol(icnf, prob) + fsol[begin:(end - n_aug_input - n_aug - 1), :] +end + +@inline function get_fsol(sol::SciMLBase.AbstractODESolution) + last(sol.u) +end + +@inline function get_fsol(sol::AbstractArray{T, N}) where {T, N} + selectdim(sol, N, lastindex(sol, N)) +end + +function inference_prob( + icnf::AbstractICNF{T, <:VectorMode, INPLACE, false}, + mode::Mode, + xs::AbstractVector{<:Real}, + ps::Any, + st::NamedTuple, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + zrs = similar(xs, n_aug_input + n_aug + 1) + ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) + Random.rand!(icnf.rng, icnf.epsdist, ϵ) + nn = icnf.nn + SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( + make_ode_func(icnf, mode, nn, st, ϵ), + vcat(xs, zrs), + steer_tspan(icnf, mode), + ps, + ) +end + +function inference_prob( + icnf::AbstractICNF{T, <:VectorMode, INPLACE, true}, + mode::Mode, + xs::AbstractVector{<:Real}, + ys::AbstractVector{<:Real}, + ps::Any, + st::NamedTuple, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + zrs = similar(xs, n_aug_input + n_aug + 1) + ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) + Random.rand!(icnf.rng, icnf.epsdist, ϵ) + nn = CondLayer(icnf.nn, ys) + SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( + make_ode_func(icnf, mode, nn, st, ϵ), + vcat(xs, zrs), + steer_tspan(icnf, mode), + ps, + ) +end + +function inference_prob( + icnf::AbstractICNF{T, <:MatrixMode, INPLACE, false}, + mode::Mode, + xs::AbstractMatrix{<:Real}, + ps::Any, + st::NamedTuple, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) + Random.rand!(icnf.rng, icnf.epsdist, ϵ) + nn = icnf.nn + SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( + make_ode_func(icnf, mode, nn, st, ϵ), + vcat(xs, zrs), + steer_tspan(icnf, mode), + ps, + ) +end + +function inference_prob( + icnf::AbstractICNF{T, <:MatrixMode, INPLACE, true}, + mode::Mode, + xs::AbstractMatrix{<:Real}, + ys::AbstractMatrix{<:Real}, + ps::Any, + st::NamedTuple, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + zrs = similar(xs, n_aug_input + n_aug + 1, size(xs, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, size(xs, 2)) + Random.rand!(icnf.rng, icnf.epsdist, ϵ) + nn = CondLayer(icnf.nn, ys) + SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( + make_ode_func(icnf, mode, nn, st, ϵ), + vcat(xs, zrs), + steer_tspan(icnf, mode), + ps, + ) +end + +function generate_prob( + icnf::AbstractICNF{T, <:VectorMode, INPLACE, false}, + mode::Mode, + ps::Any, + st::NamedTuple, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) + Random.rand!(icnf.rng, icnf.basedist, new_xs) + zrs = similar(new_xs, n_aug + 1) + ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) + Random.rand!(icnf.rng, icnf.epsdist, ϵ) + nn = icnf.nn + SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( + make_ode_func(icnf, mode, nn, st, ϵ), + vcat(new_xs, zrs), + reverse(steer_tspan(icnf, mode)), + ps, + ) +end + +function generate_prob( + icnf::AbstractICNF{T, <:VectorMode, INPLACE, true}, + mode::Mode, + ys::AbstractVector{<:Real}, + ps::Any, + st::NamedTuple, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) + Random.rand!(icnf.rng, icnf.basedist, new_xs) + zrs = similar(new_xs, n_aug + 1) + ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input) + Random.rand!(icnf.rng, icnf.epsdist, ϵ) + nn = CondLayer(icnf.nn, ys) + SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( + make_ode_func(icnf, mode, nn, st, ϵ), + vcat(new_xs, zrs), + reverse(steer_tspan(icnf, mode)), + ps, + ) +end + +function generate_prob( + icnf::AbstractICNF{T, <:MatrixMode, INPLACE, false}, + mode::Mode, + ps::Any, + st::NamedTuple, + n::Int, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) + Random.rand!(icnf.rng, icnf.basedist, new_xs) + zrs = similar(new_xs, n_aug + 1, n) + ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) + Random.rand!(icnf.rng, icnf.epsdist, ϵ) + nn = icnf.nn + SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( + make_ode_func(icnf, mode, nn, st, ϵ), + vcat(new_xs, zrs), + reverse(steer_tspan(icnf, mode)), + ps, + ) +end + +function generate_prob( + icnf::AbstractICNF{T, <:MatrixMode, INPLACE, true}, + mode::Mode, + ys::AbstractMatrix{<:Real}, + ps::Any, + st::NamedTuple, + n::Int, +) where {T <: AbstractFloat, INPLACE} + n_aug = n_augment(icnf, mode) + n_aug_input = n_augment_input(icnf) + new_xs = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) + Random.rand!(icnf.rng, icnf.basedist, new_xs) + zrs = similar(new_xs, n_aug + 1, n) + ChainRulesCore.@ignore_derivatives fill!(zrs, zero(T)) + ϵ = base_AT(icnf.resource, icnf, icnf.nvars + n_aug_input, n) + Random.rand!(icnf.rng, icnf.epsdist, ϵ) + nn = CondLayer(icnf.nn, ys) + SciMLBase.ODEProblem{INPLACE, SciMLBase.FullSpecialize}( + make_ode_func(icnf, mode, nn, st, ϵ), + vcat(new_xs, zrs), + reverse(steer_tspan(icnf, mode)), + ps, + ) +end + +@inline function inference( + icnf::AbstractICNF, + mode::Mode, + xs::AbstractVecOrMat{<:Real}, + ps::Any, + st::NamedTuple, +) + inference_sol(icnf, mode, inference_prob(icnf, mode, xs, ps, st)) +end + +@inline function inference( + icnf::AbstractICNF, + mode::Mode, + xs::AbstractVecOrMat{<:Real}, + ys::AbstractVecOrMat{<:Real}, + ps::Any, + st::NamedTuple, +) + inference_sol(icnf, mode, inference_prob(icnf, mode, xs, ys, ps, st)) +end + +@inline function generate( + icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, + mode::Mode, + ps::Any, + st::NamedTuple, +) + generate_sol(icnf, mode, generate_prob(icnf, mode, ps, st)) +end + +@inline function generate( + icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, + mode::Mode, + ys::AbstractVector{<:Real}, + ps::Any, + st::NamedTuple, +) + generate_sol(icnf, mode, generate_prob(icnf, mode, ys, ps, st)) +end + +@inline function generate( + icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, + mode::Mode, + ps::Any, + st::NamedTuple, + n::Int, +) + generate_sol(icnf, mode, generate_prob(icnf, mode, ps, st, n)) +end + +@inline function generate( + icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, + mode::Mode, + ys::AbstractMatrix{<:Real}, + ps::Any, + st::NamedTuple, + n::Int, +) + generate_sol(icnf, mode, generate_prob(icnf, mode, ys, ps, st, n)) +end + +@inline function loss( + icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, + mode::Mode, + xs::AbstractVector{<:Real}, + ps::Any, + st::NamedTuple, +) + -first(inference(icnf, mode, xs, ps, st)) +end + +@inline function loss( + icnf::AbstractICNF{<:AbstractFloat, <:VectorMode}, + mode::Mode, + xs::AbstractVector{<:Real}, + ys::AbstractVector{<:Real}, + ps::Any, + st::NamedTuple, +) + -first(inference(icnf, mode, xs, ys, ps, st)) +end + +@inline function loss( + icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, + mode::Mode, + xs::AbstractMatrix{<:Real}, + ps::Any, + st::NamedTuple, +) + -Statistics.mean(first(inference(icnf, mode, xs, ps, st))) +end + +@inline function loss( + icnf::AbstractICNF{<:AbstractFloat, <:MatrixMode}, + mode::Mode, + xs::AbstractMatrix{<:Real}, + ys::AbstractMatrix{<:Real}, + ps::Any, + st::NamedTuple, +) + -Statistics.mean(first(inference(icnf, mode, xs, ys, ps, st))) +end + +@inline function make_ode_func( + icnf::AbstractICNF{T, CM, INPLACE}, + mode::Mode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVecOrMat{T}, +) where {T <: AbstractFloat, CM, INPLACE} + function ode_func_op(u, p, t) + augmented_f(u, p, t, icnf, mode, nn, st, ϵ) + end + + function ode_func_ip(du, u, p, t) + augmented_f(du, u, p, t, icnf, mode, nn, st, ϵ) + end + + ifelse(INPLACE, ode_func_ip, ode_func_op) +end + +@inline function (icnf::AbstractICNF{T, CM, INPLACE, false})( + xs::AbstractVecOrMat, + ps::Any, + st::NamedTuple, +) where {T, CM, INPLACE} + first(inference(icnf, TrainMode(), xs, ps, st)), st +end + +@inline function (icnf::AbstractICNF{T, CM, INPLACE, true})( + xs_ys::Tuple, + ps::Any, + st::NamedTuple, +) where {T, CM, INPLACE} + xs, ys = xs_ys + first(inference(icnf, TrainMode(), xs, ys, ps, st)), st +end diff --git a/src/exts/dist_ext/core.jl b/src/exts/dist_ext/core.jl index fa765382..b2e9d451 100644 --- a/src/exts/dist_ext/core.jl +++ b/src/exts/dist_ext/core.jl @@ -1,16 +1,16 @@ -export ICNFDist, CondICNFDist - -abstract type ICNFDistribution{AICNF <: AbstractICNF} <: - Distributions.ContinuousMultivariateDistribution end - -function Base.length(d::ICNFDistribution) - d.m.nvars -end - -function Base.eltype(::ICNFDistribution{AICNF}) where {AICNF <: AbstractICNF} - first(AICNF.parameters) -end - -function Base.broadcastable(d::ICNFDistribution) - Ref(d) -end +export ICNFDist, CondICNFDist + +abstract type ICNFDistribution{AICNF <: AbstractICNF} <: + Distributions.ContinuousMultivariateDistribution end + +function Base.length(d::ICNFDistribution) + d.m.nvars +end + +function Base.eltype(::ICNFDistribution{AICNF}) where {AICNF <: AbstractICNF} + first(AICNF.parameters) +end + +function Base.broadcastable(d::ICNFDistribution) + Ref(d) +end diff --git a/src/exts/dist_ext/core_cond_icnf.jl b/src/exts/dist_ext/core_cond_icnf.jl index 37fa0cbe..1c34b08b 100644 --- a/src/exts/dist_ext/core_cond_icnf.jl +++ b/src/exts/dist_ext/core_cond_icnf.jl @@ -1,61 +1,61 @@ -struct CondICNFDist{AICNF <: AbstractICNF} <: ICNFDistribution{AICNF} - m::AICNF - mode::Mode - ys::AbstractVecOrMat{<:Real} - ps::Any - st::NamedTuple -end - -function CondICNFDist( - mach::MLJBase.Machine{<:CondICNFModel}, - mode::Mode, - ys::AbstractVecOrMat{<:Real}, -) - (ps, st) = MLJModelInterface.fitted_params(mach) - CondICNFDist(mach.model.m, mode, ys, ps, st) -end - -function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(Distributions._logpdf(d, hcat(x))) - else - error("Not Implemented") - end -end -function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - Distributions._logpdf.(d, eachcol(A)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) - else - error("Not Implemented") - end -end -function Distributions._rand!( - rng::Random.AbstractRNG, - d::CondICNFDist, - x::AbstractVector{<:Real}, -) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - x .= generate(d.m, d.mode, d.ys, d.ps, d.st) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - x .= Distributions._rand!(rng, d, hcat(x)) - else - error("Not Implemented") - end -end -function Distributions._rand!( - rng::Random.AbstractRNG, - d::CondICNFDist, - A::AbstractMatrix{<:Real}, -) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) - else - error("Not Implemented") - end -end +struct CondICNFDist{AICNF <: AbstractICNF} <: ICNFDistribution{AICNF} + m::AICNF + mode::Mode + ys::AbstractVecOrMat{<:Real} + ps::Any + st::NamedTuple +end + +function CondICNFDist( + mach::MLJBase.Machine{<:CondICNFModel}, + mode::Mode, + ys::AbstractVecOrMat{<:Real}, +) + (ps, st) = MLJModelInterface.fitted_params(mach) + CondICNFDist(mach.model.m, mode, ys, ps, st) +end + +function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + first(inference(d.m, d.mode, x, d.ys, d.ps, d.st)) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + first(Distributions._logpdf(d, hcat(x))) + else + error("Not Implemented") + end +end +function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + Distributions._logpdf.(d, eachcol(A)) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st)) + else + error("Not Implemented") + end +end +function Distributions._rand!( + rng::Random.AbstractRNG, + d::CondICNFDist, + x::AbstractVector{<:Real}, +) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + x .= generate(d.m, d.mode, d.ys, d.ps, d.st) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + x .= Distributions._rand!(rng, d, hcat(x)) + else + error("Not Implemented") + end +end +function Distributions._rand!( + rng::Random.AbstractRNG, + d::CondICNFDist, + A::AbstractMatrix{<:Real}, +) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2)) + else + error("Not Implemented") + end +end diff --git a/src/exts/dist_ext/core_icnf.jl b/src/exts/dist_ext/core_icnf.jl index ea7448a0..47a8fd02 100644 --- a/src/exts/dist_ext/core_icnf.jl +++ b/src/exts/dist_ext/core_icnf.jl @@ -1,58 +1,58 @@ -struct ICNFDist{AICNF <: AbstractICNF} <: ICNFDistribution{AICNF} - m::AICNF - mode::Mode - ps::Any - st::NamedTuple -end - -function ICNFDist(mach::MLJBase.Machine{<:ICNFModel}, mode::Mode) - (ps, st) = MLJModelInterface.fitted_params(mach) - ICNFDist(mach.model.m, mode, ps, st) -end - -function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - first(inference(d.m, d.mode, x, d.ps, d.st)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(Distributions._logpdf(d, hcat(x))) - else - error("Not Implemented") - end -end - -function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real}) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - Distributions._logpdf.(d, eachcol(A)) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - first(inference(d.m, d.mode, A, d.ps, d.st)) - else - error("Not Implemented") - end -end - -function Distributions._rand!( - rng::Random.AbstractRNG, - d::ICNFDist, - x::AbstractVector{<:Real}, -) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - x .= generate(d.m, d.mode, d.ps, d.st) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - x .= Distributions._rand!(rng, d, hcat(x)) - else - error("Not Implemented") - end -end -function Distributions._rand!( - rng::Random.AbstractRNG, - d::ICNFDist, - A::AbstractMatrix{<:Real}, -) - if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} - A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...) - elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} - A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) - else - error("Not Implemented") - end -end +struct ICNFDist{AICNF <: AbstractICNF} <: ICNFDistribution{AICNF} + m::AICNF + mode::Mode + ps::Any + st::NamedTuple +end + +function ICNFDist(mach::MLJBase.Machine{<:ICNFModel}, mode::Mode) + (ps, st) = MLJModelInterface.fitted_params(mach) + ICNFDist(mach.model.m, mode, ps, st) +end + +function Distributions._logpdf(d::ICNFDist, x::AbstractVector{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + first(inference(d.m, d.mode, x, d.ps, d.st)) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + first(Distributions._logpdf(d, hcat(x))) + else + error("Not Implemented") + end +end + +function Distributions._logpdf(d::ICNFDist, A::AbstractMatrix{<:Real}) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + Distributions._logpdf.(d, eachcol(A)) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + first(inference(d.m, d.mode, A, d.ps, d.st)) + else + error("Not Implemented") + end +end + +function Distributions._rand!( + rng::Random.AbstractRNG, + d::ICNFDist, + x::AbstractVector{<:Real}, +) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + x .= generate(d.m, d.mode, d.ps, d.st) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + x .= Distributions._rand!(rng, d, hcat(x)) + else + error("Not Implemented") + end +end +function Distributions._rand!( + rng::Random.AbstractRNG, + d::ICNFDist, + A::AbstractMatrix{<:Real}, +) + if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode} + A .= hcat(Distributions._rand!.(rng, d, eachcol(A))...) + elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode} + A .= generate(d.m, d.mode, d.ps, d.st, size(A, 2)) + else + error("Not Implemented") + end +end diff --git a/src/exts/mlj_ext/core.jl b/src/exts/mlj_ext/core.jl index bcc10edb..cdf45030 100644 --- a/src/exts/mlj_ext/core.jl +++ b/src/exts/mlj_ext/core.jl @@ -1,21 +1,21 @@ -function MLJModelInterface.fitted_params(::MLJICNF, fitresult) - (ps, st) = fitresult - (learned_parameters = ps, states = st) -end - -@inline function make_opt_loss( - icnf::AbstractICNF{T, CM, INPLACE, COND}, - mode::Mode, - st::NamedTuple, - loss_::Function, -) where {T, CM, INPLACE, COND} - function opt_loss_org(u, p, xs) - loss_(icnf, mode, xs, u, st) - end - - function opt_loss_cond(u, p, xs, ys) - loss_(icnf, mode, xs, ys, u, st) - end - - ifelse(COND, opt_loss_cond, opt_loss_org) -end +function MLJModelInterface.fitted_params(::MLJICNF, fitresult) + (ps, st) = fitresult + (learned_parameters = ps, states = st) +end + +@inline function make_opt_loss( + icnf::AbstractICNF{T, CM, INPLACE, COND}, + mode::Mode, + st::NamedTuple, + loss_::Function, +) where {T, CM, INPLACE, COND} + function opt_loss_org(u, p, xs) + loss_(icnf, mode, xs, u, st) + end + + function opt_loss_cond(u, p, xs, ys) + loss_(icnf, mode, xs, ys, u, st) + end + + ifelse(COND, opt_loss_cond, opt_loss_org) +end diff --git a/src/exts/mlj_ext/core_cond_icnf.jl b/src/exts/mlj_ext/core_cond_icnf.jl index 7220cb9e..86dfe4f8 100644 --- a/src/exts/mlj_ext/core_cond_icnf.jl +++ b/src/exts/mlj_ext/core_cond_icnf.jl @@ -1,157 +1,157 @@ -mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} - m::AICNF - loss::Function - - optimizers::Tuple - n_epochs::Int - adtype::ADTypes.AbstractADType - - use_batch::Bool - batch_size::Int -end - -function CondICNFModel( - m::AbstractICNF, - loss::Function = loss; - optimizers::Tuple = (Optimisers.Lion(),), - n_epochs::Int = 300, - adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), - use_batch::Bool = true, - batch_size::Int = 32, -) - CondICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size) -end - -function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) - X, Y = XY - x = collect(transpose(MLJModelInterface.matrix(X))) - y = collect(transpose(MLJModelInterface.matrix(Y))) - ps, st = LuxCore.setup(model.m.rng, model.m) - ps = ComponentArrays.ComponentArray(ps) - if model.m.resource isa ComputationalResources.CUDALibs - gdev = Lux.gpu_device() - x = gdev(x) - y = gdev(y) - ps = gdev(ps) - st = gdev(st) - end - optfunc = SciMLBase.OptimizationFunction( - make_opt_loss(model.m, TrainMode(), st, model.loss), - model.adtype, - ) - optprob = SciMLBase.OptimizationProblem(optfunc, ps) - tst_overall = @timed for opt in model.optimizers - tst_epochs = @timed for ep in 1:(model.n_epochs) - if model.use_batch - if model.m.compute_mode isa VectorMode - data = MLUtils.DataLoader( - (x, y); - batchsize = -1, - shuffle = true, - partial = true, - parallel = false, - buffer = false, - ) - elseif model.m.compute_mode isa MatrixMode - data = MLUtils.DataLoader( - (x, y); - batchsize = model.batch_size, - shuffle = true, - partial = true, - parallel = false, - buffer = false, - ) - else - error("Not Implemented") - end - else - data = [(x, y)] - end - optprob_re = SciMLBase.remake(optprob; u0 = ps) - tst_one = - @timed res = SciMLBase.solve(optprob_re, opt, data; progress = true) - ps .= res.u - @info( - "Fitting (epoch: $ep of $(model.n_epochs)) - $(typeof(opt).name.name)", - "elapsed time (seconds)" = tst_one.time, - "garbage collection time (seconds)" = tst_one.gctime, - "allocated (bytes)" = tst_one.bytes, - "final loss value" = res.objective, - ) - end - @info( - "Fitting (all epochs) - $(typeof(opt).name.name)", - "elapsed time (seconds)" = tst_epochs.time, - "garbage collection time (seconds)" = tst_epochs.gctime, - "allocated (bytes)" = tst_epochs.bytes, - ) - end - @info( - "Fitting - Overall", - "elapsed time (seconds)" = tst_overall.time, - "garbage collection time (seconds)" = tst_overall.gctime, - "allocated (bytes)" = tst_overall.bytes, - ) - - fitresult = (ps, st) - cache = nothing - report = (stats = tst_overall,) - (fitresult, cache, report) -end - -function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) - Xnew, Ynew = XYnew - xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) - ynew = collect(transpose(MLJModelInterface.matrix(Ynew))) - if model.m.resource isa ComputationalResources.CUDALibs - gdev = Lux.gpu_device() - xnew = gdev(xnew) - ynew = gdev(ynew) - end - (ps, st) = fitresult - - tst = @timed if model.m.compute_mode isa VectorMode - logp̂x = broadcast( - (x, y) -> first(inference(model.m, TestMode(), x, y, ps, st)), - eachcol(xnew), - eachcol(ynew), - ) - elseif model.m.compute_mode isa MatrixMode - logp̂x = first(inference(model.m, TestMode(), xnew, ynew, ps, st)) - else - error("Not Implemented") - end - @info( - "Transforming", - "elapsed time (seconds)" = tst.time, - "garbage collection time (seconds)" = tst.gctime, - "allocated (bytes)" = tst.bytes, - ) - - DataFrames.DataFrame(; px = exp.(logp̂x)) -end - -MLJBase.metadata_pkg( - CondICNFModel; - package_name = "ContinuousNormalizingFlows", - package_uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac", - package_url = "https://github.com/impICNF/ContinuousNormalizingFlows.jl", - is_pure_julia = true, - package_license = "MIT", - is_wrapper = false, -) -MLJBase.metadata_model( - CondICNFModel; - input_scitype = Tuple{ - ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}, - ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}, - }, - target_scitype = ScientificTypesBase.Table{ - AbstractVector{ScientificTypesBase.Continuous}, - }, - output_scitype = ScientificTypesBase.Table{ - AbstractVector{ScientificTypesBase.Continuous}, - }, - supports_weights = false, - load_path = "ContinuousNormalizingFlows.CondICNFModel", -) +mutable struct CondICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} + m::AICNF + loss::Function + + optimizers::Tuple + n_epochs::Int + adtype::ADTypes.AbstractADType + + use_batch::Bool + batch_size::Int +end + +function CondICNFModel( + m::AbstractICNF, + loss::Function = loss; + optimizers::Tuple = (Optimisers.Lion(),), + n_epochs::Int = 300, + adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), + use_batch::Bool = true, + batch_size::Int = 32, +) + CondICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size) +end + +function MLJModelInterface.fit(model::CondICNFModel, verbosity, XY) + X, Y = XY + x = collect(transpose(MLJModelInterface.matrix(X))) + y = collect(transpose(MLJModelInterface.matrix(Y))) + ps, st = LuxCore.setup(model.m.rng, model.m) + ps = ComponentArrays.ComponentArray(ps) + if model.m.resource isa ComputationalResources.CUDALibs + gdev = Lux.gpu_device() + x = gdev(x) + y = gdev(y) + ps = gdev(ps) + st = gdev(st) + end + optfunc = SciMLBase.OptimizationFunction( + make_opt_loss(model.m, TrainMode(), st, model.loss), + model.adtype, + ) + optprob = SciMLBase.OptimizationProblem(optfunc, ps) + tst_overall = @timed for opt in model.optimizers + tst_epochs = @timed for ep in 1:(model.n_epochs) + if model.use_batch + if model.m.compute_mode isa VectorMode + data = MLUtils.DataLoader( + (x, y); + batchsize = -1, + shuffle = true, + partial = true, + parallel = false, + buffer = false, + ) + elseif model.m.compute_mode isa MatrixMode + data = MLUtils.DataLoader( + (x, y); + batchsize = model.batch_size, + shuffle = true, + partial = true, + parallel = false, + buffer = false, + ) + else + error("Not Implemented") + end + else + data = [(x, y)] + end + optprob_re = SciMLBase.remake(optprob; u0 = ps) + tst_one = + @timed res = SciMLBase.solve(optprob_re, opt, data; progress = true) + ps .= res.u + @info( + "Fitting (epoch: $ep of $(model.n_epochs)) - $(typeof(opt).name.name)", + "elapsed time (seconds)" = tst_one.time, + "garbage collection time (seconds)" = tst_one.gctime, + "allocated (bytes)" = tst_one.bytes, + "final loss value" = res.objective, + ) + end + @info( + "Fitting (all epochs) - $(typeof(opt).name.name)", + "elapsed time (seconds)" = tst_epochs.time, + "garbage collection time (seconds)" = tst_epochs.gctime, + "allocated (bytes)" = tst_epochs.bytes, + ) + end + @info( + "Fitting - Overall", + "elapsed time (seconds)" = tst_overall.time, + "garbage collection time (seconds)" = tst_overall.gctime, + "allocated (bytes)" = tst_overall.bytes, + ) + + fitresult = (ps, st) + cache = nothing + report = (stats = tst_overall,) + (fitresult, cache, report) +end + +function MLJModelInterface.transform(model::CondICNFModel, fitresult, XYnew) + Xnew, Ynew = XYnew + xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) + ynew = collect(transpose(MLJModelInterface.matrix(Ynew))) + if model.m.resource isa ComputationalResources.CUDALibs + gdev = Lux.gpu_device() + xnew = gdev(xnew) + ynew = gdev(ynew) + end + (ps, st) = fitresult + + tst = @timed if model.m.compute_mode isa VectorMode + logp̂x = broadcast( + (x, y) -> first(inference(model.m, TestMode(), x, y, ps, st)), + eachcol(xnew), + eachcol(ynew), + ) + elseif model.m.compute_mode isa MatrixMode + logp̂x = first(inference(model.m, TestMode(), xnew, ynew, ps, st)) + else + error("Not Implemented") + end + @info( + "Transforming", + "elapsed time (seconds)" = tst.time, + "garbage collection time (seconds)" = tst.gctime, + "allocated (bytes)" = tst.bytes, + ) + + DataFrames.DataFrame(; px = exp.(logp̂x)) +end + +MLJBase.metadata_pkg( + CondICNFModel; + package_name = "ContinuousNormalizingFlows", + package_uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac", + package_url = "https://github.com/impICNF/ContinuousNormalizingFlows.jl", + is_pure_julia = true, + package_license = "MIT", + is_wrapper = false, +) +MLJBase.metadata_model( + CondICNFModel; + input_scitype = Tuple{ + ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}, + ScientificTypesBase.Table{AbstractVector{ScientificTypesBase.Continuous}}, + }, + target_scitype = ScientificTypesBase.Table{ + AbstractVector{ScientificTypesBase.Continuous}, + }, + output_scitype = ScientificTypesBase.Table{ + AbstractVector{ScientificTypesBase.Continuous}, + }, + supports_weights = false, + load_path = "ContinuousNormalizingFlows.CondICNFModel", +) diff --git a/src/exts/mlj_ext/core_icnf.jl b/src/exts/mlj_ext/core_icnf.jl index 510f4cac..218282e0 100644 --- a/src/exts/mlj_ext/core_icnf.jl +++ b/src/exts/mlj_ext/core_icnf.jl @@ -1,148 +1,148 @@ -mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} - m::AICNF - loss::Function - - optimizers::Tuple - n_epochs::Int - adtype::ADTypes.AbstractADType - - use_batch::Bool - batch_size::Int -end - -function ICNFModel( - m::AbstractICNF, - loss::Function = loss; - optimizers::Tuple = (Optimisers.Lion(),), - n_epochs::Int = 300, - adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), - use_batch::Bool = true, - batch_size::Int = 32, -) - ICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size) -end - -function MLJModelInterface.fit(model::ICNFModel, verbosity, X) - x = collect(transpose(MLJModelInterface.matrix(X))) - ps, st = LuxCore.setup(model.m.rng, model.m) - ps = ComponentArrays.ComponentArray(ps) - if model.m.resource isa ComputationalResources.CUDALibs - gdev = Lux.gpu_device() - x = gdev(x) - ps = gdev(ps) - st = gdev(st) - end - optfunc = SciMLBase.OptimizationFunction( - make_opt_loss(model.m, TrainMode(), st, model.loss), - model.adtype, - ) - optprob = SciMLBase.OptimizationProblem(optfunc, ps) - - tst_overall = @timed for opt in model.optimizers - tst_epochs = @timed for ep in 1:(model.n_epochs) - if model.use_batch - if model.m.compute_mode isa VectorMode - data = MLUtils.DataLoader( - (x,); - batchsize = -1, - shuffle = true, - partial = true, - parallel = false, - buffer = false, - ) - elseif model.m.compute_mode isa MatrixMode - data = MLUtils.DataLoader( - (x,); - batchsize = model.batch_size, - shuffle = true, - partial = true, - parallel = false, - buffer = false, - ) - else - error("Not Implemented") - end - else - data = [(x,)] - end - optprob_re = SciMLBase.remake(optprob; u0 = ps) - tst_one = - @timed res = SciMLBase.solve(optprob_re, opt, data; progress = true) - ps .= res.u - @info( - "Fitting (epoch: $ep of $(model.n_epochs)) - $(typeof(opt).name.name)", - "elapsed time (seconds)" = tst_one.time, - "garbage collection time (seconds)" = tst_one.gctime, - "allocated (bytes)" = tst_one.bytes, - "final loss value" = res.objective, - ) - end - @info( - "Fitting (all epochs) - $(typeof(opt).name.name)", - "elapsed time (seconds)" = tst_epochs.time, - "garbage collection time (seconds)" = tst_epochs.gctime, - "allocated (bytes)" = tst_epochs.bytes, - ) - end - @info( - "Fitting - Overall", - "elapsed time (seconds)" = tst_overall.time, - "garbage collection time (seconds)" = tst_overall.gctime, - "allocated (bytes)" = tst_overall.bytes, - ) - - fitresult = (ps, st) - cache = nothing - report = (stats = tst_overall,) - (fitresult, cache, report) -end - -function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) - xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) - if model.m.resource isa ComputationalResources.CUDALibs - gdev = Lux.gpu_device() - xnew = gdev(xnew) - end - (ps, st) = fitresult - - tst = @timed if model.m.compute_mode isa VectorMode - logp̂x = broadcast(x -> first(inference(model.m, TestMode(), x, ps, st)), eachcol(xnew)) - elseif model.m.compute_mode isa MatrixMode - logp̂x = first(inference(model.m, TestMode(), xnew, ps, st)) - else - error("Not Implemented") - end - - @info( - "Transforming", - "elapsed time (seconds)" = tst.time, - "garbage collection time (seconds)" = tst.gctime, - "allocated (bytes)" = tst.bytes, - ) - - DataFrames.DataFrame(; px = exp.(logp̂x)) -end - -MLJBase.metadata_pkg( - ICNFModel; - package_name = "ContinuousNormalizingFlows", - package_uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac", - package_url = "https://github.com/impICNF/ContinuousNormalizingFlows.jl", - is_pure_julia = true, - package_license = "MIT", - is_wrapper = false, -) -MLJBase.metadata_model( - ICNFModel; - input_scitype = ScientificTypesBase.Table{ - AbstractVector{ScientificTypesBase.Continuous}, - }, - target_scitype = ScientificTypesBase.Table{ - AbstractVector{ScientificTypesBase.Continuous}, - }, - output_scitype = ScientificTypesBase.Table{ - AbstractVector{ScientificTypesBase.Continuous}, - }, - supports_weights = false, - load_path = "ContinuousNormalizingFlows.ICNFModel", -) +mutable struct ICNFModel{AICNF <: AbstractICNF} <: MLJICNF{AICNF} + m::AICNF + loss::Function + + optimizers::Tuple + n_epochs::Int + adtype::ADTypes.AbstractADType + + use_batch::Bool + batch_size::Int +end + +function ICNFModel( + m::AbstractICNF, + loss::Function = loss; + optimizers::Tuple = (Optimisers.Lion(),), + n_epochs::Int = 300, + adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(), + use_batch::Bool = true, + batch_size::Int = 32, +) + ICNFModel(m, loss, optimizers, n_epochs, adtype, use_batch, batch_size) +end + +function MLJModelInterface.fit(model::ICNFModel, verbosity, X) + x = collect(transpose(MLJModelInterface.matrix(X))) + ps, st = LuxCore.setup(model.m.rng, model.m) + ps = ComponentArrays.ComponentArray(ps) + if model.m.resource isa ComputationalResources.CUDALibs + gdev = Lux.gpu_device() + x = gdev(x) + ps = gdev(ps) + st = gdev(st) + end + optfunc = SciMLBase.OptimizationFunction( + make_opt_loss(model.m, TrainMode(), st, model.loss), + model.adtype, + ) + optprob = SciMLBase.OptimizationProblem(optfunc, ps) + + tst_overall = @timed for opt in model.optimizers + tst_epochs = @timed for ep in 1:(model.n_epochs) + if model.use_batch + if model.m.compute_mode isa VectorMode + data = MLUtils.DataLoader( + (x,); + batchsize = -1, + shuffle = true, + partial = true, + parallel = false, + buffer = false, + ) + elseif model.m.compute_mode isa MatrixMode + data = MLUtils.DataLoader( + (x,); + batchsize = model.batch_size, + shuffle = true, + partial = true, + parallel = false, + buffer = false, + ) + else + error("Not Implemented") + end + else + data = [(x,)] + end + optprob_re = SciMLBase.remake(optprob; u0 = ps) + tst_one = + @timed res = SciMLBase.solve(optprob_re, opt, data; progress = true) + ps .= res.u + @info( + "Fitting (epoch: $ep of $(model.n_epochs)) - $(typeof(opt).name.name)", + "elapsed time (seconds)" = tst_one.time, + "garbage collection time (seconds)" = tst_one.gctime, + "allocated (bytes)" = tst_one.bytes, + "final loss value" = res.objective, + ) + end + @info( + "Fitting (all epochs) - $(typeof(opt).name.name)", + "elapsed time (seconds)" = tst_epochs.time, + "garbage collection time (seconds)" = tst_epochs.gctime, + "allocated (bytes)" = tst_epochs.bytes, + ) + end + @info( + "Fitting - Overall", + "elapsed time (seconds)" = tst_overall.time, + "garbage collection time (seconds)" = tst_overall.gctime, + "allocated (bytes)" = tst_overall.bytes, + ) + + fitresult = (ps, st) + cache = nothing + report = (stats = tst_overall,) + (fitresult, cache, report) +end + +function MLJModelInterface.transform(model::ICNFModel, fitresult, Xnew) + xnew = collect(transpose(MLJModelInterface.matrix(Xnew))) + if model.m.resource isa ComputationalResources.CUDALibs + gdev = Lux.gpu_device() + xnew = gdev(xnew) + end + (ps, st) = fitresult + + tst = @timed if model.m.compute_mode isa VectorMode + logp̂x = broadcast(x -> first(inference(model.m, TestMode(), x, ps, st)), eachcol(xnew)) + elseif model.m.compute_mode isa MatrixMode + logp̂x = first(inference(model.m, TestMode(), xnew, ps, st)) + else + error("Not Implemented") + end + + @info( + "Transforming", + "elapsed time (seconds)" = tst.time, + "garbage collection time (seconds)" = tst.gctime, + "allocated (bytes)" = tst.bytes, + ) + + DataFrames.DataFrame(; px = exp.(logp̂x)) +end + +MLJBase.metadata_pkg( + ICNFModel; + package_name = "ContinuousNormalizingFlows", + package_uuid = "00b1973d-5b2e-40bf-8604-5c9c1d8f50ac", + package_url = "https://github.com/impICNF/ContinuousNormalizingFlows.jl", + is_pure_julia = true, + package_license = "MIT", + is_wrapper = false, +) +MLJBase.metadata_model( + ICNFModel; + input_scitype = ScientificTypesBase.Table{ + AbstractVector{ScientificTypesBase.Continuous}, + }, + target_scitype = ScientificTypesBase.Table{ + AbstractVector{ScientificTypesBase.Continuous}, + }, + output_scitype = ScientificTypesBase.Table{ + AbstractVector{ScientificTypesBase.Continuous}, + }, + supports_weights = false, + load_path = "ContinuousNormalizingFlows.ICNFModel", +) diff --git a/src/icnf.jl b/src/icnf.jl index 47ae82c6..1b03c2b7 100644 --- a/src/icnf.jl +++ b/src/icnf.jl @@ -1,654 +1,654 @@ -struct Planar{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondPlanar{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - -struct FFJORD{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondFFJORD{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - -struct RNODE{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end -struct CondRNODE{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end - -""" -Implementation of ICNF. - -Refs: - -[Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366) - -[Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. "Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv preprint arXiv:1810.01367 (2018).](https://arxiv.org/abs/1810.01367) - -[Finlay, Chris, Jörn-Henrik Jacobsen, Levon Nurbekyan, and Adam M. Oberman. "How to train your neural ODE: the world of Jacobian and kinetic regularization." arXiv preprint arXiv:2002.02798 (2020).](https://arxiv.org/abs/2002.02798) -""" -struct ICNF{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z, - NORM_J, - NORM_Z_AUG, - NN <: LuxCore.AbstractExplicitLayer, - NVARS <: Int, - RESOURCE <: ComputationalResources.AbstractResource, - BASEDIST <: Distributions.Distribution, - TSPAN <: NTuple{2, T}, - STEERDIST <: Distributions.Distribution, - EPSDIST <: Distributions.Distribution, - SOL_KWARGS <: NamedTuple, - RNG <: Random.AbstractRNG, -} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} - nn::NN - nvars::NVARS - naugmented::NVARS - - compute_mode::CM - resource::RESOURCE - basedist::BASEDIST - tspan::TSPAN - steerdist::STEERDIST - epsdist::EPSDIST - sol_kwargs::SOL_KWARGS - rng::RNG - λ₁::T - λ₂::T - λ₃::T -end - -@inline function n_augment(::ICNF, ::TrainMode) - 2 -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:ADVectorMode, false}, - mode::TestMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, J = AbstractDifferentiation.value_and_jacobian(icnf.compute_mode.adback, snn, z) - l̇ = -LinearAlgebra.tr(only(J)) - vcat(ż, l̇) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:ADVectorMode, true}, - mode::TestMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, J = AbstractDifferentiation.value_and_jacobian(icnf.compute_mode.adback, snn, z) - du[begin:(end - n_aug - 1)] .= ż - du[(end - n_aug)] = -LinearAlgebra.tr(only(J)) - nothing -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIVectorMode, false}, - mode::TestMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z) - l̇ = -LinearAlgebra.tr(J) - vcat(ż, l̇) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIVectorMode, true}, - mode::TestMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z) - du[begin:(end - n_aug - 1)] .= ż - du[(end - n_aug)] = -LinearAlgebra.tr(J) - nothing -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:MatrixMode, false}, - mode::TestMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1), :] - ż, J = jacobian_batched(icnf, snn, z) - l̇ = -transpose(LinearAlgebra.tr.(J)) - vcat(ż, l̇) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:MatrixMode, true}, - mode::TestMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1), :] - ż, J = jacobian_batched(icnf, snn, z) - du[begin:(end - n_aug - 1), :] .= ż - du[(end - n_aug), :] .= -(LinearAlgebra.tr.(J)) - nothing -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:ADVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, VJ = AbstractDifferentiation.value_and_pullback_function( - icnf.compute_mode.adback, - snn, - z, - ) - ϵJ = only(VJ(ϵ)) - l̇ = -LinearAlgebra.dot(ϵJ, ϵ) - Ė = if NORM_Z - LinearAlgebra.norm(ż) - else - zero(T) - end - ṅ = if NORM_J - LinearAlgebra.norm(ϵJ) - else - zero(T) - end - vcat(ż, l̇, Ė, ṅ) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:ADVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, VJ = AbstractDifferentiation.value_and_pullback_function( - icnf.compute_mode.adback, - snn, - z, - ) - ϵJ = only(VJ(ϵ)) - du[begin:(end - n_aug - 1)] .= ż - du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) - du[(end - n_aug + 1)] = if NORM_Z - LinearAlgebra.norm(ż) - else - zero(T) - end - du[(end - n_aug + 2)] = if NORM_J - LinearAlgebra.norm(ϵJ) - else - zero(T) - end - nothing -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:ADJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż_JV = AbstractDifferentiation.value_and_pushforward_function( - icnf.compute_mode.adback, - snn, - z, - ) - ż, Jϵ = ż_JV(ϵ) - Jϵ = only(Jϵ) - l̇ = -LinearAlgebra.dot(ϵ, Jϵ) - Ė = if NORM_Z - LinearAlgebra.norm(ż) - else - zero(T) - end - ṅ = if NORM_J - LinearAlgebra.norm(Jϵ) - else - zero(T) - end - vcat(ż, l̇, Ė, ṅ) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:ADJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż_JV = AbstractDifferentiation.value_and_pushforward_function( - icnf.compute_mode.adback, - snn, - z, - ) - ż, Jϵ = ż_JV(ϵ) - Jϵ = only(Jϵ) - du[begin:(end - n_aug - 1)] .= ż - du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) - du[(end - n_aug + 1)] = if NORM_Z - LinearAlgebra.norm(ż) - else - zero(T) - end - du[(end - n_aug + 2)] = if NORM_J - LinearAlgebra.norm(Jϵ) - else - zero(T) - end - nothing -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) - l̇ = -LinearAlgebra.dot(ϵJ, ϵ) - Ė = if NORM_Z - LinearAlgebra.norm(ż) - else - zero(T) - end - ṅ = if NORM_J - LinearAlgebra.norm(ϵJ) - else - zero(T) - end - vcat(ż, l̇, Ė, ṅ) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) - du[begin:(end - n_aug - 1)] .= ż - du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) - du[(end - n_aug + 1)] = if NORM_Z - LinearAlgebra.norm(ż) - else - zero(T) - end - du[(end - n_aug + 2)] = if NORM_J - LinearAlgebra.norm(ϵJ) - else - zero(T) - end - nothing -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, Jϵ = - DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) - l̇ = -LinearAlgebra.dot(ϵ, Jϵ) - Ė = if NORM_Z - LinearAlgebra.norm(ż) - else - zero(T) - end - ṅ = if NORM_J - LinearAlgebra.norm(Jϵ) - else - zero(T) - end - vcat(ż, l̇, Ė, ṅ) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractVector{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1)] - ż, Jϵ = - DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) - du[begin:(end - n_aug - 1)] .= ż - du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) - du[(end - n_aug + 1)] = if NORM_Z - LinearAlgebra.norm(ż) - else - zero(T) - end - du[(end - n_aug + 2)] = if NORM_J - LinearAlgebra.norm(Jϵ) - else - zero(T) - end - nothing -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1), :] - ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) - l̇ = -sum(ϵJ .* ϵ; dims = 1) - Ė = transpose(if NORM_Z - LinearAlgebra.norm.(eachcol(ż)) - else - zrs_Ė = similar(ż, size(ż, 2)) - ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) - zrs_Ė - end) - ṅ = transpose(if NORM_J - LinearAlgebra.norm.(eachcol(ϵJ)) - else - zrs_ṅ = similar(ż, size(ż, 2)) - ChainRulesCore.@ignore_derivatives fill!(zrs_ṅ, zero(T)) - zrs_ṅ - end) - vcat(ż, l̇, Ė, ṅ) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1), :] - ż, ϵJ = - DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) - du[begin:(end - n_aug - 1), :] .= ż - du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) - du[(end - n_aug + 1), :] .= if NORM_Z - LinearAlgebra.norm.(eachcol(ż)) - else - zero(T) - end - du[(end - n_aug + 2), :] .= if NORM_J - LinearAlgebra.norm.(eachcol(ϵJ)) - else - zero(T) - end - nothing -end - -function augmented_f( - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = - DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) - l̇ = -sum(ϵ .* Jϵ; dims = 1) - Ė = transpose(if NORM_Z - LinearAlgebra.norm.(eachcol(ż)) - else - zrs_Ė = similar(ż, size(ż, 2)) - ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) - zrs_Ė - end) - ṅ = transpose(if NORM_J - LinearAlgebra.norm.(eachcol(Jϵ)) - else - zrs_ṅ = similar(ż, size(ż, 2)) - ChainRulesCore.@ignore_derivatives fill!(zrs_ṅ, zero(T)) - zrs_ṅ - end) - vcat(ż, l̇, Ė, ṅ) -end - -function augmented_f( - du::Any, - u::Any, - p::Any, - ::Any, - icnf::ICNF{T, <:DIJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, - mode::TrainMode, - nn::LuxCore.AbstractExplicitLayer, - st::NamedTuple, - ϵ::AbstractMatrix{T}, -) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} - n_aug = n_augment(icnf, mode) - snn = Lux.StatefulLuxLayer{true}(nn, p, st) - z = u[begin:(end - n_aug - 1), :] - ż, Jϵ = - DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) - du[begin:(end - n_aug - 1), :] .= ż - du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) - du[(end - n_aug + 1), :] .= if NORM_Z - LinearAlgebra.norm.(eachcol(ż)) - else - zero(T) - end - du[(end - n_aug + 2), :] .= if NORM_J - LinearAlgebra.norm.(eachcol(Jϵ)) - else - zero(T) - end - nothing -end - -@inline function loss( - icnf::ICNF{<:AbstractFloat, <:VectorMode}, - mode::TrainMode, - xs::AbstractVector{<:Real}, - ps::Any, - st::NamedTuple, -) - logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st) - -logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ -end - -@inline function loss( - icnf::ICNF{<:AbstractFloat, <:VectorMode}, - mode::TrainMode, - xs::AbstractVector{<:Real}, - ys::AbstractVector{<:Real}, - ps::Any, - st::NamedTuple, -) - logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st) - -logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ -end - -@inline function loss( - icnf::ICNF{<:AbstractFloat, <:MatrixMode}, - mode::TrainMode, - xs::AbstractMatrix{<:Real}, - ps::Any, - st::NamedTuple, -) - logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st) - Statistics.mean(-logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ) -end - -@inline function loss( - icnf::ICNF{<:AbstractFloat, <:MatrixMode}, - mode::TrainMode, - xs::AbstractMatrix{<:Real}, - ys::AbstractMatrix{<:Real}, - ps::Any, - st::NamedTuple, -) - logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st) - Statistics.mean(-logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ) -end +struct Planar{ + T <: AbstractFloat, + CM <: ComputeMode, + INPLACE, + COND, + AUGMENTED, + STEER, + NORM_Z_AUG, +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +struct CondPlanar{ + T <: AbstractFloat, + CM <: ComputeMode, + INPLACE, + COND, + AUGMENTED, + STEER, + NORM_Z_AUG, +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end + +struct FFJORD{ + T <: AbstractFloat, + CM <: ComputeMode, + INPLACE, + COND, + AUGMENTED, + STEER, + NORM_Z_AUG, +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +struct CondFFJORD{ + T <: AbstractFloat, + CM <: ComputeMode, + INPLACE, + COND, + AUGMENTED, + STEER, + NORM_Z_AUG, +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end + +struct RNODE{ + T <: AbstractFloat, + CM <: ComputeMode, + INPLACE, + COND, + AUGMENTED, + STEER, + NORM_Z_AUG, +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end +struct CondRNODE{ + T <: AbstractFloat, + CM <: ComputeMode, + INPLACE, + COND, + AUGMENTED, + STEER, + NORM_Z_AUG, +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} end + +""" +Implementation of ICNF. + +Refs: + +[Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366) + +[Grathwohl, Will, Ricky TQ Chen, Jesse Bettencourt, Ilya Sutskever, and David Duvenaud. "Ffjord: Free-form continuous dynamics for scalable reversible generative models." arXiv preprint arXiv:1810.01367 (2018).](https://arxiv.org/abs/1810.01367) + +[Finlay, Chris, Jörn-Henrik Jacobsen, Levon Nurbekyan, and Adam M. Oberman. "How to train your neural ODE: the world of Jacobian and kinetic regularization." arXiv preprint arXiv:2002.02798 (2020).](https://arxiv.org/abs/2002.02798) +""" +struct ICNF{ + T <: AbstractFloat, + CM <: ComputeMode, + INPLACE, + COND, + AUGMENTED, + STEER, + NORM_Z, + NORM_J, + NORM_Z_AUG, + NN <: LuxCore.AbstractExplicitLayer, + NVARS <: Int, + RESOURCE <: ComputationalResources.AbstractResource, + BASEDIST <: Distributions.Distribution, + TSPAN <: NTuple{2, T}, + STEERDIST <: Distributions.Distribution, + EPSDIST <: Distributions.Distribution, + SOL_KWARGS <: NamedTuple, + RNG <: Random.AbstractRNG, +} <: AbstractICNF{T, CM, INPLACE, COND, AUGMENTED, STEER, NORM_Z_AUG} + nn::NN + nvars::NVARS + naugmented::NVARS + + compute_mode::CM + resource::RESOURCE + basedist::BASEDIST + tspan::TSPAN + steerdist::STEERDIST + epsdist::EPSDIST + sol_kwargs::SOL_KWARGS + rng::RNG + λ₁::T + λ₂::T + λ₃::T +end + +@inline function n_augment(::ICNF, ::TrainMode) + 2 +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:ADVectorMode, false}, + mode::TestMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, J = AbstractDifferentiation.value_and_jacobian(icnf.compute_mode.adback, snn, z) + l̇ = -LinearAlgebra.tr(only(J)) + vcat(ż, l̇) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:ADVectorMode, true}, + mode::TestMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, J = AbstractDifferentiation.value_and_jacobian(icnf.compute_mode.adback, snn, z) + du[begin:(end - n_aug - 1)] .= ż + du[(end - n_aug)] = -LinearAlgebra.tr(only(J)) + nothing +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIVectorMode, false}, + mode::TestMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z) + l̇ = -LinearAlgebra.tr(J) + vcat(ż, l̇) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIVectorMode, true}, + mode::TestMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, J = DifferentiationInterface.value_and_jacobian(snn, icnf.compute_mode.adback, z) + du[begin:(end - n_aug - 1)] .= ż + du[(end - n_aug)] = -LinearAlgebra.tr(J) + nothing +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:MatrixMode, false}, + mode::TestMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1), :] + ż, J = jacobian_batched(icnf, snn, z) + l̇ = -transpose(LinearAlgebra.tr.(J)) + vcat(ż, l̇) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:MatrixMode, true}, + mode::TestMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1), :] + ż, J = jacobian_batched(icnf, snn, z) + du[begin:(end - n_aug - 1), :] .= ż + du[(end - n_aug), :] .= -(LinearAlgebra.tr.(J)) + nothing +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:ADVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, VJ = AbstractDifferentiation.value_and_pullback_function( + icnf.compute_mode.adback, + snn, + z, + ) + ϵJ = only(VJ(ϵ)) + l̇ = -LinearAlgebra.dot(ϵJ, ϵ) + Ė = if NORM_Z + LinearAlgebra.norm(ż) + else + zero(T) + end + ṅ = if NORM_J + LinearAlgebra.norm(ϵJ) + else + zero(T) + end + vcat(ż, l̇, Ė, ṅ) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:ADVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, VJ = AbstractDifferentiation.value_and_pullback_function( + icnf.compute_mode.adback, + snn, + z, + ) + ϵJ = only(VJ(ϵ)) + du[begin:(end - n_aug - 1)] .= ż + du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) + du[(end - n_aug + 1)] = if NORM_Z + LinearAlgebra.norm(ż) + else + zero(T) + end + du[(end - n_aug + 2)] = if NORM_J + LinearAlgebra.norm(ϵJ) + else + zero(T) + end + nothing +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:ADJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż_JV = AbstractDifferentiation.value_and_pushforward_function( + icnf.compute_mode.adback, + snn, + z, + ) + ż, Jϵ = ż_JV(ϵ) + Jϵ = only(Jϵ) + l̇ = -LinearAlgebra.dot(ϵ, Jϵ) + Ė = if NORM_Z + LinearAlgebra.norm(ż) + else + zero(T) + end + ṅ = if NORM_J + LinearAlgebra.norm(Jϵ) + else + zero(T) + end + vcat(ż, l̇, Ė, ṅ) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:ADJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż_JV = AbstractDifferentiation.value_and_pushforward_function( + icnf.compute_mode.adback, + snn, + z, + ) + ż, Jϵ = ż_JV(ϵ) + Jϵ = only(Jϵ) + du[begin:(end - n_aug - 1)] .= ż + du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) + du[(end - n_aug + 1)] = if NORM_Z + LinearAlgebra.norm(ż) + else + zero(T) + end + du[(end - n_aug + 2)] = if NORM_J + LinearAlgebra.norm(Jϵ) + else + zero(T) + end + nothing +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIVecJacVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, ϵJ = + DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) + l̇ = -LinearAlgebra.dot(ϵJ, ϵ) + Ė = if NORM_Z + LinearAlgebra.norm(ż) + else + zero(T) + end + ṅ = if NORM_J + LinearAlgebra.norm(ϵJ) + else + zero(T) + end + vcat(ż, l̇, Ė, ṅ) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIVecJacVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, ϵJ = + DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) + du[begin:(end - n_aug - 1)] .= ż + du[(end - n_aug)] = -LinearAlgebra.dot(ϵJ, ϵ) + du[(end - n_aug + 1)] = if NORM_Z + LinearAlgebra.norm(ż) + else + zero(T) + end + du[(end - n_aug + 2)] = if NORM_J + LinearAlgebra.norm(ϵJ) + else + zero(T) + end + nothing +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIJacVecVectorMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) + l̇ = -LinearAlgebra.dot(ϵ, Jϵ) + Ė = if NORM_Z + LinearAlgebra.norm(ż) + else + zero(T) + end + ṅ = if NORM_J + LinearAlgebra.norm(Jϵ) + else + zero(T) + end + vcat(ż, l̇, Ė, ṅ) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIJacVecVectorMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractVector{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1)] + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) + du[begin:(end - n_aug - 1)] .= ż + du[(end - n_aug)] = -LinearAlgebra.dot(ϵ, Jϵ) + du[(end - n_aug + 1)] = if NORM_Z + LinearAlgebra.norm(ż) + else + zero(T) + end + du[(end - n_aug + 2)] = if NORM_J + LinearAlgebra.norm(Jϵ) + else + zero(T) + end + nothing +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIVecJacMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1), :] + ż, ϵJ = + DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) + l̇ = -sum(ϵJ .* ϵ; dims = 1) + Ė = transpose(if NORM_Z + LinearAlgebra.norm.(eachcol(ż)) + else + zrs_Ė = similar(ż, size(ż, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) + zrs_Ė + end) + ṅ = transpose(if NORM_J + LinearAlgebra.norm.(eachcol(ϵJ)) + else + zrs_ṅ = similar(ż, size(ż, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs_ṅ, zero(T)) + zrs_ṅ + end) + vcat(ż, l̇, Ė, ṅ) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIVecJacMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1), :] + ż, ϵJ = + DifferentiationInterface.value_and_pullback(snn, icnf.compute_mode.adback, z, ϵ) + du[begin:(end - n_aug - 1), :] .= ż + du[(end - n_aug), :] .= -vec(sum(ϵJ .* ϵ; dims = 1)) + du[(end - n_aug + 1), :] .= if NORM_Z + LinearAlgebra.norm.(eachcol(ż)) + else + zero(T) + end + du[(end - n_aug + 2), :] .= if NORM_J + LinearAlgebra.norm.(eachcol(ϵJ)) + else + zero(T) + end + nothing +end + +function augmented_f( + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIJacVecMatrixMode, false, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1), :] + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) + l̇ = -sum(ϵ .* Jϵ; dims = 1) + Ė = transpose(if NORM_Z + LinearAlgebra.norm.(eachcol(ż)) + else + zrs_Ė = similar(ż, size(ż, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs_Ė, zero(T)) + zrs_Ė + end) + ṅ = transpose(if NORM_J + LinearAlgebra.norm.(eachcol(Jϵ)) + else + zrs_ṅ = similar(ż, size(ż, 2)) + ChainRulesCore.@ignore_derivatives fill!(zrs_ṅ, zero(T)) + zrs_ṅ + end) + vcat(ż, l̇, Ė, ṅ) +end + +function augmented_f( + du::Any, + u::Any, + p::Any, + ::Any, + icnf::ICNF{T, <:DIJacVecMatrixMode, true, COND, AUGMENTED, STEER, NORM_Z, NORM_J}, + mode::TrainMode, + nn::LuxCore.AbstractExplicitLayer, + st::NamedTuple, + ϵ::AbstractMatrix{T}, +) where {T <: AbstractFloat, COND, AUGMENTED, STEER, NORM_Z, NORM_J} + n_aug = n_augment(icnf, mode) + snn = Lux.StatefulLuxLayer{true}(nn, p, st) + z = u[begin:(end - n_aug - 1), :] + ż, Jϵ = + DifferentiationInterface.value_and_pushforward(snn, icnf.compute_mode.adback, z, ϵ) + du[begin:(end - n_aug - 1), :] .= ż + du[(end - n_aug), :] .= -vec(sum(ϵ .* Jϵ; dims = 1)) + du[(end - n_aug + 1), :] .= if NORM_Z + LinearAlgebra.norm.(eachcol(ż)) + else + zero(T) + end + du[(end - n_aug + 2), :] .= if NORM_J + LinearAlgebra.norm.(eachcol(Jϵ)) + else + zero(T) + end + nothing +end + +@inline function loss( + icnf::ICNF{<:AbstractFloat, <:VectorMode}, + mode::TrainMode, + xs::AbstractVector{<:Real}, + ps::Any, + st::NamedTuple, +) + logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st) + -logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ +end + +@inline function loss( + icnf::ICNF{<:AbstractFloat, <:VectorMode}, + mode::TrainMode, + xs::AbstractVector{<:Real}, + ys::AbstractVector{<:Real}, + ps::Any, + st::NamedTuple, +) + logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st) + -logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ +end + +@inline function loss( + icnf::ICNF{<:AbstractFloat, <:MatrixMode}, + mode::TrainMode, + xs::AbstractMatrix{<:Real}, + ps::Any, + st::NamedTuple, +) + logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ps, st) + Statistics.mean(-logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ) +end + +@inline function loss( + icnf::ICNF{<:AbstractFloat, <:MatrixMode}, + mode::TrainMode, + xs::AbstractMatrix{<:Real}, + ys::AbstractMatrix{<:Real}, + ps::Any, + st::NamedTuple, +) + logp̂x, (Ė, ṅ, Ȧ) = inference(icnf, mode, xs, ys, ps, st) + Statistics.mean(-logp̂x + icnf.λ₁ * Ė + icnf.λ₂ * ṅ + icnf.λ₃ * Ȧ) +end diff --git a/src/layers/cond_layer.jl b/src/layers/cond_layer.jl index 2e5537a1..5fce0af1 100644 --- a/src/layers/cond_layer.jl +++ b/src/layers/cond_layer.jl @@ -1,9 +1,9 @@ -struct CondLayer{NN <: LuxCore.AbstractExplicitLayer, AT <: AbstractArray} <: - LuxCore.AbstractExplicitContainerLayer{(:nn,)} - nn::NN - ys::AT -end - -@inline function (m::CondLayer)(z::AbstractVecOrMat, ps::Any, st::NamedTuple) - LuxCore.apply(m.nn, vcat(z, m.ys), ps, st) -end +struct CondLayer{NN <: LuxCore.AbstractExplicitLayer, AT <: AbstractArray} <: + LuxCore.AbstractExplicitContainerLayer{(:nn,)} + nn::NN + ys::AT +end + +@inline function (m::CondLayer)(z::AbstractVecOrMat, ps::Any, st::NamedTuple) + LuxCore.apply(m.nn, vcat(z, m.ys), ps, st) +end diff --git a/src/layers/mul_layer.jl b/src/layers/mul_layer.jl index 02804a40..616d9fc3 100644 --- a/src/layers/mul_layer.jl +++ b/src/layers/mul_layer.jl @@ -1,35 +1,35 @@ -struct MulLayer{F1, F2, NVARS <: Int} <: LuxCore.AbstractExplicitLayer - activation::F1 - nvars::NVARS - init_weight::F2 -end - -function MulLayer( - nvars::Int, - activation::Any = identity; - init_weight::Any = Lux.glorot_uniform, - allow_fast_activation::Bool = true, -) - activation = ifelse(allow_fast_activation, NNlib.fast_act(activation), activation) - MulLayer{typeof(activation), typeof(init_weight), typeof(nvars)}( - activation, - nvars, - init_weight, - ) -end - -function LuxCore.initialparameters(rng::Random.AbstractRNG, m::MulLayer) - (weight = m.init_weight(rng, m.nvars, m.nvars),) -end - -function LuxCore.parameterlength(m::MulLayer) - m.nvars * m.nvars -end - -function LuxCore.outputsize(m::MulLayer) - (m.nvars,) -end - -@inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple) - m.activation.(Octavian.matmul(ps.weight, x)), st -end +struct MulLayer{F1, F2, NVARS <: Int} <: LuxCore.AbstractExplicitLayer + activation::F1 + nvars::NVARS + init_weight::F2 +end + +function MulLayer( + nvars::Int, + activation::Any = identity; + init_weight::Any = Lux.glorot_uniform, + allow_fast_activation::Bool = true, +) + activation = ifelse(allow_fast_activation, NNlib.fast_act(activation), activation) + MulLayer{typeof(activation), typeof(init_weight), typeof(nvars)}( + activation, + nvars, + init_weight, + ) +end + +function LuxCore.initialparameters(rng::Random.AbstractRNG, m::MulLayer) + (weight = m.init_weight(rng, m.nvars, m.nvars),) +end + +function LuxCore.parameterlength(m::MulLayer) + m.nvars * m.nvars +end + +function LuxCore.outputsize(m::MulLayer) + (m.nvars,) +end + +@inline function (m::MulLayer)(x::AbstractVecOrMat, ps::Any, st::NamedTuple) + m.activation.(Octavian.matmul(ps.weight, x)), st +end diff --git a/src/layers/planar_layer.jl b/src/layers/planar_layer.jl index 6b087abe..6e3077f3 100644 --- a/src/layers/planar_layer.jl +++ b/src/layers/planar_layer.jl @@ -1,103 +1,103 @@ -""" -Implementation of Planar Layer from - -[Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366) -""" -struct PlanarLayer{use_bias, cond, F1, F2, F3, NVARS <: Int} <: - LuxCore.AbstractExplicitLayer - activation::F1 - nvars::NVARS - init_weight::F2 - init_bias::F3 - n_cond::NVARS -end - -function PlanarLayer( - nvars::Int, - activation::Any = identity; - init_weight::Any = Lux.glorot_uniform, - init_bias::Any = Lux.zeros32, - use_bias::Bool = true, - allow_fast_activation::Bool = true, - n_cond::Int = 0, -) - activation = ifelse(allow_fast_activation, NNlib.fast_act(activation), activation) - PlanarLayer{ - use_bias, - !iszero(n_cond), - typeof(activation), - typeof(init_weight), - typeof(init_bias), - typeof(nvars), - }( - activation, - nvars, - init_weight, - init_bias, - n_cond, - ) -end - -function LuxCore.initialparameters( - rng::Random.AbstractRNG, - layer::PlanarLayer{use_bias, cond}, -) where {use_bias, cond} - ifelse( - use_bias, - ( - u = layer.init_weight(rng, layer.nvars), - w = layer.init_weight( - rng, - ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), - ), - b = layer.init_bias(rng, 1), - ), - ( - u = layer.init_weight(rng, layer.nvars), - w = layer.init_weight( - rng, - ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), - ), - ), - ) -end - -function LuxCore.parameterlength(m::PlanarLayer{use_bias, cond}) where {use_bias, cond} - m.nvars + ifelse(cond, (m.nvars + m.n_cond), m.nvars) + ifelse(use_bias, 1, 0) -end - -function LuxCore.outputsize(m::PlanarLayer) - (m.nvars,) -end - -@inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st -end - -@inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st -end - -@inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple) - ps.u * m.activation.(LinearAlgebra.dot(ps.w, z)), st -end - -@inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple) - ps.u * m.activation.(transpose(ps.w) * z), st -end - -@inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple) - m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st -end - -@inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple) - m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st -end - -@inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple) - m.activation.(LinearAlgebra.dot(ps.w, z)), st -end - -@inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple) - m.activation.(transpose(ps.w) * z), st -end +""" +Implementation of Planar Layer from + +[Chen, Ricky TQ, Yulia Rubanova, Jesse Bettencourt, and David Duvenaud. "Neural Ordinary Differential Equations." arXiv preprint arXiv:1806.07366 (2018).](https://arxiv.org/abs/1806.07366) +""" +struct PlanarLayer{use_bias, cond, F1, F2, F3, NVARS <: Int} <: + LuxCore.AbstractExplicitLayer + activation::F1 + nvars::NVARS + init_weight::F2 + init_bias::F3 + n_cond::NVARS +end + +function PlanarLayer( + nvars::Int, + activation::Any = identity; + init_weight::Any = Lux.glorot_uniform, + init_bias::Any = Lux.zeros32, + use_bias::Bool = true, + allow_fast_activation::Bool = true, + n_cond::Int = 0, +) + activation = ifelse(allow_fast_activation, NNlib.fast_act(activation), activation) + PlanarLayer{ + use_bias, + !iszero(n_cond), + typeof(activation), + typeof(init_weight), + typeof(init_bias), + typeof(nvars), + }( + activation, + nvars, + init_weight, + init_bias, + n_cond, + ) +end + +function LuxCore.initialparameters( + rng::Random.AbstractRNG, + layer::PlanarLayer{use_bias, cond}, +) where {use_bias, cond} + ifelse( + use_bias, + ( + u = layer.init_weight(rng, layer.nvars), + w = layer.init_weight( + rng, + ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), + ), + b = layer.init_bias(rng, 1), + ), + ( + u = layer.init_weight(rng, layer.nvars), + w = layer.init_weight( + rng, + ifelse(cond, (layer.nvars + layer.n_cond), layer.nvars), + ), + ), + ) +end + +function LuxCore.parameterlength(m::PlanarLayer{use_bias, cond}) where {use_bias, cond} + m.nvars + ifelse(cond, (m.nvars + m.n_cond), m.nvars) + ifelse(use_bias, 1, 0) +end + +function LuxCore.outputsize(m::PlanarLayer) + (m.nvars,) +end + +@inline function (m::PlanarLayer{true})(z::AbstractVector, ps::Any, st::NamedTuple) + ps.u * m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st +end + +@inline function (m::PlanarLayer{true})(z::AbstractMatrix, ps::Any, st::NamedTuple) + ps.u * m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st +end + +@inline function (m::PlanarLayer{false})(z::AbstractVector, ps::Any, st::NamedTuple) + ps.u * m.activation.(LinearAlgebra.dot(ps.w, z)), st +end + +@inline function (m::PlanarLayer{false})(z::AbstractMatrix, ps::Any, st::NamedTuple) + ps.u * m.activation.(transpose(ps.w) * z), st +end + +@inline function pl_h(m::PlanarLayer{true}, z::AbstractVector, ps::Any, st::NamedTuple) + m.activation.(LinearAlgebra.dot(ps.w, z) + only(ps.b)), st +end + +@inline function pl_h(m::PlanarLayer{true}, z::AbstractMatrix, ps::Any, st::NamedTuple) + m.activation.(muladd(transpose(ps.w), z, only(ps.b))), st +end + +@inline function pl_h(m::PlanarLayer{false}, z::AbstractVector, ps::Any, st::NamedTuple) + m.activation.(LinearAlgebra.dot(ps.w, z)), st +end + +@inline function pl_h(m::PlanarLayer{false}, z::AbstractMatrix, ps::Any, st::NamedTuple) + m.activation.(transpose(ps.w) * z), st +end diff --git a/src/types.jl b/src/types.jl index ae9716e7..c3139349 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,51 +1,51 @@ -abstract type Mode end -struct TestMode <: Mode end -struct TrainMode <: Mode end - -abstract type ComputeMode{ADBack} end -abstract type VectorMode{ADBack} <: ComputeMode{ADBack} end -abstract type MatrixMode{ADBack} <: ComputeMode{ADBack} end - -abstract type ADVectorMode{ADBack} <: VectorMode{ADBack} end -struct ADVecJacVectorMode{ADBack <: AbstractDifferentiation.AbstractBackend} <: - ADVectorMode{ADBack} - adback::ADBack -end -struct ADJacVecVectorMode{ADBack <: AbstractDifferentiation.AbstractBackend} <: - ADVectorMode{ADBack} - adback::ADBack -end - -abstract type DIVectorMode{ADBack} <: VectorMode{ADBack} end -struct DIVecJacVectorMode{ADBack <: ADTypes.AbstractADType} <: DIVectorMode{ADBack} - adback::ADBack -end -struct DIJacVecVectorMode{ADBack <: ADTypes.AbstractADType} <: DIVectorMode{ADBack} - adback::ADBack -end - -abstract type DIMatrixMode{ADBack} <: MatrixMode{ADBack} end -struct DIVecJacMatrixMode{ADBack <: ADTypes.AbstractADType} <: DIMatrixMode{ADBack} - adback::ADBack -end -struct DIJacVecMatrixMode{ADBack <: ADTypes.AbstractADType} <: DIMatrixMode{ADBack} - adback::ADBack -end - -Base.Base.@deprecate_binding SDVecJacMatrixMode DIVecJacMatrixMode true -Base.Base.@deprecate_binding SDJacVecMatrixMode DIJacVecMatrixMode true - -Base.Base.@deprecate_binding ZygoteVectorMode DIVecJacVectorMode true -Base.Base.@deprecate_binding ZygoteMatrixMode DIVecJacMatrixMode true - -abstract type AbstractICNF{ - T <: AbstractFloat, - CM <: ComputeMode, - INPLACE, - COND, - AUGMENTED, - STEER, - NORM_Z_AUG, -} <: LuxCore.AbstractExplicitContainerLayer{(:nn,)} end - -abstract type MLJICNF{AICNF <: AbstractICNF} <: MLJModelInterface.Unsupervised end +abstract type Mode end +struct TestMode <: Mode end +struct TrainMode <: Mode end + +abstract type ComputeMode{ADBack} end +abstract type VectorMode{ADBack} <: ComputeMode{ADBack} end +abstract type MatrixMode{ADBack} <: ComputeMode{ADBack} end + +abstract type ADVectorMode{ADBack} <: VectorMode{ADBack} end +struct ADVecJacVectorMode{ADBack <: AbstractDifferentiation.AbstractBackend} <: + ADVectorMode{ADBack} + adback::ADBack +end +struct ADJacVecVectorMode{ADBack <: AbstractDifferentiation.AbstractBackend} <: + ADVectorMode{ADBack} + adback::ADBack +end + +abstract type DIVectorMode{ADBack} <: VectorMode{ADBack} end +struct DIVecJacVectorMode{ADBack <: ADTypes.AbstractADType} <: DIVectorMode{ADBack} + adback::ADBack +end +struct DIJacVecVectorMode{ADBack <: ADTypes.AbstractADType} <: DIVectorMode{ADBack} + adback::ADBack +end + +abstract type DIMatrixMode{ADBack} <: MatrixMode{ADBack} end +struct DIVecJacMatrixMode{ADBack <: ADTypes.AbstractADType} <: DIMatrixMode{ADBack} + adback::ADBack +end +struct DIJacVecMatrixMode{ADBack <: ADTypes.AbstractADType} <: DIMatrixMode{ADBack} + adback::ADBack +end + +Base.Base.@deprecate_binding SDVecJacMatrixMode DIVecJacMatrixMode true +Base.Base.@deprecate_binding SDJacVecMatrixMode DIJacVecMatrixMode true + +Base.Base.@deprecate_binding ZygoteVectorMode DIVecJacVectorMode true +Base.Base.@deprecate_binding ZygoteMatrixMode DIVecJacMatrixMode true + +abstract type AbstractICNF{ + T <: AbstractFloat, + CM <: ComputeMode, + INPLACE, + COND, + AUGMENTED, + STEER, + NORM_Z_AUG, +} <: LuxCore.AbstractExplicitContainerLayer{(:nn,)} end + +abstract type MLJICNF{AICNF <: AbstractICNF} <: MLJModelInterface.Unsupervised end diff --git a/src/utils.jl b/src/utils.jl index d691d6f7..4c0c389c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,52 +1,52 @@ -@inline function jacobian_batched( - icnf::AbstractICNF{T, <:DIVecJacMatrixMode}, - f::Lux.StatefulLuxLayer, - xs::AbstractMatrix{<:Real}, -) where {T} - y = f(xs) - z = similar(xs) - ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) - res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) - for i in axes(xs, 1) - ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) - res[i, :, :] = DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, z) - ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T) - end - y, eachslice(copy(res); dims = 3) -end - -@inline function jacobian_batched( - icnf::AbstractICNF{T, <:DIJacVecMatrixMode}, - f::Lux.StatefulLuxLayer, - xs::AbstractMatrix{<:Real}, -) where {T} - y = f(xs) - z = similar(xs) - ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) - res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) - for i in axes(xs, 1) - ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) - res[:, i, :] = - DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, z) - ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T) - end - y, eachslice(copy(res); dims = 3) -end - -@inline function jacobian_batched( - icnf::AbstractICNF{T, <:DIMatrixMode}, - f::Lux.StatefulLuxLayer, - xs::AbstractMatrix{<:Real}, -) where {T} - y, J = DifferentiationInterface.value_and_jacobian(f, icnf.compute_mode.adback, xs) - y, split_jac(J, size(xs, 1)) -end - -@inline function split_jac(x::AbstractMatrix{<:Real}, sz::Integer) - ( - x[i:j, i:j] for (i, j) in zip( - firstindex(x, 1):sz:lastindex(x, 1), - (firstindex(x, 1) + sz - 1):sz:lastindex(x, 1), - ) - ) -end +@inline function jacobian_batched( + icnf::AbstractICNF{T, <:DIVecJacMatrixMode}, + f::Lux.StatefulLuxLayer, + xs::AbstractMatrix{<:Real}, +) where {T} + y = f(xs) + z = similar(xs) + ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) + res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + for i in axes(xs, 1) + ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) + res[i, :, :] = DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, z) + ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T) + end + y, eachslice(copy(res); dims = 3) +end + +@inline function jacobian_batched( + icnf::AbstractICNF{T, <:DIJacVecMatrixMode}, + f::Lux.StatefulLuxLayer, + xs::AbstractMatrix{<:Real}, +) where {T} + y = f(xs) + z = similar(xs) + ChainRulesCore.@ignore_derivatives fill!(z, zero(T)) + res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2)) + for i in axes(xs, 1) + ChainRulesCore.@ignore_derivatives z[i, :] .= one(T) + res[:, i, :] = + DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, z) + ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T) + end + y, eachslice(copy(res); dims = 3) +end + +@inline function jacobian_batched( + icnf::AbstractICNF{T, <:DIMatrixMode}, + f::Lux.StatefulLuxLayer, + xs::AbstractMatrix{<:Real}, +) where {T} + y, J = DifferentiationInterface.value_and_jacobian(f, icnf.compute_mode.adback, xs) + y, split_jac(J, size(xs, 1)) +end + +@inline function split_jac(x::AbstractMatrix{<:Real}, sz::Integer) + ( + x[i:j, i:j] for (i, j) in zip( + firstindex(x, 1):sz:lastindex(x, 1), + (firstindex(x, 1) + sz - 1):sz:lastindex(x, 1), + ) + ) +end diff --git a/test/call_tests.jl b/test/call_tests.jl index c8ba5ae2..9d2f9419 100644 --- a/test/call_tests.jl +++ b/test/call_tests.jl @@ -1,256 +1,256 @@ -Test.@testset "Call Tests" begin - mts = if GROUP == "RNODE" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.RNODE] - elseif GROUP == "FFJORD" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.FFJORD] - elseif GROUP == "Planar" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.Planar] - elseif GROUP == "CondRNODE" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondRNODE] - elseif GROUP == "CondFFJORD" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondFFJORD] - elseif GROUP == "CondPlanar" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondPlanar] - else - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ - ContinuousNormalizingFlows.RNODE, - ContinuousNormalizingFlows.FFJORD, - ContinuousNormalizingFlows.Planar, - ContinuousNormalizingFlows.CondRNODE, - ContinuousNormalizingFlows.CondFFJORD, - ContinuousNormalizingFlows.CondPlanar, - ] - end - omodes = ContinuousNormalizingFlows.Mode[ - ContinuousNormalizingFlows.TrainMode(), - ContinuousNormalizingFlows.TestMode(), - ] - ndata_ = Int[4] - nvars_ = Int[2] - aug_steers = Bool[false, true] - inplaces = Bool[false, true] - adb_list = AbstractDifferentiation.AbstractBackend[ - AbstractDifferentiation.ZygoteBackend(), - AbstractDifferentiation.ReverseDiffBackend(), - AbstractDifferentiation.ForwardDiffBackend(), - ] - adtypes = ADTypes.AbstractADType[ - ADTypes.AutoZygote(), - ADTypes.AutoReverseDiff(), - ADTypes.AutoForwardDiff(), - ] - compute_modes = ContinuousNormalizingFlows.ComputeMode[ - ContinuousNormalizingFlows.ADVecJacVectorMode( - AbstractDifferentiation.ZygoteBackend(), - ), - ContinuousNormalizingFlows.ADJacVecVectorMode( - AbstractDifferentiation.ZygoteBackend(), - ), - ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), - ] - data_types = Type{<:AbstractFloat}[Float32] - resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()] - if CUDA.has_cuda_gpu() && USE_GPU - push!(resources, ComputationalResources.CUDALibs()) - gdev = Lux.gpu_device() - end - - Test.@testset "$resource | $data_type | $compute_mode | inplace = $inplace | aug & steer = $aug_steer | nvars = $nvars | $omode | $mt" for resource in - resources, - data_type in data_types, - compute_mode in compute_modes, - inplace in inplaces, - aug_steer in aug_steers, - nvars in nvars_, - ndata in ndata_, - omode in omodes, - mt in mts - - data_dist = - Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) - data_dist2 = - Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (4, 2))...) - if compute_mode isa ContinuousNormalizingFlows.VectorMode - r = convert.(data_type, rand(data_dist, nvars)) - r2 = convert.(data_type, rand(data_dist2, nvars)) - elseif compute_mode isa ContinuousNormalizingFlows.MatrixMode - r = convert.(data_type, rand(data_dist, nvars, ndata)) - r2 = convert.(data_type, rand(data_dist2, nvars, ndata)) - end - - nn = ifelse( - mt <: Union{ - ContinuousNormalizingFlows.CondRNODE, - ContinuousNormalizingFlows.CondFFJORD, - ContinuousNormalizingFlows.CondPlanar, - }, - ifelse( - mt <: ContinuousNormalizingFlows.CondPlanar, - ifelse( - aug_steer, - Lux.Chain( - ContinuousNormalizingFlows.PlanarLayer( - nvars * 2, - tanh; - n_cond = nvars, - ), - ), - Lux.Chain( - ContinuousNormalizingFlows.PlanarLayer(nvars, tanh; n_cond = nvars), - ), - ), - ifelse( - aug_steer, - Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), - Lux.Chain(Lux.Dense(nvars * 2 => nvars, tanh)), - ), - ), - ifelse( - mt <: ContinuousNormalizingFlows.Planar, - ifelse( - aug_steer, - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars, tanh)), - ), - ifelse( - aug_steer, - Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), - Lux.Chain(Lux.Dense(nvars => nvars, tanh)), - ), - ), - ) - icnf = ifelse( - aug_steer, - ContinuousNormalizingFlows.construct( - mt, - nn, - nvars, - nvars; - data_type, - compute_mode, - inplace, - resource, - steer_rate = convert(data_type, 1e-1), - λ₃ = convert(data_type, 1e-2), - ), - ContinuousNormalizingFlows.construct( - mt, - nn, - nvars; - data_type, - compute_mode, - inplace, - resource, - ), - ) - ps, st = Lux.setup(icnf.rng, icnf) - ps = ComponentArrays.ComponentArray(ps) - if resource isa ComputationalResources.CUDALibs - r = gdev(r) - r2 = gdev(r2) - ps = gdev(ps) - st = gdev(st) - end - - if mt <: Union{ - ContinuousNormalizingFlows.CondRNODE, - ContinuousNormalizingFlows.CondFFJORD, - ContinuousNormalizingFlows.CondPlanar, - } - Test.@test !isnothing( - ContinuousNormalizingFlows.inference(icnf, omode, r, r2, ps, st), - ) - if compute_mode isa ContinuousNormalizingFlows.MatrixMode - Test.@test !isnothing( - ContinuousNormalizingFlows.generate(icnf, omode, r2, ps, st, ndata), - ) - else - Test.@test !isnothing( - ContinuousNormalizingFlows.generate(icnf, omode, r2, ps, st), - ) - end - - Test.@test !isnothing( - ContinuousNormalizingFlows.loss(icnf, omode, r, r2, ps, st), - ) - Test.@test !isnothing(icnf((r, r2), ps, st)) - - diff_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, r, r2, x, st) - diff2_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, x, r2, ps, st) - else - Test.@test !isnothing( - ContinuousNormalizingFlows.inference(icnf, omode, r, ps, st), - ) - if compute_mode isa ContinuousNormalizingFlows.MatrixMode - Test.@test !isnothing( - ContinuousNormalizingFlows.generate(icnf, omode, ps, st, ndata), - ) - else - Test.@test !isnothing( - ContinuousNormalizingFlows.generate(icnf, omode, ps, st), - ) - end - - Test.@test !isnothing(ContinuousNormalizingFlows.loss(icnf, omode, r, ps, st)) - Test.@test !isnothing(icnf(r, ps, st)) - - diff_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, r, x, st) - diff2_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, x, ps, st) - end - - if mt <: Union{ - ContinuousNormalizingFlows.CondRNODE, - ContinuousNormalizingFlows.CondFFJORD, - ContinuousNormalizingFlows.CondPlanar, - } - d = ContinuousNormalizingFlows.CondICNFDist(icnf, omode, r2, ps, st) - else - d = ContinuousNormalizingFlows.ICNFDist(icnf, omode, ps, st) - end - - Test.@test !isnothing(Distributions.logpdf(d, r)) - Test.@test !isnothing(Distributions.pdf(d, r)) - Test.@test !isnothing(rand(d)) - Test.@test !isnothing(rand(d, ndata)) - - Test.@testset "$(typeof(adb).name.name)" for adb in adb_list - Test.@testset "Loss" begin - Test.@testset "ps" begin - Test.@test !isnothing( - AbstractDifferentiation.gradient(adb, diff_loss, ps), - ) - end - Test.@testset "x" begin - Test.@test !isnothing( - AbstractDifferentiation.gradient(adb, diff2_loss, r), - ) broken = - (GROUP != "All") && - adb isa AbstractDifferentiation.ReverseDiffBackend && - compute_mode isa ContinuousNormalizingFlows.MatrixMode && - VERSION >= v"1.10" - end - end - end - Test.@testset "$(typeof(adtype).name.name)" for adtype in adtypes - Test.@testset "Loss" begin - Test.@testset "ps" begin - Test.@test !isnothing( - DifferentiationInterface.gradient(diff_loss, adtype, ps), - ) - end - Test.@testset "x" begin - Test.@test !isnothing( - DifferentiationInterface.gradient(diff2_loss, adtype, r), - ) broken = - (GROUP != "All") && - adtype isa ADTypes.AutoReverseDiff && - compute_mode isa ContinuousNormalizingFlows.MatrixMode && - VERSION >= v"1.10" - end - end - end - end -end +Test.@testset "Call Tests" begin + mts = if GROUP == "RNODE" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.RNODE] + elseif GROUP == "FFJORD" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.FFJORD] + elseif GROUP == "Planar" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.Planar] + elseif GROUP == "CondRNODE" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondRNODE] + elseif GROUP == "CondFFJORD" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondFFJORD] + elseif GROUP == "CondPlanar" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondPlanar] + else + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ + ContinuousNormalizingFlows.RNODE, + ContinuousNormalizingFlows.FFJORD, + ContinuousNormalizingFlows.Planar, + ContinuousNormalizingFlows.CondRNODE, + ContinuousNormalizingFlows.CondFFJORD, + ContinuousNormalizingFlows.CondPlanar, + ] + end + omodes = ContinuousNormalizingFlows.Mode[ + ContinuousNormalizingFlows.TrainMode(), + ContinuousNormalizingFlows.TestMode(), + ] + ndata_ = Int[4] + nvars_ = Int[2] + aug_steers = Bool[false, true] + inplaces = Bool[false, true] + adb_list = AbstractDifferentiation.AbstractBackend[ + AbstractDifferentiation.ZygoteBackend(), + AbstractDifferentiation.ReverseDiffBackend(), + AbstractDifferentiation.ForwardDiffBackend(), + ] + adtypes = ADTypes.AbstractADType[ + ADTypes.AutoZygote(), + ADTypes.AutoReverseDiff(), + ADTypes.AutoForwardDiff(), + ] + compute_modes = ContinuousNormalizingFlows.ComputeMode[ + ContinuousNormalizingFlows.ADVecJacVectorMode( + AbstractDifferentiation.ZygoteBackend(), + ), + ContinuousNormalizingFlows.ADJacVecVectorMode( + AbstractDifferentiation.ZygoteBackend(), + ), + ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), + ] + data_types = Type{<:AbstractFloat}[Float32] + resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()] + if CUDA.has_cuda_gpu() && USE_GPU + push!(resources, ComputationalResources.CUDALibs()) + gdev = Lux.gpu_device() + end + + Test.@testset "$resource | $data_type | $compute_mode | inplace = $inplace | aug & steer = $aug_steer | nvars = $nvars | $omode | $mt" for resource in + resources, + data_type in data_types, + compute_mode in compute_modes, + inplace in inplaces, + aug_steer in aug_steers, + nvars in nvars_, + ndata in ndata_, + omode in omodes, + mt in mts + + data_dist = + Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) + data_dist2 = + Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (4, 2))...) + if compute_mode isa ContinuousNormalizingFlows.VectorMode + r = convert.(data_type, rand(data_dist, nvars)) + r2 = convert.(data_type, rand(data_dist2, nvars)) + elseif compute_mode isa ContinuousNormalizingFlows.MatrixMode + r = convert.(data_type, rand(data_dist, nvars, ndata)) + r2 = convert.(data_type, rand(data_dist2, nvars, ndata)) + end + + nn = ifelse( + mt <: Union{ + ContinuousNormalizingFlows.CondRNODE, + ContinuousNormalizingFlows.CondFFJORD, + ContinuousNormalizingFlows.CondPlanar, + }, + ifelse( + mt <: ContinuousNormalizingFlows.CondPlanar, + ifelse( + aug_steer, + Lux.Chain( + ContinuousNormalizingFlows.PlanarLayer( + nvars * 2, + tanh; + n_cond = nvars, + ), + ), + Lux.Chain( + ContinuousNormalizingFlows.PlanarLayer(nvars, tanh; n_cond = nvars), + ), + ), + ifelse( + aug_steer, + Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars * 2 => nvars, tanh)), + ), + ), + ifelse( + mt <: ContinuousNormalizingFlows.Planar, + ifelse( + aug_steer, + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars, tanh)), + ), + ifelse( + aug_steer, + Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars => nvars, tanh)), + ), + ), + ) + icnf = ifelse( + aug_steer, + ContinuousNormalizingFlows.construct( + mt, + nn, + nvars, + nvars; + data_type, + compute_mode, + inplace, + resource, + steer_rate = convert(data_type, 1e-1), + λ₃ = convert(data_type, 1e-2), + ), + ContinuousNormalizingFlows.construct( + mt, + nn, + nvars; + data_type, + compute_mode, + inplace, + resource, + ), + ) + ps, st = Lux.setup(icnf.rng, icnf) + ps = ComponentArrays.ComponentArray(ps) + if resource isa ComputationalResources.CUDALibs + r = gdev(r) + r2 = gdev(r2) + ps = gdev(ps) + st = gdev(st) + end + + if mt <: Union{ + ContinuousNormalizingFlows.CondRNODE, + ContinuousNormalizingFlows.CondFFJORD, + ContinuousNormalizingFlows.CondPlanar, + } + Test.@test !isnothing( + ContinuousNormalizingFlows.inference(icnf, omode, r, r2, ps, st), + ) + if compute_mode isa ContinuousNormalizingFlows.MatrixMode + Test.@test !isnothing( + ContinuousNormalizingFlows.generate(icnf, omode, r2, ps, st, ndata), + ) + else + Test.@test !isnothing( + ContinuousNormalizingFlows.generate(icnf, omode, r2, ps, st), + ) + end + + Test.@test !isnothing( + ContinuousNormalizingFlows.loss(icnf, omode, r, r2, ps, st), + ) + Test.@test !isnothing(icnf((r, r2), ps, st)) + + diff_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, r, r2, x, st) + diff2_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, x, r2, ps, st) + else + Test.@test !isnothing( + ContinuousNormalizingFlows.inference(icnf, omode, r, ps, st), + ) + if compute_mode isa ContinuousNormalizingFlows.MatrixMode + Test.@test !isnothing( + ContinuousNormalizingFlows.generate(icnf, omode, ps, st, ndata), + ) + else + Test.@test !isnothing( + ContinuousNormalizingFlows.generate(icnf, omode, ps, st), + ) + end + + Test.@test !isnothing(ContinuousNormalizingFlows.loss(icnf, omode, r, ps, st)) + Test.@test !isnothing(icnf(r, ps, st)) + + diff_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, r, x, st) + diff2_loss = x -> ContinuousNormalizingFlows.loss(icnf, omode, x, ps, st) + end + + if mt <: Union{ + ContinuousNormalizingFlows.CondRNODE, + ContinuousNormalizingFlows.CondFFJORD, + ContinuousNormalizingFlows.CondPlanar, + } + d = ContinuousNormalizingFlows.CondICNFDist(icnf, omode, r2, ps, st) + else + d = ContinuousNormalizingFlows.ICNFDist(icnf, omode, ps, st) + end + + Test.@test !isnothing(Distributions.logpdf(d, r)) + Test.@test !isnothing(Distributions.pdf(d, r)) + Test.@test !isnothing(rand(d)) + Test.@test !isnothing(rand(d, ndata)) + + Test.@testset "$(typeof(adb).name.name)" for adb in adb_list + Test.@testset "Loss" begin + Test.@testset "ps" begin + Test.@test !isnothing( + AbstractDifferentiation.gradient(adb, diff_loss, ps), + ) + end + Test.@testset "x" begin + Test.@test !isnothing( + AbstractDifferentiation.gradient(adb, diff2_loss, r), + ) broken = + (GROUP != "All") && + adb isa AbstractDifferentiation.ReverseDiffBackend && + compute_mode isa ContinuousNormalizingFlows.MatrixMode && + VERSION >= v"1.10" + end + end + end + Test.@testset "$(typeof(adtype).name.name)" for adtype in adtypes + Test.@testset "Loss" begin + Test.@testset "ps" begin + Test.@test !isnothing( + DifferentiationInterface.gradient(diff_loss, adtype, ps), + ) + end + Test.@testset "x" begin + Test.@test !isnothing( + DifferentiationInterface.gradient(diff2_loss, adtype, r), + ) broken = + (GROUP != "All") && + adtype isa ADTypes.AutoReverseDiff && + compute_mode isa ContinuousNormalizingFlows.MatrixMode && + VERSION >= v"1.10" + end + end + end + end +end diff --git a/test/fit_tests.jl b/test/fit_tests.jl index a63387ea..a59b1b50 100644 --- a/test/fit_tests.jl +++ b/test/fit_tests.jl @@ -1,186 +1,186 @@ -Test.@testset "Fit Tests" begin - mts = if GROUP == "RNODE" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.RNODE] - elseif GROUP == "FFJORD" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.FFJORD] - elseif GROUP == "Planar" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.Planar] - elseif GROUP == "CondRNODE" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondRNODE] - elseif GROUP == "CondFFJORD" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondFFJORD] - elseif GROUP == "CondPlanar" - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondPlanar] - else - Type{<:ContinuousNormalizingFlows.AbstractICNF}[ - ContinuousNormalizingFlows.RNODE, - ContinuousNormalizingFlows.FFJORD, - ContinuousNormalizingFlows.Planar, - ContinuousNormalizingFlows.CondRNODE, - ContinuousNormalizingFlows.CondFFJORD, - ContinuousNormalizingFlows.CondPlanar, - ] - end - n_epochs_ = Int[2] - ndata_ = Int[4] - nvars_ = Int[2] - aug_steers = Bool[false, true] - inplaces = Bool[false, true] - adtypes = ADTypes.AbstractADType[ - ADTypes.AutoZygote(), - ADTypes.AutoReverseDiff(), - ADTypes.AutoForwardDiff(), - ] - compute_modes = ContinuousNormalizingFlows.ComputeMode[ - ContinuousNormalizingFlows.ADVecJacVectorMode( - AbstractDifferentiation.ZygoteBackend(), - ), - ContinuousNormalizingFlows.ADJacVecVectorMode( - AbstractDifferentiation.ZygoteBackend(), - ), - ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), - ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), - ] - data_types = Type{<:AbstractFloat}[Float32] - resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()] - if CUDA.has_cuda_gpu() && USE_GPU - push!(resources, ComputationalResources.CUDALibs()) - end - - Test.@testset "$resource | $data_type | $compute_mode | $adtype | inplace = $inplace | aug & steer = $aug_steer | nvars = $nvars | $mt" for resource in - resources, - data_type in data_types, - compute_mode in compute_modes, - adtype in adtypes, - inplace in inplaces, - aug_steer in aug_steers, - nvars in nvars_, - ndata in ndata_, - n_epochs in n_epochs_, - mt in mts - - data_dist = - Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) - data_dist2 = - Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (4, 2))...) - r = convert.(data_type, rand(data_dist, nvars, ndata)) - r2 = convert.(data_type, rand(data_dist2, nvars, ndata)) - df = DataFrames.DataFrame(transpose(r), :auto) - df2 = DataFrames.DataFrame(transpose(r2), :auto) - - nn = ifelse( - mt <: Union{ - ContinuousNormalizingFlows.CondRNODE, - ContinuousNormalizingFlows.CondFFJORD, - ContinuousNormalizingFlows.CondPlanar, - }, - ifelse( - mt <: ContinuousNormalizingFlows.CondPlanar, - ifelse( - aug_steer, - Lux.Chain( - ContinuousNormalizingFlows.PlanarLayer( - nvars * 2, - tanh; - n_cond = nvars, - ), - ), - Lux.Chain( - ContinuousNormalizingFlows.PlanarLayer(nvars, tanh; n_cond = nvars), - ), - ), - ifelse( - aug_steer, - Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), - Lux.Chain(Lux.Dense(nvars * 2 => nvars, tanh)), - ), - ), - ifelse( - mt <: ContinuousNormalizingFlows.Planar, - ifelse( - aug_steer, - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), - Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars, tanh)), - ), - ifelse( - aug_steer, - Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), - Lux.Chain(Lux.Dense(nvars => nvars, tanh)), - ), - ), - ) - icnf = ifelse( - aug_steer, - ContinuousNormalizingFlows.construct( - mt, - nn, - nvars, - nvars; - data_type, - compute_mode, - inplace, - resource, - steer_rate = convert(data_type, 1e-1), - λ₃ = convert(data_type, 1e-2), - ), - ContinuousNormalizingFlows.construct( - mt, - nn, - nvars; - data_type, - compute_mode, - inplace, - resource, - ), - ) - if mt <: Union{ - ContinuousNormalizingFlows.CondRNODE, - ContinuousNormalizingFlows.CondFFJORD, - ContinuousNormalizingFlows.CondPlanar, - } - model = ContinuousNormalizingFlows.CondICNFModel(icnf; n_epochs, adtype) - mach = MLJBase.machine(model, (df, df2)) - - Test.@test !isnothing(MLJBase.fit!(mach)) - Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) - Test.@test !isnothing(MLJBase.fitted_params(mach)) - - Test.@test !isnothing( - ContinuousNormalizingFlows.CondICNFDist( - mach, - ContinuousNormalizingFlows.TrainMode(), - r2, - ), - ) - Test.@test !isnothing( - ContinuousNormalizingFlows.CondICNFDist( - mach, - ContinuousNormalizingFlows.TestMode(), - r2, - ), - ) - else - model = ContinuousNormalizingFlows.ICNFModel(icnf; n_epochs, adtype) - mach = MLJBase.machine(model, df) - - Test.@test !isnothing(MLJBase.fit!(mach)) - Test.@test !isnothing(MLJBase.transform(mach, df)) - Test.@test !isnothing(MLJBase.fitted_params(mach)) - - Test.@test !isnothing( - ContinuousNormalizingFlows.ICNFDist( - mach, - ContinuousNormalizingFlows.TrainMode(), - ), - ) - Test.@test !isnothing( - ContinuousNormalizingFlows.ICNFDist( - mach, - ContinuousNormalizingFlows.TestMode(), - ), - ) - end - end -end +Test.@testset "Fit Tests" begin + mts = if GROUP == "RNODE" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.RNODE] + elseif GROUP == "FFJORD" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.FFJORD] + elseif GROUP == "Planar" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.Planar] + elseif GROUP == "CondRNODE" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondRNODE] + elseif GROUP == "CondFFJORD" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondFFJORD] + elseif GROUP == "CondPlanar" + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ContinuousNormalizingFlows.CondPlanar] + else + Type{<:ContinuousNormalizingFlows.AbstractICNF}[ + ContinuousNormalizingFlows.RNODE, + ContinuousNormalizingFlows.FFJORD, + ContinuousNormalizingFlows.Planar, + ContinuousNormalizingFlows.CondRNODE, + ContinuousNormalizingFlows.CondFFJORD, + ContinuousNormalizingFlows.CondPlanar, + ] + end + n_epochs_ = Int[2] + ndata_ = Int[4] + nvars_ = Int[2] + aug_steers = Bool[false, true] + inplaces = Bool[false, true] + adtypes = ADTypes.AbstractADType[ + ADTypes.AutoZygote(), + ADTypes.AutoReverseDiff(), + ADTypes.AutoForwardDiff(), + ] + compute_modes = ContinuousNormalizingFlows.ComputeMode[ + ContinuousNormalizingFlows.ADVecJacVectorMode( + AbstractDifferentiation.ZygoteBackend(), + ), + ContinuousNormalizingFlows.ADJacVecVectorMode( + AbstractDifferentiation.ZygoteBackend(), + ), + ContinuousNormalizingFlows.DIVecJacVectorMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.DIJacVecVectorMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + ContinuousNormalizingFlows.DIJacVecMatrixMode(ADTypes.AutoZygote()), + ] + data_types = Type{<:AbstractFloat}[Float32] + resources = ComputationalResources.AbstractResource[ComputationalResources.CPU1()] + if CUDA.has_cuda_gpu() && USE_GPU + push!(resources, ComputationalResources.CUDALibs()) + end + + Test.@testset "$resource | $data_type | $compute_mode | $adtype | inplace = $inplace | aug & steer = $aug_steer | nvars = $nvars | $mt" for resource in + resources, + data_type in data_types, + compute_mode in compute_modes, + adtype in adtypes, + inplace in inplaces, + aug_steer in aug_steers, + nvars in nvars_, + ndata in ndata_, + n_epochs in n_epochs_, + mt in mts + + data_dist = + Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (2, 4))...) + data_dist2 = + Distributions.Beta{data_type}(convert(Tuple{data_type, data_type}, (4, 2))...) + r = convert.(data_type, rand(data_dist, nvars, ndata)) + r2 = convert.(data_type, rand(data_dist2, nvars, ndata)) + df = DataFrames.DataFrame(transpose(r), :auto) + df2 = DataFrames.DataFrame(transpose(r2), :auto) + + nn = ifelse( + mt <: Union{ + ContinuousNormalizingFlows.CondRNODE, + ContinuousNormalizingFlows.CondFFJORD, + ContinuousNormalizingFlows.CondPlanar, + }, + ifelse( + mt <: ContinuousNormalizingFlows.CondPlanar, + ifelse( + aug_steer, + Lux.Chain( + ContinuousNormalizingFlows.PlanarLayer( + nvars * 2, + tanh; + n_cond = nvars, + ), + ), + Lux.Chain( + ContinuousNormalizingFlows.PlanarLayer(nvars, tanh; n_cond = nvars), + ), + ), + ifelse( + aug_steer, + Lux.Chain(Lux.Dense(nvars * 3 => nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars * 2 => nvars, tanh)), + ), + ), + ifelse( + mt <: ContinuousNormalizingFlows.Planar, + ifelse( + aug_steer, + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars * 2, tanh)), + Lux.Chain(ContinuousNormalizingFlows.PlanarLayer(nvars, tanh)), + ), + ifelse( + aug_steer, + Lux.Chain(Lux.Dense(nvars * 2 => nvars * 2, tanh)), + Lux.Chain(Lux.Dense(nvars => nvars, tanh)), + ), + ), + ) + icnf = ifelse( + aug_steer, + ContinuousNormalizingFlows.construct( + mt, + nn, + nvars, + nvars; + data_type, + compute_mode, + inplace, + resource, + steer_rate = convert(data_type, 1e-1), + λ₃ = convert(data_type, 1e-2), + ), + ContinuousNormalizingFlows.construct( + mt, + nn, + nvars; + data_type, + compute_mode, + inplace, + resource, + ), + ) + if mt <: Union{ + ContinuousNormalizingFlows.CondRNODE, + ContinuousNormalizingFlows.CondFFJORD, + ContinuousNormalizingFlows.CondPlanar, + } + model = ContinuousNormalizingFlows.CondICNFModel(icnf; n_epochs, adtype) + mach = MLJBase.machine(model, (df, df2)) + + Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.transform(mach, (df, df2))) + Test.@test !isnothing(MLJBase.fitted_params(mach)) + + Test.@test !isnothing( + ContinuousNormalizingFlows.CondICNFDist( + mach, + ContinuousNormalizingFlows.TrainMode(), + r2, + ), + ) + Test.@test !isnothing( + ContinuousNormalizingFlows.CondICNFDist( + mach, + ContinuousNormalizingFlows.TestMode(), + r2, + ), + ) + else + model = ContinuousNormalizingFlows.ICNFModel(icnf; n_epochs, adtype) + mach = MLJBase.machine(model, df) + + Test.@test !isnothing(MLJBase.fit!(mach)) + Test.@test !isnothing(MLJBase.transform(mach, df)) + Test.@test !isnothing(MLJBase.fitted_params(mach)) + + Test.@test !isnothing( + ContinuousNormalizingFlows.ICNFDist( + mach, + ContinuousNormalizingFlows.TrainMode(), + ), + ) + Test.@test !isnothing( + ContinuousNormalizingFlows.ICNFDist( + mach, + ContinuousNormalizingFlows.TestMode(), + ), + ) + end + end +end diff --git a/test/instability_tests.jl b/test/instability_tests.jl index 3517a7c9..5d7f7bac 100644 --- a/test/instability_tests.jl +++ b/test/instability_tests.jl @@ -1,40 +1,40 @@ -Test.@testset "Instability" begin - JET.test_package( - ContinuousNormalizingFlows; - target_modules = [ContinuousNormalizingFlows], - mode = :sound, - ) - - nvars = 2^3 - naugs = nvars - n_in = nvars + naugs - n = 2^6 - nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh)) - - icnf = ContinuousNormalizingFlows.construct( - ContinuousNormalizingFlows.RNODE, - nn, - nvars, - naugs; - compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), - tspan = (0.0f0, 13.0f0), - steer_rate = 1.0f-1, - λ₃ = 1.0f-2, - ) - ps, st = Lux.setup(icnf.rng, icnf) - ps = ComponentArrays.ComponentArray(ps) - r = rand(icnf.rng, Float32, nvars, n) - - ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st) - JET.test_call( - ContinuousNormalizingFlows.loss, - Base.typesof(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st); - target_modules = [ContinuousNormalizingFlows], - mode = :sound, - ) - JET.test_opt( - ContinuousNormalizingFlows.loss, - Base.typesof(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st); - target_modules = [ContinuousNormalizingFlows], - ) -end +Test.@testset "Instability" begin + JET.test_package( + ContinuousNormalizingFlows; + target_modules = [ContinuousNormalizingFlows], + mode = :sound, + ) + + nvars = 2^3 + naugs = nvars + n_in = nvars + naugs + n = 2^6 + nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh)) + + icnf = ContinuousNormalizingFlows.construct( + ContinuousNormalizingFlows.RNODE, + nn, + nvars, + naugs; + compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + tspan = (0.0f0, 13.0f0), + steer_rate = 1.0f-1, + λ₃ = 1.0f-2, + ) + ps, st = Lux.setup(icnf.rng, icnf) + ps = ComponentArrays.ComponentArray(ps) + r = rand(icnf.rng, Float32, nvars, n) + + ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st) + JET.test_call( + ContinuousNormalizingFlows.loss, + Base.typesof(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st); + target_modules = [ContinuousNormalizingFlows], + mode = :sound, + ) + JET.test_opt( + ContinuousNormalizingFlows.loss, + Base.typesof(icnf, ContinuousNormalizingFlows.TrainMode(), r, ps, st); + target_modules = [ContinuousNormalizingFlows], + ) +end diff --git a/test/quality_tests.jl b/test/quality_tests.jl index 5ab576c3..4f899c43 100644 --- a/test/quality_tests.jl +++ b/test/quality_tests.jl @@ -1,6 +1,6 @@ -Test.@testset "Quality" begin - Test.@testset "Method ambiguity" begin - Aqua.test_ambiguities(ContinuousNormalizingFlows) - end - Aqua.test_all(ContinuousNormalizingFlows; ambiguities = (GROUP == "All")) -end +Test.@testset "Quality" begin + Test.@testset "Method ambiguity" begin + Aqua.test_ambiguities(ContinuousNormalizingFlows) + end + Aqua.test_all(ContinuousNormalizingFlows; ambiguities = (GROUP == "All")) +end diff --git a/test/runtests.jl b/test/runtests.jl index d6b257b6..341b06bf 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,52 +1,52 @@ -import AbstractDifferentiation, - ADTypes, - Aqua, - ComponentArrays, - ComputationalResources, - CUDA, - cuDNN, - DataFrames, - DifferentiationInterface, - Distributions, - ForwardDiff, - JET, - Logging, - Lux, - LuxCUDA, - MLJBase, - ReverseDiff, - SciMLBase, - TerminalLoggers, - Test, - Zygote, - ContinuousNormalizingFlows - -GROUP = get(ENV, "GROUP", "All") -USE_GPU = get(ENV, "USE_GPU", "Yes") == "Yes" - -if (GROUP == "All") - GC.enable_logging(true) - - debuglogger = TerminalLoggers.TerminalLogger(stderr, Logging.Debug) - Logging.global_logger(debuglogger) -else - warnlogger = TerminalLoggers.TerminalLogger(stderr, Logging.Warn) - Logging.global_logger(warnlogger) -end - -Test.@testset "Overall" begin - if GROUP == "All" || - GROUP in ["RNODE", "FFJORD", "Planar", "CondRNODE", "CondFFJORD", "CondPlanar"] - CUDA.allowscalar() do - include("smoke_tests.jl") - end - end - - if GROUP == "All" || GROUP == "Quality" - include("quality_tests.jl") - end - - if GROUP == "All" || GROUP == "Instability" - include("instability_tests.jl") - end -end +import AbstractDifferentiation, + ADTypes, + Aqua, + ComponentArrays, + ComputationalResources, + CUDA, + cuDNN, + DataFrames, + DifferentiationInterface, + Distributions, + ForwardDiff, + JET, + Logging, + Lux, + LuxCUDA, + MLJBase, + ReverseDiff, + SciMLBase, + TerminalLoggers, + Test, + Zygote, + ContinuousNormalizingFlows + +GROUP = get(ENV, "GROUP", "All") +USE_GPU = get(ENV, "USE_GPU", "Yes") == "Yes" + +if (GROUP == "All") + GC.enable_logging(true) + + debuglogger = TerminalLoggers.TerminalLogger(stderr, Logging.Debug) + Logging.global_logger(debuglogger) +else + warnlogger = TerminalLoggers.TerminalLogger(stderr, Logging.Warn) + Logging.global_logger(warnlogger) +end + +Test.@testset "Overall" begin + if GROUP == "All" || + GROUP in ["RNODE", "FFJORD", "Planar", "CondRNODE", "CondFFJORD", "CondPlanar"] + CUDA.allowscalar() do + include("smoke_tests.jl") + end + end + + if GROUP == "All" || GROUP == "Quality" + include("quality_tests.jl") + end + + if GROUP == "All" || GROUP == "Instability" + include("instability_tests.jl") + end +end diff --git a/test/smoke_tests.jl b/test/smoke_tests.jl index 01a9a0ff..8cce6d1a 100644 --- a/test/smoke_tests.jl +++ b/test/smoke_tests.jl @@ -1,4 +1,4 @@ -Test.@testset "Smoke Tests" begin - include("call_tests.jl") - include("fit_tests.jl") -end +Test.@testset "Smoke Tests" begin + include("call_tests.jl") + include("fit_tests.jl") +end