Skip to content

Commit

Permalink
Forwards-Mode Design Docs (#386)
Browse files Browse the repository at this point in the history
* Partial forwards-mode design doc

* Improve ir translation docstrings

* More docs

* Improve intrinsic_to_function docstring

* Tidy up

* Discuss batched mode

* Discuss abstractions more precisely

* Clarify that compile time means rule compliation time

* Comparison with ForwardDiff

* Small tidy up

* Loosen perf bounds

* Typo

* Minor typos

* Bump patch version
  • Loading branch information
willtebbutt authored Nov 27, 2024
1 parent d9a6952 commit 0f37c07
Show file tree
Hide file tree
Showing 8 changed files with 412 additions and 6 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.50"
version = "0.4.51"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ makedocs(;
"Developer Documentation" => [
joinpath("developer_documentation", "running_tests_locally.md"),
joinpath("developer_documentation", "developer_tools.md"),
joinpath("developer_documentation", "forwards_mode_design.md"),
joinpath("developer_documentation", "internal_docstrings.md"),
],
"known_limitations.md",
Expand Down
347 changes: 347 additions & 0 deletions docs/src/developer_documentation/forwards_mode_design.md

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions docs/src/developer_documentation/internal_docstrings.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,7 @@ The purpose of this is to make it easy for developers to find docstrings straigh
Modules = [Mooncake]
Public = false
```

```@docs
Mooncake.IntrinsicsWrappers
```
2 changes: 1 addition & 1 deletion src/debug_mode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ for `DebugRRule` for details.
return y::CoDual, DebugPullback(pb, primal(y), map(primal, x))
end

# DerivedRRule adds a method to this function.
# DerivedRule adds a method to this function.
verify_args(_, x) = nothing

@noinline function verify_fwds_inputs(rule, @nospecialize(x::Tuple))
Expand Down
13 changes: 13 additions & 0 deletions src/interpreter/ir_normalisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ If anything else, just return `inst`. See `Mooncake._foreigncall_` for details.
to be called in the context of an `IRCode`, in which case the values of `sp_map` are given
by the `sptypes` field of said `IRCode`. The keys should generally be obtained from the
`Method` from which the `IRCode` is derived. See `Mooncake.normalise!` for more details.
The purpose of this transformation is to make it possible to differentiate `:foreigncall`
expressions in the same way as a primitive `:call` expression, i.e. via an `rrule!!`.
"""
function foreigncall_to_call(inst, sp_map::Dict{Symbol,CC.VarState})
if Meta.isexpr(inst, :foreigncall)
Expand Down Expand Up @@ -146,6 +149,9 @@ end
If instruction `x` is a `:new` expression, replace it with a `:call` to `Mooncake._new_`.
Otherwise, return `x`.
The purpose of this transformation is to make it possible to differentiate `:new`
expressions in the same way as a primitive `:call` expression, i.e. via an `rrule!!`.
"""
new_to_call(x) = Meta.isexpr(x, :new) ? Expr(:call, _new_, x.args...) : x

Expand All @@ -154,6 +160,9 @@ new_to_call(x) = Meta.isexpr(x, :new) ? Expr(:call, _new_, x.args...) : x
If instruction `x` is a `:splatnew` expression, replace it with a `:call` to
`Mooncake._splat_new_`. Otherwise return `x`.
The purpose of this transformation is to make it possible to differentiate `:splatnew`
expressions in the same way as a primitive `:call` expression, i.e. via an `rrule!!`.
"""
splatnew_to_call(x) = Meta.isexpr(x, :splatnew) ? Expr(:call, _splat_new_, x.args...) : x

Expand All @@ -165,6 +174,10 @@ the corresponding `function` from `Mooncake.IntrinsicsWrappers`, else return `in
`cglobal` is a special case -- it requires that its first argument be static in exactly the
same way as `:foreigncall`. See `IntrinsicsWrappers.__cglobal` for more info.
The purpose of this transformation is to make it possible to use dispatch to write rules for
intrinsic calls using dispatch in a type-stable way. See [`IntrinsicsWrappers`](@ref) for
more context.
"""
function intrinsic_to_function(inst)
return Meta.isexpr(inst, :call) ? Expr(:call, lift_intrinsic(inst.args...)...) : inst
Expand Down
41 changes: 41 additions & 0 deletions src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,47 @@ function Base.showerror(io::IO, err::MissingRuleForBuiltinException)
return println(io, err.msg)
end

"""
module IntrinsicsWrappers
The purpose of this `module` is to associate to each function in `Core.Intrinsics` a regular
Julia function.
To understand the rationale for this observe that, unlike regular Julia functions, each
`Core.IntrinsicFunction` in `Core.Intrinsics` does _not_ have its own type. Rather, they
are instances of `Core.IntrinsicFunction`. To see this, observe that
```jldoctest
julia> typeof(Core.Intrinsics.add_float)
Core.IntrinsicFunction
julia> typeof(Core.Intrinsics.sub_float)
Core.IntrinsicFunction
```
While we could simply write a rule for `Core.IntrinsicFunction`, this would (naively) lead
to a large list of conditionals of the form
```julia
if f === Core.Intrinsics.add_float
# return add_float and its pullback
elseif f === Core.Intrinsics.sub_float
# return add_float and its pullback
elseif
...
end
```
which has the potential to cause quite substantial type instabilities.
(This might not be true anymore -- see extended help for more context).
Instead, we map each `Core.IntrinsicFunction` to one of the regular Julia functions in
`Mooncake.IntrinsicsWrappers`, to which we can dispatch in the usual way.
# Extended Help
It is possible that owing to improvements in constant propagation in the Julia compiler in
version 1.10, we actually _could_ get away with just writing a single method of `rrule!!` to
handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should
investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .
"""
module IntrinsicsWrappers

using Base: IEEEFloat
Expand Down
8 changes: 4 additions & 4 deletions test/ext/special_functions/special_functions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ using Mooncake.TestUtils: test_rule
test_rule(StableRNG(123456), f, x...; perf_flag)
end
@testset for (perf_flag, f, x...) in [
(:allocs, logerf, 0.3, 0.5), # first branch
(:allocs, logerf, 1.1, 1.2), # second branch
(:allocs, logerf, -1.2, -1.1), # third branch
(:allocs, logerf, 0.3, 1.1), # fourth branch
(:none, logerf, 0.3, 0.5), # first branch
(:none, logerf, 1.1, 1.2), # second branch
(:none, logerf, -1.2, -1.1), # third branch
(:none, logerf, 0.3, 1.1), # fourth branch
(:allocs, SpecialFunctions.loggammadiv, 1.0, 9.0),
(:allocs, SpecialFunctions.gammax, 1.0),
(:allocs, SpecialFunctions.rgammax, 3.0, 6.0),
Expand Down

2 comments on commit 0f37c07

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120254

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.51 -m "<description of version>" 0f37c079bd1ae064e7b84696eed4a1f7eb763f1f
git push origin v0.4.51

Please sign in to comment.