diff --git a/src/oneelement.jl b/src/oneelement.jl index 9011a3c3..ec9791cb 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -47,23 +47,20 @@ Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{I @boundscheck checkbounds(A, kj...) ifelse(kj == A.ind, A.val, zero(T)) end -const VectorInds = Union{AbstractRange{<:Integer}, Integer} +const VectorInds = Union{AbstractUnitRange{<:Integer}, Integer} # no index is repeated for these indices const VectorIndsWithColon = Union{VectorInds, Colon} # retain the values from Ainds corresponding to the vector indices in inds _index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds)) _index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...) _index_shape(::Tuple{}, ::Tuple{}) = () -@inline function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorInds,N}) where {T,N} - I = to_indices(A, inds) # handle Bool, and convert to compatible index types (Int usually) +Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N} + I = to_indices(A, inds) # handle Bool, and convert to compatible index types @boundscheck checkbounds(A, I...) shape = _index_shape(I, I) nzind = _index_shape(A.ind, I) .- first.(shape) .+ firstindex.(shape) containsval = all(in.(A.ind, I)) OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1)) end -Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N} - getindex(A, to_indices(A, inds)...) -end """ nzind(A::OneElement{T,N}) -> CartesianIndex{N} diff --git a/test/runtests.jl b/test/runtests.jl index 8098e378..ab39c8b1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2207,6 +2207,12 @@ end @test @inferred(A[Base.IdentityUnitRange(2:3)]) == OneElement(2,(2,),(Base.IdentityUnitRange(2:3),)) @test A[:,:] == reshape(A, size(A)..., 1) + @test A[reverse(axes(A,1))] == A[collect(reverse(axes(A,1)))] + + @testset "repeated indices" begin + @test A[StepRangeLen(2, 0, 3)] == A[fill(2, 3)] + end + B = OneElement(2, (2,), (Base.IdentityUnitRange(-1:4),)) @test @inferred(A[:]) === @inferred(A[axes(A)...]) === A @test @inferred(A[3:4]) isa OneElement{Int,1}