From 4a5d062fed070a74043a0312649b320f09e941f4 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 16 Oct 2024 13:04:05 +0530 Subject: [PATCH] Change more broadcasts over broadcast args to maps (#211) * Change more broadcasts over broadcast args to maps * Bump version to v0.13.2 --- Project.toml | 2 +- src/broadcast.jl | 26 +++++++++++++++++++++----- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 5f7c6df..0feff49 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "BlockBandedMatrices" uuid = "ffab5731-97b5-5995-9138-79e8c1846df0" -version = "0.13.1" +version = "0.13.2" [deps] ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" diff --git a/src/broadcast.jl b/src/broadcast.jl index 48b6d72..cfc54fb 100644 --- a/src/broadcast.jl +++ b/src/broadcast.jl @@ -58,8 +58,14 @@ function _broadcast_blockbandwidths((l,u), A::AbstractArray, (ax1,ax2)) end -blockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(*)}) = - min.(_broadcast_blockbandwidths.(Ref(_blockbnds(bc)), bc.args, Ref(axes(bc)))...) +function blockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(*)}) + ax = axes(bc) + blkinds = _blockbnds(bc) + t = map(bc.args) do x + _broadcast_blockbandwidths(blkinds, x, ax) + end + min.(t...) +end blockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(/)}) = _broadcast_blockbandwidths(_blockbnds(bc), first(bc.args), axes(bc)) blockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(\)}) = _broadcast_blockbandwidths(_blockbnds(bc), last(bc.args), axes(bc)) @@ -76,8 +82,13 @@ function _broadcast_subblockbandwidths((l,u), A::AbstractArray) subblockbandwidths(A) # need to special case vector broadcasting end -subblockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(*)}) = - min.(_broadcast_subblockbandwidths.(Ref(_subblockbnds(bc)), bc.args)...) +function subblockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(*)}) + subbbw = _subblockbnds(bc) + t = map(bc.args) do x + _broadcast_subblockbandwidths(subbbw, x) + end + min.(t...) +end subblockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(/)}) = _broadcast_subblockbandwidths(_subblockbnds(bc), first(bc.args)) subblockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof(\)}) = _broadcast_subblockbandwidths(_subblockbnds(bc), last(bc.args)) @@ -85,7 +96,12 @@ subblockbandwidths(bc::Broadcasted{<:Union{Nothing,BroadcastStyle},<:Any,typeof( function subblockbandwidths(bc::Broadcasted) (a,b) = size(bc) bnds = (a-1,b-1) - _isweakzero(bc.f, bc.args...) && return min.(bnds, max.(_broadcast_subblockbandwidths.(Ref(bnds), bc.args)...)) + if _isweakzero(bc.f, bc.args...) + t = map(bc.args) do x + _broadcast_subblockbandwidths(bnds, x) + end + return min.(bnds, max.(t...)) + end bnds end