From 74b65f7d1e3c32c12408efe4f2de43222a423549 Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 27 Nov 2023 16:01:59 +1300 Subject: [PATCH 1/3] [Containers] fix multi-arg eachindex --- src/Containers/DenseAxisArray.jl | 27 +++++++++++++++++++++---- src/Containers/SparseAxisArray.jl | 17 ++++++++++++++++ test/Containers/test_DenseAxisArray.jl | 13 ++++++++++++ test/Containers/test_SparseAxisArray.jl | 13 ++++++++++++ 4 files changed, 66 insertions(+), 4 deletions(-) diff --git a/src/Containers/DenseAxisArray.jl b/src/Containers/DenseAxisArray.jl index 7124835bfd7..61aac337788 100644 --- a/src/Containers/DenseAxisArray.jl +++ b/src/Containers/DenseAxisArray.jl @@ -343,7 +343,28 @@ if VERSION >= v"1.9.0-DEV" end end -Base.eachindex(A::DenseAxisArray) = CartesianIndices(size(A.data)) +Base.eachindex(A::DenseAxisArray) = eachindex(IndexStyle(A), A) + +function Base.eachindex(::IndexCartesian, A::DenseAxisArray) + return CartesianIndices(size(A.data)) +end + +function Base.eachindex( + ::IndexCartesian, + A::DenseAxisArray, + B::DenseAxisArray..., +) + ret = eachindex(A) + for b in B + if eachindex(b) != ret + err = DimensionMismatch( + "incompatible dimensions in eachindex. Got $(eachindex.((A, B...)))", + ) + throw(err) + end + end + return ret +end # Use recursion over tuples to ensure the return-type of functions like # `Base.to_index` are type-stable. @@ -513,9 +534,7 @@ function _broadcast_axes_check(x::NTuple{N}) where {N} axes = first(x) for i in 2:N if x[i][1] != axes[1] - error( - "Unable to broadcast over DenseAxisArrays with different axes.", - ) + throw(DimensionMismatch("DenseAxisArrays have different axes.")) end end return axes diff --git a/src/Containers/SparseAxisArray.jl b/src/Containers/SparseAxisArray.jl index 0bfb020cc85..64aebd67fdf 100644 --- a/src/Containers/SparseAxisArray.jl +++ b/src/Containers/SparseAxisArray.jl @@ -219,6 +219,23 @@ end Base.eachindex(d::SparseAxisArray) = keys(d.data) +function Base.eachindex( + ::IndexCartesian, + A::SparseAxisArray, + B::SparseAxisArray..., +) + ret = eachindex(A) + for b in B + if eachindex(b) != ret + err = DimensionMismatch( + "incompatible dimensions in eachindex. Got $(eachindex.((A, B...)))", + ) + throw(err) + end + end + return ret +end + ################ # Broadcasting # ################ diff --git a/test/Containers/test_DenseAxisArray.jl b/test/Containers/test_DenseAxisArray.jl index 546a16169f1..976784e5c2f 100644 --- a/test/Containers/test_DenseAxisArray.jl +++ b/test/Containers/test_DenseAxisArray.jl @@ -862,4 +862,17 @@ function test_sum_dims() return end +function test_multi_arg_eachindex() + model = Model() + @variable(model, x[2:3]) + @variable(model, y[2:3]) + @variable(model, z[2:4, 1:2]) + @test eachindex(x) == CartesianIndices((2,)) + @test eachindex(y) == CartesianIndices((2,)) + @test eachindex(z) == CartesianIndices((3, 2)) + @test eachindex(x, y) == CartesianIndices((2,)) + @test_throws DimensionMismatch eachindex(x, z) + return +end + end # module diff --git a/test/Containers/test_SparseAxisArray.jl b/test/Containers/test_SparseAxisArray.jl index 914f71b0e58..ad5dd5ca4b5 100644 --- a/test/Containers/test_SparseAxisArray.jl +++ b/test/Containers/test_SparseAxisArray.jl @@ -343,4 +343,17 @@ function test_containers_sparseaxisarray_kwarg_setindex() return end +function test_multi_arg_eachindex() + model = Model() + @variable(model, x[2:3], container = SparseAxisArray) + @variable(model, y[2:3], container = SparseAxisArray) + @variable(model, z[2:4, 1:2], container = SparseAxisArray) + @test eachindex(x) == keys(x.data) + @test eachindex(y) == keys(y.data) + @test eachindex(z) == keys(z.data) + @test eachindex(x, y) == eachindex(x) + @test_throws DimensionMismatch eachindex(x, z) + return +end + end # module From adbcca55bca9d347be8c24e3aacddaa12a74ebd6 Mon Sep 17 00:00:00 2001 From: odow Date: Mon, 27 Nov 2023 16:09:52 +1300 Subject: [PATCH 2/3] Update --- test/Containers/test_DenseAxisArray.jl | 7 +++---- test/Containers/test_SparseAxisArray.jl | 11 +++++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/test/Containers/test_DenseAxisArray.jl b/test/Containers/test_DenseAxisArray.jl index 976784e5c2f..fd780720a35 100644 --- a/test/Containers/test_DenseAxisArray.jl +++ b/test/Containers/test_DenseAxisArray.jl @@ -863,10 +863,9 @@ function test_sum_dims() end function test_multi_arg_eachindex() - model = Model() - @variable(model, x[2:3]) - @variable(model, y[2:3]) - @variable(model, z[2:4, 1:2]) + Containers.@container(x[i = 2:3], i) + Containers.@container(y[i = 2:3], i) + Containers.@container(z[i = 2:4, j = 1:2], i + j) @test eachindex(x) == CartesianIndices((2,)) @test eachindex(y) == CartesianIndices((2,)) @test eachindex(z) == CartesianIndices((3, 2)) diff --git a/test/Containers/test_SparseAxisArray.jl b/test/Containers/test_SparseAxisArray.jl index ad5dd5ca4b5..94a30762b33 100644 --- a/test/Containers/test_SparseAxisArray.jl +++ b/test/Containers/test_SparseAxisArray.jl @@ -344,10 +344,13 @@ function test_containers_sparseaxisarray_kwarg_setindex() end function test_multi_arg_eachindex() - model = Model() - @variable(model, x[2:3], container = SparseAxisArray) - @variable(model, y[2:3], container = SparseAxisArray) - @variable(model, z[2:4, 1:2], container = SparseAxisArray) + Containers.@container(x[i = 2:3], i, container = SparseAxisArray) + Containers.@container(y[i = 2:3], i, container = SparseAxisArray) + Containers.@container( + z[i = 2:4, j = 1:2], + i + j, + container = SparseAxisArray, + ) @test eachindex(x) == keys(x.data) @test eachindex(y) == keys(y.data) @test eachindex(z) == keys(z.data) From 8a920b4daf5bdf45484e13f6b8a9e4a73301590d Mon Sep 17 00:00:00 2001 From: Oscar Dowson Date: Tue, 28 Nov 2023 07:59:55 +1300 Subject: [PATCH 3/3] Update DenseAxisArray.jl --- src/Containers/DenseAxisArray.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Containers/DenseAxisArray.jl b/src/Containers/DenseAxisArray.jl index 61aac337788..1b23ee4f057 100644 --- a/src/Containers/DenseAxisArray.jl +++ b/src/Containers/DenseAxisArray.jl @@ -534,7 +534,9 @@ function _broadcast_axes_check(x::NTuple{N}) where {N} axes = first(x) for i in 2:N if x[i][1] != axes[1] - throw(DimensionMismatch("DenseAxisArrays have different axes.")) + error( + "Unable to broadcast over DenseAxisArrays with different axes.", + ) end end return axes