Skip to content

Commit

Permalink
add Regression Tests (#428)
Browse files Browse the repository at this point in the history
* add Regression Tests

* fix

* fix

* fix

* change seed

* StableRNG 12345
  • Loading branch information
prbzrg authored Sep 8, 2024
1 parent f45226f commit b116e16
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ jobs:
- CondRNODE
- CondFFJORD
- CondPlanar
- Regression
version:
- '1'
# - '1.9'
Expand Down
4 changes: 4 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ComputationalResources = "ed09eef8-17a6-5b46-8889-db040fac31e3"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Expand All @@ -15,6 +16,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -28,6 +30,7 @@ ComponentArrays = "0.15"
ComputationalResources = "0.3"
DataFrames = "1"
DifferentiationInterface = "0.5"
Distances = "0.10"
Distributions = "0.25"
Enzyme = "0.12"
GPUArraysCore = "0.1"
Expand All @@ -36,6 +39,7 @@ Lux = "1"
LuxCUDA = "0.3"
MLJBase = "1"
SciMLBase = "2"
StableRNGs = "1"
TerminalLoggers = "0.1"
Zygote = "0.6"
cuDNN = "1"
Expand Down
44 changes: 44 additions & 0 deletions test/regression_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
Test.@testset "Regression Tests" begin
rng = StableRNGs.StableRNG(12345)
nvars = 2^3
naugs = nvars
n_in = nvars + naugs
n = 2^10
nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh))

icnf = ContinuousNormalizingFlows.construct(
ContinuousNormalizingFlows.RNODE,
nn,
nvars,
naugs;
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
tspan = (0.0f0, 13.0f0),
steer_rate = 1.0f-1,
λ₃ = 1.0f-2,
rng,
)
ps, st = Lux.setup(icnf.rng, icnf)
ps = ComponentArrays.ComponentArray(ps)

data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
r = rand(icnf.rng, data_dist, nvars, n)
r = convert.(Float32, r)

df = DataFrames.DataFrame(transpose(r), :auto)
model = ContinuousNormalizingFlows.ICNFModel(icnf)

mach = MLJBase.machine(model, df)
MLJBase.fit!(mach)

d = ContinuousNormalizingFlows.ICNFDist(mach, ContinuousNormalizingFlows.TestMode())
actual_pdf = Distributions.pdf.(data_dist, r)
estimated_pdf = Distributions.pdf(d, r)

mad_ = Distances.meanad(estimated_pdf, actual_pdf)
msd_ = Distances.msd(estimated_pdf, actual_pdf)
tv_dis = Distances.totalvariation(estimated_pdf, actual_pdf) / n

Test.@test mad_ <= 1.0f-1
Test.@test msd_ <= 1.0f-1
Test.@test tv_dis <= 1.0f-1
end
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import ADTypes,
cuDNN,
DataFrames,
DifferentiationInterface,
Distances,
Distributions,
Enzyme,
GPUArraysCore,
Expand All @@ -15,6 +16,7 @@ import ADTypes,
LuxCUDA,
MLJBase,
SciMLBase,
StableRNGs,
TerminalLoggers,
Test,
Zygote,
Expand Down Expand Up @@ -47,4 +49,8 @@ Test.@testset "Overall" begin
if GROUP == "All" || GROUP == "Instability"
include("instability_tests.jl")
end

if GROUP == "All" || GROUP == "Regression"
include("regression_tests.jl")
end
end

0 comments on commit b116e16

Please sign in to comment.