Skip to content

Commit

Permalink
Fix type instability when setindex!! (#549)
Browse files Browse the repository at this point in the history
* Add `BangBang.possible` for general arrays

* Remove redundant BangBang.possible

* add comments rt JuliaFold2 PR

* copy tor's fix to JuliaFolds here

* add BangBang prefix to the functions

* remove tests of possible, no longer needed

* Add Tor's tests

* Import `setindex!!`
  • Loading branch information
sunxd3 authored Nov 2, 2023
1 parent 04b03cd commit 2e940aa
Show file tree
Hide file tree
Showing 2 changed files with 145 additions and 26 deletions.
56 changes: 30 additions & 26 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -538,36 +538,40 @@ function remove_parent_lens(vn_parent::VarName{sym}, vn_child::VarName{sym}) whe
end

# HACK: All of these are related to https://github.com/JuliaFolds/BangBang.jl/issues/233
# and https://github.com/JuliaFolds/BangBang.jl/pull/238.
# HACK(torfjelde): Avoids type-instability in `dot_assume` for `SimpleVarInfo`.
function BangBang.possible(
::typeof(BangBang._setindex!), ::C, ::T, ::Colon, ::Integer
) where {C<:AbstractMatrix,T<:AbstractVector}
return BangBang.implements(setindex!, C) &&
promote_type(eltype(C), eltype(T)) <: eltype(C)
end
function BangBang.possible(
::typeof(BangBang._setindex!), ::C, ::T, ::AbstractPPL.ConcretizedSlice, ::Integer
) where {C<:AbstractMatrix,T<:AbstractVector}
return BangBang.implements(setindex!, C) &&
promote_type(eltype(C), eltype(T)) <: eltype(C)
end
# HACK: Makes it possible to use ranges, etc. for setting a vector.
# For example, without this hack, BangBang.jl will consider
# and https://github.com/JuliaFolds/BangBang.jl/pull/238, https://github.com/JuliaFolds2/BangBang.jl/pull/16.
# This avoids type-instability in `dot_assume` for `SimpleVarInfo`.
# The following code a copy from https://github.com/JuliaFolds2/BangBang.jl/pull/16 authored by torfjelde
# Default implementation for `_setindex!` with `AbstractArray`.
# But this will return `false` even in cases such as
#
# setindex!!([1, 2, 3], [4, 5, 6], :)
#
# because `promote_type(eltype(C), T) <: eltype(C)` is `false`.
# To address this, we specialize on the case where `T<:AbstractArray`.
# In addition, we need to support a wide range of indexing behaviors:
#
# x[1:2] = [1, 2]
# We also need to ensure that the dimensionality of the index is
# valid, i.e. that we're not returning `true` in cases such as
#
# as NOT supported. This results is calling the immutable
# `BangBang.setindex` instead, which also ends up expanding the
# type of the containing array (`x` in the above scenario) to
# have element type `Any`.
# The below code just, correctly, marks this as possible and
# thus we hit the mutable `setindex!` instead.
# setindex!!([1, 2, 3], [4, 5], 1)
#
# which should return `false`.
_index_dimension(::Any) = 0
_index_dimension(::Colon) = 1
_index_dimension(::AbstractVector) = 1
_index_dimension(indices::Tuple) = sum(map(_index_dimension, indices))

function BangBang.possible(
::typeof(BangBang._setindex!), ::C, ::T, ::AbstractVector{<:Integer}
) where {C<:AbstractVector,T<:AbstractVector}
::typeof(BangBang._setindex!), ::C, ::T, indices::Vararg
) where {M,C<:AbstractArray{<:Real},T<:AbstractArray{<:Real,M}}
return BangBang.implements(setindex!, C) &&
promote_type(eltype(C), eltype(T)) <: eltype(C)
promote_type(eltype(C), eltype(T)) <: eltype(C) &&
# This will still return `false` for scenarios such as
#
# setindex!!([1, 2, 3], [4, 5, 6], :, 1)
#
# which are in fact valid. However, this cases are rare.
(_index_dimension(indices) == M || _index_dimension(indices) == 1)
end

# HACK(torfjelde): This makes it so it works on iterators, etc. by default.
Expand Down
115 changes: 115 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,119 @@
x = rand(dist)
@test vectorize(dist, x) == vec(x.UL)
end

@testset "BangBang.possible" begin
using DynamicPPL.BangBang: setindex!!

# Some utility methods for testing `setindex!`.
test_linear_index_only(::Tuple, ::AbstractArray) = false
test_linear_index_only(inds::NTuple{1}, ::AbstractArray) = true
test_linear_index_only(inds::NTuple{1}, ::AbstractVector) = false

function replace_colon_with_axis(inds::Tuple, x)
ntuple(length(inds)) do i
inds[i] isa Colon ? axes(x, i) : inds[i]
end
end
function replace_colon_with_vector(inds::Tuple, x)
ntuple(length(inds)) do i
inds[i] isa Colon ? collect(axes(x, i)) : inds[i]
end
end
function replace_colon_with_range(inds::Tuple, x)
ntuple(length(inds)) do i
inds[i] isa Colon ? (1:size(x, i)) : inds[i]
end
end
function replace_colon_with_booleans(inds::Tuple, x)
ntuple(length(inds)) do i
inds[i] isa Colon ? trues(size(x, i)) : inds[i]
end
end

function replace_colon_with_range_linear(inds::NTuple{1}, x::AbstractArray)
return inds[1] isa Colon ? (1:length(x),) : inds
end

@testset begin
@test setindex!!((1, 2, 3), :two, 2) === (1, :two, 3)
@test setindex!!((a=1, b=2, c=3), :two, :b) === (a=1, b=:two, c=3)
@test setindex!!([1, 2, 3], :two, 2) == [1, :two, 3]
@test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 10, :a) ==
Dict(:a => 10, :b => 2)
@test setindex!!(Dict{Symbol,Int}(:a => 1, :b => 2), 3, "c") ==
Dict(:a => 1, :b => 2, "c" => 3)
end

@testset "mutation" begin
@testset "without type expansion" begin
for args in [([1, 2, 3], 20, 2), (Dict(:a => 1, :b => 2), 10, :a)]
@test setindex!!(args...) === args[1]
end
end

@testset "with type expansion" begin
@test setindex!!([1, 2, 3], [4, 5], 1) == [[4, 5], 2, 3]
@test setindex!!([1, 2, 3], [4, 5, 6], :, 1) == [4, 5, 6]
end
end

@testset "slices" begin
@testset "$(typeof(x)) with $(src_idx)" for (x, src_idx) in [
# Vector.
(randn(2), (:,)),
(randn(2), (1:2,)),
# Matrix.
(randn(2, 3), (:,)),
(randn(2, 3), (:, 1)),
(randn(2, 3), (:, 1:3)),
# 3D array.
(randn(2, 3, 4), (:, 1, :)),
(randn(2, 3, 4), (:, 1:3, :)),
(randn(2, 3, 4), (1, 1:3, :)),
]
# Base case.
@test @inferred(setindex!!(x, x[src_idx...], src_idx...)) === x

# If we have `Colon` in the index, we replace this with other equivalent indices.
if any(Base.Fix2(isa, Colon), src_idx)
if test_linear_index_only(src_idx, x)
# With range instead of `Colon`.
@test @inferred(
setindex!!(
x,
x[src_idx...],
replace_colon_with_range_linear(src_idx, x)...,
)
) === x
else
# With axis instead of `Colon`.
@test @inferred(
setindex!!(
x, x[src_idx...], replace_colon_with_axis(src_idx, x)...
)
) === x
# With range instead of `Colon`.
@test @inferred(
setindex!!(
x, x[src_idx...], replace_colon_with_range(src_idx, x)...
)
) === x
# With vectors instead of `Colon`.
@test @inferred(
setindex!!(
x, x[src_idx...], replace_colon_with_vector(src_idx, x)...
)
) === x
# With boolean index instead of `Colon`.
@test @inferred(
setindex!!(
x, x[src_idx...], replace_colon_with_booleans(src_idx, x)...
)
) === x
end
end
end
end
end
end

0 comments on commit 2e940aa

Please sign in to comment.