Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Perf and Pre-Allocation #396

Merged
merged 11 commits into from
Nov 28, 2024
Merged
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.50"
version = "0.4.51"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
2 changes: 1 addition & 1 deletion src/fwds_rvs_data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -784,7 +784,7 @@ tangent_type(::Type{NoFData}, ::Type{R}) where {R<:IEEEFloat} = R
tangent_type(::Type{F}, ::Type{NoRData}) where {F<:Array} = F

# Tuples
function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple,R<:Tuple}
@generated function tangent_type(::Type{F}, ::Type{R}) where {F<:Tuple,R<:Tuple}
return Tuple{tuple_map(tangent_type, Tuple(F.parameters), Tuple(R.parameters))...}
end
function tangent_type(::Type{NoFData}, ::Type{R}) where {R<:Tuple}
Expand Down
126 changes: 110 additions & 16 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...)
__value_and_pullback!!(rule, ȳ, f::CoDual, x::CoDual...; y_cache=nothing)

*Note:* this is not part of the public Mooncake.jl interface, and may change without warning.

Expand All @@ -8,13 +8,15 @@
if calling this function multiple times with different values of `x`, should be careful to
ensure that you zero-out the tangent fields of `x` each time.
"""
function __value_and_pullback!!(rule::R, ȳ::T, fx::Vararg{CoDual,N}) where {R,N,T}
function __value_and_pullback!!(
rule::R, ȳ::T, fx::Vararg{CoDual,N}; y_cache=nothing
) where {R,N,T}
fx_fwds = tuple_map(to_fwds, fx)
__verify_sig(rule, fx_fwds)
out, pb!! = rule(fx_fwds...)
@assert _typeof(tangent(out)) == fdata_type(T)
increment!!(tangent(out), fdata(ȳ))
v = copy(primal(out))
v = y_cache === nothing ? copy(primal(out)) : _copy!!(y_cache, primal(out))
return v, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(rdata(ȳ)))
end

Expand All @@ -38,6 +40,15 @@
msg::String
end

function throw_val_and_grad_ret_type_error(y)
throw(

Check warning on line 44 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L43-L44

Added lines #L43 - L44 were not covered by tests
ValueAndGradientReturnTypeError(
"When calling __value_and_gradient!!, return value of primal must be a " *
"subtype of IEEEFloat. Instead, found value of type $(typeof(y)).",
),
)
end

"""
__value_and_gradient!!(rule, f::CoDual, x::CoDual...)

Expand Down Expand Up @@ -72,17 +83,7 @@
__verify_sig(rule, fx_fwds)
out, pb!! = rule(fx_fwds...)
y = primal(out)
if !(y isa IEEEFloat)
throw(
ValueAndGradientReturnTypeError(
"When calling __value_and_gradient!!, return value of primal must be a " *
"subtype of IEEEFloat. Instead, found value of type $(typeof(y)).",
),
)
end
@assert y isa IEEEFloat
@assert tangent(out) isa NoFData

y isa IEEEFloat || throw_val_and_grad_ret_type_error(y)

Check warning on line 86 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L86

Added line #L86 was not covered by tests
return y, tuple_map((f, r) -> tangent(fdata(tangent(f)), r), fx, pb!!(one(y)))
end

Expand Down Expand Up @@ -126,8 +127,8 @@
`Float64`.

*Note:* There are lots of subtle ways to mis-use `value_and_pullback!!`, so we generally
recommend using [`value_and_gradient!!`](@ref) (this function) where possible. The docstring for
`value_and_pullback!!` is useful for understanding this function though.
recommend using `Mooncake.value_and_gradient!!` (this function) where possible. The
docstring for `value_and_pullback!!` is useful for understanding this function though.

An example:
```jldoctest
Expand Down Expand Up @@ -163,3 +164,96 @@
end
end
end

struct Cache{Trule,Ty_cache,Ttangents<:Tuple}
rule::Trule
y_cache::Ty_cache
tangents::Ttangents
end

_copy!!(dst, src) = copy!(dst, src)
_copy!!(::Number, src::Number) = src

Check warning on line 175 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L174-L175

Added lines #L174 - L175 were not covered by tests

"""
prepare_pullback_cache(f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Returns a `cache` which can be passed to `value_and_gradient!!`. See the docstring for
`Mooncake.value_and_gradient!!` for more info.
"""
function prepare_pullback_cache(fx...; kwargs...)

Check warning on line 185 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L185

Added line #L185 was not covered by tests

# Take a copy before mutating.
fx = deepcopy(fx)

Check warning on line 188 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L188

Added line #L188 was not covered by tests

# Construct rule and tangents.
rule = build_rrule(get_interpreter(), Tuple{map(_typeof, fx)...}; kwargs...)
tangents = map(zero_tangent, fx)

Check warning on line 192 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L191-L192

Added lines #L191 - L192 were not covered by tests

# Run the rule forwards -- this should do a decent chunk of pre-allocation.
y, _ = rule(map((x, dx) -> CoDual(x, fdata(dx)), fx, tangents)...)

Check warning on line 195 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L195

Added line #L195 was not covered by tests

# Construct cache for output. Check that `copy!`ing appears to work.
y_cache = copy(primal(y))
return Cache(rule, _copy!!(y_cache, primal(y)), tangents)

Check warning on line 199 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L198-L199

Added lines #L198 - L199 were not covered by tests
end

"""
value_and_pullback!!(cache::Cache, ȳ, f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Like other methods of `value_and_pullback!!`, but makes use of the `cache` object in order
to avoid having to re-allocate various tangent objects repeatedly.

You must ensure that `f` and `x` are the same types and sizes as those used to construct
`cache`.

Warning: any mutable components of values returned by `value_and_gradient!!` will be mutated
if you run this function again with different arguments. Therefore, if you need to keep the
values returned by this function around over multiple calls to this function with the same
`cache`, you should take a copy of them before calling again.
"""
function value_and_pullback!!(cache::Cache, ȳ, f::F, x::Vararg{Any,N}) where {F,N}
tangents = tuple_map(set_to_zero!!, cache.tangents)
coduals = tuple_map(CoDual, (f, x...), tangents)
return __value_and_pullback!!(cache.rule, ȳ, coduals...; y_cache=cache.y_cache)

Check warning on line 221 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L218-L221

Added lines #L218 - L221 were not covered by tests
end

"""
prepare_gradient_cache(f, x...)

WARNING: experimental functionality. Interface subject to change without warning!

Returns a `cache` which can be passed to `value_and_gradient!!`. See the docstring for
`Mooncake.value_and_gradient!!` for more info.
"""
function prepare_gradient_cache(fx...; kwargs...)
rule = build_rrule(fx...; kwargs...)
tangents = map(zero_tangent, fx)
y, _ = rule(map((x, dx) -> CoDual(x, fdata(dx)), fx, tangents)...)
primal(y) isa IEEEFloat || throw_val_and_grad_ret_type_error(primal(y))
return Cache(rule, nothing, tangents)

Check warning on line 237 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L232-L237

Added lines #L232 - L237 were not covered by tests
end

"""
value_and_gradient!!(cache::Cache, fx::Vararg{Any, N}) where {N}

WARNING: experimental functionality. Interface subject to change without warning!

Like other methods of `value_and_gradient!!`, but makes use of the `cache` object in order
to avoid having to re-allocate various tangent objects repeatedly.

You must ensure that `f` and `x` are the same types and sizes as those used to construct
`cache`.

Warning: any mutable components of values returned by `value_and_gradient!!` will be mutated
if you run this function again with different arguments. Therefore, if you need to keep the
values returned by this function around over multiple calls to this function with the same
`cache`, you should take a copy of them before calling again.
"""
function value_and_gradient!!(cache::Cache, f::F, x::Vararg{Any,N}) where {F,N}
coduals = tuple_map(CoDual, (f, x...), tuple_map(set_to_zero!!, cache.tangents))
return __value_and_gradient!!(cache.rule, coduals...)

Check warning on line 258 in src/interface.jl

View check run for this annotation

Codecov / codecov/patch

src/interface.jl#L256-L258

Added lines #L256 - L258 were not covered by tests
end
2 changes: 1 addition & 1 deletion src/tangents.jl
Original file line number Diff line number Diff line change
Expand Up @@ -623,7 +623,7 @@ Set `x` to its zero element (`x` should be a tangent, so the zero must exist).
"""
set_to_zero!!(::NoTangent) = NoTangent()
set_to_zero!!(x::Base.IEEEFloat) = zero(x)
set_to_zero!!(x::Union{Tuple,NamedTuple}) = map(set_to_zero!!, x)
set_to_zero!!(x::Union{Tuple,NamedTuple}) = tuple_map(set_to_zero!!, x)
function set_to_zero!!(x::T) where {T<:PossiblyUninitTangent}
return is_init(x) ? T(set_to_zero!!(val(x))) : x
end
Expand Down
42 changes: 42 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
function count_allocs(fargs::P) where {P<:Tuple}
f, args... = fargs
f(args...) # warmup
return @allocations f(args...)
end

@testset "interface" begin
@testset "$(typeof((f, x...)))" for (ȳ, f, x...) in Any[
(1.0, (x, y) -> x * y + sin(x) * cos(y), 5.0, 4.0),
Expand Down Expand Up @@ -28,6 +34,7 @@
((x, y) -> x + sin(y), randn(Float64), randn(Float64)),
((x, y) -> x + sin(y), randn(Float32), randn(Float32)),
((x...) -> x[1] + x[2], randn(Float64), randn(Float64)),
(sum, randn(10)),
]
rule = build_rrule(fargs...)
f, args... = fargs
Expand All @@ -36,12 +43,47 @@
for (arg, darg) in zip(fargs, dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end

cache = Mooncake.prepare_gradient_cache(fargs...)
_v, _dfargs = value_and_gradient!!(cache, fargs...)
@test _v == v
for (arg, darg) in zip(fargs, _dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end
@test count_allocs((value_and_gradient!!, cache, fargs...)) == 0
end

rule = build_rrule(identity, (5.0, 4.0))
@test_throws(
Mooncake.ValueAndGradientReturnTypeError,
value_and_gradient!!(rule, identity, (5.0, 4.0)),
)
@test_throws(
Mooncake.ValueAndGradientReturnTypeError,
Mooncake.prepare_gradient_cache(identity, (5.0, 4.0)),
)
end
@testset "value_and_pullback!!" begin
@testset "($(typeof(fargs))" for (ȳ, fargs...) in Any[
(randn(10), identity, randn(10)),
(randn(), sin, randn(Float64)),
(randn(), sum, randn(Float64)),
]
rule = build_rrule(fargs...)
f, args... = fargs
v, dfargs = value_and_pullback!!(rule, ȳ, fargs...)
@test v == f(args...)
for (arg, darg) in zip(fargs, dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end

cache = Mooncake.prepare_pullback_cache(fargs...)
_v, _dfargs = value_and_pullback!!(cache, ȳ, fargs...)
@test _v == v
for (arg, darg) in zip(fargs, _dfargs)
@test tangent_type(typeof(arg)) == typeof(darg)
end
@test count_allocs((value_and_pullback!!, cache, ȳ, fargs...)) == 0
end
end
end
Loading