Skip to content

Commit

Permalink
Improve type-inference in blockbandwidths for BlockBandedMatrix (#210)
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub authored Oct 15, 2024
1 parent 7467a90 commit cfc2e58
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
8 changes: 7 additions & 1 deletion src/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,13 @@ import BandedMatrices: _isweakzero
function blockbandwidths(bc::Broadcasted)
(a,b) = size(bc)
bnds = (a-1,b-1)
_isweakzero(bc.f, bc.args...) && return min.(bnds, max.(_broadcast_blockbandwidths.(Ref(bnds), bc.args, Ref(axes(bc)))...))
if _isweakzero(bc.f, bc.args...)
ax = axes(bc)
t = map(bc.args) do x
_broadcast_blockbandwidths(bnds, x, ax)
end
return min.(bnds, max.(t...))
end
bnds
end

Expand Down
8 changes: 8 additions & 0 deletions test/test_broadcasting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,14 @@ import Base: oneto
@test C == A + A
end
end

@testset "blockbandwidths" begin
B = BlockArray(ones(6,6), 1:3, 1:3)
BB = BlockBandedMatrix(B, (1,1))
bc = Broadcast.broadcasted(+, BB, BB)
bbw = @inferred blockbandwidths(bc)
@test bbw == blockbandwidths(BB)
end
end

end # module

0 comments on commit cfc2e58

Please sign in to comment.