diff --git a/Project.toml b/Project.toml index d9f5d9634..63df80fb0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Mooncake" uuid = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.4.55" +version = "0.4.56" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index cd2d96466..02f9f153b 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -21,7 +21,9 @@ function tri!(A, u::Char, d::Char) return u == 'L' ? tril!(A, d == 'U' ? -1 : 0) : triu!(A, d == 'U' ? 1 : 0) end -const MatrixOrView{T} = Union{Matrix{T},SubArray{T,2,Matrix{T}}} +const MatrixOrView{T} = Union{Matrix{T},SubArray{T,2,<:Array{T}}} +const VecOrView{T} = Union{Vector{T},SubArray{T,1,<:Array{T}}} +const BlasRealFloat = Union{Float32,Float64} # # Utility @@ -132,82 +134,67 @@ end # LEVEL 2 # -for (gemv, elty) in ((:dgemv_, :Float64), (:sgemv_, :Float32)) - @eval @inline function rrule!!( - ::CoDual{typeof(_foreigncall_)}, - ::CoDual{Val{$(blas_name(gemv))}}, - ::CoDual, - ::CoDual, - ::CoDual, - ::CoDual, - _tA::CoDual{Ptr{UInt8}}, - _M::CoDual{Ptr{BLAS.BlasInt}}, - _N::CoDual{Ptr{BLAS.BlasInt}}, - _alpha::CoDual{Ptr{$elty}}, - _A::CoDual{Ptr{$elty}}, - _lda::CoDual{Ptr{BLAS.BlasInt}}, - _x::CoDual{Ptr{$elty}}, - _incx::CoDual{Ptr{BLAS.BlasInt}}, - _beta::CoDual{Ptr{$elty}}, - _y::CoDual{Ptr{$elty}}, - _incy::CoDual{Ptr{BLAS.BlasInt}}, - args::Vararg{Any,Nargs}, - ) where {Nargs} - GC.@preserve args begin +@is_primitive( + MinimalCtx, + Tuple{ + typeof(BLAS.gemv!),Char,P,MatrixOrView{P},VecOrView{P},P,VecOrView{P} + } where {P<:BlasRealFloat}, +) - # Load in data. - tA = Char(unsafe_load(primal(_tA))) - M, N, lda, incx, incy = map(unsafe_load ∘ primal, (_M, _N, _lda, _incx, _incy)) - alpha = unsafe_load(primal(_alpha)) - beta = unsafe_load(primal(_beta)) +@inline function rrule!!( + ::CoDual{typeof(BLAS.gemv!)}, + _tA::CoDual{Char}, + _alpha::CoDual{P}, + _A::CoDual{<:MatrixOrView{P}}, + _x::CoDual{<:VecOrView{P}}, + _beta::CoDual{P}, + _y::CoDual{<:VecOrView{P}}, +) where {P<:BlasRealFloat} + + # Pull out primals and tangents (the latter only where necessary). + trans = _tA.x + alpha = _alpha.x + A, dA = viewify(_A) + x, dx = viewify(_x) + beta = _beta.x + y, dy = viewify(_y) + + # Take copies before adding. + y_copy = copy(y) - # Run primal. - A = wrap_ptr_as_view(primal(_A), lda, M, N) - Nx = tA == 'N' ? N : M - Ny = tA == 'N' ? M : N - x = view(unsafe_wrap(Vector{$elty}, primal(_x), incx * Nx), 1:incx:(incx * Nx)) - y = view(unsafe_wrap(Vector{$elty}, primal(_y), incy * Ny), 1:incy:(incy * Ny)) - y_copy = copy(y) + # Run primal. + BLAS.gemv!(trans, alpha, A, x, beta, y) - BLAS.gemv!(tA, alpha, A, x, beta, y) + function gemv!_pb!!(::NoRData) - dalpha = tangent(_alpha) - dbeta = tangent(_beta) - _dA = tangent(_A) - _dx = tangent(_x) - _dy = tangent(_y) + # Increment fdata. + if trans == 'N' + dalpha = dot(dy, A, x) + dA .+= alpha .* dy .* x' + BLAS.gemv!('T', alpha, A, dy, one(eltype(A)), dx) + else + dalpha = dot(dy, A', x) + dA .+= alpha .* x .* dy' + BLAS.gemv!('N', alpha, A, dy, one(eltype(A)), dx) end + dbeta = dot(y_copy, dy) + dy .*= beta - function gemv_pb!!(::NoRData) - GC.@preserve args begin - - # Load up the tangents. - dA = wrap_ptr_as_view(_dA, lda, M, N) - dx = view(unsafe_wrap(Vector{$elty}, _dx, incx * Nx), 1:incx:(incx * Nx)) - dy = view(unsafe_wrap(Vector{$elty}, _dy, incy * Ny), 1:incy:(incy * Ny)) - - # Increment the tangents. - unsafe_store!(dalpha, unsafe_load(dalpha) + dot(dy, _trans(tA, A), x)) - dA .+= _trans(tA, alpha * dy * x') - dx .+= alpha * _trans(tA, A)'dy - unsafe_store!(dbeta, unsafe_load(dbeta) + dot(y_copy, dy)) - dy .*= beta - - # Restore the original value of `y`. - y .= y_copy - end + # Restore primal. + copyto!(y, y_copy) - return tuple_fill(NoRData(), Val(17 + Nargs)) - end - return zero_fcodual(Cvoid()), gemv_pb!! + # Return rdata. + return NoRData(), NoRData(), dalpha, NoRData(), NoRData(), dbeta, NoRData() end + + return _y, gemv!_pb!! end @is_primitive( MinimalCtx, Tuple{ - typeof(BLAS.symv!),Char,T,MatrixOrView{T},Vector{T},T,Vector{T} - } where {T<:Union{Float32,Float64}}, + typeof(BLAS.symv!),Char,T,MatrixOrView{T},VecOrView{T},T,VecOrView{T} + } where {T<:BlasRealFloat}, ) function rrule!!( @@ -215,10 +202,10 @@ function rrule!!( uplo::CoDual{Char}, alpha::CoDual{T}, A_dA::CoDual{<:MatrixOrView{T}}, - x_dx::CoDual{Vector{T}}, + x_dx::CoDual{<:VecOrView{T}}, beta::CoDual{T}, - y_dy::CoDual{Vector{T}}, -) where {T<:Union{Float32,Float64}} + y_dy::CoDual{<:VecOrView{T}}, +) where {T<:BlasRealFloat} # Extract primals. ul = primal(uplo) @@ -342,8 +329,8 @@ end @is_primitive( MinimalCtx, Tuple{ - typeof(BLAS.gemm!),Char,Char,T,MatrixOrView{T},MatrixOrView{T},T,Matrix{T} - } where {T<:Union{Float32,Float64}}, + typeof(BLAS.gemm!),Char,Char,T,MatrixOrView{T},MatrixOrView{T},T,MatrixOrView{T} + } where {T<:BlasRealFloat}, ) function rrule!!( @@ -354,8 +341,8 @@ function rrule!!( A::CoDual{<:MatrixOrView{T}}, B::CoDual{<:MatrixOrView{T}}, beta::CoDual{T}, - C::CoDual{Matrix{T}}, -) where {T<:Union{Float32,Float64}} + C::CoDual{<:MatrixOrView{T}}, +) where {T<:BlasRealFloat} tA = primal(transA) tB = primal(transB) a = primal(alpha) @@ -374,7 +361,7 @@ function rrule!!( else tmp = BLAS.gemm(tA, tB, one(T), p_A, p_B) tmp_ref[] = tmp - BLAS.axpby!(a, tmp, b, p_C) + p_C .= a .* tmp .+ b .* p_C end function gemm!_pb!!(::NoRData) @@ -386,7 +373,7 @@ function rrule!!( BLAS.copyto!(p_C, p_C_copy) # Compute pullback w.r.t. beta. - db = BLAS.dot(dC, p_C) + db = dot(dC, p_C) # Increment cotangents. if tA == 'N' @@ -399,7 +386,7 @@ function rrule!!( else BLAS.gemm!('T', tA == 'N' ? 'N' : 'T', a, dC, p_A, one(T), dB) end - BLAS.scal!(b, dC) + dC .*= b return NoRData(), NoRData(), NoRData(), da, NoRData(), NoRData(), db, NoRData() end @@ -499,8 +486,8 @@ end @is_primitive( MinimalCtx, Tuple{ - typeof(BLAS.symm!),Char,Char,T,MatrixOrView{T},MatrixOrView{T},T,Matrix{T} - } where {T<:Union{Float32,Float64}}, + typeof(BLAS.symm!),Char,Char,T,MatrixOrView{T},MatrixOrView{T},T,MatrixOrView{T} + } where {T<:BlasRealFloat}, ) function rrule!!( @@ -511,8 +498,8 @@ function rrule!!( A_dA::CoDual{<:MatrixOrView{T}}, B_dB::CoDual{<:MatrixOrView{T}}, beta::CoDual{T}, - C_dC::CoDual{Matrix{T}}, -) where {T<:Union{Float32,Float64}} + C_dC::CoDual{<:MatrixOrView{T}}, +) where {T<:BlasRealFloat} # Extract primals. s = primal(side) @@ -533,7 +520,7 @@ function rrule!!( else tmp = BLAS.symm(s, ul, one(T), A, B) tmp_ref[] = tmp - BLAS.axpby!(α, tmp, β, C) + C .= α .* tmp .+ β .* C end function symm!_adjoint(::NoRData) @@ -569,7 +556,7 @@ function rrule!!( dβ = dot(dC, C) # gradient w.r.t. C. - BLAS.scal!(β, dC) + dC .*= β return NoRData(), NoRData(), NoRData(), dα, NoRData(), NoRData(), dβ, NoRData() end @@ -801,29 +788,66 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32)) end end +function blas_matrices(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int, q::Int) + Xs = Any[ + randn(rng, P, p, q), + view(randn(rng, P, p + 5, 2q), 3:(p + 2), 1:2:(2q)), + view(randn(rng, P, 3p, 3, 2q), (p + 1):(2p), 2, 1:2:(2q)), + ] + @assert all(X -> size(X) == (p, q), Xs) + @assert all(Base.Fix2(isa, AbstractMatrix{P}), Xs) + return Xs +end + +function blas_vectors(rng::AbstractRNG, P::Type{<:BlasRealFloat}, p::Int) + xs = Any[ + randn(rng, P, p), + view(randn(rng, P, p + 5), 3:(p + 2)), + view(randn(rng, P, 3p), 1:2:(2p)), + view(randn(rng, P, 3p, 3), 1:2:(2p), 2), + ] + @assert all(x -> length(x) == p, xs) + @assert all(Base.Fix2(isa, AbstractVector{P}), xs) + return xs +end + function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) t_flags = ['N', 'T', 'C'] alphas = [1.0, -0.25] betas = [0.0, 0.33] + rng = rng_ctor(123456) test_cases = vcat( + # gemv! + vec( + reduce( + vcat, + map(product(t_flags, [1, 3], [1, 2])) do (tA, M, N) + t = tA == 'N' + As = blas_matrices(rng, Float64, t ? M : N, t ? N : M) + xs = blas_vectors(rng, Float64, N) + ys = blas_vectors(rng, Float64, M) + flags = (false, :stability, (lb=1e-3, ub=5.0)) + return map(product(As, xs, ys)) do (A, x, y) + return (flags..., BLAS.gemv!, tA, randn(), A, x, randn(), y) + end + end, + ), + ), + # symv! vec( reduce( vcat, - vec( - map(product(['L', 'U'], alphas, betas)) do (uplo, α, β) - A = randn(5, 5) - vA = view(randn(15, 15), 1:5, 1:5) - x = randn(5) - y = randn(5) - return Any[ - (false, :stability, nothing, BLAS.symv!, uplo, α, A, x, β, y), - (false, :stability, nothing, BLAS.symv!, uplo, α, vA, x, β, y), - ] - end, - ), + map(product(['L', 'U'], alphas, betas)) do (uplo, α, β) + As = blas_matrices(rng, Float64, 5, 5) + ys = blas_vectors(rng, Float64, 5) + xs = blas_vectors(rng, Float64, 5) + return map(product(As, xs, ys)) do (A, x, y) + (false, :stability, nothing, BLAS.symv!, uplo, α, A, x, β, y) + end + end, ), ), @@ -831,26 +855,14 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) vec( reduce( vcat, - vec( - map(product(t_flags, t_flags, alphas, betas)) do (tA, tB, a, b) - A = tA == 'N' ? randn(3, 4) : randn(4, 3) - B = tB == 'N' ? randn(4, 5) : randn(5, 4) - As = if tA == 'N' - [randn(3, 4), view(randn(15, 15), 2:4, 3:6)] - else - [randn(4, 3), view(randn(15, 15), 2:5, 3:5)] - end - Bs = if tB == 'N' - [randn(4, 5), view(randn(15, 15), 1:4, 2:6)] - else - [randn(5, 4), view(randn(15, 15), 1:5, 3:6)] - end - C = randn(3, 5) - return map(product(As, Bs)) do (A, B) - (false, :stability, nothing, BLAS.gemm!, tA, tB, a, A, B, b, C) - end - end, - ), + map(product(t_flags, t_flags, alphas, betas)) do (tA, tB, a, b) + As = blas_matrices(rng, Float64, tA == 'N' ? 3 : 4, tA == 'N' ? 4 : 3) + Bs = blas_matrices(rng, Float64, tB == 'N' ? 4 : 5, tB == 'N' ? 5 : 4) + Cs = blas_matrices(rng, Float64, 3, 5) + return map(product(As, Bs, Cs)) do (A, B, C) + (false, :none, nothing, BLAS.gemm!, tA, tB, a, A, B, b, C) + end + end, ), ), @@ -858,46 +870,15 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) vec( reduce( vcat, - vec( - map( - product(['L', 'R'], ['L', 'U'], alphas, betas) - ) do (side, uplo, α, β) - nA = side == 'L' ? 5 : 7 - A = randn(nA, nA) - vA = view(randn(15, 15), 1:nA, 1:nA) - B = randn(5, 7) - vB = view(randn(15, 15), 1:5, 1:7) - C = randn(5, 7) - return Any[ - ( - false, - :stability, - nothing, - BLAS.symm!, - side, - uplo, - α, - A, - B, - β, - C, - ), - ( - false, - :stability, - nothing, - BLAS.symm!, - side, - uplo, - α, - vA, - vB, - β, - C, - ), - ] - end, - ), + map(product(['L', 'R'], ['L', 'U'], alphas, betas)) do (side, uplo, α, β) + nA = side == 'L' ? 5 : 7 + As = blas_matrices(rng, Float64, nA, nA) + Bs = blas_matrices(rng, Float64, 5, 7) + Cs = blas_matrices(rng, Float64, 5, 7) + return map(product(As, Bs, Cs)) do (A, B, C) + (false, :stability, nothing, BLAS.symm!, side, uplo, α, A, B, β, C) + end + end, ), ), ) @@ -934,29 +915,6 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) # BLAS LEVEL 2 # - # gemv! - vec( - reduce( - vcat, - map(product(t_flags, [1, 3], [1, 2])) do (tA, M, N) - t = tA == 'N' - As = [ - t ? randn(M, N) : randn(N, M), - view( - randn(15, 15), - t ? (3:(M + 2)) : (2:(N + 1)), - t ? (2:(N + 1)) : (3:(M + 2)), - ), - ] - xs = [randn(N), view(randn(15), 3:(N + 2)), view(randn(30), 1:2:(2N))] - ys = [randn(M), view(randn(15), 2:(M + 1)), view(randn(30), 2:2:(2M))] - return map(Iterators.product(As, xs, ys)) do (A, x, y) - (false, :none, nothing, BLAS.gemv!, tA, randn(), A, x, randn(), y) - end - end, - ), - ), - # trmv! vec( reduce( @@ -989,7 +947,7 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) map(product(['U', 'L'], t_flags)) do (uplo, t) A = t == 'N' ? randn(3, 4) : randn(4, 3) C = randn(3, 3) - Any[false, :none, nothing, BLAS.syrk!, uplo, t, randn(), A, randn(), C] + return (false, :none, nothing, BLAS.syrk!, uplo, t, randn(), A, randn(), C) end, ), @@ -1004,11 +962,9 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) R = side == 'L' ? M : N As = [randn(R, R), view(randn(15, 15), 3:(R + 2), 4:(R + 3))] Bs = [randn(M, N), view(randn(15, 15), 2:(M + 1), 5:(N + 4))] + flags = (false, :none, nothing) return map(product(As, Bs)) do (A, B) - alpha = randn() - Any[ - false, :none, nothing, BLAS.trmm!, side, ul, tA, dA, alpha, A, B - ] + (flags..., BLAS.trmm!, side, ul, tA, dA, randn(), A, B) end end, ), @@ -1025,11 +981,9 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas}) R = side == 'L' ? M : N As = [randn(R, R) + 5I, view(randn(15, 15), 3:(R + 2), 4:(R + 3)) + 5I] Bs = [randn(M, N), view(randn(15, 15), 2:(M + 1), 5:(N + 4))] + flags = (false, :none, nothing) return map(product(As, Bs)) do (A, B) - alpha = randn() - Any[ - false, :none, nothing, BLAS.trsm!, side, ul, tA, dA, alpha, A, B - ] + (flags..., BLAS.trsm!, side, ul, tA, dA, randn(), A, B) end end, ),