Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forwards-Mode Design Docs #386

Merged
merged 15 commits into from
Nov 27, 2024
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.50"
version = "0.4.51"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ makedocs(;
"Developer Documentation" => [
joinpath("developer_documentation", "running_tests_locally.md"),
joinpath("developer_documentation", "developer_tools.md"),
joinpath("developer_documentation", "forwards_mode_design.md"),
joinpath("developer_documentation", "internal_docstrings.md"),
],
"known_limitations.md",
Expand Down
347 changes: 347 additions & 0 deletions docs/src/developer_documentation/forwards_mode_design.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/src/developer_documentation/internal_docstrings.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ The purpose of this is to make it easy for developers to find docstrings straigh
Modules = [Mooncake]
Public = false
```

```@docs
Mooncake.IntrinsicsWrappers
```
2 changes: 1 addition & 1 deletion src/debug_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ for `DebugRRule` for details.
return y::CoDual, DebugPullback(pb, primal(y), map(primal, x))
end

# DerivedRRule adds a method to this function.
# DerivedRule adds a method to this function.
verify_args(_, x) = nothing

@noinline function verify_fwds_inputs(rule, @nospecialize(x::Tuple))
Expand Down
13 changes: 13 additions & 0 deletions src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ If anything else, just return `inst`. See `Mooncake._foreigncall_` for details.
to be called in the context of an `IRCode`, in which case the values of `sp_map` are given
by the `sptypes` field of said `IRCode`. The keys should generally be obtained from the
`Method` from which the `IRCode` is derived. See `Mooncake.normalise!` for more details.

The purpose of this transformation is to make it possible to differentiate `:foreigncall`
expressions in the same way as a primitive `:call` expression, i.e. via an `rrule!!`.
"""
function foreigncall_to_call(inst, sp_map::Dict{Symbol,CC.VarState})
if Meta.isexpr(inst, :foreigncall)
Expand Down Expand Up @@ -146,6 +149,9 @@ end

If instruction `x` is a `:new` expression, replace it with a `:call` to `Mooncake._new_`.
Otherwise, return `x`.

The purpose of this transformation is to make it possible to differentiate `:new`
expressions in the same way as a primitive `:call` expression, i.e. via an `rrule!!`.
"""
new_to_call(x) = Meta.isexpr(x, :new) ? Expr(:call, _new_, x.args...) : x

Expand All @@ -154,6 +160,9 @@ new_to_call(x) = Meta.isexpr(x, :new) ? Expr(:call, _new_, x.args...) : x

If instruction `x` is a `:splatnew` expression, replace it with a `:call` to
`Mooncake._splat_new_`. Otherwise return `x`.

The purpose of this transformation is to make it possible to differentiate `:splatnew`
expressions in the same way as a primitive `:call` expression, i.e. via an `rrule!!`.
"""
splatnew_to_call(x) = Meta.isexpr(x, :splatnew) ? Expr(:call, _splat_new_, x.args...) : x

Expand All @@ -165,6 +174,10 @@ the corresponding `function` from `Mooncake.IntrinsicsWrappers`, else return `in

`cglobal` is a special case -- it requires that its first argument be static in exactly the
same way as `:foreigncall`. See `IntrinsicsWrappers.__cglobal` for more info.

The purpose of this transformation is to make it possible to use dispatch to write rules for
intrinsic calls using dispatch in a type-stable way. See [`IntrinsicsWrappers`](@ref) for
more context.
"""
function intrinsic_to_function(inst)
return Meta.isexpr(inst, :call) ? Expr(:call, lift_intrinsic(inst.args...)...) : inst
Expand Down
41 changes: 41 additions & 0 deletions src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,47 @@ function Base.showerror(io::IO, err::MissingRuleForBuiltinException)
return println(io, err.msg)
end

"""
module IntrinsicsWrappers

The purpose of this `module` is to associate to each function in `Core.Intrinsics` a regular
Julia function.

To understand the rationale for this observe that, unlike regular Julia functions, each
`Core.IntrinsicFunction` in `Core.Intrinsics` does _not_ have its own type. Rather, they
are instances of `Core.IntrinsicFunction`. To see this, observe that
```jldoctest
julia> typeof(Core.Intrinsics.add_float)
Core.IntrinsicFunction

julia> typeof(Core.Intrinsics.sub_float)
Core.IntrinsicFunction
```

While we could simply write a rule for `Core.IntrinsicFunction`, this would (naively) lead
to a large list of conditionals of the form
```julia
if f === Core.Intrinsics.add_float
# return add_float and its pullback
elseif f === Core.Intrinsics.sub_float
# return add_float and its pullback
elseif
...
end
```
which has the potential to cause quite substantial type instabilities.
(This might not be true anymore -- see extended help for more context).

Instead, we map each `Core.IntrinsicFunction` to one of the regular Julia functions in
`Mooncake.IntrinsicsWrappers`, to which we can dispatch in the usual way.

# Extended Help

It is possible that owing to improvements in constant propagation in the Julia compiler in
version 1.10, we actually _could_ get away with just writing a single method of `rrule!!` to
handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should
investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .
"""
module IntrinsicsWrappers

using Base: IEEEFloat
Expand Down
8 changes: 4 additions & 4 deletions test/ext/special_functions/special_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ using Mooncake.TestUtils: test_rule
test_rule(StableRNG(123456), f, x...; perf_flag)
end
@testset for (perf_flag, f, x...) in [
(:allocs, logerf, 0.3, 0.5), # first branch
(:allocs, logerf, 1.1, 1.2), # second branch
(:allocs, logerf, -1.2, -1.1), # third branch
(:allocs, logerf, 0.3, 1.1), # fourth branch
(:none, logerf, 0.3, 0.5), # first branch
(:none, logerf, 1.1, 1.2), # second branch
(:none, logerf, -1.2, -1.1), # third branch
(:none, logerf, 0.3, 1.1), # fourth branch
(:allocs, SpecialFunctions.loggammadiv, 1.0, 9.0),
(:allocs, SpecialFunctions.gammax, 1.0),
(:allocs, SpecialFunctions.rgammax, 3.0, 6.0),
Expand Down
Loading