From 5397d284e89f08afbb22cce32483c1ada6baec10 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Sat, 23 Nov 2024 16:33:47 +0000 Subject: [PATCH] Run everything on 1.10 (#385) * Run everything on 1.10 * Expand buildkite runner into matrix * Also run integration and ext testing on LTS * Try LTS * Just write 1.10 * Julia version number in matrix * Document support policy * Document support policy * Restrict Julia compat properly * Fix DynamicPPL tests on 1.10 * Simplify gitignore * Formatting * Relax performance bounds on nnlib scatter * Improve error message * Fix a problem * Fix error introduced in the last PR * Fix coverage * Fix coverage * Fix LuxLib * Prevent small union inlining * use static if everywhere * Revert previous change * Patch the problem * Run error printing * Bump patch --- .buildkite/pipeline.yml | 7 ++- .github/workflows/CI.yml | 9 +--- .gitignore | 6 +-- Project.toml | 4 +- README.md | 3 ++ SUPPORT_POLICY.md | 42 ++++++++++++++++ ext/MooncakeLuxLibSLEEFPiratesExtension.jl | 27 ++++++++++ src/interpreter/abstract_interpretation.jl | 4 +- src/interpreter/contexts.jl | 2 +- src/interpreter/ir_normalisation.jl | 2 +- src/interpreter/ir_utils.jl | 4 +- src/interpreter/s2s_reverse_mode_ad.jl | 58 ++++++++++++++++++---- src/rrules/builtins.jl | 9 +++- src/rrules/fastmath.jl | 3 ++ src/rrules/misc.jl | 2 +- src/test_utils.jl | 2 +- test/ext/dynamic_ppl/dynamic_ppl.jl | 3 +- test/ext/nnlib/nnlib.jl | 4 +- test/interpreter/s2s_reverse_mode_ad.jl | 21 ++++---- test/rrules/builtins.jl | 6 +++ 20 files changed, 170 insertions(+), 48 deletions(-) create mode 100644 SUPPORT_POLICY.md diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 5eed9ba0c..4aa3d9cc6 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -2,10 +2,10 @@ env: SECRET_CODECOV_TOKEN: "nkcRFVXdaPNAbiI0x3qK/XUG8rWjBc8fU73YEyP35SeS465XORqrIYrHUbHuJTRyeyqNRdsHaBcV1P7TBbKAaTQAjHQ1Q0KYfd0uRMSWpZSCgTBz5AwttAxVfFrX+Ky3PzTi2TfDe0uPFZtFo0Asq6sUEr1on+Oo+j+q6br2NK6CrA5yKKuTX4Q2V/UPOIK4vNXY3+zDTKSNtr+HQOlcVEeRIk/0ZQ78Cjd52flEaVw8GWo/CC4YBzLtcOZgaFdgOTEDNHMr0mw6zLE4Y6nxq4lHVSoraSjxjhkB0pXTZ1c51yHX8Jc+q6HC5s87+2Zq5YtsuQSGao+eMtkTAYwfLw==;U2FsdGVkX18z27J3+gNgxsPNnXA0ad4LvZnXeohTam7/6UPqX5+3BYI0tAiVkCho4vlJyL7dd8JEyNtk9BFXsg==" steps: - - label: "Julia v1" + - label: "Julia v{{matrix}}" plugins: - JuliaCI/julia#v1: - version: "1" + version: "{{matrix}}" - JuliaCI/julia-coverage#v1: dirs: - src @@ -19,3 +19,6 @@ steps: env: LABEL: cuda TEST_TYPE: ext + matrix: + - "1" + - "1.10" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index ea578eae3..06d4d0dd0 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -40,13 +40,11 @@ jobs: 'rrules/twice_precision', ] version: + - 'lts' - '1' arch: - x64 include: - - test_group: 'basic' - version: '1.10' - arch: x64 - test_group: 'basic' version: '1.10' arch: x86 @@ -95,12 +93,9 @@ jobs: ] version: - '1' + - 'lts' arch: - x64 - include: - - test_group: {test_type: 'integration_testing', label: 'turing'} - version: '1.10' - arch: x64 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 diff --git a/.gitignore b/.gitignore index 2d83a4fe5..bcbf1f024 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,8 @@ -/Manifest.toml -/Manifest-v1.11.toml +Manifest* dev -bench/Manifest.toml analysis_results .vscode profile.pb.gz scratch.jl docs/build/ docs/site/ -docs/Manifest.toml -Manifest.toml diff --git a/Project.toml b/Project.toml index c51fa023f..cf98145b4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.48" +version = "0.4.49" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -66,7 +66,7 @@ Setfield = "1" SpecialFunctions = "2" StableRNGs = "1" Test = "1" -julia = "1" +julia = "1.10" [extras] AllocCheck = "9b6a8646-10ed-4001-bbdc-1d2f46dfbb1a" diff --git a/README.md b/README.md index 2c844d934..6872107aa 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,9 @@ If you encounter a new version of `Mooncake.jl` in the wild, please consult this # Getting Started +Check that you're running a version of Julia that Mooncake.jl supports. +See the `SUPPORT_POLICY.md` file for more info. + There are several ways to interact with `Mooncake.jl`. The one that we recommend people begin with is [`DifferentiationInterface.jl`](https://github.com/gdalle/DifferentiationInterface.jl/). For example, use it as follows to compute the gradient of a function mapping a `Vector{Float64}` to `Float64`. diff --git a/SUPPORT_POLICY.md b/SUPPORT_POLICY.md new file mode 100644 index 000000000..0b4bcef34 --- /dev/null +++ b/SUPPORT_POLICY.md @@ -0,0 +1,42 @@ +# Summary + +At any given point in time, `Mooncake.jl` supports the current Long Term Support (LTS) release of Julia, and the latest release version of Julia 1. +Consequently, the versions of Julia which are officially supported by `Mooncake.jl` will change (almost) _immediately_ whenever a new Julia LTS version is declared, or a minor release of Julia is made. + +For example, the LTS is 1.10 and the latest release is 1.11 at the time of writing. When 1.12 is released, we will +1. bump the Julia compat bounds in `Mooncake.jl` to require either 1.10 or 1.12, +1. cease to run CI on 1.11, +1. cease to provide bug fixes for 1.11, +1. cease to accept 1.11-specific bug fixes, as we will not be running CI for 1.11 and therefore will not be able to test that they have worked. + +In short: as far as `Mooncake.jl`'s future releases, 1.11 ceases to exist the moment 1.12 is released. + +Note that these changes are not applied retrospectively to existing releases of `Mooncake.jl`. +Suppose that `Mooncake.jl` is at `v0.4.50` when 1.12 is released. +Then the above changes would be relevant to `Mooncake.jl` versions `v0.4.51` and higher. + +# Patch Versions + +The above only discussed minor versions of Julia (1.10, 1.11, 1.12, etc). +However, it also applies to patch versions of Julia. +For example, at the time of writing, Julia version 1.10.6 is _actually_ the LTS, and 1.11.1 the current release of Julia. +The moment that 1.10.7 is released, we will cease to run any CI on 1.10.6, and will not accept fixes for it. +The same is true of 1.11.2. + +Since patch releases of Julia are less invasive than minor releases, this should generally not cause users problems. + +# Context + +In order to support a particular version of Julia, we must +1. always run CI for that version, +1. accept and proactively produce fixes for that version, +1. maintain version-specific code in the `Mooncake.jl` codebase. + +This requires a surprisingly amount of overhead to the development of `Mooncake.jl`, and has the potential to substantially increase the complexity of the codebase. +All of this makes it harder to improve `Mooncake.jl`. +Consequently, this policy represents a decision to tradeoff support for a range of minor Julia versions in exchange for easing the development burden associated to `Mooncake.jl`. + +## Why not gently drop support? + +In the JuliaGaussianProcesses ecosystem, we had a loosely-defined policy of keeping support for an older version until we ran into a large problem which could not be fixed easily, at which point we would drop support. +While this sounds appealing, in practice it makes it hard to know exactly when to drop support for a particular version of Julia, increases the burden for maintainers, and makes it hard for users to know exactly what to expect. diff --git a/ext/MooncakeLuxLibSLEEFPiratesExtension.jl b/ext/MooncakeLuxLibSLEEFPiratesExtension.jl index 906dc30f2..97f55835f 100644 --- a/ext/MooncakeLuxLibSLEEFPiratesExtension.jl +++ b/ext/MooncakeLuxLibSLEEFPiratesExtension.jl @@ -4,6 +4,8 @@ using LuxLib, Mooncake, SLEEFPirates using Base: IEEEFloat using Mooncake: @from_rrule, DefaultCtx +@static if VERSION >= v"1.11" + # Workaround for package load order problems. See # https://github.com/JuliaLang/julia/issues/56204#issuecomment-2419553167 for more context. function __init__() @@ -31,4 +33,29 @@ function __init__() end end +else + +for f in Any[ + LuxLib.NNlib.sigmoid_fast, + LuxLib.NNlib.softplus, + LuxLib.NNlib.logsigmoid, + LuxLib.NNlib.swish, + LuxLib.NNlib.lisht, + Base.tanh, + LuxLib.NNlib.tanh_fast, +] + f_fast = LuxLib.Impl.sleefpirates_fast_act(f) + @eval @from_rrule DefaultCtx Tuple{typeof($f_fast), IEEEFloat} + @eval @from_rrule( + DefaultCtx, + Tuple{ + typeof(Broadcast.broadcasted), + typeof($f_fast), + Union{IEEEFloat, Array{<:IEEEFloat}}, + }, + ) +end + +end + end diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index d8b15322e..6442be1d7 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -107,7 +107,7 @@ function CC.method_table(interp::MooncakeInterpreter) return CC.OverlayMethodTable(interp.world, mooncake_method_table) end -if VERSION < v"1.11.0" +@static if VERSION < v"1.11.0" CC.get_world_counter(interp::MooncakeInterpreter) = interp.world get_inference_world(interp::CC.AbstractInterpreter) = CC.get_world_counter(interp) else @@ -160,7 +160,7 @@ function Core.Compiler.abstract_call_gf_by_type( end end -if VERSION < v"1.11-" +@static if VERSION < v"1.11-" function CC.inlining_policy( interp::MooncakeInterpreter{C}, diff --git a/src/interpreter/contexts.jl b/src/interpreter/contexts.jl index b83c92cca..5b82e47c7 100644 --- a/src/interpreter/contexts.jl +++ b/src/interpreter/contexts.jl @@ -32,7 +32,7 @@ Observe that this information means that whether or not something is a primitive particular context depends only on static information, not any run-time information that might live in a particular instance of `Ctx`. """ -is_primitive(::Type{MinimalCtx}, ::Any) = false +is_primitive(::Type{MinimalCtx}, sig::Type{<:Tuple}) = false is_primitive(::Type{DefaultCtx}, sig) = is_primitive(MinimalCtx, sig) """ diff --git a/src/interpreter/ir_normalisation.jl b/src/interpreter/ir_normalisation.jl index 6c835f023..21873c966 100644 --- a/src/interpreter/ir_normalisation.jl +++ b/src/interpreter/ir_normalisation.jl @@ -214,7 +214,7 @@ __get_arg(x::QuoteNode) = x.value __get_arg(x) = x # memoryrefget and memoryrefset! were introduced in 1.11. -if VERSION >= v"1.11-" +@static if VERSION >= v"1.11-" """ lift_memoryrefget_and_memoryrefset_builtins(inst) diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 9e736daf4..56883b76d 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -190,7 +190,7 @@ function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) ir = __strip_coverage!(ir) ir = CC.sroa_pass!(ir, inline_state) - if VERSION < v"1.11-" + @static if VERSION < v"1.11-" ir = CC.adce_pass!(ir, inline_state) else ir, _ = CC.adce_pass!(ir, inline_state) @@ -227,7 +227,7 @@ function lookup_ir(interp::CC.AbstractInterpreter, tt::Type{<:Tuple}; optimize_u asts = [] for match in get_matches(matches.matches) match = match::Core.MethodMatch - if VERSION < v"1.11-" + @static if VERSION < v"1.11-" meth = Base.func_for_method_checked(match.method, tt, match.sparams) (code, ty) = CC.typeinf_ircode( interp, meth, match.spec_types, match.sparams, optimize_until diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 0de5480ba..1896981ce 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -213,6 +213,10 @@ get_primal_type(::ADInfo, x) = _typeof(x) function get_primal_type(::ADInfo, x::GlobalRef) return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty end +function get_primal_type(::ADInfo, x::Expr) + x.head === :boundscheck && return Bool + error("Unrecognised expression $x found in argument slot.") +end """ get_rev_data_id(info::ADInfo, x) @@ -394,7 +398,11 @@ function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) rvs = new_inst(Expr(:call, increment_ref!, rdata_id, Argument(2))) return ad_stmt_info(line, nothing, inc_args(stmt), rvs) else - fwds = ReturnNode(const_codual(stmt.val, info)) + const_id = ID() + fwds = [ + (const_id, new_inst(const_codual_stmt(stmt.val, info))), + (ID(), new_inst(ReturnNode(const_id))), + ] return ad_stmt_info(line, nothing, fwds, nothing) end end @@ -457,7 +465,11 @@ function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) else # If the value of the PiNode is a constant / QuoteNode etc, then there is nothing to # do on the reverse-pass. - fwds = PiNode(const_codual(stmt.val, info), fcodual_type(_type(stmt.typ))) + const_id = ID() + fwds = [ + (const_id, new_inst(const_codual_stmt(stmt.val, info))), + (line, new_inst(PiNode(const_id, fcodual_type(_type(stmt.typ))))), + ] rvs = nothing end @@ -475,11 +487,11 @@ end function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo) isconst(stmt) && return const_ad_stmt(stmt, line, info) - x = const_codual(getglobal(stmt.mod, stmt.name), info) - globalref_id = ID() + const_id, globalref_id = ID(), ID() fwds = [ (globalref_id, new_inst(stmt)), - (line, new_inst(Expr(:call, __verify_const, globalref_id, x))), + (const_id, new_inst(const_codual_stmt(getglobal(stmt.mod, stmt.name), info))), + (line, new_inst(Expr(:call, __verify_const, globalref_id, const_id))), ] return ad_stmt_info(line, nothing, fwds, nothing) end @@ -502,8 +514,22 @@ make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) Implementation of `make_ad_stmts!` used for constants. """ function const_ad_stmt(stmt, line::ID, info::ADInfo) - x = const_codual(stmt, info) - return ad_stmt_info(line, nothing, x isa ID ? Expr(:call, identity, x) : x, nothing) + return ad_stmt_info(line, nothing, const_codual_stmt(stmt, info), nothing) +end + +""" + const_codual_stmt(stmt, info::ADInfo) + +Returns a `:call` expression which will return a `CoDual` whose primal is `stmt`, and whose +tangent is whatever `uninit_tangent` returns. +""" +function const_codual_stmt(stmt, info::ADInfo) + v = get_const_primal_value(stmt) + if safe_for_literal(v) + return Expr(:call, uninit_fcodual, v) + else + return Expr(:call, identity, add_data!(info, uninit_fcodual(v))) + end end """ @@ -519,10 +545,21 @@ function const_codual(stmt, info::ADInfo) return safe_for_literal(v) ? x : add_data!(info, x) end -safe_for_literal(v) = v isa String || v isa Type || isbitstype(_typeof(v)) +function safe_for_literal(v) + v isa Expr && v.head === :boundscheck && return true + v isa String && return true + v isa Type && return true + v isa Tuple && all(safe_for_literal, v) && return true + isbitstype(_typeof(v)) && return true + return false +end inc_or_const(stmt, info::ADInfo) = is_active(stmt) ? __inc(stmt) : const_codual(stmt, info) +function inc_or_const_stmt(stmt, info::ADInfo) + return is_active(stmt) ? Expr(:call, identity, __inc(stmt)) : const_codual_stmt(stmt, info) +end + """ get_const_primal_value(x::GlobalRef) @@ -616,7 +653,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) # Make arguments to rrule call. Things which are not already CoDual must be made so. codual_arg_ids = map(_ -> ID(), collect(args)) codual_args = map(args, codual_arg_ids) do arg, id - return (id, new_inst(Expr(:call, identity, inc_or_const(arg, info)))) + return (id, new_inst(inc_or_const_stmt(arg, info))) end # Make call to rule. @@ -691,8 +728,7 @@ function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) elseif Meta.isexpr(stmt, :copyast) # Get constant out and shove it in shared storage. - x = const_codual(stmt.args[1], info) - return ad_stmt_info(line, nothing, Expr(:call, identity, x), nothing) + return ad_stmt_info(line, nothing, const_codual_stmt(stmt.args[1], info), nothing) elseif Meta.isexpr(stmt, :loopinfo) # Cannot pass loopinfo back through the optimiser for some reason. diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 6303013c5..c862aaefe 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -29,10 +29,17 @@ function rrule!!(f::CoDual{<:Core.Builtin}, args...) "which is specialised to this case. " * "Either way, please consider commenting on " * "https://github.com/compintell/Mooncake.jl/issues/208/ so that the issue can be " * - "fixed more widely." + "fixed more widely.\n" * + "For reproducibility, note that the full signature is:\n" * + "$(typeof((f, args...)))" )) end +function Base.showerror(io::IO, err::MissingRuleForBuiltinException) + print(io, "MissingRuleForBuiltinException: ") + println(io, err.msg) +end + module IntrinsicsWrappers using Base: IEEEFloat diff --git a/src/rrules/fastmath.jl b/src/rrules/fastmath.jl index 4814fcbf6..80d014e0e 100644 --- a/src/rrules/fastmath.jl +++ b/src/rrules/fastmath.jl @@ -26,6 +26,9 @@ function rrule!!(::CoDual{typeof(Base.FastMath.sincos)}, x::CoDual{P}) where {P< return CoDual(y, NoFData()), sincos_fast_adj!! end +@is_primitive MinimalCtx Tuple{typeof(Base.log), Union{IEEEFloat, Int}} +@zero_adjoint MinimalCtx Tuple{typeof(log), Int} + function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:fastmath}) test_cases = Any[ (false, :stability_and_allocs, nothing, Base.FastMath.exp10_fast, 0.5), diff --git a/src/rrules/misc.jl b/src/rrules/misc.jl index 823711073..7fe309c31 100644 --- a/src/rrules/misc.jl +++ b/src/rrules/misc.jl @@ -34,7 +34,7 @@ # Required to avoid an ambiguity. @zero_adjoint MinimalCtx Tuple{Type{Symbol}, TypeVar, Type} -if VERSION >= v"1.11-" +@static if VERSION >= v"1.11-" @zero_adjoint MinimalCtx Tuple{typeof(Random.hash_seed), Vararg} @zero_adjoint MinimalCtx Tuple{typeof(Base.dataids), Memory} end diff --git a/src/test_utils.jl b/src/test_utils.jl index b222afaa5..0568da6e1 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -610,7 +610,7 @@ end _new_excluded(::Type) = false _new_excluded(::Type{<:Union{String}}) = true -if VERSION < v"1.11-" +@static if VERSION < v"1.11-" # Prior to 1.11, Arrays are special objects, with special constructors that don't # involve calling the `:new` instruction. From 1.11 onwards, they behave more like # regular mutable composite types, so calling `_new_` becomes meaningful. diff --git a/test/ext/dynamic_ppl/dynamic_ppl.jl b/test/ext/dynamic_ppl/dynamic_ppl.jl index 8d58636c5..7da8546df 100644 --- a/test/ext/dynamic_ppl/dynamic_ppl.jl +++ b/test/ext/dynamic_ppl/dynamic_ppl.jl @@ -7,5 +7,6 @@ using DynamicPPL: istrans, VarInfo using Mooncake.TestUtils: test_rule @testset "DynamicPPLMooncakeExt" begin - test_rule(StableRNG(123456), istrans, VarInfo(); unsafe_perturb=true) + rng = StableRNG(123456) + test_rule(rng, istrans, VarInfo(); unsafe_perturb=true, interface_only=true) end diff --git a/test/ext/nnlib/nnlib.jl b/test/ext/nnlib/nnlib.jl index 39fae7fb8..8ee065e5a 100644 --- a/test/ext/nnlib/nnlib.jl +++ b/test/ext/nnlib/nnlib.jl @@ -86,8 +86,8 @@ using NNlib: dropout (false, :none, true, NNlib.unfold, x, dense_cdims), # scatter - (false, :stability, true, NNlib.scatter, +, randn(2), [1, 3]), - (false, :stability, true, Core.kwcall, (;), NNlib.scatter, +, randn(2), [1, 3]), + (false, :none, true, NNlib.scatter, +, randn(2), [1, 3]), + (false, :none, true, Core.kwcall, (;), NNlib.scatter, +, randn(2), [1, 3]), # conv (false, :none, true, Core.kwcall, (;), conv, x, w, dense_cdims), diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl index bdb8bc7b3..40b8be7d5 100644 --- a/test/interpreter/s2s_reverse_mode_ad.jl +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -54,6 +54,8 @@ end @test Mooncake.get_primal_type(info, GlobalRef(Main, :___y)) == Float64 @test Mooncake.get_primal_type(info, 5) == Int @test Mooncake.get_primal_type(info, QuoteNode(:hello)) == Symbol + @test Mooncake.get_primal_type(info, Expr(:boundscheck)) == Bool + @test_throws ErrorException Mooncake.get_primal_type(info, Expr(:call)) end @testset "ADStmtInfo" begin # If the ID passes as the comms channel doesn't appear in the stmts for the forwards @@ -103,13 +105,13 @@ end @testset "literal" begin stmt_info = make_ad_stmts!(ReturnNode(5.0), line, info) @test stmt_info isa ADStmtInfo - @test stmt_info.fwds[1][2].stmt isa ReturnNode + @test stmt_info.fwds[2][2].stmt isa ReturnNode end @testset "GlobalRef" begin node = ReturnNode(GlobalRef(S2SGlobals, :const_float)) stmt_info = make_ad_stmts!(node, line, info) @test stmt_info isa ADStmtInfo - @test stmt_info.fwds[1][2].stmt isa ReturnNode + @test stmt_info.fwds[2][2].stmt isa ReturnNode end end @testset "IDGotoNode" begin @@ -143,13 +145,14 @@ end line = id_line_1 stmt_info = make_ad_stmts!(PiNode(nothing, Union{}), line, info) @test stmt_info isa ADStmtInfo + @test last(stmt_info.fwds)[1] == line end @testset "π (nothing, Nothing)" begin stmt_info = make_ad_stmts!(PiNode(nothing, Nothing), id_line_1, info) @test stmt_info isa ADStmtInfo - fwds_stmt = only(stmt_info.fwds)[2].stmt + @test last(stmt_info.fwds)[1] == id_line_1 + fwds_stmt = last(stmt_info.fwds)[2].stmt @test fwds_stmt isa PiNode - @test fwds_stmt.val == CoDual(nothing, NoFData()) @test fwds_stmt.typ == CoDual{Nothing, NoFData} @test only(stmt_info.rvs)[2].stmt === nothing end @@ -157,9 +160,9 @@ end node = PiNode(nothing, CC.Const(nothing)) stmt_info = make_ad_stmts!(node, id_line_1, info) @test stmt_info isa ADStmtInfo - fwds_stmt = only(stmt_info.fwds)[2].stmt + @test last(stmt_info.fwds)[1] == id_line_1 + fwds_stmt = last(stmt_info.fwds)[2].stmt @test fwds_stmt isa PiNode - @test fwds_stmt.val == CoDual(nothing, NoFData()) @test fwds_stmt.typ == CoDual{Nothing, NoFData} @test only(stmt_info.rvs)[2].stmt === nothing end @@ -167,9 +170,8 @@ end node = PiNode(GlobalRef(S2SGlobals, :const_float), Any) stmt_info = make_ad_stmts!(node, id_line_1, info) @test stmt_info isa ADStmtInfo - fwds_stmt = only(stmt_info.fwds)[2].stmt + fwds_stmt = last(stmt_info.fwds)[2].stmt @test fwds_stmt isa PiNode - @test fwds_stmt.val == CoDual(5.0, NoFData()) @test fwds_stmt.typ == CoDual @test only(stmt_info.rvs)[2].stmt === nothing end @@ -192,7 +194,8 @@ end @testset "differentiable const globals" begin stmt_info = make_ad_stmts!(GlobalRef(S2SGlobals, :const_float), ID(), info) @test stmt_info isa Mooncake.ADStmtInfo - @test only(stmt_info.fwds)[2].stmt isa CoDual{Float64} + @test only(stmt_info.fwds)[2].stmt isa Expr + @test only(stmt_info.fwds)[2].stmt.args[1] === Mooncake.uninit_fcodual end end @testset "PhiCNode" begin diff --git a/test/rrules/builtins.jl b/test/rrules/builtins.jl index df5acb933..4a8beed43 100644 --- a/test/rrules/builtins.jl +++ b/test/rrules/builtins.jl @@ -30,6 +30,12 @@ invoke(Mooncake.rrule!!, Tuple{CoDual{<:Core.Builtin}}, zero_fcodual(getfield)), ) + # Check that Base.showerror runs. + @test ==( + showerror(IOBuffer(; write=true), Mooncake.MissingRuleForBuiltinException("hmm")), + nothing, + ) + # Unhandled intrinsic throws an intelligible error. @test_throws( Mooncake.IntrinsicsWrappers.MissingIntrinsicWrapperException,