Skip to content

Commit

Permalink
smaller computational graph
Browse files Browse the repository at this point in the history
  • Loading branch information
ChrisRackauckas committed Aug 29, 2024
1 parent 2febf9f commit a816377
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ logabsdet(xs::TrackedArray) = track(logabsdet, xs)
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)

Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
Base.zero(x::Tracker.TrackedArray) = zero.(x)
Base.zero(x::Tracker.TrackedArray) = TrackedArray(zero(x.data))

@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
repeat(data(xs), inner = inner, outer = outer), function (Δ)
Expand Down

0 comments on commit a816377

Please sign in to comment.