Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full qualification by import #407

Merged
merged 21 commits into from
Apr 27, 2024
2 changes: 1 addition & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ always_for_in = true
whitespace_typedefs = true
whitespace_ops_in_indices = true
remove_extra_newlines = true
import_to_using = true
import_to_using = false
pipe_to_function_call = true
short_to_long_function_def = true
long_to_short_function_def = false
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ jobs:
- name: Install dependencies
shell: julia --color=yes {0}
run: |
using Pkg
import Pkg
Pkg.add(["PkgBenchmark", "BenchmarkCI"])
- name: Run benchmarks
shell: julia --color=yes {0}
run: |
using BenchmarkCI
import BenchmarkCI
BenchmarkCI.judge(; baseline="origin/main", retune=true, verbose=true)
- name: Print judgement
shell: julia --color=yes {0}
run: |
using BenchmarkCI
import BenchmarkCI
BenchmarkCI.displayjudgement()
- name: Post results
shell: julia --color=yes {0}
run: |
using BenchmarkCI
import BenchmarkCI
BenchmarkCI.postjudge()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
4 changes: 2 additions & 2 deletions .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ jobs:
- name: Pkg.add
shell: julia --color=yes {0}
run: |
using Pkg
import Pkg
Pkg.add("CompatHelper")
- name: CompatHelper.main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
shell: julia --color=yes {0}
run: |
using CompatHelper
import CompatHelper
CompatHelper.main(; include_jll = true, subdirs = ["", "docs", "test", "benchmark"])
11 changes: 5 additions & 6 deletions .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ jobs:
- name: Configure doc environment
shell: julia --project=docs --color=yes {0}
run: |
using Pkg
Pkg.develop(PackageSpec(path=pwd()))
import Pkg
Pkg.develop(Pkg.PackageSpec(path=pwd()))
Pkg.instantiate()
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-docdeploy@v1
Expand All @@ -40,7 +40,6 @@ jobs:
- name: Run doctests
shell: julia --project=docs --color=yes {0}
run: |
using Documenter
using ContinuousNormalizingFlows
DocMeta.setdocmeta!(ContinuousNormalizingFlows, :DocTestSetup, :(using ContinuousNormalizingFlows); recursive=true)
doctest(ContinuousNormalizingFlows)
import Documenter, ContinuousNormalizingFlows
Documenter.DocMeta.setdocmeta!(ContinuousNormalizingFlows, :DocTestSetup, :(using ContinuousNormalizingFlows); recursive=true)
Documenter.doctest(ContinuousNormalizingFlows)
6 changes: 3 additions & 3 deletions .github/workflows/Formatter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ jobs:
- name: Install JuliaFormatter and format
shell: julia --color=yes {0}
run: |
using Pkg
import Pkg
Pkg.add("JuliaFormatter")
using JuliaFormatter
format(".")
import JuliaFormatter
JuliaFormatter.format(".")

# https://github.com/marketplace/actions/create-pull-request
# https://github.com/peter-evans/create-pull-request#reference-example
Expand Down
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand Down Expand Up @@ -65,10 +67,12 @@ Octavian = "0.3.27"
Optimisers = "0.3"
Optimization = "3.15"
OptimizationOptimisers = "0.2"
OrdinaryDiffEq = "6"
Random = "1"
SciMLBase = "2"
SciMLSensitivity = "7"
ScientificTypes = "3"
ScientificTypesBase = "3"
Static = "0.8"
Statistics = "1"
Zygote = "0.6"
julia = "1.9"
99 changes: 63 additions & 36 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,89 +1,116 @@
using ContinuousNormalizingFlows,
ADTypes,
import ADTypes,
BenchmarkTools,
ComponentArrays,
DifferentiationInterface,
Lux,
PkgBenchmark,
StableRNGs,
Zygote
Zygote,
ContinuousNormalizingFlows

SUITE = BenchmarkGroup()
SUITE = BenchmarkTools.BenchmarkGroup()

SUITE["main"] = BenchmarkGroup(["package", "simple"])
SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"])

SUITE["main"]["no_inplace"] = BenchmarkGroup(["no_inplace"])
SUITE["main"]["inplace"] = BenchmarkGroup(["inplace"])
SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"])
SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"])

SUITE["main"]["no_inplace"]["direct"] = BenchmarkGroup(["direct"])
SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkGroup(["gradient"])
SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])

SUITE["main"]["inplace"]["direct"] = BenchmarkGroup(["direct"])
SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkGroup(["gradient"])
SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])

rng = StableRNG(12345)
rng = StableRNGs.StableRNG(12345)
nvars = 2^3
naugs = nvars
n_in = nvars + naugs
n = 2^6
nn = Chain(Dense(n_in => n_in, tanh))
nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh))

icnf = construct(
RNODE,
icnf = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.RNODE,
nn,
nvars,
naugs;
compute_mode = DIVecJacMatrixMode,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode,
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
rng,
)
ps, st = Lux.setup(icnf.rng, icnf)
ps = ComponentArray(ps)
ps = ComponentArrays.ComponentArray(ps)
r = rand(icnf.rng, Float32, nvars, n)

diff_loss_tn(x) = loss(icnf, TrainMode(), r, x, st)
diff_loss_tt(x) = loss(icnf, TestMode(), r, x, st)
function diff_loss_tn(x)
ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TrainMode(), r, x, st)
end
function diff_loss_tt(x)
ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TestMode(), r, x, st)
end

diff_loss_tn(ps)
diff_loss_tt(ps)
DifferentiationInterface.gradient(diff_loss_tn, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps)
GC.gc()

SUITE["main"]["no_inplace"]["direct"]["train"] = @benchmarkable diff_loss_tn(ps)
SUITE["main"]["no_inplace"]["direct"]["test"] = @benchmarkable diff_loss_tt(ps)
SUITE["main"]["no_inplace"]["direct"]["train"] =
BenchmarkTools.@benchmarkable diff_loss_tn(ps)
SUITE["main"]["no_inplace"]["direct"]["test"] =
BenchmarkTools.@benchmarkable diff_loss_tt(ps)
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
@benchmarkable DifferentiationInterface.gradient(diff_loss_tn, AutoZygote(), ps)
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tn,
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
@benchmarkable DifferentiationInterface.gradient(diff_loss_tt, AutoZygote(), ps)
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tt,
ADTypes.AutoZygote(),
ps,
)

icnf2 = construct(
RNODE,
icnf2 = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.RNODE,
nn,
nvars,
naugs;
inplace = true,
compute_mode = DIVecJacMatrixMode,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode,
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
rng,
)

diff_loss_tn2(x) = loss(icnf2, TrainMode(), r, x, st)
diff_loss_tt2(x) = loss(icnf2, TestMode(), r, x, st)
function diff_loss_tn2(x)
ContinuousNormalizingFlows.loss(icnf2, ContinuousNormalizingFlows.TrainMode(), r, x, st)
end
function diff_loss_tt2(x)
ContinuousNormalizingFlows.loss(icnf2, ContinuousNormalizingFlows.TestMode(), r, x, st)
end

diff_loss_tn2(ps)
diff_loss_tt2(ps)
DifferentiationInterface.gradient(diff_loss_tn2, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt2, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps)
GC.gc()

SUITE["main"]["inplace"]["direct"]["train"] = @benchmarkable diff_loss_tn2(ps)
SUITE["main"]["inplace"]["direct"]["test"] = @benchmarkable diff_loss_tt2(ps)
SUITE["main"]["inplace"]["direct"]["train"] =
BenchmarkTools.@benchmarkable diff_loss_tn2(ps)
SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_loss_tt2(ps)
SUITE["main"]["inplace"]["AD-1-order"]["train"] =
@benchmarkable DifferentiationInterface.gradient(diff_loss_tn2, AutoZygote(), ps)
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tn2,
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["inplace"]["AD-1-order"]["test"] =
@benchmarkable DifferentiationInterface.gradient(diff_loss_tt2, AutoZygote(), ps)
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tt2,
ADTypes.AutoZygote(),
ps,
)
12 changes: 7 additions & 5 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
using ContinuousNormalizingFlows
using Documenter
import Documenter, ContinuousNormalizingFlows

DocMeta.setdocmeta!(
Documenter.DocMeta.setdocmeta!(
ContinuousNormalizingFlows,
:DocTestSetup,
:(using ContinuousNormalizingFlows);
recursive = true,
)

makedocs(;
Documenter.makedocs(;
modules = [ContinuousNormalizingFlows],
authors = "Hossein Pourbozorg <[email protected]> and contributors",
repo = "https://github.com/impICNF/ContinuousNormalizingFlows.jl/blob/{commit}{path}#{line}",
Expand All @@ -20,4 +19,7 @@ makedocs(;
pages = ["Home" => "index.md"],
)

deploydocs(; repo = "github.com/impICNF/ContinuousNormalizingFlows.jl", devbranch = "main")
Documenter.deploydocs(;
repo = "github.com/impICNF/ContinuousNormalizingFlows.jl",
devbranch = "main",
)
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
module ContinuousNormalizingFlowsCUDAExt

using ContinuousNormalizingFlows, CUDA
using ContinuousNormalizingFlows.ComputationalResources
import CUDA, ComputationalResources, ContinuousNormalizingFlows

@inline function ContinuousNormalizingFlows.rng_AT(::CUDALibs)
CURAND.default_rng()
@inline function ContinuousNormalizingFlows.rng_AT(::ComputationalResources.CUDALibs)
CUDA.CURAND.default_rng()

Check warning on line 6 in ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl#L5-L6

Added lines #L5 - L6 were not covered by tests
end

@inline function ContinuousNormalizingFlows.base_AT(
::CUDALibs,
::ComputationalResources.CUDALibs,
::ContinuousNormalizingFlows.AbstractICNF{T},
dims...,
) where {T <: AbstractFloat}
CuArray{T}(undef, dims...)
CUDA.CuArray{T}(undef, dims...)

Check warning on line 14 in ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/ContinuousNormalizingFlowsCUDAExt/ContinuousNormalizingFlowsCUDAExt.jl#L14

Added line #L14 was not covered by tests
end

end
41 changes: 36 additions & 5 deletions src/ContinuousNormalizingFlows.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module ContinuousNormalizingFlows

using AbstractDifferentiation,
import AbstractDifferentiation,
ADTypes,
Base.Iterators,
ChainRulesCore,
Expand All @@ -24,13 +24,40 @@ using AbstractDifferentiation,
Optimisers,
Optimization,
OptimizationOptimisers,
OrdinaryDiffEq,
Random,
ScientificTypes,
ScientificTypesBase,
Static,
SciMLBase,
SciMLSensitivity,
Statistics,
Zygote

export construct,
inference,
generate,
loss,
ICNF,
RNODE,
CondRNODE,
FFJORD,
CondFFJORD,
Planar,
CondPlanar,
TestMode,
TrainMode,
ADVecJacVectorMode,
ADJacVecVectorMode,
DIVecJacVectorMode,
DIJacVecVectorMode,
DIVecJacMatrixMode,
DIJacVecMatrixMode,
ICNFModel,
CondICNFModel,
CondLayer,
PlanarLayer,
MulLayer

include(joinpath("layers", "cond_layer.jl"))
include(joinpath("layers", "planar_layer.jl"))
include(joinpath("layers", "mul_layer.jl"))
Expand All @@ -43,9 +70,13 @@ include("icnf.jl")

include("utils.jl")

include(joinpath("cores", "core.jl"))
include(joinpath("cores", "core_icnf.jl"))
include(joinpath("cores", "core_cond_icnf.jl"))
include(joinpath("exts", "mlj_ext", "core.jl"))
include(joinpath("exts", "mlj_ext", "core_icnf.jl"))
include(joinpath("exts", "mlj_ext", "core_cond_icnf.jl"))

include(joinpath("exts", "dist_ext", "core.jl"))
include(joinpath("exts", "dist_ext", "core_icnf.jl"))
include(joinpath("exts", "dist_ext", "core_cond_icnf.jl"))

"""
Implementations of Infinitesimal Continuous Normalizing Flows Algorithms in Julia
Expand Down
Loading
Loading