From 2d2ddbd1c67f52ae61a7be730938da3d8ad4eef1 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Mon, 18 Nov 2024 16:13:35 +0000 Subject: [PATCH] Fix Test Loads (#374) * Change how tests are run * Update action * Fix CI formatting * Fix CI * Fix CI file * Load up Test globally * Fix labelling * Fix a couple of load paths * Improve runner * Fix CI * Fix typo * Fix up CI * Do not specify project * Fix typo * Fix CI * Fix CI * Fix CI * More fixes * Finish rename and fix loading * Fix loading for main tests * Move stuff around * Fix CI * Do not load up front_matter twice * Fix LuxLib tests * Hopefully fix buildkite * Qualify bitcast uses * Qualify more globalrefs * Rule for compilerbarrier * Comment out failing test + leave issue number reference * Maybe fix buildkite * Hmmm * Correct file naming * Hopefully run codecov conditionally * Hopefully fix CI * Fix CUDA runner * Use JET in CUDA extension tests * Enable coverage from buildkite * Split up rules testing * Move coverage plugin to single command section * Update the docs * Tweak pipeline labelling * Bump patch version --- .buildkite/pipeline.yml | 5 +- .github/workflows/CI.yml | 100 +++++++++++++----- Project.toml | 2 +- docs/src/running_tests_locally.md | 13 ++- src/Mooncake.jl | 2 +- src/rrules/builtins.jl | 8 +- test/ext/cuda/Project.toml | 2 + test/ext/cuda/cuda.jl | 7 +- .../{di.jl => differentiation_interface.jl} | 0 test/ext/dynamic_ppl/Project.toml | 1 + test/ext/dynamic_ppl/dynamic_ppl.jl | 6 +- test/ext/luxlib/Project.toml | 2 + test/ext/luxlib/luxlib.jl | 5 +- test/ext/nnlib/Project.toml | 2 + test/ext/nnlib/nnlib.jl | 20 +++- test/ext/special_functions/Project.toml | 3 + .../special_functions/special_functions.jl | 7 +- test/integration_testing/array/Project.toml | 5 + test/integration_testing/{ => array}/array.jl | 15 ++- .../battery_tests/Project.toml | 7 ++ .../{ => battery_tests}/battery_tests.jl | 21 ++-- .../bijectors/Project.toml | 2 +- .../bijectors/bijectors.jl | 19 ++-- test/integration_testing/diff_tests.jl | 11 -- .../diff_tests/Project.toml | 5 + .../diff_tests/diff_tests.jl | 18 ++++ .../distributions/Project.toml | 3 + .../distributions/distributions.jl | 18 +++- test/integration_testing/gp/Project.toml | 2 + test/integration_testing/gp/gp.jl | 9 +- .../logexpfunctions/Project.toml | 2 + .../logexpfunctions/logexpfunctions.jl | 21 ++-- test/integration_testing/lux/Project.toml | 1 + test/integration_testing/lux/lux.jl | 10 +- .../misc_abstract_array/Project.toml | 6 ++ .../misc_abstract_array.jl | 21 ++-- .../temporalgps/Project.toml | 1 + .../temporalgps/temporalgps.jl | 7 +- test/integration_testing/turing/Project.toml | 1 + test/integration_testing/turing/turing.jl | 8 +- test/run_extra.jl | 3 + test/runtests.jl | 99 ++++++----------- 42 files changed, 321 insertions(+), 179 deletions(-) rename test/ext/differentiation_interface/{di.jl => differentiation_interface.jl} (100%) create mode 100644 test/integration_testing/array/Project.toml rename test/integration_testing/{ => array}/array.jl (98%) create mode 100644 test/integration_testing/battery_tests/Project.toml rename test/integration_testing/{ => battery_tests}/battery_tests.jl (64%) delete mode 100644 test/integration_testing/diff_tests.jl create mode 100644 test/integration_testing/diff_tests/Project.toml create mode 100644 test/integration_testing/diff_tests/diff_tests.jl create mode 100644 test/integration_testing/misc_abstract_array/Project.toml rename test/integration_testing/{ => misc_abstract_array}/misc_abstract_array.jl (82%) create mode 100644 test/run_extra.jl diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index af00de5d7..5eed9ba0c 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -6,15 +6,16 @@ steps: plugins: - JuliaCI/julia#v1: version: "1" - - JuliaCI/julia-test#v1: ~ - JuliaCI/julia-coverage#v1: dirs: - src - ext + command: julia --code-coverage=user --eval='include("test/run_extra.jl")' agents: queue: "juliagpu" cuda: "*" if: build.message !~ /\[skip tests\]/ timeout_in_minutes: 60 env: - TEST_GROUP: "gpu" + LABEL: cuda + TEST_TYPE: ext diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index e82fe22b2..ea578eae3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -20,26 +20,25 @@ jobs: strategy: fail-fast: false matrix: - test_group: - - 'aqua' - - 'basic' - - 'rrules' - - 'ext/differentiation_interface' - - 'ext/dynamic_ppl' - - 'ext/luxlib' - - 'ext/nnlib' - - 'ext/special_functions' - - 'integration_testing/array' - - 'integration_testing/bijectors' - - 'integration_testing/diff_tests' - - 'integration_testing/distributions' - - 'integration_testing/gp' - - 'integration_testing/logexpfunctions' - - 'integration_testing/lux' - - 'integration_testing/misc' - - 'integration_testing/misc_abstract_array' - - 'integration_testing/temporalgps' - - 'integration_testing/turing' + test_group: [ + 'aqua', + 'basic', + 'rrules/avoiding_non_differentiable_code', + 'rrules/blas', + 'rrules/builtins', + 'rrules/fastmath', + 'rrules/foreigncall', + 'rrules/functionwrappers', + 'rrules/iddict', + 'rrules/lapack', + 'rrules/linear_algebra', + 'rrules/low_level_maths', + 'rrules/memory', + 'rrules/misc', + 'rrules/new', + 'rrules/tasks', + 'rrules/twice_precision', + ] version: - '1' arch: @@ -48,13 +47,9 @@ jobs: - test_group: 'basic' version: '1.10' arch: x64 - - test_group: 'integration_testing/turing' - version: '1.10' - arch: x64 - test_group: 'basic' version: '1.10' arch: x86 - steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -73,6 +68,63 @@ jobs: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false + extra: + name: ${{matrix.test_group.test_type}}-${{ matrix.test_group.label }}-${{ matrix.version }}-${{ matrix.arch }} + runs-on: ubuntu-latest + if: github.event_name != 'schedule' + strategy: + fail-fast: false + matrix: + test_group: [ + {test_type: 'ext', label: 'differentiation_interface'}, + {test_type: 'ext', label: 'dynamic_ppl'}, + {test_type: 'ext', label: 'luxlib'}, + {test_type: 'ext', label: 'nnlib'}, + {test_type: 'ext', label: 'special_functions'}, + {test_type: 'integration_testing', label: 'array'}, + {test_type: 'integration_testing', label: 'bijectors'}, + {test_type: 'integration_testing', label: 'diff_tests'}, + {test_type: 'integration_testing', label: 'distributions'}, + {test_type: 'integration_testing', label: 'gp'}, + {test_type: 'integration_testing', label: 'logexpfunctions'}, + {test_type: 'integration_testing', label: 'lux'}, + {test_type: 'integration_testing', label: 'battery_tests'}, + {test_type: 'integration_testing', label: 'misc_abstract_array'}, + {test_type: 'integration_testing', label: 'temporalgps'}, + {test_type: 'integration_testing', label: 'turing'}, + ] + version: + - '1' + arch: + - x64 + include: + - test_group: {test_type: 'integration_testing', label: 'turing'} + version: '1.10' + arch: x64 + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + arch: ${{ matrix.arch }} + include-all-prereleases: false + - uses: julia-actions/cache@v2 + - uses: julia-actions/julia-buildpkg@v1 + - run: | + if [ ${{ matrix.test_group.test_type }} == 'ext' ]; then + julia --code-coverage=user --eval 'include("test/run_extra.jl")' + else + julia --eval 'include("test/run_extra.jl")' + fi + env: + LABEL: ${{ matrix.test_group.label }} + TEST_TYPE: ${{ matrix.test_group.test_type }} + - uses: julia-actions/julia-processcoverage@v1 + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + fail_ci_if_error: false perf: name: "Performance (${{ matrix.perf_group }})" runs-on: ubuntu-latest diff --git a/Project.toml b/Project.toml index 7b6ad3b0d..9a39d6c86 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.43" +version = "0.4.44" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/docs/src/running_tests_locally.md b/docs/src/running_tests_locally.md index 1b44deee0..755b46ef6 100644 --- a/docs/src/running_tests_locally.md +++ b/docs/src/running_tests_locally.md @@ -1,6 +1,8 @@ # Running Tests Locally -Mooncake.jl's test suite is fairly extensive. While you can use `Pkg.test` to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl. When editing some code, you typically only want to run the tests associated with it, not the entire test suite. +Mooncake.jl's test suite is fairly extensive. While you can use `Pkg.test` to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl, and will not run all of the tests. When editing some code, you typically only want to run the tests associated with it, not the entire test suite. + +## Main Testing Functionality Mooncake's tests are organised as follows: 1. Things that are required for most / all test suites are loaded up in `test/front_matter.jl`. @@ -21,3 +23,12 @@ The purpose of this approach is to: If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice. + +## Extension and Integration Testing + +Mooncake now has quite a lot of package extensions, and a large number of integration tests. +Unfortunately, these come with a lot of additional dependencies. +To avoid these dependencies causing CI to take much longer to run, we locate all tests for extensions and integration testing in their own environments. These can be found in the `test/ext` and `test/integration_testing` directories respectively. + +These directories comprise a single `.jl` file, and a `Project.toml`. +You should run these tests by simply `include`ing the `.jl` file. Doing so will activate the environemnt, ensure that the correct version of Mooncake is used, and run the tests. diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 6cd360006..e8bb0996f 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -24,7 +24,7 @@ using Base.Experimental: @opaque using Base.Iterators: product using Core: Intrinsics, bitcast, SimpleVector, svec, ReturnNode, GotoNode, GotoIfNot, PhiNode, - PiNode, SSAValue, Argument, OpaqueClosure + PiNode, SSAValue, Argument, OpaqueClosure, compilerbarrier using Core.Compiler: IRCode, NewInstruction using Core.Intrinsics: pointerref, pointerset using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm! diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 4e0ed19a1..912c72d80 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -494,7 +494,11 @@ function rrule!!(f::CoDual{typeof(Core.apply_type)}, args...) return CoDual{_typeof(T), NoFData}(T, NoFData()), NoPullback(f, args...) end -# Core.compilerbarrier +function rrule!!(::CoDual{typeof(compilerbarrier)}, setting::CoDual{Symbol}, val::CoDual) + compilerbarrier_pb(dout) = NoRData(), NoRData(), dout + return compilerbarrier(setting.x, val), compilerbarrier_pb +end + # Core.donotdelete # Core.finalizer # Core.get_binding_type @@ -801,7 +805,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) (true, :stability, nothing, Core._typevar, :T, Union{}, Any), (false, :none, _range, Core.apply_type, Vector, Float64), (false, :none, _range, Core.apply_type, Array, Float64, 2), - # Core.compilerbarrier -- NEEDS IMPLEMENTING AND TESTING + (false, :none, nothing, compilerbarrier, :type, 5.0), # Core.const_arrayref -- NEEDS IMPLEMENTING AND TESTING # Core.donotdelete -- NEEDS IMPLEMENTING AND TESTING # Core.finalizer -- NEEDS IMPLEMENTING AND TESTING diff --git a/test/ext/cuda/Project.toml b/test/ext/cuda/Project.toml index f674a7383..d387b554c 100644 --- a/test/ext/cuda/Project.toml +++ b/test/ext/cuda/Project.toml @@ -1,4 +1,6 @@ [deps] CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/ext/cuda/cuda.jl b/test/ext/cuda/cuda.jl index dc4a14a38..82e4fc9a5 100644 --- a/test/ext/cuda/cuda.jl +++ b/test/ext/cuda/cuda.jl @@ -2,19 +2,20 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using CUDA, Mooncake, Test +using CUDA, JET, Mooncake, StableRNGs, Test +using Mooncake.TestUtils: test_tangent, test_rule @testset "cuda" begin # Check we can operate on CuArrays. test_tangent( - Xoshiro(123456), CuArray{Float32, 2, CUDA.DeviceMemory}(undef, 8, 8); + StableRNG(123456), CuArray{Float32, 2, CUDA.DeviceMemory}(undef, 8, 8); interface_only=false, ) # Check we can instantiate a CuArray. test_rule( - sr(123456), CuArray{Float32, 1, CUDA.DeviceMemory}, undef, 256; + StableRNG(123456), CuArray{Float32, 1, CUDA.DeviceMemory}, undef, 256; interface_only=true, is_primitive=true, debug_mode=true, ) end diff --git a/test/ext/differentiation_interface/di.jl b/test/ext/differentiation_interface/differentiation_interface.jl similarity index 100% rename from test/ext/differentiation_interface/di.jl rename to test/ext/differentiation_interface/differentiation_interface.jl diff --git a/test/ext/dynamic_ppl/Project.toml b/test/ext/dynamic_ppl/Project.toml index c7862b1e9..a30309cad 100644 --- a/test/ext/dynamic_ppl/Project.toml +++ b/test/ext/dynamic_ppl/Project.toml @@ -1,4 +1,5 @@ [deps] DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/ext/dynamic_ppl/dynamic_ppl.jl b/test/ext/dynamic_ppl/dynamic_ppl.jl index c2098be0c..8d58636c5 100644 --- a/test/ext/dynamic_ppl/dynamic_ppl.jl +++ b/test/ext/dynamic_ppl/dynamic_ppl.jl @@ -2,8 +2,10 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using DynamicPPL, Mooncake, Test +using DynamicPPL, Mooncake, StableRNGs, Test +using DynamicPPL: istrans, VarInfo +using Mooncake.TestUtils: test_rule @testset "DynamicPPLMooncakeExt" begin - test_rule(sr(123456), DynamicPPL.istrans, DynamicPPL.VarInfo(); unsafe_perturb=true) + test_rule(StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true) end diff --git a/test/ext/luxlib/Project.toml b/test/ext/luxlib/Project.toml index 2e39ac752..97fc1d149 100644 --- a/test/ext/luxlib/Project.toml +++ b/test/ext/luxlib/Project.toml @@ -1,8 +1,10 @@ [deps] +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] diff --git a/test/ext/luxlib/luxlib.jl b/test/ext/luxlib/luxlib.jl index 933bfe09f..11e77e393 100644 --- a/test/ext/luxlib/luxlib.jl +++ b/test/ext/luxlib/luxlib.jl @@ -2,8 +2,9 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using Lux, LuxLib, Mooncake, NNlib, Test +using JET, Lux, LuxLib, Mooncake, NNlib, StableRNGs, Test using LuxLib.Impl: SLEEFActivations +using Mooncake.TestUtils: test_rule @testset "luxlib" begin @testset "$(typeof(fargs))" for (interface_only, perf_flag, is_primitive, fargs...) in vcat( @@ -65,6 +66,6 @@ using LuxLib.Impl: SLEEFActivations ) end), ) - test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) + test_rule(StableRNG(123), fargs...; perf_flag, is_primitive, interface_only) end end diff --git a/test/ext/nnlib/Project.toml b/test/ext/nnlib/Project.toml index 6c5b16304..e764674ef 100644 --- a/test/ext/nnlib/Project.toml +++ b/test/ext/nnlib/Project.toml @@ -1,4 +1,6 @@ [deps] +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/ext/nnlib/nnlib.jl b/test/ext/nnlib/nnlib.jl index 2c38ccea3..39fae7fb8 100644 --- a/test/ext/nnlib/nnlib.jl +++ b/test/ext/nnlib/nnlib.jl @@ -2,7 +2,8 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using Mooncake, NNlib, Test +using JET, Mooncake, NNlib, StableRNGs, Test +using Mooncake.TestUtils: test_rule using NNlib: dropout @testset "nnlib" begin @@ -27,9 +28,18 @@ using NNlib: dropout (false, :none, true, batched_mul, randn(3, 2, 3), randn(2, 5, 3)), # dropout - (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=1), randn(2, 2), 0.5), - (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=2), randn(2, 2), 0.1), - (true, :none, false, (x, p) -> dropout(sr(1), x, p; dims=(1, 2)), randn(2, 2), 0.4), + ( + true, :none, false, + (x, p) -> dropout(StableRNG(1), x, p; dims=1), randn(2, 2), 0.5, + ), + ( + true, :none, false, + (x, p) -> dropout(StableRNG(1), x, p; dims=2), randn(2, 2), 0.1, + ), + ( + true, :none, false, + (x, p) -> dropout(StableRNG(1), x, p; dims=(1, 2)), randn(2, 2), 0.4, + ), # softmax (false, :stability, true, softmax, randn(2)), @@ -106,6 +116,6 @@ using NNlib: dropout (false, :none, false, x -> pad_constant(x, 1, 2.0; dims=:), x), ] @info "$(typeof(fargs))" - test_rule(sr(1), fargs...; perf_flag, is_primitive, interface_only) + test_rule(StableRNG(123), fargs...; perf_flag, is_primitive, interface_only) end end diff --git a/test/ext/special_functions/Project.toml b/test/ext/special_functions/Project.toml index 70ae9dcfe..81fe63f5c 100644 --- a/test/ext/special_functions/Project.toml +++ b/test/ext/special_functions/Project.toml @@ -1,4 +1,7 @@ [deps] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/ext/special_functions/special_functions.jl b/test/ext/special_functions/special_functions.jl index 2b7b715ea..5f5fffac7 100644 --- a/test/ext/special_functions/special_functions.jl +++ b/test/ext/special_functions/special_functions.jl @@ -2,7 +2,8 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using Mooncake, SpecialFunctions, Test +using AllocCheck, JET, Mooncake, SpecialFunctions, StableRNGs, Test +using Mooncake.TestUtils: test_rule # Rules in this file are only lightly tester, because they are all just @from_rrule rules. @testset "special_functions" begin @@ -44,7 +45,7 @@ using Mooncake, SpecialFunctions, Test (:stability_and_allocs, ellipe, 0.3), (:stability_and_allocs, logfactorial, 3), ] - test_rule(Xoshiro(123456), f, x...; perf_flag) + test_rule(StableRNG(123456), f, x...; perf_flag) end @testset for (perf_flag, f, x...) in [ (:allocs, logerf, 0.3, 0.5), # first branch @@ -61,6 +62,6 @@ using Mooncake, SpecialFunctions, Test (:allocs, SpecialFunctions.loggamma1p, -0.3), (:none, SpecialFunctions.lambdaeta, 5.0), ] - test_rule(Xoshiro(123456), f, x...; perf_flag, is_primitive=false) + test_rule(StableRNG(123456), f, x...; perf_flag, is_primitive=false) end end diff --git a/test/integration_testing/array/Project.toml b/test/integration_testing/array/Project.toml new file mode 100644 index 000000000..c91a24417 --- /dev/null +++ b/test/integration_testing/array/Project.toml @@ -0,0 +1,5 @@ +[deps] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/array.jl b/test/integration_testing/array/array.jl similarity index 98% rename from test/integration_testing/array.jl rename to test/integration_testing/array/array.jl index 7b114dcac..84c7fc4a9 100644 --- a/test/integration_testing/array.jl +++ b/test/integration_testing/array/array.jl @@ -1,3 +1,12 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) + +using LinearAlgebra, Mooncake, StableRNGs, Test +using Mooncake.TestUtils: test_rule + +sr(n::Int) = StableRNG(n) + _getter() = 5.0 @testset "array" begin test_cases = vcat( @@ -17,7 +26,7 @@ _getter() = 5.0 (false, *, randn(sr(19)), randn(sr(20), 2), transpose(randn(sr(18), 2, 1))), (false, *, adjoint(randn(sr(22), 2, 2)), randn(sr(21), 2)), (false, *, Diagonal(randn(sr(23), 2)), randn(sr(24), 2)), - (false, *, randn(sr(27), 2)', Diagonal(randn(sr(26), 2)), randn(sr(25), 2)), + # (false, *, randn(sr(27), 2)', Diagonal(randn(sr(26), 2)), randn(sr(25), 2)), https://github.com/compintell/Mooncake.jl/issues/319 (false, *, randn(sr(28), 2, 3)', randn(sr(29), 2)), (false, *, 4.0 * I, randn(sr(30), 2)), (false, *, 3.5 * I, randn(sr(31), 2, 3)), @@ -393,7 +402,7 @@ _getter() = 5.0 ], vec(reduce( vcat, - map(product( + map(Iterators.product( [adjoint(randn(sr(0), 2, 3)), transpose(randn(sr(1), 2, 3))], [randn(sr(3), 2), randn(sr(2), 2, 3)], [randn(sr(4)), randn(sr(5), 1), randn(sr(6), 3)], @@ -501,7 +510,7 @@ _getter() = 5.0 ] ) @testset for (interface_only, f, x...) in test_cases - @info _typeof((f, x...)) + @info typeof((f, x...)) test_rule(sr(123456), f, x...; interface_only, is_primitive=false, debug_mode=false) end end diff --git a/test/integration_testing/battery_tests/Project.toml b/test/integration_testing/battery_tests/Project.toml new file mode 100644 index 000000000..625c65412 --- /dev/null +++ b/test/integration_testing/battery_tests/Project.toml @@ -0,0 +1,7 @@ +[deps] +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/battery_tests.jl b/test/integration_testing/battery_tests/battery_tests.jl similarity index 64% rename from test/integration_testing/battery_tests.jl rename to test/integration_testing/battery_tests/battery_tests.jl index cb319a162..caafbcdf8 100644 --- a/test/integration_testing/battery_tests.jl +++ b/test/integration_testing/battery_tests/battery_tests.jl @@ -1,3 +1,10 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) + +using JET, LinearAlgebra, Mooncake, Random, StableRNGs, Test +using Mooncake: TestResources + @testset "battery_tests" begin @testset "$(typeof(p))" for p in vcat( [ @@ -29,18 +36,18 @@ view(randn(3, 3), 1:2, 1:1), Xoshiro(123456), Ref(5.0), - StructFoo(5.0, randn(5)), - MutableFoo(5.0, randn(5)), - TypeStableMutableStruct{Vector{Float64}}(5.0, randn(5)), - StructFoo(5.0), - MutableFoo(5.0), - TypeStableMutableStruct{Vector{Float64}}(5.0), + TestResources.StructFoo(5.0, randn(5)), + TestResources.MutableFoo(5.0, randn(5)), + TestResources.TypeStableMutableStruct{Vector{Float64}}(5.0, randn(5)), + TestResources.StructFoo(5.0), + TestResources.MutableFoo(5.0), + TestResources.TypeStableMutableStruct{Vector{Float64}}(5.0), LowerTriangular(randn(3, 3)), UpperTriangular(randn(3, 3)), UnitLowerTriangular(randn(3, 3)), UnitUpperTriangular(randn(2, 2)), ] ) - TestUtils.test_data(sr(123), p) + Mooncake.TestUtils.test_data(StableRNG(123), p) end end diff --git a/test/integration_testing/bijectors/Project.toml b/test/integration_testing/bijectors/Project.toml index 54a732057..8f0df0b2f 100644 --- a/test/integration_testing/bijectors/Project.toml +++ b/test/integration_testing/bijectors/Project.toml @@ -2,4 +2,4 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" diff --git a/test/integration_testing/bijectors/bijectors.jl b/test/integration_testing/bijectors/bijectors.jl index c53b4f266..54d67bc44 100644 --- a/test/integration_testing/bijectors/bijectors.jl +++ b/test/integration_testing/bijectors/bijectors.jl @@ -2,9 +2,8 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) -using Bijectors: Bijectors, inverse -using LinearAlgebra: LinearAlgebra -using Random: randn +using Bijectors, LinearAlgebra, Mooncake, StableRNGs, Test +using Mooncake.TestUtils: test_rule """ Type for specifying a test case for `test_rule`. @@ -21,7 +20,7 @@ TestCase(f, arg; name = nothing, broken=false) = TestCase(f, arg, name, broken) """ A helper function that returns a TestCase that evaluates bijector(inverse(bijector)(x)) """ -function b_binv_test_case(bijector, dim; name = nothing, rng = Xoshiro(23)) +function b_binv_test_case(bijector, dim; name = nothing, rng = StableRNG(23)) if name === nothing name = string(bijector) end @@ -73,7 +72,7 @@ end binv = Bijectors.inverse(b) return binv(b(x)) end, - randn(Xoshiro(23)); + randn(StableRNG(23)); name = "RationalQuadraticSpline on scalar", ), TestCase( @@ -82,7 +81,7 @@ end binv = Bijectors.inverse(b) return binv(b(x)) end, - randn(Xoshiro(23), 7); + randn(StableRNG(23), 7); name = "OrderedBijector", ), TestCase( @@ -96,7 +95,7 @@ end return Bijectors.logpdf(flow.dist, x) - Bijectors.logabsdetjac(flow.transform, x) end, - randn(Xoshiro(23), 7); + randn(StableRNG(23), 7); name = "PlanarLayer7", # TODO(mhauru) Broken on v1.11 due to # https://github.com/compintell/Mooncake.jl/issues/319 @@ -115,7 +114,7 @@ end Bijectors.logabsdetjac(flow.transform, x), ) end, - randn(Xoshiro(23), 11); + randn(StableRNG(23), 11); name = "PlanarLayer11", ), ] @@ -123,11 +122,11 @@ end @testset "$(case.name)" for case in test_cases if case.broken @test_broken begin - test_rule(Xoshiro(123456), case.func, case.arg; is_primitive=false) + test_rule(StableRNG(123456), case.func, case.arg; is_primitive=false) true end else - rng = Xoshiro(123456) + rng = StableRNG(123456) test_rule(rng, case.func, case.arg; is_primitive=false, unsafe_perturb=true) end end diff --git a/test/integration_testing/diff_tests.jl b/test/integration_testing/diff_tests.jl deleted file mode 100644 index aa0bf95bf..000000000 --- a/test/integration_testing/diff_tests.jl +++ /dev/null @@ -1,11 +0,0 @@ -@testset "diff_tests" begin - @testset "$f, $(_typeof(x))" for (n, (interface_only, f, x...)) in enumerate(vcat( - TestResources.DIFFTESTS_FUNCTIONS[1:6], # skipping DiffTests.num2arr_1. See https://github.com/JuliaLang/julia/issues/56193 - TestResources.DIFFTESTS_FUNCTIONS[8:66], # skipping sparse_ldiv - TestResources.DIFFTESTS_FUNCTIONS[68:89], # skipping sparse_ldiv - TestResources.DIFFTESTS_FUNCTIONS[91:end], # skipping sparse_ldiv - )) - @info "$n: $(_typeof((f, x...)))" - test_rule(sr(123456), f, x...; is_primitive=false) - end -end diff --git a/test/integration_testing/diff_tests/Project.toml b/test/integration_testing/diff_tests/Project.toml new file mode 100644 index 000000000..9ff6746f7 --- /dev/null +++ b/test/integration_testing/diff_tests/Project.toml @@ -0,0 +1,5 @@ +[deps] +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/diff_tests/diff_tests.jl b/test/integration_testing/diff_tests/diff_tests.jl new file mode 100644 index 000000000..228352112 --- /dev/null +++ b/test/integration_testing/diff_tests/diff_tests.jl @@ -0,0 +1,18 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) + +using Mooncake, Random, StableRNGs, Test +using Mooncake.TestUtils: test_rule + +@testset "diff_tests" begin + @testset "$f, $(typeof(x))" for (n, (interface_only, f, x...)) in enumerate(vcat( + Mooncake.TestResources.DIFFTESTS_FUNCTIONS[1:6], # skipping DiffTests.num2arr_1. See https://github.com/JuliaLang/julia/issues/56193 + Mooncake.TestResources.DIFFTESTS_FUNCTIONS[8:66], # skipping sparse_ldiv + Mooncake.TestResources.DIFFTESTS_FUNCTIONS[68:89], # skipping sparse_ldiv + Mooncake.TestResources.DIFFTESTS_FUNCTIONS[91:end], # skipping sparse_ldiv + )) + @info "$n: $(typeof((f, x...)))" + test_rule(StableRNG(123456), f, x...; is_primitive=false) + end +end diff --git a/test/integration_testing/distributions/Project.toml b/test/integration_testing/distributions/Project.toml index fdadd3808..cb50ea6a0 100644 --- a/test/integration_testing/distributions/Project.toml +++ b/test/integration_testing/distributions/Project.toml @@ -1,7 +1,10 @@ [deps] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/distributions/distributions.jl b/test/integration_testing/distributions/distributions.jl index afaf3919e..19d72fa37 100644 --- a/test/integration_testing/distributions/distributions.jl +++ b/test/integration_testing/distributions/distributions.jl @@ -2,10 +2,22 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using Distributions, FillArrays, Mooncake, LinearAlgebra, PDMats, Test +using + AllocCheck, + JET, + Distributions, + FillArrays, + Mooncake, + LinearAlgebra, + PDMats, + StableRNGs, + Test + +using Mooncake.TestUtils: test_rule _sym(A) = A'A _pdmat(A) = PDMat(_sym(A) + 5I) +sr(n::Int) = StableRNG(n) @testset "distributions" begin logpdf_test_cases = Any[ @@ -269,11 +281,11 @@ _pdmat(A) = PDMat(_sym(A) + 5I) @testset "$(typeof(d))" for (perf_flag, d, x) in logpdf_test_cases @info "$(map(typeof, (d, x)))" - test_rule(sr(123456), logpdf, d, x; perf_flag, is_primitive=false) + test_rule(StableRNG(123456), logpdf, d, x; perf_flag, is_primitive=false) end @testset "$name" for (perf_flag, name, f, x) in work_around_test_cases @info "$name" - test_rule(sr(123456), f, x...; perf_flag=perf_flag, is_primitive=false) + test_rule(StableRNG(123456), f, x...; perf_flag=perf_flag, is_primitive=false) end end diff --git a/test/integration_testing/gp/Project.toml b/test/integration_testing/gp/Project.toml index 0bc743d29..fb7806f29 100644 --- a/test/integration_testing/gp/Project.toml +++ b/test/integration_testing/gp/Project.toml @@ -1,5 +1,7 @@ [deps] AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/gp/gp.jl b/test/integration_testing/gp/gp.jl index 55f74459f..0933161a7 100644 --- a/test/integration_testing/gp/gp.jl +++ b/test/integration_testing/gp/gp.jl @@ -2,7 +2,8 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using AbstractGPs, KernelFunctions, Mooncake, Test +using AbstractGPs, KernelFunctions, LinearAlgebra, Mooncake, StableRNGs, Test +using Mooncake.TestUtils: test_rule @testset "gp" begin ks = Any[ @@ -36,16 +37,16 @@ using AbstractGPs, KernelFunctions, Mooncake, Test ], ) fx = GP(k)(x1, 1.1) - @testset "$(_typeof(args))" for args in Any[ + @testset "$(typeof(args))" for args in Any[ (kernelmatrix, k, x1, x2), (kernelmatrix_diag, k, x1, x2), (kernelmatrix, k, x1), (kernelmatrix_diag, k, x1), - (fx -> rand(Xoshiro(123456), fx), fx), + (fx -> rand(StableRNG(123456), fx), fx), (logpdf, fx, rand(fx)), ] @info typeof(args) - test_rule(sr(123456), args...; is_primitive=false, unsafe_perturb=true) + test_rule(StableRNG(123456), args...; is_primitive=false, unsafe_perturb=true) end end end diff --git a/test/integration_testing/logexpfunctions/Project.toml b/test/integration_testing/logexpfunctions/Project.toml index 028f788a5..9fa86b8ea 100644 --- a/test/integration_testing/logexpfunctions/Project.toml +++ b/test/integration_testing/logexpfunctions/Project.toml @@ -1,4 +1,6 @@ [deps] +AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/logexpfunctions/logexpfunctions.jl b/test/integration_testing/logexpfunctions/logexpfunctions.jl index 38a8d91da..df3cec0a8 100644 --- a/test/integration_testing/logexpfunctions/logexpfunctions.jl +++ b/test/integration_testing/logexpfunctions/logexpfunctions.jl @@ -2,7 +2,10 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using LogExpFunctions, Mooncake, Test +using AllocCheck, LogExpFunctions, Mooncake, StableRNGs, Test +using Mooncake.TestUtils: test_rule + +sr(n::Int) = StableRNG(n) @testset "logexpfunctions" begin @testset for (perf_flag, f, x...) in [ @@ -24,13 +27,13 @@ using LogExpFunctions, Mooncake, Test (:allocs, logmxp1, 0.02), (:allocs, logaddexp, -0.5, 0.4), (:allocs, logsubexp, -0.5, -5.0), - (:allocs, logsumexp, randn(5)), - (:allocs, logsumexp, randn(5, 4)), - (:allocs, logsumexp, randn(5, 4, 3)), - (:none, x -> logsumexp(x; dims=1), randn(5, 4)), - (:none, x -> logsumexp(x; dims=2), randn(5, 4)), - (:none, logsumexp!, rand(5), randn(5, 4)), - (:none, softmax, randn(10)), + (:allocs, logsumexp, randn(sr(1), 5)), + (:allocs, logsumexp, randn(sr(2), 5, 4)), + (:allocs, logsumexp, randn(sr(3), 5, 4, 3)), + (:none, x -> logsumexp(x; dims=1), randn(sr(4), 5, 4)), + (:none, x -> logsumexp(x; dims=2), randn(sr(5), 5, 4)), + (:none, logsumexp!, rand(sr(6), 5), randn(sr(7), 5, 4)), + (:none, softmax, randn(sr(7), 10)), (:allocs, cloglog, 0.5), (:allocs, cexpexp, -0.3), (:allocs, loglogistic, 0.5), @@ -38,6 +41,6 @@ using LogExpFunctions, Mooncake, Test (:allocs, log1mlogistic, -0.9), (:allocs, logit1mexp, -0.6), ] - test_rule(Xoshiro(123456), f, x...; perf_flag, is_primitive=false) + test_rule(sr(123456), f, x...; perf_flag, is_primitive=false) end end diff --git a/test/integration_testing/lux/Project.toml b/test/integration_testing/lux/Project.toml index 63c312f38..eea37ba69 100644 --- a/test/integration_testing/lux/Project.toml +++ b/test/integration_testing/lux/Project.toml @@ -1,4 +1,5 @@ [deps] Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/lux/lux.jl b/test/integration_testing/lux/lux.jl index 9249533d5..27d6e0b7c 100644 --- a/test/integration_testing/lux/lux.jl +++ b/test/integration_testing/lux/lux.jl @@ -2,7 +2,8 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using Mooncake, Lux, Test +using Mooncake, Lux, StableRNGs, Test +using Mooncake.TestUtils: test_rule @testset "lux" begin @testset "$(typeof(f))" for (f, x_f32) in Any[ @@ -43,9 +44,10 @@ using Mooncake, Lux, Test (Chain(Conv((3, 3), 2 => 6), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), (Chain(Conv((3, 3), 2 => 6, tanh), InstanceNorm(6)), randn(Float32, 6, 6, 2, 2)), ] - @info "$(_typeof((f, x_f32...)))" - ps, st = f32(Lux.setup(sr(123456), f)) + @info "$(typeof((f, x_f32...)))" + rng = StableRNG(123456) + ps, st = f32(Lux.setup(rng, f)) x = f32(x_f32) - test_rule(sr(123456), f, x, ps, st; is_primitive=false, interface_only=true) + test_rule(rng, f, x, ps, st; is_primitive=false, interface_only=true) end end diff --git a/test/integration_testing/misc_abstract_array/Project.toml b/test/integration_testing/misc_abstract_array/Project.toml new file mode 100644 index 000000000..da249598f --- /dev/null +++ b/test/integration_testing/misc_abstract_array/Project.toml @@ -0,0 +1,6 @@ +[deps] +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/misc_abstract_array.jl b/test/integration_testing/misc_abstract_array/misc_abstract_array.jl similarity index 82% rename from test/integration_testing/misc_abstract_array.jl rename to test/integration_testing/misc_abstract_array/misc_abstract_array.jl index 82f9250fa..61fd1d6cb 100644 --- a/test/integration_testing/misc_abstract_array.jl +++ b/test/integration_testing/misc_abstract_array/misc_abstract_array.jl @@ -1,3 +1,10 @@ +using Pkg +Pkg.activate(@__DIR__) +Pkg.develop(; path = joinpath(@__DIR__, "..", "..", "..")) + +using LinearAlgebra, Mooncake, Random, StableRNGs, Test +using Mooncake.TestUtils: test_rule + @testset "misc_abstract_array" begin @testset for (interface_only, f, x...) in vcat( [ @@ -6,11 +13,11 @@ (false, setindex!, randn(5), 4.0, 3), (false, setindex!, randn(5, 4), 3.0, 1, 3), (false, x -> getglobal(Main, :sin)(x), 5.0), - (false, x -> pointerref(bitcast(Ptr{Float64}, pointer_from_objref(Ref(x))), 1, 1), 5.0), - (false, (v, x) -> (pointerset(pointer(x), v, 2, 1); x), 3.0, randn(5)), - (false, x -> (pointerset(pointer(x), UInt8(3), 2, 1); x), rand(UInt8, 5)), + (false, x -> Base.pointerref(Base.bitcast(Ptr{Float64}, pointer_from_objref(Ref(x))), 1, 1), 5.0), + (false, (v, x) -> (Base.pointerset(pointer(x), v, 2, 1); x), 3.0, randn(5)), + (false, x -> (Base.pointerset(pointer(x), UInt8(3), 2, 1); x), rand(UInt8, 5)), (false, x -> Ref(x)[], 5.0), - (false, x -> unsafe_load(bitcast(Ptr{Float64}, pointer_from_objref(Ref(x)))), 5.0), + (false, x -> unsafe_load(Base.bitcast(Ptr{Float64}, pointer_from_objref(Ref(x)))), 5.0), (false, x -> unsafe_load(Base.unsafe_convert(Ptr{Float64}, x)), randn(5)), (false, view, randn(5, 4), 1, 1), (false, view, randn(5, 4), 2:3, 1), @@ -57,7 +64,7 @@ )) do (A, B, C) (false, mul!, A, B, C, randn(), randn()) end), - vec(map(product( + vec(map(Iterators.product( Any[ LowerTriangular(randn(3, 3)), UpperTriangular(randn(3, 3)), @@ -83,7 +90,7 @@ (false, mul!, A, B, C, randn(), randn()) end), ) - @info "$(_typeof((f, x...)))" - test_rule(Xoshiro(123456), f, x...; interface_only, is_primitive=false) + @info "$(typeof((f, x...)))" + test_rule(StableRNG(123456), f, x...; interface_only, is_primitive=false) end end diff --git a/test/integration_testing/temporalgps/Project.toml b/test/integration_testing/temporalgps/Project.toml index 31bdda5c0..091503857 100644 --- a/test/integration_testing/temporalgps/Project.toml +++ b/test/integration_testing/temporalgps/Project.toml @@ -2,5 +2,6 @@ AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/temporalgps/temporalgps.jl b/test/integration_testing/temporalgps/temporalgps.jl index 3e03fdec1..b9e244b58 100644 --- a/test/integration_testing/temporalgps/temporalgps.jl +++ b/test/integration_testing/temporalgps/temporalgps.jl @@ -2,7 +2,8 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using AbstractGPs, KernelFunctions, Mooncake, TemporalGPs, Test +using AbstractGPs, KernelFunctions, Mooncake, StableRNGs, TemporalGPs, Test +using Mooncake.TestUtils: test_rule build_gp(k) = to_sde(GP(k), SArrayStorage(Float64)) @@ -21,8 +22,8 @@ temporalgps_logpdf_tester(k, x, y, s) = logpdf(build_gp(k)(x, s), y) s = 1.0 y = rand(build_gp(k)(x, s)) f = temporalgps_logpdf_tester - sig = _typeof((temporalgps_logpdf_tester, k, x, y, s)) + sig = typeof((temporalgps_logpdf_tester, k, x, y, s)) @info "$sig" - test_rule(Xoshiro(123456), f, k, x, y, s; is_primitive=false) + test_rule(StableRNG(123456), f, k, x, y, s; is_primitive=false) end end diff --git a/test/integration_testing/turing/Project.toml b/test/integration_testing/turing/Project.toml index 33ae14f28..9e8a0256e 100644 --- a/test/integration_testing/turing/Project.toml +++ b/test/integration_testing/turing/Project.toml @@ -2,4 +2,5 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/integration_testing/turing/turing.jl b/test/integration_testing/turing/turing.jl index b380e4946..227150954 100644 --- a/test/integration_testing/turing/turing.jl +++ b/test/integration_testing/turing/turing.jl @@ -2,7 +2,8 @@ using Pkg Pkg.activate(@__DIR__) Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) -using Distributions, DynamicPPL, Mooncake, Test +using Distributions, DynamicPPL, Mooncake, StableRNGs, Test +using Mooncake.TestUtils: test_rule @model function simple_model() y ~ Normal() @@ -117,7 +118,8 @@ end ], ) @info name - f, x = build_turing_problem(sr(123), model, ex) - test_rule(sr(123456), f, x; interface_only, is_primitive=false, unsafe_perturb=true) + f, x = build_turing_problem(StableRNG(123), model, ex) + rng = StableRNG(123456) + test_rule(rng, f, x; interface_only, is_primitive=false, unsafe_perturb=true) end end diff --git a/test/run_extra.jl b/test/run_extra.jl new file mode 100644 index 000000000..b299096ca --- /dev/null +++ b/test/run_extra.jl @@ -0,0 +1,3 @@ +# Uses in the `extra` CI job. +using Test +include(joinpath(@__DIR__, ENV["TEST_TYPE"], ENV["LABEL"], ENV["LABEL"] * ".jl")) diff --git a/test/runtests.jl b/test/runtests.jl index eea7bdc62..79a402b17 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -23,76 +23,39 @@ include("front_matter.jl") include("interface.jl") include("config.jl") include("developer_tools.jl") - elseif test_group == "rrules" include("test_utils.jl") - @testset "rrules" begin - @info "avoiding_non_differentiable_code" - include(joinpath("rrules", "avoiding_non_differentiable_code.jl")) - @info "blas" - include(joinpath("rrules", "blas.jl")) - @info "builtins" - include(joinpath("rrules", "builtins.jl")) - @info "fastmath" - include(joinpath("rrules", "fastmath.jl")) - @info "foreigncall" - include(joinpath("rrules", "foreigncall.jl")) - @info "function_wrappers" - include(joinpath("rrules", "function_wrappers.jl")) - @info "iddict" - include(joinpath("rrules", "iddict.jl")) - @info "lapack" - include(joinpath("rrules", "lapack.jl")) - @info "linear_algebra" - include(joinpath("rrules", "linear_algebra.jl")) - @info "low_level_maths" - include(joinpath("rrules", "low_level_maths.jl")) - @info "misc" - include(joinpath("rrules", "misc.jl")) - @info "new" - include(joinpath("rrules", "new.jl")) - @info "tasks" - include(joinpath("rrules", "tasks.jl")) - @info "twice_precision" - include(joinpath("rrules", "twice_precision.jl")) - @static if VERSION >= v"1.11.0-rc4" - @info "memory" - include(joinpath("rrules", "memory.jl")) - end + elseif test_group == "rrules/avoiding_non_differentiable_code" + include(joinpath("rrules", "avoiding_non_differentiable_code.jl")) + elseif test_group == "rrules/blas" + include(joinpath("rrules", "blas.jl")) + elseif test_group == "rrules/builtins" + include(joinpath("rrules", "builtins.jl")) + elseif test_group == "rrules/fastmath" + include(joinpath("rrules", "fastmath.jl")) + elseif test_group == "rrules/foreigncall" + include(joinpath("rrules", "foreigncall.jl")) + elseif test_group == "rrules/functionwrappers" + include(joinpath("rrules", "function_wrappers.jl")) + elseif test_group == "rrules/iddict" + include(joinpath("rrules", "iddict.jl")) + elseif test_group == "rrules/lapack" + include(joinpath("rrules", "lapack.jl")) + elseif test_group == "rrules/linear_algebra" + include(joinpath("rrules", "linear_algebra.jl")) + elseif test_group == "rrules/low_level_maths" + include(joinpath("rrules", "low_level_maths.jl")) + elseif test_group == "rrules/misc" + include(joinpath("rrules", "misc.jl")) + elseif test_group == "rrules/new" + include(joinpath("rrules", "new.jl")) + elseif test_group == "rrules/tasks" + include(joinpath("rrules", "tasks.jl")) + elseif test_group == "rrules/twice_precision" + include(joinpath("rrules", "twice_precision.jl")) + elseif test_group == "rrules/memory" + @static if VERSION >= v"1.11.0-rc4" + include(joinpath("rrules", "memory.jl")) end - elseif test_group == "gpu" - include(joinpath("ext", "cuda", "cuda.jl")) - elseif test_group == "ext/differentiation_interface" - include(joinpath("ext", "differentiation_interface", "di.jl")) - elseif test_group == "ext/dynamic_ppl" - include(joinpath("ext", "dynamic_ppl", "dynamic_ppl.jl")) - elseif test_group == "ext/luxlib" - include(joinpath("ext", "luxlib", "luxlib.jl")) - elseif test_group == "ext/nnlib" - include(joinpath("ext", "nnlib", "nnlib.jl")) - elseif test_group == "ext/special_functions" - include(joinpath("ext", "special_functions", "special_functions.jl")) - elseif test_group == "integration_testing/array" - include(joinpath("integration_testing", "array.jl")) - elseif test_group == "integration_testing/bijectors" - include(joinpath("integration_testing", "bijectors", "bijectors.jl")) - elseif test_group == "integration_testing/diff_tests" - include(joinpath("integration_testing", "diff_tests.jl")) - elseif test_group == "integration_testing/distributions" - include(joinpath("integration_testing", "distributions", "distributions.jl")) - elseif test_group == "integration_testing/gp" - include(joinpath("integration_testing", "gp", "gp.jl")) - elseif test_group == "integration_testing/logexpfunctions" - include(joinpath("integration_testing", "logexpfunctions", "logexpfunctions.jl")) - elseif test_group == "integration_testing/lux" - include(joinpath("integration_testing", "lux", "lux.jl")) - elseif test_group == "integration_testing/misc" - include(joinpath("integration_testing", "battery_tests.jl")) - elseif test_group == "integration_testing/misc_abstract_array" - include(joinpath("integration_testing", "misc_abstract_array.jl")) - elseif test_group == "integration_testing/temporalgps" - include(joinpath("integration_testing", "temporalgps", "temporalgps.jl")) - elseif test_group == "integration_testing/turing" - include(joinpath("integration_testing", "turing", "turing.jl")) else throw(error("test_group=$(test_group) is not recognised")) end