Skip to content

Commit

Permalink
Try a different method
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Dec 10, 2024
1 parent 971bb9a commit 04d8c31
Showing 1 changed file with 18 additions and 27 deletions.
45 changes: 18 additions & 27 deletions src/rrules/lapack.jl
Original file line number Diff line number Diff line change
@@ -1,28 +1,22 @@
@is_primitive(
MinimalCtx,
Tuple{typeof(LAPACK.getrf!), AbstractMatrix{<:BlasRealFloat}, AbstractVector{Int}},
)
@is_primitive(MinimalCtx, Tuple{typeof(LAPACK.getrf!), AbstractMatrix{<:BlasRealFloat}})
function rrule!!(
::CoDual{typeof(LAPACK.getrf!)},
_A::CoDual{<:AbstractMatrix{P}},
_ipiv::CoDual{<:AbstractVector{Int}},
::CoDual{typeof(LAPACK.getrf!)}, _A::CoDual{<:AbstractMatrix{P}},
) where {P<:BlasRealFloat}
A, dA = viewify(_A)
ipiv = primal(_ipiv)
ipiv_copy = copy(ipiv)
A_copy = copy(A)

# Run the primal.
_, _, code = LAPACK.getrf!(A, ipiv)
_, ipiv, code = LAPACK.getrf!(A)

# Zero out the tangent.
dA .= zero(P)

function getrf_pb!!(::NoRData)
_getrf_pb!(A, dA, ipiv, A_copy, ipiv_copy)
return NoRData(), NoRData(), NoRData()
_getrf_pb!(A, dA, ipiv, A_copy)
return NoRData(), NoRData()
end
return CoDual((_A.x, ipiv, code), (_A.dx, _ipiv.dx, NoFData())), getrf_pb!!
dipiv = zero_tangent(ipiv)
return CoDual((_A.x, ipiv, code), (_A.dx, dipiv, NoFData())), getrf_pb!!
end

@is_primitive(
Expand All @@ -32,36 +26,33 @@ end
NamedTuple,
typeof(LAPACK.getrf!),
AbstractMatrix{<:BlasRealFloat},
AbstractVector{Int},
},
)
function rrule!!(
::CoDual{typeof(Core.kwcall)},
_kwargs::CoDual{<:NamedTuple},
::CoDual{typeof(getrf!)},
_A::CoDual{<:AbstractMatrix{P}},
_ipiv::CoDual{<:AbstractVector{Int}},
) where {P<:BlasRealFloat}
check = _kwargs.x.check
A, dA = viewify(_A)
ipiv = primal(_ipiv)
ipiv_copy = copy(ipiv)
A_copy = copy(A)

# Run the primal.
_, _, code = LAPACK.getrf!(A, ipiv; check)
_, ipiv, code = LAPACK.getrf!(A; check)

# Zero out the tangent.
dA .= zero(P)

function getrf_pb!!(::NoRData)
_getrf_pb!(A, dA, ipiv, A_copy, ipiv_copy)
return NoRData(), NoRData(), NoRData(), NoRData(), NoRData()
_getrf_pb!(A, dA, ipiv, A_copy)
return NoRData(), NoRData(), NoRData(), NoRData()
end
return CoDual((_A.x, ipiv, code), (_A.dx, _ipiv.dx, NoFData())), getrf_pb!!
dipiv = zero_tangent(ipiv)
return CoDual((_A.x, ipiv, code), (_A.dx, dipiv, NoFData())), getrf_pb!!
end

function _getrf_pb!(A, dA, ipiv, A_copy, ipiv_copy)
function _getrf_pb!(A, dA, ipiv, A_copy)

# Run reverse-pass.
L = UnitLowerTriangular(A)
Expand All @@ -77,7 +68,7 @@ function _getrf_pb!(A, dA, ipiv, A_copy, ipiv_copy)
dA .= (inv(L') * _dF * inv(U'))[invperm(p), :]

# Restore initial state.
ipiv .= ipiv_copy
# ipiv .= ipiv_copy
A .= A_copy

return nothing
Expand Down Expand Up @@ -341,7 +332,7 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
As = blas_matrices(rng, P, 5, 5)
ipiv = Vector{Int}(undef, 5)
return map(As) do A
(false, :none, nothing, getrf!, A, ipiv)
(false, :none, nothing, getrf!, A)
end
end...,

Expand Down Expand Up @@ -406,13 +397,13 @@ end

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:lapack})
rng = rng_ctor(123)
getrf_wrapper!(x, ipiv, check) = getrf!(x, ipiv; check)
getrf_wrapper!(x, check) = getrf!(x; check)
test_cases = vcat(
map_prod([false, true], [Float64, Float32]) do (check, P)
As = blas_matrices(rng, P, 5, 5)
ipiv = Vector{Int}(undef, 5)
# ipiv = Vector{Int}(undef, 5)
return map(As) do A
(false, :none, nothing, getrf_wrapper!, A, ipiv, check)
(false, :none, nothing, getrf_wrapper!, A, check)
end
end...,
)
Expand Down

0 comments on commit 04d8c31

Please sign in to comment.