Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Further ChainRulesCore.rrule Integration #254

Merged
merged 68 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
cef527f
Bump patch version
willtebbutt Sep 12, 2024
b9c3f65
Fix usage with benchmarktools
willtebbutt Sep 12, 2024
1f49d85
Merge branch 'main' into wct/actually-improve-rrule-integration
willtebbutt Sep 13, 2024
8f0f75d
Initial pass
willtebbutt Sep 13, 2024
e791cef
Bump patch
willtebbutt Sep 13, 2024
f45456e
Unit test to_tapir_tangent and to_cr_tangent
willtebbutt Sep 13, 2024
bec9f06
Make use of macro
willtebbutt Sep 13, 2024
d037101
More testing and tidying up
willtebbutt Sep 13, 2024
54947f0
Add some basic type checking and a test
willtebbutt Sep 13, 2024
bc88483
Improve formatting and commenting
willtebbutt Sep 13, 2024
f29b8f3
Formatting
willtebbutt Sep 13, 2024
50d7dd8
Improve documentation
willtebbutt Sep 13, 2024
1788c07
Explain how not to use rrule functionality
willtebbutt Sep 13, 2024
b4e80bc
Add rules for BLAS utilities
willtebbutt Sep 13, 2024
4a2b8e0
Initial NNlib integration
willtebbutt Sep 13, 2024
d1d9fae
Thunks and batched_mul
willtebbutt Sep 13, 2024
6f036ad
More rules + kwargs + rename
willtebbutt Sep 13, 2024
e225a0a
Fix link in docs
willtebbutt Sep 13, 2024
3bba38e
Rename chain_rules_macro to chain_rules_interop
willtebbutt Sep 13, 2024
619f0ce
Complete rename of chain rules interop file
willtebbutt Sep 16, 2024
345c46a
Refactor chain rules interop
willtebbutt Sep 16, 2024
8e87d11
Add more nnlib functionality
willtebbutt Sep 16, 2024
d345978
Remove old tests
willtebbutt Sep 16, 2024
0f3fe90
Some work
willtebbutt Sep 16, 2024
ae93a27
Remove errant show statment
willtebbutt Sep 17, 2024
82ecd82
Remove redundant test
willtebbutt Sep 17, 2024
ca93535
Support where
willtebbutt Sep 17, 2024
fc6c00f
Make use of where params
willtebbutt Sep 17, 2024
473bc02
Improve kwarg interface
willtebbutt Sep 17, 2024
1cfbfcc
Default kwargs test
willtebbutt Sep 17, 2024
8ac2903
Improve docstring
willtebbutt Sep 17, 2024
f60ca36
Merge in main
willtebbutt Sep 19, 2024
ce5afd9
Some work
willtebbutt Sep 25, 2024
3539d46
Merge in main
willtebbutt Sep 29, 2024
8a80218
Merge branch 'main' into wct/actually-improve-rrule-integration
willtebbutt Sep 29, 2024
ccdef0b
Merge branch 'main' into wct/actually-improve-rrule-integration
willtebbutt Sep 29, 2024
6edc9a4
Some work
willtebbutt Sep 30, 2024
f66cc9c
Better conv support in nnlib rules
willtebbutt Oct 1, 2024
f865fde
More LuxLib rules
willtebbutt Oct 1, 2024
149e7b4
Permit :meta nodes in IR
willtebbutt Oct 1, 2024
2dcd535
Remove redundant test
willtebbutt Oct 1, 2024
0933f37
Uncomment some tests
willtebbutt Oct 1, 2024
d217102
Rename chain rules doc
willtebbutt Oct 1, 2024
c6f8cf0
Add notes to docs on rule writing strategies
willtebbutt Oct 1, 2024
d12afa4
Add mooncake_overlay
willtebbutt Oct 2, 2024
fe1999d
Add simpler method of build_rrule
willtebbutt Oct 2, 2024
e9dce9d
Fix dispatch problem
willtebbutt Oct 2, 2024
6149fd5
Tidy up
willtebbutt Oct 2, 2024
d386cab
Tidy up build_rrule calls
willtebbutt Oct 2, 2024
4cf73de
Improve zero_adjoint docs
willtebbutt Oct 2, 2024
6ffc076
Improve documentation of from_rrule
willtebbutt Oct 2, 2024
d2e0764
Fix formatting
willtebbutt Oct 3, 2024
36ec276
Explain what is new
willtebbutt Oct 3, 2024
ee838f4
Improve from_rrule documentation
willtebbutt Oct 3, 2024
98d4840
Formatting
willtebbutt Oct 3, 2024
45f1a38
Fix formatting
willtebbutt Oct 3, 2024
4b1fff1
Add compat for ChainRulesCore in docs
willtebbutt Oct 3, 2024
3a98c22
Tidy up mooncake_method_table usage
willtebbutt Oct 3, 2024
2d54e51
Add extra luxlib test
willtebbutt Oct 3, 2024
31daf96
Add another luxlib test
willtebbutt Oct 3, 2024
40c8f17
Bump patch version
willtebbutt Oct 3, 2024
672bf69
Update ext/MooncakeNNlibExt.jl
willtebbutt Oct 3, 2024
5ab5752
Merge in main
willtebbutt Oct 8, 2024
b9be673
Restrict CI to 1.10 for now
willtebbutt Oct 8, 2024
36b3cba
Apply suggestions from code review
willtebbutt Oct 8, 2024
1918431
Restrict version consistently
willtebbutt Oct 8, 2024
57fac48
Fix typo in docstring
willtebbutt Oct 8, 2024
62ba1f2
Shove all testing functionality inside module
willtebbutt Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Mooncake"
uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.4.5"
version = "0.4.6"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand All @@ -22,13 +22,17 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"

[extensions]
MooncakeCUDAExt = "CUDA"
MooncakeDynamicPPLExt = "DynamicPPL"
MooncakeJETExt = "JET"
MooncakeLogDensityProblemsADExt = "LogDensityProblemsAD"
MooncakeLuxLibExt = "LuxLib"
MooncakeNNlibExt = "NNlib"
MooncakeSpecialFunctionsExt = "SpecialFunctions"

[compat]
Expand All @@ -46,7 +50,9 @@ FillArrays = "1"
Graphs = "1"
JET = "0.9"
LogDensityProblemsAD = "1"
LuxLib = "1.2"
MistyClosures = "1"
NNlib = "0.9"
PDMats = "0.11"
Setfield = "1"
SpecialFunctions = "2"
Expand All @@ -66,11 +72,14 @@ FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
TemporalGPs = "e155a3c4-0841-43e1-8b83-a0e4f03cc18f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"]
test = ["AbstractGPs", "BenchmarkTools", "CUDA", "DiffTests", "Distributions", "Documenter", "DynamicPPL", "FillArrays", "KernelFunctions", "JET", "LogDensityProblemsAD", "Lux", "LuxLib", "NNlib", "PDMats", "SpecialFunctions", "StableRNGs", "Test", "TemporalGPs"]
2 changes: 2 additions & 0 deletions docs/Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
DocumenterCitations = "daee34ce-89f3-4625-b898-19384cb65244"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"

[compat]
ChainRulesCore = "1"
Documenter = "1"
Mooncake = "0.4.0"
7 changes: 5 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ makedocs(
"Algorithmic Differentiation" => "algorithmic_differentiation.md",
"Mooncake.jl's Rule System" => "mathematical_interpretation.md",
],
"Utilities" => [
"Tools for Rules" => "tools_for_rules.md",
"Debug Mode" => "debug_mode.md",
"Debugging and MWEs" => "debugging_and_mwes.md",
],
"Known Limitations" => "known_limitations.md",
"Debug Mode" => "debug_mode.md",
"Debugging and MWEs" => "debugging_and_mwes.md",
]
)

Expand Down
5 changes: 5 additions & 0 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

Documentation for Mooncake.jl is on its way!

Note (03/10/2024): Various bits of utility functionality are now carefully documented. This
includes how to change the code which Mooncake sees, declare that the derivative of a
function is zero, make use of existing `ChainRules.rrule`s to quicky create new rules in
Mooncake, and more.

Note (02/07/2024): The first round of documentation has arrived.
This is largely targetted at those who are interested in contributing to Mooncake.jl -- you can find this work in the "Understanding Mooncake.jl" section of the docs.
There is more to to do, but it should be sufficient to understand how AD works in principle, and the core abstractions underlying Mooncake.jl.
Expand Down
2 changes: 1 addition & 1 deletion docs/src/known_limitations.md
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ function foo(x::Vector{Float64})
return unsafe_load(p)
end

rule = build_rrule(get_interpreter(), Tuple{typeof(foo), Vector{Float64}})
rule = build_rrule(Tuple{typeof(foo), Vector{Float64}})
Mooncake.value_and_gradient!!(rule, foo, [5.0, 4.0])

# output
Expand Down
33 changes: 33 additions & 0 deletions docs/src/tools_for_rules.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Tools for Rules

Most of the time, Mooncake.jl can just differentiate your code, but you will need to intervene if you make use of a language feature which is unsupported.
However, this does not always necessitate writing your own `rrule!!` from scratch.
In this section, we detail some useful strategies which can help you avoid having to write `rrule!!`s in many situations.

## Simplfiying Code via Overlays

```@docs
Mooncake.@mooncake_overlay
```

## Functions with Zero Adjoint

If the above strategy does not work, but you find yourself in the surprisingly common
situation that the adjoint of the derivative of your function is always zero, you can very
straightforwardly write a rule by making use of the following:
```@docs
Mooncake.@zero_adjoint
Mooncake.zero_adjoint
```

## Using ChainRules.jl

[ChainRules.jl](https://github.com/JuliaDiff/ChainRules.jl) provides a large number of rules for differentiating functions in reverse-mode.
These rules are methods of the `ChainRulesCore.rrule` function.
There are some instances where there is it most convenient to implement a `Mooncake.rrule!!` by wrapping an existing `ChainRulesCore.rrule`.
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

There is enough similarity between these two systems that most of the boilerplate code can be avoided.

```@docs
Mooncake.@from_rrule
```
14 changes: 3 additions & 11 deletions ext/MooncakeDynamicPPLExt.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,9 @@
module MooncakeDynamicPPLExt

if isdefined(Base, :get_extension)
using DynamicPPL: DynamicPPL, istrans
using Mooncake: Mooncake
else
using ..DynamicPPL: DynamicPPL, istrans
using ..Mooncake: Mooncake
end

using Mooncake: DefaultCtx, CoDual, simple_zero_adjoint
using DynamicPPL: DynamicPPL, istrans
using Mooncake: Mooncake

# This is purely an optimisation.
Mooncake.@is_primitive DefaultCtx Tuple{typeof(istrans), Vararg}
Mooncake.rrule!!(f::CoDual{typeof(istrans)}, x::CoDual...) = simple_zero_adjoint(f, x...)
Mooncake.@zero_adjoint Mooncake.DefaultCtx Tuple{typeof(istrans), Vararg}
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved

end # module
173 changes: 173 additions & 0 deletions ext/MooncakeLuxLibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
module MooncakeLuxLibExt

using LuxLib, Random, Mooncake
using Base: IEEEFloat

import LuxLib: Impl
import LuxLib.Utils: static_training_mode_check
import Mooncake:
@from_rrule,
DefaultCtx,
@mooncake_overlay,
CoDual

@from_rrule(DefaultCtx, Tuple{typeof(Impl.matmul), Array{P}, Array{P}} where {P<:IEEEFloat})
@from_rrule(
DefaultCtx,
Tuple{typeof(Impl.matmuladd), Array{P}, Array{P}, Vector{P}} where {P<:IEEEFloat},
)
@from_rrule(
DefaultCtx,
Tuple{typeof(Impl.batched_matmul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat},
)

# Re-implement a bunch of methods to ensure that Mooncake can differentiate them.
@mooncake_overlay function LuxLib.Impl.fused_dense(

Check warning on line 25 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L25

Added line #L25 was not covered by tests
opmode,
act::F,
weight::AbstractMatrix,
x::AbstractMatrix,
b::LuxLib.Optional{<:AbstractVector},
) where {F}
return bias_activation(act, Impl.matmul(weight, x), b)
end

@mooncake_overlay function LuxLib.Impl.bias_activation_loop!(

Check warning on line 35 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L35

Added line #L35 was not covered by tests
y::AbstractArray{yT, 3}, σ::F, x::AbstractArray{xT, 3}, bias::AbstractVector
) where {F, xT, yT}
return LuxLib.Impl.bias_activation_simd_loop!(y, σ, x, bias)
end

@mooncake_overlay function LuxLib.Impl.activation_loop!(

Check warning on line 41 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L41

Added line #L41 was not covered by tests
y::AbstractArray, σ::F, x::AbstractArray
) where {F}
return LuxLib.Impl.activation_simd_loop!(y, σ, x)
end

@mooncake_overlay function LuxLib.Impl.fused_conv(

Check warning on line 47 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L47

Added line #L47 was not covered by tests
::LuxLib.Impl.AbstractInternalArrayOpMode,
act::F,
weight::AbstractArray{wT, N},
x::AbstractArray{xT, N},
bias::LuxLib.Optional{<:AbstractVector},
cdims::LuxLib.Impl.ConvDims,
) where {F, wT, xT, N}
return LuxLib.Impl.bias_activation(act, LuxLib.Impl.conv(x, weight, cdims), bias)

Check warning on line 55 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L55

Added line #L55 was not covered by tests
end

for f in [
Impl.SLEEFActivations.sigmoid_fast,
Impl.SLEEFActivations.softplus,
Impl.SLEEFActivations.logsigmoid,
Impl.SLEEFActivations.swish,
Impl.SLEEFActivations.lisht,
Impl.SLEEFActivations.tanh,
Impl.SLEEFActivations.tanh_fast,
]
@from_rrule DefaultCtx Tuple{typeof(f), IEEEFloat}
@from_rrule(
DefaultCtx,
Tuple{typeof(Broadcast.broadcasted), typeof(f), Union{IEEEFloat, Array{<:IEEEFloat}}},
)
end

Mooncake.@zero_adjoint DefaultCtx Tuple{typeof(static_training_mode_check), Vararg}

# This is a really horrible hack that we need to do until Mooncake is able to support the
# call-back-into-ad interface that ChainRules exposes.

import LuxLib.Impl:
safe_eltype,
batchnorm_affine_normalize_internal,
batchnorm_affine_normalize_internal!,
∇batchnorm_affine_normalize,
AbstractInternalArrayOpMode

import ChainRulesCore as CRC

function CRC.rrule(
::typeof(batchnorm_affine_normalize_internal),
opmode::AbstractInternalArrayOpMode,
::typeof(identity),
x::AbstractArray{T, N},
μ::AbstractVector,
σ²::AbstractVector,
γ::LuxLib.Optional{<:AbstractVector},
β::LuxLib.Optional{<:AbstractVector},
ϵ::Real,
) where {T, N}
y = similar(
x,
promote_type(
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β)
)
)
γ′ = similar(
x, promote_type(safe_eltype(γ), safe_eltype(σ²), safe_eltype(ϵ)), size(x, N - 1)
)

batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ, γ′)

𝒫x, 𝒫μ, 𝒫σ² = CRC.ProjectTo(x), CRC.ProjectTo(μ), CRC.ProjectTo(σ²)
𝒫γ = γ === nothing ? identity : CRC.ProjectTo(γ)
𝒫β = β === nothing ? identity : CRC.ProjectTo(β)

∇batchnorm_affine_normalize_internal = LuxLib.Impl.@closure Δ -> begin
∂x, ∂μ, ∂σ², ∂γ, ∂β = ∇batchnorm_affine_normalize(opmode, Δ, x, μ, σ², γ, β, ϵ, γ′)
∂∅ = CRC.NoTangent()
return ∂∅, ∂∅, ∂∅, 𝒫x(∂x), 𝒫μ(∂μ), 𝒫σ²(∂σ²), 𝒫γ(∂γ), 𝒫β(∂β), ∂∅
end

return y, ∇batchnorm_affine_normalize_internal
end

@from_rrule(
DefaultCtx,
Tuple{
typeof(batchnorm_affine_normalize_internal),
AbstractInternalArrayOpMode,
typeof(identity),
AbstractArray,
AbstractVector,
AbstractVector,
LuxLib.Optional{<:AbstractVector},
LuxLib.Optional{<:AbstractVector},
Real,
},
)

@mooncake_overlay function batchnorm_affine_normalize_internal(

Check warning on line 139 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L139

Added line #L139 was not covered by tests
opmode::LuxLib.AbstractInternalArrayOpMode,
act::F,
x::AbstractArray{xT, 3},
μ::AbstractVector,
σ²::AbstractVector,
γ::Union{Nothing, AbstractVector},
β::Union{Nothing, AbstractVector},
ϵ::Real,
) where {F, xT}
y = batchnorm_affine_normalize_internal(opmode, identity, x, μ, σ², γ, β, ϵ)
LuxLib.Impl.activation!(y, opmode, act, y)
return y

Check warning on line 151 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L150-L151

Added lines #L150 - L151 were not covered by tests
end

@mooncake_overlay function batchnorm_affine_normalize_internal(

Check warning on line 154 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L154

Added line #L154 was not covered by tests
opmode::LuxLib.AbstractInternalArrayOpMode,
::typeof(identity),
x::AbstractArray{xT, 3},
μ::AbstractVector,
σ²::AbstractVector,
γ::Union{Nothing, AbstractVector},
β::Union{Nothing, AbstractVector},
ϵ::Real,
) where {xT}
y = similar(x,

Check warning on line 164 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L164

Added line #L164 was not covered by tests
promote_type(
safe_eltype(x), safe_eltype(μ), safe_eltype(σ²), safe_eltype(γ), safe_eltype(β)
)
)
batchnorm_affine_normalize_internal!(y, opmode, identity, x, μ, σ², γ, β, ϵ)
return y

Check warning on line 170 in ext/MooncakeLuxLibExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MooncakeLuxLibExt.jl#L169-L170

Added lines #L169 - L170 were not covered by tests
end

end
66 changes: 66 additions & 0 deletions ext/MooncakeNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
module MooncakeNNlibExt

using NNlib, Random, Mooncake
using Base: IEEEFloat
using NNlib: dropout

using NNlib: conv, depthwiseconv
import Mooncake: @from_rrule, DefaultCtx, MinimalCtx

@from_rrule(
MinimalCtx,
Tuple{typeof(batched_mul), Array{P, 3}, Array{P, 3}} where {P<:IEEEFloat},
)
@from_rrule(
MinimalCtx,
Tuple{typeof(dropout), AbstractRNG, Array{P}, P} where {P<:IEEEFloat},
true,
)
@from_rrule(MinimalCtx, Tuple{typeof(softmax), Array{<:IEEEFloat}}, true)
@from_rrule(MinimalCtx, Tuple{typeof(logsoftmax), Array{<:IEEEFloat}}, true)
@from_rrule(MinimalCtx, Tuple{typeof(logsumexp), Array{<:IEEEFloat}}, true)
@from_rrule(
MinimalCtx,
Tuple{typeof(upsample_nearest), Array{<:IEEEFloat}, NTuple{N, Int} where {N}},
)
@from_rrule(
MinimalCtx,
Tuple{
typeof(NNlib.fold), Array{<:IEEEFloat}, NTuple{N, Int} where {N}, DenseConvDims,
},
)
@from_rrule(
MinimalCtx, Tuple{typeof(NNlib.unfold), Array{<:IEEEFloat}, DenseConvDims}
)
@from_rrule(
MinimalCtx,
Tuple{typeof(NNlib.scatter), Any, Array, Array{<:Union{Integer, Tuple}}},
true,
)
for conv in [:conv, :depthwiseconv]
local ∇conv_data, ∇conv_filter = Symbol.(:∇, conv, [:_data, :_filter])

@eval @from_rrule(
MinimalCtx,
Tuple{typeof($conv), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat},
true,
)
@eval @from_rrule(
MinimalCtx,
Tuple{typeof($∇conv_data), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat},
true,
)
end
@eval @from_rrule(
willtebbutt marked this conversation as resolved.
Show resolved Hide resolved
MinimalCtx,
Tuple{typeof(∇conv_filter), Array{P}, Array{P}, ConvDims} where {P<:IEEEFloat},
true,
)
for pool in [:maxpool, :meanpool]
@eval @from_rrule(
MinimalCtx, Tuple{typeof($pool), Array{<:IEEEFloat}, PoolDims}, true
)
end
@from_rrule(MinimalCtx, Tuple{typeof(pad_constant), Array, Any, Any}, true)

end
Loading
Loading