From be316ff54ac02bbe4de40d1a89775eed65741860 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 24 Nov 2024 11:36:11 +0100 Subject: [PATCH 01/32] Start forward mode prototype --- src/Mooncake.jl | 3 +++ src/dual.jl | 16 ++++++++++++++++ src/frules/basic.jl | 0 3 files changed, 19 insertions(+) create mode 100644 src/dual.jl create mode 100644 src/frules/basic.jl diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 5ee8f9e50..501ed446a 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -61,6 +61,7 @@ function rrule!! end include("utils.jl") include("tangents.jl") +include("dual.jl") include("fwds_rvs_data.jl") include("codual.jl") include("debug_mode.jl") @@ -78,6 +79,8 @@ include("tools_for_rules.jl") include("test_utils.jl") include("test_resources.jl") +include(joinpath("frules", "basic.jl")) + include(joinpath("rrules", "avoiding_non_differentiable_code.jl")) include(joinpath("rrules", "blas.jl")) include(joinpath("rrules", "builtins.jl")) diff --git a/src/dual.jl b/src/dual.jl new file mode 100644 index 000000000..43617505d --- /dev/null +++ b/src/dual.jl @@ -0,0 +1,16 @@ +struct Dual{P, T} + x::P + dx::T +end + +function Dual(x::P, dx::T) where {P,T} + if T != tangent_type(P) + throw(ArgumentError("Tried to build a `Dual(x, dx)` with `x::$P` and `dx::$T` but the correct tangent type is `$(tangent_type(P))`") + end + return Dual{P,T}(x, dx) +end + +primal(x::Dual) = x.x +tangent(x::Dual) = x.dx +Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x))) +_copy(x::P) where {P<:Dual} = x diff --git a/src/frules/basic.jl b/src/frules/basic.jl new file mode 100644 index 000000000..e69de29bb From deac913521616a6f7f4fff5d04003f898481d94f Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 24 Nov 2024 15:59:11 +0100 Subject: [PATCH 02/32] First working autodiff --- src/Mooncake.jl | 14 +++ src/debug_mode.jl | 1 + src/dual.jl | 21 ++-- src/frules/basic.jl | 11 ++ src/interpreter/s2s_forward_mode_ad.jl | 147 +++++++++++++++++++++++++ src/test_utils.jl | 2 +- test/forward.jl | 27 +++++ test/runtests.jl | 1 + 8 files changed, 211 insertions(+), 13 deletions(-) create mode 100644 src/interpreter/s2s_forward_mode_ad.jl create mode 100644 test/forward.jl diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 501ed446a..40170ae92 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -22,6 +22,7 @@ using Base: arrayset, TwicePrecision, twiceprecision using Base.Experimental: @opaque using Base.Iterators: product +using Base.Meta: isexpr using Core: Intrinsics, bitcast, SimpleVector, svec, ReturnNode, GotoNode, GotoIfNot, PhiNode, PiNode, SSAValue, Argument, OpaqueClosure, compilerbarrier @@ -34,6 +35,14 @@ using FunctionWrappers: FunctionWrapper # Needs to be defined before various other things. function _foreigncall_ end +""" + frule!!(f::Dual, x::Dual...) + +Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`. +``` +""" +function frule!! end + """ rrule!!(f::CoDual, x::CoDual...) @@ -73,6 +82,7 @@ include(joinpath("interpreter", "ir_utils.jl")) include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) +include(joinpath("interpreter", "s2s_forward_mode_ad.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) include("tools_for_rules.jl") @@ -121,9 +131,13 @@ export _add_to_primal, _diff, _dot, + Dual, + zero_dual, zero_codual, codual_type, + frule!!, rrule!!, + build_frule, build_rrule, value_and_gradient!!, value_and_pullback!!, diff --git a/src/debug_mode.jl b/src/debug_mode.jl index 25342afac..a9be1b09e 100644 --- a/src/debug_mode.jl +++ b/src/debug_mode.jl @@ -1,3 +1,4 @@ +DebugFRule(rule) = rule # TODO: make it non-trivial """ DebugPullback(pb, y, x) diff --git a/src/dual.jl b/src/dual.jl index 43617505d..a0b736e12 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -1,16 +1,13 @@ -struct Dual{P, T} - x::P - dx::T +struct Dual{P,T} + primal::P + tangent::T end -function Dual(x::P, dx::T) where {P,T} - if T != tangent_type(P) - throw(ArgumentError("Tried to build a `Dual(x, dx)` with `x::$P` and `dx::$T` but the correct tangent type is `$(tangent_type(P))`") - end - return Dual{P,T}(x, dx) -end - -primal(x::Dual) = x.x -tangent(x::Dual) = x.dx +primal(x::Dual) = x.primal +tangent(x::Dual) = x.tangent Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x))) _copy(x::P) where {P<:Dual} = x + +zero_dual(x) = Dual(x, zero_tangent(x)) + +dual_type(::Type{P}) where {P} = Dual{P,tangent_type(P)} diff --git a/src/frules/basic.jl b/src/frules/basic.jl index e69de29bb..f5901b83c 100644 --- a/src/frules/basic.jl +++ b/src/frules/basic.jl @@ -0,0 +1,11 @@ +frule!!(f::F, args::Vararg{Dual,N}) where {F,N} = frule!!(zero_dual(f), args...) + +@is_primitive MinimalCtx Tuple{typeof(sin),Number} +function frule!!(::Dual{typeof(sin)}, x::Dual{<:Number}) + return Dual(sin(primal(x)), cos(primal(x)) * tangent(x)) +end + +@is_primitive MinimalCtx Tuple{typeof(cos),Number} +function frule!!(::Dual{typeof(cos)}, x::Dual{<:Number}) + return Dual(cos(primal(x)), -sin(primal(x)) * tangent(x)) +end diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl new file mode 100644 index 000000000..c7646c255 --- /dev/null +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -0,0 +1,147 @@ +function build_frule(args...; debug_mode=false) + interp = get_interpreter() + sig = _typeof(TestUtils.__get_primals(args)) + return build_frule(interp, sig; debug_mode) +end + +function build_frule( + interp::MooncakeInterpreter{C}, + sig_or_mi; + debug_mode=false, + silence_debug_messages=true, +) where {C} + # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater + # than the current world age. + if Base.get_world_counter() > interp.world + throw( + ArgumentError( + "World age associated to interp is behind current world age. Please " * + "a new interpreter for the current world age.", + ), + ) + end + + # If we're compiling in debug mode, let the user know by default. + if !silence_debug_messages && debug_mode + @info "Compiling rule for $sig_or_mi in debug mode. Disable for best performance." + end + + # If we have a hand-coded rule, just use that. + _is_primitive(C, sig_or_mi) && return (debug_mode ? DebugFRule(frule!!) : frule!!) + + + # We don't have a hand-coded rule, so derived one. + lock(MOONCAKE_INFERENCE_LOCK) + try + # If we've already derived the OpaqueClosures and info, do not re-derive, just + # create a copy and pass in new shared data. + oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode)) + if haskey(interp.oc_cache, oc_cache_key) + return _copy(interp.oc_cache[oc_cache_key]) + else + # Derive forward-pass IR, and shove in a `MistyClosure`. + forward_ir = generate_forward_ir(interp, sig_or_mi; debug_mode) + fwd_oc = MistyClosure(forward_ir; do_compile=true) + raw_rule = DerivedFRule(fwd_oc) + rule = debug_mode ? DebugFRule(raw_rule) : raw_rule + interp.oc_cache[oc_cache_key] = rule + return rule + end + catch e + rethrow(e) + finally + unlock(MOONCAKE_INFERENCE_LOCK) + end +end + +function generate_forward_ir( + interp::MooncakeInterpreter, + sig_or_mi; + debug_mode=false, + do_inline=true, +) + # Reset id count. This ensures that the IDs generated are the same each time this + # function runs. + seed_id!() + + # Grab code associated to the primal. + primal_ir, _ = lookup_ir(interp, sig_or_mi) + + # Normalise the IR. + isva, spnames = is_vararg_and_sparam_names(sig_or_mi) + ir = normalise!(primal_ir, spnames) + + fwd_ir = dualize_ir(ir) + opt_fwd_ir = optimise_ir!(fwd_ir; do_inline) + return opt_fwd_ir +end + +function dualize_ir(ir::IRCode) + new_stmts_stmt = map(make_fwd_ad_stmt, ir.stmts.stmt) + new_stmts_type = map(dual_type, ir.stmts.type) + new_stmts_info = ir.stmts.info + new_stmts_line = ir.stmts.line + new_stmts_flag = ir.stmts.flag + new_stmts = CC.InstructionStream( + new_stmts_stmt, + new_stmts_type, + new_stmts_info, + new_stmts_line, + new_stmts_flag, + ) + new_cfg = ir.cfg + new_linetable = ir.linetable + rule_type = Any + new_argtypes = convert(Vector{Any}, vcat(rule_type, map(make_fwd_argtype, ir.argtypes))) + new_meta = ir.meta + new_sptypes = ir.sptypes + return IRCode(new_stmts, new_cfg, new_linetable, new_argtypes, new_meta, new_sptypes) +end + +make_fwd_argtype(::Type{P}) where {P} = dual_type(P) +make_fwd_argtype(c::Core.Const) = Dual # TODO: refine to type of const + +function make_fwd_ad_stmt(stmt::Expr) + interp = get_interpreter() # TODO: pass it around + C = context_type(interp) + if isexpr(stmt, :invoke) || isexpr(stmt, :call) + mi = stmt.args[1]::Core.MethodInstance + sig = mi.specTypes + if is_primitive(C, sig) + shifted_args = map(stmt.args) do a + if a isa Core.Argument + Core.Argument(a.n + 1) + else + a + end + end + new_stmt = Expr( + :call, + :($frule!!), + stmt.args[2], + shifted_args[3:end]... + ) + return new_stmt + else + throw(ArgumentError("Recursing into non-primitive calls is not yet supported in forward mode")) + end + return stmt + else + throw(ArgumentError("Expressions of type `:$(stmt.head)` are not yet supported in forward mode")) + end + return stmt +end + +function make_fwd_ad_stmt(stmt::ReturnNode) + return stmt +end + +struct DerivedFRule{Tfwd_oc} + fwd_oc::Tfwd_oc +end + +_copy(rule::DerivedFRule) = deepcopy(rule) + +@inline function (fwd::DerivedFRule)(args::Vararg{Dual,N}) where {N} + return fwd.fwd_oc.oc(args...) +end diff --git a/src/test_utils.jl b/src/test_utils.jl index 0568da6e1..72ff1168e 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -470,7 +470,7 @@ function test_rrule_performance( end end -__get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs) +__get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs) @doc""" test_rule( diff --git a/test/forward.jl b/test/forward.jl new file mode 100644 index 000000000..ce6d0e5d6 --- /dev/null +++ b/test/forward.jl @@ -0,0 +1,27 @@ +using Mooncake +using Test + +x, dx = 2.0, 3.0 +xdual = Dual(x, dx) + +@testset "Manual frule" begin + sin_rule = build_frule(sin, x) + ydual = sin_rule(zero_dual(sin), xdual) + + @test primal(ydual) == sin(x) + @test tangent(ydual) == dx * cos(x) +end + +function func(x) + y = sin(x) + z = cos(y) + return z +end + +@testset "Automatic frule" begin + func_rule = build_frule(func, x) + ydual = func_rule(zero_dual(func), xdual) + + @test primal(ydual) == cos(sin(x)) + @test tangent(ydual) ≈ dx * -sin(sin(x)) * cos(x) +end diff --git a/test/runtests.jl b/test/runtests.jl index 56cf74d9b..d60604e5e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -24,6 +24,7 @@ include("front_matter.jl") include("config.jl") include("developer_tools.jl") include("test_utils.jl") + include("forward.jl") elseif test_group == "rrules/avoiding_non_differentiable_code" include(joinpath("rrules", "avoiding_non_differentiable_code.jl")) elseif test_group == "rrules/blas" From 9c96c8d93e6e10f094bf940a6aac4334f7fb3883 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 24 Nov 2024 16:07:52 +0100 Subject: [PATCH 03/32] Docstring --- src/Mooncake.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 40170ae92..35749e00c 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -39,7 +39,6 @@ function _foreigncall_ end frule!!(f::Dual, x::Dual...) Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`. -``` """ function frule!! end From 136aff643240d7d0a612fe7ea5bdbf060905a569 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 24 Nov 2024 18:18:10 +0100 Subject: [PATCH 04/32] Apply suggestions from code review Co-authored-by: Will Tebbutt Signed-off-by: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> --- src/interpreter/s2s_forward_mode_ad.jl | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index c7646c255..bd201be5d 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -78,10 +78,10 @@ end function dualize_ir(ir::IRCode) new_stmts_stmt = map(make_fwd_ad_stmt, ir.stmts.stmt) - new_stmts_type = map(dual_type, ir.stmts.type) + new_stmts_type = fill(Any, length(ir.stmts.type)) new_stmts_info = ir.stmts.info new_stmts_line = ir.stmts.line - new_stmts_flag = ir.stmts.flag + new_stmts_flag = fill(CC.IR_FLAG_REFINED, length(ir.stmts.flag)) new_stmts = CC.InstructionStream( new_stmts_stmt, new_stmts_type, @@ -108,13 +108,7 @@ function make_fwd_ad_stmt(stmt::Expr) mi = stmt.args[1]::Core.MethodInstance sig = mi.specTypes if is_primitive(C, sig) - shifted_args = map(stmt.args) do a - if a isa Core.Argument - Core.Argument(a.n + 1) - else - a - end - end + shifted_args = inc_args(stmt.args) new_stmt = Expr( :call, :($frule!!), @@ -129,7 +123,6 @@ function make_fwd_ad_stmt(stmt::Expr) else throw(ArgumentError("Expressions of type `:$(stmt.head)` are not yet supported in forward mode")) end - return stmt end function make_fwd_ad_stmt(stmt::ReturnNode) From f65cc53f1471e9d6bdc65d543eddd2d5e89ac8d9 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 24 Nov 2024 18:19:29 +0100 Subject: [PATCH 05/32] Moving files around --- src/frules/basic.jl | 4 ++-- test/{forward.jl => interpreter/s2s_forward_mode_ad.jl} | 0 test/runtests.jl | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) rename test/{forward.jl => interpreter/s2s_forward_mode_ad.jl} (100%) diff --git a/src/frules/basic.jl b/src/frules/basic.jl index f5901b83c..dc83b2a9b 100644 --- a/src/frules/basic.jl +++ b/src/frules/basic.jl @@ -1,11 +1,11 @@ frule!!(f::F, args::Vararg{Dual,N}) where {F,N} = frule!!(zero_dual(f), args...) -@is_primitive MinimalCtx Tuple{typeof(sin),Number} +@is_primitive MinimalCtx Tuple{typeof(sin),IEEEFloat} function frule!!(::Dual{typeof(sin)}, x::Dual{<:Number}) return Dual(sin(primal(x)), cos(primal(x)) * tangent(x)) end -@is_primitive MinimalCtx Tuple{typeof(cos),Number} +@is_primitive MinimalCtx Tuple{typeof(cos),IEEEFloat} function frule!!(::Dual{typeof(cos)}, x::Dual{<:Number}) return Dual(cos(primal(x)), -sin(primal(x)) * tangent(x)) end diff --git a/test/forward.jl b/test/interpreter/s2s_forward_mode_ad.jl similarity index 100% rename from test/forward.jl rename to test/interpreter/s2s_forward_mode_ad.jl diff --git a/test/runtests.jl b/test/runtests.jl index d60604e5e..e2551b8fb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ include("front_matter.jl") include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) include(joinpath("interpreter", "zero_like_rdata.jl")) + include(joinpath("interpreter", "s2s_forward_mode_ad.jl")) include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) end include("tools_for_rules.jl") @@ -24,7 +25,6 @@ include("front_matter.jl") include("config.jl") include("developer_tools.jl") include("test_utils.jl") - include("forward.jl") elseif test_group == "rrules/avoiding_non_differentiable_code" include(joinpath("rrules", "avoiding_non_differentiable_code.jl")) elseif test_group == "rrules/blas" From 053a8bb57e0d5447e3bb6d4c398cebb2cb347e73 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Sun, 24 Nov 2024 19:14:37 +0100 Subject: [PATCH 06/32] Primitives already known --- src/frules/basic.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/frules/basic.jl b/src/frules/basic.jl index dc83b2a9b..936538677 100644 --- a/src/frules/basic.jl +++ b/src/frules/basic.jl @@ -1,11 +1,11 @@ frule!!(f::F, args::Vararg{Dual,N}) where {F,N} = frule!!(zero_dual(f), args...) -@is_primitive MinimalCtx Tuple{typeof(sin),IEEEFloat} +# @is_primitive MinimalCtx Tuple{typeof(sin),IEEEFloat} function frule!!(::Dual{typeof(sin)}, x::Dual{<:Number}) return Dual(sin(primal(x)), cos(primal(x)) * tangent(x)) end -@is_primitive MinimalCtx Tuple{typeof(cos),IEEEFloat} +# @is_primitive MinimalCtx Tuple{typeof(cos),IEEEFloat} function frule!!(::Dual{typeof(cos)}, x::Dual{<:Number}) return Dual(cos(primal(x)), -sin(primal(x)) * tangent(x)) end From a3107a85210149e5ed37f1f53ad456e06aa82b70 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 25 Nov 2024 11:54:55 +0100 Subject: [PATCH 07/32] Keep pushing forward (pun intended) --- src/Mooncake.jl | 4 +- src/dual.jl | 11 +- src/frules/basic.jl | 11 -- src/interpreter/diffractor_compiler_utils.jl | 127 ++++++++++++++++ src/interpreter/s2s_forward_mode_ad.jl | 150 +++++++++++-------- src/rrules/low_level_maths.jl | 6 + 6 files changed, 234 insertions(+), 75 deletions(-) delete mode 100644 src/frules/basic.jl create mode 100644 src/interpreter/diffractor_compiler_utils.jl diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 6169b8328..1e727261b 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -83,6 +83,8 @@ pb!!(1.0) """ function rrule!! end +include("interpreter/diffractor_compiler_utils.jl") + include("utils.jl") include("tangents.jl") include("dual.jl") @@ -104,8 +106,6 @@ include("tools_for_rules.jl") include("test_utils.jl") include("test_resources.jl") -include(joinpath("frules", "basic.jl")) - include(joinpath("rrules", "avoiding_non_differentiable_code.jl")) include(joinpath("rrules", "blas.jl")) include(joinpath("rrules", "builtins.jl")) diff --git a/src/dual.jl b/src/dual.jl index a0b736e12..9210cb9a1 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -10,4 +10,13 @@ _copy(x::P) where {P<:Dual} = x zero_dual(x) = Dual(x, zero_tangent(x)) -dual_type(::Type{P}) where {P} = Dual{P,tangent_type(P)} +function dual_type(::Type{P}) where {P} + P == DataType && return Dual + P isa Union && return Union{dual_type(P.a),dual_type(P.b)} + P <: UnionAll && return Dual # P is abstract, so we don't know its tangent type. + return isconcretetype(P) ? Dual{P,tangent_type(P)} : Dual +end + +function dual_type(p::Type{Type{P}}) where {P} + return @isdefined(P) ? Dual{Type{P},NoTangent} : Dual{_typeof(p),NoTangent} +end diff --git a/src/frules/basic.jl b/src/frules/basic.jl deleted file mode 100644 index 936538677..000000000 --- a/src/frules/basic.jl +++ /dev/null @@ -1,11 +0,0 @@ -frule!!(f::F, args::Vararg{Dual,N}) where {F,N} = frule!!(zero_dual(f), args...) - -# @is_primitive MinimalCtx Tuple{typeof(sin),IEEEFloat} -function frule!!(::Dual{typeof(sin)}, x::Dual{<:Number}) - return Dual(sin(primal(x)), cos(primal(x)) * tangent(x)) -end - -# @is_primitive MinimalCtx Tuple{typeof(cos),IEEEFloat} -function frule!!(::Dual{typeof(cos)}, x::Dual{<:Number}) - return Dual(cos(primal(x)), -sin(primal(x)) * tangent(x)) -end diff --git a/src/interpreter/diffractor_compiler_utils.jl b/src/interpreter/diffractor_compiler_utils.jl new file mode 100644 index 000000000..a6b5590f9 --- /dev/null +++ b/src/interpreter/diffractor_compiler_utils.jl @@ -0,0 +1,127 @@ +# TODO: figure out if we need this + +#! format: off + +# Utilities that should probably go into CC +using Core.Compiler: IRCode, CFG, BasicBlock, BBIdxIter + +function Base.push!(cfg::CFG, bb::BasicBlock) + @assert cfg.blocks[end].stmts.stop+1 == bb.stmts.start + push!(cfg.blocks, bb) + push!(cfg.index, bb.stmts.start) +end + +if VERSION < v"1.11.0-DEV.258" + Base.getindex(ir::IRCode, ssa::SSAValue) = CC.getindex(ir, ssa) +end + +if VERSION < v"1.12.0-DEV.1268" + if isdefined(CC, :Future) + Base.isready(future::CC.Future) = CC.isready(future) + Base.getindex(future::CC.Future) = CC.getindex(future) + Base.setindex!(future::CC.Future, value) = CC.setindex!(future, value) + end + + Base.iterate(c::CC.IncrementalCompact, args...) = CC.iterate(c, args...) + Base.iterate(p::CC.Pair, args...) = CC.iterate(p, args...) + Base.iterate(urs::CC.UseRefIterator, args...) = CC.iterate(urs, args...) + Base.iterate(x::CC.BBIdxIter, args...) = CC.iterate(x, args...) + Base.getindex(urs::CC.UseRefIterator, args...) = CC.getindex(urs, args...) + Base.getindex(urs::CC.UseRef, args...) = CC.getindex(urs, args...) + Base.getindex(c::CC.IncrementalCompact, args...) = CC.getindex(c, args...) + Base.setindex!(c::CC.IncrementalCompact, args...) = CC.setindex!(c, args...) + Base.setindex!(urs::CC.UseRef, args...) = CC.setindex!(urs, args...) + + Base.copy(ir::IRCode) = CC.copy(ir) + + CC.BasicBlock(x::UnitRange) = + BasicBlock(StmtRange(first(x), last(x))) + CC.BasicBlock(x::UnitRange, preds::Vector{Int}, succs::Vector{Int}) = + BasicBlock(StmtRange(first(x), last(x)), preds, succs) + Base.length(c::CC.NewNodeStream) = CC.length(c) + Base.setindex!(i::CC.Instruction, args...) = CC.setindex!(i, args...) + Base.size(x::CC.UnitRange) = CC.size(x) + + CC.get(a::Dict, b, c) = Base.get(a,b,c) + CC.haskey(a::Dict, b) = Base.haskey(a, b) + CC.setindex!(a::Dict, b, c) = setindex!(a, b, c) +end + +CC.NewInstruction(@nospecialize node) = + NewInstruction(node, Any, CC.NoCallInfo(), nothing, CC.IR_FLAG_REFINED) + +Base.setproperty!(x::CC.Instruction, f::Symbol, v) = CC.setindex!(x, v, f) + +Base.getproperty(x::CC.Instruction, f::Symbol) = CC.getindex(x, f) + +function Base.setindex!(ir::IRCode, ni::NewInstruction, i::Int) + stmt = ir.stmts[i] + stmt.inst = ni.stmt + stmt.type = ni.type + stmt.flag = something(ni.flag, 0) # fixes 1.9? + @static if VERSION ≥ v"1.12.0-DEV.173" + stmt.line = something(ni.line, CC.NoLineUpdate) + else + stmt.line = something(ni.line, 0) + end + return ni +end + +function Base.push!(ir::IRCode, ni::NewInstruction) + # TODO: This should be a check in insert_node! + @assert length(ir.new_nodes.stmts) == 0 + @static if isdefined(CC, :add!) + # Julia 1.7 & 1.8 + ir[CC.add!(ir.stmts)] = ni + else + # Re-named in https://github.com/JuliaLang/julia/pull/47051 + ir[CC.add_new_idx!(ir.stmts)] = ni + end + ir +end + +function Base.iterate(it::Iterators.Reverse{BBIdxIter}, + (bb, idx)::Tuple{Int, Int}=(length(it.itr.ir.cfg.blocks), length(it.itr.ir.stmts)+1)) + idx == 1 && return nothing + active_bb = it.itr.ir.cfg.blocks[bb] + if idx == first(active_bb.stmts) + bb -= 1 + end + return (bb, idx - 1), (bb, idx - 1) +end + +Base.lastindex(x::CC.InstructionStream) = + CC.length(x) + +""" + find_end_of_phi_block(ir::IRCode, start_search_idx::Int) + +Finds the last index within the same basic block, on or after the `start_search_idx` which is not within a phi block. +A phi-block is a run on PhiNodes or nothings that must be the first statements within the basic block. + +If `start_search_idx` is not within a phi block to begin with, then just returns `start_search_idx` +""" +function find_end_of_phi_block(ir::IRCode, start_search_idx::Int) + # Short-cut for early exit: + stmt = ir.stmts[start_search_idx][:inst] + stmt !== nothing && !isa(stmt, PhiNode) && return start_search_idx + + # Actually going to have to go digging throught the IR to out if were are in a phi block + bb=CC.block_for_inst(ir.cfg, start_search_idx) + end_search_idx=ir.cfg.blocks[bb].stmts[end] + for idx in (start_search_idx):(end_search_idx-1) + stmt = ir.stmts[idx+1][:inst] + # next statment is no longer in a phi block, so safe to insert + stmt !== nothing && !isa(stmt, PhiNode) && return idx + end + return end_search_idx +end + +function replace_call!(ir::IRCode, idx::SSAValue, new_call::Expr) + ir[idx][:inst] = new_call + ir[idx][:type] = Any + ir[idx][:info] = CC.NoCallInfo() + ir[idx][:flag] = CC.IR_FLAG_REFINED +end + +#! format: on diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index bd201be5d..7a7b8e94f 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -5,10 +5,7 @@ function build_frule(args...; debug_mode=false) end function build_frule( - interp::MooncakeInterpreter{C}, - sig_or_mi; - debug_mode=false, - silence_debug_messages=true, + interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false, silence_debug_messages=true ) where {C} # To avoid segfaults, ensure that we bail out if the interpreter's world age is greater # than the current world age. @@ -29,7 +26,6 @@ function build_frule( # If we have a hand-coded rule, just use that. _is_primitive(C, sig_or_mi) && return (debug_mode ? DebugFRule(frule!!) : frule!!) - # We don't have a hand-coded rule, so derived one. lock(MOONCAKE_INFERENCE_LOCK) try @@ -37,7 +33,7 @@ function build_frule( # create a copy and pass in new shared data. oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode)) if haskey(interp.oc_cache, oc_cache_key) - return _copy(interp.oc_cache[oc_cache_key]) + return interp.oc_cache[oc_cache_key] else # Derive forward-pass IR, and shove in a `MistyClosure`. forward_ir = generate_forward_ir(interp, sig_or_mi; debug_mode) @@ -54,11 +50,16 @@ function build_frule( end end +struct DerivedFRule{Tfwd_oc} + fwd_oc::Tfwd_oc +end + +@inline function (fwd::DerivedFRule)(args::Vararg{Dual,N}) where {N} + return fwd.fwd_oc(args...) +end + function generate_forward_ir( - interp::MooncakeInterpreter, - sig_or_mi; - debug_mode=false, - do_inline=true, + interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true ) # Reset id count. This ensures that the IDs generated are the same each time this # function runs. @@ -71,70 +72,97 @@ function generate_forward_ir( isva, spnames = is_vararg_and_sparam_names(sig_or_mi) ir = normalise!(primal_ir, spnames) - fwd_ir = dualize_ir(ir) - opt_fwd_ir = optimise_ir!(fwd_ir; do_inline) + # Differentiate the IR + fwd_ir = copy(ir) + for i in 1:length(ir.stmts) + # betting on the fact that lines don't change before compact! is called, even with insertions + stmt = fwd_ir[SSAValue(i)][:stmt] + make_fwd_ad_stmts!(fwd_ir, ir, interp, stmt, i; debug_mode) + end + pushfirst!(fwd_ir.argtypes, Any) # the rule will be the first argument + fwd_ir_compact = CC.compact!(fwd_ir) + + # Optimize the IR + opt_fwd_ir = optimise_ir!(fwd_ir_compact; do_inline) return opt_fwd_ir end -function dualize_ir(ir::IRCode) - new_stmts_stmt = map(make_fwd_ad_stmt, ir.stmts.stmt) - new_stmts_type = fill(Any, length(ir.stmts.type)) - new_stmts_info = ir.stmts.info - new_stmts_line = ir.stmts.line - new_stmts_flag = fill(CC.IR_FLAG_REFINED, length(ir.stmts.flag)) - new_stmts = CC.InstructionStream( - new_stmts_stmt, - new_stmts_type, - new_stmts_info, - new_stmts_line, - new_stmts_flag, - ) - new_cfg = ir.cfg - new_linetable = ir.linetable - rule_type = Any - new_argtypes = convert(Vector{Any}, vcat(rule_type, map(make_fwd_argtype, ir.argtypes))) - new_meta = ir.meta - new_sptypes = ir.sptypes - return IRCode(new_stmts, new_cfg, new_linetable, new_argtypes, new_meta, new_sptypes) -end +""" + make_fwd_ad_stmts!(ir, stmt, i) -make_fwd_argtype(::Type{P}) where {P} = dual_type(P) -make_fwd_argtype(c::Core.Const) = Dual # TODO: refine to type of const +Modify `ir` in-place to transform statement `stmt`, originally at position `SSAValue(i)`, into one or more derivative statements (which get inserted). +""" +function make_fwd_ad_stmts! end -function make_fwd_ad_stmt(stmt::Expr) - interp = get_interpreter() # TODO: pass it around +function make_fwd_ad_stmts!( + fwd_ir::IRCode, + ir::IRCode, + ::MooncakeInterpreter, + stmt::ReturnNode, + i::Integer; + kwargs..., +) + inst = fwd_ir[SSAValue(i)] + # the return node becomes a Dual so it changes type + # flag to re-run type inference + inst[:type] = Any + inst[:flag] = CC.IR_FLAG_REFINED + return nothing +end + +function make_fwd_ad_stmts!( + fwd_ir::IRCode, + ir::IRCode, + interp::MooncakeInterpreter, + stmt::Expr, + i::Integer; + debug_mode, +) + inst = fwd_ir[SSAValue(i)] C = context_type(interp) if isexpr(stmt, :invoke) || isexpr(stmt, :call) - mi = stmt.args[1]::Core.MethodInstance - sig = mi.specTypes + sig, mi = if isexpr(stmt, :invoke) + mi = stmt.args[1]::Core.MethodInstance + mi.specTypes, mi + else + sig_types = map(Base.Fix1(get_forward_primal_type, ir), stmt.args) + Tuple{sig_types...}, missing + end + shifted_args = inc_args(stmt.args) if is_primitive(C, sig) - shifted_args = inc_args(stmt.args) - new_stmt = Expr( - :call, - :($frule!!), - stmt.args[2], - shifted_args[3:end]... + inst[:stmt] = Expr(:call, frule!!, shifted_args[2:end]...) + inst[:info] = CC.NoCallInfo() + inst[:type] = Any + inst[:flag] = CC.IR_FLAG_REFINED + elseif isexpr(stmt, :invoke) + rule = build_frule(interp, mi; debug_mode) + # modify the original statement to use `rule` + inst[:stmt] = Expr(:call, rule, shifted_args[2:end]...) + inst[:info] = CC.NoCallInfo() + inst[:type] = Any + inst[:flag] = CC.IR_FLAG_REFINED + elseif isexpr(stmt, :call) + throw( + ArgumentError("Expressions of type `:call` not supported in forward mode") ) - return new_stmt - else - throw(ArgumentError("Recursing into non-primitive calls is not yet supported in forward mode")) end - return stmt else - throw(ArgumentError("Expressions of type `:$(stmt.head)` are not yet supported in forward mode")) + throw( + ArgumentError( + "Expressions of type `:$(stmt.head)` are not yet supported in forward mode" + ), + ) end end -function make_fwd_ad_stmt(stmt::ReturnNode) - return stmt +get_forward_primal_type(ir::IRCode, a::Argument) = ir.arg_types[a] +get_forward_primal_type(ir::IRCode, ssa::SSAValue) = ir[ssa][:type] +get_forward_primal_type(::IRCode, x::QuoteNode) = _typeof(x.value) +get_forward_primal_type(::IRCode, x) = _typeof(x) +function get_forward_primal_type(::IRCode, x::GlobalRef) + return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty end - -struct DerivedFRule{Tfwd_oc} - fwd_oc::Tfwd_oc -end - -_copy(rule::DerivedFRule) = deepcopy(rule) - -@inline function (fwd::DerivedFRule)(args::Vararg{Dual,N}) where {N} - return fwd.fwd_oc.oc(args...) +function get_forward_primal_type(::IRCode, x::Expr) + x.head === :boundscheck && return Bool + return error("Unrecognised expression $x found in argument slot.") end diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index 2c297ff83..e37f1669d 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -35,6 +35,9 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) end @is_primitive MinimalCtx Tuple{typeof(sin),<:IEEEFloat} +function frule!!(::Dual{typeof(sin)}, x::Dual{<:IEEEFloat}) + return Dual(sin(primal(x)), cos(primal(x)) * tangent(x)) +end function rrule!!(::CoDual{typeof(sin),NoFData}, x::CoDual{P,NoFData}) where {P<:IEEEFloat} s, c = sincos(primal(x)) sin_pullback!!(dy::P) = NoRData(), dy * c @@ -42,6 +45,9 @@ function rrule!!(::CoDual{typeof(sin),NoFData}, x::CoDual{P,NoFData}) where {P<: end @is_primitive MinimalCtx Tuple{typeof(cos),<:IEEEFloat} +function frule!!(::Dual{typeof(cos)}, x::Dual{<:IEEEFloat}) + return Dual(cos(primal(x)), -sin(primal(x)) * tangent(x)) +end function rrule!!(::CoDual{typeof(cos),NoFData}, x::CoDual{P,NoFData}) where {P<:IEEEFloat} s, c = sincos(primal(x)) cos_pullback!!(dy::P) = NoRData(), -dy * s From 2836ac8f6dbe56d12980923ed09153c1159d4dca Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 25 Nov 2024 15:17:17 +0100 Subject: [PATCH 08/32] Still buggy, don't touch --- .github/workflows/CI.yml | 170 +++--------------------- src/interpreter/s2s_forward_mode_ad.jl | 98 +++++++++----- test/interpreter/s2s_forward_mode_ad.jl | 37 +++--- 3 files changed, 108 insertions(+), 197 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 49a250381..753bef71c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -23,31 +23,31 @@ jobs: test_group: [ 'quality', 'basic', - 'rrules/avoiding_non_differentiable_code', - 'rrules/blas', - 'rrules/builtins', - 'rrules/fastmath', - 'rrules/foreigncall', - 'rrules/functionwrappers', - 'rrules/iddict', - 'rrules/lapack', - 'rrules/linear_algebra', - 'rrules/low_level_maths', - 'rrules/memory', - 'rrules/misc', - 'rrules/new', - 'rrules/tasks', - 'rrules/twice_precision', + # 'rrules/avoiding_non_differentiable_code', + # 'rrules/blas', + # 'rrules/builtins', + # 'rrules/fastmath', + # 'rrules/foreigncall', + # 'rrules/functionwrappers', + # 'rrules/iddict', + # 'rrules/lapack', + # 'rrules/linear_algebra', + # 'rrules/low_level_maths', + # 'rrules/memory', + # 'rrules/misc', + # 'rrules/new', + # 'rrules/tasks', + # 'rrules/twice_precision', ] version: - - 'lts' + # - 'lts' - '1' arch: - x64 - include: - - test_group: 'basic' - version: '1.10' - arch: x86 + # include: + # - test_group: 'basic' + # version: '1.10' + # arch: x86 steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -66,132 +66,4 @@ jobs: files: lcov.info token: ${{ secrets.CODECOV_TOKEN }} fail_ci_if_error: false - extra: - name: ${{matrix.test_group.test_type}}-${{ matrix.test_group.label }}-${{ matrix.version }}-${{ matrix.arch }} - runs-on: ubuntu-latest - if: github.event_name != 'schedule' - strategy: - fail-fast: false - matrix: - test_group: [ - {test_type: 'ext', label: 'differentiation_interface'}, - {test_type: 'ext', label: 'dynamic_ppl'}, - {test_type: 'ext', label: 'luxlib'}, - {test_type: 'ext', label: 'nnlib'}, - {test_type: 'ext', label: 'special_functions'}, - {test_type: 'integration_testing', label: 'array'}, - {test_type: 'integration_testing', label: 'bijectors'}, - {test_type: 'integration_testing', label: 'diff_tests'}, - {test_type: 'integration_testing', label: 'distributions'}, - {test_type: 'integration_testing', label: 'gp'}, - {test_type: 'integration_testing', label: 'logexpfunctions'}, - {test_type: 'integration_testing', label: 'lux'}, - {test_type: 'integration_testing', label: 'battery_tests'}, - {test_type: 'integration_testing', label: 'misc_abstract_array'}, - {test_type: 'integration_testing', label: 'temporalgps'}, - {test_type: 'integration_testing', label: 'turing'}, - ] - version: - - '1' - - 'lts' - arch: - - x64 - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: ${{ matrix.version }} - arch: ${{ matrix.arch }} - include-all-prereleases: false - - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@v1 - - run: | - if [ ${{ matrix.test_group.test_type }} == 'ext' ]; then - julia --code-coverage=user --eval 'include("test/run_extra.jl")' - else - julia --eval 'include("test/run_extra.jl")' - fi - env: - LABEL: ${{ matrix.test_group.label }} - TEST_TYPE: ${{ matrix.test_group.test_type }} - - uses: julia-actions/julia-processcoverage@v1 - - uses: codecov/codecov-action@v4 - with: - files: lcov.info - token: ${{ secrets.CODECOV_TOKEN }} - fail_ci_if_error: false - perf: - name: "Performance (${{ matrix.perf_group }})" - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - perf_group: - - 'hand_written' - - 'derived' - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: '1' - arch: x64 - include-all-prereleases: false - - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@v1 - - run: julia --project=bench --eval 'include("bench/run_benchmarks.jl"); main()' - env: - PERF_GROUP: ${{ matrix.perf_group }} - shell: bash - compperf: - name: "Performance (inter-AD)" - runs-on: ubuntu-latest - if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository - strategy: - fail-fast: false - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: '1' - arch: x64 - include-all-prereleases: false - - uses: julia-actions/cache@v2 - - uses: julia-actions/julia-buildpkg@v1 - - run: mkdir bench_results - - run: julia --project=bench --eval 'include("bench/run_benchmarks.jl"); main()' - env: - PERF_GROUP: 'comparison' - GKSwstype: '100' - shell: bash - - uses: actions/upload-artifact@v4 - with: - name: benchmarking-results - path: bench_results/ - # Useful code for testing action. - # - run: | - # text="this is line one - # this is line two - # this is line three" - # echo "$text" > benchmark_results.txt - - name: Read file content - id: read-file - run: | - { - echo "table<> $GITHUB_OUTPUT - - name: Find Comment - uses: peter-evans/find-comment@v3 - id: fc - with: - issue-number: ${{ github.event.pull_request.number }} - comment-author: github-actions[bot] - - id: post-report-as-pr-comment - name: Post Report as Pull Request Comment - uses: peter-evans/create-or-update-comment@v4 - with: - issue-number: ${{ github.event.pull_request.number }} - body: "Performance Ratio:\nRatio of time to compute gradient and time to compute function.\nWarning: results are very approximate! See [here](https://github.com/compintell/Mooncake.jl/tree/main/bench#inter-framework-benchmarking) for more context.\n```\n${{ steps.read-file.outputs.table }}\n```" - comment-id: ${{ steps.fc.outputs.comment-id }} - edit-mode: replace + \ No newline at end of file diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index 7a7b8e94f..2d8c0cacf 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -36,8 +36,9 @@ function build_frule( return interp.oc_cache[oc_cache_key] else # Derive forward-pass IR, and shove in a `MistyClosure`. - forward_ir = generate_forward_ir(interp, sig_or_mi; debug_mode) - fwd_oc = MistyClosure(forward_ir; do_compile=true) + dual_ir = generate_dual_ir(interp, sig_or_mi; debug_mode) + return dual_ir # TODO: remove + fwd_oc = MistyClosure(dual_ir; do_compile=true) raw_rule = DerivedFRule(fwd_oc) rule = debug_mode ? DebugFRule(raw_rule) : raw_rule interp.oc_cache[oc_cache_key] = rule @@ -58,7 +59,36 @@ end return fwd.fwd_oc(args...) end -function generate_forward_ir( +mutable struct PrimalAndDualIR + primal::IRCode + dual::IRCode + dual_to_primal_SSA::Vector{Int} + primal_current_line::Int +end + +function PrimalAndDualIR(primal::IRCode) + dual = copy(primal) + dual_to_primal_SSA = collect(1:length(primal.stmts)) + primal_current_line = 0 + + # Modify argument types: + # - add one for the rule in front + # - convert the rest to dual types + for (a, P) in enumerate(primal.argtypes) + if P isa DataType + dual.argtypes[a] = dual_type(P) + elseif P isa Core.Const + dual.argtypes[a] = Dual # TODO: improve + end + end + pushfirst!(dual.argtypes, Any) + + return PrimalAndDualIR(primal, dual, dual_to_primal_SSA, primal_current_line) +end + +export PrimalAndDualIR + +function generate_dual_ir( interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true ) # Reset id count. This ensures that the IDs generated are the same each time this @@ -73,36 +103,31 @@ function generate_forward_ir( ir = normalise!(primal_ir, spnames) # Differentiate the IR - fwd_ir = copy(ir) + pdir = PrimalAndDualIR(ir) for i in 1:length(ir.stmts) # betting on the fact that lines don't change before compact! is called, even with insertions - stmt = fwd_ir[SSAValue(i)][:stmt] - make_fwd_ad_stmts!(fwd_ir, ir, interp, stmt, i; debug_mode) + stmt = ir[SSAValue(i)][:stmt] + make_fwd_ad_stmts!(pdir, interp, stmt, i; debug_mode) end - pushfirst!(fwd_ir.argtypes, Any) # the rule will be the first argument - fwd_ir_compact = CC.compact!(fwd_ir) + return pdir.dual # TODO: remove # Optimize the IR - opt_fwd_ir = optimise_ir!(fwd_ir_compact; do_inline) - return opt_fwd_ir + opt_dual_ir = optimise_ir!(dual_ir; do_inline) + return opt_dual_ir end """ - make_fwd_ad_stmts!(ir, stmt, i) + make_fwd_ad_stmts!(pdir::PrimalAndDualIR, interpreter, stmt, i) -Modify `ir` in-place to transform statement `stmt`, originally at position `SSAValue(i)`, into one or more derivative statements (which get inserted). +Modify the dual part of `pdir` in-place to transform statement `stmt`, located at position `i` in the primal, into one or more derivative statements. """ function make_fwd_ad_stmts! end function make_fwd_ad_stmts!( - fwd_ir::IRCode, - ir::IRCode, - ::MooncakeInterpreter, - stmt::ReturnNode, - i::Integer; - kwargs..., + pdir::PrimalAndDualIR, ::MooncakeInterpreter, stmt::ReturnNode, i::Integer; kwargs... ) - inst = fwd_ir[SSAValue(i)] + (; dual) = pdir + inst = dual[SSAValue(i)] # the return node becomes a Dual so it changes type # flag to re-run type inference inst[:type] = Any @@ -111,30 +136,37 @@ function make_fwd_ad_stmts!( end function make_fwd_ad_stmts!( - fwd_ir::IRCode, - ir::IRCode, - interp::MooncakeInterpreter, - stmt::Expr, - i::Integer; - debug_mode, + pdir::PrimalAndDualIR, interp::MooncakeInterpreter, stmt::Expr, i::Integer; debug_mode ) - inst = fwd_ir[SSAValue(i)] + (; primal, dual, dual_to_primal_SSA) = pdir + i2 = findfirst(>=(i), dual_to_primal_SSA) C = context_type(interp) if isexpr(stmt, :invoke) || isexpr(stmt, :call) sig, mi = if isexpr(stmt, :invoke) mi = stmt.args[1]::Core.MethodInstance mi.specTypes, mi else - sig_types = map(Base.Fix1(get_forward_primal_type, ir), stmt.args) + sig_types = map(stmt.args) do a + get_forward_primal_type(ir, a) + end Tuple{sig_types...}, missing end - shifted_args = inc_args(stmt.args) + shifted_args = inc_args(stmt).args if is_primitive(C, sig) - inst[:stmt] = Expr(:call, frule!!, shifted_args[2:end]...) - inst[:info] = CC.NoCallInfo() - inst[:type] = Any - inst[:flag] = CC.IR_FLAG_REFINED + # insert instruction defining dual function + fd_stmt = Expr(:call, zero_dual, shifted_args[2]) + fd_inst = CC.NewInstruction(fd_stmt, Any) + CC.insert_node!(dual, SSAValue(i2), fd_inst, false) + dual[SSAValue(i2)][:stmt] = Expr( + :call, frule!!, SSAValue(i2), shifted_args[3:end]... + ) + dual[SSAValue(i2)][:info] = CC.NoCallInfo() + dual[SSAValue(i2)][:type] = Any + dual[SSAValue(i2)][:flag] = CC.IR_FLAG_REFINED + + insert!(dual_to_primal_SSA, i2, i) elseif isexpr(stmt, :invoke) + error("Not there yet") rule = build_frule(interp, mi; debug_mode) # modify the original statement to use `rule` inst[:stmt] = Expr(:call, rule, shifted_args[2:end]...) @@ -155,7 +187,7 @@ function make_fwd_ad_stmts!( end end -get_forward_primal_type(ir::IRCode, a::Argument) = ir.arg_types[a] +get_forward_primal_type(ir::IRCode, a::Argument) = ir.argtypes[a.n] get_forward_primal_type(ir::IRCode, ssa::SSAValue) = ir[ssa][:type] get_forward_primal_type(::IRCode, x::QuoteNode) = _typeof(x.value) get_forward_primal_type(::IRCode, x) = _typeof(x) diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index ce6d0e5d6..41bb5cdfd 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -1,27 +1,34 @@ +using MistyClosures using Mooncake using Test x, dx = 2.0, 3.0 xdual = Dual(x, dx) -@testset "Manual frule" begin - sin_rule = build_frule(sin, x) - ydual = sin_rule(zero_dual(sin), xdual) +sin_rule = build_frule(sin, x) +ydual = sin_rule(zero_dual(sin), xdual) - @test primal(ydual) == sin(x) - @test tangent(ydual) == dx * cos(x) -end +@test primal(ydual) == sin(x) +@test tangent(ydual) == dx * cos(x) function func(x) - y = sin(x) - z = cos(y) - return z + z = cos(x) + w = sin(z) + return w end -@testset "Automatic frule" begin - func_rule = build_frule(func, x) - ydual = func_rule(zero_dual(func), xdual) +ir = Base.code_ircode(func, (typeof(x),))[1][1] +dual_ir = build_frule(func, x) +comp = CC.compact!(dual_ir) - @test primal(ydual) == cos(sin(x)) - @test tangent(ydual) ≈ dx * -sin(sin(x)) * cos(x) -end +dual_ir |> typeof + +oc = MistyClosure(dual_ir) + +oc.oc(zero_dual(func), xdual) + +func_rule = build_frule(func, x) +ydual = func_rule(zero_dual(func), xdual) + +@test primal(ydual) == cos(sin(x)) +@test tangent(ydual) ≈ dx * -sin(sin(x)) * cos(x) From 09d63bd7c9545fc35eb34061aad298fbee20b3cc Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 26 Nov 2024 08:03:28 +0100 Subject: [PATCH 09/32] Keep instruction mapping one to one --- .gitignore | 1 + src/Mooncake.jl | 2 + src/interpreter/s2s_forward_mode_ad.jl | 126 +++++++++++------------- test/interpreter/s2s_forward_mode_ad.jl | 16 +-- 4 files changed, 61 insertions(+), 84 deletions(-) diff --git a/.gitignore b/.gitignore index bcbf1f024..c7342b02f 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,4 @@ profile.pb.gz scratch.jl docs/build/ docs/site/ +playground.jl \ No newline at end of file diff --git a/src/Mooncake.jl b/src/Mooncake.jl index 1e727261b..a3be1a619 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -58,6 +58,8 @@ Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual` """ function frule!! end +_frule!!_funcnotdual(f, args::Vararg{Any,N}) where {N} = frule!!(zero_dual(f), args...) + """ rrule!!(f::CoDual, x::CoDual...) diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index 2d8c0cacf..b3dc4ff3b 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -37,9 +37,8 @@ function build_frule( else # Derive forward-pass IR, and shove in a `MistyClosure`. dual_ir = generate_dual_ir(interp, sig_or_mi; debug_mode) - return dual_ir # TODO: remove - fwd_oc = MistyClosure(dual_ir; do_compile=true) - raw_rule = DerivedFRule(fwd_oc) + dual_oc = MistyClosure(dual_ir; do_compile=true) + raw_rule = DerivedFRule(dual_oc) rule = debug_mode ? DebugFRule(raw_rule) : raw_rule interp.oc_cache[oc_cache_key] = rule return rule @@ -59,35 +58,6 @@ end return fwd.fwd_oc(args...) end -mutable struct PrimalAndDualIR - primal::IRCode - dual::IRCode - dual_to_primal_SSA::Vector{Int} - primal_current_line::Int -end - -function PrimalAndDualIR(primal::IRCode) - dual = copy(primal) - dual_to_primal_SSA = collect(1:length(primal.stmts)) - primal_current_line = 0 - - # Modify argument types: - # - add one for the rule in front - # - convert the rest to dual types - for (a, P) in enumerate(primal.argtypes) - if P isa DataType - dual.argtypes[a] = dual_type(P) - elseif P isa Core.Const - dual.argtypes[a] = Dual # TODO: improve - end - end - pushfirst!(dual.argtypes, Any) - - return PrimalAndDualIR(primal, dual, dual_to_primal_SSA, primal_current_line) -end - -export PrimalAndDualIR - function generate_dual_ir( interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true ) @@ -99,80 +69,94 @@ function generate_dual_ir( primal_ir, _ = lookup_ir(interp, sig_or_mi) # Normalise the IR. - isva, spnames = is_vararg_and_sparam_names(sig_or_mi) - ir = normalise!(primal_ir, spnames) - - # Differentiate the IR - pdir = PrimalAndDualIR(ir) - for i in 1:length(ir.stmts) - # betting on the fact that lines don't change before compact! is called, even with insertions - stmt = ir[SSAValue(i)][:stmt] - make_fwd_ad_stmts!(pdir, interp, stmt, i; debug_mode) + _, spnames = is_vararg_and_sparam_names(sig_or_mi) + primal_ir = normalise!(primal_ir, spnames) + + # Keep a copy of the primal IR around + dual_ir = copy(primal_ir) + + # Modify dual argument types: + # - add one for the rule in front + # - convert the rest to dual types + for (a, P) in enumerate(primal_ir.argtypes) + if P isa DataType + dual_ir.argtypes[a] = dual_type(P) + elseif P isa Core.Const + dual_ir.argtypes[a] = Dual # TODO: improve + end + end + pushfirst!(dual_ir.argtypes, Any) + + # Differentiate dual IR + for i in 1:length(primal_ir.stmts) + stmt = primal_ir[SSAValue(i)][:stmt] + make_fwd_ad_stmts!(dual_ir, primal_ir, interp, stmt, i; debug_mode) end - return pdir.dual # TODO: remove - # Optimize the IR + # Optimize dual IR opt_dual_ir = optimise_ir!(dual_ir; do_inline) return opt_dual_ir end """ - make_fwd_ad_stmts!(pdir::PrimalAndDualIR, interpreter, stmt, i) + make_fwd_ad_stmts!(dual_ir, primal_ir, interpreter, stmt, i) -Modify the dual part of `pdir` in-place to transform statement `stmt`, located at position `i` in the primal, into one or more derivative statements. +Modify `dual_ir` in-place to transform statement `stmt`, located at position `i` in `primal_ir`, into one or more derivative statements. + +!!! warning + We must enforce the invariant that this function never deletes nor adds instructions. """ function make_fwd_ad_stmts! end function make_fwd_ad_stmts!( - pdir::PrimalAndDualIR, ::MooncakeInterpreter, stmt::ReturnNode, i::Integer; kwargs... + dual_ir::IRCode, + primal_ir::IRCode, + ::MooncakeInterpreter, + stmt::ReturnNode, + i::Integer; + kwargs..., ) - (; dual) = pdir - inst = dual[SSAValue(i)] + dual_inst = dual_ir[SSAValue(i)] # the return node becomes a Dual so it changes type # flag to re-run type inference - inst[:type] = Any - inst[:flag] = CC.IR_FLAG_REFINED + dual_inst[:type] = Any + dual_inst[:flag] = CC.IR_FLAG_REFINED return nothing end function make_fwd_ad_stmts!( - pdir::PrimalAndDualIR, interp::MooncakeInterpreter, stmt::Expr, i::Integer; debug_mode + dual_ir::IRCode, + primal_ir::IRCode, + interp::MooncakeInterpreter, + stmt::Expr, + i::Integer; + debug_mode, ) - (; primal, dual, dual_to_primal_SSA) = pdir - i2 = findfirst(>=(i), dual_to_primal_SSA) C = context_type(interp) + dual_inst = dual_ir[SSAValue(i)] if isexpr(stmt, :invoke) || isexpr(stmt, :call) sig, mi = if isexpr(stmt, :invoke) mi = stmt.args[1]::Core.MethodInstance mi.specTypes, mi else sig_types = map(stmt.args) do a - get_forward_primal_type(ir, a) + get_forward_primal_type(primal_ir, a) end Tuple{sig_types...}, missing end shifted_args = inc_args(stmt).args if is_primitive(C, sig) - # insert instruction defining dual function - fd_stmt = Expr(:call, zero_dual, shifted_args[2]) - fd_inst = CC.NewInstruction(fd_stmt, Any) - CC.insert_node!(dual, SSAValue(i2), fd_inst, false) - dual[SSAValue(i2)][:stmt] = Expr( - :call, frule!!, SSAValue(i2), shifted_args[3:end]... - ) - dual[SSAValue(i2)][:info] = CC.NoCallInfo() - dual[SSAValue(i2)][:type] = Any - dual[SSAValue(i2)][:flag] = CC.IR_FLAG_REFINED - - insert!(dual_to_primal_SSA, i2, i) + dual_inst[:stmt] = Expr(:call, _frule!!_funcnotdual, shifted_args[2:end]...) + dual_inst[:info] = CC.NoCallInfo() + dual_inst[:type] = Any + dual_inst[:flag] = CC.IR_FLAG_REFINED elseif isexpr(stmt, :invoke) - error("Not there yet") rule = build_frule(interp, mi; debug_mode) # modify the original statement to use `rule` - inst[:stmt] = Expr(:call, rule, shifted_args[2:end]...) - inst[:info] = CC.NoCallInfo() - inst[:type] = Any - inst[:flag] = CC.IR_FLAG_REFINED + dual_inst[:stmt] = Expr(:call, rule, shifted_args[2:end]...) + dual_inst[:info] = CC.NoCallInfo() + dual_inst[:type] = Any + dual_inst[:flag] = CC.IR_FLAG_REFINED elseif isexpr(stmt, :call) throw( ArgumentError("Expressions of type `:call` not supported in forward mode") diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index 41bb5cdfd..c057a9c4f 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -12,21 +12,11 @@ ydual = sin_rule(zero_dual(sin), xdual) @test tangent(ydual) == dx * cos(x) function func(x) - z = cos(x) - w = sin(z) - return w + y = sin(x) + z = cos(y) + return z end -ir = Base.code_ircode(func, (typeof(x),))[1][1] -dual_ir = build_frule(func, x) -comp = CC.compact!(dual_ir) - -dual_ir |> typeof - -oc = MistyClosure(dual_ir) - -oc.oc(zero_dual(func), xdual) - func_rule = build_frule(func, x) ydual = func_rule(zero_dual(func), xdual) From fa679eb396c71e0bbce5f3ffadcbea4e36330fe6 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 26 Nov 2024 08:08:12 +0100 Subject: [PATCH 10/32] Use replace_call --- src/interpreter/s2s_forward_mode_ad.jl | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index b3dc4ff3b..6fb179538 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -133,7 +133,6 @@ function make_fwd_ad_stmts!( debug_mode, ) C = context_type(interp) - dual_inst = dual_ir[SSAValue(i)] if isexpr(stmt, :invoke) || isexpr(stmt, :call) sig, mi = if isexpr(stmt, :invoke) mi = stmt.args[1]::Core.MethodInstance @@ -146,17 +145,12 @@ function make_fwd_ad_stmts!( end shifted_args = inc_args(stmt).args if is_primitive(C, sig) - dual_inst[:stmt] = Expr(:call, _frule!!_funcnotdual, shifted_args[2:end]...) - dual_inst[:info] = CC.NoCallInfo() - dual_inst[:type] = Any - dual_inst[:flag] = CC.IR_FLAG_REFINED + call_frule = Expr(:call, _frule!!_funcnotdual, shifted_args[2:end]...) + replace_call!(dual_ir, SSAValue(i), call_frule) elseif isexpr(stmt, :invoke) rule = build_frule(interp, mi; debug_mode) - # modify the original statement to use `rule` - dual_inst[:stmt] = Expr(:call, rule, shifted_args[2:end]...) - dual_inst[:info] = CC.NoCallInfo() - dual_inst[:type] = Any - dual_inst[:flag] = CC.IR_FLAG_REFINED + call_rule = Expr(:call, rule, shifted_args[2:end]...) + replace_call!(dual_ir, SSAValue(i), call_rule) elseif isexpr(stmt, :call) throw( ArgumentError("Expressions of type `:call` not supported in forward mode") From a68257c9123b6dc980b68ad45d11cf47b6e82380 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 27 Nov 2024 16:51:07 +0100 Subject: [PATCH 11/32] Ignore code cov --- src/interpreter/diffractor_compiler_utils.jl | 2 +- src/interpreter/s2s_forward_mode_ad.jl | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/interpreter/diffractor_compiler_utils.jl b/src/interpreter/diffractor_compiler_utils.jl index a6b5590f9..d341ed2df 100644 --- a/src/interpreter/diffractor_compiler_utils.jl +++ b/src/interpreter/diffractor_compiler_utils.jl @@ -117,7 +117,7 @@ function find_end_of_phi_block(ir::IRCode, start_search_idx::Int) return end_search_idx end -function replace_call!(ir::IRCode, idx::SSAValue, new_call::Expr) +function replace_call!(ir::IRCode, idx::SSAValue, new_call) ir[idx][:inst] = new_call ir[idx][:type] = Any ir[idx][:info] = CC.NoCallInfo() diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index 6fb179538..cc451adf6 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -156,6 +156,8 @@ function make_fwd_ad_stmts!( ArgumentError("Expressions of type `:call` not supported in forward mode") ) end + elseif Meta.isexpr(stmt, :code_coverage_effect) + replace_call!(dual_ir, SSAValue(i), nothing) else throw( ArgumentError( From 7a096ba129d2d9f5c3747438eeca20ad3b7f4fb8 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Wed, 27 Nov 2024 19:09:46 +0100 Subject: [PATCH 12/32] No Aqua piracies test --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 9e6f43c9b..2878ff33a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ include("front_matter.jl") @testset "Mooncake.jl" begin if test_group == "quality" - Aqua.test_all(Mooncake) + Aqua.test_all(Mooncake; piracies=false) # TODO: toggle once Diffractor code is removed @test JuliaFormatter.format(Mooncake; verbose=false, overwrite=false) elseif test_group == "basic" include("utils.jl") From 46c3e5a37958d48607f2a62a33c17bfdc8e3f8dd Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:09:20 +0100 Subject: [PATCH 13/32] Start control flow --- src/Mooncake.jl | 2 - src/dual.jl | 6 + src/interpreter/bbcode.jl | 9 ++ src/interpreter/ir_normalisation.jl | 1 + src/interpreter/ir_utils.jl | 3 +- src/interpreter/s2s_forward_mode_ad.jl | 147 +++++++++++++++++++----- src/rrules/builtins.jl | 5 +- test/interpreter/s2s_forward_mode_ad.jl | 9 +- 8 files changed, 148 insertions(+), 34 deletions(-) diff --git a/src/Mooncake.jl b/src/Mooncake.jl index a3be1a619..1e727261b 100644 --- a/src/Mooncake.jl +++ b/src/Mooncake.jl @@ -58,8 +58,6 @@ Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual` """ function frule!! end -_frule!!_funcnotdual(f, args::Vararg{Any,N}) where {N} = frule!!(zero_dual(f), args...) - """ rrule!!(f::CoDual, x::CoDual...) diff --git a/src/dual.jl b/src/dual.jl index 9210cb9a1..363bb664b 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -20,3 +20,9 @@ end function dual_type(p::Type{Type{P}}) where {P} return @isdefined(P) ? Dual{Type{P},NoTangent} : Dual{_typeof(p),NoTangent} end + +_primal(x) = x +_primal(x::Dual) = primal(x) + +make_dual(x) = zero_dual(x) +make_dual(x::Dual) = x diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl index 46f9920aa..1d0e5c618 100644 --- a/src/interpreter/bbcode.jl +++ b/src/interpreter/bbcode.jl @@ -891,6 +891,15 @@ inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) inc_args(x::IDGotoNode) = x +function inc_args(x::PhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = __inc(x.values[n]) + end + end + return PhiNode(x.edges, new_values) +end function inc_args(x::IDPhiNode) new_values = Vector{Any}(undef, length(x.values)) for n in eachindex(x.values) diff --git a/src/interpreter/ir_normalisation.jl b/src/interpreter/ir_normalisation.jl index 570607457..3031fa66a 100644 --- a/src/interpreter/ir_normalisation.jl +++ b/src/interpreter/ir_normalisation.jl @@ -56,6 +56,7 @@ function _interpolate_boundschecks!(statements::Vector{Any}) if stmt isa Expr && stmt.head == :boundscheck && length(stmt.args) == 1 def = SSAValue(n) val = only(stmt.args) + # TODO: this could just be `statements[n] = val` (Valentin C says) for (m, stmt) in enumerate(statements) statements[m] = replace_uses_with!(stmt, def, val) end diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index d4a518bb5..0f27d4c6b 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -332,7 +332,8 @@ function replace_uses_with!(stmt, def::Union{Argument,SSAValue}, val) elseif stmt isa GotoIfNot if stmt.cond == def @assert val isa Bool - return val === true ? nothing : GotoNode(stmt.dest) + # nothing is not a Terminator + return val === true ? GotoIfNot(val, stmt.dest) : GotoNode(stmt.dest) else return stmt end diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index cc451adf6..b29d98ef4 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -72,7 +72,14 @@ function generate_dual_ir( _, spnames = is_vararg_and_sparam_names(sig_or_mi) primal_ir = normalise!(primal_ir, spnames) - # Keep a copy of the primal IR around + # Insert statements into primal IR + for i in 1:length(primal_ir.stmts) + stmt = primal_ir[SSAValue(i)][:stmt] + insert_fwd_ad_stmts!(primal_ir, interp, stmt, i; debug_mode) + end + primal_ir = CC.compact!(primal_ir) + + # Keep a copy of the primal IR with the insertions dual_ir = copy(primal_ir) # Modify dual argument types: @@ -87,28 +94,95 @@ function generate_dual_ir( end pushfirst!(dual_ir.argtypes, Any) - # Differentiate dual IR - for i in 1:length(primal_ir.stmts) - stmt = primal_ir[SSAValue(i)][:stmt] - make_fwd_ad_stmts!(dual_ir, primal_ir, interp, stmt, i; debug_mode) + # Modify dual IR without insertions + for i in 1:length(dual_ir.stmts) + stmt = dual_ir[SSAValue(i)][:stmt] + modify_fwd_ad_stmts!(dual_ir, primal_ir, interp, stmt, i; debug_mode) end + dual_ir = CC.compact!(dual_ir) # skippable but good practice + + CC.verify_ir(dual_ir) # Optimize dual IR opt_dual_ir = optimise_ir!(dual_ir; do_inline) return opt_dual_ir end -""" - make_fwd_ad_stmts!(dual_ir, primal_ir, interpreter, stmt, i) +## Insertion (only GotoIfNot) + +function insert_fwd_ad_stmts!( + primal_ir::IRCode, ::MooncakeInterpreter, stmt, i::Integer; kwargs... +) + return nothing +end + +function insert_fwd_ad_stmts!( + primal_ir::IRCode, ::MooncakeInterpreter, stmt::GotoIfNot, i::Integer; kwargs... +) + get_primal_inst = CC.NewInstruction(Expr(:call, _primal, stmt.cond), Any) + return CC.insert_node!(primal_ir, CC.SSAValue(i), get_primal_inst, false) +end + +## Modification + +function modify_fwd_ad_stmts!( + dual_ir::IRCode, + primal_ir::IRCode, + ::MooncakeInterpreter, + stmt::Nothing, + i::Integer; + kwargs..., +) + return nothing +end + +function modify_fwd_ad_stmts!( + dual_ir::IRCode, + primal_ir::IRCode, + ::MooncakeInterpreter, + stmt::Union{GotoNode,Core.GotoIfNot}, + i::Integer; + kwargs..., +) + return nothing +end + +function modify_fwd_ad_stmts!( + dual_ir::IRCode, + primal_ir::IRCode, + ::MooncakeInterpreter, + stmt::GotoIfNot, + i::Integer; + kwargs..., +) + return Mooncake.replace_call!( + dual_ir, CC.SSAValue(i), Core.GotoIfNot(CC.SSAValue(i - 1), stmt.dest) + ) +end + +# TODO: wrapping in Dual must not be systematic (e.g. Argument or SSAValue) +_frule!!_makedual(f, args::Vararg{Any,N}) where {N} = frule!!(make_dual.(args)...) + +struct DynamicFRule{V} + cache::V + debug_mode::Bool +end + +DynamicFRule(debug_mode::Bool) = DynamicFRule(Dict{Any,Any}(), debug_mode) -Modify `dual_ir` in-place to transform statement `stmt`, located at position `i` in `primal_ir`, into one or more derivative statements. +_copy(x::P) where {P<:DynamicFRule} = P(Dict{Any,Any}(), x.debug_mode) -!!! warning - We must enforce the invariant that this function never deletes nor adds instructions. -""" -function make_fwd_ad_stmts! end +function (dynamic_rule::DynamicFRule)(args::Vararg{Any,N}) where {N} + sig = Tuple{map(_typeof ∘ primal, args)...} + rule = get(dynamic_rule.cache, sig, nothing) + if rule === nothing + rule = build_frule(get_interpreter(), sig; debug_mode=dynamic_rule.debug_mode) + dynamic_rule.cache[sig] = rule + end + return rule(args...) +end -function make_fwd_ad_stmts!( +function modify_fwd_ad_stmts!( dual_ir::IRCode, primal_ir::IRCode, ::MooncakeInterpreter, @@ -116,15 +190,28 @@ function make_fwd_ad_stmts!( i::Integer; kwargs..., ) - dual_inst = dual_ir[SSAValue(i)] # the return node becomes a Dual so it changes type # flag to re-run type inference - dual_inst[:type] = Any - dual_inst[:flag] = CC.IR_FLAG_REFINED + dual_ir[SSAValue(i)][:type] = Any + dual_ir[SSAValue(i)][:flag] = CC.IR_FLAG_REFINED return nothing end -function make_fwd_ad_stmts!( +function modify_fwd_ad_stmts!( + dual_ir::IRCode, + primal_ir::IRCode, + ::MooncakeInterpreter, + stmt::PhiNode, + i::Integer; + kwargs..., +) + dual_ir[SSAValue(i)][:stmt] = inc_args(stmt) # TODO: translate constants into constant Duals + dual_ir[SSAValue(i)][:type] = Any + dual_ir[SSAValue(i)][:flag] = CC.IR_FLAG_REFINED + return nothing +end + +function modify_fwd_ad_stmts!( dual_ir::IRCode, primal_ir::IRCode, interp::MooncakeInterpreter, @@ -132,7 +219,6 @@ function make_fwd_ad_stmts!( i::Integer; debug_mode, ) - C = context_type(interp) if isexpr(stmt, :invoke) || isexpr(stmt, :call) sig, mi = if isexpr(stmt, :invoke) mi = stmt.args[1]::Core.MethodInstance @@ -143,18 +229,23 @@ function make_fwd_ad_stmts!( end Tuple{sig_types...}, missing end - shifted_args = inc_args(stmt).args - if is_primitive(C, sig) - call_frule = Expr(:call, _frule!!_funcnotdual, shifted_args[2:end]...) + shifted_args = if isexpr(stmt, :invoke) + inc_args(stmt).args[2:end] # first arg is method instance + else + inc_args(stmt).args + end + if is_primitive(context_type(interp), sig) + call_frule = Expr(:call, _frule!!_makedual, shifted_args...) replace_call!(dual_ir, SSAValue(i), call_frule) - elseif isexpr(stmt, :invoke) - rule = build_frule(interp, mi; debug_mode) - call_rule = Expr(:call, rule, shifted_args[2:end]...) + else + if isexpr(stmt, :invoke) + rule = build_frule(interp, mi; debug_mode) + else + @assert isexpr(stmt, :call) + rule = DynamicFRule(debug_mode) + end + call_rule = Expr(:call, rule, shifted_args...) replace_call!(dual_ir, SSAValue(i), call_rule) - elseif isexpr(stmt, :call) - throw( - ArgumentError("Expressions of type `:call` not supported in forward mode") - ) end elseif Meta.isexpr(stmt, :code_coverage_effect) replace_call!(dual_ir, SSAValue(i), nothing) diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 66e055d57..aa8f3015f 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -100,9 +100,12 @@ macro inactive_intrinsic(name) $name(x...) = Intrinsics.$name(x...) (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name),Vararg}}) = true translate(::Val{Intrinsics.$name}) = $name - function rrule!!(f::CoDual{typeof($name)}, args::Vararg{Any,N}) where {N} + function frule!!(f::Dual{typeof($name)}, args::Vararg{Any,N}) where {N} return Mooncake.zero_adjoint(f, args...) end + function frule!!(f::Dual{typeof($name)}, args::Vararg{Any,N}) where {N} + return Mooncake.zero_dual(f(args...)) + end end return esc(expr) end diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index c057a9c4f..bd4a483b5 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -13,11 +13,16 @@ ydual = sin_rule(zero_dual(sin), xdual) function func(x) y = sin(x) - z = cos(y) + if x[1] > 0 + z = cos(y) + else + z = sin(y) + end return z end -func_rule = build_frule(func, x) +ir = Base.code_ircode(func, (Int,))[1][1] +irfunc_rule = build_frule(func, x) ydual = func_rule(zero_dual(func), xdual) @test primal(ydual) == cos(sin(x)) From ad3f98a1c0fe35e8340c41af80e274ea594d1380 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:31:45 +0100 Subject: [PATCH 14/32] Fix intrinsic --- src/rrules/builtins.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index aa8f3015f..a93d4b367 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -100,7 +100,7 @@ macro inactive_intrinsic(name) $name(x...) = Intrinsics.$name(x...) (is_primitive)(::Type{MinimalCtx}, ::Type{<:Tuple{typeof($name),Vararg}}) = true translate(::Val{Intrinsics.$name}) = $name - function frule!!(f::Dual{typeof($name)}, args::Vararg{Any,N}) where {N} + function rrule!!(f::CoDual{typeof($name)}, args::Vararg{Any,N}) where {N} return Mooncake.zero_adjoint(f, args...) end function frule!!(f::Dual{typeof($name)}, args::Vararg{Any,N}) where {N} From 9071574858b2eb292ab3b90b2dc6322db5a68884 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:33:59 +0100 Subject: [PATCH 15/32] Import --- src/rrules/builtins.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index a93d4b367..5eab2a997 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -49,6 +49,7 @@ using Mooncake import ..Mooncake: rrule!!, CoDual, + Dual, primal, tangent, zero_tangent, @@ -66,7 +67,8 @@ import ..Mooncake: NoRData, rdata, increment_rdata!!, - zero_fcodual + zero_fcodual, + zero_dual using Core.Intrinsics: atomic_pointerref From dcfe282bd655be410f2ba264ec37dce0fe6daeaa Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 28 Nov 2024 17:40:54 +0100 Subject: [PATCH 16/32] Typos --- src/interpreter/s2s_forward_mode_ad.jl | 2 +- test/interpreter/s2s_forward_mode_ad.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index b29d98ef4..543fa8ffe 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -161,7 +161,7 @@ function modify_fwd_ad_stmts!( end # TODO: wrapping in Dual must not be systematic (e.g. Argument or SSAValue) -_frule!!_makedual(f, args::Vararg{Any,N}) where {N} = frule!!(make_dual.(args)...) +_frule!!_makedual(f, args::Vararg{Any,N}) where {N} = frule!!(make_dual.((f, args...))...) struct DynamicFRule{V} cache::V diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index bd4a483b5..699da8c35 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -22,7 +22,7 @@ function func(x) end ir = Base.code_ircode(func, (Int,))[1][1] -irfunc_rule = build_frule(func, x) +func_rule = build_frule(func, x) ydual = func_rule(zero_dual(func), xdual) @test primal(ydual) == cos(sin(x)) From dd89e57ffcebf022f667baf44adf41a068b3ce77 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 12:04:19 +0100 Subject: [PATCH 17/32] Figure out incremental additions --- src/interpreter/diffractor_compiler_utils.jl | 2 +- src/interpreter/s2s_forward_mode_ad.jl | 103 ++++++------ src/rrules/builtins.jl | 7 +- src/rrules/misc.jl | 10 ++ src/test_utils.jl | 166 ++++++++++++++++++- test/interpreter/s2s_forward_mode_ad.jl | 28 ++-- 6 files changed, 241 insertions(+), 75 deletions(-) diff --git a/src/interpreter/diffractor_compiler_utils.jl b/src/interpreter/diffractor_compiler_utils.jl index d341ed2df..86b6693f3 100644 --- a/src/interpreter/diffractor_compiler_utils.jl +++ b/src/interpreter/diffractor_compiler_utils.jl @@ -117,7 +117,7 @@ function find_end_of_phi_block(ir::IRCode, start_search_idx::Int) return end_search_idx end -function replace_call!(ir::IRCode, idx::SSAValue, new_call) +function replace_call!(ir, idx::SSAValue, new_call) ir[idx][:inst] = new_call ir[idx][:type] = Any ir[idx][:info] = CC.NoCallInfo() diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index 543fa8ffe..7901d33a7 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -32,17 +32,17 @@ function build_frule( # If we've already derived the OpaqueClosures and info, do not re-derive, just # create a copy and pass in new shared data. oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode)) - if haskey(interp.oc_cache, oc_cache_key) - return interp.oc_cache[oc_cache_key] - else - # Derive forward-pass IR, and shove in a `MistyClosure`. - dual_ir = generate_dual_ir(interp, sig_or_mi; debug_mode) - dual_oc = MistyClosure(dual_ir; do_compile=true) - raw_rule = DerivedFRule(dual_oc) - rule = debug_mode ? DebugFRule(raw_rule) : raw_rule - interp.oc_cache[oc_cache_key] = rule - return rule - end + # if haskey(interp.oc_cache, oc_cache_key) + # return interp.oc_cache[oc_cache_key] + # else + # Derive forward-pass IR, and shove in a `MistyClosure`. + dual_ir = generate_dual_ir(interp, sig_or_mi; debug_mode) + dual_oc = MistyClosure(dual_ir; do_compile=true) + raw_rule = DerivedFRule(dual_oc) + rule = debug_mode ? DebugFRule(raw_rule) : raw_rule + interp.oc_cache[oc_cache_key] = rule + return rule + # end catch e rethrow(e) finally @@ -72,13 +72,6 @@ function generate_dual_ir( _, spnames = is_vararg_and_sparam_names(sig_or_mi) primal_ir = normalise!(primal_ir, spnames) - # Insert statements into primal IR - for i in 1:length(primal_ir.stmts) - stmt = primal_ir[SSAValue(i)][:stmt] - insert_fwd_ad_stmts!(primal_ir, interp, stmt, i; debug_mode) - end - primal_ir = CC.compact!(primal_ir) - # Keep a copy of the primal IR with the insertions dual_ir = copy(primal_ir) @@ -94,39 +87,27 @@ function generate_dual_ir( end pushfirst!(dual_ir.argtypes, Any) - # Modify dual IR without insertions - for i in 1:length(dual_ir.stmts) - stmt = dual_ir[SSAValue(i)][:stmt] - modify_fwd_ad_stmts!(dual_ir, primal_ir, interp, stmt, i; debug_mode) + # Modify dual IR incrementally + dual_ir_comp = CC.IncrementalCompact(dual_ir) + for ((_, i), inst) in dual_ir_comp + modify_fwd_ad_stmts!(dual_ir_comp, primal_ir, interp, inst, i; debug_mode) end - dual_ir = CC.compact!(dual_ir) # skippable but good practice + dual_ir_comp = CC.finish(dual_ir_comp) + dual_ir_comp = CC.compact!(dual_ir_comp) - CC.verify_ir(dual_ir) + CC.verify_ir(dual_ir_comp) # Optimize dual IR - opt_dual_ir = optimise_ir!(dual_ir; do_inline) + opt_dual_ir = optimise_ir!(dual_ir_comp; do_inline=false) # TODO: toggle + # @info "Inferred dual IR" + # display(opt_dual_ir) # TODO: toggle return opt_dual_ir end -## Insertion (only GotoIfNot) - -function insert_fwd_ad_stmts!( - primal_ir::IRCode, ::MooncakeInterpreter, stmt, i::Integer; kwargs... -) - return nothing -end - -function insert_fwd_ad_stmts!( - primal_ir::IRCode, ::MooncakeInterpreter, stmt::GotoIfNot, i::Integer; kwargs... -) - get_primal_inst = CC.NewInstruction(Expr(:call, _primal, stmt.cond), Any) - return CC.insert_node!(primal_ir, CC.SSAValue(i), get_primal_inst, false) -end - ## Modification function modify_fwd_ad_stmts!( - dual_ir::IRCode, + dual_ir::CC.IncrementalCompact, primal_ir::IRCode, ::MooncakeInterpreter, stmt::Nothing, @@ -137,10 +118,10 @@ function modify_fwd_ad_stmts!( end function modify_fwd_ad_stmts!( - dual_ir::IRCode, + dual_ir::CC.IncrementalCompact, primal_ir::IRCode, ::MooncakeInterpreter, - stmt::Union{GotoNode,Core.GotoIfNot}, + stmt::GotoNode, i::Integer; kwargs..., ) @@ -148,16 +129,28 @@ function modify_fwd_ad_stmts!( end function modify_fwd_ad_stmts!( - dual_ir::IRCode, + dual_ir::CC.IncrementalCompact, primal_ir::IRCode, ::MooncakeInterpreter, - stmt::GotoIfNot, + stmt::Core.GotoIfNot, i::Integer; kwargs..., ) - return Mooncake.replace_call!( - dual_ir, CC.SSAValue(i), Core.GotoIfNot(CC.SSAValue(i - 1), stmt.dest) + # replace GotoIfNot with the call to primal + Mooncake.replace_call!(dual_ir, CC.SSAValue(i), Expr(:call, _primal, stmt.cond)) + # reinsert the GotoIfNot right after the call to primal + # (incremental insertion cannot be done before "where we are") + new_gotoifnot_inst = CC.NewInstruction( + Core.GotoIfNot(CC.SSAValue(i), stmt.dest), # + Any, + CC.NoCallInfo(), + Int32(1), # meaningless + CC.IR_FLAG_REFINED, ) + # stick the new instruction in the previous CFG block + reverse_affinity = true + CC.insert_node_here!(dual_ir, new_gotoifnot_inst, reverse_affinity) + return nothing end # TODO: wrapping in Dual must not be systematic (e.g. Argument or SSAValue) @@ -173,17 +166,18 @@ DynamicFRule(debug_mode::Bool) = DynamicFRule(Dict{Any,Any}(), debug_mode) _copy(x::P) where {P<:DynamicFRule} = P(Dict{Any,Any}(), x.debug_mode) function (dynamic_rule::DynamicFRule)(args::Vararg{Any,N}) where {N} - sig = Tuple{map(_typeof ∘ primal, args)...} + args_dual = map(make_dual, args) # TODO: don't turn everything into a Dual, be clever with Argument and SSAValue + sig = Tuple{map(_typeof ∘ primal, args_dual)...} rule = get(dynamic_rule.cache, sig, nothing) if rule === nothing rule = build_frule(get_interpreter(), sig; debug_mode=dynamic_rule.debug_mode) dynamic_rule.cache[sig] = rule end - return rule(args...) + return rule(args_dual...) end function modify_fwd_ad_stmts!( - dual_ir::IRCode, + dual_ir::CC.IncrementalCompact, primal_ir::IRCode, ::MooncakeInterpreter, stmt::ReturnNode, @@ -198,7 +192,7 @@ function modify_fwd_ad_stmts!( end function modify_fwd_ad_stmts!( - dual_ir::IRCode, + dual_ir::CC.IncrementalCompact, primal_ir::IRCode, ::MooncakeInterpreter, stmt::PhiNode, @@ -212,7 +206,7 @@ function modify_fwd_ad_stmts!( end function modify_fwd_ad_stmts!( - dual_ir::IRCode, + dual_ir::CC.IncrementalCompact, primal_ir::IRCode, interp::MooncakeInterpreter, stmt::Expr, @@ -244,10 +238,13 @@ function modify_fwd_ad_stmts!( @assert isexpr(stmt, :call) rule = DynamicFRule(debug_mode) end + # TODO: could this insertion of a naked rule in the IR cause a memory leak? call_rule = Expr(:call, rule, shifted_args...) replace_call!(dual_ir, SSAValue(i), call_rule) end - elseif Meta.isexpr(stmt, :code_coverage_effect) + elseif isexpr(stmt, :boundscheck) + nothing + elseif isexpr(stmt, :code_coverage_effect) replace_call!(dual_ir, SSAValue(i), nothing) else throw( diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index 03e1d3281..af6f61195 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -89,6 +89,7 @@ using Core: Intrinsics using Mooncake import ..Mooncake: rrule!!, + frule!!, CoDual, Dual, primal, @@ -146,8 +147,10 @@ macro inactive_intrinsic(name) function rrule!!(f::CoDual{typeof($name)}, args::Vararg{Any,N}) where {N} return Mooncake.zero_adjoint(f, args...) end - function frule!!(f::Dual{typeof($name)}, args::Vararg{Any,N}) where {N} - return Mooncake.zero_dual(f(args...)) + function frule!!(f::Dual{typeof($name)}, args::Vararg{Dual,N}) where {N} + f_primal = primal(f) + args_primal = map(primal, args) + return zero_dual(f_primal(args_primal...)) end end return esc(expr) diff --git a/src/rrules/misc.jl b/src/rrules/misc.jl index f6f4de643..4b88417fc 100644 --- a/src/rrules/misc.jl +++ b/src/rrules/misc.jl @@ -58,6 +58,16 @@ This approach is identical to the one taken by `Zygote.jl` to circumvent the sam lgetfield(x, ::Val{f}) where {f} = getfield(x, f) @is_primitive MinimalCtx Tuple{typeof(lgetfield),Any,Val} +@inline function frule!!(::Dual{typeof(lgetfield)}, x::Dual, ::Dual{Val{f}}) where {f} + P = typeof(primal(x)) + primal_field = getfield(primal(x), f) + tangent_field = if tangent_type(P) === NoTangent + NoTangent() + else + getfield(tangent(x).fields, f) + end + return Dual(primal_field, tangent_field) +end @inline function rrule!!( ::CoDual{typeof(lgetfield)}, x::CoDual{P,F}, ::CoDual{Val{f}} ) where {P,F<:StandardFDataType,f} diff --git a/src/test_utils.jl b/src/test_utils.jl index 57b4461c9..ec60f1dec 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -350,7 +350,70 @@ function address_maps_are_consistent(x::AddressMap, y::AddressMap) end # Assumes that the interface has been tested, and we can simply check for numerical issues. -function test_rule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb::Bool) +function test_frule_correctness(rng::AbstractRNG, x_ẋ...; rule, unsafe_perturb::Bool) + # TODO: Will can fix it + #= + @nospecialize rng x_ẋ + + x_ẋ = map(_deepcopy, x_ẋ) # defensive copy + + # Run original function on deep-copies of inputs. + x = map(primal, x_ẋ) + # ẋ = map(tangent, x_ẋ) + x_primal = _deepcopy(x) + y_primal = x_primal[1](x_primal[2:end]...) + + # Use finite differences to estimate Frechet derivative. + ẋ = map(_x -> randn_tangent(rng, _x), x) + ε = 1e-7 + x′ = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) + y′ = x′[1](x′[2:end]...) + ẏ = _scale(1 / ε, _diff(y′, y_primal)) + ẋ_post = map((_x′, _x_p) -> _scale(1 / ε, _diff(_x′, _x_p)), x′, x_primal) + + # Run rule on copies of `f` and `x`. We use randomly generated tangents so that we + # can later verify that non-zero values do not get propagated by the rule. + ẋ_zero = map(zero_tangent, x) + x_ẋ_rule = map((x, ẋ) -> dual_type(_typeof(x))(_deepcopy(x), ẋ), x, ẋ_zero) + inputs_address_map = populate_address_map( + map(primal, x_ẋ_rule), map(tangent, x_ẋ_rule) + ) + y_ẏ_rule = rule(x_ẋ_rule...) + + # Verify that inputs / outputs are the same under `f` and its rrule. + @test has_equal_data(x_primal, map(primal, x_ẋ_rule)) + @test has_equal_data(y_primal, primal(y_ẏ_rule)) + + # Query both `x_ẋ` and `y`, because `x_ẋ` may have been mutated by `f`. + outputs_address_map = populate_address_map( + (map(primal, x_x̄_rule)..., primal(y_ȳ_rule)), + (map(tangent, x_x̄_rule)..., tangent(y_ȳ_rule)), + ) + @test address_maps_are_consistent(inputs_address_map, outputs_address_map) + + # Run reverse-pass. + ȳ_delta = randn_tangent(rng, primal(y_ȳ_rule)) + x̄_delta = map(Base.Fix1(randn_tangent, rng) ∘ primal, x_x̄_rule) + + ȳ_init = set_to_zero!!(zero_tangent(primal(y_ȳ_rule), tangent(y_ȳ_rule))) + x̄_init = map(set_to_zero!!, x̄_zero) + ȳ = increment!!(ȳ_init, ȳ_delta) + map(increment!!, x̄_init, x̄_delta) + x̄_rvs_inc = pb!!(Mooncake.rdata(ȳ)) + x̄_rvs = increment!!(map(rdata, x̄_delta), x̄_rvs_inc) + x̄ = map(tangent, x̄_fwds, x̄_rvs) + + # Check that inputs have been returned to their original value. + @test all(map(has_equal_data_up_to_undefs, x, map(primal, x_x̄_rule))) + + # pullbacks increment, so have to compare to the incremented quantity. + @test _dot(ȳ_delta, ẏ) + _dot(x̄_delta, ẋ_post) ≈ _dot(x̄, ẋ) rtol = 1e-3 atol = + 1e-3 + =# +end + +# Assumes that the interface has been tested, and we can simply check for numerical issues. +function test_rrule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb::Bool) @nospecialize rng x_x̄ x_x̄ = map(_deepcopy, x_x̄) # defensive copy @@ -419,7 +482,53 @@ _deepcopy(x::Module) = x rrule_output_type(::Type{Ty}) where {Ty} = Tuple{Mooncake.fcodual_type(Ty),Any} -function test_rrule_interface(f_f̄, x_x̄...; rule) +function test_frule_interface(f_ḟ, x_ẋ...; frule) + @nospecialize f_ḟ x_ẋ + + # Pull out primals and run primal computation. + f = primal(f_ḟ) + ḟ = tangent(f_ḟ) + x_ẋ = map(_deepcopy, x_ẋ) + x = map(primal, x_ẋ) + ẋ = map(tangent, x_ẋ) + + # Run the primal programme. Bail out early if this doesn't work. + y = try + f(deepcopy(x)...) + catch e + display(e) + println() + throw(ArgumentError("Primal evaluation does not work.")) + end + + # Check that input types are valid. + @test _typeof(tangent(f_ḟ)) == tangent_type(_typeof(primal(f_ḟ))) + for x_ẋ_component in x_ẋ + @test _typeof(tangent(x_ẋ_component)) == + tangent_type(_typeof(primal(x_ẋ_component))) + end + + # Run the rrule, check it has output a thing of the correct type, and extract results. + # Throw a meaningful exception if the rrule doesn't run at all. + rrule_ret = try + rule(f_ḟ, x_ẋ...) + catch e + display(e) + println() + throw( + ArgumentError( + "rule for $(_typeof(f_ḟ)) with argument types $(_typeof(x_ẋ)) does not run.", + ), + ) + end + y_ẏ = rrule_ret + + # Check that returned fdata type is correct. + @test y_ẏ isa Dual + @test typeof(y_ẏ.dx) == tangent_type(typeof(y_ẏ.x)) +end + +function test_rrule_interface(f_f̄, x_x̄...; rrule) @nospecialize f_f̄ x_x̄ # Pull out primals and run primal computation. @@ -501,6 +610,12 @@ function __forwards_and_backwards(rule, x_x̄::Vararg{Any,N}) where {N} return pb!!(Mooncake.zero_rdata(primal(out))) end +function test_frule_performance( + performance_checks_flag::Symbol, rule::R, f_ḟ::F, x_ẋ::Vararg{Any,N} +) where {R,F,N} + @warn "No performance test for frule yet" +end + function test_rrule_performance( performance_checks_flag::Symbol, rule::R, f_f̄::F, x_x̄::Vararg{Any,N} ) where {R,F,N} @@ -616,17 +731,34 @@ function test_rule( interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), debug_mode::Bool=false, unsafe_perturb::Bool=false, + forward::Bool=false, ) @nospecialize rng x # Construct the rule. sig = _typeof(__get_primals(x)) - rule = Mooncake.build_rrule(interp, sig; debug_mode) + if forward + frule = Mooncake.build_frule(interp, sig; debug_mode) + rrule = missing + else + frule = missing + rrule = Mooncake.build_rrule(interp, sig; debug_mode) + end # If something is primitive, then the rule should be `rrule!!`. - is_primitive && @test rule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!) + if forward + is_primitive && @test frule == frule!! + else + is_primitive && @test rrule == (debug_mode ? Mooncake.DebugRRule(rrule!!) : rrule!!) + end # Generate random tangents for anything that is not already a CoDual. + x_ẋ = map(x -> if x isa Dual + x + else + zero_dual(x) + end, x) + x_x̄ = map(x -> if x isa CoDual x elseif interface_only @@ -636,16 +768,34 @@ function test_rule( end, x) # Test that the interface is basically satisfied (checks types / memory addresses). - test_rrule_interface(x_x̄...; rule) + if forward + test_frule_interface(x_ẋ...; frule) + else + test_rrule_interface(x_x̄...; rrule) + end # Test that answers are numerically correct / consistent. - interface_only || test_rule_correctness(rng, x_x̄...; rule, unsafe_perturb) + if forward + interface_only || test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb) + else + interface_only || test_rrule_correctness(rng, x_x̄...; rrule, unsafe_perturb) + end # Test the performance of the rule. - test_rrule_performance(perf_flag, rule, x_x̄...) + if forward + test_rrule_performance(perf_flag, rrule, x_ẋ...) + else + test_rrule_performance(perf_flag, rrule, x_x̄...) + end # Test the interface again, in order to verify that caching is working correctly. - return test_rrule_interface(x_x̄...; rule=Mooncake.build_rrule(interp, sig; debug_mode)) + if forward + test_frule_interface(x_ẋ...; frule=Mooncake.build_frule(interp, sig; debug_mode)) + else + test_rrule_interface(x_x̄...; rrule=Mooncake.build_rrule(interp, sig; debug_mode)) + end + + return nothing end function run_hand_written_rrule!!_test_cases(rng_ctor, v::Val) diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index 699da8c35..48ba0d18a 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -1,7 +1,10 @@ using MistyClosures using Mooncake using Test +using Core.Compiler: SSAValue +const CC = Core.Compiler +#= x, dx = 2.0, 3.0 xdual = Dual(x, dx) @@ -10,20 +13,23 @@ ydual = sin_rule(zero_dual(sin), xdual) @test primal(ydual) == sin(x) @test tangent(ydual) == dx * cos(x) +=# -function func(x) - y = sin(x) - if x[1] > 0 - z = cos(y) +function func2(x) + if x > 0.0 + y = sin(x) else - z = sin(y) + y = cos(x) end - return z + return y end -ir = Base.code_ircode(func, (Int,))[1][1] -func_rule = build_frule(func, x) -ydual = func_rule(zero_dual(func), xdual) +x = 1.0 +xdual = Dual(1.0, 2.0) -@test primal(ydual) == cos(sin(x)) -@test tangent(ydual) ≈ dx * -sin(sin(x)) * cos(x) +ir = Base.code_ircode(func2, (typeof(x),))[1][1] + +func_rule = build_frule(func2, x) +ydual = func_rule(zero_dual(func2), xdual) + +2cos(1) From 9bdb57f680646b1f0a0f51eaf1a6a9e0cdf30c8f Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 12:30:19 +0000 Subject: [PATCH 18/32] Initial test case additions --- test/interpreter/s2s_forward_mode_ad.jl | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index 48ba0d18a..d8420614b 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -4,6 +4,18 @@ using Test using Core.Compiler: SSAValue const CC = Core.Compiler +@testset "s2s_forward_mode_ad" begin + test_cases = collect(enumerate(TestResources.generate_test_functions()))[1:1] + @testset "$(_typeof((f, x...)))" for (n, (interface_only, _, _, f, x...)) in test_cases + sig = _typeof((f, x...)) + @info "$n: $sig" + TestUtils.test_rule( + Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false + ) + end +end + + #= x, dx = 2.0, 3.0 xdual = Dual(x, dx) From 4bb9911efb7331b5fcda06229335dee6d8fb13fa Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 12:31:32 +0000 Subject: [PATCH 19/32] Formatting --- src/test_utils.jl | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index ec60f1dec..b4933a6df 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -753,11 +753,7 @@ function test_rule( end # Generate random tangents for anything that is not already a CoDual. - x_ẋ = map(x -> if x isa Dual - x - else - zero_dual(x) - end, x) + x_ẋ = map(x -> x isa Dual ? x : zero_dual(x), x) x_x̄ = map(x -> if x isa CoDual x From 9b037e74c6538b41f724b0d3a09cae5def16074b Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 13:58:23 +0000 Subject: [PATCH 20/32] Add verify_dual_type --- src/dual.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/dual.jl b/src/dual.jl index 363bb664b..4ac784e35 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -26,3 +26,12 @@ _primal(x::Dual) = primal(x) make_dual(x) = zero_dual(x) make_dual(x::Dual) = x + +""" + verify_dual_type(x::Dual) + +Check that the type of `tangent(x)` is the tangent type of the type of `primal(x)`. +""" +function verify_dual_type(x::Dual) + return tangent_type(typeof(primal(x))) == typeof(tangent(x)) +end From 6dea624246bfe64a9c3e2380dfb6aa4dd0f9286b Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 13:59:06 +0000 Subject: [PATCH 21/32] test_frule_interface runs --- src/test_utils.jl | 159 +++++++++++++++++++++------------------------- 1 file changed, 73 insertions(+), 86 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index b4933a6df..420ba31c1 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -482,50 +482,36 @@ _deepcopy(x::Module) = x rrule_output_type(::Type{Ty}) where {Ty} = Tuple{Mooncake.fcodual_type(Ty),Any} -function test_frule_interface(f_ḟ, x_ẋ...; frule) - @nospecialize f_ḟ x_ẋ +function test_frule_interface(x_ẋ...; frule) + @nospecialize x_ẋ # Pull out primals and run primal computation. - f = primal(f_ḟ) - ḟ = tangent(f_ḟ) x_ẋ = map(_deepcopy, x_ẋ) x = map(primal, x_ẋ) - ẋ = map(tangent, x_ẋ) # Run the primal programme. Bail out early if this doesn't work. y = try - f(deepcopy(x)...) - catch e - display(e) - println() - throw(ArgumentError("Primal evaluation does not work.")) + x[1](deepcopy(x[2:end])...) + catch + throw(ArgumentError("Primal does not run, signature is $(_typeof(x_ẋ)).")) end # Check that input types are valid. - @test _typeof(tangent(f_ḟ)) == tangent_type(_typeof(primal(f_ḟ))) for x_ẋ_component in x_ẋ - @test _typeof(tangent(x_ẋ_component)) == - tangent_type(_typeof(primal(x_ẋ_component))) + @test Mooncake.verify_dual_type(x_ẋ_component) end - # Run the rrule, check it has output a thing of the correct type, and extract results. - # Throw a meaningful exception if the rrule doesn't run at all. - rrule_ret = try - rule(f_ḟ, x_ẋ...) - catch e - display(e) - println() - throw( - ArgumentError( - "rule for $(_typeof(f_ḟ)) with argument types $(_typeof(x_ẋ)) does not run.", - ), - ) + # Run the frule, check it has output a thing of the correct type, and extract results. + # Throw a meaningful exception if the frule doesn't run at all. + y_ẏ = try + frule(x_ẋ...) + catch + throw(ArgumentError("rule does not run, signature is $(_typeof(x_ẋ)).")) end - y_ẏ = rrule_ret # Check that returned fdata type is correct. @test y_ẏ isa Dual - @test typeof(y_ẏ.dx) == tangent_type(typeof(y_ẏ.x)) + @test Mooncake.verify_dual_type(y_ẏ) end function test_rrule_interface(f_f̄, x_x̄...; rrule) @@ -663,65 +649,66 @@ end __get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs) -@doc """ - test_rule( - rng, x...; - interface_only=false, - is_primitive::Bool=true, - perf_flag::Symbol=:none, - interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), - debug_mode::Bool=false, - unsafe_perturb::Bool=false, - ) - - Run standardised tests on the `rule` for `x`. - The first element of `x` should be the primal function to test, and each other element a - positional argument. - In most cases, elements of `x` can just be the primal values, and `randn_tangent` can be - relied upon to generate an appropriate tangent to test. Some notable exceptions exist - though, in partcular `Ptr`s. In this case, the argument for which `randn_tangent` cannot be - readily defined should be a `CoDual` containing the primal, and a _manually_ constructed - tangent field. - - This function uses [`Mooncake.build_rrule`](@ref) to construct a rule. This will use an - `rrule!!` if one exists, and derive a rule otherwise. - - # Arguments - - `rng::AbstractRNG`: a random number generator - - `x...`: the function (first element) and its arguments (the remainder) - - # Keyword Arguments - - `interface_only::Bool=false`: test only that the interface is satisfied, without testing - correctness. This should generally be set to `false` (the default value), and only - enabled if the testing infrastructure is unable to test correctness for some reason - e.g. the returned value of the function is a `Ptr`, and appropriate tangents cannot, - therefore, be generated for it automatically. - - `is_primitive::Bool=true`: check whether the thing that you are testing has a hand-written - `rrule!!`. This option is helpful if you are testing a new `rrule!!`, as it enables you - to verify that your method of `is_primitive` has returned the correct value, and that - you are actually testing a method of the `rrule!!` function -- a common mistake when - authoring a new `rrule!!` is to implement `is_primitive` incorrectly and to accidentally - wind up testing a rule which Mooncake has derived, as opposed to the one that you have - written. If you are testing something for which you have not - hand-written an `rrule!!`, or which you do not care whether it has a hand-written - `rrule!!` or not, you should set it to `false`. - - `perf_flag::Symbol=:none`: the value of this symbol determines what kind of performance - tests should be performed. By default, none are performed. If you believe that a rule - should be allocation-free (iff the primal is allocation free), set this to `:allocs`. If - you hand-write an `rrule!!` and believe that your test case should be type stable, set - this to `:stability` (at present we cannot verify whether a derived rule is type stable - for technical reasons). If you believe that a hand-written rule should be _both_ - allocation-free and type-stable, set this to `:stability_and_allocs`. - - `interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter()`: the abstract - interpreter to be used when testing this rule. The default should generally be used. - - `debug_mode::Bool=false`: whether or not the rule should be tested in debug mode. - Typically this should be left at its default `false` value, but if you are finding that - the tests are failing for a given rule, you may wish to temporarily set it to `true` in - order to get access to additional information and automated testing. - - `unsafe_perturb::Bool=false`: value passed as the third argument to `_add_to_primal`. - Should usually be left `false` -- consult the docstring for `_add_to_primal` for more - info on when you might wish to set it to `true`. - """ +""" + test_rule( + rng, x...; + interface_only=false, + is_primitive::Bool=true, + perf_flag::Symbol=:none, + interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), + debug_mode::Bool=false, + unsafe_perturb::Bool=false, + forward::Bool=false, + ) + +Run standardised tests on the `rule` for `x`. +The first element of `x` should be the primal function to test, and each other element a +positional argument. +In most cases, elements of `x` can just be the primal values, and `randn_tangent` can be +relied upon to generate an appropriate tangent to test. Some notable exceptions exist +though, in partcular `Ptr`s. In this case, the argument for which `randn_tangent` cannot be +readily defined should be a `CoDual` containing the primal, and a _manually_ constructed +tangent field. + +This function uses [`Mooncake.build_rrule`](@ref) to construct a rule. This will use an +`rrule!!` if one exists, and derive a rule otherwise. + +# Arguments +- `rng::AbstractRNG`: a random number generator +- `x...`: the function (first element) and its arguments (the remainder) + +# Keyword Arguments +- `interface_only::Bool=false`: test only that the interface is satisfied, without testing + correctness. This should generally be set to `false` (the default value), and only + enabled if the testing infrastructure is unable to test correctness for some reason + e.g. the returned value of the function is a `Ptr`, and appropriate tangents cannot, + therefore, be generated for it automatically. +- `is_primitive::Bool=true`: check whether the thing that you are testing has a hand-written + `rrule!!`. This option is helpful if you are testing a new `rrule!!`, as it enables you + to verify that your method of `is_primitive` has returned the correct value, and that + you are actually testing a method of the `rrule!!` function -- a common mistake when + authoring a new `rrule!!` is to implement `is_primitive` incorrectly and to accidentally + wind up testing a rule which Mooncake has derived, as opposed to the one that you have + written. If you are testing something for which you have not + hand-written an `rrule!!`, or which you do not care whether it has a hand-written + `rrule!!` or not, you should set it to `false`. +- `perf_flag::Symbol=:none`: the value of this symbol determines what kind of performance + tests should be performed. By default, none are performed. If you believe that a rule + should be allocation-free (iff the primal is allocation free), set this to `:allocs`. If + you hand-write an `rrule!!` and believe that your test case should be type stable, set + this to `:stability` (at present we cannot verify whether a derived rule is type stable + for technical reasons). If you believe that a hand-written rule should be _both_ + allocation-free and type-stable, set this to `:stability_and_allocs`. +- `interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter()`: the abstract + interpreter to be used when testing this rule. The default should generally be used. +- `debug_mode::Bool=false`: whether or not the rule should be tested in debug mode. + Typically this should be left at its default `false` value, but if you are finding that + the tests are failing for a given rule, you may wish to temporarily set it to `true` in + order to get access to additional information and automated testing. +- `unsafe_perturb::Bool=false`: value passed as the third argument to `_add_to_primal`. + Should usually be left `false` -- consult the docstring for `_add_to_primal` for more + info on when you might wish to set it to `true`. +""" function test_rule( rng::AbstractRNG, x...; From a6148463a3e8313e62228f00153ccd48ffe07af7 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 14:19:03 +0000 Subject: [PATCH 22/32] Fix ReturnNode --- src/interpreter/s2s_forward_mode_ad.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index 7901d33a7..74412f862 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -186,6 +186,7 @@ function modify_fwd_ad_stmts!( ) # the return node becomes a Dual so it changes type # flag to re-run type inference + dual_ir[SSAValue(i)][:stmt] = inc_args(stmt) dual_ir[SSAValue(i)][:type] = Any dual_ir[SSAValue(i)][:flag] = CC.IR_FLAG_REFINED return nothing From eadae95026d7d35b3b96333a86be5860e1ac54b7 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 14:19:26 +0000 Subject: [PATCH 23/32] Correctness testing runs --- src/test_utils.jl | 53 +++++++++++++++++------------------------------ 1 file changed, 19 insertions(+), 34 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 420ba31c1..2f9a45f60 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -114,6 +114,7 @@ using Mooncake: instantiate, can_produce_zero_rdata_from_type, increment_rdata!!, + dual_type, fcodual_type, verify_fdata_type, verify_rdata_type, @@ -350,35 +351,32 @@ function address_maps_are_consistent(x::AddressMap, y::AddressMap) end # Assumes that the interface has been tested, and we can simply check for numerical issues. -function test_frule_correctness(rng::AbstractRNG, x_ẋ...; rule, unsafe_perturb::Bool) - # TODO: Will can fix it - #= +function test_frule_correctness(rng::AbstractRNG, x_ẋ...; frule, unsafe_perturb::Bool) @nospecialize rng x_ẋ x_ẋ = map(_deepcopy, x_ẋ) # defensive copy # Run original function on deep-copies of inputs. x = map(primal, x_ẋ) - # ẋ = map(tangent, x_ẋ) + ẋ = map(tangent, x_ẋ) x_primal = _deepcopy(x) y_primal = x_primal[1](x_primal[2:end]...) - # Use finite differences to estimate Frechet derivative. - ẋ = map(_x -> randn_tangent(rng, _x), x) + # Use finite differences to estimate Frechet derivative at ẋ. ε = 1e-7 x′ = _add_to_primal(x, _scale(ε, ẋ), unsafe_perturb) y′ = x′[1](x′[2:end]...) - ẏ = _scale(1 / ε, _diff(y′, y_primal)) - ẋ_post = map((_x′, _x_p) -> _scale(1 / ε, _diff(_x′, _x_p)), x′, x_primal) + ẏ_fd = _scale(1 / ε, _diff(y′, y_primal)) + ẋ_fd = map((_x′, _x_p) -> _scale(1 / ε, _diff(_x′, _x_p)), x′, x_primal) - # Run rule on copies of `f` and `x`. We use randomly generated tangents so that we - # can later verify that non-zero values do not get propagated by the rule. - ẋ_zero = map(zero_tangent, x) - x_ẋ_rule = map((x, ẋ) -> dual_type(_typeof(x))(_deepcopy(x), ẋ), x, ẋ_zero) + # Use AD to compute Frechet derivative at ẋ. + x_ẋ_rule = map((x, ẋ) -> dual_type(_typeof(x))(_deepcopy(x), ẋ), x, ẋ) inputs_address_map = populate_address_map( map(primal, x_ẋ_rule), map(tangent, x_ẋ_rule) ) - y_ẏ_rule = rule(x_ẋ_rule...) + y_ẏ_rule = frule(x_ẋ_rule...) + ẋ_ad = map(tangent, x_ẋ_rule) + ẏ_ad = tangent(y_ẏ_rule) # Verify that inputs / outputs are the same under `f` and its rrule. @test has_equal_data(x_primal, map(primal, x_ẋ_rule)) @@ -386,30 +384,17 @@ function test_frule_correctness(rng::AbstractRNG, x_ẋ...; rule, unsafe_perturb # Query both `x_ẋ` and `y`, because `x_ẋ` may have been mutated by `f`. outputs_address_map = populate_address_map( - (map(primal, x_x̄_rule)..., primal(y_ȳ_rule)), - (map(tangent, x_x̄_rule)..., tangent(y_ȳ_rule)), + (map(primal, x_ẋ_rule)..., primal(y_ẏ_rule)), + (map(tangent, x_ẋ_rule)..., tangent(y_ẏ_rule)), ) - @test address_maps_are_consistent(inputs_address_map, outputs_address_map) - - # Run reverse-pass. - ȳ_delta = randn_tangent(rng, primal(y_ȳ_rule)) - x̄_delta = map(Base.Fix1(randn_tangent, rng) ∘ primal, x_x̄_rule) - - ȳ_init = set_to_zero!!(zero_tangent(primal(y_ȳ_rule), tangent(y_ȳ_rule))) - x̄_init = map(set_to_zero!!, x̄_zero) - ȳ = increment!!(ȳ_init, ȳ_delta) - map(increment!!, x̄_init, x̄_delta) - x̄_rvs_inc = pb!!(Mooncake.rdata(ȳ)) - x̄_rvs = increment!!(map(rdata, x̄_delta), x̄_rvs_inc) - x̄ = map(tangent, x̄_fwds, x̄_rvs) - # Check that inputs have been returned to their original value. - @test all(map(has_equal_data_up_to_undefs, x, map(primal, x_x̄_rule))) + # Check that all aliasing structure is correct. + @test address_maps_are_consistent(inputs_address_map, outputs_address_map) - # pullbacks increment, so have to compare to the incremented quantity. - @test _dot(ȳ_delta, ẏ) + _dot(x̄_delta, ẋ_post) ≈ _dot(x̄, ẋ) rtol = 1e-3 atol = - 1e-3 - =# + # Any linear projection of the outputs ought to do. + x̄ = map(Base.Fix1(randn_tangent, rng), x′) + ȳ = randn_tangent(rng, y′) + @test _dot(ȳ, ẏ_fd) + _dot(x̄, ẋ_fd) ≈ _dot(ȳ, ẏ_ad) + _dot(x̄, ẋ_ad) rtol=1e-3 atol=1e-3 end # Assumes that the interface has been tested, and we can simply check for numerical issues. From 345b3fd03e79d0aa421a9d4d6c8a40bcc9ef1499 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 14:39:52 +0000 Subject: [PATCH 24/32] Add randn_dual --- src/dual.jl | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index 4ac784e35..bfc9ae5fe 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -9,6 +9,7 @@ Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x))) _copy(x::P) where {P<:Dual} = x zero_dual(x) = Dual(x, zero_tangent(x)) +randn_dual(rng::AbstractRNG, x) = Dual(x, randn_tangent(rng, x)) function dual_type(::Type{P}) where {P} P == DataType && return Dual @@ -32,6 +33,4 @@ make_dual(x::Dual) = x Check that the type of `tangent(x)` is the tangent type of the type of `primal(x)`. """ -function verify_dual_type(x::Dual) - return tangent_type(typeof(primal(x))) == typeof(tangent(x)) -end +verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x)) From f58c3947263b5fa7339698a74242d575c758dcd3 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 14:40:38 +0000 Subject: [PATCH 25/32] Improve sin and cos frules --- src/rrules/low_level_maths.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index e37f1669d..7c66ef20e 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -36,7 +36,8 @@ end @is_primitive MinimalCtx Tuple{typeof(sin),<:IEEEFloat} function frule!!(::Dual{typeof(sin)}, x::Dual{<:IEEEFloat}) - return Dual(sin(primal(x)), cos(primal(x)) * tangent(x)) + s, c = sincos(primal(x)) + return Dual(s, c * tangent(x)) end function rrule!!(::CoDual{typeof(sin),NoFData}, x::CoDual{P,NoFData}) where {P<:IEEEFloat} s, c = sincos(primal(x)) @@ -46,7 +47,8 @@ end @is_primitive MinimalCtx Tuple{typeof(cos),<:IEEEFloat} function frule!!(::Dual{typeof(cos)}, x::Dual{<:IEEEFloat}) - return Dual(cos(primal(x)), -sin(primal(x)) * tangent(x)) + s, c = sincos(primal(x)) + return Dual(c, -s * tangent(x)) end function rrule!!(::CoDual{typeof(cos),NoFData}, x::CoDual{P,NoFData}) where {P<:IEEEFloat} s, c = sincos(primal(x)) From c8d88950b53c3c596758f72d486acab4627c0306 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 14:40:59 +0000 Subject: [PATCH 26/32] Performance tests run --- src/test_utils.jl | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 2f9a45f60..8d0733435 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -115,6 +115,7 @@ using Mooncake: can_produce_zero_rdata_from_type, increment_rdata!!, dual_type, + randn_dual, fcodual_type, verify_fdata_type, verify_rdata_type, @@ -576,6 +577,8 @@ function test_rrule_interface(f_f̄, x_x̄...; rrule) @test all(map((a, b) -> _typeof(a) == _typeof(rdata(b)), x̄_new, x̄)) end +__forwards(frule::F, x_ẋ::Vararg{Any,N}) where {F,N} = frule(x_ẋ...) + function __forwards_and_backwards(rule, x_x̄::Vararg{Any,N}) where {N} out, pb!! = rule(x_x̄...) return pb!!(Mooncake.zero_rdata(primal(out))) @@ -584,7 +587,40 @@ end function test_frule_performance( performance_checks_flag::Symbol, rule::R, f_ḟ::F, x_ẋ::Vararg{Any,N} ) where {R,F,N} - @warn "No performance test for frule yet" + + # Verify that a valid performance flag has been passed. + valid_flags = (:none, :stability, :allocs, :stability_and_allocs) + if !in(performance_checks_flag, valid_flags) + throw( + ArgumentError( + "performance_checks=$performance_checks_flag. Must be one of $valid_flags" + ), + ) + end + performance_checks_flag == :none && return nothing + + if performance_checks_flag in (:stability, :stability_and_allocs) + + # Test primal stability. + test_opt(Shim(), primal(f_ḟ), map(_typeof ∘ primal, x_ẋ)) + + # Test forwards-mode stability. + @show (_typeof(f_ḟ), map(_typeof, x_ẋ)...), rule + test_opt(Shim(), rule, (_typeof(f_ḟ), map(_typeof, x_ẋ)...)) + end + + if performance_checks_flag in (:allocs, :stability_and_allocs) + f = primal(f_ḟ) + x = map(primal, x_ẋ) + + # Test allocations in primal. + f(x...) + @test (@allocations f(x...)) == 0 + + # Test allocations in forwards-mode. + __forwards(rule, f_ḟ, x_ẋ...) + @test (@allocations __forwards(rule, f_ḟ, x_ẋ...)) == 0 + end end function test_rrule_performance( @@ -725,7 +761,7 @@ function test_rule( end # Generate random tangents for anything that is not already a CoDual. - x_ẋ = map(x -> x isa Dual ? x : zero_dual(x), x) + x_ẋ = map(x -> x isa Dual ? x : randn_dual(rng, x), x) x_x̄ = map(x -> if x isa CoDual x @@ -751,7 +787,7 @@ function test_rule( # Test the performance of the rule. if forward - test_rrule_performance(perf_flag, rrule, x_ẋ...) + test_frule_performance(perf_flag, frule, x_ẋ...) else test_rrule_performance(perf_flag, rrule, x_x̄...) end From 578e41be7d0f925e4294a9bc814fe8d764be3cd3 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 14:41:31 +0000 Subject: [PATCH 27/32] Tidy up implementation --- src/interpreter/s2s_forward_mode_ad.jl | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index 74412f862..3bba2ea54 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -82,7 +82,7 @@ function generate_dual_ir( if P isa DataType dual_ir.argtypes[a] = dual_type(P) elseif P isa Core.Const - dual_ir.argtypes[a] = Dual # TODO: improve + dual_ir.argtypes[a] = dual_type(_typeof(P.val)) end end pushfirst!(dual_ir.argtypes, Any) @@ -98,7 +98,7 @@ function generate_dual_ir( CC.verify_ir(dual_ir_comp) # Optimize dual IR - opt_dual_ir = optimise_ir!(dual_ir_comp; do_inline=false) # TODO: toggle + opt_dual_ir = optimise_ir!(dual_ir_comp; do_inline) # TODO: toggle # @info "Inferred dual IR" # display(opt_dual_ir) # TODO: toggle return opt_dual_ir @@ -154,7 +154,9 @@ function modify_fwd_ad_stmts!( end # TODO: wrapping in Dual must not be systematic (e.g. Argument or SSAValue) -_frule!!_makedual(f, args::Vararg{Any,N}) where {N} = frule!!(make_dual.((f, args...))...) +function _frule!!_makedual(f::F, args::Vararg{Any,N}) where {F,N} + return frule!!(tuple_map(make_dual, (f, args...))...) +end struct DynamicFRule{V} cache::V From b5d34b245c677195e4577215f07586d7e28d3f89 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 14:41:49 +0000 Subject: [PATCH 28/32] Standard testing infrastructure --- test/interpreter/s2s_forward_mode_ad.jl | 67 +++++++++++++------------ 1 file changed, 36 insertions(+), 31 deletions(-) diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index d8420614b..5b1440e7a 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -1,47 +1,52 @@ -using MistyClosures -using Mooncake -using Test -using Core.Compiler: SSAValue -const CC = Core.Compiler +# using MistyClosures +# using Mooncake +# using Test +# using Core.Compiler: SSAValue +# const CC = Core.Compiler @testset "s2s_forward_mode_ad" begin - test_cases = collect(enumerate(TestResources.generate_test_functions()))[1:1] - @testset "$(_typeof((f, x...)))" for (n, (interface_only, _, _, f, x...)) in test_cases + test_cases = collect(enumerate(TestResources.generate_test_functions()))[3:4] + @testset "$(_typeof((f, x...)))" for (n, (int_only, pf, _, f, x...)) in test_cases sig = _typeof((f, x...)) @info "$n: $sig" TestUtils.test_rule( - Xoshiro(123456), f, x...; perf_flag, interface_only, is_primitive=false + Xoshiro(123456), + f, + x...; + perf_flag=pf, + interface_only=int_only, + is_primitive=false, + forward=true, ) end end +# #= +# x, dx = 2.0, 3.0 +# xdual = Dual(x, dx) -#= -x, dx = 2.0, 3.0 -xdual = Dual(x, dx) +# sin_rule = build_frule(sin, x) +# ydual = sin_rule(zero_dual(sin), xdual) -sin_rule = build_frule(sin, x) -ydual = sin_rule(zero_dual(sin), xdual) +# @test primal(ydual) == sin(x) +# @test tangent(ydual) == dx * cos(x) +# =# -@test primal(ydual) == sin(x) -@test tangent(ydual) == dx * cos(x) -=# +# function func2(x) +# if x > 0.0 +# y = sin(x) +# else +# y = cos(x) +# end +# return y +# end -function func2(x) - if x > 0.0 - y = sin(x) - else - y = cos(x) - end - return y -end - -x = 1.0 -xdual = Dual(1.0, 2.0) +# x = 1.0 +# xdual = Dual(1.0, 2.0) -ir = Base.code_ircode(func2, (typeof(x),))[1][1] +# ir = Base.code_ircode(func2, (typeof(x),))[1][1] -func_rule = build_frule(func2, x) -ydual = func_rule(zero_dual(func2), xdual) +# func_rule = build_frule(func2, x) +# ydual = func_rule(zero_dual(func2), xdual) -2cos(1) +# 2cos(1) From 205e716d5121e02faf48b9de681f09a11e5f63e9 Mon Sep 17 00:00:00 2001 From: willtebbutt Date: Fri, 6 Dec 2024 15:14:16 +0000 Subject: [PATCH 29/32] Fix typos --- src/test_utils.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/test_utils.jl b/src/test_utils.jl index 8d0733435..95ac58c49 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -399,7 +399,7 @@ function test_frule_correctness(rng::AbstractRNG, x_ẋ...; frule, unsafe_pertur end # Assumes that the interface has been tested, and we can simply check for numerical issues. -function test_rrule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb::Bool) +function test_rrule_correctness(rng::AbstractRNG, x_x̄...; rrule, unsafe_perturb::Bool) @nospecialize rng x_x̄ x_x̄ = map(_deepcopy, x_x̄) # defensive copy @@ -428,7 +428,7 @@ function test_rrule_correctness(rng::AbstractRNG, x_x̄...; rule, unsafe_perturb inputs_address_map = populate_address_map( map(primal, x_x̄_rule), map(tangent, x_x̄_rule) ) - y_ȳ_rule, pb!! = rule(x_x̄_rule...) + y_ȳ_rule, pb!! = rrule(x_x̄_rule...) # Verify that inputs / outputs are the same under `f` and its rrule. @test has_equal_data(x_primal, map(primal, x_x̄_rule)) @@ -533,7 +533,7 @@ function test_rrule_interface(f_f̄, x_x̄...; rrule) # Throw a meaningful exception if the rrule doesn't run at all. x_addresses = map(get_address, x) rrule_ret = try - rule(f_fwds, x_fwds...) + rrule(f_fwds, x_fwds...) catch e display(e) println() From d328db0bf6bcaa255005b5aabc647c84764ed068 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 18:00:05 +0100 Subject: [PATCH 30/32] Fix return node to return dual --- .github/workflows/documentation.yml | 32 ----------- src/dual.jl | 4 +- src/interpreter/s2s_forward_mode_ad.jl | 70 ++++++++++++++----------- src/test_utils.jl | 67 ++++++++++++++--------- test/interpreter/s2s_forward_mode_ad.jl | 44 ++-------------- 5 files changed, 86 insertions(+), 131 deletions(-) delete mode 100644 .github/workflows/documentation.yml diff --git a/.github/workflows/documentation.yml b/.github/workflows/documentation.yml deleted file mode 100644 index 0ec2baa23..000000000 --- a/.github/workflows/documentation.yml +++ /dev/null @@ -1,32 +0,0 @@ -name: Documentation - -on: - push: - branches: - - main - tags: '*' - pull_request: - -jobs: - build: - permissions: - contents: write - pull-requests: read - statuses: write - actions: write - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v4 - - uses: julia-actions/setup-julia@v2 - with: - version: '1' - arch: x64 - include-all-prereleases: false - - name: Install dependencies - run: julia --project=docs/ -e 'using Pkg; Pkg.develop(PackageSpec(path=pwd())); Pkg.update(); Pkg.instantiate()' - - name: Build and deploy - env: - GKSwstype: nul # turn off GR's interactive plotting for notebooks - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # If authenticating with GitHub Actions token - DOCUMENTER_KEY: ${{ secrets.DOCUMENTER_KEY }} # For authentication with SSH deploy key - run: julia --project=docs/ docs/make.jl diff --git a/src/dual.jl b/src/dual.jl index bfc9ae5fe..ae802ab60 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -25,8 +25,8 @@ end _primal(x) = x _primal(x::Dual) = primal(x) -make_dual(x) = zero_dual(x) -make_dual(x::Dual) = x +_dual(x) = zero_dual(x) +_dual(x::Dual) = x """ verify_dual_type(x::Dual) diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index 3bba2ea54..8eff21206 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -104,7 +104,7 @@ function generate_dual_ir( return opt_dual_ir end -## Modification +## Modification of IR nodes function modify_fwd_ad_stmts!( dual_ir::CC.IncrementalCompact, @@ -137,7 +137,9 @@ function modify_fwd_ad_stmts!( kwargs..., ) # replace GotoIfNot with the call to primal - Mooncake.replace_call!(dual_ir, CC.SSAValue(i), Expr(:call, _primal, stmt.cond)) + Mooncake.replace_call!( + dual_ir, CC.SSAValue(i), Expr(:call, _primal, inc_args(stmt).cond) + ) # reinsert the GotoIfNot right after the call to primal # (incremental insertion cannot be done before "where we are") new_gotoifnot_inst = CC.NewInstruction( @@ -153,31 +155,6 @@ function modify_fwd_ad_stmts!( return nothing end -# TODO: wrapping in Dual must not be systematic (e.g. Argument or SSAValue) -function _frule!!_makedual(f::F, args::Vararg{Any,N}) where {F,N} - return frule!!(tuple_map(make_dual, (f, args...))...) -end - -struct DynamicFRule{V} - cache::V - debug_mode::Bool -end - -DynamicFRule(debug_mode::Bool) = DynamicFRule(Dict{Any,Any}(), debug_mode) - -_copy(x::P) where {P<:DynamicFRule} = P(Dict{Any,Any}(), x.debug_mode) - -function (dynamic_rule::DynamicFRule)(args::Vararg{Any,N}) where {N} - args_dual = map(make_dual, args) # TODO: don't turn everything into a Dual, be clever with Argument and SSAValue - sig = Tuple{map(_typeof ∘ primal, args_dual)...} - rule = get(dynamic_rule.cache, sig, nothing) - if rule === nothing - rule = build_frule(get_interpreter(), sig; debug_mode=dynamic_rule.debug_mode) - dynamic_rule.cache[sig] = rule - end - return rule(args_dual...) -end - function modify_fwd_ad_stmts!( dual_ir::CC.IncrementalCompact, primal_ir::IRCode, @@ -186,11 +163,13 @@ function modify_fwd_ad_stmts!( i::Integer; kwargs..., ) - # the return node becomes a Dual so it changes type - # flag to re-run type inference - dual_ir[SSAValue(i)][:stmt] = inc_args(stmt) - dual_ir[SSAValue(i)][:type] = Any - dual_ir[SSAValue(i)][:flag] = CC.IR_FLAG_REFINED + # make sure that we always return a Dual even when it's a constant + Mooncake.replace_call!(dual_ir, CC.SSAValue(i), Expr(:call, _dual, inc_args(stmt).val)) + # return the result from the previous Dual conversion + new_return_inst = CC.NewInstruction( + Core.ReturnNode(CC.SSAValue(i)), Any, CC.NoCallInfo(), Int32(1), CC.IR_FLAG_REFINED + ) + CC.insert_node_here!(dual_ir, new_return_inst, true) return nothing end @@ -208,6 +187,33 @@ function modify_fwd_ad_stmts!( return nothing end +## Modification of IR nodes - expressions + +# TODO: wrapping in Dual must not be systematic (e.g. Argument or SSAValue) +function _frule!!_makedual(f::F, args::Vararg{Any,N}) where {F,N} + return frule!!(tuple_map(_dual, (f, args...))...) +end + +struct DynamicFRule{V} + cache::V + debug_mode::Bool +end + +DynamicFRule(debug_mode::Bool) = DynamicFRule(Dict{Any,Any}(), debug_mode) + +_copy(x::P) where {P<:DynamicFRule} = P(Dict{Any,Any}(), x.debug_mode) + +function (dynamic_rule::DynamicFRule)(args::Vararg{Any,N}) where {N} + args_dual = map(_dual, args) # TODO: don't turn everything into a Dual, be clever with Argument and SSAValue + sig = Tuple{map(_typeof ∘ primal, args_dual)...} + rule = get(dynamic_rule.cache, sig, nothing) + if rule === nothing + rule = build_frule(get_interpreter(), sig; debug_mode=dynamic_rule.debug_mode) + dynamic_rule.cache[sig] = rule + end + return rule(args_dual...) +end + function modify_fwd_ad_stmts!( dual_ir::CC.IncrementalCompact, primal_ir::IRCode, diff --git a/src/test_utils.jl b/src/test_utils.jl index 95ac58c49..0fde81f87 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -395,7 +395,8 @@ function test_frule_correctness(rng::AbstractRNG, x_ẋ...; frule, unsafe_pertur # Any linear projection of the outputs ought to do. x̄ = map(Base.Fix1(randn_tangent, rng), x′) ȳ = randn_tangent(rng, y′) - @test _dot(ȳ, ẏ_fd) + _dot(x̄, ẋ_fd) ≈ _dot(ȳ, ẏ_ad) + _dot(x̄, ẋ_ad) rtol=1e-3 atol=1e-3 + @test _dot(ȳ, ẏ_fd) + _dot(x̄, ẋ_fd) ≈ _dot(ȳ, ẏ_ad) + _dot(x̄, ẋ_ad) rtol = 1e-3 atol = + 1e-3 end # Assumes that the interface has been tested, and we can simply check for numerical issues. @@ -771,35 +772,51 @@ function test_rule( zero_codual(x) end, x) - # Test that the interface is basically satisfied (checks types / memory addresses). - if forward - test_frule_interface(x_ẋ...; frule) - else - test_rrule_interface(x_x̄...; rrule) - end + testset = @testset "$(typeof(x))" begin + # Test that the interface is basically satisfied (checks types / memory addresses). + @testset "Interface (1)" begin + if forward + test_frule_interface(x_ẋ...; frule) + else + test_rrule_interface(x_x̄...; rrule) + end + end - # Test that answers are numerically correct / consistent. - if forward - interface_only || test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb) - else - interface_only || test_rrule_correctness(rng, x_x̄...; rrule, unsafe_perturb) - end + # Test that answers are numerically correct / consistent. + @testset "Correctness" begin + if forward + interface_only || + test_frule_correctness(rng, x_ẋ...; frule, unsafe_perturb) + else + interface_only || + test_rrule_correctness(rng, x_x̄...; rrule, unsafe_perturb) + end + end - # Test the performance of the rule. - if forward - test_frule_performance(perf_flag, frule, x_ẋ...) - else - test_rrule_performance(perf_flag, rrule, x_x̄...) - end + # Test the performance of the rule. + @testset "Performance" begin + if forward + test_frule_performance(perf_flag, frule, x_ẋ...) + else + test_rrule_performance(perf_flag, rrule, x_x̄...) + end + end - # Test the interface again, in order to verify that caching is working correctly. - if forward - test_frule_interface(x_ẋ...; frule=Mooncake.build_frule(interp, sig; debug_mode)) - else - test_rrule_interface(x_x̄...; rrule=Mooncake.build_rrule(interp, sig; debug_mode)) + # Test the interface again, in order to verify that caching is working correctly. + @testset "Interface (2)" begin + if forward + test_frule_interface( + x_ẋ...; frule=Mooncake.build_frule(interp, sig; debug_mode) + ) + else + test_rrule_interface( + x_x̄...; rrule=Mooncake.build_rrule(interp, sig; debug_mode) + ) + end + end end - return nothing + return testset end function run_hand_written_rrule!!_test_cases(rng_ctor, v::Val) diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index 5b1440e7a..99e3d01f4 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -1,12 +1,6 @@ -# using MistyClosures -# using Mooncake -# using Test -# using Core.Compiler: SSAValue -# const CC = Core.Compiler - -@testset "s2s_forward_mode_ad" begin - test_cases = collect(enumerate(TestResources.generate_test_functions()))[3:4] - @testset "$(_typeof((f, x...)))" for (n, (int_only, pf, _, f, x...)) in test_cases +@testset verbose = true "s2s_forward_mode_ad" begin + test_cases = collect(enumerate(TestResources.generate_test_functions()))[begin:5] + @testset "$n: $(_typeof((f, x...)))" for (n, (int_only, pf, _, f, x...)) in test_cases sig = _typeof((f, x...)) @info "$n: $sig" TestUtils.test_rule( @@ -19,34 +13,4 @@ forward=true, ) end -end - -# #= -# x, dx = 2.0, 3.0 -# xdual = Dual(x, dx) - -# sin_rule = build_frule(sin, x) -# ydual = sin_rule(zero_dual(sin), xdual) - -# @test primal(ydual) == sin(x) -# @test tangent(ydual) == dx * cos(x) -# =# - -# function func2(x) -# if x > 0.0 -# y = sin(x) -# else -# y = cos(x) -# end -# return y -# end - -# x = 1.0 -# xdual = Dual(1.0, 2.0) - -# ir = Base.code_ircode(func2, (typeof(x),))[1][1] - -# func_rule = build_frule(func2, x) -# ydual = func_rule(zero_dual(func2), xdual) - -# 2cos(1) +end; From 66a48c8df68f2b0fd00047023750f443cda7dad0 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 18:51:06 +0100 Subject: [PATCH 31/32] Handle PiNode --- src/dual.jl | 2 ++ src/interpreter/bbcode.jl | 3 ++- src/interpreter/s2s_forward_mode_ad.jl | 32 +++++++++++++++++++++---- src/rrules/builtins.jl | 10 ++++++++ src/tools_for_rules.jl | 6 ++--- test/interpreter/s2s_forward_mode_ad.jl | 11 +++++++-- 6 files changed, 54 insertions(+), 10 deletions(-) diff --git a/src/dual.jl b/src/dual.jl index ae802ab60..e8a99ed35 100644 --- a/src/dual.jl +++ b/src/dual.jl @@ -34,3 +34,5 @@ _dual(x::Dual) = x Check that the type of `tangent(x)` is the tangent type of the type of `primal(x)`. """ verify_dual_type(x::Dual) = tangent_type(typeof(primal(x))) == typeof(tangent(x)) + +@inline uninit_dual(x::P) where {P} = Dual(x, uninit_tangent(x)) diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl index 1d0e5c618..b93ac4b33 100644 --- a/src/interpreter/bbcode.jl +++ b/src/interpreter/bbcode.jl @@ -889,8 +889,9 @@ Increment by `1` the `n` field of any `Argument`s present in `stmt`. """ inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x +inc_args(x::GotoIfNot) = GotoIfNot(__inc(x.cond), x.dest) inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) -inc_args(x::IDGotoNode) = x +inc_args(x::PiNode) = PiNode(__inc(x.val), x.typ) function inc_args(x::PhiNode) new_values = Vector{Any}(undef, length(x.values)) for n in eachindex(x.values) diff --git a/src/interpreter/s2s_forward_mode_ad.jl b/src/interpreter/s2s_forward_mode_ad.jl index 8eff21206..90d3d3943 100644 --- a/src/interpreter/s2s_forward_mode_ad.jl +++ b/src/interpreter/s2s_forward_mode_ad.jl @@ -187,11 +187,35 @@ function modify_fwd_ad_stmts!( return nothing end +function modify_fwd_ad_stmts!( + dual_ir::CC.IncrementalCompact, + primal_ir::IRCode, + ::MooncakeInterpreter, + stmt::PiNode, + i::Integer; + kwargs..., +) + dual_ir[SSAValue(i)][:stmt] = inc_args( + PiNode(stmt.val, Dual{stmt.typ,tangent_type(stmt.typ)}) + ) # TODO: improve? + dual_ir[SSAValue(i)][:type] = Any + dual_ir[SSAValue(i)][:flag] = CC.IR_FLAG_REFINED + return nothing +end + ## Modification of IR nodes - expressions +struct DualArguments{FR} + frule::FR +end + +function Base.show(io::IO, da::DualArguments) + return print(io, "DualArguments($(da.frule))") +end + # TODO: wrapping in Dual must not be systematic (e.g. Argument or SSAValue) -function _frule!!_makedual(f::F, args::Vararg{Any,N}) where {F,N} - return frule!!(tuple_map(_dual, (f, args...))...) +function (da::DualArguments)(f::F, args::Vararg{Any,N}) where {F,N} + return da.frule(tuple_map(_dual, (f, args...))...) end struct DynamicFRule{V} @@ -238,7 +262,7 @@ function modify_fwd_ad_stmts!( inc_args(stmt).args end if is_primitive(context_type(interp), sig) - call_frule = Expr(:call, _frule!!_makedual, shifted_args...) + call_frule = Expr(:call, DualArguments(frule!!), shifted_args...) replace_call!(dual_ir, SSAValue(i), call_frule) else if isexpr(stmt, :invoke) @@ -248,7 +272,7 @@ function modify_fwd_ad_stmts!( rule = DynamicFRule(debug_mode) end # TODO: could this insertion of a naked rule in the IR cause a memory leak? - call_rule = Expr(:call, rule, shifted_args...) + call_rule = Expr(:call, DualArguments(rule), shifted_args...) replace_call!(dual_ir, SSAValue(i), call_rule) end elseif isexpr(stmt, :boundscheck) diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index af6f61195..230fee015 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -164,6 +164,11 @@ function rrule!!(::CoDual{typeof(abs_float)}, x) end @intrinsic add_float +function frule!!(::Dual{typeof(add_float)}, a, b) + c = add_float(primal(a), primal(b)) + d = add_float(tangent(a), tangent(b)) + return Dual(c, d) +end function rrule!!(::CoDual{typeof(add_float)}, a, b) add_float_pb!!(c̄) = NoRData(), c̄, c̄ c = add_float(primal(a), primal(b)) @@ -350,6 +355,11 @@ end @inactive_intrinsic lt_float_fast @intrinsic mul_float +function frule!!(::Dual{typeof(mul_float)}, a, b) + p = mul_float(primal(a), primal(b)) + dp = add_float(mul_float(primal(a), tangent(b)), mul_float(primal(b), tangent(a))) + return Dual(p, dp) +end function rrule!!(::CoDual{typeof(mul_float)}, a, b) _a = primal(a) _b = primal(b) diff --git a/src/tools_for_rules.jl b/src/tools_for_rules.jl index 65ee34368..dfe8fa738 100644 --- a/src/tools_for_rules.jl +++ b/src/tools_for_rules.jl @@ -18,7 +18,7 @@ function parse_signature_expr(sig::Expr) return arg_type_symbols, where_params end -function construct_def(arg_names, arg_types, where_params, body) +function construct_rrule_def(arg_names, arg_types, where_params, body) name = :(Mooncake.rrule!!) arg_exprs = map((n, t) -> :($n::$t), arg_names, arg_types) def = Dict(:head => :function, :name => name, :args => arg_exprs, :body => body) @@ -216,7 +216,7 @@ macro zero_adjoint(ctx, sig) # Return code to create a method of is_primitive and a rule. ex = quote Mooncake.is_primitive(::Type{$ctx}, ::Type{<:$sig}) = true - $(construct_def(arg_names, arg_types, where_params, body)) + $(construct_rrule_def(arg_names, arg_types, where_params, body)) end return esc(ex) end @@ -330,7 +330,7 @@ end function construct_rrule_wrapper_def(arg_names, arg_types, where_params) body = Expr(:call, rrule_wrapper, arg_names...) - return construct_def(arg_names, arg_types, where_params, body) + return construct_rrule_def(arg_names, arg_types, where_params, body) end @doc """ diff --git a/test/interpreter/s2s_forward_mode_ad.jl b/test/interpreter/s2s_forward_mode_ad.jl index 99e3d01f4..3673c67a0 100644 --- a/test/interpreter/s2s_forward_mode_ad.jl +++ b/test/interpreter/s2s_forward_mode_ad.jl @@ -1,6 +1,13 @@ +#= +Failing cases: +- 7: need help for frule of getfield +- 10: need help to adapt @zero_adjoint to forward mode +=# +working_cases = vcat(1:6, 8:9) + @testset verbose = true "s2s_forward_mode_ad" begin - test_cases = collect(enumerate(TestResources.generate_test_functions()))[begin:5] - @testset "$n: $(_typeof((f, x...)))" for (n, (int_only, pf, _, f, x...)) in test_cases + test_cases = collect(enumerate(TestResources.generate_test_functions()))[working_cases] + @testset "$(_typeof((f, x...)))" for (n, (int_only, pf, _, f, x...)) in test_cases sig = _typeof((f, x...)) @info "$n: $sig" TestUtils.test_rule( From e455cf6d68b39b4c8f1705a3a28dc06b3dab204e Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Fri, 6 Dec 2024 19:17:52 +0100 Subject: [PATCH 32/32] Deleted line --- src/interpreter/bbcode.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl index b93ac4b33..65033713b 100644 --- a/src/interpreter/bbcode.jl +++ b/src/interpreter/bbcode.jl @@ -891,6 +891,7 @@ inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x inc_args(x::GotoIfNot) = GotoIfNot(__inc(x.cond), x.dest) inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) +inc_args(x::IDGotoNode) = x inc_args(x::PiNode) = PiNode(__inc(x.val), x.typ) function inc_args(x::PhiNode) new_values = Vector{Any}(undef, length(x.values))