diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 533b2529..3c08aeef 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,6 +26,7 @@ jobs: - CondRNODE - CondFFJORD - CondPlanar + - Regression version: - '1' # - '1.9' diff --git a/test/Project.toml b/test/Project.toml index db262567..877bdc2e 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -6,6 +6,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" @@ -15,6 +16,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -28,6 +30,7 @@ ComponentArrays = "0.15" ComputationalResources = "0.3" DataFrames = "1" DifferentiationInterface = "0.5" +Distances = "0.10" Distributions = "0.25" Enzyme = "0.12" GPUArraysCore = "0.1" @@ -36,6 +39,7 @@ Lux = "1" LuxCUDA = "0.3" MLJBase = "1" SciMLBase = "2" +StableRNGs = "1" TerminalLoggers = "0.1" Zygote = "0.6" cuDNN = "1" diff --git a/test/regression_tests.jl b/test/regression_tests.jl new file mode 100644 index 00000000..5393fc05 --- /dev/null +++ b/test/regression_tests.jl @@ -0,0 +1,44 @@ +Test.@testset "Regression Tests" begin + rng = StableRNGs.StableRNG(12345) + nvars = 2^3 + naugs = nvars + n_in = nvars + naugs + n = 2^10 + nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh)) + + icnf = ContinuousNormalizingFlows.construct( + ContinuousNormalizingFlows.RNODE, + nn, + nvars, + naugs; + compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()), + tspan = (0.0f0, 13.0f0), + steer_rate = 1.0f-1, + λ₃ = 1.0f-2, + rng, + ) + ps, st = Lux.setup(icnf.rng, icnf) + ps = ComponentArrays.ComponentArray(ps) + + data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0) + r = rand(icnf.rng, data_dist, nvars, n) + r = convert.(Float32, r) + + df = DataFrames.DataFrame(transpose(r), :auto) + model = ContinuousNormalizingFlows.ICNFModel(icnf) + + mach = MLJBase.machine(model, df) + MLJBase.fit!(mach) + + d = ContinuousNormalizingFlows.ICNFDist(mach, ContinuousNormalizingFlows.TestMode()) + actual_pdf = Distributions.pdf.(data_dist, r) + estimated_pdf = Distributions.pdf(d, r) + + mad_ = Distances.meanad(estimated_pdf, actual_pdf) + msd_ = Distances.msd(estimated_pdf, actual_pdf) + tv_dis = Distances.totalvariation(estimated_pdf, actual_pdf) / n + + Test.@test mad_ <= 1.0f-1 + Test.@test msd_ <= 1.0f-1 + Test.@test tv_dis <= 1.0f-1 +end diff --git a/test/runtests.jl b/test/runtests.jl index a6885118..c541f946 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,7 @@ import ADTypes, cuDNN, DataFrames, DifferentiationInterface, + Distances, Distributions, Enzyme, GPUArraysCore, @@ -15,6 +16,7 @@ import ADTypes, LuxCUDA, MLJBase, SciMLBase, + StableRNGs, TerminalLoggers, Test, Zygote, @@ -47,4 +49,8 @@ Test.@testset "Overall" begin if GROUP == "All" || GROUP == "Instability" include("instability_tests.jl") end + + if GROUP == "All" || GROUP == "Regression" + include("regression_tests.jl") + end end