Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full qualification by import #407

Merged
merged 21 commits into from
Apr 27, 2024
2 changes: 1 addition & 1 deletion .JuliaFormatter.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ always_for_in = true
whitespace_typedefs = true
whitespace_ops_in_indices = true
remove_extra_newlines = true
import_to_using = true
import_to_using = false
pipe_to_function_call = true
short_to_long_function_def = true
long_to_short_function_def = false
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/Benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,22 @@ jobs:
- name: Install dependencies
shell: julia --color=yes {0}
run: |
using Pkg
import Pkg
Pkg.add(["PkgBenchmark", "BenchmarkCI"])
- name: Run benchmarks
shell: julia --color=yes {0}
run: |
using BenchmarkCI
import BenchmarkCI
BenchmarkCI.judge(; baseline="origin/main", retune=true, verbose=true)
- name: Print judgement
shell: julia --color=yes {0}
run: |
using BenchmarkCI
import BenchmarkCI
BenchmarkCI.displayjudgement()
- name: Post results
shell: julia --color=yes {0}
run: |
using BenchmarkCI
import BenchmarkCI
BenchmarkCI.postjudge()
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
4 changes: 2 additions & 2 deletions .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,13 @@ jobs:
- name: Pkg.add
shell: julia --color=yes {0}
run: |
using Pkg
import Pkg
Pkg.add("CompatHelper")
- name: CompatHelper.main
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
shell: julia --color=yes {0}
run: |
using CompatHelper
import CompatHelper
CompatHelper.main(; include_jll = true, subdirs = ["", "docs", "test", "benchmark"])
9 changes: 4 additions & 5 deletions .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
- name: Configure doc environment
shell: julia --project=docs --color=yes {0}
run: |
using Pkg
import Pkg
Pkg.develop(PackageSpec(path=pwd()))
Pkg.instantiate()
- uses: julia-actions/julia-buildpkg@v1
Expand All @@ -40,7 +40,6 @@ jobs:
- name: Run doctests
shell: julia --project=docs --color=yes {0}
run: |
using Documenter
using ContinuousNormalizingFlows
DocMeta.setdocmeta!(ContinuousNormalizingFlows, :DocTestSetup, :(using ContinuousNormalizingFlows); recursive=true)
doctest(ContinuousNormalizingFlows)
import Documenter, ContinuousNormalizingFlows
Documenter.DocMeta.setdocmeta!(ContinuousNormalizingFlows, :DocTestSetup, :(using ContinuousNormalizingFlows); recursive=true)
Documenter.doctest(ContinuousNormalizingFlows)
6 changes: 3 additions & 3 deletions .github/workflows/Formatter.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ jobs:
- name: Install JuliaFormatter and format
shell: julia --color=yes {0}
run: |
using Pkg
import Pkg
Pkg.add("JuliaFormatter")
using JuliaFormatter
format(".")
import JuliaFormatter
JuliaFormatter.format(".")

# https://github.com/marketplace/actions/create-pull-request
# https://github.com/peter-evans/create-pull-request#reference-example
Expand Down
11 changes: 8 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,21 @@ Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
ContinuousNormalizingFlowsCUDAExt = "CUDA"
ContinuousNormalizingFlowsCUDAExt = ["CUDA", "ComputationalResources"]
ContinuousNormalizingFlowsDistributionsExt = "Distributions"

[compat]
ADTypes = "0.2, 1"
Expand All @@ -65,10 +68,12 @@ Octavian = "0.3.27"
Optimisers = "0.3"
Optimization = "3.15"
OptimizationOptimisers = "0.2"
OrdinaryDiffEq = "6"
Random = "1"
SciMLBase = "2"
SciMLSensitivity = "7"
ScientificTypes = "3"
ScientificTypesBase = "3"
Static = "0.8"
Statistics = "1"
Zygote = "0.6"
julia = "1.9"
99 changes: 63 additions & 36 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
@@ -1,89 +1,116 @@
using ContinuousNormalizingFlows,
ADTypes,
import ADTypes,
BenchmarkTools,
ComponentArrays,
DifferentiationInterface,
Lux,
PkgBenchmark,
StableRNGs,
Zygote
Zygote,
ContinuousNormalizingFlows

SUITE = BenchmarkGroup()
SUITE = BenchmarkTools.BenchmarkGroup()

SUITE["main"] = BenchmarkGroup(["package", "simple"])
SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"])

SUITE["main"]["no_inplace"] = BenchmarkGroup(["no_inplace"])
SUITE["main"]["inplace"] = BenchmarkGroup(["inplace"])
SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"])
SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"])

SUITE["main"]["no_inplace"]["direct"] = BenchmarkGroup(["direct"])
SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkGroup(["gradient"])
SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])

SUITE["main"]["inplace"]["direct"] = BenchmarkGroup(["direct"])
SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkGroup(["gradient"])
SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])

rng = StableRNG(12345)
rng = StableRNGs.StableRNG(12345)
nvars = 2^3
naugs = nvars
n_in = nvars + naugs
n = 2^6
nn = Chain(Dense(n_in => n_in, tanh))
nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh))

icnf = construct(
RNODE,
icnf = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.RNODE,
nn,
nvars,
naugs;
compute_mode = DIVecJacMatrixMode,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode,
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
rng,
)
ps, st = Lux.setup(icnf.rng, icnf)
ps = ComponentArray(ps)
ps = ComponentArrays.ComponentArray(ps)
r = rand(icnf.rng, Float32, nvars, n)

diff_loss_tn(x) = loss(icnf, TrainMode(), r, x, st)
diff_loss_tt(x) = loss(icnf, TestMode(), r, x, st)
function diff_loss_tn(x)
ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TrainMode(), r, x, st)
end
function diff_loss_tt(x)
ContinuousNormalizingFlows.loss(icnf, ContinuousNormalizingFlows.TestMode(), r, x, st)
end

diff_loss_tn(ps)
diff_loss_tt(ps)
DifferentiationInterface.gradient(diff_loss_tn, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps)
GC.gc()

SUITE["main"]["no_inplace"]["direct"]["train"] = @benchmarkable diff_loss_tn(ps)
SUITE["main"]["no_inplace"]["direct"]["test"] = @benchmarkable diff_loss_tt(ps)
SUITE["main"]["no_inplace"]["direct"]["train"] =
BenchmarkTools.@benchmarkable diff_loss_tn(ps)
SUITE["main"]["no_inplace"]["direct"]["test"] =
BenchmarkTools.@benchmarkable diff_loss_tt(ps)
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
@benchmarkable DifferentiationInterface.gradient(diff_loss_tn, AutoZygote(), ps)
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tn,
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
@benchmarkable DifferentiationInterface.gradient(diff_loss_tt, AutoZygote(), ps)
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tt,
ADTypes.AutoZygote(),
ps,
)

icnf2 = construct(
RNODE,
icnf2 = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.RNODE,
nn,
nvars,
naugs;
inplace = true,
compute_mode = DIVecJacMatrixMode,
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode,
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
rng,
)

diff_loss_tn2(x) = loss(icnf2, TrainMode(), r, x, st)
diff_loss_tt2(x) = loss(icnf2, TestMode(), r, x, st)
function diff_loss_tn2(x)
ContinuousNormalizingFlows.loss(icnf2, ContinuousNormalizingFlows.TrainMode(), r, x, st)
end
function diff_loss_tt2(x)
ContinuousNormalizingFlows.loss(icnf2, ContinuousNormalizingFlows.TestMode(), r, x, st)
end

diff_loss_tn2(ps)
diff_loss_tt2(ps)
DifferentiationInterface.gradient(diff_loss_tn2, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt2, AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps)
DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps)
GC.gc()

SUITE["main"]["inplace"]["direct"]["train"] = @benchmarkable diff_loss_tn2(ps)
SUITE["main"]["inplace"]["direct"]["test"] = @benchmarkable diff_loss_tt2(ps)
SUITE["main"]["inplace"]["direct"]["train"] =
BenchmarkTools.@benchmarkable diff_loss_tn2(ps)
SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_loss_tt2(ps)
SUITE["main"]["inplace"]["AD-1-order"]["train"] =
@benchmarkable DifferentiationInterface.gradient(diff_loss_tn2, AutoZygote(), ps)
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tn2,
ADTypes.AutoZygote(),
ps,
)
SUITE["main"]["inplace"]["AD-1-order"]["test"] =
@benchmarkable DifferentiationInterface.gradient(diff_loss_tt2, AutoZygote(), ps)
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
diff_loss_tt2,
ADTypes.AutoZygote(),
ps,
)
12 changes: 7 additions & 5 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
using ContinuousNormalizingFlows
using Documenter
import Documenter, ContinuousNormalizingFlows

DocMeta.setdocmeta!(
Documenter.DocMeta.setdocmeta!(
ContinuousNormalizingFlows,
:DocTestSetup,
:(using ContinuousNormalizingFlows);
recursive = true,
)

makedocs(;
Documenter.makedocs(;
modules = [ContinuousNormalizingFlows],
authors = "Hossein Pourbozorg <[email protected]> and contributors",
repo = "https://github.com/impICNF/ContinuousNormalizingFlows.jl/blob/{commit}{path}#{line}",
Expand All @@ -20,4 +19,7 @@ makedocs(;
pages = ["Home" => "index.md"],
)

deploydocs(; repo = "github.com/impICNF/ContinuousNormalizingFlows.jl", devbranch = "main")
Documenter.deploydocs(;
repo = "github.com/impICNF/ContinuousNormalizingFlows.jl",
devbranch = "main",
)
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
module ContinuousNormalizingFlowsCUDAExt

using ContinuousNormalizingFlows, CUDA
using ContinuousNormalizingFlows.ComputationalResources
import CUDA, ComputationalResources, ContinuousNormalizingFlows

@inline function ContinuousNormalizingFlows.rng_AT(::CUDALibs)
CURAND.default_rng()
@inline function ContinuousNormalizingFlows.rng_AT(::ComputationalResources.CUDALibs)
CUDA.CURAND.default_rng()
end

@inline function ContinuousNormalizingFlows.base_AT(
::CUDALibs,
::CUDA.CUDALibs,
::ContinuousNormalizingFlows.AbstractICNF{T},
dims...,
) where {T <: AbstractFloat}
CuArray{T}(undef, dims...)
CUDA.CuArray{T}(undef, dims...)
end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module ContinuousNormalizingFlowsDistributionsExt

import Distributions, ContinuousNormalizingFlows

export ICNFDist, CondICNFDist

include("core.jl")
include("core_icnf.jl")
include("core_cond_icnf.jl")

end
prbzrg marked this conversation as resolved.
Show resolved Hide resolved
16 changes: 16 additions & 0 deletions ext/ContinuousNormalizingFlowsDistributionsExt/core.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
abstract type ICNFDistribution{AICNF <: ContinuousNormalizingFlows.AbstractICNF} <:
ContinuousMultivariateDistribution end

function Base.length(d::ICNFDistribution)
d.m.nvars
end

function Base.eltype(
::ICNFDistribution{AICNF},
) where {AICNF <: ContinuousNormalizingFlows.AbstractICNF}
first(AICNF.parameters)
end

function Base.broadcastable(d::ICNFDistribution)
Ref(d)
end
prbzrg marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading