Skip to content

Commit

Permalink
Merge pull request #169 from FluxML/zero
Browse files Browse the repository at this point in the history
Fix Base.zero type output
  • Loading branch information
ChrisRackauckas authored Aug 29, 2024
2 parents e384881 + a816377 commit fa1ca6d
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Tracker"
uuid = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
version = "0.2.34"
version = "0.2.35"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
13 changes: 7 additions & 6 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,15 @@ Base.getindex(xs::TrackedArray, i...; kwargs...) = track(getindex, xs, i...; kwa

@grad function getindex(xs::AbstractArray, i...; kwargs...)
getindex(data(xs), i...; kwargs...), function (Δ)
Δ′ = zero(xs)
Δ′ = zero(data(xs))
setindex!(Δ′, data(Δ), i...; kwargs...)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
end
end

@grad function getindex(xs::AbstractArray, i::Array...)
data(xs)[i...], function (Δ)
Δ′ = zero(xs)
Δ′ = zero(data(xs))
@views Δ′[i...] .+= data(Δ)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
end
Expand All @@ -117,7 +117,7 @@ Base.view(x::TrackedArray, inds...; kwargs...) = track(Base.view, x, inds...; kw

@grad function view(x::AbstractArray, inds...; kwargs...)
view(data(x), inds...; kwargs...), function (Δ)
grad_output = zero(x)
grad_output = zero(data(x))
subgrad = view(grad_output, inds...; kwargs...)
subgrad[:] = data(Δ)
(nobacksies(:view, grad_output), map(_->nothing, inds)...)
Expand All @@ -144,10 +144,11 @@ logabsdet(xs::TrackedArray) = track(logabsdet, xs)
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)

Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
Base.zero(x::Tracker.TrackedArray) = TrackedArray(zero(x.data))

@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
repeat(data(xs), inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs)
Δ′ = zero(data(xs))
S = size(xs)

# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
Expand Down Expand Up @@ -433,7 +434,7 @@ Base.minimum(xs::TrackedArray; dims = :) = track(minimum, xs, dims = dims)

@grad function maximum(xs; dims = dims)
maximum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
Δ′ = zero(data(xs))
_, i = findmax(data(xs), dims = dims)
Δ′[i] = data(Δ)
return (nobacksies(:maximum, Δ′),)
Expand All @@ -442,7 +443,7 @@ end

@grad function minimum(xs; dims = dims)
minimum(data(xs), dims = dims), function (Δ)
Δ′ = zero(xs)
Δ′ = zero(data(xs))
_, i = findmin(data(xs), dims = dims)
Δ′[i] = data(Δ)
return (nobacksies(:minimum, Δ′),)
Expand Down
5 changes: 5 additions & 0 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ RNG = NNlib.Random.MersenneTwister(1)

end # @testset gradtests

@testset "zero" begin
@test zero(TrackedArray(rand(2))) isa TrackedArray
@test gradtest(x-> zero(x) .* x, (2,))
end

@testset "indexing & slicing" begin
@test gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end
Expand Down

0 comments on commit fa1ca6d

Please sign in to comment.