Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Start forward mode AD #389

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
be316ff
Start forward mode prototype
gdalle Nov 24, 2024
deac913
First working autodiff
gdalle Nov 24, 2024
9c96c8d
Docstring
gdalle Nov 24, 2024
136aff6
Apply suggestions from code review
gdalle Nov 24, 2024
f65cc53
Moving files around
gdalle Nov 24, 2024
053a8bb
Primitives already known
gdalle Nov 24, 2024
6d8ec04
Merge branch 'main' into gd/forward
gdalle Nov 25, 2024
a3107a8
Keep pushing forward (pun intended)
gdalle Nov 25, 2024
2836ac8
Still buggy, don't touch
gdalle Nov 25, 2024
09d63bd
Keep instruction mapping one to one
gdalle Nov 26, 2024
fa679eb
Use replace_call
gdalle Nov 26, 2024
a68257c
Ignore code cov
gdalle Nov 27, 2024
7a096ba
No Aqua piracies test
gdalle Nov 27, 2024
46c3e5a
Start control flow
gdalle Nov 28, 2024
ad3f98a
Fix intrinsic
gdalle Nov 28, 2024
9071574
Import
gdalle Nov 28, 2024
dcfe282
Typos
gdalle Nov 28, 2024
e44380d
Co-authored-by: Will Tebbutt <[email protected]>
gdalle Dec 6, 2024
dd89e57
Figure out incremental additions
gdalle Dec 6, 2024
9bdb57f
Initial test case additions
willtebbutt Dec 6, 2024
4bb9911
Formatting
willtebbutt Dec 6, 2024
9b037e7
Add verify_dual_type
willtebbutt Dec 6, 2024
6dea624
test_frule_interface runs
willtebbutt Dec 6, 2024
a614846
Fix ReturnNode
willtebbutt Dec 6, 2024
eadae95
Correctness testing runs
willtebbutt Dec 6, 2024
345b3fd
Add randn_dual
willtebbutt Dec 6, 2024
f58c394
Improve sin and cos frules
willtebbutt Dec 6, 2024
c8d8895
Performance tests run
willtebbutt Dec 6, 2024
578e41b
Tidy up implementation
willtebbutt Dec 6, 2024
b5d34b2
Standard testing infrastructure
willtebbutt Dec 6, 2024
205e716
Fix typos
willtebbutt Dec 6, 2024
d328db0
Fix return node to return dual
gdalle Dec 6, 2024
66a48c8
Handle PiNode
gdalle Dec 6, 2024
e455cf6
Deleted line
gdalle Dec 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions src/Mooncake.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Base:
arrayset, TwicePrecision, twiceprecision
using Base.Experimental: @opaque
using Base.Iterators: product
using Base.Meta: isexpr
using Core:
Intrinsics, bitcast, SimpleVector, svec, ReturnNode, GotoNode, GotoIfNot, PhiNode,
PiNode, SSAValue, Argument, OpaqueClosure, compilerbarrier
Expand All @@ -34,6 +35,13 @@ using FunctionWrappers: FunctionWrapper
# Needs to be defined before various other things.
function _foreigncall_ end

"""
frule!!(f::Dual, x::Dual...)

Performs AD in forward mode, possibly modifying the inputs, and returns a `Dual`.
"""
function frule!! end

"""
rrule!!(f::CoDual, x::CoDual...)

Expand Down Expand Up @@ -61,6 +69,7 @@ function rrule!! end

include("utils.jl")
include("tangents.jl")
include("dual.jl")
include("fwds_rvs_data.jl")
include("codual.jl")
include("debug_mode.jl")
Expand All @@ -72,12 +81,15 @@ include(joinpath("interpreter", "ir_utils.jl"))
include(joinpath("interpreter", "bbcode.jl"))
include(joinpath("interpreter", "ir_normalisation.jl"))
include(joinpath("interpreter", "zero_like_rdata.jl"))
include(joinpath("interpreter", "s2s_forward_mode_ad.jl"))
include(joinpath("interpreter", "s2s_reverse_mode_ad.jl"))

include("tools_for_rules.jl")
include("test_utils.jl")
include("test_resources.jl")

include(joinpath("frules", "basic.jl"))

include(joinpath("rrules", "avoiding_non_differentiable_code.jl"))
include(joinpath("rrules", "blas.jl"))
include(joinpath("rrules", "builtins.jl"))
Expand Down Expand Up @@ -118,9 +130,13 @@ export
_add_to_primal,
_diff,
_dot,
Dual,
zero_dual,
zero_codual,
codual_type,
frule!!,
rrule!!,
build_frule,
build_rrule,
value_and_gradient!!,
value_and_pullback!!,
Expand Down
1 change: 1 addition & 0 deletions src/debug_mode.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
DebugFRule(rule) = rule # TODO: make it non-trivial

Check warning on line 1 in src/debug_mode.jl

View check run for this annotation

Codecov / codecov/patch

src/debug_mode.jl#L1

Added line #L1 was not covered by tests

"""
DebugPullback(pb, y, x)
Expand Down
13 changes: 13 additions & 0 deletions src/dual.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
struct Dual{P,T}
primal::P
tangent::T
end

primal(x::Dual) = x.primal
tangent(x::Dual) = x.tangent
Base.copy(x::Dual) = Dual(copy(primal(x)), copy(tangent(x)))
_copy(x::P) where {P<:Dual} = x

Check warning on line 9 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L6-L9

Added lines #L6 - L9 were not covered by tests

zero_dual(x) = Dual(x, zero_tangent(x))

Check warning on line 11 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L11

Added line #L11 was not covered by tests

dual_type(::Type{P}) where {P} = Dual{P,tangent_type(P)}

Check warning on line 13 in src/dual.jl

View check run for this annotation

Codecov / codecov/patch

src/dual.jl#L13

Added line #L13 was not covered by tests
11 changes: 11 additions & 0 deletions src/frules/basic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
frule!!(f::F, args::Vararg{Dual,N}) where {F,N} = frule!!(zero_dual(f), args...)

Check warning on line 1 in src/frules/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/frules/basic.jl#L1

Added line #L1 was not covered by tests
gdalle marked this conversation as resolved.
Show resolved Hide resolved

@is_primitive MinimalCtx Tuple{typeof(sin),Number}
gdalle marked this conversation as resolved.
Show resolved Hide resolved
function frule!!(::Dual{typeof(sin)}, x::Dual{<:Number})
return Dual(sin(primal(x)), cos(primal(x)) * tangent(x))

Check warning on line 5 in src/frules/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/frules/basic.jl#L4-L5

Added lines #L4 - L5 were not covered by tests
end

@is_primitive MinimalCtx Tuple{typeof(cos),Number}
function frule!!(::Dual{typeof(cos)}, x::Dual{<:Number})
return Dual(cos(primal(x)), -sin(primal(x)) * tangent(x))

Check warning on line 10 in src/frules/basic.jl

View check run for this annotation

Codecov / codecov/patch

src/frules/basic.jl#L9-L10

Added lines #L9 - L10 were not covered by tests
end
147 changes: 147 additions & 0 deletions src/interpreter/s2s_forward_mode_ad.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
function build_frule(args...; debug_mode=false)
interp = get_interpreter()
sig = _typeof(TestUtils.__get_primals(args))
return build_frule(interp, sig; debug_mode)

Check warning on line 4 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L1-L4

Added lines #L1 - L4 were not covered by tests
end

function build_frule(

Check warning on line 7 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L7

Added line #L7 was not covered by tests
interp::MooncakeInterpreter{C},
sig_or_mi;
debug_mode=false,
silence_debug_messages=true,
) where {C}
# To avoid segfaults, ensure that we bail out if the interpreter's world age is greater
# than the current world age.
if Base.get_world_counter() > interp.world
throw(

Check warning on line 16 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L15-L16

Added lines #L15 - L16 were not covered by tests
ArgumentError(
"World age associated to interp is behind current world age. Please " *
"a new interpreter for the current world age.",
),
)
end

# If we're compiling in debug mode, let the user know by default.
if !silence_debug_messages && debug_mode
@info "Compiling rule for $sig_or_mi in debug mode. Disable for best performance."

Check warning on line 26 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L25-L26

Added lines #L25 - L26 were not covered by tests
end

# If we have a hand-coded rule, just use that.
_is_primitive(C, sig_or_mi) && return (debug_mode ? DebugFRule(frule!!) : frule!!)

Check warning on line 30 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L30

Added line #L30 was not covered by tests


# We don't have a hand-coded rule, so derived one.
lock(MOONCAKE_INFERENCE_LOCK)
try

Check warning on line 35 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L34-L35

Added lines #L34 - L35 were not covered by tests
# If we've already derived the OpaqueClosures and info, do not re-derive, just
# create a copy and pass in new shared data.
oc_cache_key = ClosureCacheKey(interp.world, (sig_or_mi, debug_mode))
if haskey(interp.oc_cache, oc_cache_key)
return _copy(interp.oc_cache[oc_cache_key])

Check warning on line 40 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L38-L40

Added lines #L38 - L40 were not covered by tests
else
# Derive forward-pass IR, and shove in a `MistyClosure`.
forward_ir = generate_forward_ir(interp, sig_or_mi; debug_mode)
fwd_oc = MistyClosure(forward_ir; do_compile=true)
raw_rule = DerivedFRule(fwd_oc)
rule = debug_mode ? DebugFRule(raw_rule) : raw_rule
interp.oc_cache[oc_cache_key] = rule
return rule

Check warning on line 48 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L43-L48

Added lines #L43 - L48 were not covered by tests
end
catch e
rethrow(e)

Check warning on line 51 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L51

Added line #L51 was not covered by tests
finally
unlock(MOONCAKE_INFERENCE_LOCK)

Check warning on line 53 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L53

Added line #L53 was not covered by tests
end
end

function generate_forward_ir(

Check warning on line 57 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L57

Added line #L57 was not covered by tests
interp::MooncakeInterpreter,
sig_or_mi;
debug_mode=false,
do_inline=true,
)
# Reset id count. This ensures that the IDs generated are the same each time this
# function runs.
seed_id!()

Check warning on line 65 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L65

Added line #L65 was not covered by tests

# Grab code associated to the primal.
primal_ir, _ = lookup_ir(interp, sig_or_mi)

Check warning on line 68 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L68

Added line #L68 was not covered by tests

# Normalise the IR.
isva, spnames = is_vararg_and_sparam_names(sig_or_mi)
ir = normalise!(primal_ir, spnames)

Check warning on line 72 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L71-L72

Added lines #L71 - L72 were not covered by tests

fwd_ir = dualize_ir(ir)
opt_fwd_ir = optimise_ir!(fwd_ir; do_inline)
return opt_fwd_ir

Check warning on line 76 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L74-L76

Added lines #L74 - L76 were not covered by tests
end

function dualize_ir(ir::IRCode)
new_stmts_stmt = map(make_fwd_ad_stmt, ir.stmts.stmt)
new_stmts_type = map(dual_type, ir.stmts.type)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
new_stmts_info = ir.stmts.info
new_stmts_line = ir.stmts.line
new_stmts_flag = ir.stmts.flag
gdalle marked this conversation as resolved.
Show resolved Hide resolved
new_stmts = CC.InstructionStream(

Check warning on line 85 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L79-L85

Added lines #L79 - L85 were not covered by tests
new_stmts_stmt,
new_stmts_type,
new_stmts_info,
new_stmts_line,
new_stmts_flag,
)
new_cfg = ir.cfg
new_linetable = ir.linetable
rule_type = Any
new_argtypes = convert(Vector{Any}, vcat(rule_type, map(make_fwd_argtype, ir.argtypes)))
new_meta = ir.meta
new_sptypes = ir.sptypes
return IRCode(new_stmts, new_cfg, new_linetable, new_argtypes, new_meta, new_sptypes)

Check warning on line 98 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L92-L98

Added lines #L92 - L98 were not covered by tests
end

make_fwd_argtype(::Type{P}) where {P} = dual_type(P)
make_fwd_argtype(c::Core.Const) = Dual # TODO: refine to type of const

Check warning on line 102 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L101-L102

Added lines #L101 - L102 were not covered by tests

function make_fwd_ad_stmt(stmt::Expr)
interp = get_interpreter() # TODO: pass it around
C = context_type(interp)
if isexpr(stmt, :invoke) || isexpr(stmt, :call)
mi = stmt.args[1]::Core.MethodInstance
sig = mi.specTypes
if is_primitive(C, sig)
shifted_args = map(stmt.args) do a
if a isa Core.Argument
Core.Argument(a.n + 1)

Check warning on line 113 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L104-L113

Added lines #L104 - L113 were not covered by tests
else
a

Check warning on line 115 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L115

Added line #L115 was not covered by tests
end
end
gdalle marked this conversation as resolved.
Show resolved Hide resolved
new_stmt = Expr(

Check warning on line 118 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L118

Added line #L118 was not covered by tests
:call,
:($frule!!),
stmt.args[2],
shifted_args[3:end]...
gdalle marked this conversation as resolved.
Show resolved Hide resolved
)
return new_stmt

Check warning on line 124 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L124

Added line #L124 was not covered by tests
else
throw(ArgumentError("Recursing into non-primitive calls is not yet supported in forward mode"))

Check warning on line 126 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L126

Added line #L126 was not covered by tests
end
return stmt

Check warning on line 128 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L128

Added line #L128 was not covered by tests
else
throw(ArgumentError("Expressions of type `:$(stmt.head)` are not yet supported in forward mode"))

Check warning on line 130 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L130

Added line #L130 was not covered by tests
end
return stmt

Check warning on line 132 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L132

Added line #L132 was not covered by tests
gdalle marked this conversation as resolved.
Show resolved Hide resolved
end

function make_fwd_ad_stmt(stmt::ReturnNode)
return stmt

Check warning on line 136 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L135-L136

Added lines #L135 - L136 were not covered by tests
end

struct DerivedFRule{Tfwd_oc}
fwd_oc::Tfwd_oc
end

_copy(rule::DerivedFRule) = deepcopy(rule)

Check warning on line 143 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L143

Added line #L143 was not covered by tests

@inline function (fwd::DerivedFRule)(args::Vararg{Dual,N}) where {N}
return fwd.fwd_oc.oc(args...)

Check warning on line 146 in src/interpreter/s2s_forward_mode_ad.jl

View check run for this annotation

Codecov / codecov/patch

src/interpreter/s2s_forward_mode_ad.jl#L145-L146

Added lines #L145 - L146 were not covered by tests
end
2 changes: 1 addition & 1 deletion src/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -470,7 +470,7 @@ function test_rrule_performance(
end
end

__get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs)
__get_primals(xs) = map(x -> x isa Union{Dual,CoDual} ? primal(x) : x, xs)

@doc"""
test_rule(
Expand Down
27 changes: 27 additions & 0 deletions test/forward.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
using Mooncake
gdalle marked this conversation as resolved.
Show resolved Hide resolved
using Test

x, dx = 2.0, 3.0
xdual = Dual(x, dx)

@testset "Manual frule" begin
sin_rule = build_frule(sin, x)
ydual = sin_rule(zero_dual(sin), xdual)

@test primal(ydual) == sin(x)
@test tangent(ydual) == dx * cos(x)
end

function func(x)
y = sin(x)
z = cos(y)
return z
end

@testset "Automatic frule" begin
func_rule = build_frule(func, x)
ydual = func_rule(zero_dual(func), xdual)

@test primal(ydual) == cos(sin(x))
@test tangent(ydual) ≈ dx * -sin(sin(x)) * cos(x)
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ include("front_matter.jl")
include("config.jl")
include("developer_tools.jl")
include("test_utils.jl")
include("forward.jl")
elseif test_group == "rrules/avoiding_non_differentiable_code"
include(joinpath("rrules", "avoiding_non_differentiable_code.jl"))
elseif test_group == "rrules/blas"
Expand Down
Loading