Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Containers] use OrderedDict as the data structure for SparseAxisArray #3681

Merged
merged 7 commits into from
Feb 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/Containers/Containers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ necessarily integers.
"""
module Containers

import Base.Meta.isexpr
import OrderedCollections

# Arbitrary typed indices. Linear indexing not supported.
struct IndexAnyCartesian <: Base.IndexStyle end
Expand Down
57 changes: 36 additions & 21 deletions src/Containers/SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,32 @@

"""
struct SparseAxisArray{T,N,K<:NTuple{N, Any}} <: AbstractArray{T,N}
data::Dict{K,T}
data::OrderedCollections.OrderedDict{K,T}
end

`N`-dimensional array with elements of type `T` where only a subset of the
entries are defined. The entries with indices `idx = (i1, i2, ..., iN)` in
`keys(data)` has value `data[idx]`. Note that as opposed to
`SparseArrays.AbstractSparseArray`, the missing entries are not assumed to be
`zero(T)`, they are simply not part of the array. This means that the result of
`map(f, sa::SparseAxisArray)` or `f.(sa::SparseAxisArray)` has the same sparsity
structure than `sa` even if `f(zero(T))` is not zero.
`keys(data)` has value `data[idx]`.

Note that, as opposed to `SparseArrays.AbstractSparseArray`, the missing entries
are not assumed to be `zero(T)`, they are simply not part of the array. This
means that the result of `map(f, sa::SparseAxisArray)` or
`f.(sa::SparseAxisArray)` has the same sparsity structure as `sa`, even if
`f(zero(T))` is not zero.

## Example

```jldoctest
julia> dict = Dict((:a, 2) => 1.0, (:a, 3) => 2.0, (:b, 3) => 3.0)
Dict{Tuple{Symbol, Int64}, Float64} with 3 entries:
julia> using OrderedCollections: OrderedDict

julia> dict = OrderedDict((:a, 2) => 1.0, (:a, 3) => 2.0, (:b, 3) => 3.0)
OrderedDict{Tuple{Symbol, Int64}, Float64} with 3 entries:
(:a, 2) => 1.0
(:a, 3) => 2.0
(:b, 3) => 3.0
(:a, 2) => 1.0

julia> array = Containers.SparseAxisArray(dict)
SparseAxisArray{Float64, 2, Tuple{Symbol, Int64}} with 3 entries:
JuMP.Containers.SparseAxisArray{Float64, 2, Tuple{Symbol, Int64}} with 3 entries:
odow marked this conversation as resolved.
Show resolved Hide resolved
[a, 2] = 1.0
[a, 3] = 2.0
[b, 3] = 3.0
Expand All @@ -36,15 +40,23 @@ julia> array[:b, 3]
```
"""
struct SparseAxisArray{T,N,K<:NTuple{N,Any}} <: AbstractArray{T,N}
data::Dict{K,T}
data::OrderedCollections.OrderedDict{K,T}
names::NTuple{N,Symbol}
end

function SparseAxisArray(d::Dict{K,T}) where {T,N,K<:NTuple{N,Any}}
function SparseAxisArray(
d::AbstractDict{K,T},
names::NTuple{N,Symbol},
) where {T,N,K<:NTuple{N,Any}}
od = convert(OrderedCollections.OrderedDict{K,T}, d)
return SparseAxisArray(od, names)
end

function SparseAxisArray(d::AbstractDict{K,T}) where {T,N,K<:NTuple{N,Any}}
return SparseAxisArray(d, ntuple(n -> Symbol("#$n"), N))
end

SparseAxisArray(d::Dict, ::Nothing) = SparseAxisArray(d)
SparseAxisArray(d::AbstractDict, ::Nothing) = SparseAxisArray(d)

Base.length(sa::SparseAxisArray) = length(sa.data)

Expand All @@ -71,7 +83,7 @@ function Base.similar(
::Type{T},
length::Integer = 0,
) where {S,T,N,K}
d = Dict{K,T}()
d = OrderedCollections.OrderedDict{K,T}()
if !iszero(length)
sizehint!(d, length)
end
Expand Down Expand Up @@ -165,7 +177,7 @@ function Base.getindex(
end
K2 = _sliced_key_type(K, args...)
if K2 !== nothing
new_data = Dict{K2,T}(
new_data = OrderedCollections.OrderedDict{K2,T}(
_sliced_key(k, args) => v for (k, v) in d.data if _filter(k, args)
)
names = _sliced_key_name(K, d.names, args...)
Expand Down Expand Up @@ -293,12 +305,16 @@ end
function Base.copy(
bc::Base.Broadcast.Broadcasted{BroadcastStyle{N,K}},
) where {N,K}
dict = Dict(index => _getindex(bc, index) for index in _indices(bc.args...))
if isempty(dict) && dict isa Dict{Any,Any}
dict = OrderedCollections.OrderedDict(
index => _getindex(bc, index) for index in _indices(bc.args...)
)
if isempty(dict) && dict isa OrderedCollections.OrderedDict{Any,Any}
# If `dict` is empty (e.g., because there are no indices), then
# inference will produce a `Dict{Any,Any}`, and we won't have enough
# type information to call SparseAxisArray(dict). As a work-around, we
# explicitly construct the type of the resulting SparseAxisArray.
# inference will produce a `OrderedCollections.OrderedDict{Any,Any}`,
# and we won't have enough type information to call
# `SparseAxisArray(dict)`. As a work-around, we explicitly construct the
# type of the resulting SparseAxisArray.
#
# For more, see JuMP issue #2867.
return SparseAxisArray{Any,N,K}(dict, ntuple(n -> Symbol("#$n"), N))
end
Expand Down Expand Up @@ -448,7 +464,6 @@ function Base.show(io::IOContext, x::SparseAxisArray)
(i, (key, value)) in enumerate(x.data) if
i < half_screen_rows || i > length(x) - half_screen_rows
]
sort!(key_strings; by = x -> x[1])
odow marked this conversation as resolved.
Show resolved Hide resolved
pad = maximum(length(x[1]) for x in key_strings)
for (i, (key, value)) in enumerate(key_strings)
print(io, " [", rpad(key, pad), "] = ", value)
Expand Down
20 changes: 14 additions & 6 deletions src/Containers/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,21 @@ function container(
end
# Same as `map` but does not allocate the resulting vector.
mappings = Base.Generator(I -> I => f(I...), indices)
# Same as `Dict(mapping)` but it will error if two indices are the same.
# Same as `OrderedCollections.OrderedDict(mapping)`, but it will error if
# two indices are the same.
data = NoDuplicateDict(mappings)
return _sparseaxisarray(data.dict, f, indices, names)
end

# The NoDuplicateDict was able to infer the element type.
_sparseaxisarray(dict::Dict, ::Any, ::Any, names) = SparseAxisArray(dict, names)
function _sparseaxisarray(
dict::OrderedCollections.OrderedDict,
::Any,
::Any,
names,
)
return SparseAxisArray(dict, names)
end

# @default_eltype succeeded and inferred a tuple of the appropriate size!
# Use `return_types` to get the value type of the dictionary.
Expand All @@ -159,13 +167,13 @@ function _container_dict(
) where {N}
ret = Base.return_types(f, K)
V = length(ret) == 1 ? first(ret) : Any
return Dict{K,V}()
return OrderedCollections.OrderedDict{K,V}()
end

# @default_eltype bailed and returned Any. Use an NTuple of Any of the
# appropriate size intead.
function _container_dict(::Any, ::Any, K::Type{<:NTuple{N,Any}}) where {N}
return Dict{K,Any}()
return OrderedCollections.OrderedDict{K,Any}()
end

# @default_eltype bailed and returned Union{}. Use an NTuple of Any of the
Expand All @@ -176,7 +184,7 @@ function _container_dict(
::Function,
K::Type{<:NTuple{N,Any}},
) where {N}
return Dict{K,Any}()
return OrderedCollections.OrderedDict{K,Any}()
end

# Calling `@default_eltye` on `x` isn't sufficient, because the iterator may
Expand All @@ -189,7 +197,7 @@ _default_eltype(x) = Base.@default_eltype x
# best-guess attempt, collect all of the keys excluding the conditional
# statement (these must be defined, because the conditional applies to the
# lowest-level of the index loops), then get the eltype of the result.
function _sparseaxisarray(dict::Dict{Any,Any}, f, indices, names)
function _sparseaxisarray(dict::AbstractDict{Any,Any}, f, indices, names)
@assert isempty(dict)
d = _container_dict(_default_eltype(indices), f, _eltype_or_any(indices))
return SparseAxisArray(d, names)
Expand Down
21 changes: 15 additions & 6 deletions src/Containers/no_duplicate_dict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,47 @@

"""
struct NoDuplicateDict{K, V} <: AbstractDict{K, V}
dict::Dict{K, V}
dict::OrderedCollections.OrderedDict{K, V}
end

Same as `Dict{K, V}` but errors if constructed from an iterator with duplicate
keys.
Same as `OrderedCollections.OrderedDict{K, V}` but errors if constructed from an
iterator with duplicate keys.
"""
struct NoDuplicateDict{K,V} <: AbstractDict{K,V}
dict::Dict{K,V}
NoDuplicateDict{K,V}() where {K,V} = new{K,V}(Dict{K,V}())
dict::OrderedCollections.OrderedDict{K,V}

function NoDuplicateDict{K,V}() where {K,V}
new{K,V}(OrderedCollections.OrderedDict{K,V}())
end
end

# Implementation of the `AbstractDict` API.
function Base.empty(::NoDuplicateDict, ::Type{K}, ::Type{V}) where {K,V}
return NoDuplicateDict{K,V}()
end

Base.iterate(d::NoDuplicateDict, args...) = iterate(d.dict, args...)

Base.length(d::NoDuplicateDict) = length(d.dict)

Base.haskey(dict::NoDuplicateDict, key) = haskey(dict.dict, key)

Base.getindex(dict::NoDuplicateDict, key) = getindex(dict.dict, key)

function Base.setindex!(dict::NoDuplicateDict, value, key)
if haskey(dict, key)
error("Repeated index ", key, ". Index sets must have unique elements.")
end
return setindex!(dict.dict, value, key)
end

function NoDuplicateDict{K,V}(it) where {K,V}
dict = NoDuplicateDict{K,V}()
for (k, v) in it
dict[k] = v
end
return dict
end

function NoDuplicateDict(it)
return Base.dict_with_eltype((K, V) -> NoDuplicateDict{K,V}, it, eltype(it))
end
17 changes: 17 additions & 0 deletions test/Containers/test_SparseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ module TestContainersSparseAxisArray
using JuMP.Containers
using Test

import LinearAlgebra

function _util_sparse_test(d, sum_d, d2, d3, dsqr, d_bads)
sqr(x) = x^2
# map
Expand Down Expand Up @@ -359,4 +361,19 @@ function test_multi_arg_eachindex()
return
end

function test_sparseaxisarray_order()
A = [[1, 2, 10], [2, 3, 30]]
Containers.@container(
x[i in 1:2, j in A[i]],
i + j,
container = SparseAxisArray,
)
Containers.@container(x1[j in A[1]], 1 + j, container = SparseAxisArray)
Containers.@container(x2[j in A[2]], 2 + j, container = SparseAxisArray)
@test x[1, :] == x1
@test x[2, :] == x2
@test LinearAlgebra.dot(x[1, :], 1:3) == 41
return
end

end # module
Loading