Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overload NNlib.within_gradient #136

Merged
merged 4 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Tracker"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.22"
version = "0.2.23"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -26,7 +26,7 @@ Functors = "0.3.0"
ForwardDiff = "0.10"
LogExpFunctions = "0.3"
MacroTools = "0.5"
NNlib = "0.8"
NNlib = "0.8.14"
NaNMath = "1"
Optimisers = "0.2.9"
Requires = "1.0"
Expand Down
1 change: 1 addition & 0 deletions src/Tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using MacroTools
using MacroTools: @q, @forward

using DiffRules
using ForwardDiff
import LogExpFunctions
import NaNMath
import SpecialFunctions
Expand Down
30 changes: 24 additions & 6 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -372,11 +372,6 @@ end

Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)

Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)

Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)

import LinearAlgebra: dot

dot(xs::TrackedArray, ys::TrackedArray) = track(dot, xs, ys)
Expand All @@ -390,13 +385,33 @@ Statistics.std(x::TrackedArray; dims = :, mean = Statistics.mean(x, dims = dims)
_std(x::TrackedArray, mean, dims, corrected) = sqrt.(sum((x .- mean).^2, dims = dims) ./ (mapreduce(i -> size(x,i),*, dims) - corrected))
_std(x::TrackedArray, mean, ::Colon, corrected) = sqrt.(sum((x .- mean).^2) ./ (length(x) - corrected))

Statistics.var(x::TrackedArray; dims=:, mean=Statistics.mean(data(x); dims), corrected::Bool=true) =
track(Statistics.var, x; dims, mean=data(mean), corrected)
# from https://github.com/JuliaDiff/ChainRules.jl/blob/main/src/rulesets/Statistics/statistics.jl
@grad function Statistics.var(x; dims, mean, corrected)
y = Statistics.var(data(x); corrected, mean, dims)
function variance_back(dy)
dx = 2 .* dy .* (data(x) .- mean) ./ (_denom(x, dims) - corrected)
(dx,)
end
y, variance_back
end
_denom(x, dims::Integer) = size(x, dims)
_denom(x, dims::Colon) = length(x)
_denom(x, dims) = prod(i->size(x, i), unique(dims), init=1)

LinearAlgebra.norm(x::TrackedArray{T}, p::Real = 2) where T =
(sum(abs.(x).^p) + eps(T))^(oneunit(T) / p) # avoid d(sqrt(x))/dx == Inf at 0

Statistics.mean(xs::TrackedArray; dims = :) = track(mean, xs, dims = dims)

@grad mean(xs; dims = :) = mean(data(xs), dims=dims), Δ -> (_backmean(xs,Δ,dims),)
_backmean(xs, Δ, ::Colon) = zero(xs) .+ Δ ./ length(xs)
_backmean(xs, Δ, dims) = zero(xs) .+ Δ ./ mapreduce(i -> size(data(xs),i),*,dims)

Base.maximum(xs::TrackedArray; dims = :) = track(maximum, xs, dims = dims)
Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)

@grad function maximum(xs; dims = dims)
maximum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
Expand Down Expand Up @@ -472,7 +487,10 @@ Base.:*(x::TrackedMatrix, y::Diagonal) = track(*, x, y)

using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, ∇conv_data, depthwiseconv, maxpool, meanpool
import NNlib: DenseConvDims, DepthwiseConvDims, PoolDims
import NNlib: DenseConvDims, DepthwiseConvDims, PoolDims, within_gradient

within_gradient(::TrackedArray) = true
within_gradient(::TrackedReal) = true

softmax(xs::TrackedArray; dims=1) = track(softmax, xs; dims=dims)

Expand Down
2 changes: 2 additions & 0 deletions src/lib/real.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
data(x::TrackedReal) = x.data
tracker(x::TrackedReal) = x.tracker

ForwardDiff.value(x::TrackedReal) = x.data

track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))

function back!(x::TrackedReal; once = true)
Expand Down
7 changes: 7 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,11 @@ Optimisers.trainable(x::TwoThirds) = (a = x.a,)
@test g == (val = 6.0, grad = ((a = [1.0], b = nothing, c = nothing),))
end

using NNlib
@testset "NNlib.within_gradient" begin
f_good(x) = NNlib.within_gradient(x) ? 10x : x
@test gradient(f_good, 1.0)[1] == 10
@test gradient(x -> sum(f_good(x)), [1.0])[1] == [10]
end

end # overall @testset
10 changes: 8 additions & 2 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv
using PDMats
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet, I, Diagonal
using Statistics: mean, std
using Statistics: mean, std, var
using Random
# using StatsBase

Expand Down Expand Up @@ -137,7 +137,7 @@ end
@test hcat(1, param([1 2 3;])) isa TrackedArray
@test vcat(param(1), 2) isa TrackedArray
end

@testset "ambiguities" begin
@test vcat(param([1, 2, 3]), [2,3]) isa TrackedArray
@test vcat(param([1, 2, 3]), [2.0, 3.0]) isa TrackedArray
Expand Down Expand Up @@ -230,6 +230,12 @@ end
@test gradtest(x -> std(x, dims = 1), rand(5,5))
@test gradtest(x -> std(x, dims = 1, corrected = false), rand(5,5))

@test gradtest(x -> var(x), rand(5,5))
@test gradtest(x -> var(x, dims = 1), rand(5,5))
@test gradtest(x -> var(x, dims = 1, corrected = false), rand(5,5))
x55 = rand(5,5)
@test gradtest(x -> var(x, dims = 2, mean = mean(x55; dims=2)), x55)

@test gradtest((x, y) -> x .* y, rand(5), rand(5))
@test gradtest(dot, rand(5), rand(5))

Expand Down