Skip to content

Commit

Permalink
Makes apply_iterate work (#152)
Browse files Browse the repository at this point in the history
* Bump patch

* Sketch of implementation

* Fix bug in array construction rule

* Implementation of apply_iterate

* Test edge case for unsafe_copyto

* Update readme to point out that we now do handle _apply_iterate

* Undo double coment

* Bump patch
  • Loading branch information
willtebbutt authored May 15, 2024
1 parent 5514cdb commit 2843542
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.12"
version = "0.2.13"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,5 @@ Please be aware that by "performant" we mean similar or better performance than

While `Tapir.jl` should now work on a very large subset of the language, there remain things that you should expect not to work. A non-exhaustive list of things to bear in mind includes:
1. It is always necessary to produce hand-written rules for `ccall`s (and, more generally, foreigncall nodes). We have rules for many `ccall`s, but not all. If you encounter a foreigncall without a hand-written rule, you should get an informative error message which tells you what is going on and how to deal with it.
1. Builtins which require rules. The vast majority of them have rules now, but some don't. Notably, `apply_iterate` does not have a rule, so `Tapir.jl` cannot currently AD through type-unstable splatting -- someone should resolve this.
1. Builtins which require rules. The vast majority of them have rules now, but some don't. ~~Notably, `apply_iterate` does not have a rule, so `Tapir.jl` cannot currently AD through type-unstable splatting -- someone should resolve this.~~ `Core._apply_iterate` should now work correctly.
1. Anything involving tasks / threading -- we have no thread safety guarantees and, at the time of writing, I'm not entirely sure what error you will find if you attempt to AD through code which uses Julia's task / thread system. The same applies to distributed computing. These limitations ought to be possible to resolve.
71 changes: 70 additions & 1 deletion src/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,69 @@ function rrule!!(f::CoDual{typeof(===)}, x, y)
end

# Core._abstracttype

#
# Core._apply_iterate
#
# Core._apply_iterate is a tricky case, and requires calling back into AD to handle
# properly. The basic strategy is to differentiate a function which is semantically
# identical to Core._apply_iterate, but whose components we know how to differentiate.
#

# A function with the same semantics as `Core._apply_iterate`, but which is differentiable.
function _apply_iterate_equivalent(itr, f::F, args::Vararg{Any, N}) where {F, N}
vec_args = reduce(vcat, tuple_map(collect, args))
tuple_args = __vec_to_tuple(vec_args)
return __barrier(f, tuple_args)
end

# A primitive used to avoid exposing `_apply_iterate_equivalent` to `Core._apply_iterate`.
__vec_to_tuple(v::Vector) = Tuple(v)

@is_primitive MinimalCtx Tuple{typeof(__vec_to_tuple), Vector}

function rrule!!(::CoDual{typeof(__vec_to_tuple)}, v::CoDual{<:Vector})
dv = tangent(v)
y = CoDual(Tuple(primal(v)), fdata(Tuple(dv)))
function vec_to_tuple_pb!!(dy::Union{Tuple, NoRData})
if dy isa Tuple
for n in eachindex(dy)
dv[n] = increment_rdata!!(dv[n], dy[n])
end
end
return NoRData(), NoRData()
end
return y, vec_to_tuple_pb!!
end

@noinline __barrier(f::F, args::Tuple) where {F} = f(args...)

# Over-ride default definition of `is_primitive` for buildins.
is_primitive(::Type{MinimalCtx}, ::Type{<:Tuple{typeof(Core._apply_iterate), Vararg}}) = false

struct ApplyIterateRule{R}
rule::R
end

function (rule::ApplyIterateRule)(::CoDual{typeof(Core._apply_iterate)}, args::CoDual...)
return rule.rule(zero_fcodual(_apply_iterate_equivalent), args...)
end

function build_rrule(
interp::TapirInterpreter, sig::Type{<:Tuple{typeof(Core._apply_iterate), Vararg}};
kwargs...
)
new_sig = Tuple{typeof(_apply_iterate_equivalent), sig.parameters[2:end]...}
return ApplyIterateRule(build_rrule(interp, new_sig; kwargs...))
end

function rule_type(
interp::TapirInterpreter{C}, sig::Type{<:Tuple{typeof(Core._apply_iterate), Vararg}}
) where {C}
new_sig = Tuple{typeof(_apply_iterate_equivalent), sig.parameters[2:end]...}
return ApplyIterateRule{rule_type(interp, new_sig)}
end

# Core._apply_pure
# Core._call_in_world
# Core._call_in_world_total
Expand Down Expand Up @@ -770,7 +832,9 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:builtins})

# Non-intrinsic built-ins:
# Core._abstracttype -- NEEDS IMPLEMENTING AND TESTING
# Core._apply_iterate -- NEEDS IMPLEMENTING AND TESTING
(false, :none, nothing, __vec_to_tuple, [1.0]),
(false, :none, nothing, __vec_to_tuple, Any[1.0]),
(false, :none, nothing, __vec_to_tuple, Any[[1.0]]),
# Core._apply_pure -- NEEDS IMPLEMENTING AND TESTING
# Core._call_in_world -- NEEDS IMPLEMENTING AND TESTING
# Core._call_in_world_total -- NEEDS IMPLEMENTING AND TESTING
Expand Down Expand Up @@ -919,6 +983,11 @@ end

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:builtins})
test_cases = Any[
(false, :none, nothing, Core._apply_iterate, Base.iterate, *, 5.0, 4.0),
(false, :none, nothing, Core._apply_iterate, Base.iterate, *, (5.0, 4.0)),
(false, :none, nothing, Core._apply_iterate, Base.iterate, *, [5.0, 4.0]),
(false, :none, nothing, Core._apply_iterate, Base.iterate, *, [5.0], (4.0, )),
(false, :none, nothing, Core._apply_iterate, Base.iterate, *, 3, (4.0, )),
(
false, :none, nothing,
(
Expand Down
23 changes: 19 additions & 4 deletions src/rrules/foreigncall.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,12 @@ function rrule!!(
return zero_fcodual(Array{T, N}(undef, map(primal, m)...)), NoPullback(f, u, m...)
end

function rrule!!(
f::CoDual{Type{Array{T, 0}}}, u::CoDual{typeof(undef)}, m::CoDual{Tuple{}}
) where {T}
return zero_fcodual(Array{T, 0}(undef)), NoPullback(f, u, m)
end

@is_primitive MinimalCtx Tuple{Type{<:Array{T, N}}, typeof(undef), NTuple{N}} where {T, N}
function rrule!!(
::CoDual{<:Type{<:Array{T, N}}}, ::CoDual{typeof(undef)}, m::CoDual{NTuple{N}},
Expand Down Expand Up @@ -349,8 +355,8 @@ function rrule!!(
_soffs = primal(soffs)
pdest = primal(dest)
ddest = tangent(dest)
dest_copy = primal(dest)[dest_idx]
ddest_copy = tangent(dest)[dest_idx]
dest_copy = pdest[dest_idx]
ddest_copy = ddest[dest_idx]

# Run primal computation.
dsrc = tangent(src)
Expand All @@ -364,8 +370,11 @@ function rrule!!(
dsrc[src_idx] .= increment!!.(view(dsrc, src_idx), view(ddest, dest_idx))

# Restore initial state.
pdest[dest_idx] .= dest_copy
ddest[dest_idx] .= ddest_copy
@inbounds for n in eachindex(dest_copy)
isassigned(dest_copy, n) || continue
pdest[dest_idx[n]] = dest_copy[n]
ddest[dest_idx[n]] = ddest_copy[n]
end

return NoRData(), NoRData(), NoRData(), NoRData(), NoRData(), NoRData()
end
Expand Down Expand Up @@ -522,11 +531,13 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall})
test_cases = Any[
(false, :stability, nothing, Base.allocatedinline, Float64),
(false, :stability, nothing, Base.allocatedinline, Vector{Float64}),
(true, :stability, nothing, Array{Float64, 0}, undef),
(true, :stability, nothing, Array{Float64, 1}, undef, 5),
(true, :stability, nothing, Array{Float64, 2}, undef, 5, 4),
(true, :stability, nothing, Array{Float64, 3}, undef, 5, 4, 3),
(true, :stability, nothing, Array{Float64, 4}, undef, 5, 4, 3, 2),
(true, :stability, nothing, Array{Float64, 5}, undef, 5, 4, 3, 2, 1),
(true, :stability, nothing, Array{Float64, 0}, undef, ()),
(true, :stability, nothing, Array{Float64, 4}, undef, (2, 3, 4, 5)),
(true, :stability, nothing, Array{Float64, 5}, undef, (2, 3, 4, 5, 6)),
(false, :stability, nothing, copy, randn(5, 4)),
Expand Down Expand Up @@ -574,6 +585,10 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:foreigncall})
false, :stability, nothing,
unsafe_copyto!, [rand(3) for _ in 1:5], 2, [rand(4) for _ in 1:4], 1, 3,
),
(
false, :none, nothing,
unsafe_copyto!, Vector{Any}(undef, 5), 2, Any[rand() for _ in 1:4], 1, 3,
),
(false, :stability, nothing, deepcopy, 5.0),
(false, :stability, nothing, deepcopy, randn(5)),
(false, :none, nothing, deepcopy, TestResources.MutableFoo(5.0, randn(5))),
Expand Down
13 changes: 13 additions & 0 deletions test/rrules/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@
Tapir.rrule!!(CoDual(IntrinsicsWrappers.sub_ptr, NoTangent()), 5.0, 4.0),
)

@testset "_apply_iterate_equivalent with $(typeof(args))" for args in Any[
(*, 5.0, 4.0),
(*, (5.0, 4.0)),
(*, [1.0, 2.0]),
(*, 1.0, [2.0]),
(*, [1.0, 2.0], ()),
]
@test ==(
Core._apply_iterate(Base.iterate, args...),
Tapir._apply_iterate_equivalent(Base.iterate, args...),
)
end

TestUtils.run_rrule!!_test_cases(StableRNG, Val(:builtins))

@testset "Disable bitcast to differentiable type" begin
Expand Down

2 comments on commit 2843542

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/106944

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.2.13 -m "<description of version>" 28435424138501c0044e9ffe72dcdf39d3c2ce49
git push origin v0.2.13

Please sign in to comment.