From c17774ba6b48c4f2ebb40de2fbe048f9527ecabd Mon Sep 17 00:00:00 2001 From: "Documenter.jl" Date: Wed, 27 Nov 2024 10:47:56 +0000 Subject: [PATCH] build based on 73e1ec6 --- previews/PR386/.documenter-siteinfo.json | 2 +- .../developer_tools/index.html | 6 +- .../forwards_mode_design/index.html | 6 +- .../internal_docstrings/index.html | 54 +++++++++--------- .../running_tests_locally/index.html | 2 +- previews/PR386/index.html | 2 +- previews/PR386/known_limitations/index.html | 2 +- previews/PR386/objects.inv | Bin 6774 -> 6819 bytes previews/PR386/search_index.js | 2 +- .../algorithmic_differentiation/index.html | 2 +- .../introduction/index.html | 2 +- .../rule_system/index.html | 6 +- .../PR386/utilities/debug_mode/index.html | 4 +- .../utilities/debugging_and_mwes/index.html | 4 +- .../utilities/tools_for_rules/index.html | 8 +-- 15 files changed, 51 insertions(+), 51 deletions(-) diff --git a/previews/PR386/.documenter-siteinfo.json b/previews/PR386/.documenter-siteinfo.json index 674d978e4..025773cdd 100644 --- a/previews/PR386/.documenter-siteinfo.json +++ b/previews/PR386/.documenter-siteinfo.json @@ -1 +1 @@ -{"documenter":{"julia_version":"1.11.1","generation_timestamp":"2024-11-23T15:21:22","documenter_version":"1.8.0"}} \ No newline at end of file +{"documenter":{"julia_version":"1.11.1","generation_timestamp":"2024-11-27T10:47:47","documenter_version":"1.8.0"}} \ No newline at end of file diff --git a/previews/PR386/developer_documentation/developer_tools/index.html b/previews/PR386/developer_documentation/developer_tools/index.html index c3d661b31..e92249c67 100644 --- a/previews/PR386/developer_documentation/developer_tools/index.html +++ b/previews/PR386/developer_documentation/developer_tools/index.html @@ -2,16 +2,16 @@ Developer Tools · Mooncake.jl

Developer Tools

Mooncake.jl offers developers to a few convenience functions which give access to the IR that it generates in order to perform AD. These are lightweight wrappers around internals which save you from having to dig in to the objects created by build_rrule.

Since these provide access to internals, they do not follow the usual rules of semver, and may change without notice!

Mooncake.primal_irFunction
primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Get the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).

For example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:

julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true
source
Mooncake.fwd_irFunction
fwd_ir(
     sig::Type{<:Tuple};
     interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
 )::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_irFunction
rvs_ir(
+true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_irFunction
rvs_ir(
     sig::Type{<:Tuple};
     interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
 )::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
+true

Arguments

Keyword Arguments

source diff --git a/previews/PR386/developer_documentation/forwards_mode_design/index.html b/previews/PR386/developer_documentation/forwards_mode_design/index.html index 132c3cfd0..14eb0a870 100644 --- a/previews/PR386/developer_documentation/forwards_mode_design/index.html +++ b/previews/PR386/developer_documentation/forwards_mode_design/index.html @@ -1,5 +1,5 @@ -Forwards-Mode Design · Mooncake.jl

Forwards-Mode Design

Disclaimer: this document refers to an as-yet-unimplemented forwards-mode AD. This will disclaimer will be removed once it has been implemented.

The purpose of this document is to explain how forwards-mode AD in Mooncake.jl is implemented. It should do so to a sufficient level of depth to enable the interested reader to read to the forwards-mode AD code in Mooncake.jl and understand what is going on.

This document

  1. specifies the semantics of a "rule" for forwards-mode AD,
  2. specifies how to implement rules by-hand for primitives, and
  3. specifies how to derive rules from IRCode algorithmically in general.
  4. discusses batched forwards-mode
  5. concludes with some notable technical differences between our forwards-mode AD implementation details and reverse-mode AD implementation details.

Forwards-Rule Interface

Loosely, a rule for a function simultaneously

  1. performs same computation as the original function, and
  2. computes the Frechet derivative.

This is best made concrete through a worked example. Consider a function call

z = f(x, y)

where f itself may contain data / state which is modified by executing f. rule_for_f is some callable which claims to be a forwards-rule for f. For rule_for_f to be a valid forwards-rule for f, it must be applicable to Duals as follows:

z_dz = rule_for_f(Dual(f, df), Dual(x, dx), Dual(y, dy))::Dual

where:

  1. rule_for_f is a callable. It might be written by-hand, or derived algorithmically.
  2. df, dx, and dy are tangents for f, x, and y respectively. Before executing rule_for_f, they are inputs to the derivative of (f, x, y). After executing they are outputs of this derivative.
  3. z_dz is a Dual containing the primal and the component of the derivative of (f, x, y) to (df, dx, dy) associated to z.
  4. running rule_for_f leaves f, x, and y in the same state that running f does.

We refer readers to Algorithmic Differentiation to explain what we mean when we talk about the "derivative" above. We also discussed some worked examples shortly.

Note that rule_for_f is an as-yet-unspecified callable which we introduced purely to specify the interface that a forwards-rule must satisfy. In Hand-Written Rules and Derived Rules below, we introduce two concrete ways to produce rules for f.

Tangent Types

We will use the type system documented in Representing Gradients. This means that every primal type has a unique tangent type. Moreover, if a Dual is defined as follows:

struct Dual{P, T}
+Forwards-Mode Design · Mooncake.jl

Forwards-Mode Design

Disclaimer: this document refers to an as-yet-unimplemented forwards-mode AD. This will disclaimer will be removed once it has been implemented.

The purpose of this document is to explain how forwards-mode AD in Mooncake.jl is implemented. It should do so to a sufficient level of depth to enable the interested reader to read to the forwards-mode AD code in Mooncake.jl and understand what is going on.

This document

  1. specifies the semantics of a "rule" for forwards-mode AD,
  2. specifies how to implement rules by-hand for primitives, and
  3. specifies how to derive rules from IRCode algorithmically in general.
  4. discusses batched forwards-mode
  5. discusses some notable technical differences between our forwards-mode AD implementation details and reverse-mode AD implementation details, and
  6. concludes with a brief comparison with ForwardDiff.jl.

Forwards-Rule Interface

Loosely, a rule for a function simultaneously

  1. performs same computation as the original function, and
  2. computes the Frechet derivative.

This is best made concrete through a worked example. Consider a function call

z = f(x, y)

where f itself may contain data / state which is modified by executing f. rule_for_f is some callable which claims to be a forwards-rule for f. For rule_for_f to be a valid forwards-rule for f, it must be applicable to Duals as follows:

z_dz = rule_for_f(Dual(f, df), Dual(x, dx), Dual(y, dy))::Dual

where:

  1. rule_for_f is a callable. It might be written by-hand, or derived algorithmically.
  2. df, dx, and dy are tangents for f, x, and y respectively. Before executing rule_for_f, they are inputs to the derivative of (f, x, y). After executing they are outputs of this derivative.
  3. z_dz is a Dual containing the primal and the component of the derivative of (f, x, y) to (df, dx, dy) associated to z.
  4. running rule_for_f leaves f, x, and y in the same state that running f does.

We refer readers to Algorithmic Differentiation to explain what we mean when we talk about the "derivative" above. We also discussed some worked examples shortly.

Note that rule_for_f is an as-yet-unspecified callable which we introduced purely to specify the interface that a forwards-rule must satisfy. In Hand-Written Rules and Derived Rules below, we introduce two concrete ways to produce rules for f.

Tangent Types

We will use the type system documented in Representing Gradients. This means that every primal type has a unique tangent type. Moreover, if a Dual is defined as follows:

struct Dual{P, T}
     primal::P
     tangent::T
 end

it must always hold that T = tangent_type(P).

Testing

Suppose that we have (somehow) produced a supposed forwards-rule. To check that it is correctly implemented, we must

  1. all primal state after running the rule is approximately the same as all primal state after running the primal, and
  2. the inner product between all tangents (both output and input) and a random tangent vector after running the rule is approximately the same as the estimate of the same quantity produced by finite differencing or reverse-mode AD.

We already have the functionality to do this in a very general way (see Mooncake.TestUtils.test_rule).

Hand-Written Rules

Hand-written rules are implemented by writing methods of two functions: is_primitive and frule!!.

is_primitive

is_primitive(::Type{<:Union{MinimalForwardsCtx, DefaultForwardsCtx}}, signature::Type{<:Tuple}) should return true if AD must attempt to differentiate a call by passing the arguments to frule!!, and false otherwise. The Mooncake.@is_primitive macro can be used to implement this straightforwardly.

frule!!

Methods of frule!! do the actual differentiation, and must satisfy the Forwards-Rule Interface discussed above.

In what follows, we will refer to frule!!s for signatures. For example, the frule!! for signature Tuple{typeof(sin), Float64} is the rule which would differentiate calls like sin(5.0).

Simple Scalar Function

Recall that for $y = \sin(x)$ we have that $\dot{y} = \cos(x) \dot{x}$. So the frule!! for signature Tuple{typeof(sin), Float64} is:

function frule!!(::Dual{typeof(sin)}, x::Dual{Float64})
@@ -35,5 +35,5 @@
 2 1 ─ %1 = invoke rule_for_g($(Dual(Main.g, NoTangent())), _3::Dual{Float64, Float64})::Dual{Float64, Float64}
 3 │   %2 = invoke rule_for_h($(Dual(Main.h, NoTangent())), _3::Dual{Float64, Float64}, %1::Dual{Float64, Float64})::Dual{Float64, Float64}
 4 └──      return %2
-   => Dual{Float64, Float64}

Observe that:

  1. All Arguments have been incremented by 1. i.e. _2 has been replaced with _3. This corresponds to the fact that the arguments to the rule have all been shuffled along by one, and the rule itself is now the first argument.
  2. Everything has been turned into a Dual.
  3. Constants such as Dual(Main.g, NoTangent()) appear directly in the code (here as QuoteNodes).

(In practice it might be that we actually construct the Dualed constants on the lines immediately preceding a call and rely on the compiler to optimise them back into the call directly).

Here, as before, we have not specified exactly what rule_for_f, rule_for_g, and rule_for_h are. This is intentional – they are just callables satisfying the Forwards-Rule Interface. In the following we show how to derive rule_for_f, and show how rule_for_g and rule_for_h might be methods of Mooncake.frule!!, or themselves derived rules.

Rule Derivation Outline

Equipped with some intuition about what a derived rule ought to look like, we examine how we go about producing it algorithmically.

Rule derivation is implemented via the function Mooncake.build_frule. This function accepts as arguments a context and a signature / Base.MethodInstance / MistyClosure and, roughly speaking, does the following:

  1. Look up the optimised Compiler.IRCode.
  2. Apply a series of standardising transformations to the IRCode.
  3. Transform each statement according to a set of rules to produce a new IRCode.
  4. Apply standard Julia optimisations to this new IRCode.
  5. Put this code inside a MistyClosure in order to produce a executable object.
  6. Wrap this MistyClosure in a DerivedFRule to handle various bits of book-keeping around varargs.

In order:

Looking up the Compiler.IRCode.

This is done using Mooncake.lookup_ir. This function has methods with will return the IRCode associated to:

  1. signatures (e.g. Tuple{typeof(f), Float64})
  2. Base.MethodInstances (relevant for :invoke expressions – see Statement Transformation below)
  3. MistyClosures.MistyClosure objects, which is essential when computing higher order derivatives and Hessians by applying Mooncake.jl to itself.

Standardisation

We apply the following transformations to the Julia IR. They can all be found in ir_normalisation.jl:

  1. Mooncake.foreigncall_to_call: convert Expr(:foreigncall, ...) expressions into Expr(:call, Mooncake._foreigncall_, ...) expressions.
  2. Mooncake.new_to_call: convert Expr(:new, ...) expressions to Expr(:call, Mooncake._new_, ...) expressions.
  3. Mooncake.splatnew_to_call: convert Expr(:splatnew, ...) expressions to Expr(:call, Mooncake._splat_new_...) expressions.
  4. Mooncake.intrinsic_to_function: convert Expr(:call, ::IntrinsicFunction, ...) to calls to the corresponding function in Mooncake.IntrinsicsWrappers.

The purpose of converting Expr(:foreigncall...), Expr(:new, ...) and Expr(:splatnew, ...) into Expr(:call, ...)s is to enable us to differentiate such expressions by adding methods to frule!!(::Dual{typeof(Mooncake._foreigncall_)}), frule!!(::Dual{typeof(Mooncake._new_)}), and frule!!(::Dual{typeof(Mooncake._splat_new_)}), in exactly the same way that we would for any other regular Julia function.

The purpose of translating Expr(:call, ::IntrinsicFunction, ...) is to do with type stability – see the docstring for the Mooncake.IntrinsicsWrappers module for more info.

Statement Transformation

Each statment which can appear in the Julia IR is transformed by a method of Mooncake.make_fwds_ad_stmts. Consequently, this transformation phase simply corresponds to iterating through all of the expressions in the IRCode, applying Mooncake.make_fwd_ad_stmts to each to produce new IRCode. To understand how to modify IRCode and insert new instructions, see Oxinabox's Gist.

We provide here a high-level summary of the transformations for the most important Julia IR statements, and refer readers to the methods of Mooncake.make_fwds_ad_stmts for the definitive explanation of what transformation is applied, and the rationale for applying it. In particular there are quite a number more statements which can appear in Julia IR than those listed here and, for those we do list here, there are typically a few edge cases left out.

Expr(:invoke, method_instance, f, x...) and Expr(:call, f, x...)

:call expressions correspond to dynamic dispatch, while :invoke expressions correspond to static dispatch. That is, if you see an :invoke expression, you know for sure that the compiler knows enough information about the types of f and x to prove exactly which specialisation of which method to call. This specialisation is method_instance. This typically happens when the compiler is able to prove the types of f and x. Conversely, a :call expression typically occurs when the compiler has not been able to deduce the exact types of f and x, and therefore not been able to figure out what to call. It therefore has to wait until runtime to figure out what to call, resulting in dynamic dispatch.

As we saw earlier, the idea is to translate these kinds of expressions into something vaguely along the lines of

Expr(:call, rule_for_f, f, x...)

There are three cases to consider, in order of preference:

Primitives:

If is_primitive returns true when applied to the signature constructed from the static types of f and x, then we simply replace the expression with Expr(:call, frule!!, f, x...), regardless whether we have an :invoke or :call expression. (Due to the Standardisation steps, it regularly happens that we see :call expressions in which we actually do know enough type information to do this, e.g. for Mooncake._new_ :call expressions).

Static Dispatch:

In the case of :invoke nodes we know for sure at rule compilation time what rule_for_f must be. We derive a rule for the call by passing method_instance to Mooncake.build_frule. (In practice, we might do this lazily, but while retaining enough information to maintain type stability. See the Mooncake.LazyDerivedRule for how this is handled in reverse-mode).

Dynamic Dispatch:

If we have a :call expression and are not able to prove that is_primitive will return true, we must defer dispatch until runtime. We do this by replacing the :call expression with a call to a DynamicFRule, which simply constructs (or retrieves from a cache) the rule at runtime. Reverse-mode utilises a similar strategy via Mooncake.DynamicDerivedRule.

The above was written in terms of f and x. In practice, of course, we encounter various kinds of constants (e.g. Base.sin), Arguments (e.g. _3), and Core.SSAValues (e.g. %5). The translation rules for these are:

  1. constants are turned into constant duals in which the tangent is zero,
  2. Arguments are incremented by 1.
  3. SSAValues are left as-is.

Core.GotoNodes

These remain entirely unchanged.

Core.GotoIfNot

These require minor modification. Suppose that a Core.GotoIfNot of the form Core.GotoIfNot(%5, 4) is encountered in the primal. Since %5 will be a Dual in the derived rule, we must pull out the primal field, and pass that to the conditional instead. Therefore, these statments get lowered to two lines in the derived rule. For example, Core.GotoIfNot(%5, 4) would be translated to:

%n = getfield(%5, :primal)
-Core.GotoIfNot(%n, 4)

Core.PhiNode

Core.PhiNode looks something like the following in the general case:

φ (#1 => %3, #2 => _2, #3 => 4, #4 => #undef)

They map from a collection of basic block numbers (#1, #2, etc) to values. The values can be Core.Arguments, Core.SSAValues, constants (literals and QuoteNodes), or undefined.

Core.PhiNodes in the primal are mapped to Core.PhiNodes in the rule. They contain exactly the same basic block numbers, and apply the following translation rules to the values:

  1. Core.SSAValues are unchanged.
  2. Core.Arguments are incremented by 1 (as always).
  3. constants are translated into constant duals.
  4. undefined values remain undefined.

So the above example would be translated into something like

φ (#1 => %3, #2 => _3, #3 => $(CoDual(4, NoTangent())), #4 => #undef)

Optimisation

The IR generated in the previous step will typically be uninferred, and suboptimal in a variety of ways. We fix this up by running inference and optimisation on the generated IRCode. This is implemented by Mooncake.optimise_ir!.

Put IRCode in MistyClosure

Now that we have an optimised IRCode object, we need to turn it into something that can actually be run. This can, in general, be straightforwardly achieved by putting it inside a Core.OpaqueClosure. This works, but Core.OpaqueClosures have the disadvantage that once you've constructed a Core.OpaqueClosure using an IRCode, it is not possible to get it back out. Consequently, we use MistyClosures, in order to keep the IRCode readily accessible if we want to access it later.

Put the MistyClosure in a DerivedFRule

See the implementation of DerivedRule (used in reverse-mode) for more context on this. This is the "rule" that users get.

Batch Mode

So far, we have assumed that we would only apply forwards-mode to a single tangent vector at a time. However, in practice, it is typically best to pass a collection of tangents through at a time.

In order to do this, all of the transformation code listed above can remain the same, we will just need to devise a system of "batched tangents". Then, instead of propagating a "primal-tangent" pairs via Duals, we propagate primal-tangent_batch pairs (perhaps also via Duals).

Forwards vs Reverse Implementation

The implementation of forwards-mode AD is quite dramatically simpler than that of reverse-mode AD. Some notable technical differences include:

  1. forwards-mode AD only makes use of the tangent system, whereas reverse-mode also makes use of the fdata / rdata system.
  2. forwards-mode AD comprises only line-by-line transformations of the IRCode. In particular, it does not require the insertion of additional basic blocks, nor the modification of the successors / predecessors of any given basic block. Consequently, there is no need to make use of the BBCode infrastructure built up for reverse-mode AD – everything can be straightforwardly done at the Compiler.IRCode level.
+ => Dual{Float64, Float64}

Observe that:

  1. All Arguments have been incremented by 1. i.e. _2 has been replaced with _3. This corresponds to the fact that the arguments to the rule have all been shuffled along by one, and the rule itself is now the first argument.
  2. Everything has been turned into a Dual.
  3. Constants such as Dual(Main.g, NoTangent()) appear directly in the code (here as QuoteNodes).

(In practice it might be that we actually construct the Dualed constants on the lines immediately preceding a call and rely on the compiler to optimise them back into the call directly).

Here, as before, we have not specified exactly what rule_for_f, rule_for_g, and rule_for_h are. This is intentional – they are just callables satisfying the Forwards-Rule Interface. In the following we show how to derive rule_for_f, and show how rule_for_g and rule_for_h might be methods of Mooncake.frule!!, or themselves derived rules.

Rule Derivation Outline

Equipped with some intuition about what a derived rule ought to look like, we examine how we go about producing it algorithmically.

Rule derivation is implemented via the function Mooncake.build_frule. This function accepts as arguments a context and a signature / Base.MethodInstance / MistyClosure and, roughly speaking, does the following:

  1. Look up the optimised Compiler.IRCode.
  2. Apply a series of standardising transformations to the IRCode.
  3. Transform each statement according to a set of rules to produce a new IRCode.
  4. Apply standard Julia optimisations to this new IRCode.
  5. Put this code inside a MistyClosure in order to produce a executable object.
  6. Wrap this MistyClosure in a DerivedFRule to handle various bits of book-keeping around varargs.

In order:

Looking up the Compiler.IRCode.

This is done using Mooncake.lookup_ir. This function has methods with will return the IRCode associated to:

  1. signatures (e.g. Tuple{typeof(f), Float64})
  2. Base.MethodInstances (relevant for :invoke expressions – see Statement Transformation below)
  3. MistyClosures.MistyClosure objects, which is essential when computing higher order derivatives and Hessians by applying Mooncake.jl to itself.

Standardisation

We apply the following transformations to the Julia IR. They can all be found in ir_normalisation.jl:

  1. Mooncake.foreigncall_to_call: convert Expr(:foreigncall, ...) expressions into Expr(:call, Mooncake._foreigncall_, ...) expressions.
  2. Mooncake.new_to_call: convert Expr(:new, ...) expressions to Expr(:call, Mooncake._new_, ...) expressions.
  3. Mooncake.splatnew_to_call: convert Expr(:splatnew, ...) expressions to Expr(:call, Mooncake._splat_new_...) expressions.
  4. Mooncake.intrinsic_to_function: convert Expr(:call, ::IntrinsicFunction, ...) to calls to the corresponding function in Mooncake.IntrinsicsWrappers.

The purpose of converting Expr(:foreigncall...), Expr(:new, ...) and Expr(:splatnew, ...) into Expr(:call, ...)s is to enable us to differentiate such expressions by adding methods to frule!!(::Dual{typeof(Mooncake._foreigncall_)}), frule!!(::Dual{typeof(Mooncake._new_)}), and frule!!(::Dual{typeof(Mooncake._splat_new_)}), in exactly the same way that we would for any other regular Julia function.

The purpose of translating Expr(:call, ::IntrinsicFunction, ...) is to do with type stability – see the docstring for the Mooncake.IntrinsicsWrappers module for more info.

Statement Transformation

Each statment which can appear in the Julia IR is transformed by a method of Mooncake.make_fwds_ad_stmts. Consequently, this transformation phase simply corresponds to iterating through all of the expressions in the IRCode, applying Mooncake.make_fwd_ad_stmts to each to produce new IRCode. To understand how to modify IRCode and insert new instructions, see Oxinabox's Gist.

We provide here a high-level summary of the transformations for the most important Julia IR statements, and refer readers to the methods of Mooncake.make_fwds_ad_stmts for the definitive explanation of what transformation is applied, and the rationale for applying it. In particular there are quite a number more statements which can appear in Julia IR than those listed here and, for those we do list here, there are typically a few edge cases left out.

Expr(:invoke, method_instance, f, x...) and Expr(:call, f, x...)

:call expressions correspond to dynamic dispatch, while :invoke expressions correspond to static dispatch. That is, if you see an :invoke expression, you know for sure that the compiler knows enough information about the types of f and x to prove exactly which specialisation of which method to call. This specialisation is method_instance. This typically happens when the compiler is able to prove the types of f and x. Conversely, a :call expression typically occurs when the compiler has not been able to deduce the exact types of f and x, and therefore not been able to figure out what to call. It therefore has to wait until runtime to figure out what to call, resulting in dynamic dispatch.

As we saw earlier, the idea is to translate these kinds of expressions into something vaguely along the lines of

Expr(:call, rule_for_f, f, x...)

There are three cases to consider, in order of preference:

Primitives:

If is_primitive returns true when applied to the signature constructed from the static types of f and x, then we simply replace the expression with Expr(:call, frule!!, f, x...), regardless whether we have an :invoke or :call expression. (Due to the Standardisation steps, it regularly happens that we see :call expressions in which we actually do know enough type information to do this, e.g. for Mooncake._new_ :call expressions).

Static Dispatch:

In the case of :invoke nodes we know for sure at rule compilation time what rule_for_f must be. We derive a rule for the call by passing method_instance to Mooncake.build_frule. (In practice, we might do this lazily, but while retaining enough information to maintain type stability. See the Mooncake.LazyDerivedRule for how this is handled in reverse-mode).

Dynamic Dispatch:

If we have a :call expression and are not able to prove that is_primitive will return true, we must defer dispatch until runtime. We do this by replacing the :call expression with a call to a DynamicFRule, which simply constructs (or retrieves from a cache) the rule at runtime. Reverse-mode utilises a similar strategy via Mooncake.DynamicDerivedRule.

The above was written in terms of f and x. In practice, of course, we encounter various kinds of constants (e.g. Base.sin), Arguments (e.g. _3), and Core.SSAValues (e.g. %5). The translation rules for these are:

  1. constants are turned into constant duals in which the tangent is zero,
  2. Arguments are incremented by 1.
  3. SSAValues are left as-is.

Core.GotoNodes

These remain entirely unchanged.

Core.GotoIfNot

These require minor modification. Suppose that a Core.GotoIfNot of the form Core.GotoIfNot(%5, 4) is encountered in the primal. Since %5 will be a Dual in the derived rule, we must pull out the primal field, and pass that to the conditional instead. Therefore, these statments get lowered to two lines in the derived rule. For example, Core.GotoIfNot(%5, 4) would be translated to:

%n = getfield(%5, :primal)
+Core.GotoIfNot(%n, 4)

Core.PhiNode

Core.PhiNode looks something like the following in the general case:

φ (#1 => %3, #2 => _2, #3 => 4, #4 => #undef)

They map from a collection of basic block numbers (#1, #2, etc) to values. The values can be Core.Arguments, Core.SSAValues, constants (literals and QuoteNodes), or undefined.

Core.PhiNodes in the primal are mapped to Core.PhiNodes in the rule. They contain exactly the same basic block numbers, and apply the following translation rules to the values:

  1. Core.SSAValues are unchanged.
  2. Core.Arguments are incremented by 1 (as always).
  3. constants are translated into constant duals.
  4. undefined values remain undefined.

So the above example would be translated into something like

φ (#1 => %3, #2 => _3, #3 => $(CoDual(4, NoTangent())), #4 => #undef)

Optimisation

The IR generated in the previous step will typically be uninferred, and suboptimal in a variety of ways. We fix this up by running inference and optimisation on the generated IRCode. This is implemented by Mooncake.optimise_ir!.

Put IRCode in MistyClosure

Now that we have an optimised IRCode object, we need to turn it into something that can actually be run. This can, in general, be straightforwardly achieved by putting it inside a Core.OpaqueClosure. This works, but Core.OpaqueClosures have the disadvantage that once you've constructed a Core.OpaqueClosure using an IRCode, it is not possible to get it back out. Consequently, we use MistyClosures, in order to keep the IRCode readily accessible if we want to access it later.

Put the MistyClosure in a DerivedFRule

See the implementation of DerivedRule (used in reverse-mode) for more context on this. This is the "rule" that users get.

Batch Mode

So far, we have assumed that we would only apply forwards-mode to a single tangent vector at a time. However, in practice, it is typically best to pass a collection of tangents through at a time.

In order to do this, all of the transformation code listed above can remain the same, we will just need to devise a system of "batched tangents". Then, instead of propagating a "primal-tangent" pairs via Duals, we propagate primal-tangent_batch pairs (perhaps also via Duals).

Forwards vs Reverse Implementation

The implementation of forwards-mode AD is quite dramatically simpler than that of reverse-mode AD. Some notable technical differences include:

  1. forwards-mode AD only makes use of the tangent system, whereas reverse-mode also makes use of the fdata / rdata system.
  2. forwards-mode AD comprises only line-by-line transformations of the IRCode. In particular, it does not require the insertion of additional basic blocks, nor the modification of the successors / predecessors of any given basic block. Consequently, there is no need to make use of the BBCode infrastructure built up for reverse-mode AD – everything can be straightforwardly done at the Compiler.IRCode level.

Comparison with ForwardDiff.jl

With reference to the limitations of ForwardDiff.jl, there are a few noteworthy differences between ForwardDiff.jl and this implementation:

  1. :foreigncalls pose much less of a problem for Mooncake's forward-mode than for ForwardDiff.jl, because we can write a rule for any method of any function. In essence, you can only (reliably) write rules for ForwardDiff.jl via dispatch on ForwardDiff.Dual.
  2. the target function can be of any arity in Mooncake.jl, but must be unary in ForwardDiff.jl.
  3. there are no limitations on the argument type constraints that Mooncake.jl can handle, while ForwardDiff.jl requires that argument type constraints be <:Real or arrays of <:Real.
  4. No special storage types are required with Mooncake.jl, while ForwardDiff.jl requires that any container you write to is able to contain ForwardDiff.Duals.
diff --git a/previews/PR386/developer_documentation/internal_docstrings/index.html b/previews/PR386/developer_documentation/internal_docstrings/index.html index 32da2e0f2..2ccf895a2 100644 --- a/previews/PR386/developer_documentation/internal_docstrings/index.html +++ b/previews/PR386/developer_documentation/internal_docstrings/index.html @@ -1,17 +1,17 @@ -Internal Docstrings · Mooncake.jl

Internal Docstrings

Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.

The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.

Mooncake.TerminatorType
Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode}

A Union of the possible types of a terminator node.

source
Core.Compiler.IRCodeMethod
IRCode(bb_code::BBCode)

Produce an IRCode instance which is equivalent to bb_code. The resulting IRCode shares no memory with bb_code, so can be safely mutated without modifying bb_code.

All IDPhiNodes, IDGotoIfNots, and IDGotoNodes are converted into PhiNodes, GotoIfNots, and GotoNodes respectively.

In the resulting bb_code, any Switch nodes are lowered into a semantically-equivalent collection of GotoIfNot nodes.

source
Mooncake.ADInfoType
ADInfo

This data structure is used to hold "global" information associated to a particular call to build_rrule. It is used as a means of communication between make_ad_stmts! and the codegen which produces the forwards- and reverse-passes.

  • interp: a MooncakeInterpreter.
  • block_stack_id: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.
  • block_stack: the block stack. Can always be found at block_stack_id in the forwards- and reverse-passes.
  • entry_id: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.
  • shared_data_pairs: the SharedDataPairs used to define the captured variables passed to both the forwards- and reverse-passes.
  • arg_types: a map from Argument to its static type.
  • ssa_insts: a map from ID associated to lines to the primal NewInstruction. This contains the line of code, its static / inferred type, and some other detailss. See Core.Compiler.NewInstruction for a full list of fields.
  • arg_rdata_ref_ids: the dict mapping from arguments to the ID which creates and initialises the Ref which contains the reverse data associated to that argument. Recall that the heap allocations associated to this Ref are always optimised away in the final programme.
  • ssa_rdata_ref_ids: the same as arg_rdata_ref_ids, but for each ID associated to an ssa rather than each argument.
  • debug_mode: if true, run in "debug mode" – wraps all rule calls in DebugRRule. This is applied recursively, so that debug mode is also switched on in derived rules.
  • is_used_dict: for each ID associated to a line of code, is false if line is not used anywhere in any other line of code.
  • lazy_zero_rdata_ref_id: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct a LazyZeroRData for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.
source
Mooncake.ADStmtInfoType
ADStmtInfo

Data structure which contains the result of make_ad_stmts!. Fields are

  • line: the ID associated to the primal line from which this is derived
  • comms_id: an ID from one of the lines in fwds, whose value will be made available on the reverse-pass in the same ID. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communication ID on a per-block basis.
  • fwds: the instructions which run the forwards-pass of AD
  • rvs: the instructions which run the reverse-pass of AD / the pullback
source
Mooncake.BBCodeMethod
BBCode(ir::IRCode)

Convert an ir into a BBCode. Creates a completely independent data structure, so mutating the BBCode returned will not mutate ir.

All PhiNodes, GotoIfNots, and GotoNodes will be replaced with the IDPhiNodes, IDGotoIfNots, and IDGotoNodes respectively.

See IRCode for conversion back to IRCode.

Note that IRCode(BBCode(ir)) should be equal to the identity function.

source
Mooncake.BBCodeMethod
BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block})

Make a new BBCode whose blocks is given by new_blocks, and fresh copies are made of all other fields from ir.

source
Mooncake.BBCodeType
BBCode(
+Internal Docstrings · Mooncake.jl

Internal Docstrings

Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.

The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.

Mooncake.TerminatorType
Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode}

A Union of the possible types of a terminator node.

source
Core.Compiler.IRCodeMethod
IRCode(bb_code::BBCode)

Produce an IRCode instance which is equivalent to bb_code. The resulting IRCode shares no memory with bb_code, so can be safely mutated without modifying bb_code.

All IDPhiNodes, IDGotoIfNots, and IDGotoNodes are converted into PhiNodes, GotoIfNots, and GotoNodes respectively.

In the resulting bb_code, any Switch nodes are lowered into a semantically-equivalent collection of GotoIfNot nodes.

source
Mooncake.ADInfoType
ADInfo

This data structure is used to hold "global" information associated to a particular call to build_rrule. It is used as a means of communication between make_ad_stmts! and the codegen which produces the forwards- and reverse-passes.

  • interp: a MooncakeInterpreter.
  • block_stack_id: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.
  • block_stack: the block stack. Can always be found at block_stack_id in the forwards- and reverse-passes.
  • entry_id: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.
  • shared_data_pairs: the SharedDataPairs used to define the captured variables passed to both the forwards- and reverse-passes.
  • arg_types: a map from Argument to its static type.
  • ssa_insts: a map from ID associated to lines to the primal NewInstruction. This contains the line of code, its static / inferred type, and some other detailss. See Core.Compiler.NewInstruction for a full list of fields.
  • arg_rdata_ref_ids: the dict mapping from arguments to the ID which creates and initialises the Ref which contains the reverse data associated to that argument. Recall that the heap allocations associated to this Ref are always optimised away in the final programme.
  • ssa_rdata_ref_ids: the same as arg_rdata_ref_ids, but for each ID associated to an ssa rather than each argument.
  • debug_mode: if true, run in "debug mode" – wraps all rule calls in DebugRRule. This is applied recursively, so that debug mode is also switched on in derived rules.
  • is_used_dict: for each ID associated to a line of code, is false if line is not used anywhere in any other line of code.
  • lazy_zero_rdata_ref_id: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct a LazyZeroRData for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.
source
Mooncake.ADStmtInfoType
ADStmtInfo

Data structure which contains the result of make_ad_stmts!. Fields are

  • line: the ID associated to the primal line from which this is derived
  • comms_id: an ID from one of the lines in fwds, whose value will be made available on the reverse-pass in the same ID. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communication ID on a per-block basis.
  • fwds: the instructions which run the forwards-pass of AD
  • rvs: the instructions which run the reverse-pass of AD / the pullback
source
Mooncake.BBCodeMethod
BBCode(ir::IRCode)

Convert an ir into a BBCode. Creates a completely independent data structure, so mutating the BBCode returned will not mutate ir.

All PhiNodes, GotoIfNots, and GotoNodes will be replaced with the IDPhiNodes, IDGotoIfNots, and IDGotoNodes respectively.

See IRCode for conversion back to IRCode.

Note that IRCode(BBCode(ir)) should be equal to the identity function.

source
Mooncake.BBCodeMethod
BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block})

Make a new BBCode whose blocks is given by new_blocks, and fresh copies are made of all other fields from ir.

source
Mooncake.BBCodeType
BBCode(
     blocks::Vector{BBlock}
     argtypes::Vector{Any}
     sptypes::Vector{CC.VarState}
     linetable::Vector{Core.LineInfoNode}
     meta::Vector{Expr}
-)

A BBCode is a data structure which is similar to IRCode, but adds additional structure.

In particular, a BBCode comprises a sequence of basic blocks (BBlocks), each of which comprise a sequence of statements. Moreover, each BBlock has its own unique ID, as does each statment.

The consequence of this is that new basic blocks can be inserted into a BBCode. This is distinct from IRCode, in which to create a new basic block, one must insert additional statments which you know will create a new basic block – this is generally quite an unreliable process, while inserting a new BBlock into BBCode is entirely predictable. Furthermore, inserting a new BBlock does not change the ID associated to the other blocks, meaning that you can safely assume that references from existing basic block terminators / phi nodes to other blocks will not be modified by inserting a new basic block.

Additionally, since each statment in each basic block has its own unique ID, new statments can be inserted without changing references between other blocks. IRCode also has some support for this via its new_nodes field, but eventually all statements will be renamed upon compact!ing the IRCode, meaning that the name of any given statement will eventually change.

Finally, note that the basic blocks in a BBCode support the custom Switch statement. This statement is not valid in IRCode, and is therefore lowered into a collection of GotoIfNots and GotoNodes when a BBCode is converted back into an IRCode.

source
Mooncake.BBlockMethod
BBlock(id::ID, inst_pairs::Vector{IDInstPair})

Convenience constructor – splits inst_pairs into a Vector{ID} and InstVector in order to build a BBlock.

source
Mooncake.BBlockType
BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector)

A basic block data structure (not called BasicBlock to avoid accidental confusion with CC.BasicBlock). Forms a single basic block.

Each BBlock has an ID (a unique name). This makes it possible to refer to blocks in a way that does not change when additional BBlocks are inserted into a BBCode. This differs from the positional block numbering found in IRCode, in which the number associated to a basic block changes when new blocks are inserted.

The nth line of code in a BBlock is associated to ID stmt_ids[n], and the nth instruction from stmts.

Note that PhiNodes, GotoIfNots, and GotoNodes should not appear in a BBlock – instead an IDPhiNode, IDGotoIfNot, or IDGotoNode should be used.

source
Mooncake.BlockStackType

The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32) unique basic blocks in a given function, which ought to be reasonable.

source
Mooncake.CannotProduceZeroRDataFromTypeType
CannotProduceZeroRDataFromType()

Returned by zero_rdata_from_type if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type for more info.

source
Mooncake.ConfigType
Config(; debug_mode=false, silence_debug_messages=false)

Configuration struct for use with ADTypes.AutoMooncake.

source
Mooncake.DebugPullbackMethod
(pb::DebugPullback)(dy)

Apply type checking to enforce pre- and post-conditions on pb.pb. See the docstring for DebugPullback for details.

source
Mooncake.DebugPullbackType
DebugPullback(pb, y, x)

Construct a callable which is equivalent to pb, but which enforces type-based pre- and post-conditions to pb. Let dx = pb.pb(dy), for some rdata dy, then this function

  • checks that dy has the correct rdata type for y, and
  • checks that each element of dx has the correct rdata type for x.

Reverse pass counterpart to DebugRRule

source
Mooncake.DebugRRuleMethod
(rule::DebugRRule)(x::CoDual...)

Apply type checking to enforce pre- and post-conditions on rule.rule. See the docstring for DebugRRule for details.

source
Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source
Mooncake.DefaultCtxType
struct DefaultCtx end

Context for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.

source
Mooncake.DynamicDerivedRuleType
DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)

For internal use only.

A callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.

This is used to implement dynamic dispatch.

source
Mooncake.FDataType
FData(data::NamedTuple)

The component of a struct which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64 fields of Tangent do not need to appear in the associated FData.

source
Mooncake.IDType
ID()

An ID (read: unique name) is just a wrapper around an Int32. Uniqueness is ensured via a global counter, which is incremented each time that an ID is created.

This counter can be reset using seed_id! if you need to ensure deterministic IDs are produced, in the same way that seed for random number generators can be set.

source
Mooncake.IDPhiNodeType
IDPhiNode(edges::Vector{ID}, values::Vector{Any})

Like a PhiNode, but edges are IDs rather than Int32s.

source
Mooncake.InstVectorType
const InstVector = Vector{NewInstruction}

Note: the CC.NewInstruction type is used to represent instructions because it has the correct fields. While it is only used to represent new instrucdtions in Core.Compiler, it is used to represent all instructions in BBCode.

source
Mooncake.LazyDerivedRuleType
LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)

For internal use only.

A type-stable wrapper around a DerivedRule, which only instantiates the DerivedRule when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.

If debug_mode is true, then the rule constructed will be a DebugRRule. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.

Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.

source
Mooncake.LazyZeroRDataType
LazyZeroRData{P, Tdata}()

This type is a lazy placeholder for zero_like_rdata_from_type. This is used to defer construction of zero data to the reverse pass. Calling instantiate on an instance of this will construct a zero data.

Users should construct using LazyZeroRData(p), where p is an value of type P. This constructor, and instantiate, are specialised to minimise the amount of data which must be stored. For example, Float64s do not need any data, so LazyZeroRData(0.0) produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.

source
Mooncake.MinimalCtxType
struct MinimalCtx end

Functions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.

source
Mooncake.NoPullbackMethod
NoPullback(args::CoDual...)

Construct a NoPullback from the arguments passed to an rrule!!. For each argument, extracts the primal value, and constructs a LazyZeroRData. These are stored in a NoPullback which, in the reverse-pass of AD, instantiates these LazyZeroRDatas and returns them in order to perform the reverse-pass of AD.

The advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.

source
Mooncake.RRuleZeroWrapperType
RRuleZeroWrapper(rule)

This struct is used to ensure that ZeroRDatas, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.

source
Mooncake.SharedDataPairsType
SharedDataPairs()

A data structure used to manage the captured data in the OpaqueClosures which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data) at element n of the pairs field of this data structure means that data will be available at register id during the forwards- and reverse-passes of AD.

This is achieved by storing all of the data in the pairs field in the captured tuple which is passed to an OpaqueClosure, and extracting this data into registers associated to the corresponding IDs.

source
Mooncake.StackType
Stack{T}()

A stack specialised for reverse-mode AD.

Semantically equivalent to a usual stack, but never de-allocates memory once allocated.

source
Mooncake.SwitchType
Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID)

A switch-statement node. These can be inserted in the BBCode representation of Julia IR. Switch has the following semantics:

goto dests[1] if not conds[1]
+)

A BBCode is a data structure which is similar to IRCode, but adds additional structure.

In particular, a BBCode comprises a sequence of basic blocks (BBlocks), each of which comprise a sequence of statements. Moreover, each BBlock has its own unique ID, as does each statment.

The consequence of this is that new basic blocks can be inserted into a BBCode. This is distinct from IRCode, in which to create a new basic block, one must insert additional statments which you know will create a new basic block – this is generally quite an unreliable process, while inserting a new BBlock into BBCode is entirely predictable. Furthermore, inserting a new BBlock does not change the ID associated to the other blocks, meaning that you can safely assume that references from existing basic block terminators / phi nodes to other blocks will not be modified by inserting a new basic block.

Additionally, since each statment in each basic block has its own unique ID, new statments can be inserted without changing references between other blocks. IRCode also has some support for this via its new_nodes field, but eventually all statements will be renamed upon compact!ing the IRCode, meaning that the name of any given statement will eventually change.

Finally, note that the basic blocks in a BBCode support the custom Switch statement. This statement is not valid in IRCode, and is therefore lowered into a collection of GotoIfNots and GotoNodes when a BBCode is converted back into an IRCode.

source
Mooncake.BBlockMethod
BBlock(id::ID, inst_pairs::Vector{IDInstPair})

Convenience constructor – splits inst_pairs into a Vector{ID} and InstVector in order to build a BBlock.

source
Mooncake.BBlockType
BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector)

A basic block data structure (not called BasicBlock to avoid accidental confusion with CC.BasicBlock). Forms a single basic block.

Each BBlock has an ID (a unique name). This makes it possible to refer to blocks in a way that does not change when additional BBlocks are inserted into a BBCode. This differs from the positional block numbering found in IRCode, in which the number associated to a basic block changes when new blocks are inserted.

The nth line of code in a BBlock is associated to ID stmt_ids[n], and the nth instruction from stmts.

Note that PhiNodes, GotoIfNots, and GotoNodes should not appear in a BBlock – instead an IDPhiNode, IDGotoIfNot, or IDGotoNode should be used.

source
Mooncake.BlockStackType

The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32) unique basic blocks in a given function, which ought to be reasonable.

source
Mooncake.CannotProduceZeroRDataFromTypeType
CannotProduceZeroRDataFromType()

Returned by zero_rdata_from_type if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type for more info.

source
Mooncake.ConfigType
Config(; debug_mode=false, silence_debug_messages=false)

Configuration struct for use with ADTypes.AutoMooncake.

source
Mooncake.DebugPullbackMethod
(pb::DebugPullback)(dy)

Apply type checking to enforce pre- and post-conditions on pb.pb. See the docstring for DebugPullback for details.

source
Mooncake.DebugPullbackType
DebugPullback(pb, y, x)

Construct a callable which is equivalent to pb, but which enforces type-based pre- and post-conditions to pb. Let dx = pb.pb(dy), for some rdata dy, then this function

  • checks that dy has the correct rdata type for y, and
  • checks that each element of dx has the correct rdata type for x.

Reverse pass counterpart to DebugRRule

source
Mooncake.DebugRRuleMethod
(rule::DebugRRule)(x::CoDual...)

Apply type checking to enforce pre- and post-conditions on rule.rule. See the docstring for DebugRRule for details.

source
Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source
Mooncake.DefaultCtxType
struct DefaultCtx end

Context for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.

source
Mooncake.DynamicDerivedRuleType
DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)

For internal use only.

A callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.

This is used to implement dynamic dispatch.

source
Mooncake.FDataType
FData(data::NamedTuple)

The component of a struct which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64 fields of Tangent do not need to appear in the associated FData.

source
Mooncake.IDType
ID()

An ID (read: unique name) is just a wrapper around an Int32. Uniqueness is ensured via a global counter, which is incremented each time that an ID is created.

This counter can be reset using seed_id! if you need to ensure deterministic IDs are produced, in the same way that seed for random number generators can be set.

source
Mooncake.IDPhiNodeType
IDPhiNode(edges::Vector{ID}, values::Vector{Any})

Like a PhiNode, but edges are IDs rather than Int32s.

source
Mooncake.InstVectorType
const InstVector = Vector{NewInstruction}

Note: the CC.NewInstruction type is used to represent instructions because it has the correct fields. While it is only used to represent new instrucdtions in Core.Compiler, it is used to represent all instructions in BBCode.

source
Mooncake.LazyDerivedRuleType
LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)

For internal use only.

A type-stable wrapper around a DerivedRule, which only instantiates the DerivedRule when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.

If debug_mode is true, then the rule constructed will be a DebugRRule. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.

Note: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.

source
Mooncake.LazyZeroRDataType
LazyZeroRData{P, Tdata}()

This type is a lazy placeholder for zero_like_rdata_from_type. This is used to defer construction of zero data to the reverse pass. Calling instantiate on an instance of this will construct a zero data.

Users should construct using LazyZeroRData(p), where p is an value of type P. This constructor, and instantiate, are specialised to minimise the amount of data which must be stored. For example, Float64s do not need any data, so LazyZeroRData(0.0) produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.

source
Mooncake.MinimalCtxType
struct MinimalCtx end

Functions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.

source
Mooncake.NoPullbackMethod
NoPullback(args::CoDual...)

Construct a NoPullback from the arguments passed to an rrule!!. For each argument, extracts the primal value, and constructs a LazyZeroRData. These are stored in a NoPullback which, in the reverse-pass of AD, instantiates these LazyZeroRDatas and returns them in order to perform the reverse-pass of AD.

The advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.

source
Mooncake.RRuleZeroWrapperType
RRuleZeroWrapper(rule)

This struct is used to ensure that ZeroRDatas, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.

source
Mooncake.SharedDataPairsType
SharedDataPairs()

A data structure used to manage the captured data in the OpaqueClosures which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data) at element n of the pairs field of this data structure means that data will be available at register id during the forwards- and reverse-passes of AD.

This is achieved by storing all of the data in the pairs field in the captured tuple which is passed to an OpaqueClosure, and extracting this data into registers associated to the corresponding IDs.

source
Mooncake.StackType
Stack{T}()

A stack specialised for reverse-mode AD.

Semantically equivalent to a usual stack, but never de-allocates memory once allocated.

source
Mooncake.SwitchType
Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID)

A switch-statement node. These can be inserted in the BBCode representation of Julia IR. Switch has the following semantics:

goto dests[1] if not conds[1]
 goto dests[2] if not conds[2]
 ...
 goto dests[N] if not conds[N]
-goto fallthrough_dest

where the value associated to each element of conds is a Bool, and dests indicate which block to jump to. If none of the conditions are met, then we go to whichever block is specified by fallthrough_dest.

Switch statements are lowered into the above sequence of GotoIfNots and GotoNodes when converting BBCode back into IRCode, because Switch statements are not valid nodes in regular Julia IR.

source
Mooncake.ZeroRDataType
ZeroRData()

Singleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.

If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.

source
Base.insert!Method
Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing

Inserts stmt and id into bb immediately before the nth instruction.

source
Mooncake.__flatten_varargsMethod
__flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs}

If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).

source
Mooncake.__insts_to_instruction_streamMethod
__insts_to_instruction_stream(insts::Vector{Any})

Produces an instruction stream whose

  • stmt (v1.11 and up) / inst (v1.10) field is insts,
  • type field is all Any,
  • info field is all Core.Compiler.NoCallInfo,
  • line field is all Int32(1), and
  • flag field is all Core.Compiler.IR_FLAG_REFINED.

As such, if you wish to ensure that your IRCode prints nicely, you should ensure that its linetable field has at least one element.

source
Mooncake.__line_numbers_to_block_numbers!Method
__line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG)

Converts any edges in GotoNodes, GotoIfNots, PhiNodes, and :enter expressions which refer to line numbers into references to block numbers. The cfg provides the information required to perform this conversion.

For context, CodeInfo objects have references to line numbers, while IRCode uses block numbers.

This code is copied over directly from the body of Core.Compiler.inflate_ir!.

source
Mooncake.__pop_blk_stack!Method
__pop_blk_stack!(block_stack::BlockStack)

Equivalent to pop!(block_stack). Going via this function, rather than just calling pop! directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.

source
Mooncake.__push_blk_stack!Method
__push_blk_stack!(block_stack::BlockStack, id::Int32)

Equivalent to push!(block_stack, id). Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.

source
Mooncake.__run_rvs_pass!Method
__run_rvs_pass!(
+goto fallthrough_dest

where the value associated to each element of conds is a Bool, and dests indicate which block to jump to. If none of the conditions are met, then we go to whichever block is specified by fallthrough_dest.

Switch statements are lowered into the above sequence of GotoIfNots and GotoNodes when converting BBCode back into IRCode, because Switch statements are not valid nodes in regular Julia IR.

source
Mooncake.ZeroRDataType
ZeroRData()

Singleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.

If you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.

source
Base.insert!Method
Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing

Inserts stmt and id into bb immediately before the nth instruction.

source
Mooncake.__flatten_varargsMethod
__flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs}

If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).

source
Mooncake.__insts_to_instruction_streamMethod
__insts_to_instruction_stream(insts::Vector{Any})

Produces an instruction stream whose

  • stmt (v1.11 and up) / inst (v1.10) field is insts,
  • type field is all Any,
  • info field is all Core.Compiler.NoCallInfo,
  • line field is all Int32(1), and
  • flag field is all Core.Compiler.IR_FLAG_REFINED.

As such, if you wish to ensure that your IRCode prints nicely, you should ensure that its linetable field has at least one element.

source
Mooncake.__line_numbers_to_block_numbers!Method
__line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG)

Converts any edges in GotoNodes, GotoIfNots, PhiNodes, and :enter expressions which refer to line numbers into references to block numbers. The cfg provides the information required to perform this conversion.

For context, CodeInfo objects have references to line numbers, while IRCode uses block numbers.

This code is copied over directly from the body of Core.Compiler.inflate_ir!.

source
Mooncake.__pop_blk_stack!Method
__pop_blk_stack!(block_stack::BlockStack)

Equivalent to pop!(block_stack). Going via this function, rather than just calling pop! directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.

source
Mooncake.__push_blk_stack!Method
__push_blk_stack!(block_stack::BlockStack, id::Int32)

Equivalent to push!(block_stack, id). Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.

source
Mooncake.__run_rvs_pass!Method
__run_rvs_pass!(
     P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...
-) where {sig}

Used in make_ad_stmts! method for Expr(:call, ...) and Expr(:invoke, ...).

source
Mooncake.__unflatten_codual_varargsMethod
__unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs}

If isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0)) are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0))).

source
Mooncake.__value_and_gradient!!Method
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

Equivalent to __value_and_pullback!!(rule, 1.0, f, x...) – assumes f returns a Float64.

# Set up the problem.
+) where {sig}

Used in make_ad_stmts! method for Expr(:call, ...) and Expr(:invoke, ...).

source
Mooncake.__unflatten_codual_varargsMethod
__unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs}

If isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0)) are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0))).

source
Mooncake.__value_and_gradient!!Method
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

Equivalent to __value_and_pullback!!(rule, 1.0, f, x...) – assumes f returns a Float64.

# Set up the problem.
 f(x, y) = sum(x .* y)
 x = [2.0, 2.0]
 y = [1.0, 1.0]
@@ -29,25 +29,25 @@
 )
 # output
 
-(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))
source
Mooncake.__value_and_pullback!!Method
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

In-place version of value_and_pullback!! in which the arguments have been wrapped in CoDuals. Note that any mutable data in f and x will be incremented in-place. As such, if calling this function multiple times with different values of x, should be careful to ensure that you zero-out the tangent fields of x each time.

source
Mooncake._block_nums_to_idsMethod
_block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector}

Assign to each basic block in cfg an ID. Replace all integers referencing block numbers in insts with the corresponding ID. Return the IDs and the updated instructions.

source
Mooncake._build_graph_of_cfgMethod
_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}

Builds a SimpleDiGraph, g, representing of the CFG associated to blks, where blks comprises the collection of basic blocks associated to a BBCode. This is a type from Graphs.jl, so constructing g makes it straightforward to analyse the control flow structure of ir using algorithms from Graphs.jl.

Returns a 2-tuple, whose first element is g, and whose second element is a map from the ID associated to each basic block in ir, to the Int corresponding to its node index in g.

source
Mooncake._compute_all_predecessorsMethod
_compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_predecessors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._compute_all_successorsMethod
_compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_successors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._control_flow_graphMethod
_control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG

Internal function, used to implement control_flow_graph. Easier to write test cases for because there is no need to construct an ensure BBCode object, just the BBlocks.

source
Mooncake._distance_to_entryMethod
_distance_to_entry(blks::Vector{BBlock})::Vector{Int}

For each basic block in blks, compute the distance from it to the entry point (the first block. The distance is typemax(Int) if no path from the entry point to a given node.

source
Mooncake._find_id_uses!Method
_find_id_uses!(d::Dict{ID, Bool}, x)

Helper function used in characterise_used_ids. For all uses of IDs in x, set the corresponding value of d to true.

For example, if x = ReturnNode(ID(5)), then this function sets d[ID(5)] = true.

source
Mooncake.__value_and_pullback!!Method
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...)

Note: this is not part of the public Mooncake.jl interface, and may change without warning.

In-place version of value_and_pullback!! in which the arguments have been wrapped in CoDuals. Note that any mutable data in f and x will be incremented in-place. As such, if calling this function multiple times with different values of x, should be careful to ensure that you zero-out the tangent fields of x each time.

source
Mooncake._block_nums_to_idsMethod
_block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector}

Assign to each basic block in cfg an ID. Replace all integers referencing block numbers in insts with the corresponding ID. Return the IDs and the updated instructions.

source
Mooncake._build_graph_of_cfgMethod
_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}

Builds a SimpleDiGraph, g, representing of the CFG associated to blks, where blks comprises the collection of basic blocks associated to a BBCode. This is a type from Graphs.jl, so constructing g makes it straightforward to analyse the control flow structure of ir using algorithms from Graphs.jl.

Returns a 2-tuple, whose first element is g, and whose second element is a map from the ID associated to each basic block in ir, to the Int corresponding to its node index in g.

source
Mooncake._compute_all_predecessorsMethod
_compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_predecessors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._compute_all_successorsMethod
_compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}

Internal method implementing compute_all_successors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.

source
Mooncake._control_flow_graphMethod
_control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG

Internal function, used to implement control_flow_graph. Easier to write test cases for because there is no need to construct an ensure BBCode object, just the BBlocks.

source
Mooncake._distance_to_entryMethod
_distance_to_entry(blks::Vector{BBlock})::Vector{Int}

For each basic block in blks, compute the distance from it to the entry point (the first block. The distance is typemax(Int) if no path from the entry point to a given node.

source
Mooncake._find_id_uses!Method
_find_id_uses!(d::Dict{ID, Bool}, x)

Helper function used in characterise_used_ids. For all uses of IDs in x, set the corresponding value of d to true.

For example, if x = ReturnNode(ID(5)), then this function sets d[ID(5)] = true.

source
Mooncake._foreigncall_Method
function _foreigncall_(
     ::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...
-) where {name, RT, nreq, calling_convention}

:foreigncall nodes get translated into calls to this function. For example,

Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)

becomes

_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)

Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.

Credit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.

source
Mooncake._ids_to_line_numbersMethod
_ids_to_line_numbers(bb_code::BBCode)::InstVector

For each statement in bb_code, returns a NewInstruction in which every ID is replaced by either an SSAValue, or an Int64 / Int32 which refers to an SSAValue.

source
Mooncake._is_reachableMethod
_is_reachable(blks::Vector{BBlock})::Vector{Bool}

Computes a Vector whose length is length(blks). The nth element is true iff it is possible for control flow to reach the nth block.

source
Mooncake._lines_to_blocksMethod
_instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector

Pulls out the instructions from insts, and calls __line_numbers_to_block_numbers!.

source
Mooncake._lower_switch_statementsMethod
_lower_switch_statements(bb_code::BBCode)

Converts all Switchs into a semantically-equivalent collection of GotoIfNots. See the Switch docstring for an explanation of what is going on here.

source
Mooncake._mapMethod
_map(f, x...)

Same as map but requires all elements of x to have equal length. The usual function map doesn't enforce this for Arrays.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)

Similar to the other method of _map_if_assigned! – for all n, if x1[n] is assigned, writes f(x1[n], x2[n]) to y[n], otherwise leaves y[n] unchanged.

Requires that y, x1, and x2 have the same size.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}

For all n, if x[n] is assigned, then writes the value returned by f(x[n]) to y[n], otherwise leaves y[n] unchanged.

Equivalent to map!(f, y, x) if P is a bits type as element will always be assigned.

Requires that y and x have the same size.

source
Mooncake._new_Method
_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}

One-liner which calls the :new instruction with type T with arguments x.

source
Mooncake._remove_double_edgesMethod
_remove_double_edges(ir::BBCode)::BBCode

If the dest field of an IDGotoIfNot node in block n of ir points towards the n+1th block then we have two edges from block n to block n+1. This transformation replaces all such IDGotoIfNot nodes with unconditional IDGotoNodes pointing towards the n+1th block in ir.

source
Mooncake._sort_blocks!Method
_sort_blocks!(ir::BBCode)::BBCode

Ensure that blocks appear in order of distance-from-entry-point, where distance the distance from block b to the entry point is defined to be the minimum number of basic blocks that must be passed through in order to reach b.

For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem there.

WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in ir is valid. Notably, this does not hold if you have any IDGotoIfNot nodes in ir.

source
Mooncake._ssa_to_idsMethod
_ssa_to_ids(d::SSAToIdDict, inst::NewInstruction)

Produce a new instance of inst in which all instances of SSAValues are replaced with the IDs prescribed by d, all basic block numbers are replaced with the IDs prescribed by d, and GotoIfNot, GotoNode, and PhiNode instances are replaced with the corresponding ID versions.

source
Mooncake._ssas_to_idsMethod
_ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector}

Assigns an ID to each line in stmts, and replaces each instance of an SSAValue in each line with the corresponding ID. For example, a call statement of the form Expr(:call, :f, %4) is be replaced with Expr(:call, :f, id_assigned_to_%4).

source
Mooncake._to_ssasMethod
_to_ssas(d::Dict, inst::NewInstruction)

Like _ssas_to_ids, but in reverse. Converts IDs to SSAValues / (integers corresponding to ssas).

source
Mooncake._typeofMethod
_typeof(x)

Central definition of typeof, which is specific to the use-required in this package.

source
Mooncake.ad_stmt_infoMethod
ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)

Convenient constructor for ADStmtInfo. If either fwds or rvs is not a vector, __vec promotes it to a single-element Vector.

source
Mooncake.add_data!Method
add_data!(info::ADInfo, data)::ID

Equivalent to add_data!(info.shared_data_pairs, data).

source
Mooncake.add_data!Method
add_data!(p::SharedDataPairs, data)::ID

Puts data into p, and returns the id associated to it. This id should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id is always data.

source
Mooncake.add_data_if_not_singleton!Method
add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)

Returns x if it is a singleton, or the ID of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.

source
Mooncake.characterise_unique_predecessor_blocksMethod
characterise_unique_predecessor_blocks(blks::Vector{BBlock}) ->
-    Tuple{Dict{ID, Bool}, Dict{ID, Bool}}

We call a block b a unique predecessor in the control flow graph associated to blks if it is the only predecessor to all of its successors. Put differently we call b a unique predecessor if, whenever control flow arrives in any of the successors of b, we know for certain that the previous block must have been b.

Returns two Dicts. A value in the first Dict is true if the block associated to its key is a unique precessor, and is false if not. A value in the second Dict is true if it has a single predecessor, and that predecessor is a unique predecessor.

Context:

This information is important for optimising AD because knowing that b is a unique predecessor means that

  1. on the forwards-pass, there is no need to push the ID of b to the block stack when passing through it, and
  2. on the reverse-pass, there is no need to pop the block stack when passing through one of the successors to b.

Utilising this reduces the overhead associated to doing AD. It is quite important when working with cheap loops – loops where the operations performed at each iteration are inexpensive – for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all.

source
Mooncake.characterise_used_idsMethod
characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool}

For each line in stmts, determine whether it is referenced anywhere else in the code. Returns a dictionary containing the results. An element is false if the corresponding ID is unused, and true if is used.

source
Mooncake.collect_stmtsMethod
collect_stmts(ir::BBCode)::Vector{IDInstPair}

Produce a Vector containing all of the statements in ir. These are returned in order, so it is safe to assume that element n refers to the nth element of the IRCode associated to ir.

source
Mooncake.collect_stmtsMethod
collect_stmts(bb::BBlock)::Vector{IDInstPair}

Returns a Vector containing the IDs and instructions associated to each line in bb. These should be assumed to be ordered.

source
Mooncake.comms_channelMethod
comms_channel(info::ADStmtInfo)

Return the element of fwds whose ID is the communcation ID. Returns Nothing if comms_id is nothing.

source
Mooncake.conclude_rvs_blockMethod
conclude_rvs_block(
+) where {name, RT, nreq, calling_convention}

:foreigncall nodes get translated into calls to this function. For example,

Expr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)

becomes

_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)

Please consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.

Credit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.

source
Mooncake._ids_to_line_numbersMethod
_ids_to_line_numbers(bb_code::BBCode)::InstVector

For each statement in bb_code, returns a NewInstruction in which every ID is replaced by either an SSAValue, or an Int64 / Int32 which refers to an SSAValue.

source
Mooncake._is_reachableMethod
_is_reachable(blks::Vector{BBlock})::Vector{Bool}

Computes a Vector whose length is length(blks). The nth element is true iff it is possible for control flow to reach the nth block.

source
Mooncake._lines_to_blocksMethod
_instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector

Pulls out the instructions from insts, and calls __line_numbers_to_block_numbers!.

source
Mooncake._lower_switch_statementsMethod
_lower_switch_statements(bb_code::BBCode)

Converts all Switchs into a semantically-equivalent collection of GotoIfNots. See the Switch docstring for an explanation of what is going on here.

source
Mooncake._mapMethod
_map(f, x...)

Same as map but requires all elements of x to have equal length. The usual function map doesn't enforce this for Arrays.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)

Similar to the other method of _map_if_assigned! – for all n, if x1[n] is assigned, writes f(x1[n], x2[n]) to y[n], otherwise leaves y[n] unchanged.

Requires that y, x1, and x2 have the same size.

source
Mooncake._map_if_assigned!Method
_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}

For all n, if x[n] is assigned, then writes the value returned by f(x[n]) to y[n], otherwise leaves y[n] unchanged.

Equivalent to map!(f, y, x) if P is a bits type as element will always be assigned.

Requires that y and x have the same size.

source
Mooncake._new_Method
_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}

One-liner which calls the :new instruction with type T with arguments x.

source
Mooncake._remove_double_edgesMethod
_remove_double_edges(ir::BBCode)::BBCode

If the dest field of an IDGotoIfNot node in block n of ir points towards the n+1th block then we have two edges from block n to block n+1. This transformation replaces all such IDGotoIfNot nodes with unconditional IDGotoNodes pointing towards the n+1th block in ir.

source
Mooncake._sort_blocks!Method
_sort_blocks!(ir::BBCode)::BBCode

Ensure that blocks appear in order of distance-from-entry-point, where distance the distance from block b to the entry point is defined to be the minimum number of basic blocks that must be passed through in order to reach b.

For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem there.

WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in ir is valid. Notably, this does not hold if you have any IDGotoIfNot nodes in ir.

source
Mooncake._ssa_to_idsMethod
_ssa_to_ids(d::SSAToIdDict, inst::NewInstruction)

Produce a new instance of inst in which all instances of SSAValues are replaced with the IDs prescribed by d, all basic block numbers are replaced with the IDs prescribed by d, and GotoIfNot, GotoNode, and PhiNode instances are replaced with the corresponding ID versions.

source
Mooncake._ssas_to_idsMethod
_ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector}

Assigns an ID to each line in stmts, and replaces each instance of an SSAValue in each line with the corresponding ID. For example, a call statement of the form Expr(:call, :f, %4) is be replaced with Expr(:call, :f, id_assigned_to_%4).

source
Mooncake._to_ssasMethod
_to_ssas(d::Dict, inst::NewInstruction)

Like _ssas_to_ids, but in reverse. Converts IDs to SSAValues / (integers corresponding to ssas).

source
Mooncake._typeofMethod
_typeof(x)

Central definition of typeof, which is specific to the use-required in this package.

source
Mooncake.ad_stmt_infoMethod
ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)

Convenient constructor for ADStmtInfo. If either fwds or rvs is not a vector, __vec promotes it to a single-element Vector.

source
Mooncake.add_data!Method
add_data!(info::ADInfo, data)::ID

Equivalent to add_data!(info.shared_data_pairs, data).

source
Mooncake.add_data!Method
add_data!(p::SharedDataPairs, data)::ID

Puts data into p, and returns the id associated to it. This id should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id is always data.

source
Mooncake.add_data_if_not_singleton!Method
add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)

Returns x if it is a singleton, or the ID of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.

source
Mooncake.characterise_unique_predecessor_blocksMethod
characterise_unique_predecessor_blocks(blks::Vector{BBlock}) ->
+    Tuple{Dict{ID, Bool}, Dict{ID, Bool}}

We call a block b a unique predecessor in the control flow graph associated to blks if it is the only predecessor to all of its successors. Put differently we call b a unique predecessor if, whenever control flow arrives in any of the successors of b, we know for certain that the previous block must have been b.

Returns two Dicts. A value in the first Dict is true if the block associated to its key is a unique precessor, and is false if not. A value in the second Dict is true if it has a single predecessor, and that predecessor is a unique predecessor.

Context:

This information is important for optimising AD because knowing that b is a unique predecessor means that

  1. on the forwards-pass, there is no need to push the ID of b to the block stack when passing through it, and
  2. on the reverse-pass, there is no need to pop the block stack when passing through one of the successors to b.

Utilising this reduces the overhead associated to doing AD. It is quite important when working with cheap loops – loops where the operations performed at each iteration are inexpensive – for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all.

source
Mooncake.characterise_used_idsMethod
characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool}

For each line in stmts, determine whether it is referenced anywhere else in the code. Returns a dictionary containing the results. An element is false if the corresponding ID is unused, and true if is used.

source
Mooncake.collect_stmtsMethod
collect_stmts(ir::BBCode)::Vector{IDInstPair}

Produce a Vector containing all of the statements in ir. These are returned in order, so it is safe to assume that element n refers to the nth element of the IRCode associated to ir.

source
Mooncake.collect_stmtsMethod
collect_stmts(bb::BBlock)::Vector{IDInstPair}

Returns a Vector containing the IDs and instructions associated to each line in bb. These should be assumed to be ordered.

source
Mooncake.comms_channelMethod
comms_channel(info::ADStmtInfo)

Return the element of fwds whose ID is the communcation ID. Returns Nothing if comms_id is nothing.

source
Mooncake.conclude_rvs_blockMethod
conclude_rvs_block(
     blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
-)

Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.

source
Mooncake.const_codualMethod
const_codual(stmt, info::ADInfo)

Build a CoDual from stmt, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.

source
Mooncake.create_comms_insts!Method
create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)

This function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id field of the ADStmtInfo type – namely that if you assign a value to comms_id on the forwards-pass, the same value will be available at comms_id on the reverse-pass.

For each basic block represented in ADStmts:

  1. create a stack containing a Tuple which can hold all of the values associated to the comms_ids for each statement. Put this stack in shared data.
  2. create instructions which can be inserted at the end of the block generated to perform the forwards-pass (in forwards_pass_ir) which will put all of the data associated to the comms_ids into shared data, and
  3. create instruction which can be inserted at the start of the block generated to perform the reverse-pass (in pullback_ir), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to the comms_ids.

Returns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}. The nth element of each Vector corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks.

source
Mooncake.fdata_field_typeMethod
fdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the fdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.foreigncall_to_callMethod
foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})

If inst is a :foreigncall expression translate it into an equivalent :call expression. If anything else, just return inst. See Mooncake._foreigncall_ for details.

sp_map maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode, in which case the values of sp_map are given by the sptypes field of said IRCode. The keys should generally be obtained from the Method from which the IRCode is derived. See Mooncake.normalise! for more details.

The purpose of this transformation is to make it possible to differentiate :foreigncall expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.forwards_pass_irMethod
forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the forwards-pass.

source
Mooncake.fwd_irMethod
fwd_ir(
+)

Generates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.

source
Mooncake.const_codualMethod
const_codual(stmt, info::ADInfo)

Build a CoDual from stmt, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.

source
Mooncake.const_codual_stmtMethod
const_codual_stmt(stmt, info::ADInfo)

Returns a :call expression which will return a CoDual whose primal is stmt, and whose tangent is whatever uninit_tangent returns.

source
Mooncake.create_comms_insts!Method
create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)

This function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id field of the ADStmtInfo type – namely that if you assign a value to comms_id on the forwards-pass, the same value will be available at comms_id on the reverse-pass.

For each basic block represented in ADStmts:

  1. create a stack containing a Tuple which can hold all of the values associated to the comms_ids for each statement. Put this stack in shared data.
  2. create instructions which can be inserted at the end of the block generated to perform the forwards-pass (in forwards_pass_ir) which will put all of the data associated to the comms_ids into shared data, and
  3. create instruction which can be inserted at the start of the block generated to perform the reverse-pass (in pullback_ir), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to the comms_ids.

Returns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}. The nth element of each Vector corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks.

source
Mooncake.fdata_field_typeMethod
fdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the fdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.foreigncall_to_callMethod
foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})

If inst is a :foreigncall expression translate it into an equivalent :call expression. If anything else, just return inst. See Mooncake._foreigncall_ for details.

sp_map maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode, in which case the values of sp_map are given by the sptypes field of said IRCode. The keys should generally be obtained from the Method from which the IRCode is derived. See Mooncake.normalise! for more details.

The purpose of this transformation is to make it possible to differentiate :foreigncall expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.forwards_pass_irMethod
forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the forwards-pass.

source
Mooncake.fwd_irMethod
fwd_ir(
     sig::Type{<:Tuple};
     interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
 )::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.gc_preserveMethod
gc_preserve(xs...)

A no-op function. Its rrule!! ensures that the memory associated to xs is not freed until the pullback that it returns is run.

source
Mooncake.generate_irMethod
generate_ir(
+true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.gc_preserveMethod
gc_preserve(xs...)

A no-op function. Its rrule!! ensures that the memory associated to xs is not freed until the pullback that it returns is run.

source
Mooncake.generate_irMethod
generate_ir(
     interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true
-)

Used by build_rrule, and the various debugging tools: primalir, fwdsir, adjoint_ir.

source
Mooncake.get_rev_data_idMethod
get_rev_data_id(info::ADInfo, x)

Returns the ID associated to the line in the reverse pass which will contain the reverse data for x. If x is not an Argument or ID, then nothing is returned.

source
Mooncake.get_tangent_fieldMethod
get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)

Gets the ith field of data in t.

Has the same semantics that getfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of getfield for MutableTangent.

source
Mooncake.id_to_line_mapMethod
id_to_line_map(ir::BBCode)

Produces a Dict mapping from each ID associated with a line in ir to its line number. This is isomorphic to mapping to its SSAValue in IRCode. Terminators do not have IDs associated to them, so not every line in the original IRCode is mapped to.

source
Mooncake.increment_and_get_rdata!Method
increment_and_get_rdata!(fdata, zero_rdata, cr_tangent)

Increment fdata by the fdata component of the ChainRules.jl-style tangent, cr_tangent, and return the rdata component of cr_tangent by adding it to zero_rdata.

source
Mooncake.increment_rdata!!Method
increment_rdata!!(t::T, r)::T where {T}

Increment the rdata component of tangent t by r, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.

source
Mooncake.infer_ir!Method
infer_ir!(ir::IRCode) -> IRCode

Runs type inference on ir, which mutates ir, and returns it.

Note: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag is not set to Core.Compiler.IR_FLAG_REFINED. Nor will it attempt to refine the type of the value returned by a :invoke expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.

source
Mooncake.insert_before_terminator!Method
insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing

If the final instruction in bb is a Terminator, insert inst immediately before it. Otherwise, insert inst at the end of the block.

source
Mooncake.interpolate_boundschecks!Method
interpolate_boundschecks!(ir::IRCode)

For every x = Expr(:boundscheck, value) in ir, interpolate value into all uses of x. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.

source
Mooncake.intrinsic_to_functionMethod
intrinsic_to_function(inst)

If inst is a :call expression to a Core.IntrinsicFunction, replace it with a call to the corresponding function from Mooncake.IntrinsicsWrappers, else return inst.

cglobal is a special case – it requires that its first argument be static in exactly the same way as :foreigncall. See IntrinsicsWrappers.__cglobal for more info.

The purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers for more context.

source
Mooncake.ircodeFunction
ircode(
+)

Used by build_rrule, and the various debugging tools: primalir, fwdsir, adjoint_ir.

source
Mooncake.get_rev_data_idMethod
get_rev_data_id(info::ADInfo, x)

Returns the ID associated to the line in the reverse pass which will contain the reverse data for x. If x is not an Argument or ID, then nothing is returned.

source
Mooncake.get_tangent_fieldMethod
get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)

Gets the ith field of data in t.

Has the same semantics that getfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of getfield for MutableTangent.

source
Mooncake.id_to_line_mapMethod
id_to_line_map(ir::BBCode)

Produces a Dict mapping from each ID associated with a line in ir to its line number. This is isomorphic to mapping to its SSAValue in IRCode. Terminators do not have IDs associated to them, so not every line in the original IRCode is mapped to.

source
Mooncake.increment_and_get_rdata!Method
increment_and_get_rdata!(fdata, zero_rdata, cr_tangent)

Increment fdata by the fdata component of the ChainRules.jl-style tangent, cr_tangent, and return the rdata component of cr_tangent by adding it to zero_rdata.

source
Mooncake.increment_rdata!!Method
increment_rdata!!(t::T, r)::T where {T}

Increment the rdata component of tangent t by r, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.

source
Mooncake.infer_ir!Method
infer_ir!(ir::IRCode) -> IRCode

Runs type inference on ir, which mutates ir, and returns it.

Note: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag is not set to Core.Compiler.IR_FLAG_REFINED. Nor will it attempt to refine the type of the value returned by a :invoke expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.

source
Mooncake.insert_before_terminator!Method
insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing

If the final instruction in bb is a Terminator, insert inst immediately before it. Otherwise, insert inst at the end of the block.

source
Mooncake.interpolate_boundschecks!Method
interpolate_boundschecks!(ir::IRCode)

For every x = Expr(:boundscheck, value) in ir, interpolate value into all uses of x. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.

source
Mooncake.intrinsic_to_functionMethod
intrinsic_to_function(inst)

If inst is a :call expression to a Core.IntrinsicFunction, replace it with a call to the corresponding function from Mooncake.IntrinsicsWrappers, else return inst.

cglobal is a special case – it requires that its first argument be static in exactly the same way as :foreigncall. See IntrinsicsWrappers.__cglobal for more info.

The purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers for more context.

source
Mooncake.ircodeFunction
ircode(
     inst::Vector{Any},
     argtypes::Vector{Any},
     sptypes::Vector{CC.VarState}=CC.VarState[],
-) -> IRCode

Constructs an instance of an IRCode. This is useful for constructing test cases with known properties.

No optimisations or type inference are performed on the resulting IRCode, so that the IRCode contains exactly what is intended by the caller. Please make use of infer_types! if you require the types to be inferred.

Edges in PhiNodes, GotoIfNots, and GotoNodes found in inst must refer to lines (as in CodeInfo). In the IRCode returned by this function, these line references are translated into block references.

source
Mooncake.is_always_fully_initialisedMethod
is_always_fully_initialised(P::DataType)::Bool

True if all fields in P are always initialised. Put differently, there are no inner constructors which permit partial initialisation.

source
Mooncake.is_always_initialisedMethod
is_always_initialised(P::DataType, n::Int)::Bool

True if the nth field of P is always initialised. If the nth fieldtype of P isbitstype, then this is distinct from asking whether the nth field is always defined. An isbits field is always defined, but is not always explicitly initialised.

source
Mooncake.is_primitiveMethod
is_primitive(::Type{Ctx}, sig) where {Ctx}

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.

is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})

will return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.

Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.

source
Mooncake.is_reachable_return_nodeMethod
is_reachable_return_node(x::ReturnNode)

Determine whether x is a ReturnNode, and if it is, if it is also reachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_unreachable_return_nodeMethod
is_unreachable_return_node(x::ReturnNode)

Determine whehter x is a ReturnNode, and if it is, if it is also unreachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_usedMethod
is_used(info::ADInfo, id::ID)::Bool

Returns true if id is used by any of the lines in the ir, false otherwise.

source
Mooncake.is_vararg_and_sparam_namesMethod
is_vararg_and_sparam_names(m::Method)

Returns a 2-tuple. The first element is true if m is a vararg method, and false if not. The second element contains the names of the static parameters associated to m.

source
Mooncake.lgetfieldMethod
lgetfield(x, f::Val)

An implementation of getfield in which the the field f is specified statically via a Val. This enables the implementation to be type-stable even when it is not possible to constant-propagate f. Moreover, it enable the pullback to also be type-stable.

It will always be the case that

getfield(x, :f) === lgetfield(x, Val(:f))
-getfield(x, 2) === lgetfield(x, Val(2))

This approach is identical to the one taken by Zygote.jl to circumvent the same problem. Zygote.jl calls the function literal_getfield, while we call it lgetfield.

source
Mooncake.lgetfieldMethod
lgetfield(x, ::Val{f}, ::Val{order}) where {f, order}

Like getfield, but with the field and access order encoded as types.

source
Mooncake.lift_gc_preservationMethod
lift_gc_preserve(inst)

Expressions of the form

y = GC.@preserve x1 x2 foo(args...)

get lowered to

token = Expr(:gc_preserve_begin, x1, x2)
+) -> IRCode

Constructs an instance of an IRCode. This is useful for constructing test cases with known properties.

No optimisations or type inference are performed on the resulting IRCode, so that the IRCode contains exactly what is intended by the caller. Please make use of infer_types! if you require the types to be inferred.

Edges in PhiNodes, GotoIfNots, and GotoNodes found in inst must refer to lines (as in CodeInfo). In the IRCode returned by this function, these line references are translated into block references.

source
Mooncake.is_always_fully_initialisedMethod
is_always_fully_initialised(P::DataType)::Bool

True if all fields in P are always initialised. Put differently, there are no inner constructors which permit partial initialisation.

source
Mooncake.is_always_initialisedMethod
is_always_initialised(P::DataType, n::Int)::Bool

True if the nth field of P is always initialised. If the nth fieldtype of P isbitstype, then this is distinct from asking whether the nth field is always defined. An isbits field is always defined, but is not always explicitly initialised.

source
Mooncake.is_primitiveMethod
is_primitive(::Type{Ctx}, sig) where {Ctx}

Returns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.

is_primitive(DefaultCtx, Tuple{typeof(sin), Float64})

will return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.

Observe that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.

source
Mooncake.is_reachable_return_nodeMethod
is_reachable_return_node(x::ReturnNode)

Determine whether x is a ReturnNode, and if it is, if it is also reachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_unreachable_return_nodeMethod
is_unreachable_return_node(x::ReturnNode)

Determine whehter x is a ReturnNode, and if it is, if it is also unreachable. This is purely a function of whether or not its val field is defined or not.

source
Mooncake.is_usedMethod
is_used(info::ADInfo, id::ID)::Bool

Returns true if id is used by any of the lines in the ir, false otherwise.

source
Mooncake.is_vararg_and_sparam_namesMethod
is_vararg_and_sparam_names(m::Method)

Returns a 2-tuple. The first element is true if m is a vararg method, and false if not. The second element contains the names of the static parameters associated to m.

source
Mooncake.lgetfieldMethod
lgetfield(x, f::Val)

An implementation of getfield in which the the field f is specified statically via a Val. This enables the implementation to be type-stable even when it is not possible to constant-propagate f. Moreover, it enable the pullback to also be type-stable.

It will always be the case that

getfield(x, :f) === lgetfield(x, Val(:f))
+getfield(x, 2) === lgetfield(x, Val(2))

This approach is identical to the one taken by Zygote.jl to circumvent the same problem. Zygote.jl calls the function literal_getfield, while we call it lgetfield.

source
Mooncake.lgetfieldMethod
lgetfield(x, ::Val{f}, ::Val{order}) where {f, order}

Like getfield, but with the field and access order encoded as types.

source
Mooncake.lift_gc_preservationMethod
lift_gc_preserve(inst)

Expressions of the form

y = GC.@preserve x1 x2 foo(args...)

get lowered to

token = Expr(:gc_preserve_begin, x1, x2)
 y = expr
 Expr(:gc_preserve_end, token)

These expressions guarantee that any memory associated x1 and x2 not be freed until the :gc_preserve_end expression is reached.

In the context of reverse-mode AD, we must ensure that the memory associated to x1, x2 and their fdata is available during the reverse pass code associated to expr. We do this by preventing the memory from being freed until the :gc_preserve_begin is reached on the reverse pass.

To achieve this, we replace the primal code with

# store `x` in `pb_gc_preserve` to prevent it from being freed.
 _, pb_gc_preserve = rrule!!(zero_fcodual(gc_preserve), x1, x2)
@@ -63,11 +63,11 @@
 _, dargs... = foo_pb(dy)
 
 # No-op pullback associated to `gc_preserve`.
-pb_gc_preserve(NoRData())
source
Mooncake.lift_getfield_and_othersMethod
lift_getfield_and_others(inst)

Converts expressions of the form getfield(x, :a) into lgetfield(x, Val(:a)). This has identical semantics, but is performant in the absence of proper constant propagation.

Does the same for...

source
Mooncake.lift_getfield_and_othersMethod
lift_getfield_and_others(inst)

Converts expressions of the form getfield(x, :a) into lgetfield(x, Val(:a)). This has identical semantics, but is performant in the absence of proper constant propagation.

Does the same for...

source
Mooncake.lookup_irMethod
lookup_ir(
     interp::AbstractInterpreter,
     sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},
-)::Tuple{IRCode, T}

Get the unique IR associated to sig_or_mi under interp. Throws ArgumentErrors if there is no code found, or if more than one IRCode instance returned.

Returns a tuple containing the IRCode and its return type.

source
Mooncake.lsetfield!Method
lsetfield!(value, name::Val, x, [order::Val])

This function is to setfield! what lgetfield is to getfield. It will always hold that

setfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)
-setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)
source
Mooncake.make_ad_stmts!Function
make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo

Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.

Translates the instruction inst, associated to line in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo.

info is a data structure containing various bits of global information that certain types of nodes need access to.

source
Mooncake.make_switch_stmtsMethod
make_switch_stmts(
+)::Tuple{IRCode, T}

Get the unique IR associated to sig_or_mi under interp. Throws ArgumentErrors if there is no code found, or if more than one IRCode instance returned.

Returns a tuple containing the IRCode and its return type.

source
Mooncake.lsetfield!Method
lsetfield!(value, name::Val, x, [order::Val])

This function is to setfield! what lgetfield is to getfield. It will always hold that

setfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)
+setfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)
source
Mooncake.make_ad_stmts!Function
make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo

Every line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.

Translates the instruction inst, associated to line in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo.

info is a data structure containing various bits of global information that certain types of nodes need access to.

source
Mooncake.make_switch_stmtsMethod
make_switch_stmts(
     pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo
 )

preds_ids comprises the IDs associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is [ID(1), ID(2), ID(3)], then make_switch_stmts emits code along the lines of

prev_block = pop!(block_stack)
 not_pred_was_1 = !(prev_block == ID(1))
@@ -76,29 +76,29 @@
     not_pred_was_1 => ID(1),
     not_pred_was_2 => ID(2),
     ID(3)
-)

In words: make_switch_stmts emits code which jumps to whichever block preceded the current block during the forwards-pass.

source
Mooncake.new_instFunction
new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction

Create a NewInstruction with fields:

  • stmt = stmt
  • type = type
  • info = CC.NoCallInfo()
  • line = Int32(1)
  • flag = flag
source
Mooncake.new_inst_vecMethod
new_inst_vec(x::CC.InstructionStream)

Convert an Compiler.InstructionStream into a list of Compiler.NewInstructions.

source
Mooncake.new_to_callMethod
new_to_call(x)

If instruction x is a :new expression, replace it with a :call to Mooncake._new_. Otherwise, return x.

The purpose of this transformation is to make it possible to differentiate :new expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.normalise!Method
normalise!(ir::IRCode, spnames::Vector{Symbol})

Apply a sequence of standardising transformations to ir which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace

  1. :foreigncall Exprs with :calls to Mooncake._foreigncall_,
  2. :new Exprs with :calls to Mooncake._new_,
  3. :splatnew Exprs with:calls toMooncake.splatnew_`,
  4. Core.IntrinsicFunctions with counterparts from Mooncake.IntrinsicWrappers,
  5. getfield(x, 1) with lgetfield(x, Val(1)), and related transformations,
  6. memoryrefget calls to lmemoryrefget calls, and related transformations,
  7. gc_preserve_begin / gc_preserve_end exprs so that memory release is delayed.

spnames are the names associated to the static parameters of ir. These are needed when handling :foreigncall expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter expressions.

Unfortunately, the static parameter names are not retained in IRCode, and the Method from which the IRCode is derived must be consulted. Mooncake.is_vararg_and_sparam_names provides a convenient way to do this.

source
Mooncake.optimise_ir!Method
optimise_ir!(ir::IRCode, show_ir=false)

Run a fairly standard optimisation pass on ir. If show_ir is true, displays the IR to stdout at various points in the pipeline – this is sometimes useful for debugging.

source
Mooncake.phi_nodesMethod
phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}}

Returns all of the IDPhiNodes at the start of bb, along with their IDs. If there are no IDPhiNodes at the start of bb, then both vectors will be empty.

source
Mooncake.primal_irMethod
primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Get the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).

For example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:

julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
+)

In words: make_switch_stmts emits code which jumps to whichever block preceded the current block during the forwards-pass.

source
Mooncake.new_instFunction
new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction

Create a NewInstruction with fields:

  • stmt = stmt
  • type = type
  • info = CC.NoCallInfo()
  • line = Int32(1)
  • flag = flag
source
Mooncake.new_inst_vecMethod
new_inst_vec(x::CC.InstructionStream)

Convert an Compiler.InstructionStream into a list of Compiler.NewInstructions.

source
Mooncake.new_to_callMethod
new_to_call(x)

If instruction x is a :new expression, replace it with a :call to Mooncake._new_. Otherwise, return x.

The purpose of this transformation is to make it possible to differentiate :new expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.normalise!Method
normalise!(ir::IRCode, spnames::Vector{Symbol})

Apply a sequence of standardising transformations to ir which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace

  1. :foreigncall Exprs with :calls to Mooncake._foreigncall_,
  2. :new Exprs with :calls to Mooncake._new_,
  3. :splatnew Exprs with:calls toMooncake.splatnew_`,
  4. Core.IntrinsicFunctions with counterparts from Mooncake.IntrinsicWrappers,
  5. getfield(x, 1) with lgetfield(x, Val(1)), and related transformations,
  6. memoryrefget calls to lmemoryrefget calls, and related transformations,
  7. gc_preserve_begin / gc_preserve_end exprs so that memory release is delayed.

spnames are the names associated to the static parameters of ir. These are needed when handling :foreigncall expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter expressions.

Unfortunately, the static parameter names are not retained in IRCode, and the Method from which the IRCode is derived must be consulted. Mooncake.is_vararg_and_sparam_names provides a convenient way to do this.

source
Mooncake.optimise_ir!Method
optimise_ir!(ir::IRCode, show_ir=false)

Run a fairly standard optimisation pass on ir. If show_ir is true, displays the IR to stdout at various points in the pipeline – this is sometimes useful for debugging.

source
Mooncake.phi_nodesMethod
phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}}

Returns all of the IDPhiNodes at the start of bb, along with their IDs. If there are no IDPhiNodes at the start of bb, then both vectors will be empty.

source
Mooncake.primal_irMethod
primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Get the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).

For example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:

julia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true
source
Mooncake.pullback_irMethod
pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the pullback.

source
Mooncake.pullback_typeMethod
pullback_type(Trule, arg_types)

Get a bound on the pullback type, given a rule and associated primal types.

source
Mooncake.rdata_field_typeMethod
rdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the rdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.remove_unreachable_blocks!Method
remove_unreachable_blocks!(ir::BBCode)::BBCode

If a basic block in ir cannot possibly be reached during execution, then it can be safely removed from ir without changing its functionality. A block is unreachable if either:

  1. it has no predecessors and it is not the first block, or
  2. all of its predecessors are themselves unreachable.

For example, consider the following IR:

julia> ir = Mooncake.ircode(
+true
source
Mooncake.pullback_irMethod
pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)

Produce the IR associated to the OpaqueClosure which runs most of the pullback.

source
Mooncake.pullback_typeMethod
pullback_type(Trule, arg_types)

Get a bound on the pullback type, given a rule and associated primal types.

source
Mooncake.rdata_field_typeMethod
rdata_field_type(::Type{P}, n::Int) where {P}

Returns the type of to the nth field of the rdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.

source
Mooncake.remove_unreachable_blocks!Method
remove_unreachable_blocks!(ir::BBCode)::BBCode

If a basic block in ir cannot possibly be reached during execution, then it can be safely removed from ir without changing its functionality. A block is unreachable if either:

  1. it has no predecessors and it is not the first block, or
  2. all of its predecessors are themselves unreachable.

For example, consider the following IR:

julia> ir = Mooncake.ircode(
            Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))],
            Any[Any, Any, Any],
        );

There is no possible way to reach the second basic block (lines 2 and 3). Applying this function will therefore remove it, yielding the following:

julia> Mooncake.IRCode(Mooncake.remove_unreachable_blocks!(Mooncake.BBCode(ir)))
-1 1 ─     return nothing

In the blocks which have not been removed, there may be references to blocks which have been removed. For example, the edges in a PhiNode may contain a reference to a removed block. These references are removed in-place from these remaining blocks, so this function will (in general) modify ir.

source
Mooncake.replace_capturesMethod
replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}

Same as replace_captures for Core.OpaqueClosures, but returns a new MistyClosure.

source
Mooncake.replace_capturesMethod
replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}

Given an OpaqueClosure oc, create a new OpaqueClosure of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosures that it produces multiple times, because it can be quite expensive to do so.

source
Mooncake.replace_uses_with!Method
replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)

Replace all uses of def with val in the single statement stmt. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl. You probably do not want to use it.

source
Mooncake.rrule_wrapperMethod
rrule_wrapper(f::CoDual, args::CoDual...)

Used to implement rrule!!s via ChainRulesCore.rrule.

Given a function foo, argument types arg_types, and a method of ChainRulesCore.rrule which applies to these, you can make use of this function as follows:

Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}
+1 1 ─     return nothing

In the blocks which have not been removed, there may be references to blocks which have been removed. For example, the edges in a PhiNode may contain a reference to a removed block. These references are removed in-place from these remaining blocks, so this function will (in general) modify ir.

source
Mooncake.replace_capturesMethod
replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}

Same as replace_captures for Core.OpaqueClosures, but returns a new MistyClosure.

source
Mooncake.replace_capturesMethod
replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}

Given an OpaqueClosure oc, create a new OpaqueClosure of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosures that it produces multiple times, because it can be quite expensive to do so.

source
Mooncake.replace_uses_with!Method
replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)

Replace all uses of def with val in the single statement stmt. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl. You probably do not want to use it.

source
Mooncake.rrule_wrapperMethod
rrule_wrapper(f::CoDual, args::CoDual...)

Used to implement rrule!!s via ChainRulesCore.rrule.

Given a function foo, argument types arg_types, and a method of ChainRulesCore.rrule which applies to these, you can make use of this function as follows:

Mooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}
 function Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...)
     return rrule_wrapper(f, args...)
-end

Assumes that methods of to_cr_tangent and to_mooncake_tangent are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.

Furthermore, it is essential that

  1. f(args) does not mutate f or args, and
  2. the result of f(args) does not alias any data stored in f or args.

Subject to some constraints, you can use the @from_rrule macro to reduce the amount of boilerplate code that you are required to write even further.

source
Mooncake.rule_typeMethod
rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}

Compute the concrete type of the rule that will be returned from build_rrule. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.

source
Mooncake.rvs_irMethod
rvs_ir(
+end

Assumes that methods of to_cr_tangent and to_mooncake_tangent are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.

Furthermore, it is essential that

  1. f(args) does not mutate f or args, and
  2. the result of f(args) does not alias any data stored in f or args.

Subject to some constraints, you can use the @from_rrule macro to reduce the amount of boilerplate code that you are required to write even further.

source
Mooncake.rule_typeMethod
rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}

Compute the concrete type of the rule that will be returned from build_rrule. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.

source
Mooncake.rvs_irMethod
rvs_ir(
     sig::Type{<:Tuple};
     interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true
 )::IRCode

!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.

Generate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.

For example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:

julia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode
 true
 julia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode
-true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_phi_blockMethod
rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)

Produces a BBlock which runs the reverse-pass for the edge associated to pred_id in a collection of IDPhiNodes, and then goes to the block associated to pred_id.

For example, suppose that we encounter the following collection of PhiNodes at the start of some block:

%6 = φ (#2 => _1, #3 => %5)
+true

Arguments

  • sig::Type{<:Tuple}: the signature of the call to be differentiated.

Keyword Arguments

  • interp: the interpreter to use to obtain the primal IR.
  • debug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.
  • do_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.
source
Mooncake.rvs_phi_blockMethod
rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)

Produces a BBlock which runs the reverse-pass for the edge associated to pred_id in a collection of IDPhiNodes, and then goes to the block associated to pred_id.

For example, suppose that we encounter the following collection of PhiNodes at the start of some block:

%6 = φ (#2 => _1, #3 => %5)
 %7 = φ (#2 => 5., #3 => _2)

Let the tangent refs associated to %6, %7, and _1be denotedt%6,t%7, andt1resp., and letpredidbe#2`, then this function will produce a basic block of the form

increment_ref!(t_1, t%6)
 nothing
-goto #2

The call to increment_ref! appears because _1 is the value associated to%6 when the primal code comes from #2. Similarly, the goto #2 statement appears because we came from #2 on the forwards-pass. There is no increment_ref! associated to %7 because 5. is a constant. We emit a nothing statement, which the compiler will happily optimise away later on.

The same ideas apply if pred_id were #3. The block would end with #3, and there would be two increment_ref! calls because both %5 and _2 are not constants.

source
Mooncake.seed_id!Method
seed_id!()

Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to ensure determinism between two runs of the same function which makes use of IDs.

This is akin to setting the random seed associated to a random number generator globally.

source
Mooncake.set_tangent_field!Method
set_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}

Sets the value of the ith field of the data in t to value x.

Has the same semantics that setfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of setfield! for MutableTangent.

source
Mooncake.shared_data_stmtsMethod
shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}

Produce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p) such that the correct value is associated to the correct ID.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

IDInstPair[
+goto #2

The call to increment_ref! appears because _1 is the value associated to%6 when the primal code comes from #2. Similarly, the goto #2 statement appears because we came from #2 on the forwards-pass. There is no increment_ref! associated to %7 because 5. is a constant. We emit a nothing statement, which the compiler will happily optimise away later on.

The same ideas apply if pred_id were #3. The block would end with #3, and there would be two increment_ref! calls because both %5 and _2 are not constants.

source
Mooncake.seed_id!Method
seed_id!()

Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to ensure determinism between two runs of the same function which makes use of IDs.

This is akin to setting the random seed associated to a random number generator globally.

source
Mooncake.set_tangent_field!Method
set_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}

Sets the value of the ith field of the data in t to value x.

Has the same semantics that setfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of setfield! for MutableTangent.

source
Mooncake.shared_data_stmtsMethod
shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}

Produce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p) such that the correct value is associated to the correct ID.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

IDInstPair[
     (ID(5), new_inst(:(getfield(_1, 1)))),
     (ID(3), new_inst(:(getfield(_1, 2)))),
-]
source
Mooncake.shared_data_tupleMethod
shared_data_tuple(p::SharedDataPairs)::Tuple

Create the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosures.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

(5.0, "hello")
source
Mooncake.sparam_namesMethod
sparam_names(m::Core.Method)::Vector{Symbol}

Returns the names of all of the static parameters in m.

source
Mooncake.splatnew_to_callMethod
splatnew_to_call(x)

If instruction x is a :splatnew expression, replace it with a :call to Mooncake._splat_new_. Otherwise return x.

The purpose of this transformation is to make it possible to differentiate :splatnew expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.stmtMethod
stmt(ir::CC.InstructionStream)

Get the field containing the instructions in ir. This changed name in 1.11 from inst to stmt.

source
Mooncake.tangent_field_typeMethod
tangent_field_type(::Type{P}, n::Int) where {P}

Returns the type that lives in the nth elements of fields in a Tangent / MutableTangent. Will either be the tangent_type of the nth fieldtype of P, or the tangent_type wrapped in a PossiblyUninitTangent. The latter case only occurs if it is possible for the field to be undefined.

source
Mooncake.tangent_test_casesMethod
tangent_test_cases()

Constructs a Vector of Tuples containing test cases for the tangent infrastructure.

If the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value

interface_only is a Bool which will be used to determine which subset of tests to run.

If the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).

Test cases in the first format make use of zero_tangent / randn_tangent etc to generate tangents, but they're unable to check that increment!! is correct in an absolute sense.

source
Mooncake.terminatorMethod
terminator(bb::BBlock)

Returns the terminator associated to bb. If the last instruction in bb isa Terminator then that is returned, otherwise nothing is returned.

source
Mooncake.tuple_mapMethod
tuple_map(f::F, x::Tuple) where {F}

This function is largely equivalent to map(f, x), but always specialises on all of the element types of x, regardless the length of x. This contrasts with map, in which the number of element types specialised upon is a fixed constant in the compiler.

As a consequence, if x is very long, this function may have very large compile times.

tuple_map(f::F, x::Tuple, y::Tuple) where {F}

Binary extension of tuple_map. Nearly equivalent to map(f, x, y), but guaranteed to specialise on all element types of x and y. Furthermore, errors if x and y aren't the same length, while map will just produce a new tuple whose length is equal to the shorter of x and y.

source
Mooncake.uninit_fcodualMethod
uninit_fcodual(x)

Like zero_fcodual, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.

source
Mooncake.uninit_tangentMethod
uninit_tangent(x)

Related to zero_tangent, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.

source
Mooncake.verify_fdata_typeMethod
verify_fdata_type(P::Type, F::Type)::Nothing

Check that F is a valid type for fdata associated to a primal of type P. Returns nothing if valid, throws an InvalidFDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_fdata_valueMethod
verify_fdata_value(p, f)::Nothing

Check that f cannot be proven to be invalid fdata for p.

This method attempts to provide some confidence that f is valid fdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that f is valid fdata, only that it is not obviously invalid.

source
Mooncake.verify_rdata_typeMethod
verify_rdata_type(P::Type, R::Type)::Nothing

Check that R is a valid type for rdata associated to a primal of type P. Returns nothing if valid, throws an InvalidRDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_rdata_valueMethod
verify_rdata_value(p, r)::Nothing

Check that r cannot be proven to be invalid rdata for p.

This method attempts to provide some confidence that r is valid rdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that r is valid rdata, only that it is not obviously invalid.

source
Mooncake.zero_adjointMethod
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
+]
source
Mooncake.shared_data_tupleMethod
shared_data_tuple(p::SharedDataPairs)::Tuple

Create the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosures.

For example, if p.pairs is

[(ID(5), 5.0), (ID(3), "hello")]

then the output of this function is

(5.0, "hello")
source
Mooncake.sparam_namesMethod
sparam_names(m::Core.Method)::Vector{Symbol}

Returns the names of all of the static parameters in m.

source
Mooncake.splatnew_to_callMethod
splatnew_to_call(x)

If instruction x is a :splatnew expression, replace it with a :call to Mooncake._splat_new_. Otherwise return x.

The purpose of this transformation is to make it possible to differentiate :splatnew expressions in the same way as a primitive :call expression, i.e. via an rrule!!.

source
Mooncake.stmtMethod
stmt(ir::CC.InstructionStream)

Get the field containing the instructions in ir. This changed name in 1.11 from inst to stmt.

source
Mooncake.tangent_field_typeMethod
tangent_field_type(::Type{P}, n::Int) where {P}

Returns the type that lives in the nth elements of fields in a Tangent / MutableTangent. Will either be the tangent_type of the nth fieldtype of P, or the tangent_type wrapped in a PossiblyUninitTangent. The latter case only occurs if it is possible for the field to be undefined.

source
Mooncake.tangent_test_casesMethod
tangent_test_cases()

Constructs a Vector of Tuples containing test cases for the tangent infrastructure.

If the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value

interface_only is a Bool which will be used to determine which subset of tests to run.

If the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).

Test cases in the first format make use of zero_tangent / randn_tangent etc to generate tangents, but they're unable to check that increment!! is correct in an absolute sense.

source
Mooncake.terminatorMethod
terminator(bb::BBlock)

Returns the terminator associated to bb. If the last instruction in bb isa Terminator then that is returned, otherwise nothing is returned.

source
Mooncake.tuple_mapMethod
tuple_map(f::F, x::Tuple) where {F}

This function is largely equivalent to map(f, x), but always specialises on all of the element types of x, regardless the length of x. This contrasts with map, in which the number of element types specialised upon is a fixed constant in the compiler.

As a consequence, if x is very long, this function may have very large compile times.

tuple_map(f::F, x::Tuple, y::Tuple) where {F}

Binary extension of tuple_map. Nearly equivalent to map(f, x, y), but guaranteed to specialise on all element types of x and y. Furthermore, errors if x and y aren't the same length, while map will just produce a new tuple whose length is equal to the shorter of x and y.

source
Mooncake.uninit_fcodualMethod
uninit_fcodual(x)

Like zero_fcodual, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.

source
Mooncake.uninit_tangentMethod
uninit_tangent(x)

Related to zero_tangent, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.

source
Mooncake.verify_fdata_typeMethod
verify_fdata_type(P::Type, F::Type)::Nothing

Check that F is a valid type for fdata associated to a primal of type P. Returns nothing if valid, throws an InvalidFDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_fdata_valueMethod
verify_fdata_value(p, f)::Nothing

Check that f cannot be proven to be invalid fdata for p.

This method attempts to provide some confidence that f is valid fdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that f is valid fdata, only that it is not obviously invalid.

source
Mooncake.verify_rdata_typeMethod
verify_rdata_type(P::Type, R::Type)::Nothing

Check that R is a valid type for rdata associated to a primal of type P. Returns nothing if valid, throws an InvalidRDataException if a problem is found.

This applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.

source
Mooncake.verify_rdata_valueMethod
verify_rdata_value(p, r)::Nothing

Check that r cannot be proven to be invalid rdata for p.

This method attempts to provide some confidence that r is valid rdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.

Put differently, we cannot prove that r is valid rdata, only that it is not obviously invalid.

source
Mooncake.zero_adjointMethod
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
 
 julia> foo(x::Vararg{Int}) = 5
 foo (generic function with 1 method)
@@ -108,7 +108,7 @@
 julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
 
 julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
-(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source
Mooncake.zero_like_rdata_from_typeMethod
zero_like_rdata_from_type(::Type{P}) where {P}

This is an internal implementation detail – you should generally not use this function.

Returns either the zero element of type rdata_type(tangent_type(P)), or a ZeroRData. It is always valid to return a ZeroRData,

source
Mooncake.zero_like_rdata_typeMethod
zero_like_rdata_type(::Type{P}) where {P}

Indicates the type which will be returned by zero_like_rdata_from_type. Will be the rdata type for P if we can produce the zero rdata element given only P, and will be the union of R and ZeroRData if an instance of P is needed.

source
Mooncake.zero_rdata_from_typeMethod
zero_rdata_from_type(::Type{P}) where {P}

Returns the zero element of rdata_type(tangent_type(P)) if this is possible given only P. If not possible, returns an instance of CannotProduceZeroRDataFromType.

For example, the zero rdata associated to any primal of type Float64 is 0.0, so for Float64s this function is simple. Similarly, if the rdata type for P is NoRData, that can simply be returned.

However, it is not possible to return the zero rdata element for abstract types e.g. Real as the type does not uniquely determine the zero element – the rdata type for Real is Any.

These considerations apply recursively to tuples / namedtuples / structs, etc.

If you encounter a type which this function returns CannotProduceZeroRDataFromType, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.

source
Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
+(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source
Mooncake.zero_like_rdata_from_typeMethod
zero_like_rdata_from_type(::Type{P}) where {P}

This is an internal implementation detail – you should generally not use this function.

Returns either the zero element of type rdata_type(tangent_type(P)), or a ZeroRData. It is always valid to return a ZeroRData,

source
Mooncake.zero_like_rdata_typeMethod
zero_like_rdata_type(::Type{P}) where {P}

Indicates the type which will be returned by zero_like_rdata_from_type. Will be the rdata type for P if we can produce the zero rdata element given only P, and will be the union of R and ZeroRData if an instance of P is needed.

source
Mooncake.zero_rdata_from_typeMethod
zero_rdata_from_type(::Type{P}) where {P}

Returns the zero element of rdata_type(tangent_type(P)) if this is possible given only P. If not possible, returns an instance of CannotProduceZeroRDataFromType.

For example, the zero rdata associated to any primal of type Float64 is 0.0, so for Float64s this function is simple. Similarly, if the rdata type for P is NoRData, that can simply be returned.

However, it is not possible to return the zero rdata element for abstract types e.g. Real as the type does not uniquely determine the zero element – the rdata type for Real is Any.

These considerations apply recursively to tuples / namedtuples / structs, etc.

If you encounter a type which this function returns CannotProduceZeroRDataFromType, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.

source
Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
 
 julia> using ChainRulesCore
 
@@ -153,7 +153,7 @@
        TestUtils.test_rule(
            Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
        )
-Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
Mooncake.@is_primitiveMacro
@is_primitive context_type signature

Creates a method of is_primitive which always returns true for the context_type and signature provided. For example

@is_primitive MinimalCtx Tuple{typeof(foo), Float64}

is equivalent to

is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true

You should implemented more complicated method of is_primitive in the usual way.

source
Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
+Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
Mooncake.@is_primitiveMacro
@is_primitive context_type signature

Creates a method of is_primitive which always returns true for the context_type and signature provided. For example

@is_primitive MinimalCtx Tuple{typeof(foo), Float64}

is equivalent to

is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true

You should implemented more complicated method of is_primitive in the usual way.

source
Mooncake.@mooncake_overlayMacro
@mooncake_overlay method_expr

Define a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.

For example, suppose that you have a function

julia> foo(x::Float64) = bar(x)
 foo (generic function with 1 method)

where Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:

julia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)
 

When looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.

A Worked Example

To demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!

First, consider a simple example:

julia> scale(x) = 2x
 scale (generic function with 1 method)
@@ -173,7 +173,7 @@
 julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});
 
 julia> Mooncake.value_and_gradient!!(rule, scale, 5.0)
-(20.0, (NoTangent(), 4.0))
source
Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
+(20.0, (NoTangent(), 4.0))
source
Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
 
 julia> foo(x) = 5
 foo (generic function with 1 method)
@@ -195,7 +195,7 @@
 true
 
 julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
-(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.IntrinsicsWrappersModule
module IntrinsicsWrappers

The purpose of this module is to associate to each function in Core.Intrinsics a regular Julia function.

To understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction in Core.Intrinsics does not have its own type. Rather, they are instances of Core.IntrinsicFunction. To see this, observe that

julia> typeof(Core.Intrinsics.add_float)
+(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.IntrinsicsWrappersModule
module IntrinsicsWrappers

The purpose of this module is to associate to each function in Core.Intrinsics a regular Julia function.

To understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction in Core.Intrinsics does not have its own type. Rather, they are instances of Core.IntrinsicFunction. To see this, observe that

julia> typeof(Core.Intrinsics.add_float)
 Core.IntrinsicFunction
 
 julia> typeof(Core.Intrinsics.sub_float)
@@ -205,4 +205,4 @@
     # return add_float and its pullback
 elseif
     ...
-end

which has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).

Instead, we map each Core.IntrinsicFunction to one of the regular Julia functions in Mooncake.IntrinsicsWrappers, to which we can dispatch in the usual way.

Extended Help

It is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!! to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .

source
+end

which has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).

Instead, we map each Core.IntrinsicFunction to one of the regular Julia functions in Mooncake.IntrinsicsWrappers, to which we can dispatch in the usual way.

Extended Help

It is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!! to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .

source
diff --git a/previews/PR386/developer_documentation/running_tests_locally/index.html b/previews/PR386/developer_documentation/running_tests_locally/index.html index a619782ad..700abe2a3 100644 --- a/previews/PR386/developer_documentation/running_tests_locally/index.html +++ b/previews/PR386/developer_documentation/running_tests_locally/index.html @@ -1,2 +1,2 @@ -Running Tests Locally · Mooncake.jl

Running Tests Locally

Mooncake.jl's test suite is fairly extensive. While you can use Pkg.test to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl, and will not run all of the tests. When editing some code, you typically only want to run the tests associated with it, not the entire test suite.

There are two workflows for running tests, discussed below.

Main Testing Functionality

For all code in src, Mooncake's tests are organised as follows:

  1. Things that are required for most / all test suites are loaded up in test/front_matter.jl.
  2. The tests for something in src are located in an identically-named file in test. e.g. the unit tests for src/rrules/new.jl are located in test/rrules/new.jl.

Thus, a workflow that I (Will) find works very well is the following:

  1. Ensure that you have Revise.jl and TestEnv.jl installed in your default environment.
  2. start the REPL, dev Mooncake.jl, and navigate to the top level of the Mooncake.jl directory.
  3. using TestEnv, Revise. Better still, load both of these in your .julia/config/startup.jl file so that you don't ever forget to load them.
  4. Run the following: using Pkg; Pkg.activate("."); TestEnv.activate(); include("test/front_matter.jl"); to set up your environment.
  5. include whichever test file you want to run the tests from.
  6. Modify code, and re-include tests to check it has done was you need. Loop this until done.
  7. Make a PR. This runs the entire test suite – I find that I almost never run the entire test suite locally.

The purpose of this approach is to:

  1. Avoid restarting the REPL each time you make a change, and
  2. Run the smallest bit of the test suite possible when making changes, in order to make development a fast and enjoyable process.

If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice.

Extension and Integration Testing

Mooncake now has quite a lot of package extensions, and a large number of integration tests. Unfortunately, these come with a lot of additional dependencies. To avoid these dependencies causing CI to take much longer to run, we locate all tests for extensions and integration testing in their own environments. These can be found in the test/ext and test/integration_testing directories respectively.

These directories comprise a single .jl file, and a Project.toml. You should run these tests by simply includeing the .jl file. Doing so will activate the environemnt, ensure that the correct version of Mooncake is used, and run the tests.

+Running Tests Locally · Mooncake.jl

Running Tests Locally

Mooncake.jl's test suite is fairly extensive. While you can use Pkg.test to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl, and will not run all of the tests. When editing some code, you typically only want to run the tests associated with it, not the entire test suite.

There are two workflows for running tests, discussed below.

Main Testing Functionality

For all code in src, Mooncake's tests are organised as follows:

  1. Things that are required for most / all test suites are loaded up in test/front_matter.jl.
  2. The tests for something in src are located in an identically-named file in test. e.g. the unit tests for src/rrules/new.jl are located in test/rrules/new.jl.

Thus, a workflow that I (Will) find works very well is the following:

  1. Ensure that you have Revise.jl and TestEnv.jl installed in your default environment.
  2. start the REPL, dev Mooncake.jl, and navigate to the top level of the Mooncake.jl directory.
  3. using TestEnv, Revise. Better still, load both of these in your .julia/config/startup.jl file so that you don't ever forget to load them.
  4. Run the following: using Pkg; Pkg.activate("."); TestEnv.activate(); include("test/front_matter.jl"); to set up your environment.
  5. include whichever test file you want to run the tests from.
  6. Modify code, and re-include tests to check it has done was you need. Loop this until done.
  7. Make a PR. This runs the entire test suite – I find that I almost never run the entire test suite locally.

The purpose of this approach is to:

  1. Avoid restarting the REPL each time you make a change, and
  2. Run the smallest bit of the test suite possible when making changes, in order to make development a fast and enjoyable process.

If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice.

Extension and Integration Testing

Mooncake now has quite a lot of package extensions, and a large number of integration tests. Unfortunately, these come with a lot of additional dependencies. To avoid these dependencies causing CI to take much longer to run, we locate all tests for extensions and integration testing in their own environments. These can be found in the test/ext and test/integration_testing directories respectively.

These directories comprise a single .jl file, and a Project.toml. You should run these tests by simply includeing the .jl file. Doing so will activate the environemnt, ensure that the correct version of Mooncake is used, and run the tests.

diff --git a/previews/PR386/index.html b/previews/PR386/index.html index 308b8abe9..0d6bff850 100644 --- a/previews/PR386/index.html +++ b/previews/PR386/index.html @@ -1,2 +1,2 @@ -Mooncake.jl · Mooncake.jl

Mooncake.jl

Documentation for Mooncake.jl is on its way!

Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.

Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.

Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!

+Mooncake.jl · Mooncake.jl

Mooncake.jl

Documentation for Mooncake.jl is on its way!

Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.

Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the "Understanding Mooncake.jl" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.

Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!

diff --git a/previews/PR386/known_limitations/index.html b/previews/PR386/known_limitations/index.html index 7ef26f76f..d22425abf 100644 --- a/previews/PR386/known_limitations/index.html +++ b/previews/PR386/known_limitations/index.html @@ -38,4 +38,4 @@ Mooncake.value_and_gradient!!(rule, foo, [5.0, 4.0]) # output -(4.0, (NoTangent(), [0.0, 1.0]))

The Solution

This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.

While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.

Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.

+(4.0, (NoTangent(), [0.0, 1.0]))

The Solution

This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.

While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.

Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.

diff --git a/previews/PR386/objects.inv b/previews/PR386/objects.inv index ec1a9dc3e5a28500cd97a012ec6196a531eecd99..81f4968863bf5821ae88ae4d9a6fb96c550102a4 100644 GIT binary patch delta 6752 zcmV-m8lUC%G@~_;LjpB0kwrg$Hn!jAS1{NbyDh9Zaf^1ZQJ{4^adz%a9K)Wbxc7@7 zYbcKzjWpH6&TQfTzULfL5-Cyhk}OSdfgqmY@Nq~U9-cS-aK^u|B%PhjB9~dO8T$}L zq3e9+FaHQe8T=Y2sT0op@ZMg?UrzimjiXuSrhXKvZMUd?!Y0bMrU9IP;694|biVLi zd**u{k9nB-&X(h?X??=Z%OBZT|FH?0I19L)tdf*3wisQheWEVr-xGENKe1c-X_TdY z;HN%MPSPj}65ESn8=EHD5KWa%*c1yG6i}~;(%+2VWeCrVo4YE1KVf713{cK^wS#+{ z6wJ+DJT}2puO~aiG6rRT?8EPGlhKSn@E}_9*oGdm1uVDNiZ}~HY@G5WP3$0Yogi4L zBkk1i1lECy*%XV}dr{mtdRb0OyUn{6CmxNQIfmNpIzXOG9B6JT+Y#;`tHP1VqhPA(O)aZxz>97d0Sp&j@OU$oSEul;z! z{s)WL`?6^C`<#c??L5kYnf1R>X8o3Mi?E31_Y3tGICnf~#NxlL@8anDONKR=sKKBH zgBl9%|1v7j)&jJ(5N+3vt0Vp2zA>TwISn3IuwB-<_1{_GJJw|umPE7cd@~_|AoJNySI^iUlRi2< zhe6=~HN)UJgh~u+REXI5m=Pc&yvyL|#{L6!Y<8K#z z*0i_+HV#s~1nBVV*TDNI`i$F~Ev*zz-bJ)n`T>t$PHrv$u$S#7BRgl=l0oOJ+WAOd z`f+*J>^-7?@xH-vo!C!ckLVQ3N{H??8no+xW#8iX>V3n|Vd4XdS|7t;1++w5O7QWM zpZIrxKW_m!-88X2dc_H|-aa@%w*3jO(7=31hC{OYkW3EAK6=I3W95*e{N}XA`YpV1 ziA?t}z+<|JZ=j19F1~9g*^Y=-<|x}UY!96+W1irD0S7#O7dtZ_KJVFW*%K>zDxj4F z;>`N)aqgs&O8F`L?SLv(208lv+yWAX0~M5thC<4u)&+WW#;}^8RX+ax1sIi;^;_uw znemV(i4}R)bhYGkzfYa;9;aoUSeIibb>Qy}{@rd4+KO#6%F*ei90xzqNP@{(v9c+P zJcdnwiDAf_GGHIEI{Xy?j)*SunRQdh9($i#*_9MX-Z-f{7t7hISL#QIC>MSV7q}B! zH=H)wtwL;MLGo6tY&E)$;nSZPoL;y{iRZiFNtDI49=8!`VGON5;sJ~PM!&<&1`8}X zk~2$C>Nbm__lRgM`xHOq%+DDU#iRW0d)RV+b{qH`Yjo=q|8@)92keHKtKD^i_XLvd(9JM1XNkB1u&@fOmG@i?~6t7bbfd0Xqg%#gQV3Cu^QT3EU(PKHO4VGS97l zam73oB5Zvy-~oYRIa0DIz=9H4lgXMRyno?7hFcvEF5ZRrv0xW-{u)P4a2q69NT7Oe|S#8}XOf7vg? zen;0?D)|M({||nWt}cQo$>NsNfPN*zfyYAacO=8T?yZ3LE&F#5Pm0q&#&2Ei&N`3GQCq;xVt zv@RNVBiFNsq!9k}{!0+{6)KB=F-WRJ%tMWciW`W7jwMBhOJclQV<9R=G`t3UbeK8@ za&_Sh+IO^9b#ZoH=-f!eg#VvSpq(q4w|p?8NITE9ZzR#e|G$hp+^`of+RUH2I68d| z|6nD4OX&8RJAO#_ZcQ7s_U5A2OlwWrYhu+&>)|;rxBhVx!SY$Jzug^wCk?$Fo&Hny zcXaw+r0*s)EGjTD9DDukXza&aYE~^psv>H(o0?<;z3W+@mT&_RsSE8T8Mfe|LK7ZVO>tDY*I=#^Q5S=nCU|6u`K%aJepikEy z=+kdUrI~OWL*qB9!Msoz{2Qce+mSN*AHiA9pRx_a&MLbALhXBhD?3I8%|^X3>{X$*}q_2NGX7u6vQ~BY?ANpRWhG`+T5Tvkf`T9)@1j$R}BxPa#2Wnx(S2SfR z`Thdfy}Ua)bx~VRU-|Om4=pw+h@#Kg()MGCco$L3Ulw+M#`8O9;THBqt};LkwZ+QHw&PY54b94Sm@V$$*0R$G2uI|8&I@t^_X!xg zD5mp#%ydvMF5mS};?@GKQak~8K?KAi)+W$q0bS@POSuiqnhiTO<1SB_ zgP`_~a2vpXR9C73YCHUT@BDiug;-*EcARyJ#=xCT&C2+6(e*wmNJq#4x*Ty6+Qq8_ z!-fXBdpVEyb1-QO{*v;ZP^AUmbb05&DROHH#7w5`&zpn1=; zuXd6+M1cKSh2ZHBUh&6C2)3B4wc(N5+;jDQ&9y^z zBn+K@w9AegN9)xvMF&SeysCAf^gjs!nf)1gO}+i%@+AS)_hzAj&kqydMFO0!{cy)% z4LGfj11b4%!8GN-MWGxhD&`*)`HrgLdgMBTi=G#z0=%LCxTq-r)=flfAjK|CFRo{+Q;%tqsx~NOg;U#0GRutD<^kX(r=3OM5=P1p(DyplK{9gvt(3gd za-fm9t2&vxh{nKc9HR${6W`a%ubv8OcQu5bSeu4|TvyEmPk=5_j;7W0>y?6a1?}5^ zZJHXVl_PWS7q~N<*<@4kJ&+9mv!G7E(^Q&H_Gr?}Yc6bI&i~ z4$5La9J73v`N7PNkr!(;%j(u$-Be|N13}6+@9mdmJb+nA0Rc7(foQQq%Pa`)fNLsE zJC)rhT2k;fyK_C%x<!;1HxG_}{UniI z#?gY}g!(>XsPajLZ>H_}JecXGTrVjY8F{e@-yiADnO9K02~+=%Ka}QPZ)fs<-GLg* zFelZD>m@a=rnv638CEM-OBHtkGTK(GzB|rc)mMz)MQJqg(4nhuqI>Qz%hgy?{_XvHeMzx5ulw|)?#c*f)6 zD0;;wLGPV_itw+fSRB5;ykjPRiuyL{>(fxq9)VLkeO|4kB<&zXs1?NbqMPcOqqrZBGpSrIF@KPv#;nS; zUap7M-9d%sYipfl?fOX5B;0s{lqRTTnA=_uY>k%2KdQ;_v z>slH|f$as+qkSJc%lUzALlc0{wi&SR52QLG-@9Z9pBHxMEO;`YS!4rDt!*Qd-CA~n zdPx+g!Z$3b84HHdXs93=Z3L}#8imXM(n*pvV)NH8F`H-FBsPCF(sIZj$^OxoS^FFP zl%n$sDj<75twAE$XWg}b8fj3S_m`!ry%dZ3=J<1vO>WZ&&n6CEt7Ed2?mQ_?Z%>L- zKqfD$QUN^YVwFMIoS<)Vajsj^3_i`V41KSI0Y*i;wN2nbOCLyo-34;&yQy&PyS+XC zg#KjJ)3w$+ascPZs&6FDAb0-D=g@q}HaJl3?H;AsXxqk?Nw=!~*68SmS9M@Ij4O-z zB6{HPYleQ9e0I+d!~)k0S*J`Z;vf3xDhKGm4v7X}enOrT`$J5x6!;0xB^R^jx@6XG zSBtwS=$-HuK&rBTOtW)(!C?zwS^0~h>MhI3tqfvaCN#VKW*jagw?M4_f4$NNF=}5!t?3IP_nQ0Q`z{c z1%^{R`wfMYNfAcwio#_HP|B#>!Tijsz~UO0i> z#W2CtDRTj$BsB5k!@3v*OwV``{HkXA8k~Ku1JfP*41GJnqq9Q4VHkiIlHBr!o+iS0 zE-t$M)#^>4ArmQq)9F$r{cqV*sPI zB_{V1!URNrw(BH?tl7KTrU4L4aL^{D0mWy<+|&KK1jAdf1lNP{yj1{ICoP4yt+D~v z-Lbf4ha01dO|eSe!Fbbb=QmSbtli*OT~I|_s2^`>WGTyw*A`Ul6dt~0X0{{GcD?%p zGlw?;_(0+HBp~Z3?PAV(wsi7`R>&ZfC$! z)C2Gh2A}mY`Mn!2n&+@pz}M!m{Wji8h}Od-4=M3XSVQ&W;Fre%|IgjufmZt^t@Z~y z8o4#pbTs_>(V)K-uXZBgzKMkIHZ5gI#lsVQZ=`T%9>nPieF)Gf<^%&K9H;@&nsY@3 z0mg5CISD4R3SCxKOB`6MLwhsiMcK5Ot4DF!;p>Zz+B`o{Zv*QUp~;#_Um&+=+8 z{OZNgLIh>@t~cFt=t>f*QXg7#+*BP?`d5x-syfu@OsPAZ>>rjvZ!mC!Lu~$GLbh6e znH_Y1d#eWGraUWMIG%WZh;kS1Lt!{4_a#kmd=o`zOB*7fSc6ctmPi++nplXi7B@r7 z=;JY`RP`6>mjR;jLpR7~oNh|;7dNLDsk>{FgBh;E@F`sy=3r1DH8@}^PG+>P9j9;c zL^8CxpTyxPi|ZivuP!zi5@~@->g0!is3V1Q@-%Q_Wuuu6>6|W z3J2E-VFDB=ZWlh zD)7Wl9-QI}sSZU!8C@ajmkJWBaLix%vqkTkxR@)at3JSfV@Q!FU zrR*_fJ!+2o_aW*k>x$E3Ml7*#TCJK#bS`x>f~eSNCbH+ii{U zdxo$*QB>^op)X@Dy{;oYQp}c6=Rd1{+Wt zmB@@^|L6gLI;l;8aQOJha4x|nSa%D96OWX%Q=FMNs6nHIngQy6G(^adgXsgo+xO`( zCSDStw9!VdphMps@*%c_vRjK&Pi$80@rFv0@V*z`44;Z+^|ar6lcpeggf`M~BSHNO zwmp=X|FzI{8=AK%6B?a=f&Y_GyjgdEnb8bTwpbMM_OKsgX9-Qef)G~&yu(JhqPh7n ztUf9v&tUcOV|NdK@8V7pg{fNA16B4ur=B-|O`@hD8ZaMT4wF&!(_MZ#4AUqgw&jz% z2Xb*WLe}Y}mF2HS*q(SULeBn{s`js%FhYYyAFA78qWF04a7qdak}Q~zQhqi1I~}eQ zRKqlk(R7NtQ~$nD)v$4-VA}k%5QJ3ZZf0m$SjdJabo(cNs|2{jT4EP1=39^*ag{DH;m+ezl?PlOLlHP^w`t&q(r#xK7zdzbLU@apZ@Vlw3A6!;j&{Z)8q zZGg9b*DUyP;fD^YM~z0d-jfkj;4wxyx6e@G0hw8?$x||1f|3Tv)HEe6WLl9%PRywvzyOm_?dEjj95A6@yKi0DymMNSLXB7 zh7*y0*4cC>lrQq#xBfjvD*&1hx(;SrjoeqHj5`oLoQ+p&!&#l`;`Ve+^vHCG6I+75 z_Q)D?@`<%Af7t%;PTL<|;T6zLI!QYV^YBO+;$#c}-7)^t9R4{>^ym5ss$rm=MuBb@ z1llqNRK#l7cO1i<>{12Y7er(q_48{<%R<#P62>ik(f`+>XW5TWJBaffiL?A@Cmto9qzAs0uip+MqMKulxNM*v{#^!sQ~@#68-r@sgPOAom82Zi_3TYCptI7uPW&`XvL%!~n&v5JH z^gF!wHSOftFWJ@pNb;db-qpKJMC5RP>O^kWUdBey8c}=+Ieq|!35&A&|PW=_8(Ek%831ttB%RID?-7YXco6Fr*RJ@6B_#xU= zUVWRlADftaqr!asCJXUW@Br C8DHiA delta 6707 zcmV-38qDRRHTE=+Ljp87kwrg$wzA*%S1>p=&S_!AiCeVi8UN{KZx2E+8J1>7^WBtb_Y~n27cCt!RzSyF7rS^$BnSW2%4gADz>8DYa z`hlPNJUL0DC`fECifwF~Xk9c_I$={RU{F9%69wLk-(`r;jGIA~zn`!%eg-UOyb9nR z2L)rZ7mrOi)$2)sSjM1#jD7h1Z8Dni2OdOA9@_vRTflUSrHHdI#KtL4(!>rT*9n4^ z+S5)APhcLXm`$;my%)uey_e-QwcGStX@Y~51!6TC`+`+^bmCb;lWxy=;@^kcbZBaK z!Y*a0h06vwU}mf2mket#QG-DZ z1~nAg|7BE=tp#LjA=$3&S10>V7aVw>%X(WcdW}SESY8rd^00~F3t(n z0GA&=6Fio~8!3E$Mu!aztw^QN^u@Yx0yhh?>b&pH;c==$Cc9wR1ruGcyo-@?23TXm@zS>YHAl<`#}6&! zP)KX}@;N|TE4}C4cpr3%fSCl$G{B?p{uMobXU!s>SQ_ttu+GQUV-$a`?y)@veRO&b zoxuNVy1{)2l^E8jFtPJ7BSc1cmciDI{RaYUeazCp5BYX;V`P^wHFWHjK$!$OD%iI* zEslVVom3A2+Wh)4@IHz@`^_ALzvT7zQiY(Bfo(kDvU+zk@}53l!p}iS^Md$lQAS-~`$B+pNNf z@-7*6$?9D)*(Lkv6&rw+U5@hO%NpxP;l_zE9fp99={URr5HXxS*G{q>C##H6wrAKL zfGuO5;4TAVdKWu0A3pC1wuHnANCmcXm!4VQJBH5x^pp|t)LPdp(kAUF`S}KY~65LXt#>7Q4Yvkv9i_ZI)+bwX0UtVG$X<5 zh9^-L*C1}AS3(~eJQBnTc%$Fpc!3F)V!@du*kzkZ(IFyPOPJz^Tmm^`qIi@aKo47v zZiA3vjc$DsqHaM1fYmUImbp$_k^g_x-IqLowk+&>ge<*9ayurd<%GU zMBkdmPM84t3&T0G^JsIZOouwk2c8HK_5XMrS?9AqBA~b?y`ZX~z#oh1{}+g>P3MTl*pP))*Rva3+FL^ zoa(r7@hrTLh0K-{3id}F{6`o|DRj#p4i)+h0DzK|u zq@wWMMhkA~KssS72)ELAG9t0of`OQ(dF>XcTB!_&hXo4JJ2)(3C)n>CVW9X(0HaCO zGP&y+BH60vuTbbh#%%%FTOUC#37pk`E<;z~Fgh0i4$I-J+5bu4-_KJJYXcNLdgTV| zyKd$Uc{_&fYMkOjCe{!8eHTP`0P>R)`_5gBl#DB$H^7e*mP&a6>HmYDq^pY{O0u}+FaWL;IB;8ksBlM8 z{FIw5%~|(#MEcD(OCce!5&|*j90Ce-7VvLi;wi8tY1qmCWI^gLgDuS*jl)xWU|J6> z?_uF`Sw_Ejp2K1i!X%0sh+N*hRn zjwMxxOA@?VQz2R>G`a?QbeP(I2TFC}4BB_JR&{lDUKqYe#f1N#O`)AD=C8anqe?r^ zO<<(a!vDXFJY29BFWQWsKpdUEhJP>6KQLTYm33zg}7Er+_pFi%t^siqX zonGiL1W<+r3=7t5=+llj^y&HyefsUF^wn+S?fXXcS{EjRe}i^yJ5opgBRI?DQ?`NG zStSTy)V{Z}V-(PAv>U^!Wk5WCorkMq2LC--4Py7ej4wDm@;JBUS`SMz6(Cvu)^-=7jnc!gMn7ijU0!oS2o zr?A@M(^W*qYWNC&|DK-M2JdMQu5J<;#yhv{E_F^B(;$}5o0-j%YZf17Bef`j$3gw^zGVV zwzz{+%T6O;9FhAuFUT3(ClKtSnD+BA(^0*+eAhpVTZ^>H@dV-p6A+VF8$g>0bOBD5 zavOv-8&+z@U7jRSoai%G)PM5^QSF`JHlV4_R0Y*`^!48P_ev_UB=GDw>J*)UJG+{d z{^_9Wy;snGj!*)0InpMyi&qCm4Gs7`8yUQLL$sZ#nnRNGCD zpV(O#bH|;d@xX?2Jc~m+ENPOuzM;Mp095zqou! zNcFvcSZLt$!^C%y0q1i++%{N)PV4PJMn0S{%{XvT7zc`q`3FV6qiQ%Gxy|6B=ZUEx zuNVL>Y6gIH1JN2twM(-L>^3mn{1z>@w7v~sc2jZN&g!;5rMYlwgEj|R-6FApoYShR zUYVN7xGvJ3k8gp2Ap?JUdQxOy!PNVn!1{=PjCy=+QYVc9ZaF?lf;mxEM%rc>q9A~N z#Ys>lN2xOOJ&$#ij8cCqBd_!vXl(AP&gL$nF~}On7=hx%_x19tr-IvE4PhtNhM^GG zRWrjAs7sQgVKx1FrBGeL`*s_qCTZow-1`CU%w{(ERD2I~1E4IZ6Yw+&Uuc?jWmhSG z4CQw2#C(QQ$hG6g$Y7MrAye!*1Udu0ZX5nK$sVivjf0jlArl zol5#!ICsKXz-RBB@IG_y`6Zk|Sl3|Bf~lK z3d%R(>i_YF(%tLr%)UD?V;RPzdU3s^Ce;+joi@U1IdLjIRdlB`jizkZF` zJkln$`K!^EL;gtqkG|a6-{`v(gI~}9+52t{5-C3Gq1DKP>axEqRsE${)VIW+qihPB zMtn9&_*xy4uXN`@X?A;mGMoZ2c~O-M;64|#49fN#S}I$5GR}ATAZGe1(x$DN&$UFp zvKJ?YjoNNc3a=zcYGKHOx=dg2$KLJZHMAPp*SU2|n+A7Qmt6lXU zn^+%_NPe+JcOgMEaT!1PGP5k(<4hDgL(=#3fzpB~31{a^}$aE(@ z18gUFbXFKQ3>}a_l6&4T(nL7V#X;A1mG)PaqcW|)-RgAif9r(4H z$z;d-Gg|ZNA#*KbvW9cZ7|>|#iOJ)HFaeS6I!U2x_HMRmKm-#Uv`J||_p@T|>2X~` z;4MUg>%nAy-YSA>kd`9bR>gqp?nGR3z>P7)rkEw~V4`WZi<_xV)^7N#A*kYBs2^|X zWGTywCl6HW6cN7UX0{{GcD?%pbB8w}_)6 z?A_UI2&z9~?qGTtI9?NOXCPA4FW?&-KI`w~_wKxZXdc5>7rr)M+iw%Cgk(KT@{p3w zgf~>*4t}{E@c%sg9eA~0@@jvut&w{}OleiE1Yn?weTnZo^WhR6IP<_d<$r z=0Sq4Fopn~Vooq%!hsqRttD495a9hSC&5IPyb>u;8aFS?PseP`!-Sv5v9nS?qx@GG z<)GGob?@hFWh|y0Q#XEI8cOJEIkMXKQgo~u<7C_GHnH2jm)g0gMi#~>g$@RcWLu9> zYZ|IllYX-MP+Dt0>Of3X z4>*~P{%Qt1wNi0*V4fD;HeE_x8tv~oY;7Tbl1D#v=cKSoDxZWD{xBIP-)9Y)_ zSv_?VMBj9O@7nYjSDc#-`UvQnQ9I- z22<(*C;MN^pf?z}As{yYFd<*9%nmxhxm81PQ=XMB0#7_YM7;~wp>Uj2_>!hLzKJ4# z^ra1Pp;*IE^_IvGrJ7oZxE3cv%iG6ePO0iI(k}zt#t+>fn{hfRDPG*1UZftbO#x;& z3&W>$X&8gSfYcCxtpu6Tx^|qt#S_WU>Twc>V=S(t*uS~B^kFNcy+u9(*9qb9rd!{3 znU~w9Bkh2vd^@!JUk2}SRvur2rmt3iF)7F2VNz-@aPVC@c_k^)F9bYbQ9&nu^57I- z$POk7%BupReyQzag<~d*dzaRYLsPE}5MDJpLZoYj(pIMaRJ~wOKr^WS)Q~C{)%IPo zq?&)AV#KGia?xrj74o0S2VSxBiywPlsM>&x#(sfXZCwoc6!eP1)!*zF*kj0l`V}zl z--l>;t6Kq&8Oe7d0<-E90KJ)B`Rt#gG@-0|_oT=tgGUWcfE@jn>_FSxxdR0Rul8dn!#idUPRzJZ3jjU~d57nAIB( z5q?;o&$SVUb);)maCG&!#JcZ)&3FgPaP>+Q)c~lBRP9j|t-({XKdp_2Bs7HaN;(Ru zC;DXQpmp%1mJ*s(s%gK=Va8t-t|PbZ27q{@8iX=TT_sn@YhcWDbt;5Y=$8Tc^m~n)UUR*TAY`V;Z4@e9b)@j0Zx2sIzp^0|( zB9jC)=uMAiggU+EVaTKNf$;5ncNmlB2vpkmB31ye&06zX(vX5Dl2qEXT=c`spsm7>3#xw{GRLsRwd!HDcE3p_S#C|7}mC{FkG@rKBg7$`n}z&f0ff-brhLP)8<8e>)tH;|=a7{(h>O2|t8@hsJ=zsSL~@n<0@ zsVLpd(6O+P4NvIwPgV&?iM8Y$+Pz)Dc8o7H=-m$chC&VJZgRjCFDNIOt3UHRFxOm% zTenI!r+ZrYk=eV3C%RWp#}{-!A!|!lSMO`W)`ahFru#{^S;#h8)rPTcZ(&GlZRx%xLA!s48(vA9`{rPK zQ(kmcX+dFscW^buFapyJiS2Y2>0#|yL$TLN(T9=(V}m&TOiy&=4^t<<3p1=3N;Jc` z`!%F}S|e_~b`-0s>7r+Fs;(_V2G-22OeqsQ}EF99TX8_3Qt$+yJ%W zXUP>HpMYvUS`EJ(R?ZrA1Z$7h#v0F>wPwnb^taT150aOZQf>1qviE2OL`iP($ZHt~ zszIJt=Ka(Zi+(vosJSWbvmRmIxKp-|+rDfcifv<-3Xz)Ev?r7=^6jqvBSk9+nh3g% zW?PHgS7eMk5IdZWM{C1go$BiLbj|IN=@2`%M1Ad+HRM$jb6ftf{pOvv-@GC!pqrF{ zb{6J;w*uvDU*o;U9j`r_Z#@ok>2du4)!UDqUVhx}?qkcVkK(R|eaG>dlUZS+-Vre#(OBcN;&&Zb^bJ({lH!G6{fPATe{CMM9woE!}^`3Dvr+)!cvKIJiJ(z zY9M9@4q7yQg&=WdN~Qu3{TG-K$NCP5Av$E6`@4CJawlC7chJ5 zDb3=LmO|wjH&i@?SiK&*oL`jqb?D(fba~o{YX9=I;pGIHsG`CLKSZb6Mf&Aw5Y}pZ zO+ikM0ZC0jYMbk^NXRDI3}#bo#-_bpkvu3EewgSl`$;?1TA{5f(_LJ?>pzFL8GqG( zcd^M3Wsu`wBrPU2TBbm;2|uuxSQNCL^D&k(kKZ1TR49!b9;(VC-x`yrH#}K2-Tu`o zs66DuRrv_F&bz+DQ&!VXUbK>+_9v2(VfXGT5*eL3F=uFxzNpsCp`t>Y5b7>8bXvxp z^}TR5xz_3iKtq&`ygoD3_1G$+yM@<(C^v07v6$k-N8WJKi-tgJfn^c`UHdc2bN}{> z%a_!s_XwJ`DG3%zN`gf-CBee|`Rs$Aq$_!0KY6KsJGj&Y3qxvx#k$l43z?drqcN(I z#X^3vBVcravuZ=rZL<}dacNE7#S=V-a38RO%rjgGl$lCA?PlKfP>kDzO~IOf+jC+y z)(TX)89&&I6?N=74Hu#-Z=|n@sO7Irdjr4n5oj=VyyV zyo!o+|A&8by^8F>*Vd)zfZ59JM?amvP?aAsn_8T`+`6~7=bunB8OEGg!& const x = Ref(1.0);\n\njulia> function foo(y::Float64)\n x[] = y\n return x[]\n end\nfoo (generic function with 1 method)","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"x is a global variable (if you refer to it in your code, it appears as a GlobalRef in the AST or lowered code). For some technical reasons that are beyond the scope of this section, this package cannot propagate gradient information through x. foo is the identity function, so it should have gradient 1.0. However, if you differentiate this example, you'll see:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> rule = Mooncake.build_rrule(foo, 2.0);\n\njulia> Mooncake.value_and_gradient!!(rule, foo, 2.0)\n(2.0, (NoTangent(), 0.0))","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Observe that while it has correctly computed the identity function, the gradient is zero.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The takehome: do not attempt to differentiate functions which modify global state. Uses of globals which does not involve mutating them is fine though.","category":"page"},{"location":"known_limitations/#Circular-References","page":"Known Limitations","title":"Circular References","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"To a large extent, Mooncake.jl does not presently support circular references in an automatic fashion. It is generally possible to hand-write solutions, so we explain some of the problems here, and the general approach to resolving them.","category":"page"},{"location":"known_limitations/#Tangent-Types","page":"Known Limitations","title":"Tangent Types","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Suppose that you have a type such as:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct A\n x::Float64\n a::A\n function A(x::Float64)\n a = new(x)\n a.a = a\n return a\n end\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is a fairly canonical example of a self-referential type. There are a couple of things which will not work with it out-of-the-box. tangent_type(A) will produce a stack overflow error. To see this, note that it will in effect try to produce a tangent of type Tangent{Tuple{tangent_type(A)}} – the circular dependency on the tangent_type function causes real problems here.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In order to resolve this, you need to produce a tangent type by hand. You might go with something like","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct TangentForA\n x::Float64 # tangent type for Float64 is Float64\n a::TangentForA\n function TangentForA(x::Float64)\n a = new(x)\n a.a = a\n return a\n end\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The point here is that you can manually resolve the circular dependency using a data structure which mimics the primal type. You will, however, need to implement similar methods for zero_tangent, randn_tangent, etc, and presumably need to implement additional getfield and setfield rules which are specific to this type.","category":"page"},{"location":"known_limitations/#Circular-References-in-General","page":"Known Limitations","title":"Circular References in General","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Consider a type of the form","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct Foo\n x\n Foo() = new()\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In this instance, tangent_type will work fine because Foo does not directly reference itself in its definition. Moreover, general uses of Foo will be fine.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"However, it's possible to construct an instance of Foo with a circular reference:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"f = Foo()\nf.x = f","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is actually fine provided we never attempt to call zero_tangent / randn_tangent / similar functionality on f once we've set its x field to itself. If we attempt to call such a function, we'll find ourselves with a stack overflow.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution This is a little tricker to handle. You could specialise zero_tangent etc for Foo, but this is something of a pain. Fortunately, it seems to be incredibly rare that this is ever a problem in practice. If we gain evidence that this is often a problem in practice, we'll look into supporting zero_tangent etc automatically for this case.","category":"page"},{"location":"known_limitations/#Tangent-Generation-and-Pointers","page":"Known Limitations","title":"Tangent Generation and Pointers","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In many use cases, a pointer provides the address of the start of a block of memory which has been allocated to e.g. store an array. However, we cannot get any of this context from the pointer itself – by just looking at a pointer, I cannot know whether its purpose is to refer to the start of a large block of memory, some proportion of the way through a block of memory, or even to keep track of a single address.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Recall that the tangent to a pointer is another pointer:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> Mooncake.tangent_type(Ptr{Float64})\nPtr{Float64}","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Plainly I cannot implement a method of zero_tangent for Ptr{Float64} because I don't know how much memory to allocate.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is, however, fine if a pointer appears half way through a function, having been derived from another data structure. e.g.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"function foo(x::Vector{Float64})\n p = pointer(x, 2)\n return unsafe_load(p)\nend\n\nrule = build_rrule(Tuple{typeof(foo), Vector{Float64}})\nMooncake.value_and_gradient!!(rule, foo, [5.0, 4.0])\n\n# output\n(4.0, (NoTangent(), [0.0, 1.0]))","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"DocTestSetup = nothing","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Algorithmic-Differentiation","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This section introduces the mathematics behind AD. Even if you have worked with AD before, we recommend reading in order to acclimatise yourself to the perspective that Mooncake.jl takes on the subject.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Derivatives","page":"Algorithmic Differentiation","title":"Derivatives","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"A foundation on which all of AD is built the the derivate – we require a fairly general definition of it, which we build up to here.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Scalar-to-Scalar Functions","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Consider first f RR to RR, which we require to be differentiable at x in RR. Its derivative at x is usually thought of as the scalar alpha in RR such that","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = alpha textdx ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Loosely speaking, by this notation we mean that for arbitrary small changes textd x in the input to f, the change in the output textd f is alpha textdx. We refer readers to the first few minutes of the first lecture mentioned before for a more careful explanation.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Vector-to-Vector Functions","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The generalisation of this to Euclidean space should be familiar: if f RR^P to RR^Q is differentiable at a point x in RR^P, then the derivative of f at x is given by the Jacobian matrix at x, denoted Jx in RR^Q times P, such that","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = Jx textdx ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"It is possible to stop here, as all the functions we shall need to consider can in principle be written as functions on some subset RR^P.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"However, when we consider differentiating computer programmes, we will have to deal with complicated nested data structures, e.g. structs inside Tuples inside Vectors etc. While all of these data structures can be mapped onto a flat vector in order to make sense of the Jacobian of a computer programme, this becomes very inconvenient very quickly. To see the problem, consider the Julia function whose input is of type Tuple{Tuple{Float64, Vector{Float64}}, Vector{Float64}, Float64} and whose output is of type Tuple{Vector{Float64}, Float64}. What kind of object might be use to represent the derivative of a function mapping between these two spaces? We certainly can treat these as structured \"view\" into a \"flat\" Vector{Float64}s, and then define a Jacobian, but actually finding this mapping is a tedious exercise, even if it quite obviously exists.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In fact, a more general formulation of the derivative is used all the time in the context of AD – the matrix calculus discussed by [1] and [2] (to name a couple) make use of a generalised form of the derivative in order to work with functions which map to and from matrices (albeit there are slight differences in naming conventions from text to text), without needing to \"flatten\" them into vectors in order to make sense of them.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In general, it will be much easier to avoid \"flattening\" operations wherever possible. In order to do so, we now introduce a generalised notion of the derivative.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Functions Between More General Spaces","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to avoid the difficulties described above, we consider functions f mathcalX to mathcalY, where mathcalX and mathcalY are finite dimensional real Hilbert spaces (read: finite-dimensional vector space with an inner product, and real-valued scalars). This definition includes functions to / from RR, RR^D, but also real-valued matrices, and any other \"container\" for collections of real numbers. Furthermore, we shall see later how we can model all sorts of structured representations of data directly as such spaces.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"For such spaces, the derivative of f at x in mathcalX is the linear operator (read: linear function) D f x mathcalX to mathcalY satisfying","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = D f x (textd x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The purpose of this linear operator is to provide a linear approximation to f which is accurate for arguments which are very close to x.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Please note that D f x is a single mathematical object, despite the fact that 3 separate symbols are used to denote it – D f x (dotx) denotes the application of the function D f x to argument dotx. Furthermore, the dot-notation (dotx) does not have anything to do with time-derivatives, it is simply common notation used in the AD literature to denote the arguments of derivatives.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"So, instead of thinking of the derivative as a number or a matrix, we think about it as a function. We can express the previous notions of the derivative in this language.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In the scalar case, rather than thinking of the derivative as being alpha, we think of it is a the linear operator D f x (dotx) = alpha dotx. Put differently, rather than thinking of the derivative as the slope of the tangent to f at x, think of it as the function decribing the tangent itself. Observe that up until now we had only considered inputs to D f x which were small (textd x) – here we extend it to the entire space mathcalX and denote inputs in this space dotx. Inputs dotx should be thoughts of as \"directions\", in the directional derivative sense (why this is true will be discussed later).","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Similarly, if mathcalX = RR^P and mathcalY = RR^Q then this operator can be specified in terms of the Jacobian matrix: D f x (dotx) = Jx dotx – brackets are used to emphasise that D f x is a function, and is being applied to dotx.[note_for_geometers]","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"To reiterate, for the rest of this document, we define the derivative to be \"multiply by alpha\" or \"multiply by Jx\", rather than to be alpha or Jx. So whenever you see the word \"derivative\", you should think \"linear function\".","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The Chain Rule","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The chain rule is the result which makes AD work. Fortunately, it applies to this version of the derivative:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f = g circ h implies D f x = (D g h(x)) circ (D h x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"By induction, this extends to a collection of N functions f_1 dots f_N:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f = f_N circ dots circ f_1 implies D f x = (D f_N x_N) circ dots circ (D f_1 x_1)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where x_n+1 = f(x_n), and x_1 = x.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"An aside: the definition of the Frechet Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This definition of the derivative has a name: the Frechet derivative. It is a generalisation of the Total Derivative. Formally, we say that a function f mathcalX to mathcalY is differentiable at a point x in mathcalX if there exists a linear operator D f x mathcalX to mathcalY (the derivative) satisfying","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"lim_textd h to 0 frac f(x + textd h) - f(x) + D f x (textd h) _mathcalY textdh _mathcalX = 0","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where cdot _mathcalX and cdot _mathcalY are the norms associated to Hilbert spaces mathcalX and mathcalY respectively. It is a good idea to consider what this looks like when mathcalX = mathcalY = RR and when mathcalX = mathcalY = RR^D. It is sometimes helpful to refer to this definition to e.g. verify the correctness of the derivative of a function – as with single-variable calculus, however, this is rare.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Another aside: what does Forwards-Mode AD compute?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"At this point we have enough machinery to discuss forwards-mode AD. Expressed in the language of linear operators and Hilbert spaces, the goal of forwards-mode AD is the following: given a function f which is differentiable at a point x, compute D f x (dotx) for a given vector dotx. If f RR^P to RR^Q, this is equivalent to computing Jx dotx, where Jx is the Jacobian of f at x. For the interested reader we provide a high-level explanation of how forwards-mode AD does this in How does Forwards-Mode AD work?.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Another aside: notation","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"You may have noticed that we typically denote the argument to a derivative with a \"dot\" over it, e.g. dotx. This is something that we will do consistently, and we will use the same notation for the outputs of derivatives. Wherever you see a symbol with a \"dot\" over it, expect it to be an input or output of a derivative / forwards-mode AD.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Reverse-Mode-AD:-*what*-does-it-do?","page":"Algorithmic Differentiation","title":"Reverse-Mode AD: what does it do?","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to explain what reverse-mode AD does, we first consider the \"vector-Jacobian product\" definition in Euclidean space which will be familiar to many readers. We then generalise.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Mode AD: what does it do in Euclidean space?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In this setting, the goal of reverse-mode AD is the following: given a function f RR^P to RR^Q which is differentiable at x in RR^P with Jacobian Jx at x, compute Jx^top bary for any bary in RR^Q. This is useful because we can obtain the gradient from this when Q = 1 by letting bary = 1.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Adjoint Operators","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to generalise this algorithm to work with linear operators, we must first generalise the idea of multiplying a vector by the transpose of the Jacobian. The relevant concept here is that of the adjoint operator. Specifically, the adjoint A^ast of linear operator A is the linear operator satisfying","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle A^ast bary dotx rangle = langle bary A dotx rangle","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where langle cdot cdot rangle denotes the inner-product. The relationship between the adjoint and matrix transpose is: if A (x) = J x for some matrix J, then A^ast (y) = J^top y.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Moreover, just as (A B)^top = B^top A^top when A and B are matrices, (A B)^ast = B^ast A^ast when A and B are linear operators. This result follows in short order from the definition of the adjoint operator – (and is a good exercise!)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Mode AD: what does it do in general?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Equipped with adjoints, we can express reverse-mode AD only in terms of linear operators, dispensing with the need to express everything in terms of Jacobians. The goal of reverse-mode AD is as follows: given a differentiable function f mathcalX to mathcalY, compute D f x^ast (bary) for some bary.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Notation: D f x^ast denotes the single mathematical object which is the adjoint of D f x. It is a linear function from mathcalY to mathcalX. We may occassionally write it as (D f x)^ast if there is some risk of confusion.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We will explain how reverse-mode AD goes about computing this after some worked examples.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Aside: Notation","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"You will have noticed that arguments to adjoints have thus far always had a \"bar\" over them, e.g. bary. This notation is common in the AD literature and will be used throughout. Additionally, this \"bar\" notation will be used for the outputs of adjoints of derivatives. So wherever you see a symbol with a \"bar\" over it, think \"input or output of adjoint of derivative\".","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Some-Worked-Examples","page":"Algorithmic Differentiation","title":"Some Worked Examples","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We now present some worked examples in order to prime intuition, and to introduce the important classes of problems that will be encountered when doing AD in the Julia language. We will put all of these problems in a single general framework later on.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#An-Example-with-Matrix-Calculus","page":"Algorithmic Differentiation","title":"An Example with Matrix Calculus","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We have introduced some mathematical abstraction in order to simplify the calculations involved in AD. To this end, we consider differentiating f(X) = X^top X. Results for this and similar operations are given by [1]. A similar operation, but which maps from matrices to RR is discussed in Lecture 4 part 2 of the MIT course mentioned previouly. Both [1] and Lecture 4 part 2 provide approaches to obtaining the derivative of this function.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Following either resource will yield the derivative:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f X (dotX) = dotX^top X + X^top dotX","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Observe that this is indeed a linear operator (i.e. it is linear in its argument, dotX). (You can always plug it in to the definition of the Frechet derivative to confirm that it is indeed the derivative.)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to perform reverse-mode AD, we need to find the adjoint operator. Using the usual definition of the inner product between matrices,","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle X Y rangle = textrmtr (X^top Y)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"we can rearrange the inner product as follows:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\n langle barY D f X (dotX) rangle = langle barY dotX^top X + X^top dotX rangle nonumber \n = textrmtr (barY^top dotX^top X) + textrmtr(barY^top X^top dotX) nonumber \n = textrmtr ( barY X^top^top dotX) + textrmtr( X barY^top dotX) nonumber \n = langle barY X^top + X barY dotX rangle nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We can read off the adjoint operator from the first argument to the inner product:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f X^ast (barY) = barY X^top + X barY","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#AD-of-a-Julia-function:-a-trivial-example","page":"Algorithmic Differentiation","title":"AD of a Julia function: a trivial example","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We now turn to differentiating Julia functions (we use function to refer to the programming language construct, and function to refer to a more general mathematical concept). The way that Mooncake.jl handles immutable data is very similar to how Zygote / ChainRules do. For example, consider the Julia function","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x::Float64) = sin(x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"If you've previously worked with ChainRules / Zygote, without thinking too hard about the formalisms we introduced previously (perhaps by considering a variety of partial derivatives) you can probably arrive at the following adjoint for the derivative of f:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"g -> g * cos(x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Implicitly, you have performed three steps:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"model f as a differentiable function,\ncompute its derivative, and\ncompute the adjoint of the derivative.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"It is helpful to work through this simple example in detail, as the steps involved apply more generally. The goal is to spell out the steps involved in detail, as this detail becomes helpful in more complicated examples. If at any point this exercise feels pedantic, we ask you to stick with it.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Obviously, we model the Julia function f as the function f RR to RR where","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x) = sin(x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Observe that, we've made (at least) two modelling assumptions here:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"a Float64 is modelled as a real number,\nthe Julia function sin is modelled as the usual mathematical function sin.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"As promised we're being quite pedantic. While the first assumption is obvious and will remain true, we will shortly see examples where we have to work a bit harder to obtain a correspondence between a Julia function and a mathematical object.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we have a mathematical model, we can differentiate it:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x (dotx) = cos(x) dotx","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Given the derivative, we can find its adjoint:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle barf D f x(dotx) rangle = langle barf cos(x) dotx rangle = langle cos(x) barf dotx rangle","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"From here the adjoint can be read off from the first argument to the inner product:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x^ast (barf) = cos(x) barf","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#AD-of-a-Julia-function:-a-slightly-less-trivial-example","page":"Algorithmic Differentiation","title":"AD of a Julia function: a slightly less trivial example","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now consider the Julia function","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x::Float64, y::Tuple{Float64, Float64}) = x + y[1] * y[2]","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Its adjoint is going to be something along the lines of","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"g -> (g, (y[2] * g, y[1] * g))","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"As before, we work through in detail.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"There are a couple of aspects of f which require thought:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"it has two arguments – we've only handled single argument functions previously, and\nthe second argument is a Tuple – we've not yet decided how to model this.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"To this end, we define a mathematical notion of a tuple. A tuple is a collection of N elements, each of which is drawn from some set mathcalX_n. We denote by mathcalX = mathcalX_1 times dots times mathcalX_N the set of all N-tuples whose nth element is drawn from mathcalX_n. Provided that each mathcalX_n forms a finite Hilbert space, mathcalX forms a Hilbert space with","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"alpha x = (alpha x_1 dots alpha x_N),\nx + y = (x_1 + y_1 dots x_N + y_N), and\nlangle x y rangle = sum_n=1^N langle x_n y_n rangle.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We can think of multi-argument functions as single-argument functions of a tuple, so a reasonable mathematical model for f might be a function f RR times RR times RR to RR, where","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x y) = x + y_1 y_2","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Note that while the function is written with two arguments, you should treat them as a single tuple, where we've assigned the name x to the first element, and y to the second.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we have a mathematical object, we can differentiate it:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x y(dotx doty) = dotx + doty_1 y_2 + y_1 doty_2","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D fx y maps RR times RR times RR to RR, so D f x y^ast must map the other way. You should verify that the following follows quickly from the definition of the adjoint:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x y^ast (barf) = (barf (barf y_2 barf y_1))","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#AD-with-mutable-data","page":"Algorithmic Differentiation","title":"AD with mutable data","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In the previous two examples there was an obvious mathematical model for the Julia function. Indeed this model was sufficiently obvious that it required little explanation. This is not always the case though, in particular, Julia functions which modify / mutate their inputs require a little more thought.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Consider the following Julia function:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"function f!(x::Vector{Float64})\n x .*= x\n return sum(x)\nend","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This function squares each element of its input in-place, and returns the sum of the result. So what is an appropriate mathematical model for this function?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The trick is to distinguish between the state of x upon entry to / exit from f!. In particular, let phi_textf RR^N to RR^N times RR be given by","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"phi_textf(x) = (x odot x sum_n=1^N x_n^2)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where odot denotes the Hadamard / element-wise product (corresponds to line x .*= x in the above code). The point here is that the inputs to phi_textf are the inputs to x upon entry to f!, and the value returned from phi_textf is a tuple containing the both the inputs upon exit from f! and the value returned by f!.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The remaining steps are straightforward now that we have the model.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The derivative of phi_textf is","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D phi_textf x(dotx) = (2 x odot dotx 2 sum_n=1^N x_n dotx_n)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The argument to the adjoint of the derivative must be a 2-tuple whose elements are drawn from RR^N times RR . Denote such a tuple as (bary_1 bary_2). Plugging this into an inner product with the derivative and rearranging yields","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\n langle (bary_1 bary_2) D phi_textf x (dotx) rangle = langle (bary_1 bary_2) (2 x odot dotx 2 sum_n=1^N x_n dotx_n) rangle nonumber \n = langle bary_1 2 x odot dotx rangle + langle bary_2 2 sum_n=1^N x_n dotx_n rangle nonumber \n = langle 2x odot bary_1 dotx rangle + langle 2 bary_2 x dotx rangle nonumber \n = langle 2 (x odot bary_1 + bary_2 x) dotx rangle nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"So we can read off the adjoint to be","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D phi_textf x^ast (bary) = 2 (x odot bary_1 + bary_2 x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Reverse-Mode-AD:-*how*-does-it-do-it?","page":"Algorithmic Differentiation","title":"Reverse-Mode AD: how does it do it?","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we know what it is that AD computes, we need a rough understanding of how it computes it.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In short: reverse-mode AD breaks down a \"complicated\" function f into the composition of a collection of \"simple\" functions f_1 dots f_N, applies the chain rule, and takes the adjoint.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Specifically, we assume that we can express any function f as f = f_N circ dots circ f_1, and that we can compute the adjoint of the derivative for each f_n. From this, we can obtain the adjoint of f by applying the chain rule to the derivatives and taking the adjoint:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\nD f x^ast = (D f_N x_N circ dots circ D f_1 x_1)^ast nonumber \n = D f_1 x_1^ast circ dots circ D f_N x_N^ast nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"For example, suppose that f(x) = sin(cos(texttr(X^top X))). One option to compute its adjoint is to figure it out by hand directly (probably using the chain rule somewhere). Instead, we could notice that f = f_4 circ f_3 circ f_2 circ f_1 where f_4 = sin, f_3 = cos, f_2 = texttr and f_1(X) = X^top X. We could derive the adjoint for each of these functions (a fairly straightforward task), and then compute","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x^ast (bary) = (D f_1 x_1^ast circ D f_2 x_2^ast circ D f_3 x_3^ast circ D f_4 x_4^ast)(1)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"in order to obtain the gradient of f. Reverse-mode AD essentially just does this. Modern systems have hand-written adjoints for (hopefully!) all of the \"simple\" functions you may wish to build a function such as f from (often there are hundreds of these), and composes them to compute the adjoint of f. A sketch of a more generic algorithm is as follows.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Forwards-Pass:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"x_1 = x, n = 1\nconstruct D f_n x_n^ast\nlet x_n+1 = f_n (x_n)\nlet n = n + 1\nif n N + 1 then go to step 2.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Pass:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"let barx_N+1 = bary\nlet n = n - 1\nlet barx_n = D f_n x_n^ast (barx_n+1)\nif n = 1 return barx_1 else go to step 2.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"How does this relate to vector-Jacobian products?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In Euclidean space, each derivative D f_n x_n(dotx_n) = J_nx_n dotx_n. Applying the chain rule to D f x and substituting this in yields","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Jx = J_Nx_N dots J_1x_1 ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Taking the transpose and multiplying from the left by bary yields","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Jx^top bary = Jx_1^top_1 dots Jx_N^top_N bary ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Comparing this with the expression in terms of adjoints and operators, we see that composition of adjoints of derivatives has been replaced with multiplying by transposed Jacobian matrices. This \"vector-Jacobian product\" expression is commonly used to explain AD, and is likely familiar to many readers.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Directional-Derivatives-and-Gradients","page":"Algorithmic Differentiation","title":"Directional Derivatives and Gradients","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now we turn to using reverse-mode AD to compute the gradient of a function. In short, given a function g mathcalX to RR with derivative D g x at x, its gradient is equal to D g x^ast (1). We explain why in this section.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The derivative discussed here can be used to compute directional derivatives. Consider a function f mathcalX to RR with Frechet derivative D f x mathcalX to RR at x in mathcalX. Then D fx(dotx) returns the directional derivative in direction dotx.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Gradients are closely related to the adjoint of the derivative. Recall that the gradient of f at x is defined to be the vector nabla f (x) in mathcalX such that langle nabla f (x) dotx rangle gives the directional derivative of f at x in direction dotx. Having noted that D fx(dotx) is exactly this directional derivative, we can equivalently say that","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D fx(dotx) = langle nabla f (x) dotx rangle ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The role of the adjoint is revealed when we consider f = mathcall circ g, where g mathcalX to mathcalY, mathcall(y) = langle bary y rangle, and bary in mathcalY is some fixed vector. Noting that D mathcall y(doty) = langle bary doty rangle, we apply the chain rule to obtain","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\nD f x (dotx) = (D mathcall g(x)) circ (D g x)(dotx) nonumber \n = langle bary D g x (dotx) rangle nonumber \n = langle D g x^ast (bary) dotx rangle nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"from which we conclude that D g x^ast (bary) is the gradient of the composition l circ g at x.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The above shows that if mathcalY = RR and g is the function we wish to compute the gradient of, we can simply set bary = 1 and compute D g x^ast (bary) to obtain the gradient of g at x.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Summary","page":"Algorithmic Differentiation","title":"Summary","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Asides","page":"Algorithmic Differentiation","title":"Asides","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/#*How*-does-Forwards-Mode-AD-work?","page":"Algorithmic Differentiation","title":"How does Forwards-Mode AD work?","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Forwards-mode AD achieves this by breaking down f into the composition f = f_N circ dots circ f_1, where each f_n is a simple function whose derivative (function) D f_n x_n we know for any given x_n. By the chain rule, we have that","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x (dotx) = D f_N x_N circ dots circ D f_1 x_1 (dotx)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"which suggests the following algorithm:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"let x_1 = x, dotx_1 = dotx, and n = 1\nlet dotx_n+1 = D f_n x_n (dotx_n)\nlet x_n+1 = f(x_n)\nlet n = n + 1\nif n = N+1 then return dotx_N+1, otherwise go to 2.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"When each function f_n maps between Euclidean spaces, the applications of derivatives D f_n x_n (dotx_n) are given by J_n dotx_n where J_n is the Jacobian of f_n at x_n.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).\n\n\n\nT. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).\n\n\n\n","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"[note_for_geometers]: in AD we only really need to discuss differentiatiable functions between vector spaces that are isomorphic to Euclidean space. Consequently, a variety of considerations which are usually required in differential geometry are not required here. Notably, the tangent space is assumed to be the same everywhere, and to be the same as the domain of the function. Avoiding these additional considerations helps keep the mathematics as simple as possible.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Forwards-Mode-Design","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Disclaimer: this document refers to an as-yet-unimplemented forwards-mode AD. This will disclaimer will be removed once it has been implemented.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The purpose of this document is to explain how forwards-mode AD in Mooncake.jl is implemented. It should do so to a sufficient level of depth to enable the interested reader to read to the forwards-mode AD code in Mooncake.jl and understand what is going on.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"This document","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"specifies the semantics of a \"rule\" for forwards-mode AD,\nspecifies how to implement rules by-hand for primitives, and\nspecifies how to derive rules from IRCode algorithmically in general.\ndiscusses batched forwards-mode\nconcludes with some notable technical differences between our forwards-mode AD implementation details and reverse-mode AD implementation details.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Forwards-Rule-Interface","page":"Forwards-Mode Design","title":"Forwards-Rule Interface","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Loosely, a rule for a function simultaneously","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"performs same computation as the original function, and\ncomputes the Frechet derivative.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"This is best made concrete through a worked example. Consider a function call","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"z = f(x, y)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"where f itself may contain data / state which is modified by executing f. rule_for_f is some callable which claims to be a forwards-rule for f. For rule_for_f to be a valid forwards-rule for f, it must be applicable to Duals as follows:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"z_dz = rule_for_f(Dual(f, df), Dual(x, dx), Dual(y, dy))::Dual","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"where:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"rule_for_f is a callable. It might be written by-hand, or derived algorithmically.\ndf, dx, and dy are tangents for f, x, and y respectively. Before executing rule_for_f, they are inputs to the derivative of (f, x, y). After executing they are outputs of this derivative.\nz_dz is a Dual containing the primal and the component of the derivative of (f, x, y) to (df, dx, dy) associated to z.\nrunning rule_for_f leaves f, x, and y in the same state that running f does.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We refer readers to Algorithmic Differentiation to explain what we mean when we talk about the \"derivative\" above. We also discussed some worked examples shortly.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Note that rule_for_f is an as-yet-unspecified callable which we introduced purely to specify the interface that a forwards-rule must satisfy. In Hand-Written Rules and Derived Rules below, we introduce two concrete ways to produce rules for f.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Tangent-Types","page":"Forwards-Mode Design","title":"Tangent Types","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We will use the type system documented in Representing Gradients. This means that every primal type has a unique tangent type. Moreover, if a Dual is defined as follows:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"struct Dual{P, T}\n primal::P\n tangent::T\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"it must always hold that T = tangent_type(P).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Testing","page":"Forwards-Mode Design","title":"Testing","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Suppose that we have (somehow) produced a supposed forwards-rule. To check that it is correctly implemented, we must","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"all primal state after running the rule is approximately the same as all primal state after running the primal, and\nthe inner product between all tangents (both output and input) and a random tangent vector after running the rule is approximately the same as the estimate of the same quantity produced by finite differencing or reverse-mode AD.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We already have the functionality to do this in a very general way (see Mooncake.TestUtils.test_rule).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Hand-Written-Rules","page":"Forwards-Mode Design","title":"Hand-Written Rules","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Hand-written rules are implemented by writing methods of two functions: is_primitive and frule!!.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#is_primitive","page":"Forwards-Mode Design","title":"is_primitive","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"is_primitive(::Type{<:Union{MinimalForwardsCtx, DefaultForwardsCtx}}, signature::Type{<:Tuple}) should return true if AD must attempt to differentiate a call by passing the arguments to frule!!, and false otherwise. The Mooncake.@is_primitive macro can be used to implement this straightforwardly.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#frule!!","page":"Forwards-Mode Design","title":"frule!!","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Methods of frule!! do the actual differentiation, and must satisfy the Forwards-Rule Interface discussed above.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In what follows, we will refer to frule!!s for signatures. For example, the frule!! for signature Tuple{typeof(sin), Float64} is the rule which would differentiate calls like sin(5.0).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Simple-Scalar-Function","page":"Forwards-Mode Design","title":"Simple Scalar Function","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Recall that for y = sin(x) we have that doty = cos(x) dotx. So the frule!! for signature Tuple{typeof(sin), Float64} is:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"function frule!!(::Dual{typeof(sin)}, x::Dual{Float64})\n return Dual(sin(x.primal), cos(x.primal) * x.tangent)\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Pre-allocated-Matrix-Matrix-Multiply","page":"Forwards-Mode Design","title":"Pre-allocated Matrix-Matrix Multiply","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Recall that for Z = X Y we have that dotZ = X dotY + dotX Y. So the frule!! for signature Tuple{typeof(mul!), Matrix{Float64}, Matrix{Float64}, Matrix{Float64}} is:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"function frule!!(\n ::Dual{typeof(LinearAlgebra.mul!)}, Z::Dual{P}, X::Dual{P}, Y::Dual{P}\n) where {P<:Matrix{Float64}}\n\n # Primal computation.\n mul!(Z.primal, X.primal, Y.primal)\n\n # Overwrite tangent of `z` to contain propagated tangent.\n mul!(Z.tangent, X.primal, Y.tangent)\n\n # Add the result of x.tangent * y.primal to `z.tangent`.\n mul!(Z.tangent, X.tangent, Y.primal, 0.0, 1.0) \n return Z\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"(In practice we would probably implement a rule for a lower-level function like LinearAlgebra.BLAS.gemm!, rather than mul!).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Derived-Rules","page":"Forwards-Mode Design","title":"Derived Rules","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"This is the \"automatic\" / \"algorithmic\" bit of AD! This is the second way of producing concrete callable objects which satisfy the Forwards-Rule Interface discussed above. The object which we will ultimately construct is an instance Mooncake.DerivedFRule.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Worked-Example:-Julia-Function","page":"Forwards-Mode Design","title":"Worked Example: Julia Function","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Before explaining how derived rules are produced algorithmically, we explain by way of example what a derived rule should look like if we work things through by hand.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"A derived rule for a function such as","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"function f(x)\n y = g(x)\n z = h(x, y)\n return z\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"should be something of the form","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"function rule_for_f(::Dual{typeof(f)}, x::Dual)\n y = rule_for_g(zero_dual(g), x)\n z = rule_for_h(zero_dual(h), x, y)\n return z\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Observe that the transformation is simply","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"replace all variables with Dual variables,\nreplace all constants (e.g. g and h) with constant Duals,\nreplace all calls with calls to rules.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In general, all control flow should be identical between primal and rule.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Worked-Example:-IRCode","page":"Forwards-Mode Design","title":"Worked Example: IRCode","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The above example is expressed in terms of Julia code, but we will be operating on Julia Compiler.IRCode, so it is helpful to consider how the above example translates into this form. If we call f on a Float64, and suppose that g and h both return Float64s, the primal Compiler.IRCode will look something like the following:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"julia> Base.code_ircode_by_type(Tuple{typeof(f), Float64})\n1-element Vector{Any}:\n2 1 ─ %1 = invoke Main.g(_2::Float64)::Float64\n3 │ %2 = invoke Main.h(_2::Float64, %1::Float64)::Float64\n4 └── return %2\n => Float64","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Recall that _2 is the second argument, in this case the Float64, and %1 and %2 are SSAValues. Roughly speaking, the forwards-mode IR for the (ficiticious) function rule_for_f should look something like:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"julia> Base.code_ircode_by_type(Tuple{typeof(rule_for_f), Dual{typeof(f), NoTangent}, Dual{Float64, Float64}})\n1-element Vector{Any}:\n2 1 ─ %1 = invoke rule_for_g($(Dual(Main.g, NoTangent())), _3::Dual{Float64, Float64})::Dual{Float64, Float64}\n3 │ %2 = invoke rule_for_h($(Dual(Main.h, NoTangent())), _3::Dual{Float64, Float64}, %1::Dual{Float64, Float64})::Dual{Float64, Float64}\n4 └── return %2\n => Dual{Float64, Float64}","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Observe that:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"All Arguments have been incremented by 1. i.e. _2 has been replaced with _3. This corresponds to the fact that the arguments to the rule have all been shuffled along by one, and the rule itself is now the first argument.\nEverything has been turned into a Dual.\nConstants such as Dual(Main.g, NoTangent()) appear directly in the code (here as QuoteNodes).","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"(In practice it might be that we actually construct the Dualed constants on the lines immediately preceding a call and rely on the compiler to optimise them back into the call directly).","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Here, as before, we have not specified exactly what rule_for_f, rule_for_g, and rule_for_h are. This is intentional – they are just callables satisfying the Forwards-Rule Interface. In the following we show how to derive rule_for_f, and show how rule_for_g and rule_for_h might be methods of Mooncake.frule!!, or themselves derived rules.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Rule-Derivation-Outline","page":"Forwards-Mode Design","title":"Rule Derivation Outline","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Equipped with some intuition about what a derived rule ought to look like, we examine how we go about producing it algorithmically.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Rule derivation is implemented via the function Mooncake.build_frule. This function accepts as arguments a context and a signature / Base.MethodInstance / MistyClosure and, roughly speaking, does the following:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Look up the optimised Compiler.IRCode.\nApply a series of standardising transformations to the IRCode.\nTransform each statement according to a set of rules to produce a new IRCode.\nApply standard Julia optimisations to this new IRCode.\nPut this code inside a MistyClosure in order to produce a executable object.\nWrap this MistyClosure in a DerivedFRule to handle various bits of book-keeping around varargs.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In order:","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Looking-up-the-Compiler.IRCode.","page":"Forwards-Mode Design","title":"Looking up the Compiler.IRCode.","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"This is done using Mooncake.lookup_ir. This function has methods with will return the IRCode associated to:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"signatures (e.g. Tuple{typeof(f), Float64})\nBase.MethodInstances (relevant for :invoke expressions – see Statement Transformation below)\nMistyClosures.MistyClosure objects, which is essential when computing higher order derivatives and Hessians by applying Mooncake.jl to itself.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Standardisation","page":"Forwards-Mode Design","title":"Standardisation","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We apply the following transformations to the Julia IR. They can all be found in ir_normalisation.jl:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Mooncake.foreigncall_to_call: convert Expr(:foreigncall, ...) expressions into Expr(:call, Mooncake._foreigncall_, ...) expressions.\nMooncake.new_to_call: convert Expr(:new, ...) expressions to Expr(:call, Mooncake._new_, ...) expressions.\nMooncake.splatnew_to_call: convert Expr(:splatnew, ...) expressions to Expr(:call, Mooncake._splat_new_...) expressions.\nMooncake.intrinsic_to_function: convert Expr(:call, ::IntrinsicFunction, ...) to calls to the corresponding function in Mooncake.IntrinsicsWrappers.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The purpose of converting Expr(:foreigncall...), Expr(:new, ...) and Expr(:splatnew, ...) into Expr(:call, ...)s is to enable us to differentiate such expressions by adding methods to frule!!(::Dual{typeof(Mooncake._foreigncall_)}), frule!!(::Dual{typeof(Mooncake._new_)}), and frule!!(::Dual{typeof(Mooncake._splat_new_)}), in exactly the same way that we would for any other regular Julia function.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The purpose of translating Expr(:call, ::IntrinsicFunction, ...) is to do with type stability – see the docstring for the Mooncake.IntrinsicsWrappers module for more info.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Statement-Transformation","page":"Forwards-Mode Design","title":"Statement Transformation","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Each statment which can appear in the Julia IR is transformed by a method of Mooncake.make_fwds_ad_stmts. Consequently, this transformation phase simply corresponds to iterating through all of the expressions in the IRCode, applying Mooncake.make_fwd_ad_stmts to each to produce new IRCode. To understand how to modify IRCode and insert new instructions, see Oxinabox's Gist.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We provide here a high-level summary of the transformations for the most important Julia IR statements, and refer readers to the methods of Mooncake.make_fwds_ad_stmts for the definitive explanation of what transformation is applied, and the rationale for applying it. In particular there are quite a number more statements which can appear in Julia IR than those listed here and, for those we do list here, there are typically a few edge cases left out.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Expr(:invoke, method_instance, f, x...) and Expr(:call, f, x...)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":":call expressions correspond to dynamic dispatch, while :invoke expressions correspond to static dispatch. That is, if you see an :invoke expression, you know for sure that the compiler knows enough information about the types of f and x to prove exactly which specialisation of which method to call. This specialisation is method_instance. This typically happens when the compiler is able to prove the types of f and x. Conversely, a :call expression typically occurs when the compiler has not been able to deduce the exact types of f and x, and therefore not been able to figure out what to call. It therefore has to wait until runtime to figure out what to call, resulting in dynamic dispatch.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"As we saw earlier, the idea is to translate these kinds of expressions into something vaguely along the lines of","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Expr(:call, rule_for_f, f, x...)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"There are three cases to consider, in order of preference:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Primitives:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"If is_primitive returns true when applied to the signature constructed from the static types of f and x, then we simply replace the expression with Expr(:call, frule!!, f, x...), regardless whether we have an :invoke or :call expression. (Due to the Standardisation steps, it regularly happens that we see :call expressions in which we actually do know enough type information to do this, e.g. for Mooncake._new_ :call expressions).","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Static Dispatch:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In the case of :invoke nodes we know for sure at rule compilation time what rule_for_f must be. We derive a rule for the call by passing method_instance to Mooncake.build_frule. (In practice, we might do this lazily, but while retaining enough information to maintain type stability. See the Mooncake.LazyDerivedRule for how this is handled in reverse-mode).","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Dynamic Dispatch:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"If we have a :call expression and are not able to prove that is_primitive will return true, we must defer dispatch until runtime. We do this by replacing the :call expression with a call to a DynamicFRule, which simply constructs (or retrieves from a cache) the rule at runtime. Reverse-mode utilises a similar strategy via Mooncake.DynamicDerivedRule.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The above was written in terms of f and x. In practice, of course, we encounter various kinds of constants (e.g. Base.sin), Arguments (e.g. _3), and Core.SSAValues (e.g. %5). The translation rules for these are:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"constants are turned into constant duals in which the tangent is zero,\nArguments are incremented by 1.\nSSAValues are left as-is.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.GotoNodes","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"These remain entirely unchanged.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.GotoIfNot","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"These require minor modification. Suppose that a Core.GotoIfNot of the form Core.GotoIfNot(%5, 4) is encountered in the primal. Since %5 will be a Dual in the derived rule, we must pull out the primal field, and pass that to the conditional instead. Therefore, these statments get lowered to two lines in the derived rule. For example, Core.GotoIfNot(%5, 4) would be translated to:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"%n = getfield(%5, :primal)\nCore.GotoIfNot(%n, 4)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.PhiNode","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.PhiNode looks something like the following in the general case:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"φ (#1 => %3, #2 => _2, #3 => 4, #4 => #undef)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"They map from a collection of basic block numbers (#1, #2, etc) to values. The values can be Core.Arguments, Core.SSAValues, constants (literals and QuoteNodes), or undefined.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.PhiNodes in the primal are mapped to Core.PhiNodes in the rule. They contain exactly the same basic block numbers, and apply the following translation rules to the values:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.SSAValues are unchanged.\nCore.Arguments are incremented by 1 (as always).\nconstants are translated into constant duals.\nundefined values remain undefined.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"So the above example would be translated into something like","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"φ (#1 => %3, #2 => _3, #3 => $(CoDual(4, NoTangent())), #4 => #undef)","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Optimisation","page":"Forwards-Mode Design","title":"Optimisation","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The IR generated in the previous step will typically be uninferred, and suboptimal in a variety of ways. We fix this up by running inference and optimisation on the generated IRCode. This is implemented by Mooncake.optimise_ir!.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Put-IRCode-in-MistyClosure","page":"Forwards-Mode Design","title":"Put IRCode in MistyClosure","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Now that we have an optimised IRCode object, we need to turn it into something that can actually be run. This can, in general, be straightforwardly achieved by putting it inside a Core.OpaqueClosure. This works, but Core.OpaqueClosures have the disadvantage that once you've constructed a Core.OpaqueClosure using an IRCode, it is not possible to get it back out. Consequently, we use MistyClosures, in order to keep the IRCode readily accessible if we want to access it later.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Put-the-MistyClosure-in-a-DerivedFRule","page":"Forwards-Mode Design","title":"Put the MistyClosure in a DerivedFRule","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"See the implementation of DerivedRule (used in reverse-mode) for more context on this. This is the \"rule\" that users get.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Batch-Mode","page":"Forwards-Mode Design","title":"Batch Mode","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"So far, we have assumed that we would only apply forwards-mode to a single tangent vector at a time. However, in practice, it is typically best to pass a collection of tangents through at a time.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In order to do this, all of the transformation code listed above can remain the same, we will just need to devise a system of \"batched tangents\". Then, instead of propagating a \"primal-tangent\" pairs via Duals, we propagate primal-tangent_batch pairs (perhaps also via Duals).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Forwards-vs-Reverse-Implementation","page":"Forwards-Mode Design","title":"Forwards vs Reverse Implementation","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The implementation of forwards-mode AD is quite dramatically simpler than that of reverse-mode AD. Some notable technical differences include:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"forwards-mode AD only makes use of the tangent system, whereas reverse-mode also makes use of the fdata / rdata system.\nforwards-mode AD comprises only line-by-line transformations of the IRCode. In particular, it does not require the insertion of additional basic blocks, nor the modification of the successors / predecessors of any given basic block. Consequently, there is no need to make use of the BBCode infrastructure built up for reverse-mode AD – everything can be straightforwardly done at the Compiler.IRCode level.","category":"page"},{"location":"developer_documentation/internal_docstrings/#Internal-Docstrings","page":"Internal Docstrings","title":"Internal Docstrings","text":"","category":"section"},{"location":"developer_documentation/internal_docstrings/","page":"Internal Docstrings","title":"Internal Docstrings","text":"Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.","category":"page"},{"location":"developer_documentation/internal_docstrings/","page":"Internal Docstrings","title":"Internal Docstrings","text":"The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.","category":"page"},{"location":"developer_documentation/internal_docstrings/","page":"Internal Docstrings","title":"Internal Docstrings","text":"Modules = [Mooncake]\nPublic = false","category":"page"},{"location":"developer_documentation/internal_docstrings/#Mooncake.GLOBAL_INTERPRETER-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.GLOBAL_INTERPRETER","text":"const GLOBAL_INTERPRETER\n\nGlobally cached interpreter. Should only be accessed via get_interpreter.\n\n\n\n\n\n","category":"constant"},{"location":"developer_documentation/internal_docstrings/#Mooncake.Terminator-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.Terminator","text":"Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode}\n\nA Union of the possible types of a terminator node.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Core.Compiler.IRCode-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Core.Compiler.IRCode","text":"IRCode(bb_code::BBCode)\n\nProduce an IRCode instance which is equivalent to bb_code. The resulting IRCode shares no memory with bb_code, so can be safely mutated without modifying bb_code.\n\nAll IDPhiNodes, IDGotoIfNots, and IDGotoNodes are converted into PhiNodes, GotoIfNots, and GotoNodes respectively.\n\nIn the resulting bb_code, any Switch nodes are lowered into a semantically-equivalent collection of GotoIfNot nodes.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ADInfo-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ADInfo","text":"ADInfo\n\nThis data structure is used to hold \"global\" information associated to a particular call to build_rrule. It is used as a means of communication between make_ad_stmts! and the codegen which produces the forwards- and reverse-passes.\n\ninterp: a MooncakeInterpreter.\nblock_stack_id: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.\nblock_stack: the block stack. Can always be found at block_stack_id in the forwards- and reverse-passes.\nentry_id: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.\nshared_data_pairs: the SharedDataPairs used to define the captured variables passed to both the forwards- and reverse-passes.\narg_types: a map from Argument to its static type.\nssa_insts: a map from ID associated to lines to the primal NewInstruction. This contains the line of code, its static / inferred type, and some other detailss. See Core.Compiler.NewInstruction for a full list of fields.\narg_rdata_ref_ids: the dict mapping from arguments to the ID which creates and initialises the Ref which contains the reverse data associated to that argument. Recall that the heap allocations associated to this Ref are always optimised away in the final programme.\nssa_rdata_ref_ids: the same as arg_rdata_ref_ids, but for each ID associated to an ssa rather than each argument.\ndebug_mode: if true, run in \"debug mode\" – wraps all rule calls in DebugRRule. This is applied recursively, so that debug mode is also switched on in derived rules.\nis_used_dict: for each ID associated to a line of code, is false if line is not used anywhere in any other line of code.\nlazy_zero_rdata_ref_id: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct a LazyZeroRData for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ADStmtInfo-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ADStmtInfo","text":"ADStmtInfo\n\nData structure which contains the result of make_ad_stmts!. Fields are\n\nline: the ID associated to the primal line from which this is derived\ncomms_id: an ID from one of the lines in fwds, whose value will be made available on the reverse-pass in the same ID. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communication ID on a per-block basis.\nfwds: the instructions which run the forwards-pass of AD\nrvs: the instructions which run the reverse-pass of AD / the pullback\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBCode-Tuple{Core.Compiler.IRCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBCode","text":"BBCode(ir::IRCode)\n\nConvert an ir into a BBCode. Creates a completely independent data structure, so mutating the BBCode returned will not mutate ir.\n\nAll PhiNodes, GotoIfNots, and GotoNodes will be replaced with the IDPhiNodes, IDGotoIfNots, and IDGotoNodes respectively.\n\nSee IRCode for conversion back to IRCode.\n\nNote that IRCode(BBCode(ir)) should be equal to the identity function.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBCode-Tuple{Union{Core.Compiler.IRCode, Mooncake.BBCode}, Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBCode","text":"BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block})\n\nMake a new BBCode whose blocks is given by new_blocks, and fresh copies are made of all other fields from ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBCode-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBCode","text":"BBCode(\n blocks::Vector{BBlock}\n argtypes::Vector{Any}\n sptypes::Vector{CC.VarState}\n linetable::Vector{Core.LineInfoNode}\n meta::Vector{Expr}\n)\n\nA BBCode is a data structure which is similar to IRCode, but adds additional structure.\n\nIn particular, a BBCode comprises a sequence of basic blocks (BBlocks), each of which comprise a sequence of statements. Moreover, each BBlock has its own unique ID, as does each statment.\n\nThe consequence of this is that new basic blocks can be inserted into a BBCode. This is distinct from IRCode, in which to create a new basic block, one must insert additional statments which you know will create a new basic block – this is generally quite an unreliable process, while inserting a new BBlock into BBCode is entirely predictable. Furthermore, inserting a new BBlock does not change the ID associated to the other blocks, meaning that you can safely assume that references from existing basic block terminators / phi nodes to other blocks will not be modified by inserting a new basic block.\n\nAdditionally, since each statment in each basic block has its own unique ID, new statments can be inserted without changing references between other blocks. IRCode also has some support for this via its new_nodes field, but eventually all statements will be renamed upon compact!ing the IRCode, meaning that the name of any given statement will eventually change.\n\nFinally, note that the basic blocks in a BBCode support the custom Switch statement. This statement is not valid in IRCode, and is therefore lowered into a collection of GotoIfNots and GotoNodes when a BBCode is converted back into an IRCode.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBlock-Tuple{Mooncake.ID, Vector{Tuple{Mooncake.ID, Core.Compiler.NewInstruction}}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBlock","text":"BBlock(id::ID, inst_pairs::Vector{IDInstPair})\n\nConvenience constructor – splits inst_pairs into a Vector{ID} and InstVector in order to build a BBlock.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBlock-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBlock","text":"BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector)\n\nA basic block data structure (not called BasicBlock to avoid accidental confusion with CC.BasicBlock). Forms a single basic block.\n\nEach BBlock has an ID (a unique name). This makes it possible to refer to blocks in a way that does not change when additional BBlocks are inserted into a BBCode. This differs from the positional block numbering found in IRCode, in which the number associated to a basic block changes when new blocks are inserted.\n\nThe nth line of code in a BBlock is associated to ID stmt_ids[n], and the nth instruction from stmts.\n\nNote that PhiNodes, GotoIfNots, and GotoNodes should not appear in a BBlock – instead an IDPhiNode, IDGotoIfNot, or IDGotoNode should be used.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BlockStack-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BlockStack","text":"The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32) unique basic blocks in a given function, which ought to be reasonable.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.CannotProduceZeroRDataFromType-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.CannotProduceZeroRDataFromType","text":"CannotProduceZeroRDataFromType()\n\nReturned by zero_rdata_from_type if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type for more info.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.Config-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.Config","text":"Config(; debug_mode=false, silence_debug_messages=false)\n\nConfiguration struct for use with ADTypes.AutoMooncake.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DebugPullback-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DebugPullback","text":"(pb::DebugPullback)(dy)\n\nApply type checking to enforce pre- and post-conditions on pb.pb. See the docstring for DebugPullback for details.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DebugPullback-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DebugPullback","text":"DebugPullback(pb, y, x)\n\nConstruct a callable which is equivalent to pb, but which enforces type-based pre- and post-conditions to pb. Let dx = pb.pb(dy), for some rdata dy, then this function\n\nchecks that dy has the correct rdata type for y, and\nchecks that each element of dx has the correct rdata type for x.\n\nReverse pass counterpart to DebugRRule\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DebugRRule-Union{NTuple{N, Mooncake.CoDual}, Tuple{N}} where N-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DebugRRule","text":"(rule::DebugRRule)(x::CoDual...)\n\nApply type checking to enforce pre- and post-conditions on rule.rule. See the docstring for DebugRRule for details.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DebugRRule-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DebugRRule","text":"DebugRRule(rule)\n\nConstruct a callable which is equivalent to rule, but inserts additional type checking. In particular:\n\ncheck that the fdata in each argument is of the correct type for the primal\ncheck that the fdata in the CoDual returned from the rule is of the correct type for the primal.\n\nThis happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.\n\nSome additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).\n\nLet rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.\n\nNote: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DefaultCtx-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DefaultCtx","text":"struct DefaultCtx end\n\nContext for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DynamicDerivedRule-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DynamicDerivedRule","text":"DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)\n\nFor internal use only.\n\nA callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.\n\nThis is used to implement dynamic dispatch.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.FData-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.FData","text":"FData(data::NamedTuple)\n\nThe component of a struct which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64 fields of Tangent do not need to appear in the associated FData.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ID-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ID","text":"ID()\n\nAn ID (read: unique name) is just a wrapper around an Int32. Uniqueness is ensured via a global counter, which is incremented each time that an ID is created.\n\nThis counter can be reset using seed_id! if you need to ensure deterministic IDs are produced, in the same way that seed for random number generators can be set.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IDGotoIfNot-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.IDGotoIfNot","text":"IDGotoIfNot(cond::Any, dest::ID)\n\nLike a GotoIfNot, but dest is an ID rather than an Int64.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IDGotoNode-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.IDGotoNode","text":"IDGotoNode(label::ID)\n\nLike a GotoNode, but label is an ID rather than an Int64.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IDInstPair-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.IDInstPair","text":"const IDInstPair = Tuple{ID, NewInstruction}\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IDPhiNode-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.IDPhiNode","text":"IDPhiNode(edges::Vector{ID}, values::Vector{Any})\n\nLike a PhiNode, but edges are IDs rather than Int32s.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.InstVector-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.InstVector","text":"const InstVector = Vector{NewInstruction}\n\nNote: the CC.NewInstruction type is used to represent instructions because it has the correct fields. While it is only used to represent new instrucdtions in Core.Compiler, it is used to represent all instructions in BBCode.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.InvalidFDataException-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.InvalidFDataException","text":"InvalidFDataException(msg::String)\n\nException indicating that there is a problem with the fdata associated to a primal.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.InvalidRDataException-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.InvalidRDataException","text":"InvalidRDataException(msg::String)\n\nException indicating that there is a problem with the rdata associated to a primal.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.LazyDerivedRule-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.LazyDerivedRule","text":"LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)\n\nFor internal use only.\n\nA type-stable wrapper around a DerivedRule, which only instantiates the DerivedRule when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.\n\nIf debug_mode is true, then the rule constructed will be a DebugRRule. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.\n\nNote: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.LazyZeroRData-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.LazyZeroRData","text":"LazyZeroRData{P, Tdata}()\n\nThis type is a lazy placeholder for zero_like_rdata_from_type. This is used to defer construction of zero data to the reverse pass. Calling instantiate on an instance of this will construct a zero data.\n\nUsers should construct using LazyZeroRData(p), where p is an value of type P. This constructor, and instantiate, are specialised to minimise the amount of data which must be stored. For example, Float64s do not need any data, so LazyZeroRData(0.0) produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.MinimalCtx-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.MinimalCtx","text":"struct MinimalCtx end\n\nFunctions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.NoPullback-Union{NTuple{N, Mooncake.CoDual}, Tuple{N}} where N-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.NoPullback","text":"NoPullback(args::CoDual...)\n\nConstruct a NoPullback from the arguments passed to an rrule!!. For each argument, extracts the primal value, and constructs a LazyZeroRData. These are stored in a NoPullback which, in the reverse-pass of AD, instantiates these LazyZeroRDatas and returns them in order to perform the reverse-pass of AD.\n\nThe advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.RRuleZeroWrapper-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.RRuleZeroWrapper","text":"RRuleZeroWrapper(rule)\n\nThis struct is used to ensure that ZeroRDatas, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.SharedDataPairs-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.SharedDataPairs","text":"SharedDataPairs()\n\nA data structure used to manage the captured data in the OpaqueClosures which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data) at element n of the pairs field of this data structure means that data will be available at register id during the forwards- and reverse-passes of AD.\n\nThis is achieved by storing all of the data in the pairs field in the captured tuple which is passed to an OpaqueClosure, and extracting this data into registers associated to the corresponding IDs.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.Stack-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.Stack","text":"Stack{T}()\n\nA stack specialised for reverse-mode AD.\n\nSemantically equivalent to a usual stack, but never de-allocates memory once allocated.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.Switch-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.Switch","text":"Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID)\n\nA switch-statement node. These can be inserted in the BBCode representation of Julia IR. Switch has the following semantics:\n\ngoto dests[1] if not conds[1]\ngoto dests[2] if not conds[2]\n...\ngoto dests[N] if not conds[N]\ngoto fallthrough_dest\n\nwhere the value associated to each element of conds is a Bool, and dests indicate which block to jump to. If none of the conditions are met, then we go to whichever block is specified by fallthrough_dest.\n\nSwitch statements are lowered into the above sequence of GotoIfNots and GotoNodes when converting BBCode back into IRCode, because Switch statements are not valid nodes in regular Julia IR.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.UnhandledLanguageFeatureException-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.UnhandledLanguageFeatureException","text":"UnhandledLanguageFeatureException(message::String)\n\nAn exception used to indicate that some aspect of the Julia language which AD cannot handle has been encountered.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ZeroRData-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ZeroRData","text":"ZeroRData()\n\nSingleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.\n\nIf you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Base.insert!-Tuple{Mooncake.BBlock, Int64, Mooncake.ID, Core.Compiler.NewInstruction}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Base.insert!","text":"Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing\n\nInserts stmt and id into bb immediately before the nth instruction.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__deref_and_zero-Union{Tuple{P}, Tuple{Type{P}, Ref}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__deref_and_zero","text":"__deref_and_zero(::Type{P}, x::Ref) where {P}\n\nHelper, used in concludervsblock.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__flatten_varargs-Union{Tuple{nvargs}, Tuple{isva}, Tuple{Val{isva}, Any, Val{nvargs}}} where {isva, nvargs}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__flatten_varargs","text":"__flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs}\n\nIf isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__get_value-Tuple{Mooncake.ID, Mooncake.IDPhiNode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__get_value","text":"__get_value(edge::ID, x::IDPhiNode)\n\nHelper functionality for concludervsblock.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__insts_to_instruction_stream-Tuple{Vector{Any}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__insts_to_instruction_stream","text":"__insts_to_instruction_stream(insts::Vector{Any})\n\nProduces an instruction stream whose\n\nstmt (v1.11 and up) / inst (v1.10) field is insts,\ntype field is all Any,\ninfo field is all Core.Compiler.NoCallInfo,\nline field is all Int32(1), and\nflag field is all Core.Compiler.IR_FLAG_REFINED.\n\nAs such, if you wish to ensure that your IRCode prints nicely, you should ensure that its linetable field has at least one element.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__line_numbers_to_block_numbers!-Tuple{Vector{Any}, Core.Compiler.CFG}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__line_numbers_to_block_numbers!","text":"__line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG)\n\nConverts any edges in GotoNodes, GotoIfNots, PhiNodes, and :enter expressions which refer to line numbers into references to block numbers. The cfg provides the information required to perform this conversion.\n\nFor context, CodeInfo objects have references to line numbers, while IRCode uses block numbers.\n\nThis code is copied over directly from the body of Core.Compiler.inflate_ir!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__make_ref-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__make_ref","text":"__make_ref(p::Type{P}) where {P}\n\nHelper for reverse_data_ref_stmts. Constructs a Ref whose element type is the zero_like_rdata_type for P, and whose element is the zero-like rdata for P.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__pop_blk_stack!-Tuple{Mooncake.Stack{Int32}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__pop_blk_stack!","text":"__pop_blk_stack!(block_stack::BlockStack)\n\nEquivalent to pop!(block_stack). Going via this function, rather than just calling pop! directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__push_blk_stack!-Tuple{Mooncake.Stack{Int32}, Int32}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__push_blk_stack!","text":"__push_blk_stack!(block_stack::BlockStack, id::Int32)\n\nEquivalent to push!(block_stack, id). Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__run_rvs_pass!-Union{Tuple{sig}, Tuple{Type, Type{sig}, Any, Ref, Vararg{Any}}} where sig-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__run_rvs_pass!","text":"__run_rvs_pass!(\n P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...\n) where {sig}\n\nUsed in make_ad_stmts! method for Expr(:call, ...) and Expr(:invoke, ...).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__switch_case-Tuple{Int32, Int32}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__switch_case","text":"__switch_case(id::Int32, predecessor_id::Int32)\n\nHelper function emitted by make_switch_stmts.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__unflatten_codual_varargs-Union{Tuple{nargs}, Tuple{isva}, Tuple{Val{isva}, Any, Val{nargs}}} where {isva, nargs}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__unflatten_codual_varargs","text":"__unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs}\n\nIf isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0)) are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0))).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__value_and_gradient!!-Union{Tuple{N}, Tuple{R}, Tuple{R, Vararg{Mooncake.CoDual, N}}} where {R, N}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__value_and_gradient!!","text":"__value_and_gradient!!(rule, f::CoDual, x::CoDual...)\n\nNote: this is not part of the public Mooncake.jl interface, and may change without warning.\n\nEquivalent to __value_and_pullback!!(rule, 1.0, f, x...) – assumes f returns a Float64.\n\n# Set up the problem.\nf(x, y) = sum(x .* y)\nx = [2.0, 2.0]\ny = [1.0, 1.0]\nrule = build_rrule(f, x, y)\n\n# Allocate tangents. These will be written to in-place. You are free to re-use these if you\n# compute gradients multiple times.\ntf = zero_tangent(f)\ntx = zero_tangent(x)\nty = zero_tangent(y)\n\n# Do AD.\nMooncake.__value_and_gradient!!(\n rule, Mooncake.CoDual(f, tf), Mooncake.CoDual(x, tx), Mooncake.CoDual(y, ty)\n)\n# output\n\n(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__value_and_pullback!!-Union{Tuple{T}, Tuple{N}, Tuple{R}, Tuple{R, T, Vararg{Mooncake.CoDual, N}}} where {R, N, T}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__value_and_pullback!!","text":"__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...)\n\nNote: this is not part of the public Mooncake.jl interface, and may change without warning.\n\nIn-place version of value_and_pullback!! in which the arguments have been wrapped in CoDuals. Note that any mutable data in f and x will be incremented in-place. As such, if calling this function multiple times with different values of x, should be careful to ensure that you zero-out the tangent fields of x each time.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._block_nums_to_ids-Tuple{Vector{Core.Compiler.NewInstruction}, Core.Compiler.CFG}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._block_nums_to_ids","text":"_block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector}\n\nAssign to each basic block in cfg an ID. Replace all integers referencing block numbers in insts with the corresponding ID. Return the IDs and the updated instructions.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._build_graph_of_cfg-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._build_graph_of_cfg","text":"_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}\n\nBuilds a SimpleDiGraph, g, representing of the CFG associated to blks, where blks comprises the collection of basic blocks associated to a BBCode. This is a type from Graphs.jl, so constructing g makes it straightforward to analyse the control flow structure of ir using algorithms from Graphs.jl.\n\nReturns a 2-tuple, whose first element is g, and whose second element is a map from the ID associated to each basic block in ir, to the Int corresponding to its node index in g.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._compute_all_predecessors-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._compute_all_predecessors","text":"_compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}\n\nInternal method implementing compute_all_predecessors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._compute_all_successors-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._compute_all_successors","text":"_compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}\n\nInternal method implementing compute_all_successors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._control_flow_graph-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._control_flow_graph","text":"_control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG\n\nInternal function, used to implement control_flow_graph. Easier to write test cases for because there is no need to construct an ensure BBCode object, just the BBlocks.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._distance_to_entry-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._distance_to_entry","text":"_distance_to_entry(blks::Vector{BBlock})::Vector{Int}\n\nFor each basic block in blks, compute the distance from it to the entry point (the first block. The distance is typemax(Int) if no path from the entry point to a given node.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._find_id_uses!-Tuple{Dict{Mooncake.ID, Bool}, Expr}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._find_id_uses!","text":"_find_id_uses!(d::Dict{ID, Bool}, x)\n\nHelper function used in characterise_used_ids. For all uses of IDs in x, set the corresponding value of d to true.\n\nFor example, if x = ReturnNode(ID(5)), then this function sets d[ID(5)] = true.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._foreigncall_-Union{Tuple{N}, Tuple{calling_convention}, Tuple{nreq}, Tuple{RT}, Tuple{name}, Tuple{Val{name}, Val{RT}, Tuple, Val{nreq}, Val{calling_convention}, Vararg{Any, N}}} where {name, RT, nreq, calling_convention, N}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._foreigncall_","text":"function _foreigncall_(\n ::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...\n) where {name, RT, nreq, calling_convention}\n\n:foreigncall nodes get translated into calls to this function. For example,\n\nExpr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)\n\nbecomes\n\n_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)\n\nPlease consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.\n\nCredit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._ids_to_line_numbers-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._ids_to_line_numbers","text":"_ids_to_line_numbers(bb_code::BBCode)::InstVector\n\nFor each statement in bb_code, returns a NewInstruction in which every ID is replaced by either an SSAValue, or an Int64 / Int32 which refers to an SSAValue.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._is_reachable-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._is_reachable","text":"_is_reachable(blks::Vector{BBlock})::Vector{Bool}\n\nComputes a Vector whose length is length(blks). The nth element is true iff it is possible for control flow to reach the nth block.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._lines_to_blocks-Tuple{Vector{Core.Compiler.NewInstruction}, Core.Compiler.CFG}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._lines_to_blocks","text":"_instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector\n\nPulls out the instructions from insts, and calls __line_numbers_to_block_numbers!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._lower_switch_statements-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._lower_switch_statements","text":"_lower_switch_statements(bb_code::BBCode)\n\nConverts all Switchs into a semantically-equivalent collection of GotoIfNots. See the Switch docstring for an explanation of what is going on here.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._map-Union{Tuple{N}, Tuple{F}, Tuple{F, Vararg{Any, N}}} where {F, N}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._map","text":"_map(f, x...)\n\nSame as map but requires all elements of x to have equal length. The usual function map doesn't enforce this for Arrays.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._map_if_assigned!-Union{Tuple{P}, Tuple{F}, Tuple{F, DenseArray, DenseArray{P}, DenseArray}} where {F, P}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._map_if_assigned!","text":"_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)\n\nSimilar to the other method of _map_if_assigned! – for all n, if x1[n] is assigned, writes f(x1[n], x2[n]) to y[n], otherwise leaves y[n] unchanged.\n\nRequires that y, x1, and x2 have the same size.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._map_if_assigned!-Union{Tuple{P}, Tuple{F}, Tuple{F, DenseArray, DenseArray{P}}} where {F, P}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._map_if_assigned!","text":"_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}\n\nFor all n, if x[n] is assigned, then writes the value returned by f(x[n]) to y[n], otherwise leaves y[n] unchanged.\n\nEquivalent to map!(f, y, x) if P is a bits type as element will always be assigned.\n\nRequires that y and x have the same size.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._new_-Union{Tuple{N}, Tuple{T}, Tuple{Type{T}, Vararg{Any, N}}} where {T, N}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._new_","text":"_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}\n\nOne-liner which calls the :new instruction with type T with arguments x.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._remove_double_edges-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._remove_double_edges","text":"_remove_double_edges(ir::BBCode)::BBCode\n\nIf the dest field of an IDGotoIfNot node in block n of ir points towards the n+1th block then we have two edges from block n to block n+1. This transformation replaces all such IDGotoIfNot nodes with unconditional IDGotoNodes pointing towards the n+1th block in ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._sort_blocks!-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._sort_blocks!","text":"_sort_blocks!(ir::BBCode)::BBCode\n\nEnsure that blocks appear in order of distance-from-entry-point, where distance the distance from block b to the entry point is defined to be the minimum number of basic blocks that must be passed through in order to reach b.\n\nFor reasons unknown (to me, Will), the compiler / optimiser needs this for inference to succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem there.\n\nWARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in ir is valid. Notably, this does not hold if you have any IDGotoIfNot nodes in ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._splat_new_-Union{Tuple{P}, Tuple{Type{P}, Tuple}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._splat_new_","text":"_splat_new_(::Type{P}, x::Tuple) where {P}\n\nFunction which replaces instances of :splatnew.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._ssa_to_ids-Tuple{Dict{Core.SSAValue, Mooncake.ID}, Core.Compiler.NewInstruction}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._ssa_to_ids","text":"_ssa_to_ids(d::SSAToIdDict, inst::NewInstruction)\n\nProduce a new instance of inst in which all instances of SSAValues are replaced with the IDs prescribed by d, all basic block numbers are replaced with the IDs prescribed by d, and GotoIfNot, GotoNode, and PhiNode instances are replaced with the corresponding ID versions.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._ssas_to_ids-Tuple{Vector{Core.Compiler.NewInstruction}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._ssas_to_ids","text":"_ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector}\n\nAssigns an ID to each line in stmts, and replaces each instance of an SSAValue in each line with the corresponding ID. For example, a call statement of the form Expr(:call, :f, %4) is be replaced with Expr(:call, :f, id_assigned_to_%4).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._to_ssas-Tuple{Dict, Core.Compiler.NewInstruction}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._to_ssas","text":"_to_ssas(d::Dict, inst::NewInstruction)\n\nLike _ssas_to_ids, but in reverse. Converts IDs to SSAValues / (integers corresponding to ssas).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._typeof-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._typeof","text":"_typeof(x)\n\nCentral definition of typeof, which is specific to the use-required in this package.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ad_stmt_info-Tuple{Mooncake.ID, Union{Nothing, Mooncake.ID}, Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ad_stmt_info","text":"ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)\n\nConvenient constructor for ADStmtInfo. If either fwds or rvs is not a vector, __vec promotes it to a single-element Vector.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.add_data!-Tuple{Mooncake.ADInfo, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.add_data!","text":"add_data!(info::ADInfo, data)::ID\n\nEquivalent to add_data!(info.shared_data_pairs, data).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.add_data!-Tuple{Mooncake.SharedDataPairs, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.add_data!","text":"add_data!(p::SharedDataPairs, data)::ID\n\nPuts data into p, and returns the id associated to it. This id should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id is always data.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.add_data_if_not_singleton!-Tuple{Union{Mooncake.ADInfo, Mooncake.SharedDataPairs}, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.add_data_if_not_singleton!","text":"add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)\n\nReturns x if it is a singleton, or the ID of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.can_produce_zero_rdata_from_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.can_produce_zero_rdata_from_type","text":"can_produce_zero_rdata_from_type(::Type{P}) where {P}\n\nReturns whether or not the zero element of the rdata type for primal type P can be obtained from P alone.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.characterise_unique_predecessor_blocks-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.characterise_unique_predecessor_blocks","text":"characterise_unique_predecessor_blocks(blks::Vector{BBlock}) ->\n Tuple{Dict{ID, Bool}, Dict{ID, Bool}}\n\nWe call a block b a unique predecessor in the control flow graph associated to blks if it is the only predecessor to all of its successors. Put differently we call b a unique predecessor if, whenever control flow arrives in any of the successors of b, we know for certain that the previous block must have been b.\n\nReturns two Dicts. A value in the first Dict is true if the block associated to its key is a unique precessor, and is false if not. A value in the second Dict is true if it has a single predecessor, and that predecessor is a unique predecessor.\n\nContext:\n\nThis information is important for optimising AD because knowing that b is a unique predecessor means that\n\non the forwards-pass, there is no need to push the ID of b to the block stack when passing through it, and\non the reverse-pass, there is no need to pop the block stack when passing through one of the successors to b.\n\nUtilising this reduces the overhead associated to doing AD. It is quite important when working with cheap loops – loops where the operations performed at each iteration are inexpensive – for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.characterise_used_ids-Tuple{Vector{Tuple{Mooncake.ID, Core.Compiler.NewInstruction}}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.characterise_used_ids","text":"characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool}\n\nFor each line in stmts, determine whether it is referenced anywhere else in the code. Returns a dictionary containing the results. An element is false if the corresponding ID is unused, and true if is used.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.collect_stmts-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.collect_stmts","text":"collect_stmts(ir::BBCode)::Vector{IDInstPair}\n\nProduce a Vector containing all of the statements in ir. These are returned in order, so it is safe to assume that element n refers to the nth element of the IRCode associated to ir. \n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.collect_stmts-Tuple{Mooncake.BBlock}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.collect_stmts","text":"collect_stmts(bb::BBlock)::Vector{IDInstPair}\n\nReturns a Vector containing the IDs and instructions associated to each line in bb. These should be assumed to be ordered.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.comms_channel-Tuple{Mooncake.ADStmtInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.comms_channel","text":"comms_channel(info::ADStmtInfo)\n\nReturn the element of fwds whose ID is the communcation ID. Returns Nothing if comms_id is nothing.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.compute_all_predecessors-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.compute_all_predecessors","text":"compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}}\n\nCompute a map from the ID of eachBBlockinir` to its possible predecessors.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.compute_all_successors-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.compute_all_successors","text":"compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}}\n\nCompute a map from the ID of eachBBlockinir` to its possible successors.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.conclude_rvs_block-Tuple{Mooncake.BBlock, Vector{Mooncake.ID}, Bool, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.conclude_rvs_block","text":"conclude_rvs_block(\n blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo\n)\n\nGenerates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.const_ad_stmt-Tuple{Any, Mooncake.ID, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.const_ad_stmt","text":"const_ad_stmt(stmt, line::ID, info::ADInfo)\n\nImplementation of make_ad_stmts! used for constants.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.const_codual-Tuple{Any, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.const_codual","text":"const_codual(stmt, info::ADInfo)\n\nBuild a CoDual from stmt, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.control_flow_graph-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.control_flow_graph","text":"control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG\n\nComputes the Core.Compiler.CFG object associated to this bb_code.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.create_comms_insts!-Tuple{Vector{Tuple{Mooncake.ID, Vector{Mooncake.ADStmtInfo}}}, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.create_comms_insts!","text":"create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)\n\nThis function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id field of the ADStmtInfo type – namely that if you assign a value to comms_id on the forwards-pass, the same value will be available at comms_id on the reverse-pass.\n\nFor each basic block represented in ADStmts:\n\ncreate a stack containing a Tuple which can hold all of the values associated to the comms_ids for each statement. Put this stack in shared data.\ncreate instructions which can be inserted at the end of the block generated to perform the forwards-pass (in forwards_pass_ir) which will put all of the data associated to the comms_ids into shared data, and\ncreate instruction which can be inserted at the start of the block generated to perform the reverse-pass (in pullback_ir), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to the comms_ids.\n\nReturns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}. The nth element of each Vector corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.fcodual_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.fcodual_type","text":"fcodual_type(P::Type)\n\nThe type of the CoDual which contains instances of P and its fdata.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.fdata_field_type-Union{Tuple{P}, Tuple{Type{P}, Int64}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.fdata_field_type","text":"fdata_field_type(::Type{P}, n::Int) where {P}\n\nReturns the type of to the nth field of the fdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.foreigncall_to_call-Tuple{Any, Dict{Symbol, Core.Compiler.VarState}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.foreigncall_to_call","text":"foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})\n\nIf inst is a :foreigncall expression translate it into an equivalent :call expression. If anything else, just return inst. See Mooncake._foreigncall_ for details.\n\nsp_map maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode, in which case the values of sp_map are given by the sptypes field of said IRCode. The keys should generally be obtained from the Method from which the IRCode is derived. See Mooncake.normalise! for more details.\n\nThe purpose of this transformation is to make it possible to differentiate :foreigncall expressions in the same way as a primitive :call expression, i.e. via an rrule!!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.forwards_pass_ir-Tuple{Mooncake.BBCode, Vector{Tuple{Mooncake.ID, Vector{Mooncake.ADStmtInfo}}}, Any, Mooncake.ADInfo, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.forwards_pass_ir","text":"forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)\n\nProduce the IR associated to the OpaqueClosure which runs most of the forwards-pass.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.fwd_ir-Tuple{Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.fwd_ir","text":"fwd_ir(\n sig::Type{<:Tuple};\n interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true\n)::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGenerate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.\n\nFor example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:\n\njulia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\nArguments\n\nsig::Type{<:Tuple}: the signature of the call to be differentiated.\n\nKeyword Arguments\n\ninterp: the interpreter to use to obtain the primal IR.\ndebug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.\ndo_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.gc_preserve-Tuple-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.gc_preserve","text":"gc_preserve(xs...)\n\nA no-op function. Its rrule!! ensures that the memory associated to xs is not freed until the pullback that it returns is run.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.generate_ir-Tuple{Mooncake.MooncakeInterpreter, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.generate_ir","text":"generate_ir(\n interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true\n)\n\nUsed by build_rrule, and the various debugging tools: primalir, fwdsir, adjoint_ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.get_const_primal_value-Tuple{GlobalRef}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.get_const_primal_value","text":"get_const_primal_value(x::GlobalRef)\n\nGet the value associated to x. For GlobalRefs, verify that x is indeed a constant.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.get_primal_type-Tuple{Mooncake.ADInfo, Core.Argument}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.get_primal_type","text":"get_primal_type(info::ADInfo, x)\n\nReturns the static / inferred type associated to x.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.get_rev_data_id-Tuple{Mooncake.ADInfo, Core.Argument}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.get_rev_data_id","text":"get_rev_data_id(info::ADInfo, x)\n\nReturns the ID associated to the line in the reverse pass which will contain the reverse data for x. If x is not an Argument or ID, then nothing is returned.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.get_tangent_field-Union{Tuple{Tfs}, Tuple{Union{MutableTangent{Tfs}, Tangent{Tfs}}, Int64}} where Tfs-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.get_tangent_field","text":"get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)\n\nGets the ith field of data in t.\n\nHas the same semantics that getfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of getfield for MutableTangent.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.id_to_line_map-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.id_to_line_map","text":"id_to_line_map(ir::BBCode)\n\nProduces a Dict mapping from each ID associated with a line in ir to its line number. This is isomorphic to mapping to its SSAValue in IRCode. Terminators do not have IDs associated to them, so not every line in the original IRCode is mapped to.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.inc_args-Tuple{Expr}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.inc_args","text":"inc_args(stmt)\n\nIncrement by 1 the n field of any Arguments present in stmt.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.increment_and_get_rdata!-Union{Tuple{T}, Tuple{NoFData, T, T}} where T<:Union{Float16, Float32, Float64}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.increment_and_get_rdata!","text":"increment_and_get_rdata!(fdata, zero_rdata, cr_tangent)\n\nIncrement fdata by the fdata component of the ChainRules.jl-style tangent, cr_tangent, and return the rdata component of cr_tangent by adding it to zero_rdata.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.increment_field!!-Union{Tuple{i}, Tuple{Tuple, Any, Val{i}}} where i-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.increment_field!!","text":"increment_field!!(x::T, y::V, f) where {T, V}\n\nincrement!! the field f of x by y, and return the updated x.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.increment_rdata!!-Union{Tuple{T}, Tuple{T, Any}} where T-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.increment_rdata!!","text":"increment_rdata!!(t::T, r)::T where {T}\n\nIncrement the rdata component of tangent t by r, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.infer_ir!-Tuple{Core.Compiler.IRCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.infer_ir!","text":"infer_ir!(ir::IRCode) -> IRCode\n\nRuns type inference on ir, which mutates ir, and returns it.\n\nNote: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag is not set to Core.Compiler.IR_FLAG_REFINED. Nor will it attempt to refine the type of the value returned by a :invoke expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.insert_before_terminator!-Tuple{Mooncake.BBlock, Mooncake.ID, Core.Compiler.NewInstruction}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.insert_before_terminator!","text":"insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing\n\nIf the final instruction in bb is a Terminator, insert inst immediately before it. Otherwise, insert inst at the end of the block.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.interpolate_boundschecks!-Tuple{Core.Compiler.IRCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.interpolate_boundschecks!","text":"interpolate_boundschecks!(ir::IRCode)\n\nFor every x = Expr(:boundscheck, value) in ir, interpolate value into all uses of x. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.intrinsic_to_function-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.intrinsic_to_function","text":"intrinsic_to_function(inst)\n\nIf inst is a :call expression to a Core.IntrinsicFunction, replace it with a call to the corresponding function from Mooncake.IntrinsicsWrappers, else return inst.\n\ncglobal is a special case – it requires that its first argument be static in exactly the same way as :foreigncall. See IntrinsicsWrappers.__cglobal for more info.\n\nThe purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers for more context.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ircode-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ircode","text":"ircode(\n inst::Vector{Any},\n argtypes::Vector{Any},\n sptypes::Vector{CC.VarState}=CC.VarState[],\n) -> IRCode\n\nConstructs an instance of an IRCode. This is useful for constructing test cases with known properties.\n\nNo optimisations or type inference are performed on the resulting IRCode, so that the IRCode contains exactly what is intended by the caller. Please make use of infer_types! if you require the types to be inferred.\n\nEdges in PhiNodes, GotoIfNots, and GotoNodes found in inst must refer to lines (as in CodeInfo). In the IRCode returned by this function, these line references are translated into block references.\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_always_fully_initialised-Tuple{DataType}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_always_fully_initialised","text":"is_always_fully_initialised(P::DataType)::Bool\n\nTrue if all fields in P are always initialised. Put differently, there are no inner constructors which permit partial initialisation.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_always_initialised-Tuple{DataType, Int64}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_always_initialised","text":"is_always_initialised(P::DataType, n::Int)::Bool\n\nTrue if the nth field of P is always initialised. If the nth fieldtype of P isbitstype, then this is distinct from asking whether the nth field is always defined. An isbits field is always defined, but is not always explicitly initialised.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_primitive-Tuple{Type{Mooncake.MinimalCtx}, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_primitive","text":"is_primitive(::Type{Ctx}, sig) where {Ctx}\n\nReturns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.\n\nis_primitive(DefaultCtx, Tuple{typeof(sin), Float64})\n\nwill return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.\n\nObserve that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_reachable_return_node-Tuple{Core.ReturnNode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_reachable_return_node","text":"is_reachable_return_node(x::ReturnNode)\n\nDetermine whether x is a ReturnNode, and if it is, if it is also reachable. This is purely a function of whether or not its val field is defined or not.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_unreachable_return_node-Tuple{Core.ReturnNode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_unreachable_return_node","text":"is_unreachable_return_node(x::ReturnNode)\n\nDetermine whehter x is a ReturnNode, and if it is, if it is also unreachable. This is purely a function of whether or not its val field is defined or not.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_used-Tuple{Mooncake.ADInfo, Mooncake.ID}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_used","text":"is_used(info::ADInfo, id::ID)::Bool\n\nReturns true if id is used by any of the lines in the ir, false otherwise.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_vararg_and_sparam_names-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_vararg_and_sparam_names","text":"is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}\n\nFinds the method associated to sig, and calls is_vararg_and_sparam_names on it.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_vararg_and_sparam_names-Tuple{Core.MethodInstance}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_vararg_and_sparam_names","text":"is_vararg_and_sparam_names(mi::Core.MethodInstance)\n\nCalls is_vararg_and_sparam_names on mi.def::Method.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_vararg_and_sparam_names-Tuple{Method}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_vararg_and_sparam_names","text":"is_vararg_and_sparam_names(m::Method)\n\nReturns a 2-tuple. The first element is true if m is a vararg method, and false if not. The second element contains the names of the static parameters associated to m.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lgetfield-Union{Tuple{f}, Tuple{Any, Val{f}}} where f-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lgetfield","text":"lgetfield(x, f::Val)\n\nAn implementation of getfield in which the the field f is specified statically via a Val. This enables the implementation to be type-stable even when it is not possible to constant-propagate f. Moreover, it enable the pullback to also be type-stable.\n\nIt will always be the case that\n\ngetfield(x, :f) === lgetfield(x, Val(:f))\ngetfield(x, 2) === lgetfield(x, Val(2))\n\nThis approach is identical to the one taken by Zygote.jl to circumvent the same problem. Zygote.jl calls the function literal_getfield, while we call it lgetfield.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lgetfield-Union{Tuple{order}, Tuple{f}, Tuple{Any, Val{f}, Val{order}}} where {f, order}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lgetfield","text":"lgetfield(x, ::Val{f}, ::Val{order}) where {f, order}\n\nLike getfield, but with the field and access order encoded as types.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lift_gc_preservation-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lift_gc_preservation","text":"lift_gc_preserve(inst)\n\nExpressions of the form\n\ny = GC.@preserve x1 x2 foo(args...)\n\nget lowered to\n\ntoken = Expr(:gc_preserve_begin, x1, x2)\ny = expr\nExpr(:gc_preserve_end, token)\n\nThese expressions guarantee that any memory associated x1 and x2 not be freed until the :gc_preserve_end expression is reached.\n\nIn the context of reverse-mode AD, we must ensure that the memory associated to x1, x2 and their fdata is available during the reverse pass code associated to expr. We do this by preventing the memory from being freed until the :gc_preserve_begin is reached on the reverse pass.\n\nTo achieve this, we replace the primal code with\n\n# store `x` in `pb_gc_preserve` to prevent it from being freed.\n_, pb_gc_preserve = rrule!!(zero_fcodual(gc_preserve), x1, x2)\n\n# Differentiate the `:call` expression in the usual way.\ny, foo_pb = rrule!!(zero_fcodual(foo), args...)\n\n# Do not permit the GC to free `x` here.\nnothing\n\nThe pullback should be something along the lines of\n\n# no pullback associated to `nothing`.\nnothing\n\n# Run the pullback associated to `foo` in the usual manner. `x` must be available.\n_, dargs... = foo_pb(dy)\n\n# No-op pullback associated to `gc_preserve`.\npb_gc_preserve(NoRData())\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lift_getfield_and_others-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lift_getfield_and_others","text":"lift_getfield_and_others(inst)\n\nConverts expressions of the form getfield(x, :a) into lgetfield(x, Val(:a)). This has identical semantics, but is performant in the absence of proper constant propagation.\n\nDoes the same for...\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lookup_ir-Tuple{Core.Compiler.AbstractInterpreter, Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lookup_ir","text":"lookup_ir(\n interp::AbstractInterpreter,\n sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},\n)::Tuple{IRCode, T}\n\nGet the unique IR associated to sig_or_mi under interp. Throws ArgumentErrors if there is no code found, or if more than one IRCode instance returned.\n\nReturns a tuple containing the IRCode and its return type.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lsetfield!-Union{Tuple{name}, Tuple{Any, Val{name}, Any}} where name-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lsetfield!","text":"lsetfield!(value, name::Val, x, [order::Val])\n\nThis function is to setfield! what lgetfield is to getfield. It will always hold that\n\nsetfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)\nsetfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.make_ad_stmts!-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.make_ad_stmts!","text":"make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo\n\nEvery line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.\n\nTranslates the instruction inst, associated to line in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo.\n\ninfo is a data structure containing various bits of global information that certain types of nodes need access to.\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/internal_docstrings/#Mooncake.make_switch_stmts-Tuple{Vector{Mooncake.ID}, Vector{Mooncake.ID}, Bool, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.make_switch_stmts","text":"make_switch_stmts(\n pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo\n)\n\npreds_ids comprises the IDs associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is [ID(1), ID(2), ID(3)], then make_switch_stmts emits code along the lines of\n\nprev_block = pop!(block_stack)\nnot_pred_was_1 = !(prev_block == ID(1))\nnot_pred_was_2 = !(prev_block == ID(2))\nswitch(\n not_pred_was_1 => ID(1),\n not_pred_was_2 => ID(2),\n ID(3)\n)\n\nIn words: make_switch_stmts emits code which jumps to whichever block preceded the current block during the forwards-pass.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.new_inst-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.new_inst","text":"new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction\n\nCreate a NewInstruction with fields:\n\nstmt = stmt\ntype = type\ninfo = CC.NoCallInfo()\nline = Int32(1)\nflag = flag\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/internal_docstrings/#Mooncake.new_inst_vec-Tuple{Core.Compiler.InstructionStream}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.new_inst_vec","text":"new_inst_vec(x::CC.InstructionStream)\n\nConvert an Compiler.InstructionStream into a list of Compiler.NewInstructions.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.new_to_call-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.new_to_call","text":"new_to_call(x)\n\nIf instruction x is a :new expression, replace it with a :call to Mooncake._new_. Otherwise, return x.\n\nThe purpose of this transformation is to make it possible to differentiate :new expressions in the same way as a primitive :call expression, i.e. via an rrule!!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.normalise!-Tuple{Core.Compiler.IRCode, Vector{Symbol}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.normalise!","text":"normalise!(ir::IRCode, spnames::Vector{Symbol})\n\nApply a sequence of standardising transformations to ir which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace\n\n:foreigncall Exprs with :calls to Mooncake._foreigncall_,\n:new Exprs with :calls to Mooncake._new_,\n:splatnew Exprs with:calls toMooncake.splatnew_`,\nCore.IntrinsicFunctions with counterparts from Mooncake.IntrinsicWrappers,\ngetfield(x, 1) with lgetfield(x, Val(1)), and related transformations,\nmemoryrefget calls to lmemoryrefget calls, and related transformations,\ngc_preserve_begin / gc_preserve_end exprs so that memory release is delayed.\n\nspnames are the names associated to the static parameters of ir. These are needed when handling :foreigncall expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter expressions.\n\nUnfortunately, the static parameter names are not retained in IRCode, and the Method from which the IRCode is derived must be consulted. Mooncake.is_vararg_and_sparam_names provides a convenient way to do this.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.optimise_ir!-Tuple{Core.Compiler.IRCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.optimise_ir!","text":"optimise_ir!(ir::IRCode, show_ir=false)\n\nRun a fairly standard optimisation pass on ir. If show_ir is true, displays the IR to stdout at various points in the pipeline – this is sometimes useful for debugging.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.phi_nodes-Tuple{Mooncake.BBlock}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.phi_nodes","text":"phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}}\n\nReturns all of the IDPhiNodes at the start of bb, along with their IDs. If there are no IDPhiNodes at the start of bb, then both vectors will be empty.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.primal_ir-Tuple{Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.primal_ir","text":"primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGet the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).\n\nFor example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:\n\njulia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.pullback_ir-Tuple{Mooncake.BBCode, Any, Vector{Tuple{Mooncake.ID, Vector{Mooncake.ADStmtInfo}}}, Any, Mooncake.ADInfo, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.pullback_ir","text":"pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)\n\nProduce the IR associated to the OpaqueClosure which runs most of the pullback.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.pullback_type-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.pullback_type","text":"pullback_type(Trule, arg_types)\n\nGet a bound on the pullback type, given a rule and associated primal types.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rdata_field_type-Union{Tuple{P}, Tuple{Type{P}, Int64}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rdata_field_type","text":"rdata_field_type(::Type{P}, n::Int) where {P}\n\nReturns the type of to the nth field of the rdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.remove_unreachable_blocks!-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.remove_unreachable_blocks!","text":"remove_unreachable_blocks!(ir::BBCode)::BBCode\n\nIf a basic block in ir cannot possibly be reached during execution, then it can be safely removed from ir without changing its functionality. A block is unreachable if either:\n\nit has no predecessors and it is not the first block, or\nall of its predecessors are themselves unreachable.\n\nFor example, consider the following IR:\n\njulia> ir = Mooncake.ircode(\n Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))],\n Any[Any, Any, Any],\n );\n\nThere is no possible way to reach the second basic block (lines 2 and 3). Applying this function will therefore remove it, yielding the following:\n\njulia> Mooncake.IRCode(Mooncake.remove_unreachable_blocks!(Mooncake.BBCode(ir)))\n1 1 ─ return nothing\n\nIn the blocks which have not been removed, there may be references to blocks which have been removed. For example, the edges in a PhiNode may contain a reference to a removed block. These references are removed in-place from these remaining blocks, so this function will (in general) modify ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.replace_captures-Union{Tuple{Tmc}, Tuple{Tmc, Any}} where Tmc<:MistyClosures.MistyClosure-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.replace_captures","text":"replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}\n\nSame as replace_captures for Core.OpaqueClosures, but returns a new MistyClosure.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.replace_captures-Union{Tuple{Toc}, Tuple{Toc, Any}} where Toc<:Core.OpaqueClosure-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.replace_captures","text":"replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}\n\nGiven an OpaqueClosure oc, create a new OpaqueClosure of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosures that it produces multiple times, because it can be quite expensive to do so.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.replace_uses_with!-Tuple{Any, Union{Core.Argument, Core.SSAValue}, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.replace_uses_with!","text":"replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)\n\nReplace all uses of def with val in the single statement stmt. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl. You probably do not want to use it.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.reverse_data_ref_stmts-Tuple{Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.reverse_data_ref_stmts","text":"reverse_data_ref_stmts(info::ADInfo)\n\nCreate the statements which initialise the reverse-data Refs.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rrule_wrapper-Union{NTuple{N, Mooncake.CoDual}, Tuple{N}} where N-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rrule_wrapper","text":"rrule_wrapper(f::CoDual, args::CoDual...)\n\nUsed to implement rrule!!s via ChainRulesCore.rrule.\n\nGiven a function foo, argument types arg_types, and a method of ChainRulesCore.rrule which applies to these, you can make use of this function as follows:\n\nMooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}\nfunction Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...)\n return rrule_wrapper(f, args...)\nend\n\nAssumes that methods of to_cr_tangent and to_mooncake_tangent are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.\n\nFurthermore, it is essential that\n\nf(args) does not mutate f or args, and\nthe result of f(args) does not alias any data stored in f or args.\n\nSubject to some constraints, you can use the @from_rrule macro to reduce the amount of boilerplate code that you are required to write even further.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rule_type-Union{Tuple{C}, Tuple{Mooncake.MooncakeInterpreter{C}, Any}} where C-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rule_type","text":"rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}\n\nCompute the concrete type of the rule that will be returned from build_rrule. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rvs_ir-Tuple{Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rvs_ir","text":"rvs_ir(\n sig::Type{<:Tuple};\n interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true\n)::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGenerate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.\n\nFor example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:\n\njulia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\nArguments\n\nsig::Type{<:Tuple}: the signature of the call to be differentiated.\n\nKeyword Arguments\n\ninterp: the interpreter to use to obtain the primal IR.\ndebug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.\ndo_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rvs_phi_block-Tuple{Mooncake.ID, Vector{Mooncake.ID}, Vector{Any}, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rvs_phi_block","text":"rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)\n\nProduces a BBlock which runs the reverse-pass for the edge associated to pred_id in a collection of IDPhiNodes, and then goes to the block associated to pred_id.\n\nFor example, suppose that we encounter the following collection of PhiNodes at the start of some block:\n\n%6 = φ (#2 => _1, #3 => %5)\n%7 = φ (#2 => 5., #3 => _2)\n\nLet the tangent refs associated to %6, %7, and _1be denotedt%6,t%7, andt1resp., and letpredidbe#2`, then this function will produce a basic block of the form\n\nincrement_ref!(t_1, t%6)\nnothing\ngoto #2\n\nThe call to increment_ref! appears because _1 is the value associated to%6 when the primal code comes from #2. Similarly, the goto #2 statement appears because we came from #2 on the forwards-pass. There is no increment_ref! associated to %7 because 5. is a constant. We emit a nothing statement, which the compiler will happily optimise away later on.\n\nThe same ideas apply if pred_id were #3. The block would end with #3, and there would be two increment_ref! calls because both %5 and _2 are not constants.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.seed_id!-Tuple{}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.seed_id!","text":"seed_id!()\n\nSet the global counter used to ensure ID uniqueness to 0. This is useful when you want to ensure determinism between two runs of the same function which makes use of IDs.\n\nThis is akin to setting the random seed associated to a random number generator globally.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.set_tangent_field!-Union{Tuple{Tfields}, Tuple{MutableTangent{Tfields}, Int64, Any}} where Tfields-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.set_tangent_field!","text":"set_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}\n\nSets the value of the ith field of the data in t to value x.\n\nHas the same semantics that setfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of setfield! for MutableTangent.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.shared_data_stmts-Tuple{Mooncake.SharedDataPairs}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.shared_data_stmts","text":"shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}\n\nProduce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p) such that the correct value is associated to the correct ID.\n\nFor example, if p.pairs is\n\n[(ID(5), 5.0), (ID(3), \"hello\")]\n\nthen the output of this function is\n\nIDInstPair[\n (ID(5), new_inst(:(getfield(_1, 1)))),\n (ID(3), new_inst(:(getfield(_1, 2)))),\n]\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.shared_data_tuple-Tuple{Mooncake.SharedDataPairs}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.shared_data_tuple","text":"shared_data_tuple(p::SharedDataPairs)::Tuple\n\nCreate the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosures.\n\nFor example, if p.pairs is\n\n[(ID(5), 5.0), (ID(3), \"hello\")]\n\nthen the output of this function is\n\n(5.0, \"hello\")\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.sparam_names-Tuple{Method}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.sparam_names","text":"sparam_names(m::Core.Method)::Vector{Symbol}\n\nReturns the names of all of the static parameters in m.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.splatnew_to_call-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.splatnew_to_call","text":"splatnew_to_call(x)\n\nIf instruction x is a :splatnew expression, replace it with a :call to Mooncake._splat_new_. Otherwise return x.\n\nThe purpose of this transformation is to make it possible to differentiate :splatnew expressions in the same way as a primitive :call expression, i.e. via an rrule!!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.stmt-Tuple{Core.Compiler.InstructionStream}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.stmt","text":"stmt(ir::CC.InstructionStream)\n\nGet the field containing the instructions in ir. This changed name in 1.11 from inst to stmt.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.tangent_field_type-Union{Tuple{P}, Tuple{Type{P}, Int64}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.tangent_field_type","text":"tangent_field_type(::Type{P}, n::Int) where {P}\n\nReturns the type that lives in the nth elements of fields in a Tangent / MutableTangent. Will either be the tangent_type of the nth fieldtype of P, or the tangent_type wrapped in a PossiblyUninitTangent. The latter case only occurs if it is possible for the field to be undefined.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.tangent_test_cases-Tuple{}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.tangent_test_cases","text":"tangent_test_cases()\n\nConstructs a Vector of Tuples containing test cases for the tangent infrastructure.\n\nIf the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value\n\ninterface_only is a Bool which will be used to determine which subset of tests to run.\n\nIf the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).\n\nTest cases in the first format make use of zero_tangent / randn_tangent etc to generate tangents, but they're unable to check that increment!! is correct in an absolute sense.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.terminator-Tuple{Mooncake.BBlock}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.terminator","text":"terminator(bb::BBlock)\n\nReturns the terminator associated to bb. If the last instruction in bb isa Terminator then that is returned, otherwise nothing is returned.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.to_cr_tangent-Tuple{Union{Float16, Float32, Float64}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.to_cr_tangent","text":"to_cr_tangent(t)\n\nConvert a Mooncake tangent into a type that ChainRules.jl rrules expect to see.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.tuple_map-Union{Tuple{F}, Tuple{F, Tuple}} where F-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.tuple_map","text":"tuple_map(f::F, x::Tuple) where {F}\n\nThis function is largely equivalent to map(f, x), but always specialises on all of the element types of x, regardless the length of x. This contrasts with map, in which the number of element types specialised upon is a fixed constant in the compiler.\n\nAs a consequence, if x is very long, this function may have very large compile times.\n\ntuple_map(f::F, x::Tuple, y::Tuple) where {F}\n\nBinary extension of tuple_map. Nearly equivalent to map(f, x, y), but guaranteed to specialise on all element types of x and y. Furthermore, errors if x and y aren't the same length, while map will just produce a new tuple whose length is equal to the shorter of x and y.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.unhandled_feature-Tuple{String}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.unhandled_feature","text":"unhandled_feature(msg::String)\n\nThrow an UnhandledLanguageFeatureException with message msg.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.uninit_codual-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.uninit_codual","text":"uninit_codual(x)\n\nEquivalent to CoDual(x, uninit_tangent(x)).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.uninit_fcodual-Tuple{P} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.uninit_fcodual","text":"uninit_fcodual(x)\n\nLike zero_fcodual, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.uninit_tangent-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.uninit_tangent","text":"uninit_tangent(x)\n\nRelated to zero_tangent, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.verify_fdata_type-Tuple{Type, Type}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.verify_fdata_type","text":"verify_fdata_type(P::Type, F::Type)::Nothing\n\nCheck that F is a valid type for fdata associated to a primal of type P. Returns nothing if valid, throws an InvalidFDataException if a problem is found.\n\nThis applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.verify_fdata_value-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.verify_fdata_value","text":"verify_fdata_value(p, f)::Nothing\n\nCheck that f cannot be proven to be invalid fdata for p.\n\nThis method attempts to provide some confidence that f is valid fdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.\n\nPut differently, we cannot prove that f is valid fdata, only that it is not obviously invalid.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.verify_rdata_type-Tuple{Type, Type}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.verify_rdata_type","text":"verify_rdata_type(P::Type, R::Type)::Nothing\n\nCheck that R is a valid type for rdata associated to a primal of type P. Returns nothing if valid, throws an InvalidRDataException if a problem is found.\n\nThis applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.verify_rdata_value-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.verify_rdata_value","text":"verify_rdata_value(p, r)::Nothing\n\nCheck that r cannot be proven to be invalid rdata for p.\n\nThis method attempts to provide some confidence that r is valid rdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.\n\nPut differently, we cannot prove that r is valid rdata, only that it is not obviously invalid.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_adjoint-Union{Tuple{N}, Tuple{Mooncake.CoDual, Vararg{Mooncake.CoDual, N}}} where N-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_adjoint","text":"zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}\n\nUtility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.\n\nNOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.\n\nYou make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:\n\njulia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual\n\njulia> foo(x::Vararg{Int}) = 5\nfoo (generic function with 1 method)\n\njulia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;\n\njulia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())\n(NoRData(), NoRData(), NoRData())\n\nWARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_like_rdata_from_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_like_rdata_from_type","text":"zero_like_rdata_from_type(::Type{P}) where {P}\n\nThis is an internal implementation detail – you should generally not use this function.\n\nReturns either the zero element of type rdata_type(tangent_type(P)), or a ZeroRData. It is always valid to return a ZeroRData, \n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_like_rdata_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_like_rdata_type","text":"zero_like_rdata_type(::Type{P}) where {P}\n\nIndicates the type which will be returned by zero_like_rdata_from_type. Will be the rdata type for P if we can produce the zero rdata element given only P, and will be the union of R and ZeroRData if an instance of P is needed.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_rdata-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_rdata","text":"zero_rdata(p)\n\nGiven value p, return the zero element associated to its reverse data type.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_rdata_from_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_rdata_from_type","text":"zero_rdata_from_type(::Type{P}) where {P}\n\nReturns the zero element of rdata_type(tangent_type(P)) if this is possible given only P. If not possible, returns an instance of CannotProduceZeroRDataFromType.\n\nFor example, the zero rdata associated to any primal of type Float64 is 0.0, so for Float64s this function is simple. Similarly, if the rdata type for P is NoRData, that can simply be returned.\n\nHowever, it is not possible to return the zero rdata element for abstract types e.g. Real as the type does not uniquely determine the zero element – the rdata type for Real is Any.\n\nThese considerations apply recursively to tuples / namedtuples / structs, etc.\n\nIf you encounter a type which this function returns CannotProduceZeroRDataFromType, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.@from_rrule-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.@from_rrule","text":"@from_rrule ctx sig [has_kwargs=false]\n\nConvenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.\n\nArguments\n\nctx: A Mooncake context type\nsig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.\nhas_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.\n\nExample Usage\n\nA Basic Example\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real) = 5x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω\n return foo(x), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)\n(NoRData(), 5.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)\nTest Passed\n\nAn Example with Keyword Arguments\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω\n return foo(x; cond), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true\n\njulia> _, pb = rrule!!(\n zero_fcodual(Core.kwcall),\n zero_fcodual((cond=false, )),\n zero_fcodual(foo),\n zero_fcodual(5.0),\n );\n\njulia> pb(3.0)\n(NoRData(), NoRData(), NoRData(), 12.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(\n Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true\n )\nTest Passed\n\nNotice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.\n\nLimitations\n\nIt is your responsibility to ensure that\n\ncalls with signature sig do not mutate their arguments,\nthe output of calls with signature sig does not alias any of the inputs.\n\nAs with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.\n\nArgument Type Constraints\n\nMany methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature\n\nTuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}\n\nThere are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.\n\nSuffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.\n\nConversions Between Different Tangent Type Systems\n\nUnder the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.\n\n\n\n\n\n","category":"macro"},{"location":"developer_documentation/internal_docstrings/#Mooncake.@is_primitive-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.@is_primitive","text":"@is_primitive context_type signature\n\nCreates a method of is_primitive which always returns true for the context_type and signature provided. For example\n\n@is_primitive MinimalCtx Tuple{typeof(foo), Float64}\n\nis equivalent to\n\nis_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true\n\nYou should implemented more complicated method of is_primitive in the usual way.\n\n\n\n\n\n","category":"macro"},{"location":"developer_documentation/internal_docstrings/#Mooncake.@mooncake_overlay-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.@mooncake_overlay","text":"@mooncake_overlay method_expr\n\nDefine a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.\n\nFor example, suppose that you have a function\n\njulia> foo(x::Float64) = bar(x)\nfoo (generic function with 1 method)\n\nwhere Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:\n\njulia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)\n\n\nWhen looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.\n\nA Worked Example\n\nTo demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!\n\nFirst, consider a simple example:\n\njulia> scale(x) = 2x\nscale (generic function with 1 method)\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(10.0, (NoTangent(), 2.0))\n\nWe can use @mooncake_overlay to change the definition which Mooncake.jl sees:\n\njulia> Mooncake.@mooncake_overlay scale(x) = 3x\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(15.0, (NoTangent(), 3.0))\n\nAs can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.\n\nAdditionally, it is possible to use the usual multi-line syntax to declare an overlay:\n\njulia> Mooncake.@mooncake_overlay function scale(x)\n return 4x\n end\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(20.0, (NoTangent(), 4.0))\n\n\n\n\n\n","category":"macro"},{"location":"developer_documentation/internal_docstrings/#Mooncake.@zero_adjoint-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.@zero_adjoint","text":"@zero_adjoint ctx sig\n\nDefines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.\n\nFor example:\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo(x) = 5\nfoo (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})\ntrue\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())\n(NoRData(), 0.0)\n\nLimited support for Varargs is also available. For example\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo_varargs(x...) = 5\nfoo_varargs (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})\ntrue\n\njulia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())\n(NoRData(), 0.0, NoRData())\n\nBe aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.\n\nWARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.\n\nAs always, you should use TestUtils.test_rule to ensure that you've not made a mistake.\n\nSignatures Unsupported By This Macro\n\nIf the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.\n\n\n\n\n\n","category":"macro"},{"location":"developer_documentation/internal_docstrings/","page":"Internal Docstrings","title":"Internal Docstrings","text":"Mooncake.IntrinsicsWrappers","category":"page"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IntrinsicsWrappers","page":"Internal Docstrings","title":"Mooncake.IntrinsicsWrappers","text":"module IntrinsicsWrappers\n\nThe purpose of this module is to associate to each function in Core.Intrinsics a regular Julia function.\n\nTo understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction in Core.Intrinsics does not have its own type. Rather, they are instances of Core.IntrinsicFunction. To see this, observe that\n\njulia> typeof(Core.Intrinsics.add_float)\nCore.IntrinsicFunction\n\njulia> typeof(Core.Intrinsics.sub_float)\nCore.IntrinsicFunction\n\nWhile we could simply write a rule for Core.IntrinsicFunction, this would (naively) lead to a large list of conditionals of the form\n\nif f === Core.Intrinsics.add_float\n # return add_float and its pullback\nelseif f === Core.Intrinsics.sub_float\n # return add_float and its pullback\nelseif\n ...\nend\n\nwhich has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).\n\nInstead, we map each Core.IntrinsicFunction to one of the regular Julia functions in Mooncake.IntrinsicsWrappers, to which we can dispatch in the usual way.\n\nExtended Help\n\nIt is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!! to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .\n\n\n\n\n\n","category":"module"},{"location":"developer_documentation/developer_tools/#Developer-Tools","page":"Developer Tools","title":"Developer Tools","text":"","category":"section"},{"location":"developer_documentation/developer_tools/","page":"Developer Tools","title":"Developer Tools","text":"Mooncake.jl offers developers to a few convenience functions which give access to the IR that it generates in order to perform AD. These are lightweight wrappers around internals which save you from having to dig in to the objects created by build_rrule.","category":"page"},{"location":"developer_documentation/developer_tools/","page":"Developer Tools","title":"Developer Tools","text":"Since these provide access to internals, they do not follow the usual rules of semver, and may change without notice!","category":"page"},{"location":"developer_documentation/developer_tools/","page":"Developer Tools","title":"Developer Tools","text":"Mooncake.primal_ir\nMooncake.fwd_ir\nMooncake.rvs_ir","category":"page"},{"location":"developer_documentation/developer_tools/#Mooncake.primal_ir","page":"Developer Tools","title":"Mooncake.primal_ir","text":"primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGet the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).\n\nFor example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:\n\njulia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/developer_tools/#Mooncake.fwd_ir","page":"Developer Tools","title":"Mooncake.fwd_ir","text":"fwd_ir(\n sig::Type{<:Tuple};\n interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true\n)::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGenerate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.\n\nFor example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:\n\njulia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\nArguments\n\nsig::Type{<:Tuple}: the signature of the call to be differentiated.\n\nKeyword Arguments\n\ninterp: the interpreter to use to obtain the primal IR.\ndebug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.\ndo_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/developer_tools/#Mooncake.rvs_ir","page":"Developer Tools","title":"Mooncake.rvs_ir","text":"rvs_ir(\n sig::Type{<:Tuple};\n interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true\n)::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGenerate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.\n\nFor example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:\n\njulia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\nArguments\n\nsig::Type{<:Tuple}: the signature of the call to be differentiated.\n\nKeyword Arguments\n\ninterp: the interpreter to use to obtain the primal IR.\ndebug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.\ndo_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/running_tests_locally/#Running-Tests-Locally","page":"Running Tests Locally","title":"Running Tests Locally","text":"","category":"section"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Mooncake.jl's test suite is fairly extensive. While you can use Pkg.test to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl, and will not run all of the tests. When editing some code, you typically only want to run the tests associated with it, not the entire test suite.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"There are two workflows for running tests, discussed below.","category":"page"},{"location":"developer_documentation/running_tests_locally/#Main-Testing-Functionality","page":"Running Tests Locally","title":"Main Testing Functionality","text":"","category":"section"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"For all code in src, Mooncake's tests are organised as follows:","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Things that are required for most / all test suites are loaded up in test/front_matter.jl.\nThe tests for something in src are located in an identically-named file in test. e.g. the unit tests for src/rrules/new.jl are located in test/rrules/new.jl.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Thus, a workflow that I (Will) find works very well is the following:","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Ensure that you have Revise.jl and TestEnv.jl installed in your default environment.\nstart the REPL, dev Mooncake.jl, and navigate to the top level of the Mooncake.jl directory.\nusing TestEnv, Revise. Better still, load both of these in your .julia/config/startup.jl file so that you don't ever forget to load them.\nRun the following: using Pkg; Pkg.activate(\".\"); TestEnv.activate(); include(\"test/front_matter.jl\"); to set up your environment.\ninclude whichever test file you want to run the tests from.\nModify code, and re-include tests to check it has done was you need. Loop this until done.\nMake a PR. This runs the entire test suite – I find that I almost never run the entire test suite locally.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"The purpose of this approach is to:","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Avoid restarting the REPL each time you make a change, and\nRun the smallest bit of the test suite possible when making changes, in order to make development a fast and enjoyable process.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice.","category":"page"},{"location":"developer_documentation/running_tests_locally/#Extension-and-Integration-Testing","page":"Running Tests Locally","title":"Extension and Integration Testing","text":"","category":"section"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Mooncake now has quite a lot of package extensions, and a large number of integration tests. Unfortunately, these come with a lot of additional dependencies. To avoid these dependencies causing CI to take much longer to run, we locate all tests for extensions and integration testing in their own environments. These can be found in the test/ext and test/integration_testing directories respectively.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"These directories comprise a single .jl file, and a Project.toml. You should run these tests by simply includeing the .jl file. Doing so will activate the environemnt, ensure that the correct version of Mooncake is used, and run the tests.","category":"page"},{"location":"utilities/debugging_and_mwes/#Debugging-and-MWEs","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"","category":"section"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"There's a reasonable chance that you'll run into an issue with Mooncake.jl at some point. In order to debug what is going on when this happens, or to produce an MWE, it is helpful to have a convenient way to run Mooncake.jl on whatever function and arguments you have which are causing problems.","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"We recommend making use of Mooncake.jl's testing functionality to generate your test cases:","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"Mooncake.TestUtils.test_rule","category":"page"},{"location":"utilities/debugging_and_mwes/#Mooncake.TestUtils.test_rule","page":"Debugging and MWEs","title":"Mooncake.TestUtils.test_rule","text":"test_rule(\n rng, x...;\n interface_only=false,\n is_primitive::Bool=true,\n perf_flag::Symbol=:none,\n interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(),\n debug_mode::Bool=false,\n unsafe_perturb::Bool=false,\n)\n\nRun standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.\n\nThis function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.\n\nArguments\n\nrng::AbstractRNG: a random number generator\nx...: the function (first element) and its arguments (the remainder)\n\nKeyword Arguments\n\ninterface_only::Bool=false: test only that the interface is satisfied, without testing correctness. This should generally be set to false (the default value), and only enabled if the testing infrastructure is unable to test correctness for some reason e.g. the returned value of the function is a Ptr, and appropriate tangents cannot, therefore, be generated for it automatically.\nis_primitive::Bool=true: check whether the thing that you are testing has a hand-written rrule!!. This option is helpful if you are testing a new rrule!!, as it enables you to verify that your method of is_primitive has returned the correct value, and that you are actually testing a method of the rrule!! function – a common mistake when authoring a new rrule!! is to implement is_primitive incorrectly and to accidentally wind up testing a rule which Mooncake has derived, as opposed to the one that you have written. If you are testing something for which you have not hand-written an rrule!!, or which you do not care whether it has a hand-written rrule!! or not, you should set it to false.\nperf_flag::Symbol=:none: the value of this symbol determines what kind of performance tests should be performed. By default, none are performed. If you believe that a rule should be allocation-free (iff the primal is allocation free), set this to :allocs. If you hand-write an rrule!! and believe that your test case should be type stable, set this to :stability (at present we cannot verify whether a derived rule is type stable for technical reasons). If you believe that a hand-written rule should be both allocation-free and type-stable, set this to :stability_and_allocs.\ninterp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(): the abstract interpreter to be used when testing this rule. The default should generally be used.\ndebug_mode::Bool=false: whether or not the rule should be tested in debug mode. Typically this should be left at its default false value, but if you are finding that the tests are failing for a given rule, you may wish to temporarily set it to true in order to get access to additional information and automated testing.\nunsafe_perturb::Bool=false: value passed as the third argument to _add_to_primal. Should usually be left false – consult the docstring for _add_to_primal for more info on when you might wish to set it to true.\n\n\n\n\n\n","category":"function"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"This approach is convenient because it can","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"check whether AD runs at all,\ncheck whether AD produces the correct answers,\ncheck whether AD is performant, and\ncan be used without having to manually generate tangents.","category":"page"},{"location":"utilities/debugging_and_mwes/#Example","page":"Debugging and MWEs","title":"Example","text":"","category":"section"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"DocTestSetup = quote\n using Random, Mooncake\nend","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"For example","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"f(x) = Core.bitcast(Float64, x)\nMooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.","category":"page"},{"location":"utilities/debugging_and_mwes/#Segfaults","page":"Debugging and MWEs","title":"Segfaults","text":"","category":"section"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.","category":"page"},{"location":"understanding_mooncake/rule_system/#Mooncake.jl's-Rule-System","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.jl's approach to AD is recursive. It has a single specification for what it means to differentiate a Julia callable, and basically two approaches to achieving this. This section of the documentation explains the former.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We take an iterative approach to this explanation, starting at a high-level and adding more depth as we go.","category":"page"},{"location":"understanding_mooncake/rule_system/#10,000-Foot-View","page":"Mooncake.jl's Rule System","title":"10,000 Foot View","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule r(f, x) for a function f(x) \"does reverse mode AD\", and executes in two phases, known as the forwards pass and the reverse pass. In the forwards pass a rule executes the original function, and does some additional book-keeping in preparation for the reverse pass. On the reverse pass it undoes the computation from the forwards pass, \"backpropagates\" the gradient w.r.t. the output of the original function by applying the adjoint of the derivative of the original function to it, and writes the results of this computation to the correct places.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A precise mathematical model for the original function is therefore entirely crucial to this discussion, as it is needed to understand what the adjoint of its derivative is.","category":"page"},{"location":"understanding_mooncake/rule_system/#A-Model-For-A-Julia-Function","page":"Mooncake.jl's Rule System","title":"A Model For A Julia Function","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Since Julia permits the in-place modification / mutation of many data structures, we cannot make a naive translation between a Julia function and a mathematical object. Rather, we will have to model the state of the arguments to a function both before and after execution. Moreover, since a function can allocate new memory as part of execution and return it to the calling scope, we must track that too.","category":"page"},{"location":"understanding_mooncake/rule_system/#Consider-Only-Externally-Visible-Effects-Of-Function-Evaluation","page":"Mooncake.jl's Rule System","title":"Consider Only Externally-Visible Effects Of Function Evaluation","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We wish to treat a given function as a black box – we care about what a function does, not how it does it – so we consider only the externally-visible results of executing it. There are two ways in which changes can be made externally visible.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Return Value","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"(This point hardly requires explanation, but for the sake of completeness we do so anyway.)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The most obvious way in which a result can be made visible outside of a function is via its return value. For example, letting bar(x) = sin(x), consider the function","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n y = bar(x)\n z = bar(y)\n return z\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The communication between the two invocations of bar happen via the value it returns.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Modification of arguments","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In contrast to the above, changes made by one function can be made available to another implicitly if it modifies the values of its arguments, even if it doesn't return anything. For example, consider:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function bar(x::Vector{Float64})\n x .*= 2\n return nothing\nend\n\nfunction foo(x::Vector{Float64})\n bar(x)\n bar(x)\n return x\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second call to bar in foo sees the changes made to x by the first call to bar, despite not being explicitly returned.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"No Global Mutable State","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"functions can in principle also communicate via global mutable state. We make the decision to not support this.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For example, we assume functions of the following form cannot be encountered:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"const a = randn(10)\n\nfunction bar(x)\n a .+= x\n return nothing\nend\n\nfunction foo(x)\n bar(x)\n return a\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, a is modified by bar, the effect of which is visible to foo.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For a variety of reasons this is very awkward to handle well. Since it's largely considered poor practice anyway, we explicitly outlaw this mode of communication between functions. See Why Support Closures But Not Mutable Globals for more info.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Note that this does not preclude the use of closed-over values or callable structs. For example, something like","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n function bar(y)\n x .+= y\n return nothing\n end\n return bar(x)\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"is perfectly fine.","category":"page"},{"location":"understanding_mooncake/rule_system/#The-Model","page":"Mooncake.jl's Rule System","title":"The Model","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"It is helpful to have a concrete example which uses both of the permissible methods to make results externally visible. To this end, consider the following function:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function f(x::Vector{Float64}, y::Vector{Float64}, z::Vector{Float64}, s::Ref{Vector{Float64}})\n z .*= y .* x\n s[] = 2z\n return sum(z)\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We draw your attention to a variety of features of this function:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"z is mutated,\ns is mutated to reference freshly allocated memory,\nthe value previously pointed to by s is unmodified, and\nwe allocate a new value and return it (albeit, it is probably allocated on the stack).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The model we adopt for any Julia function f is a function f mathcalX to mathcalX times mathcalA where mathcalX is the real finite Hilbert space associated to the arguments to f prior to execution, and mathcalA is the real finite Hilbert space associated to any newly allocated data during execution which is externally visible after execution – any newly allocated data which is not made visible is of no concern.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, mathcalX = RR^D times RR^D times RR^D times RR^S where D is the length of x / y / z, and S the length of s[] prior to running f. mathcalA = RR^D times RR, where the RR^D component corresponds to the freshly allocated memory that s references, and RR to the return value. Observe that we model Float64s as elements of RR, Vector{Float64}s as elements of RR^D (for some value of D), and Refs with whatever the model for their contents is. The keen-eyed reader will note that these choices abstract away several details which could conceivably be included in the model. In particular, Vector{Float64} is implemented via a memory buffer, a pointer to the start of this buffer, and an integer which indicates the length of this buffer – none of these details are exposed in the model.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, some of the memory allocated during execution is made externally visible by modifying one of the arguments, not just via the return value.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The argument to f is the arguments to f before execution, and the output is the 2-tuple comprising the same arguments after execution and the values associated to any newly allocated / created data. Crucially, observe that we distinguish between the state of the arguments before and after execution.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For our example, the exact form of f is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f((x y z s)) = ((x y x odot y s) (2 x odot y sum_d=1^D x odot y))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that f behaves a little like a transition operator, in the that the first element of the tuple returned is the updated state of the arguments.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This model is good enough for the vast majority of functions. Unfortunately it isn't sufficient to describe a function when arguments alias each other (e.g. consider the way in which this particular model is wrong if y aliases z). Fortunately this is only a problem in a small fraction of all cases of aliasing, so we defer discussion of this until later on.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider now how this approach can be used to model several additional Julia functions, and to obtain their derivatives and adjoints.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"sin(x::Float64)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"mathcalX = RR, mathcalA = RR, f(x) = (x sin(x)).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Thus the derivative is D f x (dotx) = (dotx cos(x) dotx), and its adjoint is D f x^ast (bary) = bary_x + bary_a cos(x), where bary = (bary_x bary_a).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that this result is slightly different to the last example we saw involving sin.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"AD With Mutable Data","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider again","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function f!(x::Vector{Float64})\n x .*= x\n return sum(x)\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Our framework is able to accomodate this function, and has essentially the same solution as the last time we saw this example:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x odot x sum_n=1^N x_n^2)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Non-Mutating Functions","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A very interesting class of functions are those which do not modify their arguments. These are interesting because they are common, and are all that many AD frameworks like ChainRules.jl / Zygote.jl support – by considering this class of functions, we highlight some key similarities between these distinct rule systems.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"As always we can model these kinds of functions with a function f mathcalX to mathcalX times mathcalA, but we additionally have that f must have the form","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x varphi(x))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"for some function varphi mathcalX to mathcalA. The derivative is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x (dotx) = (dotx D varphi x(dotx))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider the usual inner product to derive the adjoint:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"beginalign\n langle bary D f x (dotx) rangle = langle (bary_1 bary_2) (dotx D varphi x(dotx)) rangle nonumber \n = langle bary_1 dotx rangle + langle bary_2 D varphi x(dotx) rangle nonumber \n = langle bary_1 dotx rangle + langle D varphi x^ast (bary_2) dotx rangle nonumber quad text(by definition of the adjoint) \n = langle bary_1 + D varphi x^ast (bary_2) dotx rangle nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"So the adjoint of the derivative is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x^ast (bary) = bary_1 + D varphi x^ast (bary_2)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We see the correct thing to do is to increment the gradient of the output – bary_1 – by the result of applying the adjoint of the derivative of varphi to bary_2. In a ChainRules.rrule the bary_1 term is always zero, but the D varphi x^ast (bary_2) term is essentially the same.","category":"page"},{"location":"understanding_mooncake/rule_system/#The-Rule-Interface-(Round-1)","page":"Mooncake.jl's Rule System","title":"The Rule Interface (Round 1)","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Having explained in principle what it is that a rule must do, we now take a first look at the interface we use to achieve this. A rule for a function foo with signature","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Tuple{typeof(foo), Float64} -> Float64","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"must have signature","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Tuple{Trule, CoDual{typeof(foo), NoFData}, CoDual{Float64, NoFData}} ->\n Tuple{CoDual{Float64, NoFData}, Trvs_pass}","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For example, if we call foo(5.0), it rules would be called as rule(CoDual(foo, NoFData()), CoDual(5.0, NoFData())). The precise definition and role of NoFData will be explained shortly, but the general scheme is that to a rule for foo you must pass foo itself, its arguments, and some additional data for book-keeping. foo and each of its arguments are paired with this additional book-keeping data via the CoDual type.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The rule returns another CoDual (it propagates book-keeping information forwards), along with a function which runs the reverse pass.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In a little more depth:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Notation: primal","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Throughout the rest of this document, we will refer to the function being differentiated as the \"primal\" computation, and its arguments as the \"primal\" arguments.","category":"page"},{"location":"understanding_mooncake/rule_system/#Forwards-Pass","page":"Mooncake.jl's Rule System","title":"Forwards Pass","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Inputs","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Each piece of each input to the primal is paired with shadow data, if it has a fixed address. For example, a Vector{Float64} argument is paired with another Vector{Float64}. The adjoint of f is accumulated into this shadow vector on the reverse pass. However, a Float64 argument gets paired with NoFData(), since it is a bits type and therefore has no fixed address.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Outputs","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule must return a Tuple of two things. The first thing must be a CoDual containing the output of the primal computation and its shadow memory (if it has any). The second must be a function which runs the reverse pass of AD – this will usually be a closure of some kind.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Functionality","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule must","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"ensure that the state of the primal components of all inputs / the output are as they would have been had the primal computation been run (up to differences due to finite precision arithmetic),\npropagate / construct the shadow memory associated to the output (initialised to zero), and\nconstruct the function to run the reverse pass – typically this will involve storing some quantities computed during the forwards pass.","category":"page"},{"location":"understanding_mooncake/rule_system/#Reverse-Pass","page":"Mooncake.jl's Rule System","title":"Reverse Pass","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second element of the output of a rule is a function which runs the reverse pass.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Inputs","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The \"rdata\" associated to the output of the primal.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Outputs","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The \"rdata\" associated to the inputs of the primal.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Functionality","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"undo changes made to primal state on the forwards pass.\napply adjoint of derivative of primal operation, putting the results in the correct place.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This description should leave you with (at least) a couple of questions. What is \"rdata\", and what is \"the correct place\" to put the results of applying the adjoint of the derivative? In order to address these, we need to discuss the types that Mooncake.jl uses to represent the results of AD, and to propagate results backwards on the reverse pass.","category":"page"},{"location":"understanding_mooncake/rule_system/#Representing-Gradients","page":"Mooncake.jl's Rule System","title":"Representing Gradients","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We refer to both inputs and outputs of derivatives D f x mathcalX to mathcalY as tangents, e.g. dotx or doty. Conversely, we refer to both inputs and outputs to the adjoint of this derivative D f x^ast mathcalY to mathcalX as gradients, e.g. bary and barx.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Note, however, that the sets involved are the same whether dealing with a derivative or its adjoint. Consequently, we use the same type to represent both.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Representing Gradients","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This package assigns to each type in Julia a unique tangent_type, the purpose of which is to contain the gradients computed during reverse mode AD. The extended docstring for tangent_type provides the best introduction to the types which are used to represent tangents / gradients.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"tangent_type(P)","category":"page"},{"location":"understanding_mooncake/rule_system/#Mooncake.tangent_type-Tuple{Any}","page":"Mooncake.jl's Rule System","title":"Mooncake.tangent_type","text":"tangent_type(P)\n\nThere must be a single type used to represents tangents of primals of type P, and it must be given by tangent_type(P).\n\nExtended help\n\nThe tangent types which Mooncake.jl uses are quite similar in spirit to ChainRules.jl. For example, tangent \"vectors\" for\n\nFloat64s are Float64s,\nVector{Float64}s are Vector{Float64}s, and\nstructs are other another (special) struct with field types specified recursively.\n\nThere are, however, some major differences. Firstly, while it is certainly true that the above tangent types are permissible in ChainRules.jl, they are not the uniquely permissible types. For example, ZeroTangent is also a permissible type of tangent for any of them, and Float32 is permissible for Float64. This is a general theme in ChainRules.jl – it intentionally declines to place restrictions on what type can be used to represent the tangent of a given type.\n\nMooncake.jl differs from this. It insists that each primal type is associated to a single tangent type. Furthermore, this type is always given by the function Mooncake.tangent_type(primal_type).\n\nConsider some more worked examples.\n\nInt\n\nInt is not a differentiable type, so its tangent type is NoTangent:\n\njulia> tangent_type(Int)\nNoTangent\n\nTuples\n\nThe tangent type of a Tuple is defined recursively based on its field types. For example\n\njulia> tangent_type(Tuple{Float64, Vector{Float64}, Int})\nTuple{Float64, Vector{Float64}, NoTangent}\n\nThere is one edge case to be aware of: if all of the field of a Tuple are non-differentiable, then the tangent type is NoTangent. For example,\n\njulia> tangent_type(Tuple{Int, Int})\nNoTangent\n\nStructs\n\nAs with Tuples, the tangent type of a struct is, by default, given recursively. In particular, the tangent type of a struct type is Tangent. This type contains a NamedTuple containing the tangent to each field in the primal struct.\n\nAs with Tuples, if all field types are non-differentiable, the tangent type of the entire struct is NoTangent.\n\nThere are a couple of additional subtleties to consider over Tuples though. Firstly, not all fields of a struct have to be defined. Fortunately, Julia makes it easy to determine how many of the fields might possibly not be defined. The tangent associated to any field which might possibly not be defined is wrapped in a PossiblyUninitTangent.\n\nFurthermore, structs can have fields whose static type is abstract. For example\n\njulia> struct Foo\n x\n end\n\nIf you ask for the tangent type of Foo, you will see that it is\n\njulia> tangent_type(Foo)\nTangent{@NamedTuple{x}}\n\nObserve that the field type associated to x is Any. The way to understand this result is to observe that\n\nx could have literally any type at runtime, so we know nothing about what its tangent type must be until runtime, and\nwe require that the tangent type of Foo be unique.\n\nThe consequence of these two considerations is that the tangent type of Foo must be able to contain any type of tangent in its x field. It follows that the fieldtype of the x field of Foos tangent must be Any.\n\nMutable Structs\n\nThe tangent type for mutable structs have the same set of considerations as structs. The only difference is that they must themselves be mutable. Consequently, we use a type called MutableTangent to represent their tangents. It is a mutable struct with the same structure as Tangent.\n\nFor example, if you ask for the tangent_type of\n\njulia> mutable struct Bar\n x::Float64\n end\n\nyou will find that it is\n\njulia> tangent_type(Bar)\nMutableTangent{@NamedTuple{x::Float64}}\n\nPrimitive Types\n\nWe've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.\n\nOne interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example\n\njulia> tangent_type(Ptr{Float64})\nPtr{Float64}\n\n\n\n\n\n","category":"method"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"FData and RData","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this \"unique gradient address\" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The following docstring provides the best in-depth explanation.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.fdata_type(T)","category":"page"},{"location":"understanding_mooncake/rule_system/#Mooncake.fdata_type-Tuple{Any}","page":"Mooncake.jl's Rule System","title":"Mooncake.fdata_type","text":"fdata_type(T)\n\nReturns the type of the forwards data associated to a tangent of type T.\n\nExtended help\n\nRules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.\n\nGiven a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.\n\nGiven a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that\n\ntangent(fdata(t), rdata(t)) === t\n\nThe need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.\n\nInt\n\nInts are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore\n\njulia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))\n(NoFData, NoRData)\n\nFloat64\n\nThe tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so\n\njulia> (fdata_type(Float64), rdata_type(Float64))\n(NoFData, Float64)\n\nVector{Float64}\n\nThe tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so\n\njulia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))\n(Vector{Float64}, NoRData)\n\nTuple{Float64, Vector{Float64}, Int}\n\nThis is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:\n\njulia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})\nTuple{Float64, Vector{Float64}, NoTangent}\n\njulia> (fdata_type(T), rdata_type(T))\n(Tuple{NoFData, Vector{Float64}, NoFData}, Tuple{Float64, NoRData, NoRData})\n\nThe zero tangent for (5.0, [5.0]) is t = (0.0, [0.0]). fdata(t) returns (NoFData(), [0.0]), where the second element is === to the second element of t. rdata(t) returns (0.0, NoRData()). In this example, t contains a mixture of data, some of which is identified by its value, and some of which is identified by its address, so there is some fdata and some rdata.\n\nStructs\n\nStructs are handled in more-or-less the same way as Tuples, albeit with the possibility of undefined fields needing to be explicitly handled. For example, a struct such as\n\njulia> struct Foo\n x::Float64\n y\n z::Int\n end\n\nhas tangent type\n\njulia> tangent_type(Foo)\nTangent{@NamedTuple{x::Float64, y, z::NoTangent}}\n\nIts fdata and rdata are given by special FData and RData types:\n\njulia> (fdata_type(tangent_type(Foo)), rdata_type(tangent_type(Foo)))\n(Mooncake.FData{@NamedTuple{x::NoFData, y, z::NoFData}}, Mooncake.RData{@NamedTuple{x::Float64, y, z::NoRData}})\n\nPractically speaking, FData and RData both have the same structure as Tangents and are just used in different contexts.\n\nMutable Structs\n\nThe fdata for a mutable structs is its tangent, and it has no rdata. This is because mutable structs have fixed memory addresses, and can therefore be incremented in-place. For example,\n\njulia> mutable struct Bar\n x::Float64\n y\n z::Int\n end\n\nhas tangent type\n\njulia> tangent_type(Bar)\nMutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}\n\nand fdata / rdata types\n\njulia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))\n(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)\n\nPrimitive Types\n\nAs with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.\n\n\n\n\n\n","category":"method"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"CoDuals","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A quick aside: Non-Differentiable Data","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a 0 tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.","category":"page"},{"location":"understanding_mooncake/rule_system/#The-Rule-Interface-(Round-2)","page":"Mooncake.jl's Rule System","title":"The Rule Interface (Round 2)","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"DocTestSetup = quote\n using Mooncake\n using Mooncake: CoDual\n import Mooncake: rrule!!\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider the function","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])\nfoo (generic function with 1 method)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The fdata for x is a Tuple{NoFData, Vector{Float64}}, and its rdata is a Tuple{Float64, NoRData}. The function returns a Float64, which has no fdata, and whose rdata is Float64. So on the forwards pass there is really nothing that needs to happen with the fdata for x.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Under the framework introduced above, the model for this function is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x x_1 + sum_n=1^N (x_2)_n)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where the vector in the second element of x is of length N. Now, following our usual steps, the derivative is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x(dotx) = (dotx dotx_1 + sum_n=1^N (dotx_2)_n)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A gradient for this is a tuple (bary_x bary_a) where bary_a in RR and bary_x in RR times RR^N. A quick derivation will show that the adjoint is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x^ast(bary) = ((bary_x)_1 + bary_a (bary_x)_2 + bary_a mathbf1)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where mathbf1 is the vector of length N in which each element is equal to 1. (Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Now that we know what the adjoint is, we'll write down the rrule!!, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding – Mooncake.jl should be relied upon to generate this code automatically in practice.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})\n dx_fdata = x.dx\n function dfoo_adjoint(dy::Float64)\n dx_fdata[2] .+= dy\n dx_1_rdata = dy\n dx_rdata = (dx_1_rdata, NoRData())\n return NoRData(), dx_rdata\n end\n x_p = x.x\n return CoDual(x_p[1] + sum(x_p[2]), NoFData()), dfoo_adjoint\n end;\n","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where dy is the rdata for the output to foo. The rrule!! can be called with the appropriate CoDuals:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> out, pb!! = rrule!!(CoDual(foo, NoFData()), CoDual((5.0, [1.0, 2.0]), (NoFData(), [0.0, 0.0])))\n(CoDual{Float64, NoFData}(8.0, NoFData()), var\"#dfoo_adjoint#1\"{Tuple{NoFData, Vector{Float64}}}((NoFData(), [0.0, 0.0])))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"and the pullback with appropriate rdata:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> pb!!(1.0)\n(NoRData(), (1.0, NoRData()))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"DocTestSetup = nothing","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that the forwards pass:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"computes the result of the initial function, and\npulls out the fdata for the Vector{Float64} component of the argument.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"As promised, the forwards pass really has nothing to do with the adjoint. It's just book-keeping and running the primal computation.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The reverse pass:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"increments each element of dx_fdata[2] by dy – this corresponds to (bary_x)_2 + bary_a mathbf1 in the adjoint,\nsets dx_1_rdata to dy – this corresponds (bary_x)_1 + bary_a subject to the constraint that (bary_x)_1 = 0,\nconstructs the rdata for x – this is essentially just book-keeping.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Each of these items serve to demonstrate more general points. The first that, upon entry into the reverse pass, all fdata values correspond to gradients for the arguments / output of f \"upon exit\" (for the components of these which are identified by their address), and once the reverse-pass finishes running, they must contain the gradients w.r.t. the arguments of f \"upon entry\".","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second that we always assume that the components of bary_x which are identified by their value have zero-rdata.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The third is that the components of the arguments of f which are identified by their value must have rdata passed back explicitly by a rule, while the components of the arguments to f which are identified by their address get their gradients propagated back implicitly (i.e. via the in-place modification of fdata).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Reminder: the first element of the tuple returned by dfoo_adjoint is the rdata associated to foo itself, hence it is NoRData.","category":"page"},{"location":"understanding_mooncake/rule_system/#Testing","page":"Mooncake.jl's Rule System","title":"Testing","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.jl has an almost entirely automated system for testing rules – Mooncake.TestUtils.test_rule. You should absolutely make use of these when writing rules.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"TODO: improve docstring for testing functionality.","category":"page"},{"location":"understanding_mooncake/rule_system/#Summary","page":"Mooncake.jl's Rule System","title":"Summary","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this section we have covered the rule system. Every callable object / function in the Julia language is differentiated using rules with this interface, whether they be hand-written rrule!!s, or rules derived by Mooncake.jl.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"At this point you should be equipped with enough information to understand what a rule in Mooncake.jl does, and how you can write your own ones. Later sections will explain how Mooncake.jl goes about deriving rules itself in a recursive manner, and introduce you to some of the internals.","category":"page"},{"location":"understanding_mooncake/rule_system/#Asides","page":"Mooncake.jl's Rule System","title":"Asides","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/#Why-Uniqueness-of-Type-For-Tangents-/-FData-/-RData?","page":"Mooncake.jl's Rule System","title":"Why Uniqueness of Type For Tangents / FData / RData?","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Why does Mooncake.jl insist that each primal type P be paired with a single tangent type T, as opposed to being more permissive. There are a few notable reasons:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"To provide a precise interface. Rules pass fdata around on the forwards pass and rdata on the reverse pass – being able to make strong assumptions about the type of the fdata / rdata given the primal type makes implementing rules much easier in practice.\nConditional type stability. We wish to have a high degree of confidence that if the primal code is type-stable, then the AD code will also be. It is straightforward to construct type stable primal codes which have type-unstable forwards and reverse passes if you permit there to be more than one fdata / rdata type for a given primal. So while uniqueness is certainly not sufficient on its own to guarantee conditional type stability, it is probably necessary in general.\nTest-case generation and coverage. There being a unique tangent / fdata / rdata type for each primal makes being confident that a given rule is being tested thoroughly much easier. For a given primal, rather than there being many possible input / output types to consider, there is just one.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This topic, in particular what goes wrong with permissive tangent type systems like those employed by ChainRules, deserves a more thorough treatment – hopefully someone will write something more expansive on this topic at some point.","category":"page"},{"location":"understanding_mooncake/rule_system/#Why-Support-Closures-But-Not-Mutable-Globals","page":"Mooncake.jl's Rule System","title":"Why Support Closures But Not Mutable Globals","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"First consider why closures are straightforward to support. Look at the type of the closure produced by foo:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n function bar(y)\n x .+= y\n return nothing\n end\n return bar\nend\nbar = foo(randn(5))\ntypeof(bar)\n\n# output\nvar\"#bar#1\"{Vector{Float64}}","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that the Vector{Float64} that we passed to foo, and closed over in bar, is present in the type. This alludes to the fact that closures are basically just callable structs whose fields are the closed-over variables. Since the function itself is an argument to its rule, everything enters the rule for bar via its arguments, and the rule system developed in this document applies straightforwardly.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"On the other hand, globals do not appear in the functions that they are a part of. For example,","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"const a = randn(10)\n\nfunction g(x)\n a .+= x\n return nothing\nend\n\ntypeof(g)\n\n# output\ntypeof(g) (singleton type of function g, subtype of Function)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.","category":"page"},{"location":"utilities/tools_for_rules/#Tools-for-Rules","page":"Tools for Rules","title":"Tools for Rules","text":"","category":"section"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!! from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!s in many situations.","category":"page"},{"location":"utilities/tools_for_rules/#Simplfiying-Code-via-Overlays","page":"Tools for Rules","title":"Simplfiying Code via Overlays","text":"","category":"section"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@mooncake_overlay","category":"page"},{"location":"utilities/tools_for_rules/#Mooncake.@mooncake_overlay","page":"Tools for Rules","title":"Mooncake.@mooncake_overlay","text":"@mooncake_overlay method_expr\n\nDefine a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.\n\nFor example, suppose that you have a function\n\njulia> foo(x::Float64) = bar(x)\nfoo (generic function with 1 method)\n\nwhere Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:\n\njulia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)\n\n\nWhen looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.\n\nA Worked Example\n\nTo demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!\n\nFirst, consider a simple example:\n\njulia> scale(x) = 2x\nscale (generic function with 1 method)\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(10.0, (NoTangent(), 2.0))\n\nWe can use @mooncake_overlay to change the definition which Mooncake.jl sees:\n\njulia> Mooncake.@mooncake_overlay scale(x) = 3x\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(15.0, (NoTangent(), 3.0))\n\nAs can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.\n\nAdditionally, it is possible to use the usual multi-line syntax to declare an overlay:\n\njulia> Mooncake.@mooncake_overlay function scale(x)\n return 4x\n end\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(20.0, (NoTangent(), 4.0))\n\n\n\n\n\n","category":"macro"},{"location":"utilities/tools_for_rules/#Functions-with-Zero-Adjoint","page":"Tools for Rules","title":"Functions with Zero Adjoint","text":"","category":"section"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:","category":"page"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@zero_adjoint\nMooncake.zero_adjoint","category":"page"},{"location":"utilities/tools_for_rules/#Mooncake.@zero_adjoint","page":"Tools for Rules","title":"Mooncake.@zero_adjoint","text":"@zero_adjoint ctx sig\n\nDefines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.\n\nFor example:\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo(x) = 5\nfoo (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})\ntrue\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())\n(NoRData(), 0.0)\n\nLimited support for Varargs is also available. For example\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo_varargs(x...) = 5\nfoo_varargs (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})\ntrue\n\njulia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())\n(NoRData(), 0.0, NoRData())\n\nBe aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.\n\nWARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.\n\nAs always, you should use TestUtils.test_rule to ensure that you've not made a mistake.\n\nSignatures Unsupported By This Macro\n\nIf the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.\n\n\n\n\n\n","category":"macro"},{"location":"utilities/tools_for_rules/#Mooncake.zero_adjoint","page":"Tools for Rules","title":"Mooncake.zero_adjoint","text":"zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}\n\nUtility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.\n\nNOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.\n\nYou make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:\n\njulia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual\n\njulia> foo(x::Vararg{Int}) = 5\nfoo (generic function with 1 method)\n\njulia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;\n\njulia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())\n(NoRData(), NoRData(), NoRData())\n\nWARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```\n\n\n\n\n\n","category":"function"},{"location":"utilities/tools_for_rules/#Using-ChainRules.jl","page":"Tools for Rules","title":"Using ChainRules.jl","text":"","category":"section"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.","category":"page"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"There is enough similarity between these two systems that most of the boilerplate code can be avoided.","category":"page"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@from_rrule","category":"page"},{"location":"utilities/tools_for_rules/#Mooncake.@from_rrule","page":"Tools for Rules","title":"Mooncake.@from_rrule","text":"@from_rrule ctx sig [has_kwargs=false]\n\nConvenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.\n\nArguments\n\nctx: A Mooncake context type\nsig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.\nhas_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.\n\nExample Usage\n\nA Basic Example\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real) = 5x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω\n return foo(x), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)\n(NoRData(), 5.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)\nTest Passed\n\nAn Example with Keyword Arguments\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω\n return foo(x; cond), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true\n\njulia> _, pb = rrule!!(\n zero_fcodual(Core.kwcall),\n zero_fcodual((cond=false, )),\n zero_fcodual(foo),\n zero_fcodual(5.0),\n );\n\njulia> pb(3.0)\n(NoRData(), NoRData(), NoRData(), 12.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(\n Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true\n )\nTest Passed\n\nNotice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.\n\nLimitations\n\nIt is your responsibility to ensure that\n\ncalls with signature sig do not mutate their arguments,\nthe output of calls with signature sig does not alias any of the inputs.\n\nAs with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.\n\nArgument Type Constraints\n\nMany methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature\n\nTuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}\n\nThere are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.\n\nSuffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.\n\nConversions Between Different Tangent Type Systems\n\nUnder the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.\n\n\n\n\n\n","category":"macro"},{"location":"utilities/debug_mode/#Debug-Mode","page":"Debug Mode","title":"Debug Mode","text":"","category":"section"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"DocTestSetup = quote\n using Mooncake, ADTypes\nend","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"The Problem","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})\n plus_reverse_pass(dz::Real) = NoRData(), dz, dz\n return zero_fcodual(primal(x) + primal(y))\nend","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"and calls","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, and rdata_type(Float32) is Float32, so returning a Float64 is incorrect. This error might cause the reverse pass to fail loudly immediately, but it might also fail silently. It might cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"The Solution","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"This is implemented via DebugRRule:","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Mooncake.DebugRRule","category":"page"},{"location":"utilities/debug_mode/#Mooncake.DebugRRule","page":"Debug Mode","title":"Mooncake.DebugRRule","text":"DebugRRule(rule)\n\nConstruct a callable which is equivalent to rule, but inserts additional type checking. In particular:\n\ncheck that the fdata in each argument is of the correct type for the primal\ncheck that the fdata in the CoDual returned from the rule is of the correct type for the primal.\n\nThis happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.\n\nSome additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).\n\nLet rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.\n\nNote: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.\n\n\n\n\n\n","category":"type"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Mooncake.build_rrule","category":"page"},{"location":"utilities/debug_mode/#Mooncake.build_rrule","page":"Debug Mode","title":"Mooncake.build_rrule","text":"build_rrule(args...; debug_mode=false)\n\nHelper method. Only uses static information from args.\n\n\n\n\n\nbuild_rrule(sig::Type{<:Tuple})\n\nEquivalent to build_rrule(Mooncake.get_interpreter(), sig).\n\n\n\n\n\nbuild_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}\n\nReturns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.\n\nIf debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.\n\n\n\n\n\n","category":"function"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))\nAutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))","category":"page"},{"location":"utilities/debug_mode/#When-Should-You-Use-Debug-Mode?","page":"Debug Mode","title":"When Should You Use Debug Mode?","text":"","category":"section"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Only use debug_mode when debugging a problem. This is because is has substantial performance implications.","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"DocTestSetup = nothing","category":"page"},{"location":"understanding_mooncake/introduction/#Introduction","page":"Introduction","title":"Introduction","text":"","category":"section"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.","category":"page"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"we recap what AD is, and introduce the mathematics necessary to understand is,\nexplain how this mathematics relates to functions and data structures in Julia, and\nhow this is handled in Mooncake.jl.","category":"page"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.","category":"page"},{"location":"understanding_mooncake/introduction/#Who-Are-These-Docs-For?","page":"Introduction","title":"Who Are These Docs For?","text":"","category":"section"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.","category":"page"},{"location":"understanding_mooncake/introduction/#Prerequisites-and-Resources","page":"Introduction","title":"Prerequisites and Resources","text":"","category":"section"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.","category":"page"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and \"transposed Jacobian\". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the \"Matrix Calculus for Machine Learning and Beyond\" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.","category":"page"},{"location":"#Mooncake.jl","page":"Mooncake.jl","title":"Mooncake.jl","text":"","category":"section"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Documentation for Mooncake.jl is on its way!","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the \"Understanding Mooncake.jl\" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!","category":"page"}] +[{"location":"known_limitations/#Known-Limitations","page":"Known Limitations","title":"Known Limitations","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Mooncake.jl has a number of known qualitative limitations, which we document here.","category":"page"},{"location":"known_limitations/#Mutation-of-Global-Variables","page":"Known Limitations","title":"Mutation of Global Variables","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"DocTestSetup = quote\n using Mooncake\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"While great care is taken in this package to prevent silent errors, this is one edge case that we have yet to provide a satisfactory solution for. Consider a function of the form:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> const x = Ref(1.0);\n\njulia> function foo(y::Float64)\n x[] = y\n return x[]\n end\nfoo (generic function with 1 method)","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"x is a global variable (if you refer to it in your code, it appears as a GlobalRef in the AST or lowered code). For some technical reasons that are beyond the scope of this section, this package cannot propagate gradient information through x. foo is the identity function, so it should have gradient 1.0. However, if you differentiate this example, you'll see:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> rule = Mooncake.build_rrule(foo, 2.0);\n\njulia> Mooncake.value_and_gradient!!(rule, foo, 2.0)\n(2.0, (NoTangent(), 0.0))","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Observe that while it has correctly computed the identity function, the gradient is zero.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The takehome: do not attempt to differentiate functions which modify global state. Uses of globals which does not involve mutating them is fine though.","category":"page"},{"location":"known_limitations/#Circular-References","page":"Known Limitations","title":"Circular References","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"To a large extent, Mooncake.jl does not presently support circular references in an automatic fashion. It is generally possible to hand-write solutions, so we explain some of the problems here, and the general approach to resolving them.","category":"page"},{"location":"known_limitations/#Tangent-Types","page":"Known Limitations","title":"Tangent Types","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Suppose that you have a type such as:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct A\n x::Float64\n a::A\n function A(x::Float64)\n a = new(x)\n a.a = a\n return a\n end\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is a fairly canonical example of a self-referential type. There are a couple of things which will not work with it out-of-the-box. tangent_type(A) will produce a stack overflow error. To see this, note that it will in effect try to produce a tangent of type Tangent{Tuple{tangent_type(A)}} – the circular dependency on the tangent_type function causes real problems here.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In order to resolve this, you need to produce a tangent type by hand. You might go with something like","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct TangentForA\n x::Float64 # tangent type for Float64 is Float64\n a::TangentForA\n function TangentForA(x::Float64)\n a = new(x)\n a.a = a\n return a\n end\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The point here is that you can manually resolve the circular dependency using a data structure which mimics the primal type. You will, however, need to implement similar methods for zero_tangent, randn_tangent, etc, and presumably need to implement additional getfield and setfield rules which are specific to this type.","category":"page"},{"location":"known_limitations/#Circular-References-in-General","page":"Known Limitations","title":"Circular References in General","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Consider a type of the form","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"mutable struct Foo\n x\n Foo() = new()\nend","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In this instance, tangent_type will work fine because Foo does not directly reference itself in its definition. Moreover, general uses of Foo will be fine.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"However, it's possible to construct an instance of Foo with a circular reference:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"f = Foo()\nf.x = f","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is actually fine provided we never attempt to call zero_tangent / randn_tangent / similar functionality on f once we've set its x field to itself. If we attempt to call such a function, we'll find ourselves with a stack overflow.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution This is a little tricker to handle. You could specialise zero_tangent etc for Foo, but this is something of a pain. Fortunately, it seems to be incredibly rare that this is ever a problem in practice. If we gain evidence that this is often a problem in practice, we'll look into supporting zero_tangent etc automatically for this case.","category":"page"},{"location":"known_limitations/#Tangent-Generation-and-Pointers","page":"Known Limitations","title":"Tangent Generation and Pointers","text":"","category":"section"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Problem","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"In many use cases, a pointer provides the address of the start of a block of memory which has been allocated to e.g. store an array. However, we cannot get any of this context from the pointer itself – by just looking at a pointer, I cannot know whether its purpose is to refer to the start of a large block of memory, some proportion of the way through a block of memory, or even to keep track of a single address.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Recall that the tangent to a pointer is another pointer:","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"julia> Mooncake.tangent_type(Ptr{Float64})\nPtr{Float64}","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Plainly I cannot implement a method of zero_tangent for Ptr{Float64} because I don't know how much memory to allocate.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is, however, fine if a pointer appears half way through a function, having been derived from another data structure. e.g.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"function foo(x::Vector{Float64})\n p = pointer(x, 2)\n return unsafe_load(p)\nend\n\nrule = build_rrule(Tuple{typeof(foo), Vector{Float64}})\nMooncake.value_and_gradient!!(rule, foo, [5.0, 4.0])\n\n# output\n(4.0, (NoTangent(), [0.0, 1.0]))","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"The Solution","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"This is only really a problem for tangent / fdata / rdata generation functionality, such as zero_tangent. As a work-around, AD testing functionality permits users to pass in CoDuals. So if you are testing something involving a pointer, you will need to construct its tangent yourself, and pass a CoDual to e.g. Mooncake.TestUtils.test_rule.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"While pointers tend to be a low-level implementation detail in Julia code, you could in principle actually be interested in differentiating a function of a pointer. In this case, you will not be able to use Mooncake.value_and_gradient!! as this requires the use of zero_tangent. Instead, you will need to use lower-level (internal) functionality, such as Mooncake.__value_and_gradient!!, or use the rule interface directly.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"Honestly, your best bet is just to avoid differentiating functions whose arguments are pointers if you can.","category":"page"},{"location":"known_limitations/","page":"Known Limitations","title":"Known Limitations","text":"DocTestSetup = nothing","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Algorithmic-Differentiation","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This section introduces the mathematics behind AD. Even if you have worked with AD before, we recommend reading in order to acclimatise yourself to the perspective that Mooncake.jl takes on the subject.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Derivatives","page":"Algorithmic Differentiation","title":"Derivatives","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"A foundation on which all of AD is built the the derivate – we require a fairly general definition of it, which we build up to here.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Scalar-to-Scalar Functions","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Consider first f RR to RR, which we require to be differentiable at x in RR. Its derivative at x is usually thought of as the scalar alpha in RR such that","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = alpha textdx ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Loosely speaking, by this notation we mean that for arbitrary small changes textd x in the input to f, the change in the output textd f is alpha textdx. We refer readers to the first few minutes of the first lecture mentioned before for a more careful explanation.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Vector-to-Vector Functions","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The generalisation of this to Euclidean space should be familiar: if f RR^P to RR^Q is differentiable at a point x in RR^P, then the derivative of f at x is given by the Jacobian matrix at x, denoted Jx in RR^Q times P, such that","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = Jx textdx ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"It is possible to stop here, as all the functions we shall need to consider can in principle be written as functions on some subset RR^P.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"However, when we consider differentiating computer programmes, we will have to deal with complicated nested data structures, e.g. structs inside Tuples inside Vectors etc. While all of these data structures can be mapped onto a flat vector in order to make sense of the Jacobian of a computer programme, this becomes very inconvenient very quickly. To see the problem, consider the Julia function whose input is of type Tuple{Tuple{Float64, Vector{Float64}}, Vector{Float64}, Float64} and whose output is of type Tuple{Vector{Float64}, Float64}. What kind of object might be use to represent the derivative of a function mapping between these two spaces? We certainly can treat these as structured \"view\" into a \"flat\" Vector{Float64}s, and then define a Jacobian, but actually finding this mapping is a tedious exercise, even if it quite obviously exists.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In fact, a more general formulation of the derivative is used all the time in the context of AD – the matrix calculus discussed by [1] and [2] (to name a couple) make use of a generalised form of the derivative in order to work with functions which map to and from matrices (albeit there are slight differences in naming conventions from text to text), without needing to \"flatten\" them into vectors in order to make sense of them.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In general, it will be much easier to avoid \"flattening\" operations wherever possible. In order to do so, we now introduce a generalised notion of the derivative.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Functions Between More General Spaces","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to avoid the difficulties described above, we consider functions f mathcalX to mathcalY, where mathcalX and mathcalY are finite dimensional real Hilbert spaces (read: finite-dimensional vector space with an inner product, and real-valued scalars). This definition includes functions to / from RR, RR^D, but also real-valued matrices, and any other \"container\" for collections of real numbers. Furthermore, we shall see later how we can model all sorts of structured representations of data directly as such spaces.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"For such spaces, the derivative of f at x in mathcalX is the linear operator (read: linear function) D f x mathcalX to mathcalY satisfying","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"textdf = D f x (textd x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The purpose of this linear operator is to provide a linear approximation to f which is accurate for arguments which are very close to x.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Please note that D f x is a single mathematical object, despite the fact that 3 separate symbols are used to denote it – D f x (dotx) denotes the application of the function D f x to argument dotx. Furthermore, the dot-notation (dotx) does not have anything to do with time-derivatives, it is simply common notation used in the AD literature to denote the arguments of derivatives.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"So, instead of thinking of the derivative as a number or a matrix, we think about it as a function. We can express the previous notions of the derivative in this language.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In the scalar case, rather than thinking of the derivative as being alpha, we think of it is a the linear operator D f x (dotx) = alpha dotx. Put differently, rather than thinking of the derivative as the slope of the tangent to f at x, think of it as the function decribing the tangent itself. Observe that up until now we had only considered inputs to D f x which were small (textd x) – here we extend it to the entire space mathcalX and denote inputs in this space dotx. Inputs dotx should be thoughts of as \"directions\", in the directional derivative sense (why this is true will be discussed later).","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Similarly, if mathcalX = RR^P and mathcalY = RR^Q then this operator can be specified in terms of the Jacobian matrix: D f x (dotx) = Jx dotx – brackets are used to emphasise that D f x is a function, and is being applied to dotx.[note_for_geometers]","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"To reiterate, for the rest of this document, we define the derivative to be \"multiply by alpha\" or \"multiply by Jx\", rather than to be alpha or Jx. So whenever you see the word \"derivative\", you should think \"linear function\".","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The Chain Rule","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The chain rule is the result which makes AD work. Fortunately, it applies to this version of the derivative:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f = g circ h implies D f x = (D g h(x)) circ (D h x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"By induction, this extends to a collection of N functions f_1 dots f_N:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f = f_N circ dots circ f_1 implies D f x = (D f_N x_N) circ dots circ (D f_1 x_1)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where x_n+1 = f(x_n), and x_1 = x.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"An aside: the definition of the Frechet Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This definition of the derivative has a name: the Frechet derivative. It is a generalisation of the Total Derivative. Formally, we say that a function f mathcalX to mathcalY is differentiable at a point x in mathcalX if there exists a linear operator D f x mathcalX to mathcalY (the derivative) satisfying","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"lim_textd h to 0 frac f(x + textd h) - f(x) + D f x (textd h) _mathcalY textdh _mathcalX = 0","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where cdot _mathcalX and cdot _mathcalY are the norms associated to Hilbert spaces mathcalX and mathcalY respectively. It is a good idea to consider what this looks like when mathcalX = mathcalY = RR and when mathcalX = mathcalY = RR^D. It is sometimes helpful to refer to this definition to e.g. verify the correctness of the derivative of a function – as with single-variable calculus, however, this is rare.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Another aside: what does Forwards-Mode AD compute?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"At this point we have enough machinery to discuss forwards-mode AD. Expressed in the language of linear operators and Hilbert spaces, the goal of forwards-mode AD is the following: given a function f which is differentiable at a point x, compute D f x (dotx) for a given vector dotx. If f RR^P to RR^Q, this is equivalent to computing Jx dotx, where Jx is the Jacobian of f at x. For the interested reader we provide a high-level explanation of how forwards-mode AD does this in How does Forwards-Mode AD work?.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Another aside: notation","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"You may have noticed that we typically denote the argument to a derivative with a \"dot\" over it, e.g. dotx. This is something that we will do consistently, and we will use the same notation for the outputs of derivatives. Wherever you see a symbol with a \"dot\" over it, expect it to be an input or output of a derivative / forwards-mode AD.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Reverse-Mode-AD:-*what*-does-it-do?","page":"Algorithmic Differentiation","title":"Reverse-Mode AD: what does it do?","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to explain what reverse-mode AD does, we first consider the \"vector-Jacobian product\" definition in Euclidean space which will be familiar to many readers. We then generalise.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Mode AD: what does it do in Euclidean space?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In this setting, the goal of reverse-mode AD is the following: given a function f RR^P to RR^Q which is differentiable at x in RR^P with Jacobian Jx at x, compute Jx^top bary for any bary in RR^Q. This is useful because we can obtain the gradient from this when Q = 1 by letting bary = 1.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Adjoint Operators","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to generalise this algorithm to work with linear operators, we must first generalise the idea of multiplying a vector by the transpose of the Jacobian. The relevant concept here is that of the adjoint operator. Specifically, the adjoint A^ast of linear operator A is the linear operator satisfying","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle A^ast bary dotx rangle = langle bary A dotx rangle","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where langle cdot cdot rangle denotes the inner-product. The relationship between the adjoint and matrix transpose is: if A (x) = J x for some matrix J, then A^ast (y) = J^top y.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Moreover, just as (A B)^top = B^top A^top when A and B are matrices, (A B)^ast = B^ast A^ast when A and B are linear operators. This result follows in short order from the definition of the adjoint operator – (and is a good exercise!)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Mode AD: what does it do in general?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Equipped with adjoints, we can express reverse-mode AD only in terms of linear operators, dispensing with the need to express everything in terms of Jacobians. The goal of reverse-mode AD is as follows: given a differentiable function f mathcalX to mathcalY, compute D f x^ast (bary) for some bary.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Notation: D f x^ast denotes the single mathematical object which is the adjoint of D f x. It is a linear function from mathcalY to mathcalX. We may occassionally write it as (D f x)^ast if there is some risk of confusion.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We will explain how reverse-mode AD goes about computing this after some worked examples.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Aside: Notation","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"You will have noticed that arguments to adjoints have thus far always had a \"bar\" over them, e.g. bary. This notation is common in the AD literature and will be used throughout. Additionally, this \"bar\" notation will be used for the outputs of adjoints of derivatives. So wherever you see a symbol with a \"bar\" over it, think \"input or output of adjoint of derivative\".","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Some-Worked-Examples","page":"Algorithmic Differentiation","title":"Some Worked Examples","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We now present some worked examples in order to prime intuition, and to introduce the important classes of problems that will be encountered when doing AD in the Julia language. We will put all of these problems in a single general framework later on.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#An-Example-with-Matrix-Calculus","page":"Algorithmic Differentiation","title":"An Example with Matrix Calculus","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We have introduced some mathematical abstraction in order to simplify the calculations involved in AD. To this end, we consider differentiating f(X) = X^top X. Results for this and similar operations are given by [1]. A similar operation, but which maps from matrices to RR is discussed in Lecture 4 part 2 of the MIT course mentioned previouly. Both [1] and Lecture 4 part 2 provide approaches to obtaining the derivative of this function.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Following either resource will yield the derivative:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f X (dotX) = dotX^top X + X^top dotX","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Observe that this is indeed a linear operator (i.e. it is linear in its argument, dotX). (You can always plug it in to the definition of the Frechet derivative to confirm that it is indeed the derivative.)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In order to perform reverse-mode AD, we need to find the adjoint operator. Using the usual definition of the inner product between matrices,","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle X Y rangle = textrmtr (X^top Y)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"we can rearrange the inner product as follows:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\n langle barY D f X (dotX) rangle = langle barY dotX^top X + X^top dotX rangle nonumber \n = textrmtr (barY^top dotX^top X) + textrmtr(barY^top X^top dotX) nonumber \n = textrmtr ( barY X^top^top dotX) + textrmtr( X barY^top dotX) nonumber \n = langle barY X^top + X barY dotX rangle nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We can read off the adjoint operator from the first argument to the inner product:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f X^ast (barY) = barY X^top + X barY","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#AD-of-a-Julia-function:-a-trivial-example","page":"Algorithmic Differentiation","title":"AD of a Julia function: a trivial example","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We now turn to differentiating Julia functions (we use function to refer to the programming language construct, and function to refer to a more general mathematical concept). The way that Mooncake.jl handles immutable data is very similar to how Zygote / ChainRules do. For example, consider the Julia function","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x::Float64) = sin(x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"If you've previously worked with ChainRules / Zygote, without thinking too hard about the formalisms we introduced previously (perhaps by considering a variety of partial derivatives) you can probably arrive at the following adjoint for the derivative of f:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"g -> g * cos(x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Implicitly, you have performed three steps:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"model f as a differentiable function,\ncompute its derivative, and\ncompute the adjoint of the derivative.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"It is helpful to work through this simple example in detail, as the steps involved apply more generally. The goal is to spell out the steps involved in detail, as this detail becomes helpful in more complicated examples. If at any point this exercise feels pedantic, we ask you to stick with it.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Obviously, we model the Julia function f as the function f RR to RR where","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x) = sin(x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Observe that, we've made (at least) two modelling assumptions here:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"a Float64 is modelled as a real number,\nthe Julia function sin is modelled as the usual mathematical function sin.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"As promised we're being quite pedantic. While the first assumption is obvious and will remain true, we will shortly see examples where we have to work a bit harder to obtain a correspondence between a Julia function and a mathematical object.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we have a mathematical model, we can differentiate it:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x (dotx) = cos(x) dotx","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Given the derivative, we can find its adjoint:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"langle barf D f x(dotx) rangle = langle barf cos(x) dotx rangle = langle cos(x) barf dotx rangle","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"From here the adjoint can be read off from the first argument to the inner product:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x^ast (barf) = cos(x) barf","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#AD-of-a-Julia-function:-a-slightly-less-trivial-example","page":"Algorithmic Differentiation","title":"AD of a Julia function: a slightly less trivial example","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now consider the Julia function","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x::Float64, y::Tuple{Float64, Float64}) = x + y[1] * y[2]","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Its adjoint is going to be something along the lines of","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"g -> (g, (y[2] * g, y[1] * g))","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"As before, we work through in detail.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"There are a couple of aspects of f which require thought:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"it has two arguments – we've only handled single argument functions previously, and\nthe second argument is a Tuple – we've not yet decided how to model this.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"To this end, we define a mathematical notion of a tuple. A tuple is a collection of N elements, each of which is drawn from some set mathcalX_n. We denote by mathcalX = mathcalX_1 times dots times mathcalX_N the set of all N-tuples whose nth element is drawn from mathcalX_n. Provided that each mathcalX_n forms a finite Hilbert space, mathcalX forms a Hilbert space with","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"alpha x = (alpha x_1 dots alpha x_N),\nx + y = (x_1 + y_1 dots x_N + y_N), and\nlangle x y rangle = sum_n=1^N langle x_n y_n rangle.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"We can think of multi-argument functions as single-argument functions of a tuple, so a reasonable mathematical model for f might be a function f RR times RR times RR to RR, where","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"f(x y) = x + y_1 y_2","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Note that while the function is written with two arguments, you should treat them as a single tuple, where we've assigned the name x to the first element, and y to the second.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we have a mathematical object, we can differentiate it:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x y(dotx doty) = dotx + doty_1 y_2 + y_1 doty_2","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D fx y maps RR times RR times RR to RR, so D f x y^ast must map the other way. You should verify that the following follows quickly from the definition of the adjoint:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x y^ast (barf) = (barf (barf y_2 barf y_1))","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#AD-with-mutable-data","page":"Algorithmic Differentiation","title":"AD with mutable data","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In the previous two examples there was an obvious mathematical model for the Julia function. Indeed this model was sufficiently obvious that it required little explanation. This is not always the case though, in particular, Julia functions which modify / mutate their inputs require a little more thought.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Consider the following Julia function:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"function f!(x::Vector{Float64})\n x .*= x\n return sum(x)\nend","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This function squares each element of its input in-place, and returns the sum of the result. So what is an appropriate mathematical model for this function?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 1: Differentiable Mathematical Model","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The trick is to distinguish between the state of x upon entry to / exit from f!. In particular, let phi_textf RR^N to RR^N times RR be given by","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"phi_textf(x) = (x odot x sum_n=1^N x_n^2)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"where odot denotes the Hadamard / element-wise product (corresponds to line x .*= x in the above code). The point here is that the inputs to phi_textf are the inputs to x upon entry to f!, and the value returned from phi_textf is a tuple containing the both the inputs upon exit from f! and the value returned by f!.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The remaining steps are straightforward now that we have the model.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 2: Compute Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The derivative of phi_textf is","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D phi_textf x(dotx) = (2 x odot dotx 2 sum_n=1^N x_n dotx_n)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Step 3: Compute Adjoint of Derivative","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The argument to the adjoint of the derivative must be a 2-tuple whose elements are drawn from RR^N times RR . Denote such a tuple as (bary_1 bary_2). Plugging this into an inner product with the derivative and rearranging yields","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\n langle (bary_1 bary_2) D phi_textf x (dotx) rangle = langle (bary_1 bary_2) (2 x odot dotx 2 sum_n=1^N x_n dotx_n) rangle nonumber \n = langle bary_1 2 x odot dotx rangle + langle bary_2 2 sum_n=1^N x_n dotx_n rangle nonumber \n = langle 2x odot bary_1 dotx rangle + langle 2 bary_2 x dotx rangle nonumber \n = langle 2 (x odot bary_1 + bary_2 x) dotx rangle nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"So we can read off the adjoint to be","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D phi_textf x^ast (bary) = 2 (x odot bary_1 + bary_2 x)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Reverse-Mode-AD:-*how*-does-it-do-it?","page":"Algorithmic Differentiation","title":"Reverse-Mode AD: how does it do it?","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now that we know what it is that AD computes, we need a rough understanding of how it computes it.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In short: reverse-mode AD breaks down a \"complicated\" function f into the composition of a collection of \"simple\" functions f_1 dots f_N, applies the chain rule, and takes the adjoint.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Specifically, we assume that we can express any function f as f = f_N circ dots circ f_1, and that we can compute the adjoint of the derivative for each f_n. From this, we can obtain the adjoint of f by applying the chain rule to the derivatives and taking the adjoint:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\nD f x^ast = (D f_N x_N circ dots circ D f_1 x_1)^ast nonumber \n = D f_1 x_1^ast circ dots circ D f_N x_N^ast nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"For example, suppose that f(x) = sin(cos(texttr(X^top X))). One option to compute its adjoint is to figure it out by hand directly (probably using the chain rule somewhere). Instead, we could notice that f = f_4 circ f_3 circ f_2 circ f_1 where f_4 = sin, f_3 = cos, f_2 = texttr and f_1(X) = X^top X. We could derive the adjoint for each of these functions (a fairly straightforward task), and then compute","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x^ast (bary) = (D f_1 x_1^ast circ D f_2 x_2^ast circ D f_3 x_3^ast circ D f_4 x_4^ast)(1)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"in order to obtain the gradient of f. Reverse-mode AD essentially just does this. Modern systems have hand-written adjoints for (hopefully!) all of the \"simple\" functions you may wish to build a function such as f from (often there are hundreds of these), and composes them to compute the adjoint of f. A sketch of a more generic algorithm is as follows.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Forwards-Pass:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"x_1 = x, n = 1\nconstruct D f_n x_n^ast\nlet x_n+1 = f_n (x_n)\nlet n = n + 1\nif n N + 1 then go to step 2.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Reverse-Pass:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"let barx_N+1 = bary\nlet n = n - 1\nlet barx_n = D f_n x_n^ast (barx_n+1)\nif n = 1 return barx_1 else go to step 2.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"How does this relate to vector-Jacobian products?","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"In Euclidean space, each derivative D f_n x_n(dotx_n) = J_nx_n dotx_n. Applying the chain rule to D f x and substituting this in yields","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Jx = J_Nx_N dots J_1x_1 ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Taking the transpose and multiplying from the left by bary yields","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Jx^top bary = Jx_1^top_1 dots Jx_N^top_N bary ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Comparing this with the expression in terms of adjoints and operators, we see that composition of adjoints of derivatives has been replaced with multiplying by transposed Jacobian matrices. This \"vector-Jacobian product\" expression is commonly used to explain AD, and is likely familiar to many readers.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Directional-Derivatives-and-Gradients","page":"Algorithmic Differentiation","title":"Directional Derivatives and Gradients","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Now we turn to using reverse-mode AD to compute the gradient of a function. In short, given a function g mathcalX to RR with derivative D g x at x, its gradient is equal to D g x^ast (1). We explain why in this section.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The derivative discussed here can be used to compute directional derivatives. Consider a function f mathcalX to RR with Frechet derivative D f x mathcalX to RR at x in mathcalX. Then D fx(dotx) returns the directional derivative in direction dotx.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Gradients are closely related to the adjoint of the derivative. Recall that the gradient of f at x is defined to be the vector nabla f (x) in mathcalX such that langle nabla f (x) dotx rangle gives the directional derivative of f at x in direction dotx. Having noted that D fx(dotx) is exactly this directional derivative, we can equivalently say that","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D fx(dotx) = langle nabla f (x) dotx rangle ","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The role of the adjoint is revealed when we consider f = mathcall circ g, where g mathcalX to mathcalY, mathcall(y) = langle bary y rangle, and bary in mathcalY is some fixed vector. Noting that D mathcall y(doty) = langle bary doty rangle, we apply the chain rule to obtain","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"beginalign\nD f x (dotx) = (D mathcall g(x)) circ (D g x)(dotx) nonumber \n = langle bary D g x (dotx) rangle nonumber \n = langle D g x^ast (bary) dotx rangle nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"from which we conclude that D g x^ast (bary) is the gradient of the composition l circ g at x.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"The above shows that if mathcalY = RR and g is the function we wish to compute the gradient of, we can simply set bary = 1 and compute D g x^ast (bary) to obtain the gradient of g at x.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Summary","page":"Algorithmic Differentiation","title":"Summary","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/#Asides","page":"Algorithmic Differentiation","title":"Asides","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/#*How*-does-Forwards-Mode-AD-work?","page":"Algorithmic Differentiation","title":"How does Forwards-Mode AD work?","text":"","category":"section"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"Forwards-mode AD achieves this by breaking down f into the composition f = f_N circ dots circ f_1, where each f_n is a simple function whose derivative (function) D f_n x_n we know for any given x_n. By the chain rule, we have that","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"D f x (dotx) = D f_N x_N circ dots circ D f_1 x_1 (dotx)","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"which suggests the following algorithm:","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"let x_1 = x, dotx_1 = dotx, and n = 1\nlet dotx_n+1 = D f_n x_n (dotx_n)\nlet x_n+1 = f(x_n)\nlet n = n + 1\nif n = N+1 then return dotx_N+1, otherwise go to 2.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"When each function f_n maps between Euclidean spaces, the applications of derivatives D f_n x_n (dotx_n) are given by J_n dotx_n where J_n is the Jacobian of f_n at x_n.","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).\n\n\n\nT. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).\n\n\n\n","category":"page"},{"location":"understanding_mooncake/algorithmic_differentiation/","page":"Algorithmic Differentiation","title":"Algorithmic Differentiation","text":"[note_for_geometers]: in AD we only really need to discuss differentiatiable functions between vector spaces that are isomorphic to Euclidean space. Consequently, a variety of considerations which are usually required in differential geometry are not required here. Notably, the tangent space is assumed to be the same everywhere, and to be the same as the domain of the function. Avoiding these additional considerations helps keep the mathematics as simple as possible.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Forwards-Mode-Design","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Disclaimer: this document refers to an as-yet-unimplemented forwards-mode AD. This will disclaimer will be removed once it has been implemented.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The purpose of this document is to explain how forwards-mode AD in Mooncake.jl is implemented. It should do so to a sufficient level of depth to enable the interested reader to read to the forwards-mode AD code in Mooncake.jl and understand what is going on.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"This document","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"specifies the semantics of a \"rule\" for forwards-mode AD,\nspecifies how to implement rules by-hand for primitives, and\nspecifies how to derive rules from IRCode algorithmically in general.\ndiscusses batched forwards-mode\ndiscusses some notable technical differences between our forwards-mode AD implementation details and reverse-mode AD implementation details, and\nconcludes with a brief comparison with ForwardDiff.jl.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Forwards-Rule-Interface","page":"Forwards-Mode Design","title":"Forwards-Rule Interface","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Loosely, a rule for a function simultaneously","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"performs same computation as the original function, and\ncomputes the Frechet derivative.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"This is best made concrete through a worked example. Consider a function call","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"z = f(x, y)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"where f itself may contain data / state which is modified by executing f. rule_for_f is some callable which claims to be a forwards-rule for f. For rule_for_f to be a valid forwards-rule for f, it must be applicable to Duals as follows:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"z_dz = rule_for_f(Dual(f, df), Dual(x, dx), Dual(y, dy))::Dual","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"where:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"rule_for_f is a callable. It might be written by-hand, or derived algorithmically.\ndf, dx, and dy are tangents for f, x, and y respectively. Before executing rule_for_f, they are inputs to the derivative of (f, x, y). After executing they are outputs of this derivative.\nz_dz is a Dual containing the primal and the component of the derivative of (f, x, y) to (df, dx, dy) associated to z.\nrunning rule_for_f leaves f, x, and y in the same state that running f does.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We refer readers to Algorithmic Differentiation to explain what we mean when we talk about the \"derivative\" above. We also discussed some worked examples shortly.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Note that rule_for_f is an as-yet-unspecified callable which we introduced purely to specify the interface that a forwards-rule must satisfy. In Hand-Written Rules and Derived Rules below, we introduce two concrete ways to produce rules for f.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Tangent-Types","page":"Forwards-Mode Design","title":"Tangent Types","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We will use the type system documented in Representing Gradients. This means that every primal type has a unique tangent type. Moreover, if a Dual is defined as follows:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"struct Dual{P, T}\n primal::P\n tangent::T\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"it must always hold that T = tangent_type(P).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Testing","page":"Forwards-Mode Design","title":"Testing","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Suppose that we have (somehow) produced a supposed forwards-rule. To check that it is correctly implemented, we must","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"all primal state after running the rule is approximately the same as all primal state after running the primal, and\nthe inner product between all tangents (both output and input) and a random tangent vector after running the rule is approximately the same as the estimate of the same quantity produced by finite differencing or reverse-mode AD.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We already have the functionality to do this in a very general way (see Mooncake.TestUtils.test_rule).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Hand-Written-Rules","page":"Forwards-Mode Design","title":"Hand-Written Rules","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Hand-written rules are implemented by writing methods of two functions: is_primitive and frule!!.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#is_primitive","page":"Forwards-Mode Design","title":"is_primitive","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"is_primitive(::Type{<:Union{MinimalForwardsCtx, DefaultForwardsCtx}}, signature::Type{<:Tuple}) should return true if AD must attempt to differentiate a call by passing the arguments to frule!!, and false otherwise. The Mooncake.@is_primitive macro can be used to implement this straightforwardly.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#frule!!","page":"Forwards-Mode Design","title":"frule!!","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Methods of frule!! do the actual differentiation, and must satisfy the Forwards-Rule Interface discussed above.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In what follows, we will refer to frule!!s for signatures. For example, the frule!! for signature Tuple{typeof(sin), Float64} is the rule which would differentiate calls like sin(5.0).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Simple-Scalar-Function","page":"Forwards-Mode Design","title":"Simple Scalar Function","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Recall that for y = sin(x) we have that doty = cos(x) dotx. So the frule!! for signature Tuple{typeof(sin), Float64} is:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"function frule!!(::Dual{typeof(sin)}, x::Dual{Float64})\n return Dual(sin(x.primal), cos(x.primal) * x.tangent)\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Pre-allocated-Matrix-Matrix-Multiply","page":"Forwards-Mode Design","title":"Pre-allocated Matrix-Matrix Multiply","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Recall that for Z = X Y we have that dotZ = X dotY + dotX Y. So the frule!! for signature Tuple{typeof(mul!), Matrix{Float64}, Matrix{Float64}, Matrix{Float64}} is:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"function frule!!(\n ::Dual{typeof(LinearAlgebra.mul!)}, Z::Dual{P}, X::Dual{P}, Y::Dual{P}\n) where {P<:Matrix{Float64}}\n\n # Primal computation.\n mul!(Z.primal, X.primal, Y.primal)\n\n # Overwrite tangent of `z` to contain propagated tangent.\n mul!(Z.tangent, X.primal, Y.tangent)\n\n # Add the result of x.tangent * y.primal to `z.tangent`.\n mul!(Z.tangent, X.tangent, Y.primal, 0.0, 1.0) \n return Z\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"(In practice we would probably implement a rule for a lower-level function like LinearAlgebra.BLAS.gemm!, rather than mul!).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Derived-Rules","page":"Forwards-Mode Design","title":"Derived Rules","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"This is the \"automatic\" / \"algorithmic\" bit of AD! This is the second way of producing concrete callable objects which satisfy the Forwards-Rule Interface discussed above. The object which we will ultimately construct is an instance Mooncake.DerivedFRule.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Worked-Example:-Julia-Function","page":"Forwards-Mode Design","title":"Worked Example: Julia Function","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Before explaining how derived rules are produced algorithmically, we explain by way of example what a derived rule should look like if we work things through by hand.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"A derived rule for a function such as","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"function f(x)\n y = g(x)\n z = h(x, y)\n return z\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"should be something of the form","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"function rule_for_f(::Dual{typeof(f)}, x::Dual)\n y = rule_for_g(zero_dual(g), x)\n z = rule_for_h(zero_dual(h), x, y)\n return z\nend","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Observe that the transformation is simply","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"replace all variables with Dual variables,\nreplace all constants (e.g. g and h) with constant Duals,\nreplace all calls with calls to rules.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In general, all control flow should be identical between primal and rule.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Worked-Example:-IRCode","page":"Forwards-Mode Design","title":"Worked Example: IRCode","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The above example is expressed in terms of Julia code, but we will be operating on Julia Compiler.IRCode, so it is helpful to consider how the above example translates into this form. If we call f on a Float64, and suppose that g and h both return Float64s, the primal Compiler.IRCode will look something like the following:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"julia> Base.code_ircode_by_type(Tuple{typeof(f), Float64})\n1-element Vector{Any}:\n2 1 ─ %1 = invoke Main.g(_2::Float64)::Float64\n3 │ %2 = invoke Main.h(_2::Float64, %1::Float64)::Float64\n4 └── return %2\n => Float64","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Recall that _2 is the second argument, in this case the Float64, and %1 and %2 are SSAValues. Roughly speaking, the forwards-mode IR for the (ficiticious) function rule_for_f should look something like:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"julia> Base.code_ircode_by_type(Tuple{typeof(rule_for_f), Dual{typeof(f), NoTangent}, Dual{Float64, Float64}})\n1-element Vector{Any}:\n2 1 ─ %1 = invoke rule_for_g($(Dual(Main.g, NoTangent())), _3::Dual{Float64, Float64})::Dual{Float64, Float64}\n3 │ %2 = invoke rule_for_h($(Dual(Main.h, NoTangent())), _3::Dual{Float64, Float64}, %1::Dual{Float64, Float64})::Dual{Float64, Float64}\n4 └── return %2\n => Dual{Float64, Float64}","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Observe that:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"All Arguments have been incremented by 1. i.e. _2 has been replaced with _3. This corresponds to the fact that the arguments to the rule have all been shuffled along by one, and the rule itself is now the first argument.\nEverything has been turned into a Dual.\nConstants such as Dual(Main.g, NoTangent()) appear directly in the code (here as QuoteNodes).","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"(In practice it might be that we actually construct the Dualed constants on the lines immediately preceding a call and rely on the compiler to optimise them back into the call directly).","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Here, as before, we have not specified exactly what rule_for_f, rule_for_g, and rule_for_h are. This is intentional – they are just callables satisfying the Forwards-Rule Interface. In the following we show how to derive rule_for_f, and show how rule_for_g and rule_for_h might be methods of Mooncake.frule!!, or themselves derived rules.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Rule-Derivation-Outline","page":"Forwards-Mode Design","title":"Rule Derivation Outline","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Equipped with some intuition about what a derived rule ought to look like, we examine how we go about producing it algorithmically.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Rule derivation is implemented via the function Mooncake.build_frule. This function accepts as arguments a context and a signature / Base.MethodInstance / MistyClosure and, roughly speaking, does the following:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Look up the optimised Compiler.IRCode.\nApply a series of standardising transformations to the IRCode.\nTransform each statement according to a set of rules to produce a new IRCode.\nApply standard Julia optimisations to this new IRCode.\nPut this code inside a MistyClosure in order to produce a executable object.\nWrap this MistyClosure in a DerivedFRule to handle various bits of book-keeping around varargs.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In order:","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Looking-up-the-Compiler.IRCode.","page":"Forwards-Mode Design","title":"Looking up the Compiler.IRCode.","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"This is done using Mooncake.lookup_ir. This function has methods with will return the IRCode associated to:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"signatures (e.g. Tuple{typeof(f), Float64})\nBase.MethodInstances (relevant for :invoke expressions – see Statement Transformation below)\nMistyClosures.MistyClosure objects, which is essential when computing higher order derivatives and Hessians by applying Mooncake.jl to itself.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Standardisation","page":"Forwards-Mode Design","title":"Standardisation","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We apply the following transformations to the Julia IR. They can all be found in ir_normalisation.jl:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Mooncake.foreigncall_to_call: convert Expr(:foreigncall, ...) expressions into Expr(:call, Mooncake._foreigncall_, ...) expressions.\nMooncake.new_to_call: convert Expr(:new, ...) expressions to Expr(:call, Mooncake._new_, ...) expressions.\nMooncake.splatnew_to_call: convert Expr(:splatnew, ...) expressions to Expr(:call, Mooncake._splat_new_...) expressions.\nMooncake.intrinsic_to_function: convert Expr(:call, ::IntrinsicFunction, ...) to calls to the corresponding function in Mooncake.IntrinsicsWrappers.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The purpose of converting Expr(:foreigncall...), Expr(:new, ...) and Expr(:splatnew, ...) into Expr(:call, ...)s is to enable us to differentiate such expressions by adding methods to frule!!(::Dual{typeof(Mooncake._foreigncall_)}), frule!!(::Dual{typeof(Mooncake._new_)}), and frule!!(::Dual{typeof(Mooncake._splat_new_)}), in exactly the same way that we would for any other regular Julia function.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The purpose of translating Expr(:call, ::IntrinsicFunction, ...) is to do with type stability – see the docstring for the Mooncake.IntrinsicsWrappers module for more info.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Statement-Transformation","page":"Forwards-Mode Design","title":"Statement Transformation","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Each statment which can appear in the Julia IR is transformed by a method of Mooncake.make_fwds_ad_stmts. Consequently, this transformation phase simply corresponds to iterating through all of the expressions in the IRCode, applying Mooncake.make_fwd_ad_stmts to each to produce new IRCode. To understand how to modify IRCode and insert new instructions, see Oxinabox's Gist.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"We provide here a high-level summary of the transformations for the most important Julia IR statements, and refer readers to the methods of Mooncake.make_fwds_ad_stmts for the definitive explanation of what transformation is applied, and the rationale for applying it. In particular there are quite a number more statements which can appear in Julia IR than those listed here and, for those we do list here, there are typically a few edge cases left out.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Expr(:invoke, method_instance, f, x...) and Expr(:call, f, x...)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":":call expressions correspond to dynamic dispatch, while :invoke expressions correspond to static dispatch. That is, if you see an :invoke expression, you know for sure that the compiler knows enough information about the types of f and x to prove exactly which specialisation of which method to call. This specialisation is method_instance. This typically happens when the compiler is able to prove the types of f and x. Conversely, a :call expression typically occurs when the compiler has not been able to deduce the exact types of f and x, and therefore not been able to figure out what to call. It therefore has to wait until runtime to figure out what to call, resulting in dynamic dispatch.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"As we saw earlier, the idea is to translate these kinds of expressions into something vaguely along the lines of","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Expr(:call, rule_for_f, f, x...)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"There are three cases to consider, in order of preference:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Primitives:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"If is_primitive returns true when applied to the signature constructed from the static types of f and x, then we simply replace the expression with Expr(:call, frule!!, f, x...), regardless whether we have an :invoke or :call expression. (Due to the Standardisation steps, it regularly happens that we see :call expressions in which we actually do know enough type information to do this, e.g. for Mooncake._new_ :call expressions).","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Static Dispatch:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In the case of :invoke nodes we know for sure at rule compilation time what rule_for_f must be. We derive a rule for the call by passing method_instance to Mooncake.build_frule. (In practice, we might do this lazily, but while retaining enough information to maintain type stability. See the Mooncake.LazyDerivedRule for how this is handled in reverse-mode).","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Dynamic Dispatch:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"If we have a :call expression and are not able to prove that is_primitive will return true, we must defer dispatch until runtime. We do this by replacing the :call expression with a call to a DynamicFRule, which simply constructs (or retrieves from a cache) the rule at runtime. Reverse-mode utilises a similar strategy via Mooncake.DynamicDerivedRule.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The above was written in terms of f and x. In practice, of course, we encounter various kinds of constants (e.g. Base.sin), Arguments (e.g. _3), and Core.SSAValues (e.g. %5). The translation rules for these are:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"constants are turned into constant duals in which the tangent is zero,\nArguments are incremented by 1.\nSSAValues are left as-is.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.GotoNodes","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"These remain entirely unchanged.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.GotoIfNot","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"These require minor modification. Suppose that a Core.GotoIfNot of the form Core.GotoIfNot(%5, 4) is encountered in the primal. Since %5 will be a Dual in the derived rule, we must pull out the primal field, and pass that to the conditional instead. Therefore, these statments get lowered to two lines in the derived rule. For example, Core.GotoIfNot(%5, 4) would be translated to:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"%n = getfield(%5, :primal)\nCore.GotoIfNot(%n, 4)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.PhiNode","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.PhiNode looks something like the following in the general case:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"φ (#1 => %3, #2 => _2, #3 => 4, #4 => #undef)","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"They map from a collection of basic block numbers (#1, #2, etc) to values. The values can be Core.Arguments, Core.SSAValues, constants (literals and QuoteNodes), or undefined.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.PhiNodes in the primal are mapped to Core.PhiNodes in the rule. They contain exactly the same basic block numbers, and apply the following translation rules to the values:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Core.SSAValues are unchanged.\nCore.Arguments are incremented by 1 (as always).\nconstants are translated into constant duals.\nundefined values remain undefined.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"So the above example would be translated into something like","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"φ (#1 => %3, #2 => _3, #3 => $(CoDual(4, NoTangent())), #4 => #undef)","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Optimisation","page":"Forwards-Mode Design","title":"Optimisation","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The IR generated in the previous step will typically be uninferred, and suboptimal in a variety of ways. We fix this up by running inference and optimisation on the generated IRCode. This is implemented by Mooncake.optimise_ir!.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Put-IRCode-in-MistyClosure","page":"Forwards-Mode Design","title":"Put IRCode in MistyClosure","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"Now that we have an optimised IRCode object, we need to turn it into something that can actually be run. This can, in general, be straightforwardly achieved by putting it inside a Core.OpaqueClosure. This works, but Core.OpaqueClosures have the disadvantage that once you've constructed a Core.OpaqueClosure using an IRCode, it is not possible to get it back out. Consequently, we use MistyClosures, in order to keep the IRCode readily accessible if we want to access it later.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Put-the-MistyClosure-in-a-DerivedFRule","page":"Forwards-Mode Design","title":"Put the MistyClosure in a DerivedFRule","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"See the implementation of DerivedRule (used in reverse-mode) for more context on this. This is the \"rule\" that users get.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Batch-Mode","page":"Forwards-Mode Design","title":"Batch Mode","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"So far, we have assumed that we would only apply forwards-mode to a single tangent vector at a time. However, in practice, it is typically best to pass a collection of tangents through at a time.","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"In order to do this, all of the transformation code listed above can remain the same, we will just need to devise a system of \"batched tangents\". Then, instead of propagating a \"primal-tangent\" pairs via Duals, we propagate primal-tangent_batch pairs (perhaps also via Duals).","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Forwards-vs-Reverse-Implementation","page":"Forwards-Mode Design","title":"Forwards vs Reverse Implementation","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"The implementation of forwards-mode AD is quite dramatically simpler than that of reverse-mode AD. Some notable technical differences include:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"forwards-mode AD only makes use of the tangent system, whereas reverse-mode also makes use of the fdata / rdata system.\nforwards-mode AD comprises only line-by-line transformations of the IRCode. In particular, it does not require the insertion of additional basic blocks, nor the modification of the successors / predecessors of any given basic block. Consequently, there is no need to make use of the BBCode infrastructure built up for reverse-mode AD – everything can be straightforwardly done at the Compiler.IRCode level.","category":"page"},{"location":"developer_documentation/forwards_mode_design/#Comparison-with-ForwardDiff.jl","page":"Forwards-Mode Design","title":"Comparison with ForwardDiff.jl","text":"","category":"section"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":"With reference to the limitations of ForwardDiff.jl, there are a few noteworthy differences between ForwardDiff.jl and this implementation:","category":"page"},{"location":"developer_documentation/forwards_mode_design/","page":"Forwards-Mode Design","title":"Forwards-Mode Design","text":":foreigncalls pose much less of a problem for Mooncake's forward-mode than for ForwardDiff.jl, because we can write a rule for any method of any function. In essence, you can only (reliably) write rules for ForwardDiff.jl via dispatch on ForwardDiff.Dual.\nthe target function can be of any arity in Mooncake.jl, but must be unary in ForwardDiff.jl.\nthere are no limitations on the argument type constraints that Mooncake.jl can handle, while ForwardDiff.jl requires that argument type constraints be <:Real or arrays of <:Real.\nNo special storage types are required with Mooncake.jl, while ForwardDiff.jl requires that any container you write to is able to contain ForwardDiff.Duals.","category":"page"},{"location":"developer_documentation/internal_docstrings/#Internal-Docstrings","page":"Internal Docstrings","title":"Internal Docstrings","text":"","category":"section"},{"location":"developer_documentation/internal_docstrings/","page":"Internal Docstrings","title":"Internal Docstrings","text":"Docstrings listed here are not part of the public Mooncake.jl interface. Consequently, they can change between non-breaking changes to Mooncake.jl without warning.","category":"page"},{"location":"developer_documentation/internal_docstrings/","page":"Internal Docstrings","title":"Internal Docstrings","text":"The purpose of this is to make it easy for developers to find docstrings straightforwardly via the docs, as opposed to having to ctrl+f through Mooncake.jl's source code, or looking at the docstrings via the Julia REPL.","category":"page"},{"location":"developer_documentation/internal_docstrings/","page":"Internal Docstrings","title":"Internal Docstrings","text":"Modules = [Mooncake]\nPublic = false","category":"page"},{"location":"developer_documentation/internal_docstrings/#Mooncake.GLOBAL_INTERPRETER-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.GLOBAL_INTERPRETER","text":"const GLOBAL_INTERPRETER\n\nGlobally cached interpreter. Should only be accessed via get_interpreter.\n\n\n\n\n\n","category":"constant"},{"location":"developer_documentation/internal_docstrings/#Mooncake.Terminator-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.Terminator","text":"Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode}\n\nA Union of the possible types of a terminator node.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Core.Compiler.IRCode-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Core.Compiler.IRCode","text":"IRCode(bb_code::BBCode)\n\nProduce an IRCode instance which is equivalent to bb_code. The resulting IRCode shares no memory with bb_code, so can be safely mutated without modifying bb_code.\n\nAll IDPhiNodes, IDGotoIfNots, and IDGotoNodes are converted into PhiNodes, GotoIfNots, and GotoNodes respectively.\n\nIn the resulting bb_code, any Switch nodes are lowered into a semantically-equivalent collection of GotoIfNot nodes.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ADInfo-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ADInfo","text":"ADInfo\n\nThis data structure is used to hold \"global\" information associated to a particular call to build_rrule. It is used as a means of communication between make_ad_stmts! and the codegen which produces the forwards- and reverse-passes.\n\ninterp: a MooncakeInterpreter.\nblock_stack_id: the ID associated to the block stack – the stack which keeps track of which blocks we visited during the forwards-pass, and which is used on the reverse-pass to determine which blocks to visit.\nblock_stack: the block stack. Can always be found at block_stack_id in the forwards- and reverse-passes.\nentry_id: ID associated to the block inserted at the start of execution in the the forwards-pass, and the end of execution in the pullback.\nshared_data_pairs: the SharedDataPairs used to define the captured variables passed to both the forwards- and reverse-passes.\narg_types: a map from Argument to its static type.\nssa_insts: a map from ID associated to lines to the primal NewInstruction. This contains the line of code, its static / inferred type, and some other detailss. See Core.Compiler.NewInstruction for a full list of fields.\narg_rdata_ref_ids: the dict mapping from arguments to the ID which creates and initialises the Ref which contains the reverse data associated to that argument. Recall that the heap allocations associated to this Ref are always optimised away in the final programme.\nssa_rdata_ref_ids: the same as arg_rdata_ref_ids, but for each ID associated to an ssa rather than each argument.\ndebug_mode: if true, run in \"debug mode\" – wraps all rule calls in DebugRRule. This is applied recursively, so that debug mode is also switched on in derived rules.\nis_used_dict: for each ID associated to a line of code, is false if line is not used anywhere in any other line of code.\nlazy_zero_rdata_ref_id: for any arguments whose type doesn't permit the construction of a zero-valued rdata directly from the type alone (e.g. a struct with an abstractly- typed field), we need to have a zero-valued rdata available on the reverse-pass so that this zero-valued rdata can be returned if the argument (or a part of it) is never used during the forwards-pass and consequently doesn't obtain a value on the reverse-pass. To achieve this, we construct a LazyZeroRData for each of the arguments on the forwards-pass, and make use of it on the reverse-pass. This field is the ID that will be associated to this information.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ADStmtInfo-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ADStmtInfo","text":"ADStmtInfo\n\nData structure which contains the result of make_ad_stmts!. Fields are\n\nline: the ID associated to the primal line from which this is derived\ncomms_id: an ID from one of the lines in fwds, whose value will be made available on the reverse-pass in the same ID. Nothing is asserted about how this value is made available on the reverse-pass of AD, so this package is free to do this in whichever way is most efficient, in particular to group these communication ID on a per-block basis.\nfwds: the instructions which run the forwards-pass of AD\nrvs: the instructions which run the reverse-pass of AD / the pullback\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBCode-Tuple{Core.Compiler.IRCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBCode","text":"BBCode(ir::IRCode)\n\nConvert an ir into a BBCode. Creates a completely independent data structure, so mutating the BBCode returned will not mutate ir.\n\nAll PhiNodes, GotoIfNots, and GotoNodes will be replaced with the IDPhiNodes, IDGotoIfNots, and IDGotoNodes respectively.\n\nSee IRCode for conversion back to IRCode.\n\nNote that IRCode(BBCode(ir)) should be equal to the identity function.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBCode-Tuple{Union{Core.Compiler.IRCode, Mooncake.BBCode}, Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBCode","text":"BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block})\n\nMake a new BBCode whose blocks is given by new_blocks, and fresh copies are made of all other fields from ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBCode-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBCode","text":"BBCode(\n blocks::Vector{BBlock}\n argtypes::Vector{Any}\n sptypes::Vector{CC.VarState}\n linetable::Vector{Core.LineInfoNode}\n meta::Vector{Expr}\n)\n\nA BBCode is a data structure which is similar to IRCode, but adds additional structure.\n\nIn particular, a BBCode comprises a sequence of basic blocks (BBlocks), each of which comprise a sequence of statements. Moreover, each BBlock has its own unique ID, as does each statment.\n\nThe consequence of this is that new basic blocks can be inserted into a BBCode. This is distinct from IRCode, in which to create a new basic block, one must insert additional statments which you know will create a new basic block – this is generally quite an unreliable process, while inserting a new BBlock into BBCode is entirely predictable. Furthermore, inserting a new BBlock does not change the ID associated to the other blocks, meaning that you can safely assume that references from existing basic block terminators / phi nodes to other blocks will not be modified by inserting a new basic block.\n\nAdditionally, since each statment in each basic block has its own unique ID, new statments can be inserted without changing references between other blocks. IRCode also has some support for this via its new_nodes field, but eventually all statements will be renamed upon compact!ing the IRCode, meaning that the name of any given statement will eventually change.\n\nFinally, note that the basic blocks in a BBCode support the custom Switch statement. This statement is not valid in IRCode, and is therefore lowered into a collection of GotoIfNots and GotoNodes when a BBCode is converted back into an IRCode.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBlock-Tuple{Mooncake.ID, Vector{Tuple{Mooncake.ID, Core.Compiler.NewInstruction}}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBlock","text":"BBlock(id::ID, inst_pairs::Vector{IDInstPair})\n\nConvenience constructor – splits inst_pairs into a Vector{ID} and InstVector in order to build a BBlock.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BBlock-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BBlock","text":"BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector)\n\nA basic block data structure (not called BasicBlock to avoid accidental confusion with CC.BasicBlock). Forms a single basic block.\n\nEach BBlock has an ID (a unique name). This makes it possible to refer to blocks in a way that does not change when additional BBlocks are inserted into a BBCode. This differs from the positional block numbering found in IRCode, in which the number associated to a basic block changes when new blocks are inserted.\n\nThe nth line of code in a BBlock is associated to ID stmt_ids[n], and the nth instruction from stmts.\n\nNote that PhiNodes, GotoIfNots, and GotoNodes should not appear in a BBlock – instead an IDPhiNode, IDGotoIfNot, or IDGotoNode should be used.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.BlockStack-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.BlockStack","text":"The block stack is the stack used to keep track of which basic blocks are visited on the forwards pass, and therefore which blocks need to be visited on the reverse pass. There is one block stack per derived rule. By using Int32, we assume that there aren't more than typemax(Int32) unique basic blocks in a given function, which ought to be reasonable.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.CannotProduceZeroRDataFromType-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.CannotProduceZeroRDataFromType","text":"CannotProduceZeroRDataFromType()\n\nReturned by zero_rdata_from_type if is not possible to construct the zero rdata element for a given type. See zero_rdata_from_type for more info.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.Config-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.Config","text":"Config(; debug_mode=false, silence_debug_messages=false)\n\nConfiguration struct for use with ADTypes.AutoMooncake.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DebugPullback-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DebugPullback","text":"(pb::DebugPullback)(dy)\n\nApply type checking to enforce pre- and post-conditions on pb.pb. See the docstring for DebugPullback for details.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DebugPullback-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DebugPullback","text":"DebugPullback(pb, y, x)\n\nConstruct a callable which is equivalent to pb, but which enforces type-based pre- and post-conditions to pb. Let dx = pb.pb(dy), for some rdata dy, then this function\n\nchecks that dy has the correct rdata type for y, and\nchecks that each element of dx has the correct rdata type for x.\n\nReverse pass counterpart to DebugRRule\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DebugRRule-Union{NTuple{N, Mooncake.CoDual}, Tuple{N}} where N-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DebugRRule","text":"(rule::DebugRRule)(x::CoDual...)\n\nApply type checking to enforce pre- and post-conditions on rule.rule. See the docstring for DebugRRule for details.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DebugRRule-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DebugRRule","text":"DebugRRule(rule)\n\nConstruct a callable which is equivalent to rule, but inserts additional type checking. In particular:\n\ncheck that the fdata in each argument is of the correct type for the primal\ncheck that the fdata in the CoDual returned from the rule is of the correct type for the primal.\n\nThis happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.\n\nSome additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).\n\nLet rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.\n\nNote: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DefaultCtx-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DefaultCtx","text":"struct DefaultCtx end\n\nContext for all usually used AD primitives. Anything which is a primitive in a MinimalCtx is a primitive in the DefaultCtx automatically. If you are adding a rule for the sake of performance, it should be a primitive in the DefaultCtx, but not the MinimalCtx.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.DynamicDerivedRule-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.DynamicDerivedRule","text":"DynamicDerivedRule(interp::MooncakeInterpreter, debug_mode::Bool)\n\nFor internal use only.\n\nA callable data structure which, when invoked, calls an rrule specific to the dynamic types of its arguments. Stores rules in an internal cache to avoid re-deriving.\n\nThis is used to implement dynamic dispatch.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.FData-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.FData","text":"FData(data::NamedTuple)\n\nThe component of a struct which is propagated alongside the primal on the forwards-pass of AD. For example, the tangents for Float64s do not need to be propagated on the forwards- pass of reverse-mode AD, so any Float64 fields of Tangent do not need to appear in the associated FData.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ID-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ID","text":"ID()\n\nAn ID (read: unique name) is just a wrapper around an Int32. Uniqueness is ensured via a global counter, which is incremented each time that an ID is created.\n\nThis counter can be reset using seed_id! if you need to ensure deterministic IDs are produced, in the same way that seed for random number generators can be set.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IDGotoIfNot-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.IDGotoIfNot","text":"IDGotoIfNot(cond::Any, dest::ID)\n\nLike a GotoIfNot, but dest is an ID rather than an Int64.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IDGotoNode-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.IDGotoNode","text":"IDGotoNode(label::ID)\n\nLike a GotoNode, but label is an ID rather than an Int64.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IDInstPair-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.IDInstPair","text":"const IDInstPair = Tuple{ID, NewInstruction}\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IDPhiNode-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.IDPhiNode","text":"IDPhiNode(edges::Vector{ID}, values::Vector{Any})\n\nLike a PhiNode, but edges are IDs rather than Int32s.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.InstVector-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.InstVector","text":"const InstVector = Vector{NewInstruction}\n\nNote: the CC.NewInstruction type is used to represent instructions because it has the correct fields. While it is only used to represent new instrucdtions in Core.Compiler, it is used to represent all instructions in BBCode.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.InvalidFDataException-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.InvalidFDataException","text":"InvalidFDataException(msg::String)\n\nException indicating that there is a problem with the fdata associated to a primal.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.InvalidRDataException-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.InvalidRDataException","text":"InvalidRDataException(msg::String)\n\nException indicating that there is a problem with the rdata associated to a primal.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.LazyDerivedRule-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.LazyDerivedRule","text":"LazyDerivedRule(interp, mi::Core.MethodInstance, debug_mode::Bool)\n\nFor internal use only.\n\nA type-stable wrapper around a DerivedRule, which only instantiates the DerivedRule when it is first called. This is useful, as it means that if a rule does not get run, it does not have to be derived.\n\nIf debug_mode is true, then the rule constructed will be a DebugRRule. This is useful when debugging, but should usually be switched off for production code as it (in general) incurs some runtime overhead.\n\nNote: the signature of the primal for which this is a rule is stored in the type. The only reason to keep this around is for debugging – it is very helpful to have this type visible in the stack trace when something goes wrong, as it allows you to trivially determine which bit of your code is the culprit.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.LazyZeroRData-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.LazyZeroRData","text":"LazyZeroRData{P, Tdata}()\n\nThis type is a lazy placeholder for zero_like_rdata_from_type. This is used to defer construction of zero data to the reverse pass. Calling instantiate on an instance of this will construct a zero data.\n\nUsers should construct using LazyZeroRData(p), where p is an value of type P. This constructor, and instantiate, are specialised to minimise the amount of data which must be stored. For example, Float64s do not need any data, so LazyZeroRData(0.0) produces an instance of a singleton type, meaning that various important optimisations can be performed in AD.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.MinimalCtx-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.MinimalCtx","text":"struct MinimalCtx end\n\nFunctions should only be primitives in this context if not making them so would cause AD to fail. In particular, do not add primitives to this context if you are writing them for performance only – instead, make these primitives in the DefaultCtx.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.NoPullback-Union{NTuple{N, Mooncake.CoDual}, Tuple{N}} where N-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.NoPullback","text":"NoPullback(args::CoDual...)\n\nConstruct a NoPullback from the arguments passed to an rrule!!. For each argument, extracts the primal value, and constructs a LazyZeroRData. These are stored in a NoPullback which, in the reverse-pass of AD, instantiates these LazyZeroRDatas and returns them in order to perform the reverse-pass of AD.\n\nThe advantage of this approach is that if it is possible to construct the zero rdata element for each of the arguments lazily, the NoPullback generated will be a singleton type. This means that AD can avoid generating a stack to store this pullback, which can result in significant performance improvements.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.RRuleZeroWrapper-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.RRuleZeroWrapper","text":"RRuleZeroWrapper(rule)\n\nThis struct is used to ensure that ZeroRDatas, which are used as placeholder zero elements whenever an actual instance of a zero rdata for a particular primal type cannot be constructed without also having an instance of said type, never reach rules. On the pullback, we increment the cotangent dy by an amount equal to zero. This ensures that if it is a ZeroRData, we instead get an actual zero of the correct type. If it is not a zero rdata, the computation should be elided via inlining + constant prop.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.SharedDataPairs-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.SharedDataPairs","text":"SharedDataPairs()\n\nA data structure used to manage the captured data in the OpaqueClosures which implement the bulk of the forwards- and reverse-passes of AD. An entry (id, data) at element n of the pairs field of this data structure means that data will be available at register id during the forwards- and reverse-passes of AD.\n\nThis is achieved by storing all of the data in the pairs field in the captured tuple which is passed to an OpaqueClosure, and extracting this data into registers associated to the corresponding IDs.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.Stack-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.Stack","text":"Stack{T}()\n\nA stack specialised for reverse-mode AD.\n\nSemantically equivalent to a usual stack, but never de-allocates memory once allocated.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.Switch-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.Switch","text":"Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID)\n\nA switch-statement node. These can be inserted in the BBCode representation of Julia IR. Switch has the following semantics:\n\ngoto dests[1] if not conds[1]\ngoto dests[2] if not conds[2]\n...\ngoto dests[N] if not conds[N]\ngoto fallthrough_dest\n\nwhere the value associated to each element of conds is a Bool, and dests indicate which block to jump to. If none of the conditions are met, then we go to whichever block is specified by fallthrough_dest.\n\nSwitch statements are lowered into the above sequence of GotoIfNots and GotoNodes when converting BBCode back into IRCode, because Switch statements are not valid nodes in regular Julia IR.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.UnhandledLanguageFeatureException-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.UnhandledLanguageFeatureException","text":"UnhandledLanguageFeatureException(message::String)\n\nAn exception used to indicate that some aspect of the Julia language which AD cannot handle has been encountered.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ZeroRData-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ZeroRData","text":"ZeroRData()\n\nSingleton type indicating zero-valued rdata. This should only ever appear as an intermediate quantity in the reverse-pass of AD when the type of the primal is not fully inferable, or a field of a type is abstractly typed.\n\nIf you see this anywhere in actual code, or if it appears in a hand-written rule, this is an error – please open an issue in such a situation.\n\n\n\n\n\n","category":"type"},{"location":"developer_documentation/internal_docstrings/#Base.insert!-Tuple{Mooncake.BBlock, Int64, Mooncake.ID, Core.Compiler.NewInstruction}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Base.insert!","text":"Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing\n\nInserts stmt and id into bb immediately before the nth instruction.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__deref_and_zero-Union{Tuple{P}, Tuple{Type{P}, Ref}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__deref_and_zero","text":"__deref_and_zero(::Type{P}, x::Ref) where {P}\n\nHelper, used in concludervsblock.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__flatten_varargs-Union{Tuple{nvargs}, Tuple{isva}, Tuple{Val{isva}, Any, Val{nvargs}}} where {isva, nvargs}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__flatten_varargs","text":"__flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs}\n\nIf isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__get_value-Tuple{Mooncake.ID, Mooncake.IDPhiNode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__get_value","text":"__get_value(edge::ID, x::IDPhiNode)\n\nHelper functionality for concludervsblock.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__insts_to_instruction_stream-Tuple{Vector{Any}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__insts_to_instruction_stream","text":"__insts_to_instruction_stream(insts::Vector{Any})\n\nProduces an instruction stream whose\n\nstmt (v1.11 and up) / inst (v1.10) field is insts,\ntype field is all Any,\ninfo field is all Core.Compiler.NoCallInfo,\nline field is all Int32(1), and\nflag field is all Core.Compiler.IR_FLAG_REFINED.\n\nAs such, if you wish to ensure that your IRCode prints nicely, you should ensure that its linetable field has at least one element.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__line_numbers_to_block_numbers!-Tuple{Vector{Any}, Core.Compiler.CFG}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__line_numbers_to_block_numbers!","text":"__line_numbers_to_block_numbers!(insts::Vector{Any}, cfg::CC.CFG)\n\nConverts any edges in GotoNodes, GotoIfNots, PhiNodes, and :enter expressions which refer to line numbers into references to block numbers. The cfg provides the information required to perform this conversion.\n\nFor context, CodeInfo objects have references to line numbers, while IRCode uses block numbers.\n\nThis code is copied over directly from the body of Core.Compiler.inflate_ir!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__make_ref-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__make_ref","text":"__make_ref(p::Type{P}) where {P}\n\nHelper for reverse_data_ref_stmts. Constructs a Ref whose element type is the zero_like_rdata_type for P, and whose element is the zero-like rdata for P.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__pop_blk_stack!-Tuple{Mooncake.Stack{Int32}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__pop_blk_stack!","text":"__pop_blk_stack!(block_stack::BlockStack)\n\nEquivalent to pop!(block_stack). Going via this function, rather than just calling pop! directly, makes it easy to figure out how much time is spent popping the block stack when profiling performance, and to know that this function was hit when debugging.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__push_blk_stack!-Tuple{Mooncake.Stack{Int32}, Int32}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__push_blk_stack!","text":"__push_blk_stack!(block_stack::BlockStack, id::Int32)\n\nEquivalent to push!(block_stack, id). Going via this function, rather than just calling push! directly, is helpful for debugging and performance analysis – it makes it very straightforward to figure out much time is spent pushing to the block stack when profiling.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__run_rvs_pass!-Union{Tuple{sig}, Tuple{Type, Type{sig}, Any, Ref, Vararg{Any}}} where sig-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__run_rvs_pass!","text":"__run_rvs_pass!(\n P::Type, ::Type{sig}, pb!!, ret_rev_data_ref::Ref, arg_rev_data_refs...\n) where {sig}\n\nUsed in make_ad_stmts! method for Expr(:call, ...) and Expr(:invoke, ...).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__switch_case-Tuple{Int32, Int32}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__switch_case","text":"__switch_case(id::Int32, predecessor_id::Int32)\n\nHelper function emitted by make_switch_stmts.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__unflatten_codual_varargs-Union{Tuple{nargs}, Tuple{isva}, Tuple{Val{isva}, Any, Val{nargs}}} where {isva, nargs}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__unflatten_codual_varargs","text":"__unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs}\n\nIf isva and nargs=2, then inputs (CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0)) are transformed into (CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0))).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__value_and_gradient!!-Union{Tuple{N}, Tuple{R}, Tuple{R, Vararg{Mooncake.CoDual, N}}} where {R, N}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__value_and_gradient!!","text":"__value_and_gradient!!(rule, f::CoDual, x::CoDual...)\n\nNote: this is not part of the public Mooncake.jl interface, and may change without warning.\n\nEquivalent to __value_and_pullback!!(rule, 1.0, f, x...) – assumes f returns a Float64.\n\n# Set up the problem.\nf(x, y) = sum(x .* y)\nx = [2.0, 2.0]\ny = [1.0, 1.0]\nrule = build_rrule(f, x, y)\n\n# Allocate tangents. These will be written to in-place. You are free to re-use these if you\n# compute gradients multiple times.\ntf = zero_tangent(f)\ntx = zero_tangent(x)\nty = zero_tangent(y)\n\n# Do AD.\nMooncake.__value_and_gradient!!(\n rule, Mooncake.CoDual(f, tf), Mooncake.CoDual(x, tx), Mooncake.CoDual(y, ty)\n)\n# output\n\n(4.0, (NoTangent(), [1.0, 1.0], [2.0, 2.0]))\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.__value_and_pullback!!-Union{Tuple{T}, Tuple{N}, Tuple{R}, Tuple{R, T, Vararg{Mooncake.CoDual, N}}} where {R, N, T}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.__value_and_pullback!!","text":"__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...)\n\nNote: this is not part of the public Mooncake.jl interface, and may change without warning.\n\nIn-place version of value_and_pullback!! in which the arguments have been wrapped in CoDuals. Note that any mutable data in f and x will be incremented in-place. As such, if calling this function multiple times with different values of x, should be careful to ensure that you zero-out the tangent fields of x each time.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._block_nums_to_ids-Tuple{Vector{Core.Compiler.NewInstruction}, Core.Compiler.CFG}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._block_nums_to_ids","text":"_block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector}\n\nAssign to each basic block in cfg an ID. Replace all integers referencing block numbers in insts with the corresponding ID. Return the IDs and the updated instructions.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._build_graph_of_cfg-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._build_graph_of_cfg","text":"_build_graph_of_cfg(blks::Vector{BBlock})::Tuple{SimpleDiGraph, Dict{ID, Int}}\n\nBuilds a SimpleDiGraph, g, representing of the CFG associated to blks, where blks comprises the collection of basic blocks associated to a BBCode. This is a type from Graphs.jl, so constructing g makes it straightforward to analyse the control flow structure of ir using algorithms from Graphs.jl.\n\nReturns a 2-tuple, whose first element is g, and whose second element is a map from the ID associated to each basic block in ir, to the Int corresponding to its node index in g.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._compute_all_predecessors-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._compute_all_predecessors","text":"_compute_all_predecessors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}\n\nInternal method implementing compute_all_predecessors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._compute_all_successors-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._compute_all_successors","text":"_compute_all_successors(blks::Vector{BBlock})::Dict{ID, Vector{ID}}\n\nInternal method implementing compute_all_successors. This method is easier to construct test cases for because it only requires the collection of BBlocks, not all of the other stuff that goes into a BBCode.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._control_flow_graph-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._control_flow_graph","text":"_control_flow_graph(blks::Vector{BBlock})::Core.Compiler.CFG\n\nInternal function, used to implement control_flow_graph. Easier to write test cases for because there is no need to construct an ensure BBCode object, just the BBlocks.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._distance_to_entry-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._distance_to_entry","text":"_distance_to_entry(blks::Vector{BBlock})::Vector{Int}\n\nFor each basic block in blks, compute the distance from it to the entry point (the first block. The distance is typemax(Int) if no path from the entry point to a given node.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._find_id_uses!-Tuple{Dict{Mooncake.ID, Bool}, Expr}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._find_id_uses!","text":"_find_id_uses!(d::Dict{ID, Bool}, x)\n\nHelper function used in characterise_used_ids. For all uses of IDs in x, set the corresponding value of d to true.\n\nFor example, if x = ReturnNode(ID(5)), then this function sets d[ID(5)] = true.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._foreigncall_-Union{Tuple{N}, Tuple{calling_convention}, Tuple{nreq}, Tuple{RT}, Tuple{name}, Tuple{Val{name}, Val{RT}, Tuple, Val{nreq}, Val{calling_convention}, Vararg{Any, N}}} where {name, RT, nreq, calling_convention, N}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._foreigncall_","text":"function _foreigncall_(\n ::Val{name}, ::Val{RT}, AT::Tuple, ::Val{nreq}, ::Val{calling_convention}, x...\n) where {name, RT, nreq, calling_convention}\n\n:foreigncall nodes get translated into calls to this function. For example,\n\nExpr(:foreigncall, :foo, Tout, (A, B), nreq, :ccall, args...)\n\nbecomes\n\n_foreigncall_(Val(:foo), Val(Tout), (Val(A), Val(B)), Val(nreq), Val(:ccall), args...)\n\nPlease consult the Julia documentation for more information on how foreigncall nodes work, and consult this package's tests for examples.\n\nCredit: Umlaut.jl has the original implementation of this function. This is largely copied over from there.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._ids_to_line_numbers-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._ids_to_line_numbers","text":"_ids_to_line_numbers(bb_code::BBCode)::InstVector\n\nFor each statement in bb_code, returns a NewInstruction in which every ID is replaced by either an SSAValue, or an Int64 / Int32 which refers to an SSAValue.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._is_reachable-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._is_reachable","text":"_is_reachable(blks::Vector{BBlock})::Vector{Bool}\n\nComputes a Vector whose length is length(blks). The nth element is true iff it is possible for control flow to reach the nth block.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._lines_to_blocks-Tuple{Vector{Core.Compiler.NewInstruction}, Core.Compiler.CFG}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._lines_to_blocks","text":"_instructions_to_blocks(insts::InstVector, cfg::CC.CFG)::InstVector\n\nPulls out the instructions from insts, and calls __line_numbers_to_block_numbers!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._lower_switch_statements-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._lower_switch_statements","text":"_lower_switch_statements(bb_code::BBCode)\n\nConverts all Switchs into a semantically-equivalent collection of GotoIfNots. See the Switch docstring for an explanation of what is going on here.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._map-Union{Tuple{N}, Tuple{F}, Tuple{F, Vararg{Any, N}}} where {F, N}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._map","text":"_map(f, x...)\n\nSame as map but requires all elements of x to have equal length. The usual function map doesn't enforce this for Arrays.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._map_if_assigned!-Union{Tuple{P}, Tuple{F}, Tuple{F, DenseArray, DenseArray{P}, DenseArray}} where {F, P}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._map_if_assigned!","text":"_map_if_assigned!(f::F, y::DenseArray, x1::DenseArray{P}, x2::DenseArray)\n\nSimilar to the other method of _map_if_assigned! – for all n, if x1[n] is assigned, writes f(x1[n], x2[n]) to y[n], otherwise leaves y[n] unchanged.\n\nRequires that y, x1, and x2 have the same size.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._map_if_assigned!-Union{Tuple{P}, Tuple{F}, Tuple{F, DenseArray, DenseArray{P}}} where {F, P}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._map_if_assigned!","text":"_map_if_assigned!(f, y::DenseArray, x::DenseArray{P}) where {P}\n\nFor all n, if x[n] is assigned, then writes the value returned by f(x[n]) to y[n], otherwise leaves y[n] unchanged.\n\nEquivalent to map!(f, y, x) if P is a bits type as element will always be assigned.\n\nRequires that y and x have the same size.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._new_-Union{Tuple{N}, Tuple{T}, Tuple{Type{T}, Vararg{Any, N}}} where {T, N}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._new_","text":"_new_(::Type{T}, x::Vararg{Any, N}) where {T, N}\n\nOne-liner which calls the :new instruction with type T with arguments x.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._remove_double_edges-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._remove_double_edges","text":"_remove_double_edges(ir::BBCode)::BBCode\n\nIf the dest field of an IDGotoIfNot node in block n of ir points towards the n+1th block then we have two edges from block n to block n+1. This transformation replaces all such IDGotoIfNot nodes with unconditional IDGotoNodes pointing towards the n+1th block in ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._sort_blocks!-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._sort_blocks!","text":"_sort_blocks!(ir::BBCode)::BBCode\n\nEnsure that blocks appear in order of distance-from-entry-point, where distance the distance from block b to the entry point is defined to be the minimum number of basic blocks that must be passed through in order to reach b.\n\nFor reasons unknown (to me, Will), the compiler / optimiser needs this for inference to succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem there.\n\nWARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic blocks in ir is valid. Notably, this does not hold if you have any IDGotoIfNot nodes in ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._splat_new_-Union{Tuple{P}, Tuple{Type{P}, Tuple}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._splat_new_","text":"_splat_new_(::Type{P}, x::Tuple) where {P}\n\nFunction which replaces instances of :splatnew.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._ssa_to_ids-Tuple{Dict{Core.SSAValue, Mooncake.ID}, Core.Compiler.NewInstruction}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._ssa_to_ids","text":"_ssa_to_ids(d::SSAToIdDict, inst::NewInstruction)\n\nProduce a new instance of inst in which all instances of SSAValues are replaced with the IDs prescribed by d, all basic block numbers are replaced with the IDs prescribed by d, and GotoIfNot, GotoNode, and PhiNode instances are replaced with the corresponding ID versions.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._ssas_to_ids-Tuple{Vector{Core.Compiler.NewInstruction}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._ssas_to_ids","text":"_ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector}\n\nAssigns an ID to each line in stmts, and replaces each instance of an SSAValue in each line with the corresponding ID. For example, a call statement of the form Expr(:call, :f, %4) is be replaced with Expr(:call, :f, id_assigned_to_%4).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._to_ssas-Tuple{Dict, Core.Compiler.NewInstruction}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._to_ssas","text":"_to_ssas(d::Dict, inst::NewInstruction)\n\nLike _ssas_to_ids, but in reverse. Converts IDs to SSAValues / (integers corresponding to ssas).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake._typeof-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake._typeof","text":"_typeof(x)\n\nCentral definition of typeof, which is specific to the use-required in this package.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ad_stmt_info-Tuple{Mooncake.ID, Union{Nothing, Mooncake.ID}, Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ad_stmt_info","text":"ad_stmt_info(line::ID, comms_id::Union{ID, Nothing}, fwds, rvs)\n\nConvenient constructor for ADStmtInfo. If either fwds or rvs is not a vector, __vec promotes it to a single-element Vector.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.add_data!-Tuple{Mooncake.ADInfo, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.add_data!","text":"add_data!(info::ADInfo, data)::ID\n\nEquivalent to add_data!(info.shared_data_pairs, data).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.add_data!-Tuple{Mooncake.SharedDataPairs, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.add_data!","text":"add_data!(p::SharedDataPairs, data)::ID\n\nPuts data into p, and returns the id associated to it. This id should be assumed to be available during the forwards- and reverse-passes of AD, and it should further be assumed that the value associated to this id is always data.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.add_data_if_not_singleton!-Tuple{Union{Mooncake.ADInfo, Mooncake.SharedDataPairs}, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.add_data_if_not_singleton!","text":"add_data_if_not_singleton!(p::Union{ADInfo, SharedDataPairs}, x)\n\nReturns x if it is a singleton, or the ID of the ssa which will contain it on the forwards- and reverse-passes. The reason for this is that if something is a singleton, it can be inserted directly into the IR.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.can_produce_zero_rdata_from_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.can_produce_zero_rdata_from_type","text":"can_produce_zero_rdata_from_type(::Type{P}) where {P}\n\nReturns whether or not the zero element of the rdata type for primal type P can be obtained from P alone.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.characterise_unique_predecessor_blocks-Tuple{Vector{Mooncake.BBlock}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.characterise_unique_predecessor_blocks","text":"characterise_unique_predecessor_blocks(blks::Vector{BBlock}) ->\n Tuple{Dict{ID, Bool}, Dict{ID, Bool}}\n\nWe call a block b a unique predecessor in the control flow graph associated to blks if it is the only predecessor to all of its successors. Put differently we call b a unique predecessor if, whenever control flow arrives in any of the successors of b, we know for certain that the previous block must have been b.\n\nReturns two Dicts. A value in the first Dict is true if the block associated to its key is a unique precessor, and is false if not. A value in the second Dict is true if it has a single predecessor, and that predecessor is a unique predecessor.\n\nContext:\n\nThis information is important for optimising AD because knowing that b is a unique predecessor means that\n\non the forwards-pass, there is no need to push the ID of b to the block stack when passing through it, and\non the reverse-pass, there is no need to pop the block stack when passing through one of the successors to b.\n\nUtilising this reduces the overhead associated to doing AD. It is quite important when working with cheap loops – loops where the operations performed at each iteration are inexpensive – for which minimising memory pressure is critical to performance. It is also important for single-block functions, because it can be used to entirely avoid using a block stack at all.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.characterise_used_ids-Tuple{Vector{Tuple{Mooncake.ID, Core.Compiler.NewInstruction}}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.characterise_used_ids","text":"characterise_used_ids(stmts::Vector{IDInstPair})::Dict{ID, Bool}\n\nFor each line in stmts, determine whether it is referenced anywhere else in the code. Returns a dictionary containing the results. An element is false if the corresponding ID is unused, and true if is used.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.collect_stmts-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.collect_stmts","text":"collect_stmts(ir::BBCode)::Vector{IDInstPair}\n\nProduce a Vector containing all of the statements in ir. These are returned in order, so it is safe to assume that element n refers to the nth element of the IRCode associated to ir. \n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.collect_stmts-Tuple{Mooncake.BBlock}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.collect_stmts","text":"collect_stmts(bb::BBlock)::Vector{IDInstPair}\n\nReturns a Vector containing the IDs and instructions associated to each line in bb. These should be assumed to be ordered.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.comms_channel-Tuple{Mooncake.ADStmtInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.comms_channel","text":"comms_channel(info::ADStmtInfo)\n\nReturn the element of fwds whose ID is the communcation ID. Returns Nothing if comms_id is nothing.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.compute_all_predecessors-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.compute_all_predecessors","text":"compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}}\n\nCompute a map from the ID of eachBBlockinir` to its possible predecessors.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.compute_all_successors-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.compute_all_successors","text":"compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}}\n\nCompute a map from the ID of eachBBlockinir` to its possible successors.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.conclude_rvs_block-Tuple{Mooncake.BBlock, Vector{Mooncake.ID}, Bool, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.conclude_rvs_block","text":"conclude_rvs_block(\n blk::BBlock, pred_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo\n)\n\nGenerates code which is inserted at the end of each counterpart block in the reverse-pass. Handles phi nodes, and choosing the correct next block to switch to.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.const_ad_stmt-Tuple{Any, Mooncake.ID, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.const_ad_stmt","text":"const_ad_stmt(stmt, line::ID, info::ADInfo)\n\nImplementation of make_ad_stmts! used for constants.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.const_codual-Tuple{Any, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.const_codual","text":"const_codual(stmt, info::ADInfo)\n\nBuild a CoDual from stmt, with zero / uninitialised fdata. If the resulting CoDual is a bits type, then it is returned. If it is not, then the CoDual is put into shared data, and the ID associated to it in the forwards- and reverse-passes returned.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.const_codual_stmt-Tuple{Any, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.const_codual_stmt","text":"const_codual_stmt(stmt, info::ADInfo)\n\nReturns a :call expression which will return a CoDual whose primal is stmt, and whose tangent is whatever uninit_tangent returns.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.control_flow_graph-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.control_flow_graph","text":"control_flow_graph(bb_code::BBCode)::Core.Compiler.CFG\n\nComputes the Core.Compiler.CFG object associated to this bb_code.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.create_comms_insts!-Tuple{Vector{Tuple{Mooncake.ID, Vector{Mooncake.ADStmtInfo}}}, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.create_comms_insts!","text":"create_comms_insts!(ad_stmts_blocks::ADStmts, info::ADInfo)\n\nThis function produces code which can be inserted into the forwards-pass and reverse-pass at specific locations to implement the promise associated to the comms_id field of the ADStmtInfo type – namely that if you assign a value to comms_id on the forwards-pass, the same value will be available at comms_id on the reverse-pass.\n\nFor each basic block represented in ADStmts:\n\ncreate a stack containing a Tuple which can hold all of the values associated to the comms_ids for each statement. Put this stack in shared data.\ncreate instructions which can be inserted at the end of the block generated to perform the forwards-pass (in forwards_pass_ir) which will put all of the data associated to the comms_ids into shared data, and\ncreate instruction which can be inserted at the start of the block generated to perform the reverse-pass (in pullback_ir), which will extract all of the data put into shared data by the instructions generated by the previous point, and assigned them to the comms_ids.\n\nReturns two a Tuple{Vector{IDInstPair}, Vector{IDInstPair}. The nth element of each Vector corresponds to the instructions to be inserted into the forwards- and reverse passes resp. for the nth block in ad_stmts_blocks.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.fcodual_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.fcodual_type","text":"fcodual_type(P::Type)\n\nThe type of the CoDual which contains instances of P and its fdata.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.fdata_field_type-Union{Tuple{P}, Tuple{Type{P}, Int64}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.fdata_field_type","text":"fdata_field_type(::Type{P}, n::Int) where {P}\n\nReturns the type of to the nth field of the fdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.foreigncall_to_call-Tuple{Any, Dict{Symbol, Core.Compiler.VarState}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.foreigncall_to_call","text":"foreigncall_to_call(inst, sp_map::Dict{Symbol, CC.VarState})\n\nIf inst is a :foreigncall expression translate it into an equivalent :call expression. If anything else, just return inst. See Mooncake._foreigncall_ for details.\n\nsp_map maps the names of the static parameters to their values. This function is intended to be called in the context of an IRCode, in which case the values of sp_map are given by the sptypes field of said IRCode. The keys should generally be obtained from the Method from which the IRCode is derived. See Mooncake.normalise! for more details.\n\nThe purpose of this transformation is to make it possible to differentiate :foreigncall expressions in the same way as a primitive :call expression, i.e. via an rrule!!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.forwards_pass_ir-Tuple{Mooncake.BBCode, Vector{Tuple{Mooncake.ID, Vector{Mooncake.ADStmtInfo}}}, Any, Mooncake.ADInfo, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.forwards_pass_ir","text":"forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)\n\nProduce the IR associated to the OpaqueClosure which runs most of the forwards-pass.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.fwd_ir-Tuple{Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.fwd_ir","text":"fwd_ir(\n sig::Type{<:Tuple};\n interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true\n)::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGenerate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.\n\nFor example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:\n\njulia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\nArguments\n\nsig::Type{<:Tuple}: the signature of the call to be differentiated.\n\nKeyword Arguments\n\ninterp: the interpreter to use to obtain the primal IR.\ndebug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.\ndo_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.gc_preserve-Tuple-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.gc_preserve","text":"gc_preserve(xs...)\n\nA no-op function. Its rrule!! ensures that the memory associated to xs is not freed until the pullback that it returns is run.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.generate_ir-Tuple{Mooncake.MooncakeInterpreter, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.generate_ir","text":"generate_ir(\n interp::MooncakeInterpreter, sig_or_mi; debug_mode=false, do_inline=true\n)\n\nUsed by build_rrule, and the various debugging tools: primalir, fwdsir, adjoint_ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.get_const_primal_value-Tuple{GlobalRef}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.get_const_primal_value","text":"get_const_primal_value(x::GlobalRef)\n\nGet the value associated to x. For GlobalRefs, verify that x is indeed a constant.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.get_primal_type-Tuple{Mooncake.ADInfo, Core.Argument}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.get_primal_type","text":"get_primal_type(info::ADInfo, x)\n\nReturns the static / inferred type associated to x.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.get_rev_data_id-Tuple{Mooncake.ADInfo, Core.Argument}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.get_rev_data_id","text":"get_rev_data_id(info::ADInfo, x)\n\nReturns the ID associated to the line in the reverse pass which will contain the reverse data for x. If x is not an Argument or ID, then nothing is returned.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.get_tangent_field-Union{Tuple{Tfs}, Tuple{Union{MutableTangent{Tfs}, Tangent{Tfs}}, Int64}} where Tfs-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.get_tangent_field","text":"get_tangent_field(t::Union{MutableTangent, Tangent}, i::Int)\n\nGets the ith field of data in t.\n\nHas the same semantics that getfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of getfield for MutableTangent.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.id_to_line_map-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.id_to_line_map","text":"id_to_line_map(ir::BBCode)\n\nProduces a Dict mapping from each ID associated with a line in ir to its line number. This is isomorphic to mapping to its SSAValue in IRCode. Terminators do not have IDs associated to them, so not every line in the original IRCode is mapped to.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.inc_args-Tuple{Expr}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.inc_args","text":"inc_args(stmt)\n\nIncrement by 1 the n field of any Arguments present in stmt.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.increment_and_get_rdata!-Union{Tuple{T}, Tuple{NoFData, T, T}} where T<:Union{Float16, Float32, Float64}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.increment_and_get_rdata!","text":"increment_and_get_rdata!(fdata, zero_rdata, cr_tangent)\n\nIncrement fdata by the fdata component of the ChainRules.jl-style tangent, cr_tangent, and return the rdata component of cr_tangent by adding it to zero_rdata.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.increment_field!!-Union{Tuple{i}, Tuple{Tuple, Any, Val{i}}} where i-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.increment_field!!","text":"increment_field!!(x::T, y::V, f) where {T, V}\n\nincrement!! the field f of x by y, and return the updated x.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.increment_rdata!!-Union{Tuple{T}, Tuple{T, Any}} where T-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.increment_rdata!!","text":"increment_rdata!!(t::T, r)::T where {T}\n\nIncrement the rdata component of tangent t by r, and return the updated tangent. Useful for implementation getfield-like rules for mutable structs, pointers, dicts, etc.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.infer_ir!-Tuple{Core.Compiler.IRCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.infer_ir!","text":"infer_ir!(ir::IRCode) -> IRCode\n\nRuns type inference on ir, which mutates ir, and returns it.\n\nNote: the compiler will not infer the types of anything where the corrsponding element of ir.stmts.flag is not set to Core.Compiler.IR_FLAG_REFINED. Nor will it attempt to refine the type of the value returned by a :invoke expressions. Consequently, if you find that the types in your IR are not being refined, you may wish to check that neither of these things are happening.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.insert_before_terminator!-Tuple{Mooncake.BBlock, Mooncake.ID, Core.Compiler.NewInstruction}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.insert_before_terminator!","text":"insert_before_terminator!(bb::BBlock, id::ID, inst::NewInstruction)::Nothing\n\nIf the final instruction in bb is a Terminator, insert inst immediately before it. Otherwise, insert inst at the end of the block.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.interpolate_boundschecks!-Tuple{Core.Compiler.IRCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.interpolate_boundschecks!","text":"interpolate_boundschecks!(ir::IRCode)\n\nFor every x = Expr(:boundscheck, value) in ir, interpolate value into all uses of x. This is only required in order to ensure that literal versions of memoryrefget, memoryrefset!, getfield, and setfield! work effectively. If they are removed through improvements to the way that we handle constant propagation inside Mooncake, then this functionality can be removed.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.intrinsic_to_function-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.intrinsic_to_function","text":"intrinsic_to_function(inst)\n\nIf inst is a :call expression to a Core.IntrinsicFunction, replace it with a call to the corresponding function from Mooncake.IntrinsicsWrappers, else return inst.\n\ncglobal is a special case – it requires that its first argument be static in exactly the same way as :foreigncall. See IntrinsicsWrappers.__cglobal for more info.\n\nThe purpose of this transformation is to make it possible to use dispatch to write rules for intrinsic calls using dispatch in a type-stable way. See IntrinsicsWrappers for more context.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.ircode-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.ircode","text":"ircode(\n inst::Vector{Any},\n argtypes::Vector{Any},\n sptypes::Vector{CC.VarState}=CC.VarState[],\n) -> IRCode\n\nConstructs an instance of an IRCode. This is useful for constructing test cases with known properties.\n\nNo optimisations or type inference are performed on the resulting IRCode, so that the IRCode contains exactly what is intended by the caller. Please make use of infer_types! if you require the types to be inferred.\n\nEdges in PhiNodes, GotoIfNots, and GotoNodes found in inst must refer to lines (as in CodeInfo). In the IRCode returned by this function, these line references are translated into block references.\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_always_fully_initialised-Tuple{DataType}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_always_fully_initialised","text":"is_always_fully_initialised(P::DataType)::Bool\n\nTrue if all fields in P are always initialised. Put differently, there are no inner constructors which permit partial initialisation.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_always_initialised-Tuple{DataType, Int64}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_always_initialised","text":"is_always_initialised(P::DataType, n::Int)::Bool\n\nTrue if the nth field of P is always initialised. If the nth fieldtype of P isbitstype, then this is distinct from asking whether the nth field is always defined. An isbits field is always defined, but is not always explicitly initialised.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_primitive-Tuple{Type{Mooncake.MinimalCtx}, Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_primitive","text":"is_primitive(::Type{Ctx}, sig) where {Ctx}\n\nReturns a Bool specifying whether the methods specified by sig are considered primitives in the context of contexts of type Ctx.\n\nis_primitive(DefaultCtx, Tuple{typeof(sin), Float64})\n\nwill return if calling sin(5.0) should be treated as primitive when the context is a DefaultCtx.\n\nObserve that this information means that whether or not something is a primitive in a particular context depends only on static information, not any run-time information that might live in a particular instance of Ctx.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_reachable_return_node-Tuple{Core.ReturnNode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_reachable_return_node","text":"is_reachable_return_node(x::ReturnNode)\n\nDetermine whether x is a ReturnNode, and if it is, if it is also reachable. This is purely a function of whether or not its val field is defined or not.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_unreachable_return_node-Tuple{Core.ReturnNode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_unreachable_return_node","text":"is_unreachable_return_node(x::ReturnNode)\n\nDetermine whehter x is a ReturnNode, and if it is, if it is also unreachable. This is purely a function of whether or not its val field is defined or not.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_used-Tuple{Mooncake.ADInfo, Mooncake.ID}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_used","text":"is_used(info::ADInfo, id::ID)::Bool\n\nReturns true if id is used by any of the lines in the ir, false otherwise.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_vararg_and_sparam_names-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_vararg_and_sparam_names","text":"is_vararg_and_sparam_names(sig)::Tuple{Bool, Vector{Symbol}}\n\nFinds the method associated to sig, and calls is_vararg_and_sparam_names on it.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_vararg_and_sparam_names-Tuple{Core.MethodInstance}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_vararg_and_sparam_names","text":"is_vararg_and_sparam_names(mi::Core.MethodInstance)\n\nCalls is_vararg_and_sparam_names on mi.def::Method.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.is_vararg_and_sparam_names-Tuple{Method}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.is_vararg_and_sparam_names","text":"is_vararg_and_sparam_names(m::Method)\n\nReturns a 2-tuple. The first element is true if m is a vararg method, and false if not. The second element contains the names of the static parameters associated to m.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lgetfield-Union{Tuple{f}, Tuple{Any, Val{f}}} where f-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lgetfield","text":"lgetfield(x, f::Val)\n\nAn implementation of getfield in which the the field f is specified statically via a Val. This enables the implementation to be type-stable even when it is not possible to constant-propagate f. Moreover, it enable the pullback to also be type-stable.\n\nIt will always be the case that\n\ngetfield(x, :f) === lgetfield(x, Val(:f))\ngetfield(x, 2) === lgetfield(x, Val(2))\n\nThis approach is identical to the one taken by Zygote.jl to circumvent the same problem. Zygote.jl calls the function literal_getfield, while we call it lgetfield.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lgetfield-Union{Tuple{order}, Tuple{f}, Tuple{Any, Val{f}, Val{order}}} where {f, order}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lgetfield","text":"lgetfield(x, ::Val{f}, ::Val{order}) where {f, order}\n\nLike getfield, but with the field and access order encoded as types.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lift_gc_preservation-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lift_gc_preservation","text":"lift_gc_preserve(inst)\n\nExpressions of the form\n\ny = GC.@preserve x1 x2 foo(args...)\n\nget lowered to\n\ntoken = Expr(:gc_preserve_begin, x1, x2)\ny = expr\nExpr(:gc_preserve_end, token)\n\nThese expressions guarantee that any memory associated x1 and x2 not be freed until the :gc_preserve_end expression is reached.\n\nIn the context of reverse-mode AD, we must ensure that the memory associated to x1, x2 and their fdata is available during the reverse pass code associated to expr. We do this by preventing the memory from being freed until the :gc_preserve_begin is reached on the reverse pass.\n\nTo achieve this, we replace the primal code with\n\n# store `x` in `pb_gc_preserve` to prevent it from being freed.\n_, pb_gc_preserve = rrule!!(zero_fcodual(gc_preserve), x1, x2)\n\n# Differentiate the `:call` expression in the usual way.\ny, foo_pb = rrule!!(zero_fcodual(foo), args...)\n\n# Do not permit the GC to free `x` here.\nnothing\n\nThe pullback should be something along the lines of\n\n# no pullback associated to `nothing`.\nnothing\n\n# Run the pullback associated to `foo` in the usual manner. `x` must be available.\n_, dargs... = foo_pb(dy)\n\n# No-op pullback associated to `gc_preserve`.\npb_gc_preserve(NoRData())\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lift_getfield_and_others-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lift_getfield_and_others","text":"lift_getfield_and_others(inst)\n\nConverts expressions of the form getfield(x, :a) into lgetfield(x, Val(:a)). This has identical semantics, but is performant in the absence of proper constant propagation.\n\nDoes the same for...\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lookup_ir-Tuple{Core.Compiler.AbstractInterpreter, Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lookup_ir","text":"lookup_ir(\n interp::AbstractInterpreter,\n sig_or_mi::Union{Type{<:Tuple}, Core.MethodInstance},\n)::Tuple{IRCode, T}\n\nGet the unique IR associated to sig_or_mi under interp. Throws ArgumentErrors if there is no code found, or if more than one IRCode instance returned.\n\nReturns a tuple containing the IRCode and its return type.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.lsetfield!-Union{Tuple{name}, Tuple{Any, Val{name}, Any}} where name-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.lsetfield!","text":"lsetfield!(value, name::Val, x, [order::Val])\n\nThis function is to setfield! what lgetfield is to getfield. It will always hold that\n\nsetfield!(copy(x), :f, v) == lsetfield!(copy(x), Val(:f), v)\nsetfield!(copy(x), 2, v) == lsetfield(copy(x), Val(2), v)\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.make_ad_stmts!-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.make_ad_stmts!","text":"make_ad_stmts!(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo\n\nEvery line in the primal code is associated to one or more lines in the forwards-pass of AD, and one or more lines in the pullback. This function has method specific to every node type in the Julia SSAIR.\n\nTranslates the instruction inst, associated to line in the primal, into a specification of what should happen for this instruction in the forwards- and reverse-passes of AD, and what data should be shared between the forwards- and reverse-passes. Returns this in the form of an ADStmtInfo.\n\ninfo is a data structure containing various bits of global information that certain types of nodes need access to.\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/internal_docstrings/#Mooncake.make_switch_stmts-Tuple{Vector{Mooncake.ID}, Vector{Mooncake.ID}, Bool, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.make_switch_stmts","text":"make_switch_stmts(\n pred_ids::Vector{ID}, target_ids::Vector{ID}, pred_is_unique_pred::Bool, info::ADInfo\n)\n\npreds_ids comprises the IDs associated to all possible predecessor blocks to the primal block under consideration. Suppose its value is [ID(1), ID(2), ID(3)], then make_switch_stmts emits code along the lines of\n\nprev_block = pop!(block_stack)\nnot_pred_was_1 = !(prev_block == ID(1))\nnot_pred_was_2 = !(prev_block == ID(2))\nswitch(\n not_pred_was_1 => ID(1),\n not_pred_was_2 => ID(2),\n ID(3)\n)\n\nIn words: make_switch_stmts emits code which jumps to whichever block preceded the current block during the forwards-pass.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.new_inst-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.new_inst","text":"new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction\n\nCreate a NewInstruction with fields:\n\nstmt = stmt\ntype = type\ninfo = CC.NoCallInfo()\nline = Int32(1)\nflag = flag\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/internal_docstrings/#Mooncake.new_inst_vec-Tuple{Core.Compiler.InstructionStream}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.new_inst_vec","text":"new_inst_vec(x::CC.InstructionStream)\n\nConvert an Compiler.InstructionStream into a list of Compiler.NewInstructions.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.new_to_call-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.new_to_call","text":"new_to_call(x)\n\nIf instruction x is a :new expression, replace it with a :call to Mooncake._new_. Otherwise, return x.\n\nThe purpose of this transformation is to make it possible to differentiate :new expressions in the same way as a primitive :call expression, i.e. via an rrule!!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.normalise!-Tuple{Core.Compiler.IRCode, Vector{Symbol}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.normalise!","text":"normalise!(ir::IRCode, spnames::Vector{Symbol})\n\nApply a sequence of standardising transformations to ir which leaves its semantics unchanged, but makes AD more straightforward. In particular, replace\n\n:foreigncall Exprs with :calls to Mooncake._foreigncall_,\n:new Exprs with :calls to Mooncake._new_,\n:splatnew Exprs with:calls toMooncake.splatnew_`,\nCore.IntrinsicFunctions with counterparts from Mooncake.IntrinsicWrappers,\ngetfield(x, 1) with lgetfield(x, Val(1)), and related transformations,\nmemoryrefget calls to lmemoryrefget calls, and related transformations,\ngc_preserve_begin / gc_preserve_end exprs so that memory release is delayed.\n\nspnames are the names associated to the static parameters of ir. These are needed when handling :foreigncall expressions, in which it is not necessarily the case that all static parameter names have been translated into either types, or :static_parameter expressions.\n\nUnfortunately, the static parameter names are not retained in IRCode, and the Method from which the IRCode is derived must be consulted. Mooncake.is_vararg_and_sparam_names provides a convenient way to do this.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.optimise_ir!-Tuple{Core.Compiler.IRCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.optimise_ir!","text":"optimise_ir!(ir::IRCode, show_ir=false)\n\nRun a fairly standard optimisation pass on ir. If show_ir is true, displays the IR to stdout at various points in the pipeline – this is sometimes useful for debugging.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.phi_nodes-Tuple{Mooncake.BBlock}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.phi_nodes","text":"phi_nodes(bb::BBlock)::Tuple{Vector{ID}, Vector{IDPhiNode}}\n\nReturns all of the IDPhiNodes at the start of bb, along with their IDs. If there are no IDPhiNodes at the start of bb, then both vectors will be empty.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.primal_ir-Tuple{Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.primal_ir","text":"primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGet the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).\n\nFor example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:\n\njulia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.pullback_ir-Tuple{Mooncake.BBCode, Any, Vector{Tuple{Mooncake.ID, Vector{Mooncake.ADStmtInfo}}}, Any, Mooncake.ADInfo, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.pullback_ir","text":"pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data)\n\nProduce the IR associated to the OpaqueClosure which runs most of the pullback.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.pullback_type-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.pullback_type","text":"pullback_type(Trule, arg_types)\n\nGet a bound on the pullback type, given a rule and associated primal types.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rdata_field_type-Union{Tuple{P}, Tuple{Type{P}, Int64}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rdata_field_type","text":"rdata_field_type(::Type{P}, n::Int) where {P}\n\nReturns the type of to the nth field of the rdata type associated to P. Will be a PossiblyUninitTangent if said field can be undefined.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.remove_unreachable_blocks!-Tuple{Mooncake.BBCode}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.remove_unreachable_blocks!","text":"remove_unreachable_blocks!(ir::BBCode)::BBCode\n\nIf a basic block in ir cannot possibly be reached during execution, then it can be safely removed from ir without changing its functionality. A block is unreachable if either:\n\nit has no predecessors and it is not the first block, or\nall of its predecessors are themselves unreachable.\n\nFor example, consider the following IR:\n\njulia> ir = Mooncake.ircode(\n Any[Core.ReturnNode(nothing), Expr(:call, sin, 5), Core.ReturnNode(Core.SSAValue(2))],\n Any[Any, Any, Any],\n );\n\nThere is no possible way to reach the second basic block (lines 2 and 3). Applying this function will therefore remove it, yielding the following:\n\njulia> Mooncake.IRCode(Mooncake.remove_unreachable_blocks!(Mooncake.BBCode(ir)))\n1 1 ─ return nothing\n\nIn the blocks which have not been removed, there may be references to blocks which have been removed. For example, the edges in a PhiNode may contain a reference to a removed block. These references are removed in-place from these remaining blocks, so this function will (in general) modify ir.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.replace_captures-Union{Tuple{Tmc}, Tuple{Tmc, Any}} where Tmc<:MistyClosures.MistyClosure-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.replace_captures","text":"replace_captures(mc::Tmc, new_captures) where {Tmc<:MistyClosure}\n\nSame as replace_captures for Core.OpaqueClosures, but returns a new MistyClosure.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.replace_captures-Union{Tuple{Toc}, Tuple{Toc, Any}} where Toc<:Core.OpaqueClosure-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.replace_captures","text":"replace_captures(oc::Toc, new_captures) where {Toc<:OpaqueClosure}\n\nGiven an OpaqueClosure oc, create a new OpaqueClosure of the same type, but with new captured variables. This is needed for efficiency reasons – if build_rrule is called repeatedly with the same signature and intepreter, it is important to avoid recompiling the OpaqueClosures that it produces multiple times, because it can be quite expensive to do so.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.replace_uses_with!-Tuple{Any, Union{Core.Argument, Core.SSAValue}, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.replace_uses_with!","text":"replace_uses_with!(stmt, def::Union{Argument, SSAValue}, val)\n\nReplace all uses of def with val in the single statement stmt. Note: this function is highly incomplete, really only working correctly for a specific function in ir_normalisation.jl. You probably do not want to use it.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.reverse_data_ref_stmts-Tuple{Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.reverse_data_ref_stmts","text":"reverse_data_ref_stmts(info::ADInfo)\n\nCreate the statements which initialise the reverse-data Refs.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rrule_wrapper-Union{NTuple{N, Mooncake.CoDual}, Tuple{N}} where N-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rrule_wrapper","text":"rrule_wrapper(f::CoDual, args::CoDual...)\n\nUsed to implement rrule!!s via ChainRulesCore.rrule.\n\nGiven a function foo, argument types arg_types, and a method of ChainRulesCore.rrule which applies to these, you can make use of this function as follows:\n\nMooncake.@is_primitive DefaultCtx Tuple{typeof(foo), arg_types...}\nfunction Mooncake.rrule!!(f::CoDual{typeof(foo)}, args::CoDual...)\n return rrule_wrapper(f, args...)\nend\n\nAssumes that methods of to_cr_tangent and to_mooncake_tangent are defined such that you can convert between the different representations of tangents that Mooncake and ChainRulesCore expect.\n\nFurthermore, it is essential that\n\nf(args) does not mutate f or args, and\nthe result of f(args) does not alias any data stored in f or args.\n\nSubject to some constraints, you can use the @from_rrule macro to reduce the amount of boilerplate code that you are required to write even further.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rule_type-Union{Tuple{C}, Tuple{Mooncake.MooncakeInterpreter{C}, Any}} where C-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rule_type","text":"rule_type(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode) where {C}\n\nCompute the concrete type of the rule that will be returned from build_rrule. This is important for performance in dynamic dispatch, and to ensure that recursion works properly.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rvs_ir-Tuple{Type{<:Tuple}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rvs_ir","text":"rvs_ir(\n sig::Type{<:Tuple};\n interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true\n)::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGenerate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.\n\nFor example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:\n\njulia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\nArguments\n\nsig::Type{<:Tuple}: the signature of the call to be differentiated.\n\nKeyword Arguments\n\ninterp: the interpreter to use to obtain the primal IR.\ndebug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.\ndo_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.rvs_phi_block-Tuple{Mooncake.ID, Vector{Mooncake.ID}, Vector{Any}, Mooncake.ADInfo}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.rvs_phi_block","text":"rvs_phi_block(pred_id::ID, rdata_ids::Vector{ID}, values::Vector{Any}, info::ADInfo)\n\nProduces a BBlock which runs the reverse-pass for the edge associated to pred_id in a collection of IDPhiNodes, and then goes to the block associated to pred_id.\n\nFor example, suppose that we encounter the following collection of PhiNodes at the start of some block:\n\n%6 = φ (#2 => _1, #3 => %5)\n%7 = φ (#2 => 5., #3 => _2)\n\nLet the tangent refs associated to %6, %7, and _1be denotedt%6,t%7, andt1resp., and letpredidbe#2`, then this function will produce a basic block of the form\n\nincrement_ref!(t_1, t%6)\nnothing\ngoto #2\n\nThe call to increment_ref! appears because _1 is the value associated to%6 when the primal code comes from #2. Similarly, the goto #2 statement appears because we came from #2 on the forwards-pass. There is no increment_ref! associated to %7 because 5. is a constant. We emit a nothing statement, which the compiler will happily optimise away later on.\n\nThe same ideas apply if pred_id were #3. The block would end with #3, and there would be two increment_ref! calls because both %5 and _2 are not constants.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.seed_id!-Tuple{}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.seed_id!","text":"seed_id!()\n\nSet the global counter used to ensure ID uniqueness to 0. This is useful when you want to ensure determinism between two runs of the same function which makes use of IDs.\n\nThis is akin to setting the random seed associated to a random number generator globally.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.set_tangent_field!-Union{Tuple{Tfields}, Tuple{MutableTangent{Tfields}, Int64, Any}} where Tfields-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.set_tangent_field!","text":"set_tangent_field!(t::MutableTangent{Tfields}, i::Int, x) where {Tfields}\n\nSets the value of the ith field of the data in t to value x.\n\nHas the same semantics that setfield! would have if the data in the fields field of t were actually fields of t. This is the moral equivalent of setfield! for MutableTangent.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.shared_data_stmts-Tuple{Mooncake.SharedDataPairs}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.shared_data_stmts","text":"shared_data_stmts(p::SharedDataPairs)::Vector{IDInstPair}\n\nProduce a sequence of id-statment pairs which will extract the data from shared_data_tuple(p) such that the correct value is associated to the correct ID.\n\nFor example, if p.pairs is\n\n[(ID(5), 5.0), (ID(3), \"hello\")]\n\nthen the output of this function is\n\nIDInstPair[\n (ID(5), new_inst(:(getfield(_1, 1)))),\n (ID(3), new_inst(:(getfield(_1, 2)))),\n]\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.shared_data_tuple-Tuple{Mooncake.SharedDataPairs}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.shared_data_tuple","text":"shared_data_tuple(p::SharedDataPairs)::Tuple\n\nCreate the tuple that will constitute the captured variables in the forwards- and reverse- pass OpaqueClosures.\n\nFor example, if p.pairs is\n\n[(ID(5), 5.0), (ID(3), \"hello\")]\n\nthen the output of this function is\n\n(5.0, \"hello\")\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.sparam_names-Tuple{Method}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.sparam_names","text":"sparam_names(m::Core.Method)::Vector{Symbol}\n\nReturns the names of all of the static parameters in m.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.splatnew_to_call-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.splatnew_to_call","text":"splatnew_to_call(x)\n\nIf instruction x is a :splatnew expression, replace it with a :call to Mooncake._splat_new_. Otherwise return x.\n\nThe purpose of this transformation is to make it possible to differentiate :splatnew expressions in the same way as a primitive :call expression, i.e. via an rrule!!.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.stmt-Tuple{Core.Compiler.InstructionStream}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.stmt","text":"stmt(ir::CC.InstructionStream)\n\nGet the field containing the instructions in ir. This changed name in 1.11 from inst to stmt.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.tangent_field_type-Union{Tuple{P}, Tuple{Type{P}, Int64}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.tangent_field_type","text":"tangent_field_type(::Type{P}, n::Int) where {P}\n\nReturns the type that lives in the nth elements of fields in a Tangent / MutableTangent. Will either be the tangent_type of the nth fieldtype of P, or the tangent_type wrapped in a PossiblyUninitTangent. The latter case only occurs if it is possible for the field to be undefined.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.tangent_test_cases-Tuple{}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.tangent_test_cases","text":"tangent_test_cases()\n\nConstructs a Vector of Tuples containing test cases for the tangent infrastructure.\n\nIf the returned tuple has 2 elements, the elements should be interpreted as follows: 1 - interface_only 2 - primal value\n\ninterface_only is a Bool which will be used to determine which subset of tests to run.\n\nIf the returned tuple has 5 elements, then the elements are interpreted as follows: 1 - interface_only 2 - primal value 3, 4, 5 - tangents, where <5> == increment!!(<3>, <4>).\n\nTest cases in the first format make use of zero_tangent / randn_tangent etc to generate tangents, but they're unable to check that increment!! is correct in an absolute sense.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.terminator-Tuple{Mooncake.BBlock}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.terminator","text":"terminator(bb::BBlock)\n\nReturns the terminator associated to bb. If the last instruction in bb isa Terminator then that is returned, otherwise nothing is returned.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.to_cr_tangent-Tuple{Union{Float16, Float32, Float64}}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.to_cr_tangent","text":"to_cr_tangent(t)\n\nConvert a Mooncake tangent into a type that ChainRules.jl rrules expect to see.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.tuple_map-Union{Tuple{F}, Tuple{F, Tuple}} where F-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.tuple_map","text":"tuple_map(f::F, x::Tuple) where {F}\n\nThis function is largely equivalent to map(f, x), but always specialises on all of the element types of x, regardless the length of x. This contrasts with map, in which the number of element types specialised upon is a fixed constant in the compiler.\n\nAs a consequence, if x is very long, this function may have very large compile times.\n\ntuple_map(f::F, x::Tuple, y::Tuple) where {F}\n\nBinary extension of tuple_map. Nearly equivalent to map(f, x, y), but guaranteed to specialise on all element types of x and y. Furthermore, errors if x and y aren't the same length, while map will just produce a new tuple whose length is equal to the shorter of x and y.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.unhandled_feature-Tuple{String}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.unhandled_feature","text":"unhandled_feature(msg::String)\n\nThrow an UnhandledLanguageFeatureException with message msg.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.uninit_codual-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.uninit_codual","text":"uninit_codual(x)\n\nEquivalent to CoDual(x, uninit_tangent(x)).\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.uninit_fcodual-Tuple{P} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.uninit_fcodual","text":"uninit_fcodual(x)\n\nLike zero_fcodual, but doesn't guarantee that the value of the fdata is initialised. See implementation for details, as this function is subject to change.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.uninit_tangent-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.uninit_tangent","text":"uninit_tangent(x)\n\nRelated to zero_tangent, but a bit different. Check current implementation for details – this docstring is intentionally non-specific in order to avoid becoming outdated.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.verify_fdata_type-Tuple{Type, Type}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.verify_fdata_type","text":"verify_fdata_type(P::Type, F::Type)::Nothing\n\nCheck that F is a valid type for fdata associated to a primal of type P. Returns nothing if valid, throws an InvalidFDataException if a problem is found.\n\nThis applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.verify_fdata_value-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.verify_fdata_value","text":"verify_fdata_value(p, f)::Nothing\n\nCheck that f cannot be proven to be invalid fdata for p.\n\nThis method attempts to provide some confidence that f is valid fdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.\n\nPut differently, we cannot prove that f is valid fdata, only that it is not obviously invalid.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.verify_rdata_type-Tuple{Type, Type}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.verify_rdata_type","text":"verify_rdata_type(P::Type, R::Type)::Nothing\n\nCheck that R is a valid type for rdata associated to a primal of type P. Returns nothing if valid, throws an InvalidRDataException if a problem is found.\n\nThis applies to both concrete and non-concrete P. For example, if P is the type inferred for a primal q::Q, such that Q <: P, then this method is still applicable.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.verify_rdata_value-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.verify_rdata_value","text":"verify_rdata_value(p, r)::Nothing\n\nCheck that r cannot be proven to be invalid rdata for p.\n\nThis method attempts to provide some confidence that r is valid rdata for p by checking a collection of necessary conditions. We do not guarantee that these amount to a sufficient condition, just that they rule out a variety of common problems.\n\nPut differently, we cannot prove that r is valid rdata, only that it is not obviously invalid.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_adjoint-Union{Tuple{N}, Tuple{Mooncake.CoDual, Vararg{Mooncake.CoDual, N}}} where N-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_adjoint","text":"zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}\n\nUtility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.\n\nNOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.\n\nYou make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:\n\njulia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual\n\njulia> foo(x::Vararg{Int}) = 5\nfoo (generic function with 1 method)\n\njulia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;\n\njulia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())\n(NoRData(), NoRData(), NoRData())\n\nWARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_like_rdata_from_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_like_rdata_from_type","text":"zero_like_rdata_from_type(::Type{P}) where {P}\n\nThis is an internal implementation detail – you should generally not use this function.\n\nReturns either the zero element of type rdata_type(tangent_type(P)), or a ZeroRData. It is always valid to return a ZeroRData, \n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_like_rdata_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_like_rdata_type","text":"zero_like_rdata_type(::Type{P}) where {P}\n\nIndicates the type which will be returned by zero_like_rdata_from_type. Will be the rdata type for P if we can produce the zero rdata element given only P, and will be the union of R and ZeroRData if an instance of P is needed.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_rdata-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_rdata","text":"zero_rdata(p)\n\nGiven value p, return the zero element associated to its reverse data type.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.zero_rdata_from_type-Union{Tuple{Type{P}}, Tuple{P}} where P-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.zero_rdata_from_type","text":"zero_rdata_from_type(::Type{P}) where {P}\n\nReturns the zero element of rdata_type(tangent_type(P)) if this is possible given only P. If not possible, returns an instance of CannotProduceZeroRDataFromType.\n\nFor example, the zero rdata associated to any primal of type Float64 is 0.0, so for Float64s this function is simple. Similarly, if the rdata type for P is NoRData, that can simply be returned.\n\nHowever, it is not possible to return the zero rdata element for abstract types e.g. Real as the type does not uniquely determine the zero element – the rdata type for Real is Any.\n\nThese considerations apply recursively to tuples / namedtuples / structs, etc.\n\nIf you encounter a type which this function returns CannotProduceZeroRDataFromType, but you believe this is done in error, please open an issue. This kind of problem does not constitute a correctness problem, but can be detrimental to performance, so should be dealt with.\n\n\n\n\n\n","category":"method"},{"location":"developer_documentation/internal_docstrings/#Mooncake.@from_rrule-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.@from_rrule","text":"@from_rrule ctx sig [has_kwargs=false]\n\nConvenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.\n\nArguments\n\nctx: A Mooncake context type\nsig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.\nhas_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.\n\nExample Usage\n\nA Basic Example\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real) = 5x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω\n return foo(x), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)\n(NoRData(), 5.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)\nTest Passed\n\nAn Example with Keyword Arguments\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω\n return foo(x; cond), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true\n\njulia> _, pb = rrule!!(\n zero_fcodual(Core.kwcall),\n zero_fcodual((cond=false, )),\n zero_fcodual(foo),\n zero_fcodual(5.0),\n );\n\njulia> pb(3.0)\n(NoRData(), NoRData(), NoRData(), 12.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(\n Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true\n )\nTest Passed\n\nNotice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.\n\nLimitations\n\nIt is your responsibility to ensure that\n\ncalls with signature sig do not mutate their arguments,\nthe output of calls with signature sig does not alias any of the inputs.\n\nAs with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.\n\nArgument Type Constraints\n\nMany methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature\n\nTuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}\n\nThere are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.\n\nSuffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.\n\nConversions Between Different Tangent Type Systems\n\nUnder the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.\n\n\n\n\n\n","category":"macro"},{"location":"developer_documentation/internal_docstrings/#Mooncake.@is_primitive-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.@is_primitive","text":"@is_primitive context_type signature\n\nCreates a method of is_primitive which always returns true for the context_type and signature provided. For example\n\n@is_primitive MinimalCtx Tuple{typeof(foo), Float64}\n\nis equivalent to\n\nis_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(foo), Float64}}) = true\n\nYou should implemented more complicated method of is_primitive in the usual way.\n\n\n\n\n\n","category":"macro"},{"location":"developer_documentation/internal_docstrings/#Mooncake.@mooncake_overlay-Tuple{Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.@mooncake_overlay","text":"@mooncake_overlay method_expr\n\nDefine a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.\n\nFor example, suppose that you have a function\n\njulia> foo(x::Float64) = bar(x)\nfoo (generic function with 1 method)\n\nwhere Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:\n\njulia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)\n\n\nWhen looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.\n\nA Worked Example\n\nTo demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!\n\nFirst, consider a simple example:\n\njulia> scale(x) = 2x\nscale (generic function with 1 method)\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(10.0, (NoTangent(), 2.0))\n\nWe can use @mooncake_overlay to change the definition which Mooncake.jl sees:\n\njulia> Mooncake.@mooncake_overlay scale(x) = 3x\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(15.0, (NoTangent(), 3.0))\n\nAs can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.\n\nAdditionally, it is possible to use the usual multi-line syntax to declare an overlay:\n\njulia> Mooncake.@mooncake_overlay function scale(x)\n return 4x\n end\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(20.0, (NoTangent(), 4.0))\n\n\n\n\n\n","category":"macro"},{"location":"developer_documentation/internal_docstrings/#Mooncake.@zero_adjoint-Tuple{Any, Any}-developer_documentation-internal_docstrings","page":"Internal Docstrings","title":"Mooncake.@zero_adjoint","text":"@zero_adjoint ctx sig\n\nDefines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.\n\nFor example:\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo(x) = 5\nfoo (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})\ntrue\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())\n(NoRData(), 0.0)\n\nLimited support for Varargs is also available. For example\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo_varargs(x...) = 5\nfoo_varargs (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})\ntrue\n\njulia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())\n(NoRData(), 0.0, NoRData())\n\nBe aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.\n\nWARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.\n\nAs always, you should use TestUtils.test_rule to ensure that you've not made a mistake.\n\nSignatures Unsupported By This Macro\n\nIf the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.\n\n\n\n\n\n","category":"macro"},{"location":"developer_documentation/internal_docstrings/","page":"Internal Docstrings","title":"Internal Docstrings","text":"Mooncake.IntrinsicsWrappers","category":"page"},{"location":"developer_documentation/internal_docstrings/#Mooncake.IntrinsicsWrappers","page":"Internal Docstrings","title":"Mooncake.IntrinsicsWrappers","text":"module IntrinsicsWrappers\n\nThe purpose of this module is to associate to each function in Core.Intrinsics a regular Julia function.\n\nTo understand the rationale for this observe that, unlike regular Julia functions, each Core.IntrinsicFunction in Core.Intrinsics does not have its own type. Rather, they are instances of Core.IntrinsicFunction. To see this, observe that\n\njulia> typeof(Core.Intrinsics.add_float)\nCore.IntrinsicFunction\n\njulia> typeof(Core.Intrinsics.sub_float)\nCore.IntrinsicFunction\n\nWhile we could simply write a rule for Core.IntrinsicFunction, this would (naively) lead to a large list of conditionals of the form\n\nif f === Core.Intrinsics.add_float\n # return add_float and its pullback\nelseif f === Core.Intrinsics.sub_float\n # return add_float and its pullback\nelseif\n ...\nend\n\nwhich has the potential to cause quite substantial type instabilities. (This might not be true anymore – see extended help for more context).\n\nInstead, we map each Core.IntrinsicFunction to one of the regular Julia functions in Mooncake.IntrinsicsWrappers, to which we can dispatch in the usual way.\n\nExtended Help\n\nIt is possible that owing to improvements in constant propagation in the Julia compiler in version 1.10, we actually could get away with just writing a single method of rrule!! to handle all intrinsics, so this dispatch-based mechanism might be unnecessary. Someone should investigate this. Discussed at https://github.com/compintell/Mooncake.jl/issues/387 .\n\n\n\n\n\n","category":"module"},{"location":"developer_documentation/developer_tools/#Developer-Tools","page":"Developer Tools","title":"Developer Tools","text":"","category":"section"},{"location":"developer_documentation/developer_tools/","page":"Developer Tools","title":"Developer Tools","text":"Mooncake.jl offers developers to a few convenience functions which give access to the IR that it generates in order to perform AD. These are lightweight wrappers around internals which save you from having to dig in to the objects created by build_rrule.","category":"page"},{"location":"developer_documentation/developer_tools/","page":"Developer Tools","title":"Developer Tools","text":"Since these provide access to internals, they do not follow the usual rules of semver, and may change without notice!","category":"page"},{"location":"developer_documentation/developer_tools/","page":"Developer Tools","title":"Developer Tools","text":"Mooncake.primal_ir\nMooncake.fwd_ir\nMooncake.rvs_ir","category":"page"},{"location":"developer_documentation/developer_tools/#Mooncake.primal_ir","page":"Developer Tools","title":"Mooncake.primal_ir","text":"primal_ir(sig::Type{<:Tuple}; interp=get_interpreter())::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGet the Core.Compiler.IRCode associated to sig from which the a rule can be derived. Roughly equivalent to Base.code_ircode_by_type(sig; interp).\n\nFor example, if you wanted to get the IR associated to the call map(sin, randn(10)), you could do one of the following calls:\n\njulia> Mooncake.primal_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.primal_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/developer_tools/#Mooncake.fwd_ir","page":"Developer Tools","title":"Mooncake.fwd_ir","text":"fwd_ir(\n sig::Type{<:Tuple};\n interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true\n)::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGenerate the Core.Compiler.IRCode used to construct the forwards-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.\n\nFor example, if you wanted to get the IR associated to the forwards pass for the call map(sin, randn(10)), you could do either of the following:\n\njulia> Mooncake.fwd_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.fwd_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\nArguments\n\nsig::Type{<:Tuple}: the signature of the call to be differentiated.\n\nKeyword Arguments\n\ninterp: the interpreter to use to obtain the primal IR.\ndebug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.\ndo_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/developer_tools/#Mooncake.rvs_ir","page":"Developer Tools","title":"Mooncake.rvs_ir","text":"rvs_ir(\n sig::Type{<:Tuple};\n interp=get_interpreter(), debug_mode::Bool=false, do_inline::Bool=true\n)::IRCode\n\n!!! warning: this is not part of the public interface of Mooncake. As such, it may change as part of a non-breaking release of the package.\n\nGenerate the Core.Compiler.IRCode used to construct the reverse-pass of AD. Take a look at how build_rrule makes use of generate_ir to see exactly how this is used in practice.\n\nFor example, if you wanted to get the IR associated to the reverse pass for the call map(sin, randn(10)), you could do either of the following:\n\njulia> Mooncake.rvs_ir(Tuple{typeof(map), typeof(sin), Vector{Float64}}) isa Core.Compiler.IRCode\ntrue\njulia> Mooncake.rvs_ir(typeof((map, sin, randn(10)))) isa Core.Compiler.IRCode\ntrue\n\nArguments\n\nsig::Type{<:Tuple}: the signature of the call to be differentiated.\n\nKeyword Arguments\n\ninterp: the interpreter to use to obtain the primal IR.\ndebug_mode::Bool: whether the generated IR should make use of Mooncake's debug mode.\ndo_inline::Bool: whether to apply an inlining pass prior to returning the ir generated by this function. This is true by default, but setting it to false can sometimes be helpful if you need to understand what function calls are generated in order to perform AD, before lots of it gets inlined away.\n\n\n\n\n\n","category":"function"},{"location":"developer_documentation/running_tests_locally/#Running-Tests-Locally","page":"Running Tests Locally","title":"Running Tests Locally","text":"","category":"section"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Mooncake.jl's test suite is fairly extensive. While you can use Pkg.test to run the test suite in the standard manner, this is not usually optimal in Mooncake.jl, and will not run all of the tests. When editing some code, you typically only want to run the tests associated with it, not the entire test suite.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"There are two workflows for running tests, discussed below.","category":"page"},{"location":"developer_documentation/running_tests_locally/#Main-Testing-Functionality","page":"Running Tests Locally","title":"Main Testing Functionality","text":"","category":"section"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"For all code in src, Mooncake's tests are organised as follows:","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Things that are required for most / all test suites are loaded up in test/front_matter.jl.\nThe tests for something in src are located in an identically-named file in test. e.g. the unit tests for src/rrules/new.jl are located in test/rrules/new.jl.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Thus, a workflow that I (Will) find works very well is the following:","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Ensure that you have Revise.jl and TestEnv.jl installed in your default environment.\nstart the REPL, dev Mooncake.jl, and navigate to the top level of the Mooncake.jl directory.\nusing TestEnv, Revise. Better still, load both of these in your .julia/config/startup.jl file so that you don't ever forget to load them.\nRun the following: using Pkg; Pkg.activate(\".\"); TestEnv.activate(); include(\"test/front_matter.jl\"); to set up your environment.\ninclude whichever test file you want to run the tests from.\nModify code, and re-include tests to check it has done was you need. Loop this until done.\nMake a PR. This runs the entire test suite – I find that I almost never run the entire test suite locally.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"The purpose of this approach is to:","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Avoid restarting the REPL each time you make a change, and\nRun the smallest bit of the test suite possible when making changes, in order to make development a fast and enjoyable process.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"If you find that this strategy leaves you running more of the test suite than you would like, consider copy + pasting specific tests into the REPL, or commenting out a chunk of tests in the file that you are editing during development (try not to commit this). I find this is rather crude strategy effective in practice.","category":"page"},{"location":"developer_documentation/running_tests_locally/#Extension-and-Integration-Testing","page":"Running Tests Locally","title":"Extension and Integration Testing","text":"","category":"section"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"Mooncake now has quite a lot of package extensions, and a large number of integration tests. Unfortunately, these come with a lot of additional dependencies. To avoid these dependencies causing CI to take much longer to run, we locate all tests for extensions and integration testing in their own environments. These can be found in the test/ext and test/integration_testing directories respectively.","category":"page"},{"location":"developer_documentation/running_tests_locally/","page":"Running Tests Locally","title":"Running Tests Locally","text":"These directories comprise a single .jl file, and a Project.toml. You should run these tests by simply includeing the .jl file. Doing so will activate the environemnt, ensure that the correct version of Mooncake is used, and run the tests.","category":"page"},{"location":"utilities/debugging_and_mwes/#Debugging-and-MWEs","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"","category":"section"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"There's a reasonable chance that you'll run into an issue with Mooncake.jl at some point. In order to debug what is going on when this happens, or to produce an MWE, it is helpful to have a convenient way to run Mooncake.jl on whatever function and arguments you have which are causing problems.","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"We recommend making use of Mooncake.jl's testing functionality to generate your test cases:","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"Mooncake.TestUtils.test_rule","category":"page"},{"location":"utilities/debugging_and_mwes/#Mooncake.TestUtils.test_rule","page":"Debugging and MWEs","title":"Mooncake.TestUtils.test_rule","text":"test_rule(\n rng, x...;\n interface_only=false,\n is_primitive::Bool=true,\n perf_flag::Symbol=:none,\n interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(),\n debug_mode::Bool=false,\n unsafe_perturb::Bool=false,\n)\n\nRun standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.\n\nThis function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.\n\nArguments\n\nrng::AbstractRNG: a random number generator\nx...: the function (first element) and its arguments (the remainder)\n\nKeyword Arguments\n\ninterface_only::Bool=false: test only that the interface is satisfied, without testing correctness. This should generally be set to false (the default value), and only enabled if the testing infrastructure is unable to test correctness for some reason e.g. the returned value of the function is a Ptr, and appropriate tangents cannot, therefore, be generated for it automatically.\nis_primitive::Bool=true: check whether the thing that you are testing has a hand-written rrule!!. This option is helpful if you are testing a new rrule!!, as it enables you to verify that your method of is_primitive has returned the correct value, and that you are actually testing a method of the rrule!! function – a common mistake when authoring a new rrule!! is to implement is_primitive incorrectly and to accidentally wind up testing a rule which Mooncake has derived, as opposed to the one that you have written. If you are testing something for which you have not hand-written an rrule!!, or which you do not care whether it has a hand-written rrule!! or not, you should set it to false.\nperf_flag::Symbol=:none: the value of this symbol determines what kind of performance tests should be performed. By default, none are performed. If you believe that a rule should be allocation-free (iff the primal is allocation free), set this to :allocs. If you hand-write an rrule!! and believe that your test case should be type stable, set this to :stability (at present we cannot verify whether a derived rule is type stable for technical reasons). If you believe that a hand-written rule should be both allocation-free and type-stable, set this to :stability_and_allocs.\ninterp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(): the abstract interpreter to be used when testing this rule. The default should generally be used.\ndebug_mode::Bool=false: whether or not the rule should be tested in debug mode. Typically this should be left at its default false value, but if you are finding that the tests are failing for a given rule, you may wish to temporarily set it to true in order to get access to additional information and automated testing.\nunsafe_perturb::Bool=false: value passed as the third argument to _add_to_primal. Should usually be left false – consult the docstring for _add_to_primal for more info on when you might wish to set it to true.\n\n\n\n\n\n","category":"function"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"This approach is convenient because it can","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"check whether AD runs at all,\ncheck whether AD produces the correct answers,\ncheck whether AD is performant, and\ncan be used without having to manually generate tangents.","category":"page"},{"location":"utilities/debugging_and_mwes/#Example","page":"Debugging and MWEs","title":"Example","text":"","category":"section"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"DocTestSetup = quote\n using Random, Mooncake\nend","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"For example","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"f(x) = Core.bitcast(Float64, x)\nMooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)","category":"page"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.","category":"page"},{"location":"utilities/debugging_and_mwes/#Segfaults","page":"Debugging and MWEs","title":"Segfaults","text":"","category":"section"},{"location":"utilities/debugging_and_mwes/","page":"Debugging and MWEs","title":"Debugging and MWEs","text":"These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.","category":"page"},{"location":"understanding_mooncake/rule_system/#Mooncake.jl's-Rule-System","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.jl's approach to AD is recursive. It has a single specification for what it means to differentiate a Julia callable, and basically two approaches to achieving this. This section of the documentation explains the former.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We take an iterative approach to this explanation, starting at a high-level and adding more depth as we go.","category":"page"},{"location":"understanding_mooncake/rule_system/#10,000-Foot-View","page":"Mooncake.jl's Rule System","title":"10,000 Foot View","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule r(f, x) for a function f(x) \"does reverse mode AD\", and executes in two phases, known as the forwards pass and the reverse pass. In the forwards pass a rule executes the original function, and does some additional book-keeping in preparation for the reverse pass. On the reverse pass it undoes the computation from the forwards pass, \"backpropagates\" the gradient w.r.t. the output of the original function by applying the adjoint of the derivative of the original function to it, and writes the results of this computation to the correct places.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A precise mathematical model for the original function is therefore entirely crucial to this discussion, as it is needed to understand what the adjoint of its derivative is.","category":"page"},{"location":"understanding_mooncake/rule_system/#A-Model-For-A-Julia-Function","page":"Mooncake.jl's Rule System","title":"A Model For A Julia Function","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Since Julia permits the in-place modification / mutation of many data structures, we cannot make a naive translation between a Julia function and a mathematical object. Rather, we will have to model the state of the arguments to a function both before and after execution. Moreover, since a function can allocate new memory as part of execution and return it to the calling scope, we must track that too.","category":"page"},{"location":"understanding_mooncake/rule_system/#Consider-Only-Externally-Visible-Effects-Of-Function-Evaluation","page":"Mooncake.jl's Rule System","title":"Consider Only Externally-Visible Effects Of Function Evaluation","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We wish to treat a given function as a black box – we care about what a function does, not how it does it – so we consider only the externally-visible results of executing it. There are two ways in which changes can be made externally visible.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Return Value","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"(This point hardly requires explanation, but for the sake of completeness we do so anyway.)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The most obvious way in which a result can be made visible outside of a function is via its return value. For example, letting bar(x) = sin(x), consider the function","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n y = bar(x)\n z = bar(y)\n return z\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The communication between the two invocations of bar happen via the value it returns.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Modification of arguments","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In contrast to the above, changes made by one function can be made available to another implicitly if it modifies the values of its arguments, even if it doesn't return anything. For example, consider:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function bar(x::Vector{Float64})\n x .*= 2\n return nothing\nend\n\nfunction foo(x::Vector{Float64})\n bar(x)\n bar(x)\n return x\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second call to bar in foo sees the changes made to x by the first call to bar, despite not being explicitly returned.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"No Global Mutable State","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"functions can in principle also communicate via global mutable state. We make the decision to not support this.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For example, we assume functions of the following form cannot be encountered:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"const a = randn(10)\n\nfunction bar(x)\n a .+= x\n return nothing\nend\n\nfunction foo(x)\n bar(x)\n return a\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, a is modified by bar, the effect of which is visible to foo.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For a variety of reasons this is very awkward to handle well. Since it's largely considered poor practice anyway, we explicitly outlaw this mode of communication between functions. See Why Support Closures But Not Mutable Globals for more info.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Note that this does not preclude the use of closed-over values or callable structs. For example, something like","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n function bar(y)\n x .+= y\n return nothing\n end\n return bar(x)\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"is perfectly fine.","category":"page"},{"location":"understanding_mooncake/rule_system/#The-Model","page":"Mooncake.jl's Rule System","title":"The Model","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"It is helpful to have a concrete example which uses both of the permissible methods to make results externally visible. To this end, consider the following function:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function f(x::Vector{Float64}, y::Vector{Float64}, z::Vector{Float64}, s::Ref{Vector{Float64}})\n z .*= y .* x\n s[] = 2z\n return sum(z)\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We draw your attention to a variety of features of this function:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"z is mutated,\ns is mutated to reference freshly allocated memory,\nthe value previously pointed to by s is unmodified, and\nwe allocate a new value and return it (albeit, it is probably allocated on the stack).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The model we adopt for any Julia function f is a function f mathcalX to mathcalX times mathcalA where mathcalX is the real finite Hilbert space associated to the arguments to f prior to execution, and mathcalA is the real finite Hilbert space associated to any newly allocated data during execution which is externally visible after execution – any newly allocated data which is not made visible is of no concern.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, mathcalX = RR^D times RR^D times RR^D times RR^S where D is the length of x / y / z, and S the length of s[] prior to running f. mathcalA = RR^D times RR, where the RR^D component corresponds to the freshly allocated memory that s references, and RR to the return value. Observe that we model Float64s as elements of RR, Vector{Float64}s as elements of RR^D (for some value of D), and Refs with whatever the model for their contents is. The keen-eyed reader will note that these choices abstract away several details which could conceivably be included in the model. In particular, Vector{Float64} is implemented via a memory buffer, a pointer to the start of this buffer, and an integer which indicates the length of this buffer – none of these details are exposed in the model.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this example, some of the memory allocated during execution is made externally visible by modifying one of the arguments, not just via the return value.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The argument to f is the arguments to f before execution, and the output is the 2-tuple comprising the same arguments after execution and the values associated to any newly allocated / created data. Crucially, observe that we distinguish between the state of the arguments before and after execution.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For our example, the exact form of f is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f((x y z s)) = ((x y x odot y s) (2 x odot y sum_d=1^D x odot y))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that f behaves a little like a transition operator, in the that the first element of the tuple returned is the updated state of the arguments.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This model is good enough for the vast majority of functions. Unfortunately it isn't sufficient to describe a function when arguments alias each other (e.g. consider the way in which this particular model is wrong if y aliases z). Fortunately this is only a problem in a small fraction of all cases of aliasing, so we defer discussion of this until later on.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider now how this approach can be used to model several additional Julia functions, and to obtain their derivatives and adjoints.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"sin(x::Float64)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"mathcalX = RR, mathcalA = RR, f(x) = (x sin(x)).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Thus the derivative is D f x (dotx) = (dotx cos(x) dotx), and its adjoint is D f x^ast (bary) = bary_x + bary_a cos(x), where bary = (bary_x bary_a).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that this result is slightly different to the last example we saw involving sin.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"AD With Mutable Data","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider again","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function f!(x::Vector{Float64})\n x .*= x\n return sum(x)\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Our framework is able to accomodate this function, and has essentially the same solution as the last time we saw this example:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x odot x sum_n=1^N x_n^2)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Non-Mutating Functions","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A very interesting class of functions are those which do not modify their arguments. These are interesting because they are common, and are all that many AD frameworks like ChainRules.jl / Zygote.jl support – by considering this class of functions, we highlight some key similarities between these distinct rule systems.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"As always we can model these kinds of functions with a function f mathcalX to mathcalX times mathcalA, but we additionally have that f must have the form","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x varphi(x))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"for some function varphi mathcalX to mathcalA. The derivative is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x (dotx) = (dotx D varphi x(dotx))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider the usual inner product to derive the adjoint:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"beginalign\n langle bary D f x (dotx) rangle = langle (bary_1 bary_2) (dotx D varphi x(dotx)) rangle nonumber \n = langle bary_1 dotx rangle + langle bary_2 D varphi x(dotx) rangle nonumber \n = langle bary_1 dotx rangle + langle D varphi x^ast (bary_2) dotx rangle nonumber quad text(by definition of the adjoint) \n = langle bary_1 + D varphi x^ast (bary_2) dotx rangle nonumber\nendalign","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"So the adjoint of the derivative is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x^ast (bary) = bary_1 + D varphi x^ast (bary_2)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We see the correct thing to do is to increment the gradient of the output – bary_1 – by the result of applying the adjoint of the derivative of varphi to bary_2. In a ChainRules.rrule the bary_1 term is always zero, but the D varphi x^ast (bary_2) term is essentially the same.","category":"page"},{"location":"understanding_mooncake/rule_system/#The-Rule-Interface-(Round-1)","page":"Mooncake.jl's Rule System","title":"The Rule Interface (Round 1)","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Having explained in principle what it is that a rule must do, we now take a first look at the interface we use to achieve this. A rule for a function foo with signature","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Tuple{typeof(foo), Float64} -> Float64","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"must have signature","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Tuple{Trule, CoDual{typeof(foo), NoFData}, CoDual{Float64, NoFData}} ->\n Tuple{CoDual{Float64, NoFData}, Trvs_pass}","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"For example, if we call foo(5.0), it rules would be called as rule(CoDual(foo, NoFData()), CoDual(5.0, NoFData())). The precise definition and role of NoFData will be explained shortly, but the general scheme is that to a rule for foo you must pass foo itself, its arguments, and some additional data for book-keeping. foo and each of its arguments are paired with this additional book-keeping data via the CoDual type.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The rule returns another CoDual (it propagates book-keeping information forwards), along with a function which runs the reverse pass.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In a little more depth:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Notation: primal","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Throughout the rest of this document, we will refer to the function being differentiated as the \"primal\" computation, and its arguments as the \"primal\" arguments.","category":"page"},{"location":"understanding_mooncake/rule_system/#Forwards-Pass","page":"Mooncake.jl's Rule System","title":"Forwards Pass","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Inputs","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Each piece of each input to the primal is paired with shadow data, if it has a fixed address. For example, a Vector{Float64} argument is paired with another Vector{Float64}. The adjoint of f is accumulated into this shadow vector on the reverse pass. However, a Float64 argument gets paired with NoFData(), since it is a bits type and therefore has no fixed address.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Outputs","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule must return a Tuple of two things. The first thing must be a CoDual containing the output of the primal computation and its shadow memory (if it has any). The second must be a function which runs the reverse pass of AD – this will usually be a closure of some kind.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Functionality","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A rule must","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"ensure that the state of the primal components of all inputs / the output are as they would have been had the primal computation been run (up to differences due to finite precision arithmetic),\npropagate / construct the shadow memory associated to the output (initialised to zero), and\nconstruct the function to run the reverse pass – typically this will involve storing some quantities computed during the forwards pass.","category":"page"},{"location":"understanding_mooncake/rule_system/#Reverse-Pass","page":"Mooncake.jl's Rule System","title":"Reverse Pass","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second element of the output of a rule is a function which runs the reverse pass.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Inputs","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The \"rdata\" associated to the output of the primal.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Outputs","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The \"rdata\" associated to the inputs of the primal.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Functionality","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"undo changes made to primal state on the forwards pass.\napply adjoint of derivative of primal operation, putting the results in the correct place.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This description should leave you with (at least) a couple of questions. What is \"rdata\", and what is \"the correct place\" to put the results of applying the adjoint of the derivative? In order to address these, we need to discuss the types that Mooncake.jl uses to represent the results of AD, and to propagate results backwards on the reverse pass.","category":"page"},{"location":"understanding_mooncake/rule_system/#Representing-Gradients","page":"Mooncake.jl's Rule System","title":"Representing Gradients","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"We refer to both inputs and outputs of derivatives D f x mathcalX to mathcalY as tangents, e.g. dotx or doty. Conversely, we refer to both inputs and outputs to the adjoint of this derivative D f x^ast mathcalY to mathcalX as gradients, e.g. bary and barx.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Note, however, that the sets involved are the same whether dealing with a derivative or its adjoint. Consequently, we use the same type to represent both.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Representing Gradients","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This package assigns to each type in Julia a unique tangent_type, the purpose of which is to contain the gradients computed during reverse mode AD. The extended docstring for tangent_type provides the best introduction to the types which are used to represent tangents / gradients.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"tangent_type(P)","category":"page"},{"location":"understanding_mooncake/rule_system/#Mooncake.tangent_type-Tuple{Any}","page":"Mooncake.jl's Rule System","title":"Mooncake.tangent_type","text":"tangent_type(P)\n\nThere must be a single type used to represents tangents of primals of type P, and it must be given by tangent_type(P).\n\nExtended help\n\nThe tangent types which Mooncake.jl uses are quite similar in spirit to ChainRules.jl. For example, tangent \"vectors\" for\n\nFloat64s are Float64s,\nVector{Float64}s are Vector{Float64}s, and\nstructs are other another (special) struct with field types specified recursively.\n\nThere are, however, some major differences. Firstly, while it is certainly true that the above tangent types are permissible in ChainRules.jl, they are not the uniquely permissible types. For example, ZeroTangent is also a permissible type of tangent for any of them, and Float32 is permissible for Float64. This is a general theme in ChainRules.jl – it intentionally declines to place restrictions on what type can be used to represent the tangent of a given type.\n\nMooncake.jl differs from this. It insists that each primal type is associated to a single tangent type. Furthermore, this type is always given by the function Mooncake.tangent_type(primal_type).\n\nConsider some more worked examples.\n\nInt\n\nInt is not a differentiable type, so its tangent type is NoTangent:\n\njulia> tangent_type(Int)\nNoTangent\n\nTuples\n\nThe tangent type of a Tuple is defined recursively based on its field types. For example\n\njulia> tangent_type(Tuple{Float64, Vector{Float64}, Int})\nTuple{Float64, Vector{Float64}, NoTangent}\n\nThere is one edge case to be aware of: if all of the field of a Tuple are non-differentiable, then the tangent type is NoTangent. For example,\n\njulia> tangent_type(Tuple{Int, Int})\nNoTangent\n\nStructs\n\nAs with Tuples, the tangent type of a struct is, by default, given recursively. In particular, the tangent type of a struct type is Tangent. This type contains a NamedTuple containing the tangent to each field in the primal struct.\n\nAs with Tuples, if all field types are non-differentiable, the tangent type of the entire struct is NoTangent.\n\nThere are a couple of additional subtleties to consider over Tuples though. Firstly, not all fields of a struct have to be defined. Fortunately, Julia makes it easy to determine how many of the fields might possibly not be defined. The tangent associated to any field which might possibly not be defined is wrapped in a PossiblyUninitTangent.\n\nFurthermore, structs can have fields whose static type is abstract. For example\n\njulia> struct Foo\n x\n end\n\nIf you ask for the tangent type of Foo, you will see that it is\n\njulia> tangent_type(Foo)\nTangent{@NamedTuple{x}}\n\nObserve that the field type associated to x is Any. The way to understand this result is to observe that\n\nx could have literally any type at runtime, so we know nothing about what its tangent type must be until runtime, and\nwe require that the tangent type of Foo be unique.\n\nThe consequence of these two considerations is that the tangent type of Foo must be able to contain any type of tangent in its x field. It follows that the fieldtype of the x field of Foos tangent must be Any.\n\nMutable Structs\n\nThe tangent type for mutable structs have the same set of considerations as structs. The only difference is that they must themselves be mutable. Consequently, we use a type called MutableTangent to represent their tangents. It is a mutable struct with the same structure as Tangent.\n\nFor example, if you ask for the tangent_type of\n\njulia> mutable struct Bar\n x::Float64\n end\n\nyou will find that it is\n\njulia> tangent_type(Bar)\nMutableTangent{@NamedTuple{x::Float64}}\n\nPrimitive Types\n\nWe've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.\n\nOne interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example\n\njulia> tangent_type(Ptr{Float64})\nPtr{Float64}\n\n\n\n\n\n","category":"method"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"FData and RData","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this \"unique gradient address\" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The following docstring provides the best in-depth explanation.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.fdata_type(T)","category":"page"},{"location":"understanding_mooncake/rule_system/#Mooncake.fdata_type-Tuple{Any}","page":"Mooncake.jl's Rule System","title":"Mooncake.fdata_type","text":"fdata_type(T)\n\nReturns the type of the forwards data associated to a tangent of type T.\n\nExtended help\n\nRules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.\n\nGiven a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.\n\nGiven a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that\n\ntangent(fdata(t), rdata(t)) === t\n\nThe need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.\n\nInt\n\nInts are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore\n\njulia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))\n(NoFData, NoRData)\n\nFloat64\n\nThe tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so\n\njulia> (fdata_type(Float64), rdata_type(Float64))\n(NoFData, Float64)\n\nVector{Float64}\n\nThe tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so\n\njulia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))\n(Vector{Float64}, NoRData)\n\nTuple{Float64, Vector{Float64}, Int}\n\nThis is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:\n\njulia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})\nTuple{Float64, Vector{Float64}, NoTangent}\n\njulia> (fdata_type(T), rdata_type(T))\n(Tuple{NoFData, Vector{Float64}, NoFData}, Tuple{Float64, NoRData, NoRData})\n\nThe zero tangent for (5.0, [5.0]) is t = (0.0, [0.0]). fdata(t) returns (NoFData(), [0.0]), where the second element is === to the second element of t. rdata(t) returns (0.0, NoRData()). In this example, t contains a mixture of data, some of which is identified by its value, and some of which is identified by its address, so there is some fdata and some rdata.\n\nStructs\n\nStructs are handled in more-or-less the same way as Tuples, albeit with the possibility of undefined fields needing to be explicitly handled. For example, a struct such as\n\njulia> struct Foo\n x::Float64\n y\n z::Int\n end\n\nhas tangent type\n\njulia> tangent_type(Foo)\nTangent{@NamedTuple{x::Float64, y, z::NoTangent}}\n\nIts fdata and rdata are given by special FData and RData types:\n\njulia> (fdata_type(tangent_type(Foo)), rdata_type(tangent_type(Foo)))\n(Mooncake.FData{@NamedTuple{x::NoFData, y, z::NoFData}}, Mooncake.RData{@NamedTuple{x::Float64, y, z::NoRData}})\n\nPractically speaking, FData and RData both have the same structure as Tangents and are just used in different contexts.\n\nMutable Structs\n\nThe fdata for a mutable structs is its tangent, and it has no rdata. This is because mutable structs have fixed memory addresses, and can therefore be incremented in-place. For example,\n\njulia> mutable struct Bar\n x::Float64\n y\n z::Int\n end\n\nhas tangent type\n\njulia> tangent_type(Bar)\nMutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}\n\nand fdata / rdata types\n\njulia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))\n(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)\n\nPrimitive Types\n\nAs with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.\n\n\n\n\n\n","category":"method"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"CoDuals","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A quick aside: Non-Differentiable Data","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a 0 tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.","category":"page"},{"location":"understanding_mooncake/rule_system/#The-Rule-Interface-(Round-2)","page":"Mooncake.jl's Rule System","title":"The Rule Interface (Round 2)","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"DocTestSetup = quote\n using Mooncake\n using Mooncake: CoDual\n import Mooncake: rrule!!\nend","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Consider the function","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])\nfoo (generic function with 1 method)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The fdata for x is a Tuple{NoFData, Vector{Float64}}, and its rdata is a Tuple{Float64, NoRData}. The function returns a Float64, which has no fdata, and whose rdata is Float64. So on the forwards pass there is really nothing that needs to happen with the fdata for x.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Under the framework introduced above, the model for this function is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"f(x) = (x x_1 + sum_n=1^N (x_2)_n)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where the vector in the second element of x is of length N. Now, following our usual steps, the derivative is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x(dotx) = (dotx dotx_1 + sum_n=1^N (dotx_2)_n)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"A gradient for this is a tuple (bary_x bary_a) where bary_a in RR and bary_x in RR times RR^N. A quick derivation will show that the adjoint is","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"D f x^ast(bary) = ((bary_x)_1 + bary_a (bary_x)_2 + bary_a mathbf1)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where mathbf1 is the vector of length N in which each element is equal to 1. (Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Now that we know what the adjoint is, we'll write down the rrule!!, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding – Mooncake.jl should be relied upon to generate this code automatically in practice.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})\n dx_fdata = x.dx\n function dfoo_adjoint(dy::Float64)\n dx_fdata[2] .+= dy\n dx_1_rdata = dy\n dx_rdata = (dx_1_rdata, NoRData())\n return NoRData(), dx_rdata\n end\n x_p = x.x\n return CoDual(x_p[1] + sum(x_p[2]), NoFData()), dfoo_adjoint\n end;\n","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"where dy is the rdata for the output to foo. The rrule!! can be called with the appropriate CoDuals:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> out, pb!! = rrule!!(CoDual(foo, NoFData()), CoDual((5.0, [1.0, 2.0]), (NoFData(), [0.0, 0.0])))\n(CoDual{Float64, NoFData}(8.0, NoFData()), var\"#dfoo_adjoint#1\"{Tuple{NoFData, Vector{Float64}}}((NoFData(), [0.0, 0.0])))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"and the pullback with appropriate rdata:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"julia> pb!!(1.0)\n(NoRData(), (1.0, NoRData()))","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"DocTestSetup = nothing","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that the forwards pass:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"computes the result of the initial function, and\npulls out the fdata for the Vector{Float64} component of the argument.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"As promised, the forwards pass really has nothing to do with the adjoint. It's just book-keeping and running the primal computation.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The reverse pass:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"increments each element of dx_fdata[2] by dy – this corresponds to (bary_x)_2 + bary_a mathbf1 in the adjoint,\nsets dx_1_rdata to dy – this corresponds (bary_x)_1 + bary_a subject to the constraint that (bary_x)_1 = 0,\nconstructs the rdata for x – this is essentially just book-keeping.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Each of these items serve to demonstrate more general points. The first that, upon entry into the reverse pass, all fdata values correspond to gradients for the arguments / output of f \"upon exit\" (for the components of these which are identified by their address), and once the reverse-pass finishes running, they must contain the gradients w.r.t. the arguments of f \"upon entry\".","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The second that we always assume that the components of bary_x which are identified by their value have zero-rdata.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"The third is that the components of the arguments of f which are identified by their value must have rdata passed back explicitly by a rule, while the components of the arguments to f which are identified by their address get their gradients propagated back implicitly (i.e. via the in-place modification of fdata).","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Reminder: the first element of the tuple returned by dfoo_adjoint is the rdata associated to foo itself, hence it is NoRData.","category":"page"},{"location":"understanding_mooncake/rule_system/#Testing","page":"Mooncake.jl's Rule System","title":"Testing","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Mooncake.jl has an almost entirely automated system for testing rules – Mooncake.TestUtils.test_rule. You should absolutely make use of these when writing rules.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"TODO: improve docstring for testing functionality.","category":"page"},{"location":"understanding_mooncake/rule_system/#Summary","page":"Mooncake.jl's Rule System","title":"Summary","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"In this section we have covered the rule system. Every callable object / function in the Julia language is differentiated using rules with this interface, whether they be hand-written rrule!!s, or rules derived by Mooncake.jl.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"At this point you should be equipped with enough information to understand what a rule in Mooncake.jl does, and how you can write your own ones. Later sections will explain how Mooncake.jl goes about deriving rules itself in a recursive manner, and introduce you to some of the internals.","category":"page"},{"location":"understanding_mooncake/rule_system/#Asides","page":"Mooncake.jl's Rule System","title":"Asides","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/#Why-Uniqueness-of-Type-For-Tangents-/-FData-/-RData?","page":"Mooncake.jl's Rule System","title":"Why Uniqueness of Type For Tangents / FData / RData?","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Why does Mooncake.jl insist that each primal type P be paired with a single tangent type T, as opposed to being more permissive. There are a few notable reasons:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"To provide a precise interface. Rules pass fdata around on the forwards pass and rdata on the reverse pass – being able to make strong assumptions about the type of the fdata / rdata given the primal type makes implementing rules much easier in practice.\nConditional type stability. We wish to have a high degree of confidence that if the primal code is type-stable, then the AD code will also be. It is straightforward to construct type stable primal codes which have type-unstable forwards and reverse passes if you permit there to be more than one fdata / rdata type for a given primal. So while uniqueness is certainly not sufficient on its own to guarantee conditional type stability, it is probably necessary in general.\nTest-case generation and coverage. There being a unique tangent / fdata / rdata type for each primal makes being confident that a given rule is being tested thoroughly much easier. For a given primal, rather than there being many possible input / output types to consider, there is just one.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"This topic, in particular what goes wrong with permissive tangent type systems like those employed by ChainRules, deserves a more thorough treatment – hopefully someone will write something more expansive on this topic at some point.","category":"page"},{"location":"understanding_mooncake/rule_system/#Why-Support-Closures-But-Not-Mutable-Globals","page":"Mooncake.jl's Rule System","title":"Why Support Closures But Not Mutable Globals","text":"","category":"section"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"First consider why closures are straightforward to support. Look at the type of the closure produced by foo:","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"function foo(x)\n function bar(y)\n x .+= y\n return nothing\n end\n return bar\nend\nbar = foo(randn(5))\ntypeof(bar)\n\n# output\nvar\"#bar#1\"{Vector{Float64}}","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Observe that the Vector{Float64} that we passed to foo, and closed over in bar, is present in the type. This alludes to the fact that closures are basically just callable structs whose fields are the closed-over variables. Since the function itself is an argument to its rule, everything enters the rule for bar via its arguments, and the rule system developed in this document applies straightforwardly.","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"On the other hand, globals do not appear in the functions that they are a part of. For example,","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"const a = randn(10)\n\nfunction g(x)\n a .+= x\n return nothing\nend\n\ntypeof(g)\n\n# output\ntypeof(g) (singleton type of function g, subtype of Function)","category":"page"},{"location":"understanding_mooncake/rule_system/","page":"Mooncake.jl's Rule System","title":"Mooncake.jl's Rule System","text":"Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.","category":"page"},{"location":"utilities/tools_for_rules/#Tools-for-Rules","page":"Tools for Rules","title":"Tools for Rules","text":"","category":"section"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported. However, this does not always necessitate writing your own rrule!! from scratch. In this section, we detail some useful strategies which can help you avoid having to write rrule!!s in many situations.","category":"page"},{"location":"utilities/tools_for_rules/#Simplfiying-Code-via-Overlays","page":"Tools for Rules","title":"Simplfiying Code via Overlays","text":"","category":"section"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@mooncake_overlay","category":"page"},{"location":"utilities/tools_for_rules/#Mooncake.@mooncake_overlay","page":"Tools for Rules","title":"Mooncake.@mooncake_overlay","text":"@mooncake_overlay method_expr\n\nDefine a method of a function which only Mooncake can see. This can be used to write versions of methods which can be successfully differentiated by Mooncake if the original cannot be.\n\nFor example, suppose that you have a function\n\njulia> foo(x::Float64) = bar(x)\nfoo (generic function with 1 method)\n\nwhere Mooncake.jl fails to differentiate bar for some reason. If you have access to another function baz, which does the same thing as bar, but does so in a way which Mooncake.jl can differentiate, you can simply write:\n\njulia> Mooncake.@mooncake_overlay foo(x::Float64) = baz(x)\n\n\nWhen looking up the code for foo(::Float64), Mooncake.jl will see this method, rather than the original, and differentiate it instead.\n\nA Worked Example\n\nTo demonstrate how to use @mooncake_overlays in practice, we here demonstrate how the answer that Mooncake.jl gives changes if you change the definition of a function using a @mooncake_overlay. Do not do this in practice – this is just a simple way to demonostrate how to use overlays!\n\nFirst, consider a simple example:\n\njulia> scale(x) = 2x\nscale (generic function with 1 method)\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(10.0, (NoTangent(), 2.0))\n\nWe can use @mooncake_overlay to change the definition which Mooncake.jl sees:\n\njulia> Mooncake.@mooncake_overlay scale(x) = 3x\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(15.0, (NoTangent(), 3.0))\n\nAs can be seen from the output, the result of differentiating using Mooncake.jl has changed to reflect the overlay-ed definition of the method.\n\nAdditionally, it is possible to use the usual multi-line syntax to declare an overlay:\n\njulia> Mooncake.@mooncake_overlay function scale(x)\n return 4x\n end\n\njulia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64});\n\njulia> Mooncake.value_and_gradient!!(rule, scale, 5.0)\n(20.0, (NoTangent(), 4.0))\n\n\n\n\n\n","category":"macro"},{"location":"utilities/tools_for_rules/#Functions-with-Zero-Adjoint","page":"Tools for Rules","title":"Functions with Zero Adjoint","text":"","category":"section"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:","category":"page"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@zero_adjoint\nMooncake.zero_adjoint","category":"page"},{"location":"utilities/tools_for_rules/#Mooncake.@zero_adjoint","page":"Tools for Rules","title":"Mooncake.@zero_adjoint","text":"@zero_adjoint ctx sig\n\nDefines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.\n\nFor example:\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo(x) = 5\nfoo (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo), Any}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo), Any})\ntrue\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3.0))[2](NoRData())\n(NoRData(), 0.0)\n\nLimited support for Varargs is also available. For example\n\njulia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive\n\njulia> foo_varargs(x...) = 5\nfoo_varargs (generic function with 1 method)\n\njulia> @zero_adjoint DefaultCtx Tuple{typeof(foo_varargs), Vararg}\n\njulia> is_primitive(DefaultCtx, Tuple{typeof(foo_varargs), Any, Float64, Int})\ntrue\n\njulia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())\n(NoRData(), 0.0, NoRData())\n\nBe aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.\n\nWARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.\n\nAs always, you should use TestUtils.test_rule to ensure that you've not made a mistake.\n\nSignatures Unsupported By This Macro\n\nIf the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.\n\n\n\n\n\n","category":"macro"},{"location":"utilities/tools_for_rules/#Mooncake.zero_adjoint","page":"Tools for Rules","title":"Mooncake.zero_adjoint","text":"zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}\n\nUtility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.\n\nNOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.\n\nYou make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:\n\njulia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual\n\njulia> foo(x::Vararg{Int}) = 5\nfoo (generic function with 1 method)\n\njulia> is_primitive(::Type{DefaultCtx}, ::Type{<:Tuple{typeof(foo), Vararg{Int}}}) = true;\n\njulia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())\n(NoRData(), NoRData(), NoRData())\n\nWARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```\n\n\n\n\n\n","category":"function"},{"location":"utilities/tools_for_rules/#Using-ChainRules.jl","page":"Tools for Rules","title":"Using ChainRules.jl","text":"","category":"section"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.","category":"page"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"There is enough similarity between these two systems that most of the boilerplate code can be avoided.","category":"page"},{"location":"utilities/tools_for_rules/","page":"Tools for Rules","title":"Tools for Rules","text":"Mooncake.@from_rrule","category":"page"},{"location":"utilities/tools_for_rules/#Mooncake.@from_rrule","page":"Tools for Rules","title":"Mooncake.@from_rrule","text":"@from_rrule ctx sig [has_kwargs=false]\n\nConvenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.\n\nArguments\n\nctx: A Mooncake context type\nsig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.\nhas_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.\n\nExample Usage\n\nA Basic Example\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real) = 5x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), 5Ω\n return foo(x), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat}\n\njulia> rrule!!(zero_fcodual(foo), zero_fcodual(5.0))[2](1.0)\n(NoRData(), 5.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(Xoshiro(123), foo, 5.0; is_primitive=true)\nTest Passed\n\nAn Example with Keyword Arguments\n\njulia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils\n\njulia> using ChainRulesCore\n\njulia> foo(x::Real; cond::Bool) = cond ? 5x : 4x;\n\njulia> function ChainRulesCore.rrule(::typeof(foo), x::Real; cond::Bool)\n foo_pb(Ω::Real) = ChainRulesCore.NoTangent(), cond ? 5Ω : 4Ω\n return foo(x; cond), foo_pb\n end;\n\njulia> @from_rrule DefaultCtx Tuple{typeof(foo), Base.IEEEFloat} true\n\njulia> _, pb = rrule!!(\n zero_fcodual(Core.kwcall),\n zero_fcodual((cond=false, )),\n zero_fcodual(foo),\n zero_fcodual(5.0),\n );\n\njulia> pb(3.0)\n(NoRData(), NoRData(), NoRData(), 12.0)\n\njulia> # Check that the rule works as intended.\n TestUtils.test_rule(\n Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true\n )\nTest Passed\n\nNotice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.\n\nLimitations\n\nIt is your responsibility to ensure that\n\ncalls with signature sig do not mutate their arguments,\nthe output of calls with signature sig does not alias any of the inputs.\n\nAs with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.\n\nArgument Type Constraints\n\nMany methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature\n\nTuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}\n\nThere are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.\n\nSuffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.\n\nConversions Between Different Tangent Type Systems\n\nUnder the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.\n\n\n\n\n\n","category":"macro"},{"location":"utilities/debug_mode/#Debug-Mode","page":"Debug Mode","title":"Debug Mode","text":"","category":"section"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"DocTestSetup = quote\n using Mooncake, ADTypes\nend","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"The Problem","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})\n plus_reverse_pass(dz::Real) = NoRData(), dz, dz\n return zero_fcodual(primal(x) + primal(y))\nend","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"and calls","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, and rdata_type(Float32) is Float32, so returning a Float64 is incorrect. This error might cause the reverse pass to fail loudly immediately, but it might also fail silently. It might cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"The Solution","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"This is implemented via DebugRRule:","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Mooncake.DebugRRule","category":"page"},{"location":"utilities/debug_mode/#Mooncake.DebugRRule","page":"Debug Mode","title":"Mooncake.DebugRRule","text":"DebugRRule(rule)\n\nConstruct a callable which is equivalent to rule, but inserts additional type checking. In particular:\n\ncheck that the fdata in each argument is of the correct type for the primal\ncheck that the fdata in the CoDual returned from the rule is of the correct type for the primal.\n\nThis happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.\n\nSome additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).\n\nLet rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.\n\nNote: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.\n\n\n\n\n\n","category":"type"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Mooncake.build_rrule","category":"page"},{"location":"utilities/debug_mode/#Mooncake.build_rrule","page":"Debug Mode","title":"Mooncake.build_rrule","text":"build_rrule(args...; debug_mode=false)\n\nHelper method. Only uses static information from args.\n\n\n\n\n\nbuild_rrule(sig::Type{<:Tuple})\n\nEquivalent to build_rrule(Mooncake.get_interpreter(), sig).\n\n\n\n\n\nbuild_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}\n\nReturns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.\n\nIf debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.\n\n\n\n\n\n","category":"function"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))\nAutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))","category":"page"},{"location":"utilities/debug_mode/#When-Should-You-Use-Debug-Mode?","page":"Debug Mode","title":"When Should You Use Debug Mode?","text":"","category":"section"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"Only use debug_mode when debugging a problem. This is because is has substantial performance implications.","category":"page"},{"location":"utilities/debug_mode/","page":"Debug Mode","title":"Debug Mode","text":"DocTestSetup = nothing","category":"page"},{"location":"understanding_mooncake/introduction/#Introduction","page":"Introduction","title":"Introduction","text":"","category":"section"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.","category":"page"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"we recap what AD is, and introduce the mathematics necessary to understand is,\nexplain how this mathematics relates to functions and data structures in Julia, and\nhow this is handled in Mooncake.jl.","category":"page"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.","category":"page"},{"location":"understanding_mooncake/introduction/#Who-Are-These-Docs-For?","page":"Introduction","title":"Who Are These Docs For?","text":"","category":"section"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.","category":"page"},{"location":"understanding_mooncake/introduction/#Prerequisites-and-Resources","page":"Introduction","title":"Prerequisites and Resources","text":"","category":"section"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.","category":"page"},{"location":"understanding_mooncake/introduction/","page":"Introduction","title":"Introduction","text":"In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and \"transposed Jacobian\". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the \"Matrix Calculus for Machine Learning and Beyond\" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.","category":"page"},{"location":"#Mooncake.jl","page":"Mooncake.jl","title":"Mooncake.jl","text":"","category":"section"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Documentation for Mooncake.jl is on its way!","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (03/10/2024): Various bits of utility functionality are now carefully documented. This includes how to change the code which Mooncake sees, declare that the derivative of a function is zero, make use of existing ChainRules.rrules to quicky create new rules in Mooncake, and more.","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (02/07/2024): The first round of documentation has arrived. This is largely targetted at those who are interested in contributing to Mooncake.jl – you can find this work in the \"Understanding Mooncake.jl\" section of the docs. There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.","category":"page"},{"location":"","page":"Mooncake.jl","title":"Mooncake.jl","text":"Note (29/05/2024): I (Will) am currently actively working on the documentation. It will be merged in chunks over the next month or so as good first drafts of sections are completed. Please don't be alarmed that not all of it is here!","category":"page"}] } diff --git a/previews/PR386/understanding_mooncake/algorithmic_differentiation/index.html b/previews/PR386/understanding_mooncake/algorithmic_differentiation/index.html index 27200380c..73f1047c2 100644 --- a/previews/PR386/understanding_mooncake/algorithmic_differentiation/index.html +++ b/previews/PR386/understanding_mooncake/algorithmic_differentiation/index.html @@ -19,4 +19,4 @@ D f [x] (\dot{x}) &= [(D \mathcal{l} [g(x)]) \circ (D g [x])](\dot{x}) \nonumber \\ &= \langle \bar{y}, D g [x] (\dot{x}) \rangle \nonumber \\ &= \langle D g [x]^\ast (\bar{y}), \dot{x} \rangle, \nonumber -\end{align}\]

from which we conclude that $D g [x]^\ast (\bar{y})$ is the gradient of the composition $l \circ g$ at $x$.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.

The above shows that if $\mathcal{Y} = \RR$ and $g$ is the function we wish to compute the gradient of, we can simply set $\bar{y} = 1$ and compute $D g [x]^\ast (\bar{y})$ to obtain the gradient of $g$ at $x$.

Summary

This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.

Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.

Asides

How does Forwards-Mode AD work?

Forwards-mode AD achieves this by breaking down $f$ into the composition $f = f_N \circ \dots \circ f_1$, where each $f_n$ is a simple function whose derivative (function) $D f_n [x_n]$ we know for any given $x_n$. By the chain rule, we have that

\[D f [x] (\dot{x}) = D f_N [x_N] \circ \dots \circ D f_1 [x_1] (\dot{x})\]

which suggests the following algorithm:

  1. let $x_1 = x$, $\dot{x}_1 = \dot{x}$, and $n = 1$
  2. let $\dot{x}_{n+1} = D f_n [x_n] (\dot{x}_n)$
  3. let $x_{n+1} = f(x_n)$
  4. let $n = n + 1$
  5. if $n = N+1$ then return $\dot{x}_{N+1}$, otherwise go to 2.

When each function $f_n$ maps between Euclidean spaces, the applications of derivatives $D f_n [x_n] (\dot{x}_n)$ are given by $J_n \dot{x}_n$ where $J_n$ is the Jacobian of $f_n$ at $x_n$.

[1]
M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).
[2]
T. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).
  • note_for_geometersin AD we only really need to discuss differentiatiable functions between vector spaces that are isomorphic to Euclidean space. Consequently, a variety of considerations which are usually required in differential geometry are not required here. Notably, the tangent space is assumed to be the same everywhere, and to be the same as the domain of the function. Avoiding these additional considerations helps keep the mathematics as simple as possible.
+\end{align}\]

from which we conclude that $D g [x]^\ast (\bar{y})$ is the gradient of the composition $l \circ g$ at $x$.

The consequence is that we can always view the computation performed by reverse-mode AD as computing the gradient of the composition of the function in question and an inner product with the argument to the adjoint.

The above shows that if $\mathcal{Y} = \RR$ and $g$ is the function we wish to compute the gradient of, we can simply set $\bar{y} = 1$ and compute $D g [x]^\ast (\bar{y})$ to obtain the gradient of $g$ at $x$.

Summary

This document explains the core mathematical foundations of AD. It explains separately what is does, and how it goes about it. Some basic examples are given which show how these mathematical foundations can be applied to differentiate functions of matrices, and Julia functions.

Subsequent sections will build on these foundations, to provide a more general explanation of what AD looks like for a Julia programme.

Asides

How does Forwards-Mode AD work?

Forwards-mode AD achieves this by breaking down $f$ into the composition $f = f_N \circ \dots \circ f_1$, where each $f_n$ is a simple function whose derivative (function) $D f_n [x_n]$ we know for any given $x_n$. By the chain rule, we have that

\[D f [x] (\dot{x}) = D f_N [x_N] \circ \dots \circ D f_1 [x_1] (\dot{x})\]

which suggests the following algorithm:

  1. let $x_1 = x$, $\dot{x}_1 = \dot{x}$, and $n = 1$
  2. let $\dot{x}_{n+1} = D f_n [x_n] (\dot{x}_n)$
  3. let $x_{n+1} = f(x_n)$
  4. let $n = n + 1$
  5. if $n = N+1$ then return $\dot{x}_{N+1}$, otherwise go to 2.

When each function $f_n$ maps between Euclidean spaces, the applications of derivatives $D f_n [x_n] (\dot{x}_n)$ are given by $J_n \dot{x}_n$ where $J_n$ is the Jacobian of $f_n$ at $x_n$.

[1]
M. Giles. An extended collection of matrix derivative results for forward and reverse mode automatic differentiation. Unpublished (2008).
[2]
T. P. Minka. Old and new matrix algebra useful for statistics. See www. stat. cmu. edu/minka/papers/matrix. html 4 (2000).
  • note_for_geometersin AD we only really need to discuss differentiatiable functions between vector spaces that are isomorphic to Euclidean space. Consequently, a variety of considerations which are usually required in differential geometry are not required here. Notably, the tangent space is assumed to be the same everywhere, and to be the same as the domain of the function. Avoiding these additional considerations helps keep the mathematics as simple as possible.
diff --git a/previews/PR386/understanding_mooncake/introduction/index.html b/previews/PR386/understanding_mooncake/introduction/index.html index e7be1f24e..060796a00 100644 --- a/previews/PR386/understanding_mooncake/introduction/index.html +++ b/previews/PR386/understanding_mooncake/introduction/index.html @@ -1,2 +1,2 @@ -Introduction · Mooncake.jl

Introduction

The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.

  1. we recap what AD is, and introduce the mathematics necessary to understand is,
  2. explain how this mathematics relates to functions and data structures in Julia, and
  3. how this is handled in Mooncake.jl.

Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.

Who Are These Docs For?

These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.

Prerequisites and Resources

This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.

In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and "transposed Jacobian". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the "Matrix Calculus for Machine Learning and Beyond" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.

+Introduction · Mooncake.jl

Introduction

The point of Mooncake.jl is to perform reverse-mode algorithmic differentiation (AD). The purpose of this section is to explain what precisely is meant by this, and how it can be interpreted mathematically.

  1. we recap what AD is, and introduce the mathematics necessary to understand is,
  2. explain how this mathematics relates to functions and data structures in Julia, and
  3. how this is handled in Mooncake.jl.

Since Mooncake.jl supports in-place operations / mutation, these will push beyond what is encountered in Zygote / Diffractor / ChainRules. Consequently, while there is a great deal of overlap with these existing systems, you will need to read through this section of the docs in order to properly understand Mooncake.jl.

Who Are These Docs For?

These are primarily designed for anyone who is interested in contributing to Mooncake.jl. They are also hopefully of interest to anyone how is interested in understanding AD more broadly. If you aren't interested in understanding how Mooncake.jl and AD work, you don't need to have read them in order to make use of this package.

Prerequisites and Resources

This introduction assumes familiarity with the differentiation of vector-valued functions – familiarity with the gradient and Jacobian matrices is a given.

In order to provide a convenient exposition of AD, we need to abstract a little further than this and make use of a slightly more general notion of the derivative, gradient, and "transposed Jacobian". Please note that, fortunately, we only ever have to handle finite dimensional objects when doing AD, so there is no need for any knowledge of functional analysis to understand what is going on here. The required concepts will be introduced here, but I cannot promise that these docs give the best exposition – they're most appropriate as a refresher and to establish notation. Rather, I would recommend a couple of lectures from the "Matrix Calculus for Machine Learning and Beyond" course, which you can find on MIT's OCW website, delivered by Edelman and Johnson (who will be familiar faces to anyone who has spent much time in the Julia world!). It is designed for undergraduates, and is accessible to anyone with some undergraduate-level linear algebra and calculus. While I recommend the whole course, Lecture 1 part 2 and Lecture 4 part 1 are especially relevant to the problems we shall discuss – you can skip to 11:30 in Lecture 4 part 1 if you're in a hurry.

diff --git a/previews/PR386/understanding_mooncake/rule_system/index.html b/previews/PR386/understanding_mooncake/rule_system/index.html index b10de33e1..e45a6fba4 100644 --- a/previews/PR386/understanding_mooncake/rule_system/index.html +++ b/previews/PR386/understanding_mooncake/rule_system/index.html @@ -51,7 +51,7 @@ x::Float64 end

you will find that it is

julia> tangent_type(Bar)
 MutableTangent{@NamedTuple{x::Float64}}

Primitive Types

We've already seen a couple of primitive types (Float64 and Int). The basic story here is that all primitive types require an explicit specification of what their tangent type must be.

One interesting case are Ptr types. The tangent type of a Ptr{P} is Ptr{T}, where T = tangent_type(P). For example

julia> tangent_type(Ptr{Float64})
-Ptr{Float64}
source

FData and RData

While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.

Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this "unique gradient address" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.

The following docstring provides the best in-depth explanation.

Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
+Ptr{Float64}
source

FData and RData

While tangents are the things used to represent gradients and are what high-level interfaces will return, they are not what gets propagated forwards and backwards by rules during AD.

Rather, during AD, Mooncake.jl makes a fundamental distinction between data which is identified by its address in memory (Arrays, mutable structs, etc), and data which is identified by its value (is-bits types such as Float64, Int, and structs thereof). In particular, memory which is identified by its address gets assigned a unique location in memory in which its gradient lives (that this "unique gradient address" system is essential will become apparent when we discuss aliasing later on). Conversely, the gradient w.r.t. a value type resides in another value type.

The following docstring provides the best in-depth explanation.

Mooncake.fdata_typeMethod
fdata_type(T)

Returns the type of the forwards data associated to a tangent of type T.

Extended help

Rules in Mooncake.jl do not operate on tangents directly. Rather, functionality is defined to split each tangent into two components, that we call fdata (forwards-pass data) and rdata (reverse-pass data). In short, any component of a tangent which is identified by its address (e.g. a mutable structs or an Array) gets passed around on the forwards-pass of AD and is incremented in-place on the reverse-pass, while components of tangents identified by their value get propagated and accumulated only on the reverse-pass.

Given a tangent type T, you can find out what type its fdata and rdata must be with fdata_type(T) and rdata_type(T) respectively. A consequence of this is that there is exactly one valid fdata type and rdata type for each primal type.

Given a tangent t, you can get its fdata and rdata using f = fdata(t) and r = rdata(t) respectively. f and r can be re-combined to recover the original tangent using the binary version of tangent: tangent(f, r). It must always hold that

tangent(fdata(t), rdata(t)) === t

The need for all of this is explained in the docs, but for now it suffices to consider our running examples again, and to see what their fdata and rdata look like.

Int

Ints are non-differentiable types, so there is nothing to pass around on the forwards- or reverse-pass. Therefore

julia> fdata_type(tangent_type(Int)), rdata_type(tangent_type(Int))
 (NoFData, NoRData)

Float64

The tangent type of Float64 is Float64. Float64s are identified by their value / have no fixed address, so

julia> (fdata_type(Float64), rdata_type(Float64))
 (NoFData, Float64)

Vector{Float64}

The tangent type of Vector{Float64} is Vector{Float64}. A Vector{Float64} is identified by its address, so

julia> (fdata_type(Vector{Float64}), rdata_type(Vector{Float64}))
 (Vector{Float64}, NoRData)

Tuple{Float64, Vector{Float64}, Int}

This is an example of a type which has both fdata and rdata. The tangent type for Tuple{Float64, Vector{Float64}, Int} is Tuple{Float64, Vector{Float64}, NoTangent}. Tuples have no fixed memory address, so we interogate each field on its own. We have already established the fdata and rdata types for each element, so we recurse to obtain:

julia> T = tangent_type(Tuple{Float64, Vector{Float64}, Int})
@@ -70,7 +70,7 @@
            z::Int
        end

has tangent type

julia> tangent_type(Bar)
 MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}

and fdata / rdata types

julia> (fdata_type(tangent_type(Bar)), rdata_type(tangent_type(Bar)))
-(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source

CoDuals

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.

A quick aside: Non-Differentiable Data

In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.

Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a $0$ tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.

Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.

TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.

The Rule Interface (Round 2)

Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.

Consider the function

julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])
+(MutableTangent{@NamedTuple{x::Float64, y, z::NoTangent}}, NoRData)

Primitive Types

As with tangents, each primitive type must specify what its fdata and rdata is. See specific examples for details.

source

CoDuals

CoDuals are simply used to bundle together a primal and an associated fdata, depending upon context. Occassionally, they are used to pair together a primal and a tangent.

A quick aside: Non-Differentiable Data

In the introduction to algorithmic differentiation, we assumed that the domain / range of function are the same as that of its derivative. Unfortunately, this story is only partly true. Matters are complicated by the fact that not all data types in Julia can reasonably be thought of as forming a Hilbert space. e.g. the String type.

Consequently we introduce the special type NoTangent, instances of which can be thought of as representing the set containing only a $0$ tangent. Morally speaking, for any non-differentiable data x, x + NoTangent() == x.

Other than non-differentiable data, the model of data in Julia as living in a real-valued finite dimensional Hilbert space is quite reasonable. Therefore, we hope readers will forgive us for largely ignoring the distinction between the domain and range of a function and that of its derivative in mathematical discussions, while simultaneously drawing a distinction when discussing code.

TODO: update this to cast e.g. each possible String as its own vector space containing only the 0 element. This works, even if it seems a little contrived.

The Rule Interface (Round 2)

Now that you've seen what data structures are used to represent gradients, we can describe in more depth the detail of how fdata and rdata are used to propagate gradients backwards on the reverse pass.

Consider the function

julia> foo(x::Tuple{Float64, Vector{Float64}}) = x[1] + sum(x[2])
 foo (generic function with 1 method)

The fdata for x is a Tuple{NoFData, Vector{Float64}}, and its rdata is a Tuple{Float64, NoRData}. The function returns a Float64, which has no fdata, and whose rdata is Float64. So on the forwards pass there is really nothing that needs to happen with the fdata for x.

Under the framework introduced above, the model for this function is

\[f(x) = (x, x_1 + \sum_{n=1}^N (x_2)_n)\]

where the vector in the second element of x is of length $N$. Now, following our usual steps, the derivative is

\[D f [x](\dot{x}) = (\dot{x}, \dot{x}_1 + \sum_{n=1}^N (\dot{x}_2)_n)\]

A gradient for this is a tuple $(\bar{y}_x, \bar{y}_a)$ where $\bar{y}_a \in \RR$ and $\bar{y}_x \in \RR \times \RR^N$. A quick derivation will show that the adjoint is

\[D f [x]^\ast(\bar{y}) = ((\bar{y}_x)_1 + \bar{y}_a, (\bar{y}_x)_2 + \bar{y}_a \mathbf{1})\]

where $\mathbf{1}$ is the vector of length $N$ in which each element is equal to $1$. (Observe that this agrees with the result we derived earlier for functions which don't mutate their arguments).

Now that we know what the adjoint is, we'll write down the rrule!!, and then explain what is going on in terms of the adjoint. This hand-written implementation is to aid your understanding – Mooncake.jl should be relied upon to generate this code automatically in practice.

julia> function rrule!!(::CoDual{typeof(foo)}, x::CoDual{Tuple{Float64, Vector{Float64}}})
            dx_fdata = x.dx
            function dfoo_adjoint(dy::Float64)
@@ -105,4 +105,4 @@
 typeof(g)
 
 # output
-typeof(g) (singleton type of function g, subtype of Function)

Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.

+typeof(g) (singleton type of function g, subtype of Function)

Neither the value nor type of a are present in g. Since a doesn't enter g via its arguments, it is unclear how it should be handled in general.

diff --git a/previews/PR386/utilities/debug_mode/index.html b/previews/PR386/utilities/debug_mode/index.html index e1eef3fc3..b8b9e9d2f 100644 --- a/previews/PR386/utilities/debug_mode/index.html +++ b/previews/PR386/utilities/debug_mode/index.html @@ -2,5 +2,5 @@ Debug Mode · Mooncake.jl

Debug Mode

The Problem

A major source of potential problems in AD systems is rules returning the wrong type of tangent / fdata / rdata for a given primal value. For example, if someone writes a rule like

function rrule!!(::CoDual{typeof(+)}, x::CoDual{<:Real}, y::CoDual{<:Real})
     plus_reverse_pass(dz::Real) = NoRData(), dz, dz
     return zero_fcodual(primal(x) + primal(y))
-end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, and rdata_type(Float32) is Float32, so returning a Float64 is incorrect. This error might cause the reverse pass to fail loudly immediately, but it might also fail silently. It might cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(sig::Type{<:Tuple})

Equivalent to build_rrule(Mooncake.get_interpreter(), sig).

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
-AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.

+end

and calls

rrule(zero_fcodual(+), zero_fcodual(5.0), zero_fcodual(4f0))

then the type of dz on the reverse pass will be Float64 (assuming everything happens correctly), and this rule will return a Float64 as the rdata for y. However, the primal value of y is a Float32, and rdata_type(Float32) is Float32, so returning a Float64 is incorrect. This error might cause the reverse pass to fail loudly immediately, but it might also fail silently. It might cause an error much later in the reverse pass, making it hard to determine that the source of the error was the above rule. Worst of all, in some cases it could plausibly cause a segfault, which is more-or-less the worst kind of outcome possible.

The Solution

Check that the types of the fdata / rdata associated to arguments are exactly what tangent_type / fdata_type / rdata_type require upon entry to / exit from rules and pullbacks.

This is implemented via DebugRRule:

Mooncake.DebugRRuleType
DebugRRule(rule)

Construct a callable which is equivalent to rule, but inserts additional type checking. In particular:

  • check that the fdata in each argument is of the correct type for the primal
  • check that the fdata in the CoDual returned from the rule is of the correct type for the primal.

This happens recursively. For example, each element of a Vector{Any} is compared against each element of the associated fdata to ensure that its type is correct, as this cannot be guaranteed from the static type alone.

Some additional dynamic checks are also performed (e.g. that an fdata array of the same size as its primal).

Let rule return y, pb!!, then DebugRRule(rule) returns y, DebugPullback(pb!!). DebugPullback inserts the same kind of checks as DebugRRule, but on the reverse-pass. See the docstring for details.

Note: at any given point in time, the checks performed by this function constitute a necessary but insufficient set of conditions to ensure correctness. If you find that an error isn't being caught by these tests, but you believe it ought to be, please open an issue or (better still) a PR.

source

You can straightforwardly enable it when building a rule via the debug_mode kwarg in the following:

Mooncake.build_rruleFunction
build_rrule(args...; debug_mode=false)

Helper method. Only uses static information from args.

source
build_rrule(sig::Type{<:Tuple})

Equivalent to build_rrule(Mooncake.get_interpreter(), sig).

source
build_rrule(interp::MooncakeInterpreter{C}, sig_or_mi; debug_mode=false) where {C}

Returns a DerivedRule which is an rrule!! for sig_or_mi in context C. See the docstring for rrule!! for more info.

If debug_mode is true, then all calls to rules are replaced with calls to DebugRRules.

source

When using ADTypes.jl, you can choose whether or not to use it via the debug_mode kwarg:

julia> AutoMooncake(; config=Mooncake.Config(; debug_mode=true))
+AutoMooncake{Mooncake.Config}(Mooncake.Config(true, false))

When Should You Use Debug Mode?

Only use debug_mode when debugging a problem. This is because is has substantial performance implications.

diff --git a/previews/PR386/utilities/debugging_and_mwes/index.html b/previews/PR386/utilities/debugging_and_mwes/index.html index 67471d660..1fa9d0dc4 100644 --- a/previews/PR386/utilities/debugging_and_mwes/index.html +++ b/previews/PR386/utilities/debugging_and_mwes/index.html @@ -7,5 +7,5 @@ interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(), debug_mode::Bool=false, unsafe_perturb::Bool=false, -)

Run standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.

This function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.

Arguments

  • rng::AbstractRNG: a random number generator
  • x...: the function (first element) and its arguments (the remainder)

Keyword Arguments

  • interface_only::Bool=false: test only that the interface is satisfied, without testing correctness. This should generally be set to false (the default value), and only enabled if the testing infrastructure is unable to test correctness for some reason e.g. the returned value of the function is a Ptr, and appropriate tangents cannot, therefore, be generated for it automatically.
  • is_primitive::Bool=true: check whether the thing that you are testing has a hand-written rrule!!. This option is helpful if you are testing a new rrule!!, as it enables you to verify that your method of is_primitive has returned the correct value, and that you are actually testing a method of the rrule!! function – a common mistake when authoring a new rrule!! is to implement is_primitive incorrectly and to accidentally wind up testing a rule which Mooncake has derived, as opposed to the one that you have written. If you are testing something for which you have not hand-written an rrule!!, or which you do not care whether it has a hand-written rrule!! or not, you should set it to false.
  • perf_flag::Symbol=:none: the value of this symbol determines what kind of performance tests should be performed. By default, none are performed. If you believe that a rule should be allocation-free (iff the primal is allocation free), set this to :allocs. If you hand-write an rrule!! and believe that your test case should be type stable, set this to :stability (at present we cannot verify whether a derived rule is type stable for technical reasons). If you believe that a hand-written rule should be both allocation-free and type-stable, set this to :stability_and_allocs.
  • interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(): the abstract interpreter to be used when testing this rule. The default should generally be used.
  • debug_mode::Bool=false: whether or not the rule should be tested in debug mode. Typically this should be left at its default false value, but if you are finding that the tests are failing for a given rule, you may wish to temporarily set it to true in order to get access to additional information and automated testing.
  • unsafe_perturb::Bool=false: value passed as the third argument to _add_to_primal. Should usually be left false – consult the docstring for _add_to_primal for more info on when you might wish to set it to true.
source

This approach is convenient because it can

  1. check whether AD runs at all,
  2. check whether AD produces the correct answers,
  3. check whether AD is performant, and
  4. can be used without having to manually generate tangents.

Example

For example

f(x) = Core.bitcast(Float64, x)
-Mooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)

will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)

In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.

Segfaults

These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.

+)

Run standardised tests on the rule for x. The first element of x should be the primal function to test, and each other element a positional argument. In most cases, elements of x can just be the primal values, and randn_tangent can be relied upon to generate an appropriate tangent to test. Some notable exceptions exist though, in partcular Ptrs. In this case, the argument for which randn_tangent cannot be readily defined should be a CoDual containing the primal, and a manually constructed tangent field.

This function uses Mooncake.build_rrule to construct a rule. This will use an rrule!! if one exists, and derive a rule otherwise.

Arguments

  • rng::AbstractRNG: a random number generator
  • x...: the function (first element) and its arguments (the remainder)

Keyword Arguments

  • interface_only::Bool=false: test only that the interface is satisfied, without testing correctness. This should generally be set to false (the default value), and only enabled if the testing infrastructure is unable to test correctness for some reason e.g. the returned value of the function is a Ptr, and appropriate tangents cannot, therefore, be generated for it automatically.
  • is_primitive::Bool=true: check whether the thing that you are testing has a hand-written rrule!!. This option is helpful if you are testing a new rrule!!, as it enables you to verify that your method of is_primitive has returned the correct value, and that you are actually testing a method of the rrule!! function – a common mistake when authoring a new rrule!! is to implement is_primitive incorrectly and to accidentally wind up testing a rule which Mooncake has derived, as opposed to the one that you have written. If you are testing something for which you have not hand-written an rrule!!, or which you do not care whether it has a hand-written rrule!! or not, you should set it to false.
  • perf_flag::Symbol=:none: the value of this symbol determines what kind of performance tests should be performed. By default, none are performed. If you believe that a rule should be allocation-free (iff the primal is allocation free), set this to :allocs. If you hand-write an rrule!! and believe that your test case should be type stable, set this to :stability (at present we cannot verify whether a derived rule is type stable for technical reasons). If you believe that a hand-written rule should be both allocation-free and type-stable, set this to :stability_and_allocs.
  • interp::Mooncake.MooncakeInterpreter=Mooncake.get_interpreter(): the abstract interpreter to be used when testing this rule. The default should generally be used.
  • debug_mode::Bool=false: whether or not the rule should be tested in debug mode. Typically this should be left at its default false value, but if you are finding that the tests are failing for a given rule, you may wish to temporarily set it to true in order to get access to additional information and automated testing.
  • unsafe_perturb::Bool=false: value passed as the third argument to _add_to_primal. Should usually be left false – consult the docstring for _add_to_primal for more info on when you might wish to set it to true.
source

This approach is convenient because it can

  1. check whether AD runs at all,
  2. check whether AD produces the correct answers,
  3. check whether AD is performant, and
  4. can be used without having to manually generate tangents.

Example

For example

f(x) = Core.bitcast(Float64, x)
+Mooncake.TestUtils.test_rule(Random.Xoshiro(123), f, 3; is_primitive=false)

will error. (In this particular case, it is caused by Mooncake.jl preventing you from doing (potentially) unsafe casting. In this particular instance, Mooncake.jl just fails to compile, but in other instances other things can happen.)

In any case, the point here is that Mooncake.TestUtils.test_rule provides a convenient way to produce and report an error.

Segfaults

These are everyone's least favourite kind of problem, and they should be extremely rare in Mooncake.jl. However, if you are unfortunate enough to encounter one, please re-run your problem with the debug_mode kwarg set to true. See Debug Mode for more info. In general, this will catch problems before they become segfaults, at which point the above strategy for debugging and error reporting should work well.

diff --git a/previews/PR386/utilities/tools_for_rules/index.html b/previews/PR386/utilities/tools_for_rules/index.html index f2b3cbc47..4733d98d8 100644 --- a/previews/PR386/utilities/tools_for_rules/index.html +++ b/previews/PR386/utilities/tools_for_rules/index.html @@ -19,7 +19,7 @@ julia> rule = Mooncake.build_rrule(Tuple{typeof(scale), Float64}); julia> Mooncake.value_and_gradient!!(rule, scale, 5.0) -(20.0, (NoTangent(), 4.0))source

Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:

Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
+(20.0, (NoTangent(), 4.0))
source

Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common situation that the adjoint of the derivative of your function is always zero, you can very straightforwardly write a rule by making use of the following:

Mooncake.@zero_adjointMacro
@zero_adjoint ctx sig

Defines is_primitive(context_type, sig) = true, and defines a method of Mooncake.rrule!! which returns zero for all inputs. Users of ChainRules.jl should be familiar with this functionality – it is morally the same as ChainRulesCore.@non_differentiable.

For example:

julia> using Mooncake: @zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive
 
 julia> foo(x) = 5
 foo (generic function with 1 method)
@@ -41,7 +41,7 @@
 true
 
 julia> rrule!!(zero_fcodual(foo_varargs), zero_fcodual(3.0), zero_fcodual(5))[2](NoRData())
-(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.zero_adjointFunction
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
+(NoRData(), 0.0, NoRData())

Be aware that it is not currently possible to specify any of the type parameters of the Vararg. For example, the signature Tuple{typeof(foo), Vararg{Float64, 5}} will not work with this macro.

WARNING: this is only correct if the output of the function does not alias any fields of the function, or any of its arguments. For example, applying this macro to the function x -> x will yield incorrect results.

As always, you should use TestUtils.test_rule to ensure that you've not made a mistake.

Signatures Unsupported By This Macro

If the signature you wish to apply @zero_adjoint to is not supported, for example because it uses a Vararg with a type parameter, you can still make use of zero_adjoint.

source
Mooncake.zero_adjointFunction
zero_adjoint(f::CoDual, x::Vararg{CoDual, N}) where {N}

Utility functionality for constructing rrule!!s for functions which produce adjoints which always return zero.

NOTE: you should only make use of this function if you cannot make use of the @zero_adjoint macro.

You make use of this functionality by writing a method of Mooncake.rrule!!, and passing all of its arguments (including the function itself) to this function. For example:

julia> import Mooncake: zero_adjoint, DefaultCtx, zero_fcodual, rrule!!, is_primitive, CoDual
 
 julia> foo(x::Vararg{Int}) = 5
 foo (generic function with 1 method)
@@ -51,7 +51,7 @@
 julia> rrule!!(f::CoDual{typeof(foo)}, x::Vararg{CoDual{Int}}) = zero_adjoint(f, x...);
 
 julia> rrule!!(zero_fcodual(foo), zero_fcodual(3), zero_fcodual(2))[2](NoRData())
-(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source

Using ChainRules.jl

ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
+(NoRData(), NoRData(), NoRData())

WARNING: this is only correct if the output of primal(f)(map(primal, x)...) does not alias anything in f or x. This is always the case if the result is a bits type, but more care may be required if it is not. ```

source

Using ChainRules.jl

ChainRules.jl provides a large number of rules for differentiating functions in reverse-mode. These rules are methods of the ChainRulesCore.rrule function. There are some instances where it is most convenient to implement a Mooncake.rrule!! by wrapping an existing ChainRulesCore.rrule.

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

Mooncake.@from_rruleMacro
@from_rrule ctx sig [has_kwargs=false]

Convenience functionality to assist in using ChainRulesCore.rrules to write rrule!!s.

Arguments

  • ctx: A Mooncake context type
  • sig: the signature which you wish to assert should be a primitive in Mooncake.jl, and use an existing ChainRulesCore.rrule to implement this functionality.
  • has_kwargs: a Bool state whether or not the function has keyword arguments. This feature has the same limitations as ChainRulesCore.rrule – the derivative w.r.t. all kwargs must be zero.

Example Usage

A Basic Example

julia> using Mooncake: @from_rrule, DefaultCtx, rrule!!, zero_fcodual, TestUtils
 
 julia> using ChainRulesCore
 
@@ -96,4 +96,4 @@
        TestUtils.test_rule(
            Xoshiro(123), Core.kwcall, (cond=false, ), foo, 5.0; is_primitive=true
        )
-Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source
+Test Passed

Notice that, in order to access the kwarg method we must call the method of Core.kwcall, as Mooncake's rrule!! does not itself permit the use of kwargs.

Limitations

It is your responsibility to ensure that

  1. calls with signature sig do not mutate their arguments,
  2. the output of calls with signature sig does not alias any of the inputs.

As with all hand-written rules, you should definitely make use of TestUtils.test_rule to verify correctness on some test cases.

Argument Type Constraints

Many methods of ChainRuleCore.rrule are implemented with very loose type constraints. For example, it would not be surprising to see a method of rrule with the signature

Tuple{typeof(rrule), typeof(foo), Real, AbstractVector{<:Real}}

There are a variety of reasons for this way of doing things, and whether it is a good idea to write rules for such generic objects has been debated at length.

Suffice it to say, you should not write rules for this package which are so generically typed. Rather, you should create rules for the subset of types for which you believe that the ChainRulesCore.rrule will work correctly, and leave this package to derive rules for the rest. For example, it is quite common to be confident that a given rule will work correctly for any Base.IEEEFloat argument, i.e. Union{Float16, Float32, Float64}, but it is usually not possible to know that the rule is correct for all possible subtypes of Real that someone might define.

Conversions Between Different Tangent Type Systems

Under the hood, this functionality relies on two functions: Mooncake.to_cr_tangent, and Mooncake.increment_and_get_rdata!. These two functions handle conversion to / from Mooncake tangent types and ChainRulesCore tangent types. This functionality is known to work well for simple types, but has not been tested to a great extent on complicated composite types. If @from_rrule does not work in your case because the required method of either of these functions does not exist, please open an issue.

source