Skip to content

Commit

Permalink
[Containers] fix multi-arg eachindex (#3587)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Nov 27, 2023
1 parent f7fb42b commit 9ce7e43
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 1 deletion.
23 changes: 22 additions & 1 deletion src/Containers/DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,28 @@ if VERSION >= v"1.9.0-DEV"
end
end

Base.eachindex(A::DenseAxisArray) = CartesianIndices(size(A.data))
Base.eachindex(A::DenseAxisArray) = eachindex(IndexStyle(A), A)

function Base.eachindex(::IndexCartesian, A::DenseAxisArray)
return CartesianIndices(size(A.data))
end

function Base.eachindex(
::IndexCartesian,
A::DenseAxisArray,
B::DenseAxisArray...,
)
ret = eachindex(A)
for b in B
if eachindex(b) != ret
err = DimensionMismatch(
"incompatible dimensions in eachindex. Got $(eachindex.((A, B...)))",
)
throw(err)
end
end
return ret
end

# Use recursion over tuples to ensure the return-type of functions like
# `Base.to_index` are type-stable.
Expand Down
17 changes: 17 additions & 0 deletions src/Containers/SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,23 @@ end

Base.eachindex(d::SparseAxisArray) = keys(d.data)

function Base.eachindex(
::IndexCartesian,
A::SparseAxisArray,
B::SparseAxisArray...,
)
ret = eachindex(A)
for b in B
if eachindex(b) != ret
err = DimensionMismatch(
"incompatible dimensions in eachindex. Got $(eachindex.((A, B...)))",
)
throw(err)
end
end
return ret
end

################
# Broadcasting #
################
Expand Down
12 changes: 12 additions & 0 deletions test/Containers/test_DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -862,4 +862,16 @@ function test_sum_dims()
return
end

function test_multi_arg_eachindex()
Containers.@container(x[i = 2:3], i)
Containers.@container(y[i = 2:3], i)
Containers.@container(z[i = 2:4, j = 1:2], i + j)
@test eachindex(x) == CartesianIndices((2,))
@test eachindex(y) == CartesianIndices((2,))
@test eachindex(z) == CartesianIndices((3, 2))
@test eachindex(x, y) == CartesianIndices((2,))
@test_throws DimensionMismatch eachindex(x, z)
return
end

end # module
16 changes: 16 additions & 0 deletions test/Containers/test_SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -343,4 +343,20 @@ function test_containers_sparseaxisarray_kwarg_setindex()
return
end

function test_multi_arg_eachindex()
Containers.@container(x[i = 2:3], i, container = SparseAxisArray)
Containers.@container(y[i = 2:3], i, container = SparseAxisArray)
Containers.@container(
z[i = 2:4, j = 1:2],
i + j,
container = SparseAxisArray,
)
@test eachindex(x) == keys(x.data)
@test eachindex(y) == keys(y.data)
@test eachindex(z) == keys(z.data)
@test eachindex(x, y) == eachindex(x)
@test_throws DimensionMismatch eachindex(x, z)
return
end

end # module

0 comments on commit 9ce7e43

Please sign in to comment.