Skip to content

Commit

Permalink
Merge pull request #3100 from DhairyaLGandhi/dg/crc3
Browse files Browse the repository at this point in the history
Allow arrays in `MTKParameters` pullback
  • Loading branch information
ChrisRackauckas authored Oct 7, 2024
2 parents d7fa2b9 + 91e5e70 commit 284e463
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion ext/MTKChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ import ChainRulesCore: Tangent, ZeroTangent, NoTangent, zero_tangent, unthunk
function ChainRulesCore.rrule(::Type{MTK.MTKParameters}, tunables, args...)
function mtp_pullback(dt)
dt = unthunk(dt)
(NoTangent(), dt.tunable[1:length(tunables)],
dtunables = dt isa AbstractArray ? dt : dt.tunable
(NoTangent(), dtunables[1:length(tunables)],
ntuple(_ -> NoTangent(), length(args))...)
end
MTK.MTKParameters(tunables, args...), mtp_pullback
Expand Down

0 comments on commit 284e463

Please sign in to comment.