Skip to content

Commit

Permalink
Restrict to AbstractUnitRanges to avoid repeated indices
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed May 9, 2024
1 parent beb758d commit 3e5c4a2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,20 @@ Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, kj::Vararg{I
@boundscheck checkbounds(A, kj...)
ifelse(kj == A.ind, A.val, zero(T))
end
const VectorInds = Union{AbstractRange{<:Integer}, Integer}
const VectorInds = Union{AbstractUnitRange{<:Integer}, Integer} # no index is repeated for these indices
const VectorIndsWithColon = Union{VectorInds, Colon}
# retain the values from Ainds corresponding to the vector indices in inds
_index_shape(Ainds, inds::Tuple{Integer, Vararg{Any}}) = _index_shape(Base.tail(Ainds), Base.tail(inds))
_index_shape(Ainds, inds::Tuple{AbstractVector, Vararg{Any}}) = (Ainds[1], _index_shape(Base.tail(Ainds), Base.tail(inds))...)
_index_shape(::Tuple{}, ::Tuple{}) = ()
@inline function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorInds,N}) where {T,N}
I = to_indices(A, inds) # handle Bool, and convert to compatible index types (Int usually)
Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N}
I = to_indices(A, inds) # handle Bool, and convert to compatible index types
@boundscheck checkbounds(A, I...)
shape = _index_shape(I, I)
nzind = _index_shape(A.ind, I) .- first.(shape) .+ firstindex.(shape)
containsval = all(in.(A.ind, I))
OneElement(getindex_value(A), containsval ? Int.(nzind) : Int.(lastindex.(shape,1)).+1, axes.(shape,1))
end
Base.@propagate_inbounds function Base.getindex(A::OneElement{T,N}, inds::Vararg{VectorIndsWithColon,N}) where {T,N}
getindex(A, to_indices(A, inds)...)
end

"""
nzind(A::OneElement{T,N}) -> CartesianIndex{N}
Expand Down
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2207,6 +2207,12 @@ end
@test @inferred(A[Base.IdentityUnitRange(2:3)]) == OneElement(2,(2,),(Base.IdentityUnitRange(2:3),))
@test A[:,:] == reshape(A, size(A)..., 1)

@test A[reverse(axes(A,1))] == A[collect(reverse(axes(A,1)))]

@testset "repeated indices" begin
@test A[StepRangeLen(2, 0, 3)] == A[fill(2, 3)]
end

B = OneElement(2, (2,), (Base.IdentityUnitRange(-1:4),))
@test @inferred(A[:]) === @inferred(A[axes(A)...]) === A
@test @inferred(A[3:4]) isa OneElement{Int,1}
Expand Down

0 comments on commit 3e5c4a2

Please sign in to comment.