From 9ce7e438b3fc25bd20c46dde701e30e2e571ddda Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Tue, 28 Nov 2023 09:09:29 +1300 Subject: [PATCH] [Containers] fix multi-arg eachindex (#3587) --- src/Containers/DenseAxisArray.jl | 23 ++++++++++++++++++++++- src/Containers/SparseAxisArray.jl | 17 +++++++++++++++++ test/Containers/test_DenseAxisArray.jl | 12 ++++++++++++ test/Containers/test_SparseAxisArray.jl | 16 ++++++++++++++++ 4 files changed, 67 insertions(+), 1 deletion(-) diff --git a/src/Containers/DenseAxisArray.jl b/src/Containers/DenseAxisArray.jl index 7124835bfd7..1b23ee4f057 100644 --- a/src/Containers/DenseAxisArray.jl +++ b/src/Containers/DenseAxisArray.jl @@ -343,7 +343,28 @@ if VERSION >= v"1.9.0-DEV" end end -Base.eachindex(A::DenseAxisArray) = CartesianIndices(size(A.data)) +Base.eachindex(A::DenseAxisArray) = eachindex(IndexStyle(A), A) + +function Base.eachindex(::IndexCartesian, A::DenseAxisArray) + return CartesianIndices(size(A.data)) +end + +function Base.eachindex( + ::IndexCartesian, + A::DenseAxisArray, + B::DenseAxisArray..., +) + ret = eachindex(A) + for b in B + if eachindex(b) != ret + err = DimensionMismatch( + "incompatible dimensions in eachindex. Got $(eachindex.((A, B...)))", + ) + throw(err) + end + end + return ret +end # Use recursion over tuples to ensure the return-type of functions like # `Base.to_index` are type-stable. diff --git a/src/Containers/SparseAxisArray.jl b/src/Containers/SparseAxisArray.jl index 0bfb020cc85..64aebd67fdf 100644 --- a/src/Containers/SparseAxisArray.jl +++ b/src/Containers/SparseAxisArray.jl @@ -219,6 +219,23 @@ end Base.eachindex(d::SparseAxisArray) = keys(d.data) +function Base.eachindex( + ::IndexCartesian, + A::SparseAxisArray, + B::SparseAxisArray..., +) + ret = eachindex(A) + for b in B + if eachindex(b) != ret + err = DimensionMismatch( + "incompatible dimensions in eachindex. Got $(eachindex.((A, B...)))", + ) + throw(err) + end + end + return ret +end + ################ # Broadcasting # ################ diff --git a/test/Containers/test_DenseAxisArray.jl b/test/Containers/test_DenseAxisArray.jl index 546a16169f1..fd780720a35 100644 --- a/test/Containers/test_DenseAxisArray.jl +++ b/test/Containers/test_DenseAxisArray.jl @@ -862,4 +862,16 @@ function test_sum_dims() return end +function test_multi_arg_eachindex() + Containers.@container(x[i = 2:3], i) + Containers.@container(y[i = 2:3], i) + Containers.@container(z[i = 2:4, j = 1:2], i + j) + @test eachindex(x) == CartesianIndices((2,)) + @test eachindex(y) == CartesianIndices((2,)) + @test eachindex(z) == CartesianIndices((3, 2)) + @test eachindex(x, y) == CartesianIndices((2,)) + @test_throws DimensionMismatch eachindex(x, z) + return +end + end # module diff --git a/test/Containers/test_SparseAxisArray.jl b/test/Containers/test_SparseAxisArray.jl index 914f71b0e58..94a30762b33 100644 --- a/test/Containers/test_SparseAxisArray.jl +++ b/test/Containers/test_SparseAxisArray.jl @@ -343,4 +343,20 @@ function test_containers_sparseaxisarray_kwarg_setindex() return end +function test_multi_arg_eachindex() + Containers.@container(x[i = 2:3], i, container = SparseAxisArray) + Containers.@container(y[i = 2:3], i, container = SparseAxisArray) + Containers.@container( + z[i = 2:4, j = 1:2], + i + j, + container = SparseAxisArray, + ) + @test eachindex(x) == keys(x.data) + @test eachindex(y) == keys(y.data) + @test eachindex(z) == keys(z.data) + @test eachindex(x, y) == eachindex(x) + @test_throws DimensionMismatch eachindex(x, z) + return +end + end # module