From d3d32c0ce8e4df3fa413f4be9ec6b5197036f903 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Fri, 22 Mar 2024 15:45:35 +0000 Subject: [PATCH] Source-to-source (#90) * Implement BBCode * Add IR optimisation functionality * Only use type of ctx in TInterp * Factor out ir lookup * Remove redundant invoke-to-call functionality * Remove redundant return node creation code * Improve comment * Import more things during testing * Initial work on statement-specific transformations * Include s2s files * Add more cases in statement translation * Remove tests associated to removed code * Fix docstring * Improve comment * Improve docstrings * Improve comments * Improve comment * Improve comment * Throw informative error message for PhiCNodes * Throw error for UpsilonNode * Disable tests for interpreter during development * Initial pass over call transformation * More work * Basics working * Fix tests * Print debug info * Print debug info * Ignore code coverage statements * Restrict to 1.10 * Fix CI and comment out IR inspection * Enable all tests * Improve performance result display * Make const work * Enable more tests * Tidy up throw_if_not_def * Fix up throw_undef_if_not * Improve throw_undef_if_not test * Add additional GlobalRef test * Improve BBCode documentation * Loosen perf tolerance on handwritten sum * Clean up register transformation * Remove redundant code * Tidy up formatting * More register tests * Tidy up output check * Use new tidier register functionality * Reactivate more tests * Unhandle feature exception and globalref typeof * Improve handling of GlobalRefs * Update literals and QuoteNodes to have stacks when differentiable * Handle PiNode in BBCode * Update register types * Add helper functionality to augmented register * Import ad_stmt_info in front_matter * Reactive more tests * Refactor registers and implement PiNode * Move inc_arg_numbers to utils and unit test * Move around include order * Refactor captures handling * Enable more tests and mostly fix PhiNode * Fix up phinode problem * Fix comment * Fix typos * Fix vector elementtype in bbcode * Make varargs + splatting work * Add ReverseDiff to test deps * Move to testing s2s * Fix some lingering bugs * Fix comment and disable bad test case * Use s2s in benchmarks * Helper functiong * Fix comment * Some minor improvements to shared data and compiler * Additional test case * Do not inline stack pushes and pops * Improve compile times * Fix bug * Fix union of registers bug * Readability improvements * Cache oc compilation * Strip code coverage lines * Ignore thing with inlining problem * Do not store input ref stacks if singleton type * Formatting * NoTangent for Tuples * NamedTuple NoTangent * Improve predecessor compute times * Stop printing * Inline tuple constructor * Ignore Base check_reducedims * Make _getter constant * Move tuple_map and extend it * Refine DerivedRule construction * Fix new for NoTangent result types * Fix new * Remove redundant commented lines * Add non-differentiable function to test utils * Add comments, tidy up, rename some things * Add non-differentiable const tester * Rename my_tangent_stack to ret_tangent_stack * Remove redundant line * Support non-constant global refs * Don't verify IR after passes * Support copyast * Use full benchmarking * Tidy up some abstract type edge cases * Align registers with OpaqueClosure type inference * Tidy up types * Inline stuff again * Safer implementation of ipiv * Make memory in Stack constant * Formatting and lgetfield tests * Formatting * lsetfield tests * Move TestResources import around * NoTangent for composites with NoTangent fields * GC preserve stuff * Add NoTangent path to ifelse rule * Enable multiple lines for reverse-pass * Construct arg tangent stack refs in function * Fix basic tests * Use fixed stacks * Exclude some rrules from DiffRules * Use inbounds * Add additional test and tighten performance req * Use fixed-location tangent stacks, and 32-bit block numbering * Display which benchmark is running * Fix typo * Fix low_level_rules tests * Fix caching * Functionality to reset global ID counter * Try not inlining block pushes * Improve bbcode documentation * Remove redundant code * Formatting * Formatting * Document unhandled_feature * Use type information in BBCode and update s2s to reflect this * Fix PhiNode inference * Update PhiNode transform unit tests * Improve formatting of test_utils * Ignore local scratch file * Fix Turing hanging bug * Fix derived rule tester * Fix up testS * Use fixed tangent stack for PiNode * Tweak bounds for test utils * Fix Distributions deprecations * Revert PiNode fixed stack update * Try more stuff to fix getrf pullback * Fix test tolerances * Fix typo in comment * Revert attempted LAPACK fix * Restrict primal evals to 1 * Extend preservation to cover ccall * Make copy of ipiv after calculations are run * Improve tuple tangent_type * Inline getfield rules * Force uninit_codual to inline * Force-inline uninit_tangent * NoTangentStack for DataType * Revert change to tangent stack type * Avoid recompilation in dynamic dispatch * Tighten performance bounds * Tighten performance bounds on naive mat mul * Enable all Turing.jl tests * Remove redundant comment * Remove comment and add blank line * Formatting * Improve docstrings and comments in s2s * Remove interpreter timings from Turing integration tests * Remove redundant import in benchmarks * Move value_and_gradient to interface file * Remove redundant arg in benchmarking * Improve interface * Update README * Formatting * Improve README * Do not export increment_field --- .gitignore | 1 + Project.toml | 4 +- README.md | 36 +- bench/run_benchmarks.jl | 45 +- src/Taped.jl | 17 +- src/codual.jl | 2 +- src/interface.jl | 60 ++ src/interpreter/abstract_interpretation.jl | 10 +- src/interpreter/bbcode.jl | 584 +++++++++++ src/interpreter/interpreted_function.jl | 21 +- src/interpreter/ir_utils.jl | 145 ++- src/interpreter/registers.jl | 46 + src/interpreter/s2s_reverse_mode_ad.jl | 1033 +++++++++++++++++++ src/rrules/blas.jl | 2 +- src/rrules/builtins.jl | 125 ++- src/rrules/lapack.jl | 46 +- src/rrules/low_level_maths.jl | 2 + src/rrules/misc.jl | 91 +- src/rrules/new.jl | 26 +- src/stack.jl | 37 +- src/tangents.jl | 60 +- src/test_utils.jl | 212 ++-- src/utils.jl | 22 + test/front_matter.jl | 14 +- test/integration_testing/array.jl | 7 +- test/integration_testing/battery_tests.jl | 1 + test/integration_testing/diff_tests.jl | 2 +- test/integration_testing/distributions.jl | 4 +- test/integration_testing/gp.jl | 2 +- test/integration_testing/misc.jl | 2 +- test/integration_testing/turing.jl | 43 +- test/interface.jl | 13 + test/interpreter/abstract_interpretation.jl | 6 +- test/interpreter/bbcode.jl | 52 + test/interpreter/ir_utils.jl | 23 + test/interpreter/registers.jl | 8 + test/interpreter/s2s_reverse_mode_ad.jl | 230 +++++ test/runtests.jl | 9 +- test/tangents.jl | 33 +- test/test_utils.jl | 9 - test/utils.jl | 23 +- 41 files changed, 2737 insertions(+), 371 deletions(-) create mode 100644 src/interface.jl create mode 100644 src/interpreter/bbcode.jl create mode 100644 src/interpreter/registers.jl create mode 100644 src/interpreter/s2s_reverse_mode_ad.jl create mode 100644 test/interface.jl create mode 100644 test/interpreter/bbcode.jl create mode 100644 test/interpreter/registers.jl create mode 100644 test/interpreter/s2s_reverse_mode_ad.jl diff --git a/.gitignore b/.gitignore index 1d192e553..00e220792 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ bench/Manifest.toml analysis_results .vscode profile.pb.gz +scratch.jl diff --git a/Project.toml b/Project.toml index 72a8b5c99..02783efe6 100644 --- a/Project.toml +++ b/Project.toml @@ -9,6 +9,7 @@ ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" DiffRules = "b552c78f-8df3-52c6-915a-8e097449b14b" DiffTests = "de460e47-3fe3-5279-bb4a-814414816d5d" ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04" +Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -46,10 +47,11 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" [targets] -test = ["AbstractGPs", "BenchmarkTools", "DiffTests", "Distributions", "FillArrays", "KernelFunctions", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "Turing"] +test = ["AbstractGPs", "BenchmarkTools", "DiffTests", "Distributions", "FillArrays", "KernelFunctions", "PDMats", "ReverseDiff", "SpecialFunctions", "StableRNGs", "Test", "Turing"] diff --git a/README.md b/README.md index 1d081e877..50d95c504 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ The goal of the `Taped.jl` project is to produce a reverse-mode AD package which # How it works -`Taped.jl` is based around a single function `rrule!!`, which computes vector-Jacobian products (VJPs). +`Taped.jl` is based around a function `rrule!!` (which computes vector-Jacobian products (VJPs)) and a related function `build_rrule` (which builds functions which are semantically identical to `rrule!!`). These VJPs can, for example, be used to compute gradients. `rrule!!` is similar to ChainRules' `rrule` and Zygote's `_pullback`, but supports functions which mutate (modify) their arguments, in addition to those that do not, and immediately increments (co)tangents. It has, perhaps unsurprisingly, wound up looking quite similar to the rule system in Enzyme. @@ -18,7 +18,7 @@ For a given function and arguments, it is roughly speaking the case that either 2. no hand-written method of `rrule!!` is applicable. In the first case, we run the `rrule!!`. -In the second, we create an `rrule!!` by "doing AD" -- we decompose the function into a composition of functions which _do_ have hand-written `rrule!!`s. +In the second, we use `build_rrule` to create a function with the same semantics as `rrule!!` by "doing AD" -- we decompose the function into a composition of functions which _do_ have hand-written `rrule!!`s. In general, the goal is to write as few hand-written `rrule!!`s as is necessary, and to "do AD" for the vast majority of functions. @@ -48,9 +48,10 @@ All of our testing is implemented via this (or via another function which calls This contrasts with `Zygote.jl` / `ChainRules.jl`, where the permissive (co)tangent type system complicates both composition of `rrule`s and testing. -Additionally, our approach to AD naturally handles control flow which differs between calls of a function. This contrasts with e.g. `ReverseDiff.jl`'s compiled tape, which can give silent numerical errors if control flow ought to differ between gradient evaluations at different arguments. -~~Additionally, we augment the tape that we construct with additional instructions which throw an error if control flow differs from when the tape was constructed. -This contrasts with `ReverseDiff.jl`, which silently fails in this scenario.~~ +Additionally, our approach to AD naturally handles control flow which differs between multiple calls to the same function. +This contrasts with e.g. `ReverseDiff.jl`'s compiled tape, which can give silent numerical errors if control flow ought to differ between gradient evaluations at different arguments. +~~Additionally, we augment the tape that we construct with additional instructions which throw an error if control flow differs from when the tape was constructed.~~ +~~This contrasts with `ReverseDiff.jl`, which silently fails in this scenario.~~ ### Performance @@ -58,7 +59,8 @@ Hand-written `rrule!!`s have excellent performance, provided that they have been Consequently, whether or not the overall AD system has good performance is largely a question of how much overhead is associated to the mechanism by which hand-written `rrules!!`s are algorithmically composed. ~~At present (11/2023), we do _not_ do this in a performant way, but this will change.~~ -At present (01/2024), we do this in a _moderately_ performant way. +~~At present (01/2024), we do this in a _moderately_ performant way.~~ +At present (03/2024), we do this in a _moderately_ performant way (but better than the previous way!) See [Project Status](#project-status) below for more info. Additionally, the strategy of immediately incrementing (co)tangents resolves long-standing performance issues associated with indexing arrays. @@ -85,8 +87,14 @@ The plan is to proceed in three phases: You should take this with a pinch of salt, as it seems highly likely that we will have to revisit some design choices when optimising performance -- we do not, however, anticipate requiring major re-writes to the design as part of performance optimisation. We aim to reach the maintenance phase of the project before 01/06/2024. +*Update: (22/03/2024)* +Phase 2 is now further along. +`Taped.jl` now uses something which could reasonably be described as a source-to-source system to perform AD. +At present the performance of this system is not as good as that of Enzyme, but often beats compiled ReverseDiff, and comfortably beats Zygote in any situations involving dynamic control flow. +The present focus is on dealing with some remaining performance limitations that should make `Taped.jl`'s performance much closer to that of Enzyme, and consistently beat ReverseDiff on a range of benchmarks. + *Update: (16/01/2024)* -Phase 2 is now well underway. We now make use of a much faster approach to interpreting / executing Julia code, which yields performance that is comparable with ReverseDiff (when things go well). The current focus is on ironing out performance issues, and simplifying the implementation. +~~Phase 2 is now well underway. We now make use of a much faster approach to interpreting / executing Julia code, which yields performance that is comparable with ReverseDiff (when things go well). The current focus is on ironing out performance issues, and simplifying the implementation.~~ *Update: (06/11/2023)* ~~We are mostly through the first phase.~~ @@ -99,9 +107,8 @@ Phase 2 is now well underway. We now make use of a much faster approach to inter # Trying it out -There is not presently a high-level interface to which we are commiting, but if you want to -compute the gradient of a function, take a look at -`Taped.TestUtils.set_up_gradient_problem` and `Taped.TestUtils.value_and_gradient!!`. +There is not presently a high-level interface to which we are yet commiting, but if you want to compute the gradient of a function, take a look at `value_and_pullback!!` / `value_and_gradient!!`. +They both provide a high-level interface which will let you differentiate things, and their implementation demonstrates how an `rrule!!` / rrule-like function should be used. *Note:* I have found that using a mixture of `PProf` and the `@profview` functionality from Julia's `VSCode` extension essential when profiling code generated by `Taped.jl`. `PProf` provides complete type information on its flame graphs, which is important for figuring out what is getting called, but it doesn't highilght type-instabilities. @@ -115,13 +122,14 @@ Noteworthy things which should be work and be performant include: 1. value-dependent control flow 1. mutation of arrays and mutable structs -These are noteworthy in the sense that they are different from ReverseDiff / Zygote. Enzyme is also able to do these things. +These are noteworthy in the sense that they are different from ReverseDiff / Zygote. +Enzyme is also able to do these things. -Please be aware that by "performant" we mean similar performance to ReverseDiff with a compiled tape. +Please be aware that by "performant" we mean similar or better performance than ReverseDiff with a compiled tape, but not as good performance as Enzyme. ### What won't work -While Taped should now work on a very large subset of the language, there remain things that you should expect not to work. A non-exhaustive list of things to bear in mind includes: +While `Taped.jl` should now work on a very large subset of the language, there remain things that you should expect not to work. A non-exhaustive list of things to bear in mind includes: 1. It is always necessary to produce hand-written for `ccall`s (and, more generally, foreigncall nodes). We have rules for many `ccall`s, but not all. If you encounter a foreigncall without a hand-written rule, you should get an informative error message which tells you what is going on and how to deal with it. -1. Builtins which require rules. The vast majority of them have rules now, but some don't. Notably, `apply_iterate` does not have a rule, so Taped cannot currently AD through type-unstable splatting -- someone should resolve this. +1. Builtins which require rules. The vast majority of them have rules now, but some don't. Notably, `apply_iterate` does not have a rule, so `Taped.jl` cannot currently AD through type-unstable splatting -- someone should resolve this. 1. Anything involving tasks / threading -- we have no thread safety guarantees and, at the time of writing, I'm not entirely sure what error you will find if you attempt to AD through code which uses Julia's task / thread system. The same applies to distributed computing. These limitations ought to be possible to resolve. diff --git a/bench/run_benchmarks.jl b/bench/run_benchmarks.jl index 9e3561806..7b5d30297 100644 --- a/bench/run_benchmarks.jl +++ b/bench/run_benchmarks.jl @@ -27,7 +27,7 @@ using Taped: TInterp, _typeof -using Taped.TestUtils: _deepcopy, to_benchmark, set_up_gradient_problem +using Taped.TestUtils: _deepcopy, to_benchmark function zygote_to_benchmark(ctx, x::Vararg{Any, N}) where {N} out, pb = Zygote._pullback(ctx, x...) @@ -154,12 +154,8 @@ function generate_inter_framework_tests() ] end -function benchmark_rules!!( - test_case_data, - default_ratios, - include_other_frameworks::Bool, - tune_benchmarks::Bool, -) +function benchmark_rules!!(test_case_data, default_ratios, include_other_frameworks::Bool) + test_cases = reduce(vcat, map(first, test_case_data)) memory = map(x -> x[2], test_case_data) ranges = reduce(vcat, map(x -> x[3], test_case_data)) @@ -167,53 +163,54 @@ function benchmark_rules!!( GC.@preserve memory begin results = map(enumerate(test_cases)) do (n, args) @info "$n / $(length(test_cases))", _typeof(args) - suite = BenchmarkGroup() + suite = Dict() # Benchmark primal. primals = map(x -> x isa CoDual ? primal(x) : x, args) - suite["primal"] = @benchmarkable( + @info "primal" + suite["primal"] = @benchmark( (a[1][])((a[2][])...); setup=(a = (Ref($primals[1]), Ref(_deepcopy($primals[2:end])))), + evals=1, ) # Benchmark AD via Taped. - rule, in_f = set_up_gradient_problem(args...) + @info "taped" + rule = Taped.build_rrule(args...) coduals = map(x -> x isa CoDual ? x : zero_codual(x), args) - suite["taped"] = @benchmarkable( - to_benchmark($rule, zero_codual($in_f), $coduals...); - ) + to_benchmark(rule, coduals...) + suite["taped"] = @benchmark(to_benchmark($rule, $coduals...)) if include_other_frameworks if should_run_benchmark(Val(:zygote), args...) - suite["zygote"] = @benchmarkable( + @info "zygote" + suite["zygote"] = @benchmark( zygote_to_benchmark($(Zygote.Context()), $primals...) ) end if should_run_benchmark(Val(:reverse_diff), args...) + @info "reversediff" tape = ReverseDiff.GradientTape(primals[1], primals[2:end]) compiled_tape = ReverseDiff.compile(tape) result = map(x -> randn(size(x)), primals[2:end]) - suite["rd"] = @benchmarkable( + suite["rd"] = @benchmark( rd_to_benchmark!($result, $compiled_tape, $primals[2:end]) ) end if should_run_benchmark(Val(:enzyme), args...) + @info "enzyme" dup_args = map(x -> Duplicated(x, randn(size(x))), primals[2:end]) - suite["enzyme"] = @benchmarkable( + suite["enzyme"] = @benchmark( autodiff(Reverse, $primals[1], Active, $dup_args...) ) end end - if tune_benchmarks - @info "tuning" - tune!(suite) - end @info "running" - return (args, run(suite; verbose=true)) + return (args, suite) end end return combine_results.(results, tags, ranges, Ref(default_ratios)) @@ -259,7 +256,7 @@ function benchmark_hand_written_rrules!!(rng_ctor) tags = fill(nothing, length(test_cases)) return map(x -> x[4:end], test_cases), memory, ranges, tags end - return benchmark_rules!!(test_case_data, (lb=1e-3, ub=25.0), false, false) + return benchmark_rules!!(test_case_data, (lb=1e-3, ub=25.0), false) end function benchmark_derived_rrules!!(rng_ctor) @@ -271,7 +268,7 @@ function benchmark_derived_rrules!!(rng_ctor) tags = fill(nothing, length(test_cases)) return map(x -> x[4:end], test_cases), memory, ranges, tags end - return benchmark_rules!!(test_case_data, (lb=1e-3, ub=150), false, false) + return benchmark_rules!!(test_case_data, (lb=1e-3, ub=150), false) end function benchmark_inter_framework_rules() @@ -280,7 +277,7 @@ function benchmark_inter_framework_rules() test_cases = map(last, test_case_data) memory = [] ranges = fill(nothing, length(test_cases)) - return benchmark_rules!!([(test_cases, memory, ranges, tags)], (lb=0.1, ub=150), true, true) + return benchmark_rules!!([(test_cases, memory, ranges, tags)], (lb=0.1, ub=150), true) end function flag_concerning_performance(ratios) diff --git a/src/Taped.jl b/src/Taped.jl index 74859082e..1485685be 100644 --- a/src/Taped.jl +++ b/src/Taped.jl @@ -6,6 +6,7 @@ using BenchmarkTools, DiffRules, ExprTools, + Graphs, InteractiveUtils, LinearAlgebra, Random, @@ -20,8 +21,8 @@ using Base.Experimental: @opaque using Base.Iterators: product using Core: Intrinsics, bitcast, SimpleVector, svec, ReturnNode, GotoNode, GotoIfNot, PhiNode, - PiNode, SSAValue, Argument -using Core.Compiler: IRCode + PiNode, SSAValue, Argument, OpaqueClosure +using Core.Compiler: IRCode, NewInstruction using Core.Intrinsics: pointerref, pointerset using LinearAlgebra.BLAS: @blasfunc, BlasInt, trsm! using LinearAlgebra.LAPACK: getrf!, getrs!, getri!, trtrs!, potrf!, potrs! @@ -35,11 +36,14 @@ include("codual.jl") include("stack.jl") include(joinpath("interpreter", "contexts.jl")) +include(joinpath("interpreter", "abstract_interpretation.jl")) +include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_utils.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) -include(joinpath("interpreter", "abstract_interpretation.jl")) +include(joinpath("interpreter", "registers.jl")) include(joinpath("interpreter", "interpreted_function.jl")) include(joinpath("interpreter", "reverse_mode_ad.jl")) +include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) include("test_utils.jl") @@ -54,13 +58,13 @@ include(joinpath("rrules", "misc.jl")) include(joinpath("rrules", "new.jl")) include("chain_rules_macro.jl") +include("interface.jl") export primal, tangent, randn_tangent, increment!!, - increment_field!!, NoTangent, Tangent, MutableTangent, @@ -74,6 +78,9 @@ export _dot, zero_codual, codual_type, - rrule!! + rrule!!, + build_rrule, + value_and_gradient!!, + value_and_pullback!! end diff --git a/src/codual.jl b/src/codual.jl index 8114350bf..2e92c4a3e 100644 --- a/src/codual.jl +++ b/src/codual.jl @@ -24,7 +24,7 @@ zero_codual(x) = CoDual(x, zero_tangent(x)) See implementation for details, as this function is subject to change. """ -uninit_codual(x) = CoDual(x, uninit_tangent(x)) +@inline uninit_codual(x::P) where {P} = CoDual(x, uninit_tangent(x)) """ codual_type(P::Type) diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 000000000..35c487464 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,60 @@ +""" + value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...) + +In-place version of `value_and_pullback!!` in which the arguments have been wrapped in +`CoDual`s. Note that any mutable data in `f` and `x` will be incremented in-place. As such, +if calling this function multiple times with different values of `x`, should be careful to +ensure that you zero-out the tangent fields of `x` each time. +""" +function value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual, N}) where {R, N, T} + out, pb!! = rule(fx...) + @assert _typeof(tangent(out)) == T + ty = increment!!(tangent(out), ȳ) + return primal(out), pb!!(ty, map(tangent, fx)...) +end + +""" + value_and_gradient!!(rule, f::CoDual, x::CoDual...) + +Equivalent to `value_and_pullback(rule, 1.0, f, x...)` -- assumes `f` returns a `Float64`. +""" +function value_and_gradient!!(rule::R, fx::Vararg{CoDual, N}) where {R, N} + return value_and_pullback!!(rule, 1.0, fx...) +end + +""" + value_and_pullback!!(rule, ȳ, f, x...) + +Compute the value and pullback of `f(x...)`. + +`rule` should be constructed using `build_rrule`. + +*Note:* If calling `value_and_pullback!!` multiple times for various values of `x`, you +should use the same instance of `rule` each time. + +*Note:* It is your responsibility to ensure that there is no aliasing in `f` and `x`. +For example, +```julia +X = randn(5, 5) +rule = build_rrule(dot, X, X) +value_and_pullback!!(rule, 1.0, dot, X, X) +``` +will yield the wrong result. + +*Note:* This method of `value_and_pullback!!` has to first call `zero_codual` on all of its +arguments. This may cause some additional allocations. If this is a problem in your +use-case, consider pre-allocating the `CoDual`s and calling the other method of this +function. +""" +function value_and_pullback!!(rule::R, ȳ, fx::Vararg{Any, N}) where {R, N} + return value_and_pullback!!(rule, ȳ, map(zero_codual, fx)...) +end + +""" + value_and_gradient!!(rule, f, x...) + +Equivalent to `value_and_pullback(rule, 1.0, f, x...)` -- assumes `f` returns a `Float64`. +""" +function value_and_gradient!!(rule::R, fx::Vararg{Any, N}) where {R, N} + return value_and_gradient!!(rule, map(zero_codual, fx)...) +end diff --git a/src/interpreter/abstract_interpretation.jl b/src/interpreter/abstract_interpretation.jl index cd9c732ca..dad50c93b 100644 --- a/src/interpreter/abstract_interpretation.jl +++ b/src/interpreter/abstract_interpretation.jl @@ -12,15 +12,15 @@ end TICache() = TICache(IdDict{Core.MethodInstance, Core.CodeInstance}()) struct TapedInterpreter{C} <: CC.AbstractInterpreter - ctx::C meta # additional information world::UInt inf_params::CC.InferenceParams opt_params::CC.OptimizationParams inf_cache::Vector{CC.InferenceResult} code_cache::TICache + oc_cache::Dict{Any, Any} function TapedInterpreter( - ctx::C=DefaultCtx(); + ::Type{C}; meta=nothing, world::UInt=Base.get_world_counter(), inf_params::CC.InferenceParams=CC.InferenceParams(), @@ -28,10 +28,12 @@ struct TapedInterpreter{C} <: CC.AbstractInterpreter inf_cache::Vector{CC.InferenceResult}=CC.InferenceResult[], code_cache::TICache=TICache(), ) where {C} - return new{C}(ctx, meta, world, inf_params, opt_params, inf_cache, code_cache) + return new{C}(meta, world, inf_params, opt_params, inf_cache, code_cache, Dict()) end end +TapedInterpreter() = TapedInterpreter(DefaultCtx) + const TInterp = TapedInterpreter CC.InferenceParams(interp::TInterp) = interp.inf_params @@ -82,3 +84,5 @@ function CC.inlining_policy( argtypes::Vector{Any}, ) end + +context_type(::TInterp{C}) where {C} = C diff --git a/src/interpreter/bbcode.jl b/src/interpreter/bbcode.jl new file mode 100644 index 000000000..78d1d658c --- /dev/null +++ b/src/interpreter/bbcode.jl @@ -0,0 +1,584 @@ +_id_count::Int32 = 0 + +""" + ID() + +An `ID` (read: unique name) is just a wrapper around an `Int32`. Uniqueness is ensured via a +global counter, which is incremented each time that an `ID` is created. + +This counter can be reset using `seed_id!` if you need to ensure deterministic `ID`s are +produced, in the same way that seed for random number generators can be set. +""" +struct ID + id::Int32 + function ID() + global _id_count += 1 + return new(_id_count) + end +end + +Base.copy(id::ID) = id + +""" + seed_id!() + +Set the global counter used to ensure ID uniqueness to 0. This is useful when you want to +ensure determinism between two runs of the same function which makes use of `ID`s. + +This is akin to setting the random seed associated to a random number generator globally. +""" +function seed_id!() + global _id_count = 0 +end + +""" + IDPhiNode(edges::Vector{ID}, values::Vector{Any}) + +Like a `PhiNode`, but `edges` are `ID`s rather than `Int32`s. +""" +struct IDPhiNode + edges::Vector{ID} + values::Vector{Any} +end + +Base.copy(node::IDPhiNode) = IDPhiNode(copy(node.edges), copy(node.values)) + +""" + IDGotoNode(label::ID) + +Like a `GotoNode`, but `label` is an `ID` rather than an `Int64`. +""" +struct IDGotoNode + label::ID +end + +Base.copy(node::IDGotoNode) = IDGotoNode(copy(node.label)) + +""" + IDGotoIfNot(cond::Any, dest::ID) + +Like a `GotoIfNot`, but `dest` is an `ID` rather than an `Int64`. +""" +struct IDGotoIfNot + cond::Any + dest::ID +end + +Base.copy(node::IDGotoIfNot) = IDGotoIfNot(copy(node.cond), copy(node.dest)) + +""" + Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) + +A switch-statement node. These can be inserted in the `BBCode` representation of Julia IR. +`Switch` has the following semantics: +```julia +goto dests[1] if not conds[1] +goto dests[2] if not conds[2] +... +goto dests[N] if not conds[N] +goto fallthrough_dest +``` +where the value associated to each element of `conds` is a `Bool`, and `dests` indicate +which block to jump to. If none of the conditions are met, then we go to whichever block is +specified by `fallthrough_dest`. + +`Switch` statements are lowered into the above sequence of `GotoIfNot`s and `GotoNode`s +when converting `BBCode` back into `IRCode`, because `Switch` statements are not valid +nodes in regular Julia IR. +""" +struct Switch + conds::Vector{Any} + dests::Vector{ID} + fallthrough_dest::ID + function Switch(conds::Vector{Any}, dests::Vector{ID}, fallthrough_dest::ID) + @assert length(conds) == length(dests) + return new(conds, dests, fallthrough_dest) + end +end + +""" + Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode} + +A Union of the possible types of a terminator node. +""" +const Terminator = Union{Switch, IDGotoIfNot, IDGotoNode, ReturnNode} + +""" + const InstVector = Vector{NewInstruction} + +Note: the `CC.NewInstruction` type is used to represent instructions because it has the +correct fields. While it is only used to represent new instrucdtions in `Core.Compiler`, it +is used to represent all instructions in `BBCode`. +""" +const InstVector = Vector{NewInstruction} + +""" + BBlock(id::ID, stmt_ids::Vector{ID}, stmts::InstVector) + +A basic block data structure (not called `BasicBlock` to avoid accidental confusion with +`CC.BasicBlock`). Forms a single basic block. + +Each `BBlock` has an `ID` (a unique name). This makes it possible to refer to blocks in a +way that does not change when additional `BBlocks` are inserted into a `BBCode`. +This differs from the positional block numbering found in `IRCode`, in which the number +associated to a basic block changes when new blocks are inserted. + +The `n`th line of code in a `BBlock` is associated to `ID` `stmt_ids[n]`, and the `n`th +instruction from `stmts`. + +Note that `PhiNode`s, `GotoIfNot`s, and `GotoNode`s should not appear in a `BBlock` -- +instead an `IDPhiNode`, `IDGotoIfNot`, or `IDGotoNode` should be used. +""" +mutable struct BBlock + id::ID + inst_ids::Vector{ID} + insts::InstVector + function BBlock(id::ID, inst_ids::Vector{ID}, insts::InstVector) + @assert length(inst_ids) == length(insts) + return new(id, inst_ids, insts) + end +end + +""" + BBlock(id::ID, inst_pairs::Vector{Tuple{ID, NewInstruction}}) + +Convenience constructor -- splits `inst_pairs` into a `Vector{ID}` and `InstVector` in order +to build a `BBlock`. +""" +function BBlock(id::ID, inst_pairs::Vector{Tuple{ID, NewInstruction}}) + return BBlock(id, first.(inst_pairs), last.(inst_pairs)) +end + +Base.length(bb::BBlock) = length(bb.inst_ids) + +Base.copy(bb::BBlock) = BBlock(bb.id, copy(bb.inst_ids), copy(bb.insts)) + +""" + Base.insert!(bb::BBlock, n::Int, id::ID, stmt::CC.NewInstruction)::Nothing + +Inserts `stmt` and `id` into `bb` immediately before the `n`th instruction. +""" +function Base.insert!(bb::BBlock, n::Int, id::ID, inst::NewInstruction)::Nothing + insert!(bb.inst_ids, n, id) + insert!(bb.insts, n, inst) + return nothing +end + +""" + terminator(bb::BBlock) + +Returns the terminator associated to `bb`. If the last instruction in `bb` isa +`Terminator` then that is returned, otherwise `nothing` is returned. +""" +terminator(bb::BBlock) = isa(bb.insts[end].stmt, Terminator) ? bb.insts[end].stmt : nothing + +""" + collect_stmts(bb::BBlock)::Vector{Tuple{ID, NewInstruction}} + +Returns a `Vector` containing the `ID`s and instructions associated to each line in `bb`. +These should be assumed to be ordered. +""" +function collect_stmts(bb::BBlock)::Vector{Tuple{ID, NewInstruction}} + return collect(zip(bb.inst_ids, bb.insts)) +end + +""" + BBCode( + blocks::Vector{BBlock} + argtypes::Vector{Any} + sptypes::Vector{CC.VarState} + linetable::Vector{Core.LineInfoNode} + meta::Vector{Expr} + ) + +A `BBCode` is a data structure which is similar to `IRCode`, but adds additional structure. + +In particular, a `BBCode` comprises a sequence of basic blocks (`BBlock`s), each of which +comprise a sequence of statements. Moreover, each `BBlock` has its own unique `ID`, as does +each statment. + +The consequence of this is that new basic blocks can be inserted into a `BBCode`. This is +distinct from `IRCode`, in which to create a new basic block, one must insert additional +statments which you know will create a new basic block -- this is generally quite an +unreliable process, while inserting a new `BBlock` into `BBCode` is entirely predictable. +Furthermore, inserting a new `BBlock` does not change the `ID` associated to the other +blocks, meaning that you can safely assume that references from existing basic block +terminators / phi nodes to other blocks will not be modified by inserting a new basic block. + +Additionally, since each statment in each basic block has its own unique `ID`, new +statments can be inserted without changing references between other blocks. `IRCode` also +has some support for this via its `new_nodes` field, but eventually all statements will be +renamed upon `compact!`ing the `IRCode`, meaning that the name of any given statement will +eventually change. + +Finally, note that the basic blocks in a `BBCode` support the custom `Switch` statement. +This statement is not valid in `IRCode`, and is therefore lowered into a collection of +`GotoIfNot`s and `GotoNode`s when a `BBCode` is converted back into an `IRCode`. +""" +struct BBCode + blocks::Vector{BBlock} + argtypes::Vector{Any} + sptypes::Vector{CC.VarState} + linetable::Vector{Core.LineInfoNode} + meta::Vector{Expr} +end + +""" + BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{Block}) + +Make a new `BBCode` whose `blocks` is given by `new_blocks`, and fresh copies are made of +all other fields from `ir`. +""" +function BBCode(ir::Union{IRCode, BBCode}, new_blocks::Vector{BBlock}) + return BBCode( + new_blocks, + CC.copy(ir.argtypes), + CC.copy(ir.sptypes), + CC.copy(ir.linetable), + CC.copy(ir.meta), + ) +end + +# Makes use of the above outer constructor for `BBCode`. +Base.copy(ir::BBCode) = BBCode(ir, copy(ir.blocks)) + +""" + compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}} + +Compute a map from the `ID of each `BBlock` in `ir` to its possible successors. +""" +function compute_all_successors(ir::BBCode)::Dict{ID, Vector{ID}} + succs = map(enumerate(ir.blocks)) do (n, blk) + return successors(terminator(blk), n, ir, n == length(ir.blocks)) + end + return Dict{ID, Vector{ID}}(zip(map(b -> b.id, ir.blocks), succs)) +end + +function successors(::Nothing, n::Int, ir::BBCode, is_final_block::Bool) + return is_final_block ? ID[] : ID[ir.blocks[n+1].id] +end +successors(t::IDGotoNode, ::Int, ::BBCode, ::Bool) = [t.label] +function successors(t::IDGotoIfNot, n::Int, ir::BBCode, is_final_block::Bool) + return is_final_block ? ID[t.dest] : ID[t.dest, ir.blocks[n + 1].id] +end +successors(::ReturnNode, ::Int, ::BBCode, ::Bool) = ID[] +successors(t::Switch, ::Int, ::BBCode, ::Bool) = vcat(t.dests, t.fallthrough_dest) + +""" + compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} + +Compute a map from the `ID of each `BBlock` in `ir` to its possible predecessors. +""" +function compute_all_predecessors(ir::BBCode)::Dict{ID, Vector{ID}} + + successor_map = compute_all_successors(ir) + + # Initialise predecessor map to be empty. + ks = collect(keys(successor_map)) + predecessor_map = Dict{ID, Vector{ID}}(zip(ks, map(_ -> ID[], ks))) + + # Find all predecessors by iterating through the successor map. + for (k, succs) in successor_map + for succ in succs + push!(predecessor_map[succ], k) + end + end + + return predecessor_map +end + +""" + collect_stmts(ir::BBCode)::Vector{Tuple{ID, CC.NewInstruction}} + +Produce a `Vector` containing all of the statements in `ir`. These are returned in +order, so it is safe to assume that element `n` refers to the `nth` element of the `IRCode` +associated to `ir`. +""" +function collect_stmts(ir::BBCode)::Vector{Tuple{ID, NewInstruction}} + return reduce(vcat, map(collect_stmts, ir.blocks)) +end + +""" + id_to_line_map(ir::BBCode) + +Produces a `Dict` mapping from each `ID` associated with a line in `ir` to its line number. +This is isomorphic to mapping to its `SSAValue` in `IRCode`. Terminators do not have `ID`s +associated to them, so not every line in the original `IRCode` is mapped to. +""" +function id_to_line_map(ir::BBCode) + lines = collect_stmts(ir) + lines_and_line_numbers = collect(zip(lines, eachindex(lines))) + ids_and_line_numbers = map(x -> (x[1][1], x[2]), lines_and_line_numbers) + return Dict(ids_and_line_numbers) +end + +concatenate_ids(bb_code::BBCode) = reduce(vcat, map(b -> b.inst_ids, bb_code.blocks)) +concatenate_stmts(bb_code::BBCode) = reduce(vcat, map(b -> b.insts, bb_code.blocks)) + +# +# Converting from IRCode to BBCode +# + +""" + BBCode(ir::IRCode) + +Convert an `ir` into a `BBCode`. Creates a completely independent data structure, so +mutating the `BBCode` returned will not mutate `ir`. + +All `PhiNode`s, `GotoIfNot`s, and `GotoNode`s will be replaced with the `IDPhiNode`s, +`IDGotoIfNot`s, and `IDGotoNode`s respectively. + +See `IRCode` for conversion back to `IRCode`. + +Note that `IRCode(BBCode(ir))` should be equal to the identity function. +""" +function BBCode(ir::IRCode) + + # Produce a new set of statements with `IDs` rather than `SSAValues` and block numbers. + insts = new_inst_vec(ir.stmts) + ssa_ids, stmts = _ssas_to_ids(insts) + block_ids, stmts = _block_nums_to_ids(stmts, ir.cfg) + + # Chop up the new statements into `BBlocks`, according to the `CFG` in `ir`. + blocks = map(zip(ir.cfg.blocks, block_ids)) do (bb, id) + return BBlock(id, ssa_ids[bb.stmts], stmts[bb.stmts]) + end + return BBCode(ir, blocks) +end + +# Convert an InstructionStream into a list of `NewInstruction`s. +function new_inst_vec(x::CC.InstructionStream) + return map((v..., ) -> NewInstruction(v...), x.inst, x.type, x.info, x.line, x.flag) +end + +# Maps from positional names (SSAValues for nodes, Integers for basic blocks) to IDs. +const SSAToIdDict = Dict{SSAValue, ID} +const BlockNumToIdDict = Dict{Integer, ID} + +# Assigns an ID to each line in `stmts`, and replaces each instance of an `SSAValue` in each +# line with the corresponding `ID`. For example, a call statement of the form +# `Expr(:call, :f, %4)` is be replaced with `Expr(:call, :f, id_assigned_to_%4)`. +function _ssas_to_ids(insts::InstVector)::Tuple{Vector{ID}, InstVector} + ids = map(_ -> ID(), insts) + val_id_map = SSAToIdDict(zip(SSAValue.(eachindex(insts)), ids)) + return ids, map(Base.Fix1(_ssa_to_ids, val_id_map), insts) +end + +# Produce a new instance of `x` in which all instances of `SSAValue`s are replaced with +# the `ID`s prescribed by `d`, all basic block numbers are replaced with the `ID`s +# prescribed by `d`, and `GotoIfNot`, `GotoNode`, and `PhiNode` instances are replaced with +# the corresponding `ID` versions. +function _ssa_to_ids(d::SSAToIdDict, inst::NewInstruction) + return NewInstruction(inst; stmt=_ssa_to_ids(d, inst.stmt)) +end +function _ssa_to_ids(d::SSAToIdDict, x::ReturnNode) + return isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +end +_ssa_to_ids(d::SSAToIdDict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +_ssa_to_ids(d::SSAToIdDict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +_ssa_to_ids(d::SSAToIdDict, x::QuoteNode) = x +_ssa_to_ids(d::SSAToIdDict, x) = x +function _ssa_to_ids(d::SSAToIdDict, x::PhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = get(d, x.values[n], x.values[n]) + end + end + return PhiNode(x.edges, new_values) +end +_ssa_to_ids(d::SSAToIdDict, x::GotoNode) = x +_ssa_to_ids(d::SSAToIdDict, x::GotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), x.dest) + +# Replace all integers corresponding to references to blocks with IDs. +function _block_nums_to_ids(insts::InstVector, cfg::CC.CFG)::Tuple{Vector{ID}, InstVector} + ids = map(_ -> ID(), cfg.blocks) + block_num_id_map = BlockNumToIdDict(zip(eachindex(cfg.blocks), ids)) + return ids, map(Base.Fix1(_block_num_to_ids, block_num_id_map), insts) +end + +function _block_num_to_ids(d::BlockNumToIdDict, x::NewInstruction) + return NewInstruction(x; stmt=_block_num_to_ids(d, x.stmt)) +end +function _block_num_to_ids(d::BlockNumToIdDict, x::PhiNode) + return IDPhiNode(ID[d[e] for e in x.edges], x.values) +end +_block_num_to_ids(d::BlockNumToIdDict, x::GotoNode) = IDGotoNode(d[x.label]) +_block_num_to_ids(d::BlockNumToIdDict, x::GotoIfNot) = IDGotoIfNot(x.cond, d[x.dest]) +_block_num_to_ids(d::BlockNumToIdDict, x) = x + +# +# Converting from BBCode to IRCode +# + +""" + IRCode(bb_code::BBCode) + +Produce an `IRCode` instance which is equivalent to `bb_code`. The resulting `IRCode` +shares no memory with `bb_code`, so can be safely mutated without modifying `bb_code`. + +All `IDPhiNode`s, `IDGotoIfNot`s, and `IDGotoNode`s are converted into `PhiNode`s, +`GotoIfNot`s, and `GotoNode`s respectively. + +In the resulting `bb_code`, any `Switch` nodes are lowered into a semantically-equivalent +collection of `GotoIfNot` nodes. +""" +function CC.IRCode(bb_code::BBCode) + bb_code = _lower_switch_statements(bb_code) + bb_code = _remove_double_edges(bb_code) + insts = _ids_to_line_positions(bb_code) + cfg = _compute_basic_blocks(insts) + insts = _lines_to_blocks(insts, cfg) + return IRCode( + CC.InstructionStream( + map(x -> x.stmt, insts), + map(x -> x.type, insts), + map(x -> x.info, insts), + map(x -> x.line, insts), + map(x -> x.flag, insts), + ), + cfg, + CC.copy(bb_code.linetable), + CC.copy(bb_code.argtypes), + CC.copy(bb_code.meta), + CC.copy(bb_code.sptypes), + ) +end + +# Converts all `Switch`s into a semantically-equivalent collection of `GotoIfNot`s. See the +# `Switch` docstring for an explanation of what is going on here. +function _lower_switch_statements(bb_code::BBCode) + new_blocks = Vector{BBlock}(undef, 0) + for block in bb_code.blocks + t = terminator(block) + if t isa Switch + + # Create new block without the `Switch`. + bb = BBlock(block.id, block.inst_ids[1:end-1], block.insts[1:end-1]) + push!(new_blocks, bb) + + # Create new blocks for each `GotoIfNot` from the `Switch`. + foreach(t.conds, t.dests) do cond, dest + blk = BBlock(ID(), [ID()], [new_inst(IDGotoIfNot(cond, dest), Any)]) + push!(new_blocks, blk) + end + + # Create a new block for the fallthrough dest. + fallthrough_inst = new_inst(IDGotoNode(t.fallthrough_dest), Any) + push!(new_blocks, BBlock(ID(), [ID()], [fallthrough_inst])) + else + push!(new_blocks, block) + end + end + return BBCode(bb_code, new_blocks) +end + +# Returns a `Vector{Any}` of statements in which each `ID` has been replaced by either an +# `SSAValue`, or an `Int64` / `Int32` which refers to an `SSAValue`. +function _ids_to_line_positions(bb_code::BBCode)::InstVector + + # Construct map from `ID`s to `SSAValue`s. + block_ids = [b.id for b in bb_code.blocks] + block_lengths = map(length, bb_code.blocks) + block_start_ssas = SSAValue.(vcat(1, cumsum(block_lengths)[1:end-1] .+ 1)) + line_ids = concatenate_ids(bb_code) + line_ssas = SSAValue.(eachindex(line_ids)) + id_to_ssa_map = Dict(zip(vcat(block_ids, line_ids), vcat(block_start_ssas, line_ssas))) + + # Apply map. + return [_to_ssas(id_to_ssa_map, stmt) for stmt in concatenate_stmts(bb_code)] +end + +# Like `_to_ids`, but converts IDs to SSAValues / (integers corresponding to ssas). +_to_ssas(d::Dict, inst::NewInstruction) = NewInstruction(inst; stmt=_to_ssas(d, inst.stmt)) +_to_ssas(d::Dict, x::ReturnNode) = isdefined(x, :val) ? ReturnNode(get(d, x.val, x.val)) : x +_to_ssas(d::Dict, x::Expr) = Expr(x.head, map(a -> get(d, a, a), x.args)...) +_to_ssas(d::Dict, x::PiNode) = PiNode(get(d, x.val, x.val), get(d, x.typ, x.typ)) +_to_ssas(d::Dict, x::QuoteNode) = x +_to_ssas(d::Dict, x) = x +function _to_ssas(d::Dict, x::IDPhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = get(d, x.values[n], x.values[n]) + end + end + return PhiNode(map(e -> Int32(getindex(d, e).id), x.edges), new_values) +end +_to_ssas(d::Dict, x::IDGotoNode) = GotoNode(d[x.label].id) +_to_ssas(d::Dict, x::IDGotoIfNot) = GotoIfNot(get(d, x.cond, x.cond), d[x.dest].id) + +# Compute the CFG associated to `insts`. All references to blocks must be references to +# the SSAValue associated to the first statement in the block. +function _compute_basic_blocks(insts::InstVector) + return CC.compute_basic_blocks(Any[inst.stmt for inst in insts]) +end + +# Replaces references to blocks by line-number with references to block numbers. +function _lines_to_blocks(insts::InstVector, cfg::CC.CFG) + return map(inst -> __lines_to_blocks(cfg, inst), insts) +end + +function __lines_to_blocks(cfg::CC.CFG, inst::NewInstruction) + return NewInstruction(inst; stmt=__lines_to_blocks(cfg, inst.stmt)) +end +function __lines_to_blocks(cfg::CC.CFG, stmt::GotoNode) + return GotoNode(CC.block_for_inst(cfg, stmt.label)) +end +function __lines_to_blocks(cfg::CC.CFG, stmt::GotoIfNot) + return GotoIfNot(stmt.cond, CC.block_for_inst(cfg, stmt.dest)) +end +function __lines_to_blocks(cfg::CC.CFG, stmt::PhiNode) + return PhiNode(Int32[CC.block_for_inst(cfg, Int(e)) for e in stmt.edges], stmt.values) +end +function __lines_to_blocks(cfg::CC.CFG, stmt::Expr) + Meta.isexpr(stmt, :enter) && throw(error("Cannot handle enter yet")) + return stmt +end +__lines_to_blocks(::CC.CFG, stmt) = stmt + +# If the `dest` field of a `GotoIfNot` node points towards the next block, replace it with +# a `GotoNode`. +function _remove_double_edges(ir::BBCode) + new_blks = map(enumerate(ir.blocks)) do (n, blk) + t = terminator(blk) + if t isa IDGotoIfNot && t.dest == ir.blocks[n+1].id + new_insts = vcat(blk.insts[1:end-1], NewInstruction(t; stmt=IDGotoNode(t.dest))) + return BBlock(blk.id, blk.inst_ids, new_insts) + else + return blk + end + end + return BBCode(ir, new_blks) +end + +#= + _sort_blocks!(ir::BBCode) + +Ensure that blocks appear in order of distance-from-entry-point, where distance the +distance from block b to the entry point is defined to be the minimum number of basic +blocks that must be passed through in order to reach b. + +For reasons unknown (to me, Will), the compiler / optimiser needs this for inference to +succeed. Since we do quite a lot of re-ordering on the reverse-pass of AD, this is a problem +there. + +WARNING: use with care. Only use if you are confident that arbitrary re-ordering of basic +blocks in `ir` is valid. Notably, this does not hold if you have any `IDGotoIfNot` nodes in +`ir`. +=# +function _sort_blocks!(ir::BBCode) + + node_ints = collect(eachindex(ir.blocks)) + id_to_int = Dict(zip(map(blk -> blk.id, ir.blocks), node_ints)) + ps = compute_all_predecessors(ir) + direct_predecessors = map(ir.blocks) do blk + return map(b -> Edge(id_to_int[b], id_to_int[blk.id]), ps[blk.id]) + end + g = SimpleDiGraph(reduce(vcat, direct_predecessors)) + + d = dijkstra_shortest_paths(g, id_to_int[ir.blocks[1].id]).dists + I = sortperm(d) + ir.blocks .= ir.blocks[I] + return ir +end diff --git a/src/interpreter/interpreted_function.jl b/src/interpreter/interpreted_function.jl index 6aa5cba3f..40fc3eb5a 100644 --- a/src/interpreter/interpreted_function.jl +++ b/src/interpreter/interpreted_function.jl @@ -477,13 +477,7 @@ explosion of types. Moreover, type-stability it maintained. function InterpretedFunction(ctx::C, sig::Type{<:Tuple}, interp) where {C} # Grab code associated to this function. - output = Base.code_ircode_by_type(sig; interp) - if isempty(output) - throw(ArgumentError("No methods found for signature $sig")) - elseif length(output) > 1 - throw(ArgumentError("$(length(output)) methods found for signature $sig")) - end - ir, Treturn = only(output) + ir, Treturn = lookup_ir(interp, sig) # Slot into which the output of this function will be placed. return_slot = SlotRef{Treturn}() @@ -553,19 +547,6 @@ function __barrier(in_f::Tf) where {Tf<:InterpretedFunction} return in_f.return_slot[] end -""" - tuple_map(f::F, x::Tuple) where {F} - -This function is semantically equivalent to `map(f, x)`, but always specialises on all of -the element types of `x`, regardless the length of `x`. This contrasts with `map`, in which -the number of element types specialised upon is a fixed constant in the compiler. - -As a consequence, if `x` is very long, this function may have very large compile times. -""" -@generated function tuple_map(f::F, x::Tuple) where {F} - return Expr(:call, :tuple, map(n -> :(f(x[$n])), eachindex(x.parameters))...) -end - # Produce a `Dict` mapping from block numbers to line number of their first statement. function block_map(cfg::CC.CFG) line_to_blk_maps = map(((n, blk),) -> tuple.(blk.stmts, n), enumerate(cfg.blocks)) diff --git a/src/interpreter/ir_utils.jl b/src/interpreter/ir_utils.jl index 167c763c5..176949860 100644 --- a/src/interpreter/ir_utils.jl +++ b/src/interpreter/ir_utils.jl @@ -119,11 +119,65 @@ function __infer_ir!(ir, interp::CC.AbstractInterpreter, mi::CC.MethodInstance) irsv = CC.IRInterpretationState( interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world ) - ir.stmts.flag .|= CC.IR_FLAG_REFINED rt = CC._ir_abstract_constant_propagation(interp, irsv) return ir end +# In automatically generated code, it is meaningless to include code coverage effects. +# Moreover, it seems to cause some serious inference probems. Consequently, it makes sense +# to remove such effects before optimising IRCode. +function __strip_coverage!(ir::IRCode) + for n in eachindex(ir.stmts.inst) + if Meta.isexpr(ir.stmts.inst[n], :code_coverage_effect) + ir.stmts.inst[n] = nothing + end + end + return ir +end + +""" + optimise_ir!(ir::IRCode, show_ir=false) + +Run a fairly standard optimisation pass on `ir`. If `show_ir` is `true`, displays the IR +to `stdout` at various points in the pipeline -- this is sometimes useful for debugging. +""" +function optimise_ir!(ir::IRCode; show_ir=false, do_inline=true) + if show_ir + println("Pre-optimization") + display(ir) + println() + end + CC.verify_ir(ir) + ir = __strip_coverage!(ir) + ir = CC.compact!(ir) + local_interp = CC.NativeInterpreter() + mi = __get_toplevel_mi_from_ir(ir, @__MODULE__); + ir = __infer_ir!(ir, local_interp, mi) + if show_ir + println("Post-inference") + display(ir) + println() + end + inline_state = CC.InliningState(local_interp) + CC.verify_ir(ir) + if do_inline + ir = CC.ssa_inlining_pass!(ir, inline_state, #=propagate_inbounds=#true) + ir = CC.compact!(ir) + end + ir = __strip_coverage!(ir) + ir = CC.sroa_pass!(ir, inline_state) + ir = CC.adce_pass!(ir, inline_state) + ir = CC.compact!(ir) + # CC.verify_ir(ir, true, false, CC.optimizer_lattice(local_interp)) + CC.verify_linetable(ir.linetable, true) + if show_ir + println("Post-optimization") + display(ir) + println() + end + return ir +end + """ replace_all_uses_with!(ir::IRCode, value::SSAValue, new_value::Any) -> IRCode @@ -173,3 +227,92 @@ end # Return new_value if val equals current_val. _replace(val::SSAValue, new_val, current_val) = val == current_val ? new_val : current_val + +""" + lookup_ir(interp::AbstractInterpreter, sig::Type{<:Tuple})::Tuple{IRCode, T} + +Get the IR unique IR associated to `sig` under `interp`. Throws `ArgumentError`s if there is +no code found, or if more than one `IRCode` instance returned. + +Returns a tuple containing the `IRCode` and its return type. +""" +function lookup_ir(interp::CC.AbstractInterpreter, sig::Type{<:Tuple}) + output = Base.code_ircode_by_type(sig; interp) + if isempty(output) + throw(ArgumentError("No methods found for signature $sig")) + elseif length(output) > 1 + throw(ArgumentError("$(length(output)) methods found for signature $sig")) + end + return only(output) +end + +""" + is_reachable(x::ReturnNode) + +Determine whether `x` is a `ReturnNode`, and if it is, if it is also reachable. This is +purely a function of whether or not its `val` field is defined or not. +""" +is_reachable_return_node(x::ReturnNode) = isdefined(x, :val) +is_reachable_return_node(x) = false + +""" + globalref_type(x::GlobaRef) + +Returns the static type of the value referred to by `x`. +""" +globalref_type(x::GlobalRef) = isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty + +""" + UnhandledLanguageFeatureException(message::String) + +An exception used to indicate that some aspect of the Julia language which AD cannot handle +has been encountered. +""" +struct UnhandledLanguageFeatureException <: Exception + msg::String +end + +""" + unhandled_feature(msg::String) + +Throw an `UnhandledLanguageFeatureException` with message `msg`. +""" +unhandled_feature(msg::String) = throw(UnhandledLanguageFeatureException(msg)) + +""" + inc_args(stmt) + +Increment by `1` the `n` field of any `Argument`s present in `stmt`. +""" +inc_args(x::Expr) = Expr(x.head, map(__inc, x.args)...) +inc_args(x::ReturnNode) = isdefined(x, :val) ? ReturnNode(__inc(x.val)) : x +inc_args(x::IDGotoIfNot) = IDGotoIfNot(__inc(x.cond), x.dest) +inc_args(x::IDGotoNode) = x +function inc_args(x::IDPhiNode) + new_values = Vector{Any}(undef, length(x.values)) + for n in eachindex(x.values) + if isassigned(x.values, n) + new_values[n] = __inc(x.values[n]) + end + end + return IDPhiNode(x.edges, new_values) +end +inc_args(::Nothing) = nothing +inc_args(x::GlobalRef) = x + +__inc(x::Argument) = Argument(x.n + 1) +__inc(x) = x + +""" + new_inst(stmt, type=Any, flag=CC.IR_FLAG_REFINED)::NewInstruction + +Create a `NewInstruction` with fields: +- `stmt` = `stmt` +- `type` = `type` +- `info` = `CC.NoCallInfo()` +- `line` = `Int32(1)` +- `flag` = `flag` +""" +function new_inst(@nospecialize(stmt), @nospecialize(type)=Any, flag=CC.IR_FLAG_REFINED) + return NewInstruction(stmt, type, CC.NoCallInfo(), Int32(1), flag) +end diff --git a/src/interpreter/registers.jl b/src/interpreter/registers.jl new file mode 100644 index 000000000..68e708662 --- /dev/null +++ b/src/interpreter/registers.jl @@ -0,0 +1,46 @@ +""" + AugmentedRegister(codual::CoDual, tangent_stack) + +A wrapper data structure for bundling together a codual and a tangent stack. These appear +in the code associated to active values in the primal. + +For example, a statment in the primal such as +```julia +%5 = sin(%4)::Float64 +``` +which provably returns a `Float64` in the primal, would return an `register_type(Float64)` +in the forwards-pass, where `register_type` will return an `AugmentedRegister` when the +primal type is `Float64`. +""" +struct AugmentedRegister{T<:CoDual, V} + codual::T + tangent_ref::V +end + +@inline primal(reg::AugmentedRegister) = primal(reg.codual) + +""" + register_type(::Type{P}) where {P} + +If `P` is the type associated to a primal register, the corresponding register in the +forwards-pass must be a `register_type(P)`. +""" +function register_type(::Type{P}) where {P} + P == DataType && return Any + P == UnionAll && return Any + P isa Union && return __union_register_type(P) + if isconcretetype(P) + return AugmentedRegister{codual_type(P), tangent_ref_type_ub(P)} + else + return AugmentedRegister + end +end + +# Specialised method for unions. +function __union_register_type(::Type{P}) where {P} + if P isa Union + CC.tmerge(AugmentedRegister{codual_type(P.a)}, __union_register_type(P.b)) + else + return AugmentedRegister{codual_type(P)} + end +end diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl new file mode 100644 index 000000000..b4b5c63b8 --- /dev/null +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -0,0 +1,1033 @@ +#= + SharedDataPairs() + +A data structure used to manage the captured data in the `OpaqueClosures` which implement +the bulk of the forwards- and reverse-passes of AD. An entry `(id, data)` at element `n` +of the `pairs` field of this data structure means that `data` will be available at register +`id` during the forwards- and reverse-passes of `AD`. + +This is achieved by storing all of the data in the `pairs` field in the captured tuple which +is passed to an `OpaqueClosure`, and extracting this data into registers associated to the +corresponding `ID`s. +=# +struct SharedDataPairs + pairs::Vector{Tuple{ID, Any}} + SharedDataPairs() = new(Tuple{ID, Any}[]) +end + +#= + add_data!(p::SharedDataPairs, data)::ID + +Puts `data` into `p`, and returns the `id` associated to it. This `id` should be assumed to +be available during the forwards- and reverse-passes of AD, and it should further be assumed +that the value associated to this `id` is always `data`. +=# +function add_data!(p::SharedDataPairs, data)::ID + id = ID() + push!(p.pairs, (id, data)) + return id +end + +#= + shared_data_tuple(p::SharedDataPairs)::Tuple + +Create the tuple that will constitute the captured variables in the forwards- and reverse- +pass `OpaqueClosure`s. + +For example, if `p.pairs` is +```julia +[(ID(5), 5.0), (ID(3), "hello")] +``` +then the output of this function is +```julia +(5.0, "hello") +``` +=# +shared_data_tuple(p::SharedDataPairs)::Tuple = tuple(map(last, p.pairs)...) + +#= + shared_data_stmts(p::SharedDataPairs)::Vector{Tuple{ID, NewInstruction}} + +Produce a sequence of id-statment pairs which will extract the data from +`shared_data_tuple(p)` such that the correct value is associated to the correct `ID`. + +For example, if `p.pairs` is +```julia +[(ID(5), 5.0), (ID(3), "hello")] +``` +then the output of this function is +```julia +Tuple{ID, NewInstruction}[ + (ID(5), new_inst(:(getfield(_1, 1)))), + (ID(3), new_inst(:(getfield(_1, 2)))), +] +``` +=# +function shared_data_stmts(p::SharedDataPairs)::Vector{Tuple{ID, NewInstruction}} + return map(enumerate(p.pairs)) do (n, p) + return (p[1], new_inst(Expr(:call, getfield, Argument(1), n))) + end +end + +#= + ADInfo + +This data structure is used to hold "global" information associated to a particular call to +`build_rrule`. It is used as a means of communication between `make_ad_stmts!` and the +codegen which produces the forwards- and reverse-passes. + +- `interp`: a `TapedInterpreter`. +- `block_stack_id`: the ID associated to the block stack -- the stack which keeps track of + which blocks we visited during the forwards-pass, and which is used on the reverse-pass + to determine which blocks to visit. +- `block_stack`: the block stack. Can always be found at `block_stack_id` in the forwards- + and reverse-passes. +- `entry_id`: special ID associated to the block inserted at the start of execution in the + the forwards-pass, and the end of execution in the pullback. +- `shared_data_pairs`: the `SharedDataPairs` used to define the captured variables passed + to both the forwards- and reverse-passes.. +- `arg_types`: a map from `Argument` to its static type. +- `ssa_insts`: a map from `ID` associated to lines to the primal `NewInstruction`. +- `arg_tangent_stacks`: a map from primal `Argument`s to their tangent stacks. If the stack + associated to an `Argument` is a bits type then this will just be the tangent stack. + Otherwise, it will be the `ID` associated to the stack, and the stack itself will be put + in the `shared_data_pairs`. +- `tangent_stacks`: a map from `ID` to tangent stacks. If the tangent stack associated to + the `ID` is a bits type, then this will actually be the tangent stack. Otherwise it will + be the `ID` associated to the stack, and the stack itself will be put in the + `shared_data_pairs`. +=# +struct ADInfo + interp::TInterp + block_stack_id::ID + block_stack::Stack{Int32} + entry_id::ID + shared_data_pairs::SharedDataPairs + arg_types::Dict{Argument, Any} + ssa_insts::Dict{ID, NewInstruction} + arg_tangent_stacks::Dict{Argument, Any} + tangent_stacks::Dict{ID, Any} +end + +# The constructor that you should use for ADInfo. +function ADInfo( + interp::TInterp, + arg_types::Dict{Argument, Any}, + ssa_insts::Dict{ID, NewInstruction}, + arg_tangent_stacks, +) + shared_data_pairs = SharedDataPairs() + block_stack = Stack{Int32}() + return ADInfo( + interp, + add_data!(shared_data_pairs, block_stack), + block_stack, + ID(), + shared_data_pairs, + arg_types, + ssa_insts, + make_arg_tangent_stacks!(shared_data_pairs, arg_tangent_stacks), + make_tangent_stacks!(shared_data_pairs, ssa_insts), + ) +end + +function __log_data(p::Union{ADInfo, SharedDataPairs}, x) + return Base.issingletontype(_typeof(x)) ? x : add_data!(p, x) +end + +# Construct a map from primal `Argument`s to the location of its tangent stack in the +# forwards-pass and pullback. If tangent stack is a singleton, just yields the tangent +# stack itself. +function make_arg_tangent_stacks!(p::SharedDataPairs, arg_tangent_stacks) + arguments = Argument.(eachindex(arg_tangent_stacks)) + stack_ids = map(Base.Fix1(__log_data, p), arg_tangent_stacks) + return Dict{Argument, Any}(zip(arguments, stack_ids)) +end + +# Construct a map from primal `ID`s corresponding to lines in the IR, to the location of +# their tangent stacks in the forwards-pass and pullback. If tangent stacks is a singleton, +# just yields the tangent stack itself. +function make_tangent_stacks!(p::SharedDataPairs, ssa_insts::Dict{ID, NewInstruction}) + tangent_stacks = Dict{ID, Any}() + for (k, inst) in ssa_insts + Meta.isexpr(inst.stmt, :call) || Meta.isexpr(inst.stmt, :invoke) || continue + tangent_stacks[k] = __log_data(p, make_tangent_stack(_get_type(inst.type))) + end + return tangent_stacks +end + +# Shortcut for `add_data!(info.shared_data_pairs, data)`. +add_data!(info::ADInfo, data) = add_data!(info.shared_data_pairs, data) + +# Returns the static / inferred type associated to `x`. +get_primal_type(info::ADInfo, x::Argument) = info.arg_types[x] +get_primal_type(info::ADInfo, x::ID) = _get_type(info.ssa_insts[x].type) +get_primal_type(::ADInfo, x::QuoteNode) = _typeof(x.value) +get_primal_type(::ADInfo, x) = _typeof(x) +function get_primal_type(::ADInfo, x::GlobalRef) + return isconst(x) ? _typeof(getglobal(x.mod, x.name)) : x.binding.ty +end + +#= + ADStmtInfo + +Data structure which contains the result of `make_ad_stmts!`. Fields are +- `line`: the ID associated to the primal line from which this is derived +- `fwds`: the instructions which run the forwards-pass of AD +- `rvs`: the instructions which run the reverse-pass of AD / the pullback +=# +struct ADStmtInfo + line::ID + fwds::Vector{Tuple{ID, NewInstruction}} + rvs::Vector{Tuple{ID, NewInstruction}} +end + +# Convenient constructor for `ADStmtInfo`. If either `fwds` or `rvs` is not a vector, +# `__vec` promotes it to a single-element `Vector`. +ad_stmt_info(line::ID, fwds, rvs) = ADStmtInfo(line, __vec(line, fwds), __vec(line, rvs)) + +__vec(line::ID, x::Any) = __vec(line, new_inst(x)) +__vec(line::ID, x::NewInstruction) = Tuple{ID, NewInstruction}[(line, x)] +__vec(line::ID, x::Vector{Tuple{ID, Any}}) = throw(error("boooo")) +__vec(line::ID, x::Vector{Tuple{ID, NewInstruction}}) = x + +#= + make_ad_stmts(inst::NewInstruction, line::ID, info::ADInfo)::ADStmtInfo + +Every line in the primal code is associated to one or more lines in the forwards-pass of AD, +and one or more lines in the pullback. This function has method specific to every +node type in the Julia SSAIR. + +Translates the instruction `inst`, associated to `line` in the primal, into a specification +of what should happen for this instruction in the forwards- and reverse-passes of AD, and +what data should be shared between the forwards- and reverse-passes. Returns this in the +form of an `ADStmtInfo`. + +`info` is a data structure containing various bits of global information that certain types +of nodes need access to. +=# +function make_ad_stmts! end + +# `nothing` as a statement in Julia IR indicates the presence of a line which will later be +# removed. We emit a no-op on both the forwards- and reverse-passes. No shared data. +make_ad_stmts!(::Nothing, line::ID, ::ADInfo) = ad_stmt_info(line, nothing, nothing) + +# `ReturnNode`s have a single field, `val`, for which there are three cases to consider: +# +# 1. `val isa Union{Argument, ID}`: this is an active bit of data. Consequently, we know +# that it will be an `AugmentedRegister` already, and can just return it. Therefore `stmt` +# is returned as the forwards-pass (with any `Argument`s incremented), and nothing happens +# in the pullback. +# 2. `val` is undefined: this `ReturnNode` is unreachable. Consequently, we'll never hit the +# associated statements on the forwards-pass of pullback. We just return the original +# statement on the forwards-pass, and `nothing` on the reverse-pass. +# 3. `val` is defined, but not a `Union{Argument, ID}`: in this case we're returning a +# constant -- build a constant register and return that. +function make_ad_stmts!(stmt::ReturnNode, line::ID, info::ADInfo) + if !isdefined(stmt, :val) || is_active(stmt.val) + return ad_stmt_info(line, inc_args(stmt), nothing) + else + return ad_stmt_info(line, ReturnNode(const_register(stmt.val, info)), nothing) + end +end + +# Identity forwards-pass, no-op reverse. No shared data. +function make_ad_stmts!(stmt::IDGotoNode, line::ID, ::ADInfo) + return ad_stmt_info(line, inc_args(stmt), nothing) +end + +# Identity forwards-pass, no-op reverse. No shared data. +function make_ad_stmts!(stmt::IDGotoIfNot, line::ID, ::ADInfo) + stmt = inc_args(stmt) + if stmt.cond isa Union{Argument, ID} + # If cond refers to a register, then the primal must be extracted. + cond_id = ID() + fwds = [ + (cond_id, new_inst(Expr(:call, primal, stmt.cond))), + (line, new_inst(IDGotoIfNot(cond_id, stmt.dest), Any)), + ] + return ad_stmt_info(line, fwds, nothing) + else + # If something other than a register, then there is nothing to do. + return ad_stmt_info(line, stmt, nothing) + end +end + +# Identity forwards-pass, no-op reverse. No shared data. +function make_ad_stmts!(stmt::IDPhiNode, line::ID, info::ADInfo) + vals = stmt.values + new_vals = Vector{Any}(undef, length(vals)) + for n in eachindex(vals) + isassigned(vals, n) || continue + new_vals[n] = is_active(vals[n]) ? __inc(vals[n]) : const_register(vals[n], info) + end + + # It turns out to be really very important to do type inference correctly for PhiNodes. + # For some reason, type inference really doesn't like it when you encounter mutually- + # dependent PhiNodes whose types are unknown and for which you set the flag to + # CC.IR_FLAG_REFINED. + new_type = register_type(get_primal_type(info, line)) + _inst = new_inst(IDPhiNode(stmt.edges, new_vals), new_type, info.ssa_insts[line].flag) + return ad_stmt_info(line, _inst, nothing) +end + +function make_ad_stmts!(stmt::PiNode, line::ID, info::ADInfo) + isa(stmt.val, Union{Argument, ID}) || unhandled_feature("PiNode: $stmt") + + # Create line which sharpens the register type as much as possible. + sharp_primal_type = _get_type(stmt.typ) + sharpened_register_type = AugmentedRegister{codual_type(_get_type(sharp_primal_type))} + new_pi_line = ID() + new_pi = PiNode(__inc(stmt.val), sharpened_register_type) + + # Create a statement which moves data from the loosely-typed register to a more + # strictly typed one, which is possible because of the `PiNode`. + tangent_stack = make_tangent_stack(sharp_primal_type) + tangent_stack_id = add_data!(info, tangent_stack) + val_type = get_primal_type(info, stmt.val) + tangent_ref_stack = make_tangent_ref_stack(tangent_ref_type_ub(val_type)) + tangent_ref_stack_id = add_data!(info, tangent_ref_stack) + new_line = Expr(:call, __pi_fwds!, tangent_stack_id, tangent_ref_stack_id, new_pi_line) + + # Assemble the above lines and construct reverse-pass. + return ad_stmt_info( + line, + [(new_pi_line, new_inst(new_pi)), (line, new_inst(new_line))], + Expr(:call, __pi_rvs!, tangent_stack_id, tangent_ref_stack_id), + ) +end + +@inline function __pi_fwds!(tangent_stack, tangent_ref_stack, reg::AugmentedRegister) + push!(tangent_ref_stack, reg.tangent_ref) + push!(tangent_stack, tangent(reg.codual)) + return AugmentedRegister(reg.codual, top_ref(tangent_stack)) +end + +@inline function __pi_rvs!(tangent_stack, tangent_ref_stack) + increment_ref!(pop!(tangent_ref_stack), pop!(tangent_stack)) + return nothing +end + +# Constant GlobalRefs are handled. See const_register. Non-constant +# GlobalRefs are handled by assuming that they are constant, and creating a register with +# the value. We then check at run-time that the value has not changed. +function make_ad_stmts!(stmt::GlobalRef, line::ID, info::ADInfo) + if isconst(stmt) + return const_ad_stmt(stmt, line, info) + else + reg = const_register(getglobal(stmt.mod, stmt.name), info) + gref_id = ID() + fwds = [ + (gref_id, new_inst(stmt)), + (line, new_inst(Expr(:call, __verify_const, gref_id, reg))), + ] + return ad_stmt_info(line, fwds, nothing) + end +end + +# Helper used by `make_ad_stmts! ` for `GlobalRef`. +@noinline function __verify_const(global_ref, stored_value) + @assert global_ref == primal(stored_value) + return stored_value +end + +# QuoteNodes are constant. See make_const_register for details. +make_ad_stmts!(stmt::QuoteNode, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) + +# Literal constant. See const_register for details. +make_ad_stmts!(stmt, line::ID, info::ADInfo) = const_ad_stmt(stmt, line, info) + +# `make_ad_stmts!` for constants. +function const_ad_stmt(stmt, line::ID, info::ADInfo) + reg = const_register(stmt, info) + return ad_stmt_info(line, reg isa ID ? Expr(:call, identity, reg) : reg, nothing) +end + +# Build an `AugmentedRegister` from `stmt`, which will be checked to ensure that its value +# is constant. If the resulting register is a bits type, then it is returned. If it is not, +# then the register is put into shared data, and the ID associated to it in the forwards- +# and reverse-passes returned. +function const_register(stmt, info::ADInfo) + reg = build_const_reg(stmt) + return isbitstype(_typeof(reg)) ? reg : add_data!(info, reg) +end + +# Create a constant augmented register which lives in the shared data. Returns the `ID` +# which will be associated to this data in the forwards- and reverse-passes. +shared_data_const_reg(stmt, info::ADInfo) = add_data!(info, build_const_reg(stmt)) + +# Create an `AugmentedRegister` containing the values associated to `stmt`, a zero tangent. +# Pushes a single element onto the stack, and puts a reference to that stack in the +# register. +function build_const_reg(stmt) + primal_value = get_const_primal_value(stmt) + tangent_stack = make_tangent_stack(_typeof(primal_value)) + tangent = uninit_tangent(primal_value) + push!(tangent_stack, tangent) + return AugmentedRegister(CoDual(primal_value, tangent), top_ref(tangent_stack)) +end + +# Get the value associated to `x`. For `GlobalRef`s, verify that `x` is indeed a constant, +# and error if it is not. +function get_const_primal_value(x::GlobalRef) + isconst(x) || unhandled_feature("Non-constant GlobalRef not supported: $x") + return getglobal(x.mod, x.name) +end +get_const_primal_value(x::QuoteNode) = x.value +get_const_primal_value(x) = x + +# Taped does not yet handle `PhiCNode`s. Throw an error if one is encountered. +function make_ad_stmts!(stmt::Core.PhiCNode, ::ID, ::ADInfo) + unhandled_feature("Encountered PhiCNode: $stmt") +end + +# Taped does not yet handle `UpsilonNode`s. Throw an error if one is encountered. +function make_ad_stmts!(stmt::Core.UpsilonNode, ::ID, ::ADInfo) + unhandled_feature("Encountered UpsilonNode: $stmt") +end + +# There are quite a number of possible `Expr`s that can be encountered. Each case has its +# own comment, explaining what is going on. +function make_ad_stmts!(stmt::Expr, line::ID, info::ADInfo) + is_invoke = Meta.isexpr(stmt, :invoke) + if Meta.isexpr(stmt, :call) || is_invoke + + # Find the types of all arguments to this call / invoke. + args = ((is_invoke ? stmt.args[2:end] : stmt.args)..., ) + arg_types = map(arg -> get_primal_type(info, arg), args) + + # Construct signature, and determine how the rrule is to be computed. + sig = Tuple{arg_types...} + rule = if is_primitive(context_type(info.interp), sig) + rrule!! # intrinsic / builtin / thing we provably have rule for + elseif is_invoke + LazyDerivedRule(info.interp, sig) # Static dispatch + else + DynamicDerivedRule(info.interp) # Dynamic dispatch + end + + # If the rule is `rrule!!` (i.e. `sig` is primitive), then don't bother putting + # the rule into shared data, because it's safe to put it directly into the code. + rule_ref = __log_data(info, rule) + + # Tangent stacks are allocated in build_rrule, and stored in the `info`. Just + # retrieve the stack associated to the tangent returned from this line. + ret_tangent_stack_id = info.tangent_stacks[line] + + # If the type of the pullback is a singleton type, then there is no need to store it + # in the shared data, it can be interpolated directly into the generated IR. + pb_stack = build_pb_stack(_typeof(rule), arg_types) + pb_stack_id = __log_data(info, pb_stack) + + # if the pullback is a `NoPullback`, then there is no need to log the references to + # the tangent stacks associated to the inputs to this call, because there will never + # need to be any incrementing done. There are functions called within + # `__fwds_pass!` and `__rvs_pass!` that specialise on the type of the pullback to + # avoid ever using the arg tangent ref stacks, so we just need to create a default + # value here (`nothing`), as it will never be used. + arg_tangent_ref_stacks_id = ID() + if pb_stack isa SingletonStack{NoPullback} + arg_tangent_ref_stacks = nothing + else + ref_stacks = map(arg_types, args) do arg_type, arg + stack = __make_arg_tangent_ref_stack(arg_type, arg) + if Base.issingletontype(_typeof(stack)) + return stack + elseif haskey(info.tangent_stacks, arg) + return info.tangent_stacks[arg] + elseif arg isa Argument + return info.arg_tangent_stacks[arg] + else + return add_data!(info, stack) + end + end + arg_tangent_ref_stacks = Expr(:call, __tangent_ref_stacks, ref_stacks...) + end + + # Create calls to `__fwds_pass!` and `__rvs_pass!`, which run the forwards pass and + # pullback associated to a call / invoke. + fwds_pass_call = Expr( + :call, + __fwds_pass!, + arg_tangent_ref_stacks_id, + rule_ref, + ret_tangent_stack_id, + pb_stack_id, + register_type(get_primal_type(info, line)), + map(__inc, args)..., + ) + + rvs_pass_call = Expr( + :call, __rvs_pass!, arg_tangent_ref_stacks_id, ret_tangent_stack_id, pb_stack_id + ) + + fwds = [ + (arg_tangent_ref_stacks_id, new_inst(arg_tangent_ref_stacks)), + (line, new_inst(fwds_pass_call)), + ] + rvs = Tuple{ID, NewInstruction}[ + (arg_tangent_ref_stacks_id, new_inst(arg_tangent_ref_stacks)), + (line, new_inst(rvs_pass_call)), + ] + return ad_stmt_info(line, fwds, rvs) + + elseif Meta.isexpr(stmt, :boundscheck) + # For some reason the compiler cannot handle boundscheck statements when we run it + # again. Consequently, emit `true` to be safe. Ideally we would handle this in a + # more natural way, but I'm not sure how to do that. + tmp = AugmentedRegister(zero_codual(true), NoTangentStack()) + return ad_stmt_info(line, tmp, nothing) + + elseif Meta.isexpr(stmt, :code_coverage_effect) + # Code coverage irrelevant for derived code. + return ad_stmt_info(line, nothing, nothing) + + elseif Meta.isexpr(stmt, :copyast) + # Get constant out and shove it in shared storage. + reg = const_register(stmt.args[1], info) + return ad_stmt_info(line, Expr(:call, identity, reg), nothing) + + elseif Meta.isexpr(stmt, :loopinfo) + # Cannot pass loopinfo back through the optimiser for some reason. + # At the time of writing, I am unclear why this is not possible. + return ad_stmt_info(line, nothing, nothing) + + elseif stmt.head in [ + :enter, + :gc_preserve_begin, + :gc_preserve_end, + :leave, + :pop_exception, + :throw_undef_if_not, + ] + # Expressions which do not require any special treatment. + return ad_stmt_info(line, stmt, nothing) + else + # Encountered an expression that we've not seen before. + throw(error("Unrecognised expression $stmt")) + end +end + +# Used in `make_ad_stmts!` for call and invoke exprs. If an argument to the stmt is active, +# then we grab its tangnet ref stack. If it's inactive (a constant of some kind -- really +# anything that's not an `Argument` or an `ID`), then we create a dummy stack that will get +# optimised away. +function __make_arg_tangent_ref_stack(arg_type, arg) + is_active(arg) || return InactiveStack(InactiveRef(__zero_tangent(arg))) + return make_tangent_ref_stack(tangent_ref_type_ub(arg_type)) +end + +@inline function __tangent_ref_stacks(args::Vararg{Any, N}) where {N} + return tuple_map(___tangent_ref_stacks_helper, args) +end + +# Distinguish between tangent stacks and tangent ref stacks based on their type. If we see +# a type which looks like a tangent ref stack, just return it. If we see any other type, +# assume it is a tangent stack, meaning that the tangent stack is fixed. +# This is bit of a hack -- ideally we would get the code construction in `make_ad_stmts!` to +# determine this, as doing this based on type is potentially flakey. It will have to do for +# now though. +@inline @generated function ___tangent_ref_stacks_helper(arg::P) where {P} + if P <: Union{InactiveStack, Stack{<:Ref}, NoTangentRefStack} && !(P <: Stack{<:Ptr}) + return :(arg) + else + return :(FixedStackTangentRefStack(arg)) + end +end + +is_active(::Union{Argument, ID}) = true +is_active(::Any) = false + +__zero_tangent(arg) = zero_tangent(arg) +__zero_tangent(arg::GlobalRef) = zero_tangent(getglobal(arg.mod, arg.name)) +__zero_tangent(arg::QuoteNode) = zero_tangent(arg.value) + +# Build a stack to contain the pullback. Specialises on whether the pullback is a singleton, +# and whether we get to know the concrete type of the pullback or not. +function build_pb_stack(Trule, arg_types) + T_pb!! = Core.Compiler.return_type(Tuple{Trule, map(codual_type, arg_types)...}) + if T_pb!! <: Tuple && T_pb!! !== Union{} && !(T_pb!! isa Union) + F = T_pb!!.parameters[2] + return Base.issingletontype(F) ? SingletonStack{F}() : Stack{F}() + else + return Stack{Any}() + end +end + +# Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`. +@inline function __fwds_pass!( + arg_tangent_ref_stacks, + rule, + ret_tangent_stack, + pb_stack, + ::Type{R}, + f::F, + raw_args::Vararg{Any, N}, +) where {R, F, N} + + raw_args = (f, raw_args...) + __log_tangent_refs!(pb_stack, raw_args, arg_tangent_ref_stacks) + + # Run the rule. + args = tuple_map(x -> isa(x, AugmentedRegister) ? x.codual : uninit_codual(x), raw_args) + out, pb!! = rule(args...) + + # Log the results and return. + __push_tangent_stack!(ret_tangent_stack, tangent(out)) + __push_pb_stack!(pb_stack, pb!!) + return AugmentedRegister(out, top_ref(ret_tangent_stack))::R +end + +@inline function __log_tangent_refs!(::Any, raw_args, arg_tangent_ref_stacks) + tangent_refs = map(x -> isa(x, AugmentedRegister) ? x.tangent_ref : nothing, raw_args) + tuple_map(__push_ref_stack, arg_tangent_ref_stacks, tangent_refs) +end + +@inline __log_tangent_refs!(::SingletonStack{NoPullback}, ::Any, ::Any) = nothing + +@inline __push_ref_stack(tangent_ref_stack, ref) = push!(tangent_ref_stack, ref) +@inline __push_ref_stack(::InactiveStack, ref) = nothing +@inline __push_ref_stack(::NoTangentRefStack, ref) = nothing + +@inline __push_tangent_stack!(stack, t) = push!(stack, t) +@inline __push_pb_stack!(stack, pb!!) = push!(stack, pb!!) + +# Used in `make_ad_stmts!` method for `Expr(:call, ...)` and `Expr(:invoke, ...)`. +@inline function __rvs_pass!(arg_tangent_ref_stacks, ret_tangent_stack, pb_stack)::Nothing + pb = __pop_pb_stack!(pb_stack) + tngt = __pop_tangent_stack!(ret_tangent_stack) + __execute_reverse_pass!(pb, tngt, arg_tangent_ref_stacks) +end + +@inline __execute_reverse_pass!(::NoPullback, ::Any, ::Any) = nothing + +@inline function __execute_reverse_pass!(pb!!, dout, arg_tangent_ref_stacks) + # Get the tangent w.r.t. each argument of the primal. + tangent_refs = tuple_map(pop!, arg_tangent_ref_stacks) + + # Run the pullback and increment the argument tangents. + dargs = tuple_map(set_immutable_to_zero ∘ getindex, tangent_refs) + new_dargs = pb!!(dout, dargs...) + tuple_map(increment_ref!, tangent_refs, new_dargs) + return nothing +end + +@inline __pop_pb_stack!(stack) = pop!(stack) +@inline __pop_tangent_stack!(tangent_stack) = pop!(tangent_stack) + +# +# Runners for generated code. +# + +struct Pullback{Tpb, Tret_ref, Targ_tangent_stacks, Tisva, Tnargs} + pb_oc::Tpb + ret_ref::Tret_ref + arg_tangent_stacks::Targ_tangent_stacks + isva::Tisva + nargs::Tnargs +end + +@inline function (pb::Pullback{P, Q})(dy, dargs::Vararg{Any, N}) where {P, Q, N} + unflattened_dargs = __unflatten_varargs(pb.isva, dargs, pb.nargs) + map(setindex!, map(top_ref, pb.arg_tangent_stacks), unflattened_dargs) + increment_ref!(pb.ret_ref, dy) + pb.pb_oc(dy, unflattened_dargs...) + out = __flatten_varargs(pb.isva, map(pop!, pb.arg_tangent_stacks), nvargs(length(dargs), pb.nargs)) + return out::_typeof(dargs) +end + +@inline nvargs(n_flat, ::Val{nargs}) where {nargs} = Val(n_flat - nargs + 1) + +struct DerivedRule{Tfwds_oc, Targ_tangent_stacks, Tpb_oc, Tisva<:Val, Tnargs<:Val} + fwds_oc::Tfwds_oc + pb_oc::Tpb_oc + arg_tangent_stacks::Targ_tangent_stacks + block_stack::Stack{Int32} + isva::Tisva + nargs::Tnargs +end + +@inline function (fwds::DerivedRule{P, Q, S})(args::Vararg{CoDual, N}) where {P, Q, S, N} + + # Load arguments in to stacks, and create tuples. + args = __unflatten_codual_varargs(fwds.isva, args, fwds.nargs) + args_with_tangent_stacks = map(args, fwds.arg_tangent_stacks) do arg, arg_tangent_stack + push!(arg_tangent_stack, tangent(arg)) + return AugmentedRegister(arg, top_ref(arg_tangent_stack)) + end + + # Run forwards-pass. + reg = fwds.fwds_oc(args_with_tangent_stacks...)::AugmentedRegister + + # Extract result and assemble pullback. + pb!! = Pullback(fwds.pb_oc, reg.tangent_ref, fwds.arg_tangent_stacks, fwds.isva, fwds.nargs) + return reg.codual, pb!! +end + +# Compute the concrete type of the rule that will be returned from `build_rrule`. This is +# important for performance in dynamic dispatch, and to ensure that recursion works +# properly. +function rule_type(interp::TapedInterpreter{C}, ::Type{sig}) where {C, sig} + is_primitive(C, sig) && return typeof(rrule!!) + + ir, _ = lookup_ir(interp, sig) + Treturn = Base.Experimental.compute_ir_rettype(ir) + isva, _ = is_vararg_sig_and_sparam_names(sig) + + arg_types = map(_get_type, ir.argtypes) + arg_tangent_types = map(tangent_type, arg_types) + Targ_registers = Tuple{map(tangent_stack_type ∘ _get_type, ir.argtypes)...} + Treturn_register = register_type(Treturn) + if isconcretetype(Treturn_register) + return DerivedRule{ + Core.OpaqueClosure{Tuple{map(register_type, arg_types)...}, Treturn_register}, + Targ_registers, + Core.OpaqueClosure{Tuple{tangent_type(Treturn), arg_tangent_types...}, Nothing}, + Val{isva}, + Val{length(ir.argtypes)}, + } + else + return DerivedRule{ + Core.OpaqueClosure{Tuple{map(register_type, arg_types)...}, T} where {T<:Treturn_register}, + Targ_registers, + Core.OpaqueClosure{Tuple{tangent_type(Treturn), arg_tangent_types...}, Nothing}, + Val{isva}, + Val{length(ir.argtypes)}, + } + end +end + +# if isva and nargs=2, then inputs (5.0, 4.0, 3.0) are transformed into (5.0, (4.0, 3.0)). +function __unflatten_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs} + isva || return args + if all(t -> t isa NoTangent, args[nargs:end]) + return (args[1:nargs-1]..., NoTangent()) + else + return (args[1:nargs-1]..., args[nargs:end]) + end +end + +# If isva, inputs (5.0, (4.0, 3.0)) are transformed into (5.0, 4.0, 3.0). +function __flatten_varargs(::Val{isva}, args, ::Val{nvargs}) where {isva, nvargs} + isva || return args + if args[end] isa NoTangent + return (args[1:end-1]..., ntuple(n -> NoTangent(), nvargs)...) + else + return (args[1:end-1]..., args[end]...) + end +end + +# If isva and nargs=2, then inputs `(CoDual(5.0, 0.0), CoDual(4.0, 0.0), CoDual(3.0, 0.0))` +# are transformed into `(CoDual(5.0, 0.0), CoDual((5.0, 4.0), (0.0, 0.0)))`. +function __unflatten_codual_varargs(::Val{isva}, args, ::Val{nargs}) where {isva, nargs} + isva || return args + group_primal = map(primal, args[nargs:end]) + if tangent_type(_typeof(group_primal)) == NoTangent + grouped_args = zero_codual(group_primal) + else + grouped_args = CoDual(group_primal, map(tangent, args[nargs:end])) + end + return (args[1:nargs-1]..., grouped_args) +end + +""" + build_rrule(args...) + +Helper method. Only uses static information from `args`. +""" +function build_rrule(args...) + return build_rrule(TapedInterpreter(), _typeof(TestUtils.__get_primals(args))) +end + +""" + build_rrule(interp::TInterp{C}, sig::Type{<:Tuple}) where {C} + +Returns a `DerivedRule` which is an `rrule!!` for `sig` in context `C`. See the docstring +for `rrule!!` for more info. +""" +function build_rrule(interp::TInterp{C}, sig::Type{<:Tuple}) where {C} + + # Reset id count. This ensures that everything in this function is deterministic. + seed_id!() + + # If we have a hand-coded rule, just use that. + is_primitive(C, sig) && return rrule!! + + # Grab code associated to the primal. + ir, _ = lookup_ir(interp, sig) + Treturn = Base.Experimental.compute_ir_rettype(ir) + + # Normalise the IR, and generated BBCode version of it. + isva, spnames = is_vararg_sig_and_sparam_names(sig) + ir = normalise!(ir, spnames) + primal_ir = BBCode(ir) + + # Compute global info. + arg_types = Dict{Argument, Any}( + map(((n, t),) -> (Argument(n) => _get_type(t)), enumerate(ir.argtypes)) + ) + insts = new_inst_vec(ir.stmts) + ssa_types = Dict{ID, NewInstruction}(zip(concatenate_ids(primal_ir), insts)) + arg_tangent_stacks = (map(make_tangent_stack ∘ _get_type, primal_ir.argtypes)..., ) + info = ADInfo(interp, arg_types, ssa_types, arg_tangent_stacks) + + # For each block in the fwds and pullback BBCode, translate all statements. + ad_stmts_blocks = map(primal_ir.blocks) do primal_blk + ids = primal_blk.inst_ids + primal_stmts = map(x -> x.stmt, primal_blk.insts) + return (primal_blk.id, make_ad_stmts!.(primal_stmts, ids, Ref(info))) + end + + # Make shared data, and construct BBCode for forwards-pass and pullback. + shared_data = shared_data_tuple(info.shared_data_pairs) + # display(sig) + # @show length(shared_data) + # @show length(ir.stmts.inst) + # display(collect(_typeof(shared_data).parameters)) + + # If we've already derived the OpaqueClosures and info, do not re-derive, just create a + # copy and pass in new shared data. + if !haskey(interp.oc_cache, sig) + fwds_ir = forwards_pass_ir(primal_ir, ad_stmts_blocks, info, _typeof(shared_data)) + pb_ir = pullback_ir(primal_ir, Treturn, ad_stmts_blocks, info, _typeof(shared_data)) + optimised_fwds_ir = optimise_ir!(IRCode(fwds_ir); do_inline=true) + optimised_pb_ir = optimise_ir!(IRCode(pb_ir); do_inline=true) + # @show length(optimised_fwds_ir.stmts.inst) + # @show length(optimised_pb_ir.stmts.inst) + # display(ir) + # display(optimised_fwds_ir) + # display(optimised_pb_ir) + fwds_oc = OpaqueClosure(optimised_fwds_ir, shared_data...; do_compile=true) + pb_oc = OpaqueClosure(optimised_pb_ir, shared_data...; do_compile=true) + interp.oc_cache[sig] = (fwds_oc, pb_oc) + else + existing_fwds_oc, existing_pb_oc = interp.oc_cache[sig] + fwds_oc = replace_captures(existing_fwds_oc, shared_data) + pb_oc = replace_captures(existing_pb_oc, shared_data) + end + + return rule_type(interp, sig)( + fwds_oc, + pb_oc, + arg_tangent_stacks, + info.block_stack, + Val(isva), + Val(length(ir.argtypes)), + ) +end + +# Given an `OpaqueClosure` `oc`, create a new `OpaqueClosure` of the same type, but with new +# captured variables. This is needed for efficiency reasons -- if `build_rrule` is called +# repeatedly with the same signature and intepreter, it is important to avoid recompiling +# the `OpaqueClosure`s that it produces multiple times, because it can be quite expensive to +# do so. +@eval function replace_captures(oc::Toc, new_captures) where {Toc<:Core.OpaqueClosure} + return $(Expr( + :new, :(Toc), :new_captures, :(oc.world), :(oc.source), :(oc.invoke), :(oc.specptr) + )) +end + +const ADStmts = Vector{Tuple{ID, Vector{ADStmtInfo}}} + +#= + forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) + +Produce the IR associated to the `OpaqueClosure` which runs most of the forwards-pass. +=# +function forwards_pass_ir(ir::BBCode, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) + + # Insert a block at the start which extracts all items from the captures field of the + # `OpaqueClosure`, which contains all of the data shared between the forwards- and + # reverse-passes. These are assigned to the `ID`s given by the `SharedDataPairs`. + # Additionally, push the entry id onto the block stack. + push_block_stack_stmt = Expr(:call, push!, info.block_stack_id, info.entry_id.id) + entry_stmts = vcat( + shared_data_stmts(info.shared_data_pairs), + (ID(), new_inst(push_block_stack_stmt)), + ) + entry_block = BBlock(info.entry_id, entry_stmts) + + # Construct augmented version of each basic block from the primal. For each block: + # 1. pull the translated basic block statements from ad_stmts_blocks. + # 2. insert a statement which logs the ID of the current block to the block stack. + # 3. construct and return a BBlock. + blocks = map(ad_stmts_blocks) do (block_id, ad_stmts) + fwds_stmts = reduce(vcat, map(x -> x.fwds, ad_stmts)) + ins_loc = length(fwds_stmts) + (isa(fwds_stmts[end][2].stmt, Terminator) ? 0 : 1) + ins_stmt = Expr(:call, __push_blk_stack!, info.block_stack_id, block_id.id) + ins_inst = (ID(), new_inst(ins_stmt)) + return BBlock(block_id, insert!(fwds_stmts, ins_loc, ins_inst)) + end + + # Create and return the `BBCode` for the forwards-pass. + arg_types = vcat(Tshared_data, map(register_type ∘ _get_type, ir.argtypes)) + return BBCode(vcat(entry_block, blocks), arg_types, ir.sptypes, ir.linetable, ir.meta) +end + +@noinline __push_blk_stack!(block_stack::Stack{Int32}, id::Int32) = push!(block_stack, id) + +#= + pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) + +Produce the IR associated to the `OpaqueClosure` which runs most of the pullback. +=# +function pullback_ir(ir::BBCode, Tret, ad_stmts_blocks::ADStmts, info::ADInfo, Tshared_data) + + # Compute the argument types associated to the reverse-pass. + darg_types = map(tangent_type ∘ _get_type, ir.argtypes) + arg_types = vcat(Tshared_data, tangent_type(Tret), darg_types) + + # Compute the blocks which return in the primal. + primal_exit_blocks_inds = findall(is_reachable_return_node ∘ terminator, ir.blocks) + + # + # Short-circuit for non-terminating primals -- applies to a tiny fraction of primals: + # + + # If there are no blocks which successfully return in the primal, then the primal never + # terminates without throwing, meaning that if AD hits this function, it definitely + # won't succeed on the forwards-pass. As such, the reverse-pass can just be a no-op. + if isempty(primal_exit_blocks_inds) + blocks = [BBlock(ID(), [(ID(), new_inst(ReturnNode(nothing)))])] + return BBCode(blocks, arg_types, ir.sptypes, ir.linetable, ir.meta) + end + + # + # Standard path pullback generation -- applied to 99% of primals: + # + + # Create entry block, which pops the block_stack, and switches to whichever block we + # were in at the end of the forwards-pass. + exit_blocks_ids = map(n -> ir.blocks[n].id, primal_exit_blocks_inds) + data_stmts = shared_data_stmts(info.shared_data_pairs) + switch_stmts = make_switch_stmts(exit_blocks_ids, info) + entry_block = BBlock(ID(), vcat(data_stmts, switch_stmts)) + + # For each basic block in the primal: + # 1. pull the translated basic block statements from ad_stmts_blocks + # 2. reverse the statements + # 3. pop block stack to get the predecessor block + # 4. insert a switch statement to determine which block to jump to. Restrict blocks + # considered to only those which are predecessors of this one. If in the first block, + # check whether or not the block stack is empty. If empty, jump to the exit block. + ps = compute_all_predecessors(ir) + main_blocks = map(ad_stmts_blocks, enumerate(ir.blocks)) do (blk_id, ad_stmts), (n, blk) + rvs_stmts = reduce(vcat, [x.rvs for x in reverse(ad_stmts)]) + pred_ids = vcat(ps[blk.id], n == 1 ? [info.entry_id] : ID[]) + switch_stmts = make_switch_stmts(pred_ids, info) + return BBlock(blk_id, vcat(rvs_stmts, switch_stmts)) + end + + # Create an exit block. Simply returns nothing. + exit_block = BBlock(info.entry_id, [(ID(), new_inst(ReturnNode(nothing)))]) + + # Create and return `BBCode` for the pullback. + blks = vcat(entry_block, main_blocks, exit_block) + return _sort_blocks!(BBCode(blks, arg_types, ir.sptypes, ir.linetable, ir.meta)) +end + +#= + make_switch_stmts(pred_ids::Vector{ID}, info::ADInfo) + +`preds_ids` comprises the `ID`s associated to all possible predecessor blocks to the primal +block under consideration. Suppose its value is `[ID(1), ID(2), ID(3)]`, then +`make_switch_stmts` emits code along the lines of + +```julia +prev_block = pop!(block_stack) +not_pred_was_1 = !(prev_block == ID(1)) +not_pred_was_2 = !(prev_block == ID(2)) +switch( + not_pred_was_1 => ID(1), + not_pred_was_2 => ID(2), + ID(3) +) +``` + +In words: `make_switch_stmts` emits code which jumps to whichever block preceded the current +block during the forwards-pass. +=# +function make_switch_stmts(pred_ids::Vector{ID}, info::ADInfo) + + # If there are no predecessors, then we can't possible have hit this block. This can + # happen when all of the statements in a block have been eliminated, but the Julia + # optimiser has not removed the block entirely from the `IRCode`. This often presents as + # a block containing only a single `nothing` statement. + # Consequently, we just direct this block back towards the entry node. This is safe, as + # this block will never get hit, and ensures that the block is safe under re-ordering. + isempty(pred_ids) && return [(ID(), new_inst(IDGotoNode(info.entry_id)))] + + # Get the predecessor that we actually had in the primal. + prev_blk_id = ID() + prev_blk = new_inst(Expr(:call, __pop_blk_stack!, info.block_stack_id)) + + # Compare predecessor from primal with all possible predecessors. + conds = map(pred_ids[1:end-1]) do id + return (ID(), new_inst(Expr(:call, __switch_case, id.id, prev_blk_id))) + end + + # Switch statement to change to the predecessor. + switch_stmt = Switch(Any[c[1] for c in conds], pred_ids[1:end-1], pred_ids[end]) + switch = (ID(), new_inst(switch_stmt)) + + return vcat((prev_blk_id, prev_blk), conds, switch) +end + +@noinline __pop_blk_stack!(block_stack::Stack{Int32}) = pop!(block_stack) + +# Helper function emitted by `make_switch_stmts`. +__switch_case(id::Int32, predecessor_id::Int32) = !(id === predecessor_id) + + +#= + DynamicDerivedRule(interp::TapedInterpreter) + +For internal use only. + +A callable data structure which, when invoked, calls an rrule specific to the dynamic types +of its arguments. Stores rules in an internal cache to avoid re-deriving. + +This is used to implement dynamic dispatch. +=# +struct DynamicDerivedRule{T, V} + interp::T + cache::V +end + +DynamicDerivedRule(interp::TapedInterpreter) = DynamicDerivedRule(interp, Dict{Any, Any}()) + +function (dynamic_rule::DynamicDerivedRule)(args::Vararg{Any, N}) where {N} + sig = Tuple{map(_typeof, map(primal, args))...} + is_primitive(context_type(dynamic_rule.interp), sig) && return rrule!!(args...) + rule = get(dynamic_rule.cache, sig, nothing) + if rule === nothing + rule = build_rrule(dynamic_rule.interp, sig) + dynamic_rule.cache[sig] = rule + end + return rule(args...) +end + +#= + LazyDerivedRule(interp, sig) + +For internal use only. + +A type-stable wrapper around a `DerivedRule`, which only instantiates the `DerivedRule` +when it is first called. This is useful, as it means that if a rule does not get run, it +does not have to be derived. +=# +mutable struct LazyDerivedRule{Trule, T, V} + interp::T + sig::V + rule::Trule + function LazyDerivedRule(interp::T, sig::V) where {T<:TInterp, V<:Type{<:Tuple}} + return new{rule_type(interp, sig), T, V}(interp, sig) + end +end + +function (rule::LazyDerivedRule)(args::Vararg{Any, N}) where {N} + if !isdefined(rule, :rule) + rule.rule = build_rrule(rule.interp, rule.sig) + end + return rule.rule(args...) +end diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index f1600c1b7..ba48cba2b 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -600,7 +600,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) end, )), - # trmm! + # trsm! vec(reduce( vcat, map( diff --git a/src/rrules/builtins.jl b/src/rrules/builtins.jl index e4f20745c..8a0b77165 100644 --- a/src/rrules/builtins.jl +++ b/src/rrules/builtins.jl @@ -489,6 +489,15 @@ function rrule!!(::CoDual{typeof(Core.ifelse)}, cond, a, b) return ifelse(_cond, a, b), ifelse_pullback!! end +function rrule!!( + ::CoDual{typeof(Core.ifelse)}, + cond, + a::CoDual{<:Any, NoTangent}, + b::CoDual{<:Any, NoTangent}, +) + return ifelse(primal(cond), a, b), NoPullback() +end + # Core.set_binding_type! function rrule!!(::CoDual{typeof(Core.sizeof)}, x) @@ -519,6 +528,10 @@ function rrule!!(::CoDual{typeof(getfield)}, value::CoDual, name::CoDual) return y, getfield_pullback end +@inline function rrule!!(::CoDual{typeof(getfield)}, value::CoDual{<:Any, NoTangent}, name::CoDual) + return uninit_codual(getfield(primal(value), primal(name))), NoPullback() +end + function rrule!!(::CoDual{typeof(getfield)}, value::CoDual, name::CoDual, order::CoDual) _name = primal(name) _order = primal(order) @@ -534,6 +547,12 @@ function rrule!!(::CoDual{typeof(getfield)}, value::CoDual, name::CoDual, order: return y, getfield_pullback end +@inline function rrule!!( + ::CoDual{typeof(getfield)}, value::CoDual{<:Any, NoTangent}, name::CoDual, order::CoDual +) + return uninit_codual(getfield(primal(value), primal(name), primal(order))), NoPullback() +end + _get_tangent_field(_, tangent, f...) = getfield(tangent, f...) function _get_tangent_field(_, tangent::Union{Tangent, MutableTangent}, f...) return _value(getfield(tangent.fields, f...)) @@ -585,13 +604,34 @@ function rrule!!(::CoDual{typeof(setfield!)}, value, name, x) return y, setfield!_pullback end +function rrule!!( + ::CoDual{typeof(setfield!)}, value::CoDual{<:Any, NoTangent}, name, x +) + _name = primal(name) + save = isdefined(primal(value), _name) + old_x = save ? getfield(primal(value), _name) : nothing + function setfield!_pullback(dy, df, dvalue, ::NoTangent, dx) + old_x !== nothing && setfield!(primal(value), _name, old_x) + return df, dvalue, NoTangent(), dx + end + y = CoDual(setfield!(primal(value), _name, primal(x)), NoTangent()) + return y, setfield!_pullback +end + # swapfield! # throw -function rrule!!(::CoDual{typeof(tuple)}, args...) - y = CoDual(tuple(map(primal, args)...), tuple(map(tangent, args)...)) - tuple_pullback(dy, ::NoTangent, dargs...) = NoTangent(), map(increment!!, dargs, dy)... - return y, tuple_pullback +@inline function tuple_pullback(dy, ::NoTangent, dargs...) + return NoTangent(), tuple_map(increment!!, dargs, dy)... +end + +function rrule!!(::CoDual{typeof(tuple)}, args::Vararg{Any, N}) where {N} + primal_output = tuple(map(primal, args)...) + if tangent_type(_typeof(primal_output)) == NoTangent + return zero_codual(primal_output), NoPullback() + else + return CoDual(primal_output, tuple(map(tangent, args)...)), tuple_pullback + end end function rrule!!(::CoDual{typeof(typeassert)}, x, type) @@ -754,6 +794,8 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) # Core.get_binding_type -- NEEDS IMPLEMENTING AND TESTING [false, :none, nothing, Core.ifelse, true, randn(5), 1], [false, :none, nothing, Core.ifelse, false, randn(5), 2], + (false, :stability, nothing, Core.ifelse, true, 5, 4), + (false, :stability, nothing, Core.ifelse, false, true, false), [false, :stability, nothing, Core.ifelse, false, 1.0, 2.0], [false, :stability, nothing, Core.ifelse, true, 1.0, 2.0], [false, :stability, nothing, Core.ifelse, false, randn(5), randn(3)], @@ -796,21 +838,27 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) [false, :stability, nothing, applicable, sin, Type], [false, :stability, nothing, applicable, +, Type, Float64], [false, :stability, nothing, applicable, +, Float64, Float64], - [false, :stability, (lb=1e-3, ub=20.0), fieldtype, TestResources.StructFoo, :a], - [false, :stability, (lb=1e-3, ub=20.0), fieldtype, TestResources.StructFoo, :b], - [false, :stability, (lb=1e-3, ub=20.0), fieldtype, TestResources.MutableFoo, :a], - [false, :stability, (lb=1e-3, ub=20.0), fieldtype, TestResources.MutableFoo, :b], - [true, :none, _range, getfield, TestResources.StructFoo(5.0), :a], - [false, :none, _range, getfield, TestResources.StructFoo(5.0, randn(5)), :a], - [false, :none, _range, getfield, TestResources.StructFoo(5.0, randn(5)), :b], - [true, :none, _range, getfield, TestResources.StructFoo(5.0), 1], - [false, :none, _range, getfield, TestResources.StructFoo(5.0, randn(5)), 1], - [false, :none, _range, getfield, TestResources.StructFoo(5.0, randn(5)), 2], - [true, :none, _range, getfield, TestResources.MutableFoo(5.0), :a], - [false, :none, _range, getfield, TestResources.MutableFoo(5.0, randn(5)), :b], + [false, :stability, (lb=1e-3, ub=20.0), fieldtype, StructFoo, :a], + [false, :stability, (lb=1e-3, ub=20.0), fieldtype, StructFoo, :b], + [false, :stability, (lb=1e-3, ub=20.0), fieldtype, MutableFoo, :a], + [false, :stability, (lb=1e-3, ub=20.0), fieldtype, MutableFoo, :b], + [true, :none, _range, getfield, StructFoo(5.0), :a], + [false, :none, _range, getfield, StructFoo(5.0, randn(5)), :a], + [false, :none, _range, getfield, StructFoo(5.0, randn(5)), :b], + [true, :none, _range, getfield, StructFoo(5.0), 1], + [false, :none, _range, getfield, StructFoo(5.0, randn(5)), 1], + [false, :none, _range, getfield, StructFoo(5.0, randn(5)), 2], + [true, :none, _range, getfield, MutableFoo(5.0), :a], + [false, :none, _range, getfield, MutableFoo(5.0, randn(5)), :b], [false, :none, _range, getfield, UnitRange{Int}(5:9), :start], [false, :none, _range, getfield, UnitRange{Int}(5:9), :stop], [false, :none, _range, getfield, (5.0, ), 1, false], + (false, :none, _range, getfield, (1, ), 1, false), + (false, :none, _range, getfield, (1, 2), 1), + (false, :none, _range, getfield, (a=5, b=4), 1), + (false, :none, _range, getfield, (a=5, b=4), 2), + (false, :none, _range, getfield, (a=5.0, b=4), 1), + (false, :none, _range, getfield, (a=5.0, b=4), 2), [false, :none, _range, getfield, UInt8, :name], [false, :none, _range, getfield, UInt8, :super], [true, :none, _range, getfield, UInt8, :layout], @@ -820,38 +868,27 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins}) # invoke -- NEEDS IMPLEMENTING AND TESTING [false, :stability, nothing, isa, 5.0, Float64], [false, :stability, nothing, isa, 1, Float64], - [false, :stability, nothing, isdefined, TestResources.MutableFoo(5.0, randn(5)), :sim], - [false, :stability, nothing, isdefined, TestResources.MutableFoo(5.0, randn(5)), :a], + [false, :stability, nothing, isdefined, MutableFoo(5.0, randn(5)), :sim], + [false, :stability, nothing, isdefined, MutableFoo(5.0, randn(5)), :a], # modifyfield! -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, nfields, TestResources.MutableFoo], - [false, :stability, nothing, nfields, TestResources.StructFoo], + [false, :stability, nothing, nfields, MutableFoo], + [false, :stability, nothing, nfields, StructFoo], # replacefield! -- NEEDS IMPLEMENTING AND TESTING - [false, :none, _range, setfield!, TestResources.MutableFoo(5.0, randn(5)), :a, 4.0], - [ - false, - :none, - nothing, - setfield!, - TestResources.MutableFoo(5.0, randn(5)), - :b, - randn(5), - ], - [false, :none, _range, setfield!, TestResources.MutableFoo(5.0, randn(5)), 1, 4.0], - [ - false, - :none, - _range, - setfield!, - TestResources.MutableFoo(5.0, randn(5)), - 2, - randn(5), - ], + (false, :none, _range, setfield!, MutableFoo(5.0, randn(5)), :a, 4.0), + (false, :none, nothing, setfield!, MutableFoo(5.0, randn(5)), :b, randn(5)), + (false, :none, _range, setfield!, MutableFoo(5.0, randn(5)), 1, 4.0), + (false, :none, _range, setfield!, MutableFoo(5.0, randn(5)), 2, randn(5)), + (false, :stability, _range, setfield!, NonDifferentiableFoo(5, false), 1, 4), + (false, :stability, _range, setfield!, NonDifferentiableFoo(5, true), 2, false), # swapfield! -- NEEDS IMPLEMENTING AND TESTING # throw -- NEEDS IMPLEMENTING AND TESTING - [false, :stability, nothing, tuple, 5.0, 4.0], - [false, :stability, nothing, tuple, randn(5), 5.0], - [false, :stability, nothing, tuple, randn(5), randn(4)], - [false, :stability, nothing, tuple, 5.0, randn(1)], + [false, :stability_and_allocs, nothing, tuple, 5.0, 4.0], + [false, :stability_and_allocs, nothing, tuple, randn(5), 5.0], + [false, :stability_and_allocs, nothing, tuple, randn(5), randn(4)], + [false, :stability_and_allocs, nothing, tuple, 5.0, randn(1)], + (false, :stability_and_allocs, nothing, tuple), + (false, :stability_and_allocs, nothing, tuple, 1), + (false, :stability_and_allocs, nothing, tuple, 1, 5), [false, :stability, nothing, typeassert, 5.0, Float64], [false, :stability, nothing, typeassert, randn(5), Vector{Float64}], [false, :stability, nothing, typeof, 5.0], diff --git a/src/rrules/lapack.jl b/src/rrules/lapack.jl index 8bf138363..c036263ab 100644 --- a/src/rrules/lapack.jl +++ b/src/rrules/lapack.jl @@ -15,25 +15,31 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) _INFO::CoDual{$TInt}, # some info of some kind args..., ) - # Extract names. - M, N, LDA, IPIV, INFO = map(primal, (_M, _N, _LDA, _IPIV, _INFO)) - M_val = unsafe_load(M) - N_val = unsafe_load(N) - LDA_val = unsafe_load(LDA) - data_len = LDA_val * N_val - A, dA = primal(_A), tangent(_A) - - @assert M_val === N_val - - # Store the initial state. - A_mat = wrap_ptr_as_view(A, LDA_val, M_val, N_val) - A_store = copy(A_mat) - - # Run the primal. - ccall( - $(blas_name(fname)), Cvoid, ($TInt, $TInt, Ptr{$elty}, $TInt, $TInt, $TInt), - M, N, A, LDA, IPIV, INFO, - ) + GC.@preserve args begin + # Extract names. + M, N, LDA, IPIV, INFO = map(primal, (_M, _N, _LDA, _IPIV, _INFO)) + M_val = unsafe_load(M) + N_val = unsafe_load(N) + LDA_val = unsafe_load(LDA) + data_len = LDA_val * N_val + A, dA = primal(_A), tangent(_A) + + # This implementation is currently limited to square matrices, but should be + # extended when someone can find the time to do so. + @assert M_val === N_val + + # Store the initial state. + A_mat = wrap_ptr_as_view(A, LDA_val, M_val, N_val) + A_store = copy(A_mat) + + # Run the primal. + ccall( + $(blas_name(fname)), Cvoid, ($TInt, $TInt, Ptr{$elty}, $TInt, $TInt, $TInt), + M, N, A, LDA, IPIV, INFO, + ) + + ipiv_vec = copy(unsafe_wrap(Array, IPIV, N_val)) + end # Zero out the tangent. foreach(n -> unsafe_store!(dA, zero($elty), n), 1:data_len) @@ -47,7 +53,7 @@ for (fname, elty) in ((:dgetrf_, :Float64), (:sgetrf_, :Float32)) dL, dU = tril(dA_mat, -1), UpperTriangular(dA_mat) # Figure out the pivot matrix used. - p = LinearAlgebra.ipiv2perm(unsafe_wrap(Array, IPIV, N_val), N_val) + p = LinearAlgebra.ipiv2perm(ipiv_vec, N_val) # Compute pullback using Seth's method. __dF = tril(L'dL, -1) + UpperTriangular(dU * U') diff --git a/src/rrules/low_level_maths.jl b/src/rrules/low_level_maths.jl index d80176c8f..7c358e073 100644 --- a/src/rrules/low_level_maths.jl +++ b/src/rrules/low_level_maths.jl @@ -5,6 +5,7 @@ for (M, f, arity) in DiffRules.diffrules(; filter_modules=nothing) continue # Skip rules for methods not defined in the current scope end (f == :rem2pi || f == :ldexp) && continue # not designed for Float64s + (f == :+ || f == :*) && continue # use intrinsics instead if arity == 1 dx = DiffRules.diffrule(M, f, :x) pb_name = Symbol("$(M).$(f)_pb!!") @@ -60,6 +61,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:low_level_mat end arity > 2 && return (f == :rem2pi || f == :ldexp || f == :(^)) && return + (f == :+ || f == :*) && return # use intrinsics instead f = @eval $M.$f push!(test_cases, Any[false, :stability, nothing, f, rand_inputs(rng, f, arity)...]) end diff --git a/src/rrules/misc.jl b/src/rrules/misc.jl index b37b592b5..226f60e65 100644 --- a/src/rrules/misc.jl +++ b/src/rrules/misc.jl @@ -26,6 +26,8 @@ for name in [ :(LinearAlgebra.chkstride1), :(Threads.nthreads), :(Base.depwarn), + :(Base.reduced_indices), + :(Base.check_reducedims), ] @eval @is_primitive DefaultCtx Tuple{typeof($name), Vararg} @eval function rrule!!(::CoDual{_typeof($name)}, args::CoDual...) @@ -64,15 +66,35 @@ function rrule!!(::CoDual{typeof(lgetfield)}, x::CoDual, ::CoDual{Val{f}}) where return y, lgetfield_pb!! end +# Specialise for non-differentiable arguments. +function rrule!!( + ::CoDual{typeof(lgetfield)}, x::CoDual{<:Any, NoTangent}, ::CoDual{Val{f}} +) where {f} + return uninit_codual(getfield(primal(x), f)), NoPullback() +end + lgetfield(x, ::Val{f}, ::Val{order}) where {f, order} = getfield(x, f, order) @is_primitive MinimalCtx Tuple{typeof(lgetfield), Any, Any, Any} -function rrule!!(::CoDual{typeof(lgetfield)}, x::CoDual, ::CoDual{Val{f}}, ::CoDual{Val{order}}) where {f, order} - lgetfield_pb!!(dy, df, dx, dsym, dorder) = df, increment_field!!(dx, dy, Val{f}()), dsym, dorder +function rrule!!( + ::CoDual{typeof(lgetfield)}, x::CoDual, ::CoDual{Val{f}}, ::CoDual{Val{order}} +) where {f, order} + function lgetfield_pb!!(dy, df, dx, dsym, dorder) + return df, increment_field!!(dx, dy, Val{f}()), dsym, dorder + end y = CoDual(getfield(primal(x), f), _get_tangent_field(primal(x), tangent(x), f)) return y, lgetfield_pb!! end +function rrule!!( + ::CoDual{typeof(lgetfield)}, + x::CoDual{<:Any, NoTangent}, + ::CoDual{Val{f}}, + ::CoDual{Val{order}}, +) where {f, order} + return uninit_codual(getfield(primal(x), f)), NoPullback() +end + """ lsetfield!(value, name::Val, x, [order::Val]) @@ -105,6 +127,22 @@ function rrule!!( return y, setfield!_pullback end +function rrule!!( + ::CoDual{typeof(lsetfield!)}, + value::CoDual{<:Any, NoTangent}, + ::CoDual{Val{name}}, + x::CoDual, +) where {name} + save = isdefined(primal(value), name) + old_x = save ? getfield(primal(value), name) : nothing + function setfield!_pullback(dy, df, dvalue, dname, dx) + old_x !== nothing && lsetfield!(primal(value), Val(name), old_x) + return df, dvalue, dname, dx + end + y = CoDual(lsetfield!(primal(value), Val(name), primal(x)), NoTangent()) + return y, setfield!_pullback +end + function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) # Data which needs to not be GC'd. @@ -115,9 +153,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) test_cases = Any[ # Rules to avoid pointer type conversions. ( - true, - :stability, - nothing, + true, :stability, nothing, +, CoDual( bitcast(Ptr{Float64}, pointer_from_objref(_x)), @@ -134,8 +170,14 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) # Performance-rules that would ideally be completely removed. (false, :stability_and_allocs, nothing, size, randn(5, 4)), - (false, :stability_and_allocs, nothing, LinearAlgebra.lapack_size, 'N', randn(5, 4)), - (false, :stability_and_allocs, nothing, Base.require_one_based_indexing, randn(2, 3), randn(2, 1)), + ( + false, :stability_and_allocs, nothing, + LinearAlgebra.lapack_size, 'N', randn(5, 4), + ), + ( + false, :stability_and_allocs, nothing, + Base.require_one_based_indexing, randn(2, 3), randn(2, 1), + ), (false, :stability_and_allocs, nothing, in, 5.0, randn(4)), (false, :stability_and_allocs, nothing, iszero, 5.0), (false, :stability_and_allocs, nothing, isempty, randn(5)), @@ -143,36 +185,41 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:misc}) (false, :stability_and_allocs, nothing, sizeof, Float64), (false, :stability_and_allocs, nothing, promote_type, Float64, Float64), (false, :stability_and_allocs, nothing, LinearAlgebra.chkstride1, randn(3, 3)), - (false, :stability_and_allocs, nothing, LinearAlgebra.chkstride1, randn(3, 3), randn(2, 2)), + ( + false, :stability_and_allocs, nothing, + LinearAlgebra.chkstride1, randn(3, 3), randn(2, 2), + ), (false, :allocs, nothing, Threads.nthreads), # Literal replacements for getfield. (false, :stability_and_allocs, nothing, lgetfield, (5.0, 4), Val(1)), (false, :stability_and_allocs, nothing, lgetfield, (5.0, 4), Val(2)), + (false, :stability_and_allocs, nothing, lgetfield, (1, 4), Val(2)), + (false, :stability_and_allocs, nothing, lgetfield, ((), 4), Val(2)), (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(1)), (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(2)), (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(:a)), (false, :stability_and_allocs, nothing, lgetfield, (a=5.0, b=4), Val(:b)), + (false, :stability_and_allocs, nothing, lgetfield, 1:5, Val(:start)), + (false, :stability_and_allocs, nothing, lgetfield, 1:5, Val(:stop)), # Literal replacement for setfield!. ( - false, - :stability_and_allocs, - nothing, - lsetfield!, - TestResources.MutableFoo(5.0, [1.0, 2.0]), - Val(:a), - 4.0, + false, :stability_and_allocs, nothing, + lsetfield!, MutableFoo(5.0, [1.0, 2.0]), Val(:a), 4.0, ), ( - false, - :stability_and_allocs, - nothing, - lsetfield!, - TestResources.FullyInitMutableStruct(5.0, [1.0, 2.0]), - Val(:y), - [1.0, 3.0, 4.0], + false, :stability_and_allocs, nothing, + lsetfield!, FullyInitMutableStruct(5.0, [1.0, 2.0]), Val(:y), [1.0, 3.0, 4.0], ), + ( + false, :stability_and_allocs, nothing, + lsetfield!, NonDifferentiableFoo(5, false), Val(:x), 4, + ), + ( + false, :stability_and_allocs, nothing, + lsetfield!, NonDifferentiableFoo(5, false), Val(:y), true, + ) ] return test_cases, memory end diff --git a/src/rrules/new.jl b/src/rrules/new.jl index a4bfd2302..f2497011e 100644 --- a/src/rrules/new.jl +++ b/src/rrules/new.jl @@ -10,11 +10,15 @@ for N in 0:32 ) return d_new_, d_T, map(increment!!, dx, Tuple(dy))... end + @eval function _new_pullback!!(::NoTangent, d_new_, d_T, dx::Vararg{Any, $N}) + return d_new_, NoTangent(), dx... + end @eval function rrule!!( ::CoDual{typeof(_new_)}, ::CoDual{Type{P}}, x::Vararg{CoDual, $N} ) where {P} y = $(Expr(:new, :P, map(n -> :(primal(x[$n])), 1:N)...)) - dy = build_tangent(P, tuple_map(tangent, x)...) + T = tangent_type(P) + dy = T == NoTangent ? NoTangent() : build_tangent(P, tuple_map(tangent, x)...) return CoDual(y, dy), _new_pullback!! end end @@ -23,22 +27,30 @@ end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:new}) test_cases = Any[ - (false, :stability, nothing, _new_, Tuple{Float64, Int}, 5.0, 4), - (false, :stability, nothing, _new_, Tuple{Float64, Float64}, 5.0, 4.0), - (false, :stability, nothing, _new_, @NamedTuple{y::Float64}, 5.0), - (false, :stability, nothing, _new_, @NamedTuple{y::Float64, x::Int}, 5.0, 4), + (false, :stability_and_allocs, nothing, _new_, Tuple{}), + (false, :stability_and_allocs, nothing, _new_, Tuple{Float64, Int}, 5.0, 4), + (false, :stability_and_allocs, nothing, _new_, Tuple{Float64, Float64}, 5.0, 4.0), + (false, :stability_and_allocs, nothing, _new_, Tuple{Int, Int}, 5, 5), + (false, :stability_and_allocs, nothing, _new_, @NamedTuple{}), + (false, :stability_and_allocs, nothing, _new_, @NamedTuple{y::Float64}, 5.0), + ( + false, :stability_and_allocs, nothing, + _new_, @NamedTuple{y::Float64, x::Int}, 5.0, 4, + ), + (false, :stability_and_allocs, nothing, _new_, @NamedTuple{y::Int, x::Int}, 5, 4), ( - false, :stability, nothing, + false, :stability_and_allocs, nothing, _new_, TestResources.TypeStableStruct{Float64}, 5, 4.0, ), ( - false, :stability, nothing, + false, :stability_and_allocs, nothing, _new_, TestResources.TypeStableMutableStruct{Float64}, 5.0, 4.0, ), ( false, :none, nothing, _new_, TestResources.TypeStableMutableStruct{Any}, 5.0, 4.0, ), + (false, :stability_and_allocs, nothing, _new_, UnitRange{Int64}, 5, 4), ] memory = Any[] return test_cases, memory diff --git a/src/stack.jl b/src/stack.jl index 7e2737a80..4d29a2682 100644 --- a/src/stack.jl +++ b/src/stack.jl @@ -6,7 +6,7 @@ A stack specialised for reverse-mode AD. Semantically equivalent to a usual stack, but never de-allocates memory once allocated. """ mutable struct Stack{T} - memory::Vector{T} + const memory::Vector{T} position::Int Stack{T}() where {T} = new{T}(Vector{T}(undef, 0), 0) end @@ -19,7 +19,7 @@ end Stack(x::T) where {T} = Stack{T}(x) -function Base.push!(x::Stack{T}, val::T) where {T} +@inline function Base.push!(x::Stack{T}, val::T) where {T} position = x.position + 1 memory = x.memory x.position = position @@ -27,12 +27,12 @@ function Base.push!(x::Stack{T}, val::T) where {T} @inbounds memory[position] = val return nothing else - push!(memory, val) + @noinline push!(memory, val) return nothing end end -function Base.pop!(x::Stack) +@inline function Base.pop!(x::Stack) position = x.position val = x.memory[position] x.position = position - 1 @@ -62,7 +62,7 @@ end Base.eltype(::Stack{T}) where {T} = T -top_ref(x::Stack) = Ref(x.memory, x.position) +top_ref(x::Stack) = Ref(getfield(x, :memory), getfield(x, :position)) """ NoTangentStack() @@ -118,3 +118,30 @@ function tangent_ref_type_ub(::Type{P}) where {P} end tangent_ref_type_ub(::Type{Type{P}}) where {P} = NoTangentRef + +struct InactiveStack{T} + zero_tangent::T +end + +Base.pop!(s::InactiveStack{T}) where {T} = s.zero_tangent + +struct InactiveRef{T} + x::T +end + +Base.getindex(x::InactiveRef{T}) where {T} = x.x + +increment_ref!(::InactiveRef{T}, ::T) where {T} = nothing + +top_ref(::Nothing) = InactiveRef(nothing) + +increment_ref!(::InactiveRef{Nothing}, ::T) where {T} = nothing + + +struct FixedStackTangentRefStack{T} + x::Stack{T} +end + +Base.push!(x::FixedStackTangentRefStack, t) = nothing + +Base.pop!(x::FixedStackTangentRefStack) = top_ref(x.x) diff --git a/src/tangents.jl b/src/tangents.jl index 8fef76147..c8b740bf3 100644 --- a/src/tangents.jl +++ b/src/tangents.jl @@ -128,7 +128,19 @@ end end function build_tangent(::Type{P}, fields::Vararg{Any, N}) where {P<:Union{Tuple, NamedTuple}, N} - return tangent_type(P)(fields) + T = tangent_type(P) + if T == NoTangent + return NoTangent() + elseif isconcretetype(P) + return T(fields) + else + return __tangent_from_non_concrete(P, fields) + end +end + +__tangent_from_non_concrete(::Type{P}, fields) where {P<:Tuple} = Tuple(fields) +function __tangent_from_non_concrete(::Type{P}, fields) where {names, P<:NamedTuple{names}} + return NamedTuple{names}(fields) end _value(v::PossiblyUninitTangent) = val(v) @@ -195,12 +207,28 @@ tangent_type(::Type{Core.TypeName}) = NoTangent tangent_type(::Type{Core.MethodTable}) = NoTangent -@generated function tangent_type(::Type{T}) where {T<:Tuple} - return isconcretetype(T) ? Tuple{map(tangent_type, fieldtypes(T))...} : Tuple +@generated function tangent_type(::Type{P}) where {P<:Tuple} + isa(P, Union) && return Union{tangent_type(P.a), tangent_type(P.b)} + isempty(P.parameters) && return NoTangent + isa(last(P.parameters), Core.TypeofVararg) && return Any + all(p -> tangent_type(p) == NoTangent, P.parameters) && return NoTangent + return Tuple{map(tangent_type, fieldtypes(P))...} end @generated function tangent_type(::Type{NamedTuple{N, T}}) where {N, T<:Tuple} - return NamedTuple{N, tangent_type(T)} + if tangent_type(T) == NoTangent + return NoTangent + elseif isconcretetype(tangent_type(T)) + return NamedTuple{N, tangent_type(T)} + else + return Any + end +end + +function backing_type(::Type{P}) where {P} + tangent_field_types = map(n -> tangent_field_type(P, n), 1:fieldcount(P)) + all(==(NoTangent), tangent_field_types) && return NoTangent + return NamedTuple{fieldnames(P), Tuple{tangent_field_types...}} end @generated function tangent_type(::Type{P}) where {P} @@ -222,12 +250,8 @@ end Base.issingletontype(P) && return NoTangent # Derive tangent type. - return (ismutabletype(P) ? MutableTangent : Tangent){backing_type(P)} -end - -function backing_type(::Type{P}) where {P} - tangent_field_types = map(n -> tangent_field_type(P, n), 1:fieldcount(P)) - return NamedTuple{fieldnames(P), Tuple{tangent_field_types...}} + bt = backing_type(P) + return bt == NoTangent ? bt : (ismutabletype(P) ? MutableTangent : Tangent){bt} end """ @@ -312,7 +336,9 @@ end @inline function zero_tangent(x::Array{P, N}) where {P, N} return _map_if_assigned!(zero_tangent, Array{tangent_type(P), N}(undef, size(x)...), x) end -@inline zero_tangent(x::Union{Tuple, NamedTuple}) = map(zero_tangent, x) +@inline function zero_tangent(x::P) where {P<:Union{Tuple, NamedTuple}} + return tangent_type(P) == NoTangent ? NoTangent() : map(zero_tangent, x) +end @generated function zero_tangent(x::P) where {P} tangent_type(P) == NoTangent && return NoTangent() @@ -344,8 +370,8 @@ end Related to `zero_tangent`, but a bit different. Check current implementation for details -- this docstring is intentionally non-specific in order to avoid becoming outdated. """ -uninit_tangent(x) = zero_tangent(x) -uninit_tangent(x::Ptr{P}) where {P} = bitcast(Ptr{tangent_type(P)}, x) +@inline uninit_tangent(x) = zero_tangent(x) +@inline uninit_tangent(x::Ptr{P}) where {P} = bitcast(Ptr{tangent_type(P)}, x) """ randn_tangent(rng::AbstractRNG, x::T) where {T} @@ -364,8 +390,8 @@ function randn_tangent(rng::AbstractRNG, x::SimpleVector) return randn_tangent(rng, x[n]) end end -function randn_tangent(rng::AbstractRNG, x::Union{Tuple, NamedTuple}) - return map(x -> randn_tangent(rng, x), x) +function randn_tangent(rng::AbstractRNG, x::P) where {P <: Union{Tuple, NamedTuple}} + return tangent_type(P) == NoTangent ? NoTangent() : map(x -> randn_tangent(rng, x), x) end function randn_tangent(rng::AbstractRNG, x::T) where {T<:Union{Tangent, MutableTangent}} return T(randn_tangent(rng, x.fields)) @@ -604,7 +630,9 @@ end function _diff(p::P, q::P) where {P<:SimpleVector} return Any[_diff(a, b) for (a, b) in zip(p, q)] end -_diff(p::P, q::P) where {P<:Union{Tuple, NamedTuple}} = _map(_diff, p, q) +function _diff(p::P, q::P) where {P<:Union{Tuple, NamedTuple}} + return tangent_type(P) == NoTangent ? NoTangent() : _map(_diff, p, q) +end function _containerlike_diff(p::P, q::P) where {P} diffed_fields = map(fieldnames(P)) do f diff --git a/src/test_utils.jl b/src/test_utils.jl index 50be2804f..5e3e0a7d0 100644 --- a/src/test_utils.jl +++ b/src/test_utils.jl @@ -176,6 +176,7 @@ function populate_address_map!(m::AddressMap, primal::P, tangent::T) where {P, T end function populate_address_map!(m::AddressMap, p::P, t) where {P<:Union{Tuple, NamedTuple}} + t isa NoTangent && return m foreach(n -> populate_address_map!(m, getfield(p, n), getfield(t, n)), fieldnames(P)) return m end @@ -227,12 +228,12 @@ function test_rrule_numerical_correctness(rng::AbstractRNG, f_f̄, x_x̄...; rul y_primal = f(x_primal...) # Use finite differences to estimate vjps - ẋ = randn_tangent(rng, x) + ẋ = map(_x -> randn_tangent(rng, _x), x) ε = 1e-7 x′ = _add_to_primal(x, _scale(ε, ẋ)) y′ = f(x′...) ẏ = _scale(1 / ε, _diff(y′, y_primal)) - ẋ_post = _scale(1 / ε, _diff(x′, x_primal)) + ẋ_post = map((_x′, _x_p) -> _scale(1 / ε, _diff(_x′, _x_p)), x′, x_primal) # Run `rrule!!` on copies of `f` and `x`. We use randomly generated tangents so that we # can later verify that non-zero values do not get propagated by the rule. @@ -273,7 +274,7 @@ get_address(x) = ismutable(x) ? pointer_from_objref(x) : nothing _deepcopy(x) = deepcopy(x) _deepcopy(x::Module) = x -rrule_output_type(::Type{Ty}) where {Ty} = Tuple{CoDual{Ty, tangent_type(Ty)}, Any} +rrule_output_type(::Type{Ty}) where {Ty} = Tuple{codual_type(Ty), Any} function test_rrule_interface(f_f̄, x_x̄...; is_primitive, ctx::C, rule) where {C} @nospecialize f_f̄ x_x̄ @@ -444,6 +445,11 @@ function test_interpreted_rrule!!(rng::AbstractRNG, x...; interp, kwargs...) test_rrule!!(rng, in_f, x...; rule, kwargs...) end +function test_derived_rule(rng::AbstractRNG, x...; interp, kwargs...) + rule = Taped.build_rrule(interp, _typeof(__get_primals(x))) + test_rrule!!(rng, x...; rule, kwargs...) +end + # # Test that some basic operations work on a given type. # @@ -883,7 +889,7 @@ function run_derived_rrule!!_test_cases(rng_ctor, v::Val) test_cases, memory = Taped.generate_derived_rrule!!_test_cases(rng_ctor, v) GC.@preserve memory @testset "$f, $(typeof(x))" for (interface_only, perf_flag, _, f, x...) in test_cases - test_interpreted_rrule!!( + test_derived_rule( rng_ctor(123), f, x...; interp, interface_only, perf_flag, is_primitive=false ) end @@ -921,8 +927,7 @@ Optionally, an interpreter may be provided via the `interp` kwarg. See also: `Taped.TestUtils.value_and_gradient!!`. """ function set_up_gradient_problem(fargs...; interp=Taped.TInterp()) - primals = map(x -> x isa CoDual ? primal(x) : x, fargs) - sig = _typeof(primals) + sig = _typeof(__get_primals(fargs)) if Taped.is_primitive(DefaultCtx, sig) return rrule!!, Taped._eval else @@ -931,63 +936,7 @@ function set_up_gradient_problem(fargs...; interp=Taped.TInterp()) end end -""" - value_and_gradient!!(rule, in_f::CoDual{<:InterpretedFunction}, f::CoDual, x::CoDual...) - -In-place version of `value_and_gradient!!` in which the arguments have been wrapped in -`CoDual`s. Note that any mutable data in `f` and `x` will be incremented in-place. As such, -if calling this function multiple times with different values of `x`, should be careful to -ensure that you zero-out the tangent fields of `x` each time. See -`Taped.TestUtils.zero_out_arguments!` for more details. -""" -function value_and_gradient!!( - rule::R, in_f::CoDual{<:Taped.InterpretedFunction}, codual_fargs::Vararg{CoDual, N} -) where {R, N} - out, pb!! = rule(in_f, codual_fargs...) - @assert out isa CoDual{Float64, Float64} - return primal(out), pb!!(1.0, NoTangent(), map(tangent, codual_fargs)...)[2:end] -end - -""" - value_and_gradient!!(rule, in_f::InterpretedFunction, f, args...) - -Compute the value and gradient of `f(args...)`. - -`rule` and `in_f` should be constructed using `set_up_gradient_problem`. - -*Note:* If calling `value_and_gradient!!` multiple times for various values of `args`, you -should use the same `rule` and `in_f` each time, as there is no need to re-build them each -time. - -*Note:* It is your responsibility to ensure that there is no aliasing in `f` and `args`. -For example, -```julia -X = randn(5, 5) -rule, in_f = set_up_gradient_problem(*, X, X) -value_and_gradient!!(rule, in_f, *, X, X) -``` -will yield the wrong result. - -*Note:* This method of `value_and_gradient!!` has to first call `zero_codual` on all of its -arguments. This may cause some additional allocations. If this is a problem in your -use-case, consider pre-allocating the `CoDual`s and calling the other method of this -function. See `Taped.TestUtils.zero_out_arguments!!` for a helper function that you may find -useful if setting up your code this way. -""" -function value_and_gradient!!( - rule::R, in_f::Taped.InterpretedFunction, fargs::Vararg{Any, N} -) where {R, N} - return value_and_gradient!!(rule, zero_codual(in_f), map(zero_codual, fargs)...) -end - -""" - zero_out_arguments!!(codualed_fargs::CoDual...) - -Set the tangent component of a collection of `CoDuals` to zero. Mutates where possible. -""" -function zero_out_arguments!!(codualed_fargs::CoDual...) - return map(x -> CoDual(primal(x), Taped.set_to_zero!!(tangent(x))), codualed_fargs) -end +__get_primals(xs) = map(x -> x isa CoDual ? primal(x) : x, xs) end @@ -1041,6 +990,11 @@ end Base.:(==)(a::MutableFoo, b::MutableFoo) = equal_field(a, b, :a) && equal_field(a, b, :b) +mutable struct NonDifferentiableFoo + x::Int + y::Bool +end + mutable struct TypeStableMutableStruct{T} a::Float64 b::T @@ -1083,6 +1037,15 @@ end return z end +# A function in which everything is non-differentiable and has no branching. Ideally, the +# reverse-pass of this function would be a no-op, and there would be no use of the block +# stack anywhere. +function non_differentiable_foo(x::Int) + y = 5x + z = y + x + return 10z +end + function bar(x, y) x1 = sin(x) x2 = cos(y) @@ -1092,6 +1055,8 @@ function bar(x, y) return x5 end +const_tester_non_differentiable() = 1 + const_tester() = cos(5.0) intrinsic_tester(x) = 5x @@ -1122,9 +1087,12 @@ end type_stable_getfield_tester_1(x::StableFoo) = x.x type_stable_getfield_tester_2(x::StableFoo) = x.y -__x_for_gref_test = 5.0 +const __x_for_gref_test = 5.0 @eval globalref_tester() = $(GlobalRef(@__MODULE__, :__x_for_gref_test)) +const __y_for_gref_test = false +@eval globalref_tester_bool() = $(GlobalRef(@__MODULE__, :__y_for_gref_test)) + function globalref_tester_2(use_gref::Bool) v = use_gref ? __x_for_gref_test : 1 return sin(v) @@ -1133,9 +1101,14 @@ end const __x_for_gref_tester_3 = 5.0 @eval globalref_tester_3() = $(GlobalRef(@__MODULE__, :__x_for_gref_tester_3)) -__x_for_gref_tester_4::Float64 = 3.0 +const __x_for_gref_tester_4::Float64 = 3.0 @eval globalref_tester_4() = $(GlobalRef(@__MODULE__, :__x_for_gref_tester_4)) +__x_for_gref_tester_5 = 5.0 +@eval globalref_tester_5() = $(GlobalRef(@__MODULE__, :__x_for_gref_tester_5)) + +type_unstable_tester_0(x::Ref{Any}) = x[] + type_unstable_tester(x::Ref{Any}) = cos(x[]) type_unstable_tester_2(x::Ref{Real}) = cos(x[]) @@ -1293,7 +1266,7 @@ function test_while_loop(x) end function test_for_loop(x) - for _ in 1:5 + for _ in 1:500 x = sin(x) end return x @@ -1421,15 +1394,33 @@ function _sum(x) return z end +function test_handwritten_sum(x::AbstractArray{<:Real}) + y = 0.0 + n = 0 + while n < length(x) + n += 1 + y += x[n] + end + return y +end + +function test_map(x::Vector{Float64}, y::Vector{Float64}) + return map((x, y) -> sin(cos(exp(x)) + exp(y) * sin(y)), x, y) +end + +test_getfield_of_tuple_of_types(n::Int) = getfield((Float64, Float64), n) + function generate_test_functions() return Any[ (false, :allocs, nothing, const_tester), + (false, :allocs, nothing, const_tester_non_differentiable), (false, :allocs, nothing, identity, 5.0), (false, :allocs, nothing, foo, 5.0), + (false, :allocs, nothing, non_differentiable_foo, 5), (false, :allocs, nothing, bar, 5.0, 4.0), (false, :none, nothing, type_unstable_argument_eval, sin, 5.0), - (false, :none, (lb=1, ub=500), pi_node_tester, Ref{Any}(5.0)), - (false, :none, (lb=1, ub=500), pi_node_tester, Ref{Any}(5)), + (false, :none, (lb=1, ub=1_000), pi_node_tester, Ref{Any}(5.0)), + (false, :none, (lb=1, ub=1_000), pi_node_tester, Ref{Any}(5)), (false, :allocs, nothing, intrinsic_tester, 5.0), (false, :allocs, nothing, goto_tester, 5.0), (false, :allocs, nothing, new_tester, 5.0, :hello), @@ -1438,14 +1429,17 @@ function generate_test_functions() (false, :allocs, nothing, type_stable_getfield_tester_1, StableFoo(5.0, :hi)), (false, :allocs, nothing, type_stable_getfield_tester_2, StableFoo(5.0, :hi)), (false, :none, nothing, globalref_tester), - # (false, :stability, nothing, globalref_tester_2, true), - # (false, :stability, nothing, globalref_tester_2, false), + (false, :none, nothing, globalref_tester_bool), + (false, :none, nothing, globalref_tester_2, true), + (false, :none, nothing, globalref_tester_2, false), (false, :allocs, nothing, globalref_tester_3), (false, :allocs, nothing, globalref_tester_4), + (false, :none, (lb=1, ub=500), globalref_tester_5), + (false, :none, (lb=1, ub=1_000), type_unstable_tester_0, Ref{Any}(5.0)), (false, :none, nothing, type_unstable_tester, Ref{Any}(5.0)), (false, :none, nothing, type_unstable_tester_2, Ref{Real}(5.0)), (false, :none, (lb=1, ub=1000), type_unstable_tester_3, Ref{Any}(5.0)), - (false, :none, (lb=1, ub=1000), test_primitive_dynamic_dispatch, Any[5.0, false]), + (false, :none, (lb=1, ub=10_000), test_primitive_dynamic_dispatch, Any[5.0, false]), (false, :none, nothing, type_unstable_function_eval, Ref{Any}(sin), 5.0), (false, :allocs, nothing, phi_const_bool_tester, 5.0), (false, :allocs, nothing, phi_const_bool_tester, -5.0), @@ -1461,7 +1455,7 @@ function generate_test_functions() randn(5), 1, Base.Slice(Base.OneTo(1)), - ), # fun PhiNode example to do with not assigning values + ), # fun PhiNode example (false, :allocs, nothing, avoid_throwing_path_tester, 5.0), (false, :allocs, nothing, simple_foreigncall_tester, randn(5)), (false, :none, nothing, simple_foreigncall_tester_2, randn(6), (2, 3)), @@ -1485,24 +1479,19 @@ function generate_test_functions() # (false, :stability, nothing, unstable_splatting_tester, Ref{Any}(5.0)), # known failure case -- no rrule for _apply_iterate # (false, :stability, nothing, unstable_splatting_tester, Ref{Any}((5.0, 4.0))), # known failure case -- no rrule for _apply_iterate # (false, :stability, nothing, unstable_splatting_tester, Ref{Any}((5.0, 4.0, 3.0))), # known failure case -- no rrule for _apply_iterate - (false, :none, nothing, inferred_const_tester, Ref{Any}(nothing)), + (false, :none, (lb=1, ub=1_000), inferred_const_tester, Ref{Any}(nothing)), (false, :none, (lb=1, ub=1_000), datatype_slot_tester, 1), (false, :none, (lb=1, ub=1_000), datatype_slot_tester, 2), (false, :none, (lb=1, ub=100_000_000), test_union_of_arrays, randn(5), true), ( - false, - :none, - nothing, - test_union_of_types, - Ref{Union{Type{Float64}, Type{Int}}}(Float64), + false, :none, (lb=1, ub=500), + test_union_of_types, Ref{Union{Type{Float64}, Type{Int}}}(Float64), ), (false, :allocs, nothing, test_self_reference, 1.1, 1.5), (false, :allocs, nothing, test_self_reference, 1.5, 1.1), (false, :none, nothing, test_recursive_sum, randn(2)), ( - false, - :none, - (lb=1, ub=1_000), + false, :none, (lb=1, ub=1_000), LinearAlgebra._modify!, LinearAlgebra.MulAddMul(5.0, 4.0), 5.0, @@ -1512,26 +1501,15 @@ function generate_test_functions() (false, :allocs, nothing, getfield_tester, (5.0, 5)), (false, :allocs, nothing, getfield_tester_2, (5.0, 5)), ( - false, - :allocs, - nothing, - setfield_tester_left!, - FullyInitMutableStruct(5.0, randn(3)), - 4.0, + false, :allocs, nothing, + setfield_tester_left!, FullyInitMutableStruct(5.0, randn(3)), 4.0, ), ( - false, - :none, - nothing, - setfield_tester_right!, - FullyInitMutableStruct(5.0, randn(3)), - randn(5), + false, :none, nothing, + setfield_tester_right!, FullyInitMutableStruct(5.0, randn(3)), randn(5), ), - ( - false, :none, (lb=100, ub=100_000_000), - mul!, transpose(randn(3, 5)), randn(5, 5), randn(5, 3), 4.0, 3.0, - ), # static_parameter, - (false, :none, (lb=100, ub=100_000_000), Xoshiro, 123456), + (false, :none, nothing, mul!, randn(3, 5)', randn(5, 5), randn(5, 3), 4.0, 3.0), + (false, :none, nothing, Xoshiro, 123456), (false, :none, (lb=1, ub=100_000), *, randn(250, 500), randn(500, 250)), (false, :allocs, nothing, test_sin, 1.0), (false, :allocs, nothing, test_cos_sin, 2.0), @@ -1543,8 +1521,8 @@ function generate_test_functions() (false, :allocs, nothing, test_isbits_multiple_usage_phi, false, 1.1), (false, :allocs, nothing, test_isbits_multiple_usage_phi, true, 1.1), (false, :allocs, nothing, test_multiple_call_non_primitive, 5.0), - (false, :none, (lb=1, ub=500), test_multiple_pi_nodes, Ref{Any}(5.0)), - (false, :none, (lb=1, ub=500), test_multi_use_pi_node, Ref{Any}(5.0)), + (false, :none, (lb=1, ub=1500), test_multiple_pi_nodes, Ref{Any}(5.0)), + (false, :none, (lb=1, ub=1500), test_multi_use_pi_node, Ref{Any}(5.0)), (false, :allocs, nothing, test_getindex, [1.0, 2.0]), (false, :allocs, nothing, test_mutation!, [1.0, 2.0]), (false, :allocs, nothing, test_while_loop, 2.0), @@ -1556,47 +1534,37 @@ function generate_test_functions() (false, :none, nothing, test_struct_partial_init, 3.5), (false, :none, nothing, test_mutable_partial_init, 3.3), ( - false, - :allocs, - (lb=100, ub=2_000), + false, :allocs, nothing, test_naive_mat_mul!, randn(100, 50), randn(100, 30), randn(30, 50), ), ( - false, - :allocs, - (lb=100, ub=2_000), + false, :allocs, nothing, (A, C) -> test_naive_mat_mul!(C, A, A), randn(100, 100), randn(100, 100), ), (false, :allocs, (lb=10, ub=1_000), sum, randn(30)), - (false, :none, (lb=100, ub=10_000), test_diagonal_to_matrix, Diagonal(randn(30))), + (false, :none, (lb=10, ub=1_000), test_diagonal_to_matrix, Diagonal(randn(30))), ( - false, - :allocs, - (lb=100, ub=5_000), + false, :allocs, (lb=100, ub=1_000), ldiv!, randn(20, 20), Diagonal(rand(20) .+ 1), randn(20, 20), ), ( - false, - :allocs, - (lb=100, ub=10_000), - kron!, randn(400, 400), randn(20, 20), randn(20, 20), + false, :allocs, (lb=10, ub=500), + LinearAlgebra._kron!, randn(400, 400), randn(20, 20), randn(20, 20), ), ( - false, - :allocs, - (lb=100, ub=10_000), + false, :allocs, (lb=10, ub=500), kron!, randn(400, 400), Diagonal(randn(20)), randn(20, 20), ), ( - false, - :none, - nothing, + false, :none, nothing, test_mlp, randn(sr(1), 500, 200), randn(sr(2), 700, 500), randn(sr(3), 300, 700), ), + (false, :allocs, (lb=1.0, ub=150), test_handwritten_sum, randn(1024 * 1024)), (false, :none, nothing, _sum, randn(1024)), + (false, :none, nothing, test_map, randn(1024), randn(1024)), ] end @@ -1703,8 +1671,12 @@ const DIFFTESTS_FUNCTIONS = vcat( ), ) +export MutableFoo, StructFoo, NonDifferentiableFoo, FullyInitMutableStruct + end function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:test_utils}) return TestResources.generate_test_functions(), Any[] end + +using .TestResources diff --git a/src/utils.jl b/src/utils.jl index 41dbfc2ea..07d37c978 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,3 +6,25 @@ Central definition of typeof, which is specific to the use-required in this pack _typeof(x) = Base._stable_typeof(x) _typeof(x::Tuple) = Tuple{map(_typeof, x)...} _typeof(x::NamedTuple{names}) where {names} = NamedTuple{names, _typeof(Tuple(x))} + +""" + tuple_map(f::F, x::Tuple) where {F} + +This function is semantically equivalent to `map(f, x)`, but always specialises on all of +the element types of `x`, regardless the length of `x`. This contrasts with `map`, in which +the number of element types specialised upon is a fixed constant in the compiler. + +As a consequence, if `x` is very long, this function may have very large compile times. + + tuple_map(f::F, x::Tuple, y::Tuple) where {F} + +Binary extension of `tuple_map`. Equivalent to `map(f, x, y`, but guaranteed to specialise +on all element types of `x` and `y`. +""" +@inline @generated function tuple_map(f::F, x::Tuple) where {F} + return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n))), eachindex(x.parameters))...) +end + +@inline @generated function tuple_map(f::F, x::Tuple, y::Tuple) where {F} + return Expr(:call, :tuple, map(n -> :(f(getfield(x, $n), getfield(y, $n))), eachindex(x.parameters))...) +end diff --git a/test/front_matter.jl b/test/front_matter.jl index cf6560c2e..75436e41b 100644 --- a/test/front_matter.jl +++ b/test/front_matter.jl @@ -43,7 +43,19 @@ using Taped: get_codual, get_tangent_stack, top_ref, - NoTangentRef + NoTangentRef, + BBCode, + ID, + IDPhiNode, + IDGotoNode, + IDGotoIfNot, + BBlock, + make_ad_stmts!, + ADStmtInfo, + ad_stmt_info, + ADInfo, + SharedDataPairs, + increment_field!! using .TestUtils: test_rrule!!, diff --git a/test/integration_testing/array.jl b/test/integration_testing/array.jl index 18edac147..4eee88cd8 100644 --- a/test/integration_testing/array.jl +++ b/test/integration_testing/array.jl @@ -1,4 +1,4 @@ -_getter = () -> 5.0 +_getter() = 5.0 @testset "array" begin test_cases = vcat( Any[ @@ -505,9 +505,8 @@ _getter = () -> 5.0 interp = Taped.TInterp() @testset for (interface_only, f, x...) in test_cases @info _typeof((f, x...)) - TestUtils.test_interpreted_rrule!!( - sr(123456), f, x...; - interp, perf_flag=:none, interface_only, is_primitive=false, + TestUtils.test_derived_rule( + sr(123456), f, x...; interp, perf_flag=:none, interface_only, is_primitive=false ) end end diff --git a/test/integration_testing/battery_tests.jl b/test/integration_testing/battery_tests.jl index c00c11c9b..3bafce60c 100644 --- a/test/integration_testing/battery_tests.jl +++ b/test/integration_testing/battery_tests.jl @@ -15,6 +15,7 @@ "hello", ], randn(Float64, 5), + [1, 2, 3], # randn(Float32, 5), # randn(Float16, 5), [ diff --git a/test/integration_testing/diff_tests.jl b/test/integration_testing/diff_tests.jl index ed36cdb3f..934877c07 100644 --- a/test/integration_testing/diff_tests.jl +++ b/test/integration_testing/diff_tests.jl @@ -7,7 +7,7 @@ TestResources.DIFFTESTS_FUNCTIONS[91:end], # SKIPPING SPARSE_LDIV )) @info "$n: $(_typeof((f, x...)))" - TestUtils.test_interpreted_rrule!!( + TestUtils.test_derived_rule( sr(123456), f, x...; interp, perf_flag=:none, interface_only=false, is_primitive=false, ) diff --git a/test/integration_testing/distributions.jl b/test/integration_testing/distributions.jl index 69827af44..34c9aced4 100644 --- a/test/integration_testing/distributions.jl +++ b/test/integration_testing/distributions.jl @@ -216,7 +216,7 @@ _pdmat(A) = PDMat(_sym(A) + 5I) (false, LKJ(5, 1.1), rand(sr(123456), LKJ(5, 1.1))), ] @info "$(map(typeof, (d, x)))" - TestUtils.test_interpreted_rrule!!( + TestUtils.test_derived_rule( sr(123456), logpdf, d, x; interp, perf_flag=:none, interface_only, is_primitive=false, ) @@ -263,7 +263,7 @@ _pdmat(A) = PDMat(_sym(A) + 5I) ), ] @info "$name" - TestUtils.test_interpreted_rrule!!( + TestUtils.test_derived_rule( sr(123456), f, x...; interp, perf_flag=:none, interface_only, is_primitive=false, ) diff --git a/test/integration_testing/gp.jl b/test/integration_testing/gp.jl index f7abec86a..7c0ef1896 100644 --- a/test/integration_testing/gp.jl +++ b/test/integration_testing/gp.jl @@ -40,7 +40,7 @@ using AbstractGPs, KernelFunctions (rand, Xoshiro(123456), fx), (logpdf, fx, rand(fx)), ] - TestUtils.test_interpreted_rrule!!( + TestUtils.test_derived_rule( sr(123456), rand, Xoshiro(123456), GP(k)(x1, 1.1); interp, perf_flag=:none, interface_only=true, is_primitive=false, ) diff --git a/test/integration_testing/misc.jl b/test/integration_testing/misc.jl index c5eaca6b5..8ce7362bd 100644 --- a/test/integration_testing/misc.jl +++ b/test/integration_testing/misc.jl @@ -86,7 +86,7 @@ end), ) @info "$(_typeof((f, x...)))" - TestUtils.test_interpreted_rrule!!( + TestUtils.test_derived_rule( Xoshiro(123456), f, x...; interp, perf_flag=:none, interface_only, is_primitive=false, ) diff --git a/test/integration_testing/turing.jl b/test/integration_testing/turing.jl index efd9bbf92..17a636c08 100644 --- a/test/integration_testing/turing.jl +++ b/test/integration_testing/turing.jl @@ -1,10 +1,10 @@ using Turing +using ReverseDiff # using CSV, DataFrames, ReverseDiff # turing_bench_results = DataFrame( # :name => String[], # :primal => [], -# :interp => [], # :gradient => [], # :reversediff => [], # ) @@ -25,8 +25,8 @@ end end @model broadcast_demo(x) = begin - μ ~ TruncatedNormal(1, 2, 0.1, 10) - σ ~ TruncatedNormal(1, 2, 0.1, 10) + μ ~ truncated(Normal(1, 2), 0.1, 10) + σ ~ truncated(Normal(1, 2), 0.1, 10) x .~ LogNormal(μ, σ) end @@ -107,43 +107,28 @@ end ], Any[ (false, "demo_$n", m, Turing.DynamicPPL.TestUtils.rand_prior_true(m)) for - (n, m) in enumerate(Turing.DynamicPPL.TestUtils.DEMO_MODELS[1:11]) + (n, m) in enumerate(Turing.DynamicPPL.TestUtils.DEMO_MODELS) ], ) @info name rng = sr(123) f, x = build_turing_problem(rng, model, ex) - - in_f = Taped.InterpretedFunction(DefaultCtx(), _typeof((f, x)), interp); - if interface_only - in_f(f, deepcopy(x)) - else - @test has_equal_data(in_f(f, deepcopy(x)), f(deepcopy(x))) - end - - TestUtils.test_rrule!!( - sr(123456), in_f, f, x; - perf_flag=:none, interface_only=true, is_primitive=false, + TestUtils.test_derived_rule( + sr(123456), f, x; + perf_flag=:none, interface_only=true, is_primitive=false, interp ) - rule, in_f = TestUtils.set_up_gradient_problem(f, x); - codualed_args = map(zero_codual, (in_f, f, x)); - TestUtils.value_and_gradient!!(rule, codualed_args...) - - # @profview run_many_times(1_000, TestUtils.value_and_gradient!!, rule, codualed_args...) + # rule = build_rrule(interp, _typeof((f, x))) + # codualed_args = map(zero_codual, (f, x)) + # TestUtils.to_benchmark(rule, codualed_args...) # primal = @benchmark $f($x) - # interpreted = @benchmark $in_f($f, $x) - # gradient = @benchmark(TestUtils.value_and_gradient!!($rule, $codualed_args...)) + # gradient = @benchmark(TestUtils.to_benchmark($rule, $codualed_args...)) # println("primal") # display(primal) # println() - # println("interpreted") - # display(interpreted) - # println() - # println("gradient") # display(gradient) # println() @@ -158,11 +143,14 @@ end # println("ReverseDiff") # display(revdiff) # println() + # @show time(revdiff) / time(primal) # catch # display("revdiff failed") # end - # push!(turing_bench_results, (name, primal, interpreted, gradient, revdiff)) + # @show time(gradient) / time(primal) + + # push!(turing_bench_results, (name, primal, gradient, revdiff)) end end @@ -170,7 +158,6 @@ end # out_df = DataFrame( # :name => df.name, # :primal => map(time, df.primal), -# :interp => map(time, df.interp), # :gradient => map(time, df.gradient), # :reversediff => map(time, df.reversediff), # ) diff --git a/test/interface.jl b/test/interface.jl new file mode 100644 index 000000000..e5409f084 --- /dev/null +++ b/test/interface.jl @@ -0,0 +1,13 @@ +@testset "interface" begin + f = (x, y) -> x * y + sin(x) * cos(y) + x = 5.0 + y = 4.0 + rule = build_rrule(f, x, y) + v, grad = value_and_gradient!!(rule, f, x, y) + @test v ≈ f(x, y) + @test grad isa Tuple{NoTangent, Float64, Float64} + + v, grad2 = value_and_pullback!!(rule, 1.0, f, x, y) + @test v ≈ f(x, y) + @test grad == grad2 +end diff --git a/test/interpreter/abstract_interpretation.jl b/test/interpreter/abstract_interpretation.jl index 1b546bbc0..4229e9f82 100644 --- a/test/interpreter/abstract_interpretation.jl +++ b/test/interpreter/abstract_interpretation.jl @@ -25,7 +25,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert usual_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :sin) # Should continue to inline away under AD compilation. - interp = Taped.TapedInterpreter(Taped.DefaultCtx()) + interp = Taped.TapedInterpreter(DefaultCtx) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), ad_ir.stmts.inst) @test ad_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :sin) @@ -42,7 +42,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert usual_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :sin) # Should not inline away under AD compilation. - interp = Taped.TapedInterpreter(Taped.DefaultCtx()) + interp = Taped.TapedInterpreter(DefaultCtx) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), ad_ir.stmts.inst) @test ad_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :a_primitive) @@ -61,7 +61,7 @@ contains_primitive_behind_call(x) = @inline contains_primitive(x) @assert usual_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :sin) # Should not inline away under AD compilation. - interp = Taped.TapedInterpreter(Taped.DefaultCtx()) + interp = Taped.TapedInterpreter(DefaultCtx) ad_ir = Base.code_ircode_by_type(sig; interp)[1][1] invoke_line = findfirst(x -> Meta.isexpr(x, :invoke), ad_ir.stmts.inst) @test ad_ir.stmts.inst[invoke_line].args[2] == GlobalRef(Main, :a_primitive) diff --git a/test/interpreter/bbcode.jl b/test/interpreter/bbcode.jl new file mode 100644 index 000000000..415dfb2a1 --- /dev/null +++ b/test/interpreter/bbcode.jl @@ -0,0 +1,52 @@ +module BBCodeTestCases + test_phi_node(x::Ref{Union{Float32, Float64}}) = sin(x[]) +end + +@testset "bbcode" begin + @testset "ID" begin + id1 = ID() + id2 = ID() + @test id1 == id1 + @test id1 != id2 + end + @testset "BBlock" begin + bb = BBlock( + ID(), + ID[ID(), ID()], + CC.NewInstruction[ + CC.NewInstruction(IDPhiNode([ID(), ID()], Any[true, false]), Any), + CC.NewInstruction(:(println("hello")), Any), + ], + ) + @test bb isa BBlock + @test length(bb) == 2 + + insert!(bb, 1, ID(), CC.NewInstruction(nothing, Nothing)) + @test length(bb) == 3 + @test bb.insts[1].stmt === nothing + + bb_copy = copy(bb) + @test bb_copy.inst_ids !== bb.inst_ids + + @test Taped.terminator(bb) === nothing + end + @testset "BBCode $f" for (f, P) in [ + (TestResources.test_while_loop, Tuple{Float64}), + (sin, Tuple{Float64}), + (BBCodeTestCases.test_phi_node, Tuple{Ref{Union{Float32, Float64}}}), + ] + ir = Base.code_ircode(f, P)[1][1] + bb_code = BBCode(ir) + @test bb_code isa BBCode + @test length(bb_code.blocks) == length(ir.cfg.blocks) + new_ir = Taped.IRCode(bb_code) + @test length(new_ir.stmts.inst) == length(ir.stmts.inst) + @test all(map(==, ir.stmts.inst, new_ir.stmts.inst)) + @test all(map(==, ir.stmts.type, new_ir.stmts.type)) + @test all(map(==, ir.stmts.info, new_ir.stmts.info)) + @test all(map(==, ir.stmts.line, new_ir.stmts.line)) + @test all(map(==, ir.stmts.flag, new_ir.stmts.flag)) + @test length(Taped.collect_stmts(bb_code)) == length(ir.stmts.inst) + @test Taped.id_to_line_map(bb_code) isa Dict{ID, Int} + end +end diff --git a/test/interpreter/ir_utils.jl b/test/interpreter/ir_utils.jl index 8da61a7cf..c746d3135 100644 --- a/test/interpreter/ir_utils.jl +++ b/test/interpreter/ir_utils.jl @@ -1,3 +1,10 @@ +module IRUtilsGlobalRefs + __x_1 = 5.0 + const __x_2 = 5.0 + __x_3::Float64 = 5.0 + const __x_4::Float64 = 5.0 +end + @testset "ir_utils" begin @testset "ircode $(typeof(fargs))" for fargs in Any[ (sin, 5.0), (cos, 1.0), @@ -72,4 +79,20 @@ @test !isassigned(result.values, 1) end end + @testset "globalref_type" begin + @test Taped.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_1)) == Any + @test Taped.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_2)) == Float64 + @test Taped.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_3)) == Float64 + @test Taped.globalref_type(GlobalRef(IRUtilsGlobalRefs, :__x_4)) == Float64 + end + @testset "unhandled_feature" begin + @test_throws Taped.UnhandledLanguageFeatureException Taped.unhandled_feature("foo") + end + @testset "inc_args" begin + @test Taped.inc_args(Expr(:call, sin, Argument(4))) == Expr(:call, sin, Argument(5)) + @test Taped.inc_args(ReturnNode(Argument(2))) == ReturnNode(Argument(3)) + id = ID() + @test Taped.inc_args(IDGotoIfNot(Argument(1), id)) == IDGotoIfNot(Argument(2), id) + @test Taped.inc_args(IDGotoNode(id)) == IDGotoNode(id) + end end diff --git a/test/interpreter/registers.jl b/test/interpreter/registers.jl new file mode 100644 index 000000000..4ef420927 --- /dev/null +++ b/test/interpreter/registers.jl @@ -0,0 +1,8 @@ +@testset "registers" begin + @test Taped.register_type(Float64) <: Taped.AugmentedRegister{CoDual{Float64, Float64}} + @test Taped.register_type(Bool) <: Taped.AugmentedRegister{CoDual{Bool, NoTangent}} + @test Taped.register_type(Any) == Taped.AugmentedRegister + @test Taped.register_type(Real) == Taped.AugmentedRegister + @test ==(Taped.register_type(Union{Float64, Float32}), Taped.AugmentedRegister) + @test Taped.register_type(Union{Float64, Bool}) <: Union{Taped.AugmentedRegister, Bool} +end diff --git a/test/interpreter/s2s_reverse_mode_ad.jl b/test/interpreter/s2s_reverse_mode_ad.jl new file mode 100644 index 000000000..ebfa4f909 --- /dev/null +++ b/test/interpreter/s2s_reverse_mode_ad.jl @@ -0,0 +1,230 @@ +module S2SGlobals + non_const_global = 5.0 + const const_float = 5.0 + const const_int = 5 + const const_bool = true +end + +@testset "s2s_reverse_mode_ad" begin + @testset "SharedDataPairs" begin + m = SharedDataPairs() + id = Taped.add_data!(m, 5.0) + @test length(m.pairs) == 1 + @test m.pairs[1][1] == id + @test m.pairs[1][2] == 5.0 + end + @testset "ADInfo" begin + arg_types = Dict{Argument, Any}(Argument(1) => Float64, Argument(2) => Int) + id_ssa_1 = ID() + id_ssa_2 = ID() + ssa_insts = Dict{ID, CC.NewInstruction}( + id_ssa_1 => CC.NewInstruction(nothing, Float64), + id_ssa_2 => CC.NewInstruction(nothing, Any), + ) + info = ADInfo(Taped.TInterp(), arg_types, ssa_insts, Any[]) + + # Verify that we can access the interpreter and terminator block ID. + @test info.interp isa Taped.TInterp + + # Verify that we can get the type associated to Arguments, IDs, and others. + global ___x = 5.0 + global ___y::Float64 = 5.0 + @test Taped.get_primal_type(info, Argument(1)) == Float64 + @test Taped.get_primal_type(info, Argument(2)) == Int + @test Taped.get_primal_type(info, id_ssa_1) == Float64 + @test Taped.get_primal_type(info, GlobalRef(Base, :sin)) == typeof(sin) + @test Taped.get_primal_type(info, GlobalRef(Main, :___x)) == Any + @test Taped.get_primal_type(info, GlobalRef(Main, :___y)) == Float64 + @test Taped.get_primal_type(info, 5) == Int + @test Taped.get_primal_type(info, QuoteNode(:hello)) == Symbol + end + @testset "make_ad_stmts!" begin + + # Set up ADInfo -- this state is required by `make_ad_stmts!`, and the + # `LineToADDataMap` object can be mutated. + id_line_1 = ID() + id_line_2 = ID() + info = ADInfo( + Taped.TInterp(), + Dict{Argument, Any}(Argument(1) => typeof(sin), Argument(2) => Float64), + Dict{ID, CC.NewInstruction}( + id_line_1 => CC.NewInstruction(Expr(:invoke, nothing, cos, Argument(2)), Float64), + id_line_2 => CC.NewInstruction(nothing, Any), + ), + Any[Taped.NoTangentStack(), Stack{Float64}()], + ) + + @testset "Nothing" begin + line = ID() + @test TestUtils.has_equal_data( + make_ad_stmts!(nothing, line, info), + ad_stmt_info(line, nothing, nothing), + ) + end + @testset "ReturnNode" begin + line = ID() + @testset "unreachable" begin + @test TestUtils.has_equal_data( + make_ad_stmts!(ReturnNode(), line, info), + ad_stmt_info(line, ReturnNode(), nothing), + ) + end + @testset "Argument" begin + val = Argument(4) + @test TestUtils.has_equal_data( + make_ad_stmts!(ReturnNode(Argument(2)), line, info), + ad_stmt_info(line, ReturnNode(Argument(3)), nothing), + ) + end + @testset "literal" begin + stmt_info = make_ad_stmts!(ReturnNode(5.0), line, info) + @test stmt_info isa ADStmtInfo + @test stmt_info.fwds[1][2].stmt isa ReturnNode + end + @testset "GlobalRef" begin + node = ReturnNode(GlobalRef(S2SGlobals, :const_float)) + stmt_info = make_ad_stmts!(node, line, info) + @test stmt_info isa ADStmtInfo + @test stmt_info.fwds[1][2].stmt isa ReturnNode + end + end + @testset "IDGotoNode" begin + line = ID() + stmt = IDGotoNode(ID()) + @test TestUtils.has_equal_data( + make_ad_stmts!(stmt, line, info), ad_stmt_info(line, stmt, nothing) + ) + end + @testset "IDGotoIfNot" begin + line = ID() + cond_id = ID() + stmt = IDGotoIfNot(cond_id, ID()) + ad_stmts = make_ad_stmts!(stmt, line, info) + @test ad_stmts isa ADStmtInfo + @test ad_stmts.rvs[1][2].stmt === nothing + fwds = ad_stmts.fwds + @test fwds[1][1] == fwds[2][2].stmt.cond + @test Meta.isexpr(fwds[1][2].stmt, :call) + @test fwds[2][2].stmt isa IDGotoIfNot + @test fwds[2][2].stmt.dest == stmt.dest + end + @testset "IDPhiNode" begin + stmt = IDPhiNode(ID[ID(), ID()], Any[ID(), 5.0]) + ad_stmts = make_ad_stmts!(stmt, id_line_1, info) + @test ad_stmts isa ADStmtInfo + end + @testset "PiNode" begin + @testset "unhandled case" begin + @test_throws( + Taped.UnhandledLanguageFeatureException, + make_ad_stmts!(PiNode(5.0, Float64), ID(), info), + ) + end + @testset "sharpen type of ID" begin + line = id_line_1 + val = id_line_2 + stmt_info = make_ad_stmts!(PiNode(val, Float64), line, info) + @test stmt_info isa ADStmtInfo + end + end + @testset "GlobalRef" begin + @testset "non-const" begin + global_ref = GlobalRef(S2SGlobals, :non_const_global) + stmt_info = make_ad_stmts!(global_ref, ID(), info) + @test stmt_info isa Taped.ADStmtInfo + @test Meta.isexpr(last(stmt_info.fwds)[2].stmt, :call) + @test last(stmt_info.fwds)[2].stmt.args[1] == Taped.__verify_const + end + @testset "differentiable const globals" begin + stmt_info = make_ad_stmts!(GlobalRef(S2SGlobals, :const_float), ID(), info) + @test stmt_info isa Taped.ADStmtInfo + @test Meta.isexpr(only(stmt_info.fwds)[2].stmt, :call) + @test only(stmt_info.fwds)[2].stmt.args[1] == identity + end + end + @testset "PhiCNode" begin + @test_throws( + Taped.UnhandledLanguageFeatureException, + make_ad_stmts!(Core.PhiCNode(Any[]), ID(), info), + ) + end + @testset "UpsilonNode" begin + @test_throws( + Taped.UnhandledLanguageFeatureException, + make_ad_stmts!(Core.UpsilonNode(5), ID(), info), + ) + end + @testset "Expr" begin + @testset "invoke" begin + stmt = Expr(:invoke, nothing, cos, Argument(2)) + ad_stmts = make_ad_stmts!(stmt, id_line_1, info) + fwds_stmt = ad_stmts.fwds[2][2].stmt + @test Meta.isexpr(fwds_stmt, :call) + @test fwds_stmt.args[1] == Taped.__fwds_pass! + @test Meta.isexpr(ad_stmts.rvs[2][2].stmt, :call) + @test ad_stmts.rvs[2][2].stmt.args[1] == Taped.__rvs_pass! + end + @testset "copyast" begin + stmt = Expr(:copyast, QuoteNode(:(hi))) + ad_stmts = make_ad_stmts!(stmt, ID(), info) + @test ad_stmts isa Taped.ADStmtInfo + @test Meta.isexpr(ad_stmts.fwds[1][2].stmt, :call) + @test ad_stmts.fwds[1][2].stmt.args[1] == identity + end + @testset "throw_undef_if_not" begin + cond_id = ID() + line = ID() + @test TestUtils.has_equal_data( + make_ad_stmts!(Expr(:throw_undef_if_not, :x, cond_id), line, info), + ad_stmt_info(line, Expr(:throw_undef_if_not, :x, cond_id), nothing), + ) + end + @testset "$stmt" for stmt in [ + Expr(:gc_preserve_begin), + ] + line = ID() + @test TestUtils.has_equal_data( + make_ad_stmts!(stmt, line, info), + ad_stmt_info(line, stmt, nothing), + ) + end + end + end + + interp = Taped.TInterp() + @testset "$(_typeof((f, x...)))" for (n, (interface_only, perf_flag, bnds, f, x...)) in + collect(enumerate(TestResources.generate_test_functions())) + + sig = _typeof((f, x...)) + @info "$n: $sig" + TestUtils.test_derived_rule( + Xoshiro(123456), f, x...; interp, perf_flag, interface_only, is_primitive=false + ) + + # codual_args = map(zero_codual, (f, x...)) + # rule = Taped.build_rrule(interp, sig) + # out, pb!! = rule(codual_args...) + # # @code_warntype optimize=true rule(codual_args...) + # # @code_warntype optimize=true pb!!(tangent(out), map(tangent, codual_args)...) + + # primal_time = @benchmark $f($(Ref(x))[]...) + # s2s_time = @benchmark $rule($codual_args...)[2]($(tangent(out)), $(map(tangent, codual_args))...) + # in_f = in_f = Taped.InterpretedFunction(DefaultCtx(), sig, interp); + # __rrule!! = Taped.build_rrule!!(in_f); + # df = zero_codual(in_f); + # codual_x = map(zero_codual, (f, x...)); + # interp_time = @benchmark TestUtils.to_benchmark($__rrule!!, $df, $codual_x...) + + # display(primal_time) + # display(s2s_time) + # display(interp_time) + # s2s_ratio = time(s2s_time) / time(primal_time) + # interp_ratio = time(interp_time) / time(primal_time) + # println("s2s ratio ratio: $(s2s_ratio)") + # println("interp ratio: $(interp_ratio)") + + # f(rule, codual_args, out) = rule(codual_args...)[2](tangent(out), map(tangent, codual_args)...) + # f(rule, codual_args, out) + # @profview(run_many_times(1_000, f, rule, codual_args, out)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 67c1443a6..60c9ee804 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,11 +8,14 @@ include("front_matter.jl") include("stack.jl") @testset "interpreter" begin include(joinpath("interpreter", "contexts.jl")) + include(joinpath("interpreter", "abstract_interpretation.jl")) + include(joinpath("interpreter", "bbcode.jl")) include(joinpath("interpreter", "ir_utils.jl")) include(joinpath("interpreter", "ir_normalisation.jl")) - include(joinpath("interpreter", "abstract_interpretation.jl")) - include(joinpath("interpreter", "interpreted_function.jl")) - include(joinpath("interpreter", "reverse_mode_ad.jl")) + include(joinpath("interpreter", "registers.jl")) + # include(joinpath("interpreter", "interpreted_function.jl")) + # include(joinpath("interpreter", "reverse_mode_ad.jl")) + include(joinpath("interpreter", "s2s_reverse_mode_ad.jl")) end elseif test_group == "rrules" include("test_utils.jl") diff --git a/test/tangents.jl b/test/tangents.jl index e3751a023..71e39e610 100644 --- a/test/tangents.jl +++ b/test/tangents.jl @@ -39,13 +39,16 @@ (4.0, [4.0, 3.0]), (9.0, [7.0, 7.0]), ), + ((), NoTangent(), NoTangent(), NoTangent()), + ((1,), NoTangent(), NoTangent(), NoTangent()), + ((2, 3), NoTangent(), NoTangent(), NoTangent()), ( (a=6.0, b=[1.0, 2.0]), (a=5.0, b=[3.0, 4.0]), (a=4.0, b=[4.0, 3.0]), (a=9.0, b=[7.0, 7.0]), ), - ((;), (;), (;), (;)), + ((;), NoTangent(), NoTangent(), NoTangent()), ( TypeStableMutableStruct{Float64}(5.0, 3.0), build_tangent(TypeStableMutableStruct{Float64}, 5.0, 4.0), @@ -76,12 +79,7 @@ build_tangent(MutableFoo, 4.0), build_tangent(MutableFoo, 9.0), ), - ( - UnitRange{Int}(5, 7), - build_tangent(UnitRange{Int}, NoTangent(), NoTangent()), - build_tangent(UnitRange{Int}, NoTangent(), NoTangent()), - build_tangent(UnitRange{Int}, NoTangent(), NoTangent()), - ), + (UnitRange{Int}(5, 7), NoTangent(), NoTangent(), NoTangent()), ], map([ LowerTriangular{Float64, Matrix{Float64}}, @@ -198,6 +196,27 @@ end end end + + # The main tangent testing functionality really needs refactoring -- currently it's + # not possible to properly test tangents because you can't separately specify the + # static and dynamic types. + @testset "extra tuple tests" begin + @test tangent_type(Tuple) == Any + @test tangent_type(Tuple{}) == NoTangent + @test tangent_type(Tuple{Float64, Vararg}) == Any + @test tangent_type(Tuple{Float64}) == Tuple{Float64} + @test tangent_type(Tuple{Float64, Int}) == Tuple{Float64, NoTangent} + @test tangent_type(Tuple{Int, Int}) == NoTangent + @test tangent_type(Tuple{DataType, Type{Float64}}) == NoTangent + @test ==( + tangent_type(Union{Tuple{Float64}, Tuple{Int}}), + Union{Tuple{Float64}, NoTangent}, + ) + @test ==( + tangent_type(Union{Tuple{Float64}, Tuple{Int}, Tuple{Float64, Int}}), + Union{Tuple{Float64}, NoTangent, Tuple{Float64, NoTangent}}, + ) + end end diff --git a/test/test_utils.jl b/test/test_utils.jl index a02a3f9bd..0265aac9d 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -92,13 +92,4 @@ z̄ = (x̄..., tangent(y_ȳ)) @test_throws AssertionError populate_address_map(z, z̄) end - @testset "toy API" begin - f = (x, y) -> x * y + sin(x) * cos(y) - x = 5.0 - y = 4.0 - rule, in_f = TestUtils.set_up_gradient_problem(f, x, y) - v, grad = TestUtils.value_and_gradient!!(rule, in_f, f, x, y) - @test v ≈ f(x, y) - @test grad isa Tuple{tangent_type(typeof(f)), Float64, Float64} - end end diff --git a/test/utils.jl b/test/utils.jl index d075a4653..ae13667a7 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,9 +1,18 @@ @testset "utils" begin - @test _typeof(5.0) == Float64 - @test _typeof(randn(1)) == Vector{Float64} - @test _typeof(Float64) == Type{Float64} - @test _typeof(Vector{Int}) == Type{Vector{Int}} - @test _typeof(Vector{T} where {T}) == Type{Vector} - @test _typeof((5.0, Float64)) == Tuple{Float64, Type{Float64}} - @test _typeof((a=5.0, b=Float64)) == @NamedTuple{a::Float64, b::Type{Float64}} + @testset "_typeof" begin + @test _typeof(5.0) == Float64 + @test _typeof(randn(1)) == Vector{Float64} + @test _typeof(Float64) == Type{Float64} + @test _typeof(Vector{Int}) == Type{Vector{Int}} + @test _typeof(Vector{T} where {T}) == Type{Vector} + @test _typeof((5.0, Float64)) == Tuple{Float64, Type{Float64}} + @test _typeof((a=5.0, b=Float64)) == @NamedTuple{a::Float64, b::Type{Float64}} + end + @testset "tuple_map" begin + @test map(sin, (5.0, 4.0)) == Taped.tuple_map(sin, (5.0, 4.0)) + @test ==( + map(*, (5, 4.0, 3), (5.0, 4, 3.0)), + Taped.tuple_map(*, (5, 4.0, 3), (5.0, 4, 3.0)), + ) + end end