Skip to content

Commit

Permalink
Perf and Pre-Allocation (#396)
Browse files Browse the repository at this point in the history
* Use generated function for complicated tangent method

* Bump patch version

* Make it work

* Upgrades for DI

* Add experimental warning

* Add a couple of extra tests

* Formatting

* Remove redundant error handling

* Docstrings and perf

* Fix formatting

* Fix docstrings
  • Loading branch information
willtebbutt authored Nov 28, 2024
1 parent 0f37c07 commit d2d97a2
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ tangent_type(::Type{NoFData}, ::Type{R}) where {R<:IEEEFloat} = R
tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Array} = F

# Tuples
function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple,R<:Tuple}
@generated function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple,R<:Tuple}
return Tuple{tuple_map(tangent_type, Tuple(F.parameters), Tuple(R.parameters))...}
end
function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:Tuple}
Expand Down
126 changes: 110 additions & 16 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...)
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...; y_cache=nothing)
*Note:* this is not part of the public Mooncake.jl interface, and may change without warning.
Expand All @@ -8,13 +8,15 @@ In-place version of `value_and_pullback!!` in which the arguments have been wrap
if calling this function multiple times with different values of `x`, should be careful to
ensure that you zero-out the tangent fields of `x` each time.
"""
function __value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual,N}) where {R,N,T}
function __value_and_pullback!!(
rule::R, ȳ::T, fx::Vararg{CoDual,N}; y_cache=nothing
) where {R,N,T}
fx_fwds = tuple_map(to_fwds, fx)
__verify_sig(rule, fx_fwds)
out, pb!! = rule(fx_fwds...)
@assert _typeof(tangent(out)) == fdata_type(T)
increment!!(tangent(out), fdata(ȳ))
v = copy(primal(out))
v = y_cache === nothing ? copy(primal(out)) : _copy!!(y_cache, primal(out))
return v, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(rdata(ȳ)))
end

Expand All @@ -38,6 +40,15 @@ struct ValueAndGradientReturnTypeError <: Exception
msg::String
end

function throw_val_and_grad_ret_type_error(y)
throw(
ValueAndGradientReturnTypeError(
"When calling __value_and_gradient!!, return value of primal must be a " *
"subtype of IEEEFloat. Instead, found value of type $(typeof(y)).",
),
)
end

"""
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)
Expand Down Expand Up @@ -72,17 +83,7 @@ function __value_and_gradient!!(rule::R, fx::Vararg{CoDual,N}) where {R,N}
__verify_sig(rule, fx_fwds)
out, pb!! = rule(fx_fwds...)
y = primal(out)
if !(y isa IEEEFloat)
throw(
ValueAndGradientReturnTypeError(
"When calling __value_and_gradient!!, return value of primal must be a " *
"subtype of IEEEFloat. Instead, found value of type $(typeof(y)).",
),
)
end
@assert y isa IEEEFloat
@assert tangent(out) isa NoFData

y isa IEEEFloat || throw_val_and_grad_ret_type_error(y)
return y, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(one(y)))
end

Expand Down Expand Up @@ -126,8 +127,8 @@ Equivalent to `value_and_pullback!!(rule, 1.0, f, x...)`, and assumes `f` return
`Float64`.
*Note:* There are lots of subtle ways to mis-use `value_and_pullback!!`, so we generally
recommend using [`value_and_gradient!!`](@ref) (this function) where possible. The docstring for
`value_and_pullback!!` is useful for understanding this function though.
recommend using `Mooncake.value_and_gradient!!` (this function) where possible. The
docstring for `value_and_pullback!!` is useful for understanding this function though.
An example:
```jldoctest
Expand Down Expand Up @@ -163,3 +164,96 @@ function __create_coduals(args)
end
end
end

struct Cache{Trule,Ty_cache,Ttangents<:Tuple}
rule::Trule
y_cache::Ty_cache
tangents::Ttangents
end

_copy!!(dst, src) = copy!(dst, src)
_copy!!(::Number, src::Number) = src

"""
prepare_pullback_cache(f, x...)
WARNING: experimental functionality. Interface subject to change without warning!
Returns a `cache` which can be passed to `value_and_gradient!!`. See the docstring for
`Mooncake.value_and_gradient!!` for more info.
"""
function prepare_pullback_cache(fx...; kwargs...)

# Take a copy before mutating.
fx = deepcopy(fx)

# Construct rule and tangents.
rule = build_rrule(get_interpreter(), Tuple{map(_typeof, fx)...}; kwargs...)
tangents = map(zero_tangent, fx)

# Run the rule forwards -- this should do a decent chunk of pre-allocation.
y, _ = rule(map((x, dx) -> CoDual(x, fdata(dx)), fx, tangents)...)

# Construct cache for output. Check that `copy!`ing appears to work.
y_cache = copy(primal(y))
return Cache(rule, _copy!!(y_cache, primal(y)), tangents)
end

"""
value_and_pullback!!(cache::Cache, ȳ, f, x...)
WARNING: experimental functionality. Interface subject to change without warning!
Like other methods of `value_and_pullback!!`, but makes use of the `cache` object in order
to avoid having to re-allocate various tangent objects repeatedly.
You must ensure that `f` and `x` are the same types and sizes as those used to construct
`cache`.
Warning: any mutable components of values returned by `value_and_gradient!!` will be mutated
if you run this function again with different arguments. Therefore, if you need to keep the
values returned by this function around over multiple calls to this function with the same
`cache`, you should take a copy of them before calling again.
"""
function value_and_pullback!!(cache::Cache, ȳ, f::F, x::Vararg{Any,N}) where {F,N}
tangents = tuple_map(set_to_zero!!, cache.tangents)
coduals = tuple_map(CoDual, (f, x...), tangents)
return __value_and_pullback!!(cache.rule, ȳ, coduals...; y_cache=cache.y_cache)
end

"""
prepare_gradient_cache(f, x...)
WARNING: experimental functionality. Interface subject to change without warning!
Returns a `cache` which can be passed to `value_and_gradient!!`. See the docstring for
`Mooncake.value_and_gradient!!` for more info.
"""
function prepare_gradient_cache(fx...; kwargs...)
rule = build_rrule(fx...; kwargs...)
tangents = map(zero_tangent, fx)
y, _ = rule(map((x, dx) -> CoDual(x, fdata(dx)), fx, tangents)...)
primal(y) isa IEEEFloat || throw_val_and_grad_ret_type_error(primal(y))
return Cache(rule, nothing, tangents)
end

"""
value_and_gradient!!(cache::Cache, fx::Vararg{Any, N}) where {N}
WARNING: experimental functionality. Interface subject to change without warning!
Like other methods of `value_and_gradient!!`, but makes use of the `cache` object in order
to avoid having to re-allocate various tangent objects repeatedly.
You must ensure that `f` and `x` are the same types and sizes as those used to construct
`cache`.
Warning: any mutable components of values returned by `value_and_gradient!!` will be mutated
if you run this function again with different arguments. Therefore, if you need to keep the
values returned by this function around over multiple calls to this function with the same
`cache`, you should take a copy of them before calling again.
"""
function value_and_gradient!!(cache::Cache, f::F, x::Vararg{Any,N}) where {F,N}
coduals = tuple_map(CoDual, (f, x...), tuple_map(set_to_zero!!, cache.tangents))
return __value_and_gradient!!(cache.rule, coduals...)
end
2 changes: 1 addition & 1 deletion src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ Set `x` to its zero element (`x` should be a tangent, so the zero must exist).
"""
set_to_zero!!(::NoTangent) = NoTangent()
set_to_zero!!(x::Base.IEEEFloat) = zero(x)
set_to_zero!!(x::Union{Tuple,NamedTuple}) = map(set_to_zero!!, x)
set_to_zero!!(x::Union{Tuple,NamedTuple}) = tuple_map(set_to_zero!!, x)
function set_to_zero!!(x::T) where {T<:PossiblyUninitTangent}
return is_init(x) ? T(set_to_zero!!(val(x))) : x
end
Expand Down
42 changes: 42 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
function count_allocs(fargs::P) where {P<:Tuple}
f, args... = fargs
f(args...) # warmup
return @allocations f(args...)
end

@testset "interface" begin
@testset "$(typeof((f, x...)))" for (ȳ, f, x...) in Any[
(1.0, (x, y) -> x * y + sin(x) * cos(y), 5.0, 4.0),
Expand Down Expand Up @@ -28,6 +34,7 @@
((x, y) -> x + sin(y), randn(Float64), randn(Float64)),
((x, y) -> x + sin(y), randn(Float32), randn(Float32)),
((x...) -> x[1] + x[2], randn(Float64), randn(Float64)),
(sum, randn(10)),
]
rule = build_rrule(fargs...)
f, args... = fargs
Expand All @@ -36,12 +43,47 @@
for (arg, darg) in zip(fargs, dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end

cache = Mooncake.prepare_gradient_cache(fargs...)
_v, _dfargs = value_and_gradient!!(cache, fargs...)
@test _v == v
for (arg, darg) in zip(fargs, _dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end
@test count_allocs((value_and_gradient!!, cache, fargs...)) == 0
end

rule = build_rrule(identity, (5.0, 4.0))
@test_throws(
Mooncake.ValueAndGradientReturnTypeError,
value_and_gradient!!(rule, identity, (5.0, 4.0)),
)
@test_throws(
Mooncake.ValueAndGradientReturnTypeError,
Mooncake.prepare_gradient_cache(identity, (5.0, 4.0)),
)
end
@testset "value_and_pullback!!" begin
@testset "($(typeof(fargs))" for (ȳ, fargs...) in Any[
(randn(10), identity, randn(10)),
(randn(), sin, randn(Float64)),
(randn(), sum, randn(Float64)),
]
rule = build_rrule(fargs...)
f, args... = fargs
v, dfargs = value_and_pullback!!(rule, ȳ, fargs...)
@test v == f(args...)
for (arg, darg) in zip(fargs, dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end

cache = Mooncake.prepare_pullback_cache(fargs...)
_v, _dfargs = value_and_pullback!!(cache, ȳ, fargs...)
@test _v == v
for (arg, darg) in zip(fargs, _dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end
@test count_allocs((value_and_pullback!!, cache, ȳ, fargs...)) == 0
end
end
end

2 comments on commit d2d97a2

@willtebbutt
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: Version 0.4.51 already exists

Please sign in to comment.