Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overload NNlib.within_gradient #136

Merged
merged 4 commits into from
Jan 6, 2023
Merged

Conversation

mcabbott
Copy link
Member

@mcabbott mcabbott commented Jan 6, 2023

Uses FluxML/NNlib.jl#434

Does not solve #133 but the error is different:

julia> using Flux, Zygote, Tracker

julia> bn = BatchNorm(2); x = rand(Float32, 2, 3);

julia> sum(deepcopy(bn)(x))
2.1757803f0

julia> bn.active === nothing
true

julia> bn.active = false
false

julia> _bn = deepcopy(bn); Zygote.withgradient(sum_bn, x)[1]
2.1757803f0

julia> _bn = deepcopy(bn); Tracker.withgradient(sum_bn, x)[1]
2.1757803f0

julia> bn.active = true
true

julia> _bn = deepcopy(bn); Zygote.withgradient(sum_bn, x)[1]
-2.0489097f-7

julia> _bn = deepcopy(bn); Tracker.gradient(sum_bn, x)[1]
ERROR: MethodError: no method matching Float32(::Tracker.TrackedReal{Float32})
...
Stacktrace:
  [1] convert(#unused#::Type{Float32}, x::Tracker.TrackedReal{Float32})
    @ Base ./number.jl:7
  [2] setindex!
    @ ./array.jl:971 [inlined]
...
  [9] materialize!
    @ ./broadcast.jl:881 [inlined]
 [10] accum!(x::Matrix{Float32}, Δ::Matrix{Tracker.TrackedReal{Float32}})
    @ Tracker ~/.julia/packages/Tracker/a9oj5/src/back.jl:45
 [11] back(x::Tracker.Tracked{Matrix{Float32}}, Δ::Matrix{Tracker.TrackedReal{Float32}}, once::Bool)
    @ Tracker ~/.julia/packages/Tracker/a9oj5/src/back.jl:48
...
 [22] back(x::Tracker.Tracked{Matrix{Tracker.TrackedReal{Float32}}}, Δ::Matrix{Tracker.TrackedReal{Float32}}, once::Bool)

Does not make the example from FluxML/Flux.jl#2122 work, but perhaps the failure is unrelated:

julia> let
       using Flux
       using Random
       Random.seed!(123)

       model = Chain(
                 Conv((3, 3), 3 => 5, pad=1, bias=false), 
                 BatchNorm(5, relu), 
                 Conv((3, 3), 5 => 3, stride=16),
               )
       image = rand(Float32, 224, 224, 3, 1);
       @show sum(model(image));

       loss(m, x) = sum(m(x))

       opt = Flux.setup(Flux.Adam(0.001f0,  (0.9f0, 0.999f0), 1.1920929f-7), model)

       val, grads = Tracker.withgradient(model) do m
           loss(m, image)
       end

       Flux.update!(opt, model, grads[1])
       @show loss(model, image);
       end;
sum(model(image)) = -0.33076355f0
ERROR: UndefRefError: access to undefined reference
Stacktrace:
  [1] getindex
    @ ./essentials.jl:14 [inlined]
  [2] conv_direct!(y::Array{Tracker.TrackedReal{Float32}, 5}, x::Array{Tracker.TrackedReal{Float32}, 5}, w::Array{Float32, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}, ::Val{(3, 3, 1)}, ::Val{3}, ::Val{(0, 0, 0, 0, 0, 0)}, ::Val{(1, 1, 1)}, ::Val{(16, 16, 1)}, fk::Val{false}; alpha::Tracker.TrackedReal{Float32}, beta::Bool)
    @ NNlib ~/.julia/packages/NNlib/QJIIj/src/impl/conv_direct.jl:111
  [3] kwcall(::NamedTuple{(:alpha, :beta), Tuple{Tracker.TrackedReal{Float32}, Bool}}, ::typeof(NNlib.conv_direct!), y::Array{Tracker.TrackedReal{Float32}, 5}, x::Array{Tracker.TrackedReal{Float32}, 5}, w::Array{Float32, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}, ::Val{(3, 3, 1)}, ::Val{3}, ::Val{(0, 0, 0, 0, 0, 0)}, ::Val{(1, 1, 1)}, ::Val{(16, 16, 1)}, fk::Val{false})
    @ NNlib ~/.julia/packages/NNlib/QJIIj/src/impl/conv_direct.jl:59
  [4] conv_direct!(y::Array{Tracker.TrackedReal{Float32}, 5}, x::Array{Tracker.TrackedReal{Float32}, 5}, w::Array{Float32, 5}, cdims::DenseConvDims{3, 3, 3, 6, 3}; alpha::Tracker.TrackedReal{Float32}, beta::Bool)
    @ NNlib ~/.julia/packages/NNlib/QJIIj/src/impl/conv_direct.jl:50
  [5] conv_direct!
    @ ~/.julia/packages/NNlib/QJIIj/src/impl/conv_direct.jl:47 [inlined]
  [6] #conv!#301
    @ ~/.julia/packages/NNlib/QJIIj/src/conv.jl:288 [inlined]
  [7] conv!
    @ ~/.julia/packages/NNlib/QJIIj/src/conv.jl:280 [inlined]
  [8] conv!(y::Array{Tracker.TrackedReal{Float32}, 4}, x::Array{Tracker.TrackedReal{Float32}, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/QJIIj/src/conv.jl:145
  [9] conv!
    @ ~/.julia/packages/NNlib/QJIIj/src/conv.jl:140 [inlined]
 [10] conv(x::Array{Tracker.TrackedReal{Float32}, 4}, w::Array{Float32, 4}, cdims::DenseConvDims{2, 2, 2, 4, 2}; kwargs::Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}})
    @ NNlib ~/.julia/packages/NNlib/QJIIj/src/conv.jl:88
 [11] conv
    @ ~/.julia/packages/NNlib/QJIIj/src/conv.jl:83 [inlined]
 [12] #_forward#636
    @ ~/.julia/packages/Tracker/a9oj5/src/lib/array.jl:520 [inlined]
 [13] _forward
    @ ./none:0 [inlined]
 [14] #track#1
    @ ~/.julia/packages/Tracker/a9oj5/src/Tracker.jl:58 [inlined]
 [15] track
    @ ~/.julia/packages/Tracker/a9oj5/src/Tracker.jl:57 [inlined]
 [16] #conv#633
    @ ~/.julia/packages/Tracker/a9oj5/src/lib/array.jl:516 [inlined]
 [17] conv
    @ ~/.julia/packages/Tracker/a9oj5/src/lib/array.jl:516 [inlined]
 [18] (::Conv{2, 4, typeof(identity), TrackedArray{…,Array{Float32, 4}}, TrackedArray{…,Vector{Float32}}})(x::TrackedArray{…,Array{Tracker.TrackedReal{Float32}, 4}})
    @ Flux ~/.julia/dev/Flux/src/layers/conv.jl:200
 [19] macro expansion
    @ ~/.julia/dev/Flux/src/layers/basic.jl:53 [inlined]
 [20] _applychain(layers::Tuple{Conv{2, 4, typeof(identity), TrackedArray{,Array{Float32, 4}}, Bool}, BatchNorm{typeof(relu), TrackedArray{,Vector{Float32}}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), TrackedArray{,Array{Float32, 4}}, TrackedArray{,Vector{Float32}}}}, x::Array{Float32, 4})
    @ Flux ~/.julia/dev/Flux/src/layers/basic.jl:53
 [21] Chain
    @ ~/.julia/dev/Flux/src/layers/basic.jl:51 [inlined]
 [22] (::var"#loss#31")(m::Chain{Tuple{Conv{2, 4, typeof(identity), TrackedArray{,Array{Float32, 4}}, Bool}, BatchNorm{typeof(relu), TrackedArray{,Vector{Float32}}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), TrackedArray{,Array{Float32, 4}}, TrackedArray{,Vector{Float32}}}}}, x::Array{Float32, 4})
    @ Main ./REPL[28]:14
 [23] (::var"#30#32"{var"#loss#31", Array{Float32, 4}})(m::Chain{Tuple{Conv{2, 4, typeof(identity), TrackedArray{,Array{Float32, 4}}, Bool}, BatchNorm{typeof(relu), TrackedArray{,Vector{Float32}}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), TrackedArray{,Array{Float32, 4}}, TrackedArray{,Vector{Float32}}}}})
    @ Main ./REPL[28]:19
 [24] withgradient(f::Function, xs::Chain{Tuple{Conv{2, 4, typeof(identity), Array{Float32, 4}, Bool}, BatchNorm{typeof(relu), Vector{Float32}, Float32, Vector{Float32}}, Conv{2, 4, typeof(identity), Array{Float32, 4}, Vector{Float32}}}})
    @ Tracker ~/.julia/packages/Tracker/a9oj5/src/back.jl:218
 [25] top-level scope
    @ REPL[28]:18

Edit: fixed, closes #133, closes #137

@coveralls
Copy link

coveralls commented Jan 6, 2023

Pull Request Test Coverage Report for Build 3852665337

  • 13 of 16 (81.25%) changed or added relevant lines in 2 files are covered.
  • 1 unchanged line in 1 file lost coverage.
  • Overall coverage increased (+0.09%) to 72.439%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/lib/real.jl 0 1 0.0%
src/lib/array.jl 13 15 86.67%
Files with Coverage Reduction New Missed Lines %
src/lib/array.jl 1 67.49%
Totals Coverage Status
Change from base Build 3010278444: 0.09%
Covered Lines: 502
Relevant Lines: 693

💛 - Coveralls

@ToucheSir
Copy link
Member

For the first error, guessing it needs something like https://github.com/FluxML/Flux.jl/blob/master/src/layers/normalise.jl#L278-L280.

For the second, this is FluxML/NNlib.jl#405. Although we probably ought to make the NNlib side more robust, the conv layer and conv functions shouldn't be seeing Array{<:TrackedReal} in the first place!

@mcabbott
Copy link
Member Author

mcabbott commented Jan 6, 2023

For the first, I think that after f098dd9 the error is not from within _track_stats!.

For the second, indeed the error is exactly that, thanks. But why an array of tracked numbers is made I don't know. A smaller example is this:

julia> Flux.trainmode!(BatchNorm(3))(param(rand(2,3,1)))
Tracked 2×3×1 Array{Tracker.TrackedReal{Float64}, 3}:
[:, :, 1] =
 -0.994696  -0.999959   0.912711
  0.994696   0.999959  -0.912711

julia> Flux.testmode!(BatchNorm(3))(param(rand(2,3,1)))
Tracked 2×3×1 Array{Float64, 3}:
[:, :, 1] =
 0.664405  0.684316  0.321995
 0.834291  0.463685  0.370812

... which is #137.

Copy link
Member

@ToucheSir ToucheSir left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only remaining test failures on nightly are the known kron ones.

@mcabbott mcabbott merged commit 00393e2 into FluxML:master Jan 6, 2023
@mcabbott mcabbott deleted the within_grad branch January 6, 2023 05:40
@coveralls
Copy link

Pull Request Test Coverage Report for Build 3852093460

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 12 of 15 (80.0%) changed or added relevant lines in 2 files are covered.
  • 38 unchanged lines in 1 file lost coverage.
  • Overall coverage decreased (-0.03%) to 72.328%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/lib/real.jl 0 1 0.0%
src/lib/array.jl 12 14 85.71%
Files with Coverage Reduction New Missed Lines %
src/lib/array.jl 38 67.14%
Totals Coverage Status
Change from base Build 3010278444: -0.03%
Covered Lines: 494
Relevant Lines: 683

💛 - Coveralls

@coveralls
Copy link

coveralls commented Nov 5, 2024

Pull Request Test Coverage Report for Build 3852630410

Warning: This coverage report may be inaccurate.

This pull request's base commit is no longer the HEAD commit of its target branch. This means it includes changes from outside the original pull request, including, potentially, unrelated coverage changes.

Details

  • 12 of 14 (85.71%) changed or added relevant lines in 2 files are covered.
  • 30 unchanged lines in 1 file lost coverage.
  • Overall coverage increased (+0.2%) to 72.518%

Changes Missing Coverage Covered Lines Changed/Added Lines %
src/lib/array.jl 12 13 92.31%
src/lib/real.jl 0 1 0.0%
Files with Coverage Reduction New Missed Lines %
src/lib/array.jl 30 67.67%
Totals Coverage Status
Change from base Build 3010278444: 0.2%
Covered Lines: 504
Relevant Lines: 695

💛 - Coveralls

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Statistics.var makes Array{TrackedReal} Wrong forward results on some Flux layers
3 participants