Skip to content

Commit

Permalink
Pass names (index_vars) to container extensions (#3088)
Browse files Browse the repository at this point in the history
  • Loading branch information
hellemo authored Sep 30, 2022
1 parent 2b2c321 commit 1af86eb
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 5 deletions.
11 changes: 10 additions & 1 deletion src/Containers/container.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ a `VectorizedProductIterator` and the function returns
function default_container end

"""
container(f::Function, indices, c::Type{C}, names)
Create a container of type `C` with index names `names`, indices `indices` and
values at given indices given by `f`. If this method is not specialized on
`Type{C}`, it falls back to calling `container(f, indices, c)` for backwards
compatibility with containers not supporting index names.
container(f::Function, indices, ::Type{C})
Create a container of type `C` with indices `indices` and values at given
Expand Down Expand Up @@ -60,7 +67,9 @@ SparseAxisArray{Int64,2,Tuple{Int64,Int64}} with 5 entries:
[1, 3] = 4
```
"""
function container end
function container(f::Function, indices, D, names)
return container(f, indices, D)
end

function container(f::Function, indices)
return container(f, indices, default_container(indices))
Expand Down
23 changes: 19 additions & 4 deletions src/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -307,16 +307,31 @@ function container_code(
if requested_container == :Auto
return :(Containers.container($f, $indices))
elseif requested_container == :DenseAxisArray
return :(Containers.container($f, $indices, Containers.DenseAxisArray))
return :(Containers.container(
$f,
$indices,
Containers.DenseAxisArray,
$index_vars,
))
elseif requested_container == :SparseAxisArray
return :(Containers.container($f, $indices, Containers.SparseAxisArray))
return :(Containers.container(
$f,
$indices,
Containers.SparseAxisArray,
$index_vars,
))
elseif requested_container == :Array
return :(Containers.container($f, $indices, Array))
return :(Containers.container($f, $indices, Array, $index_vars))
else
# This is a symbol or expression from outside JuMP, so we need to escape
# it.
requested_container = esc(requested_container)
return :(Containers.container($f, $indices, $requested_container))
return :(Containers.container(
$f,
$indices,
$requested_container,
$index_vars,
))
end
end

Expand Down
23 changes: 23 additions & 0 deletions test/Containers/macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,26 @@ end
@test length(z) == 4
@test z[1, 2] == 3
end

# Test containers that use subindex names
struct _MyContainer2
names::Any
d::Any
end

function Containers.container(
f::Function,
indices,
::Type{_MyContainer2},
names,
)
key(i::Tuple) = i
key(i::Tuple{T}) where {T} = i[1]
return _MyContainer2(names, Dict(key(i) => f(i...) for i in indices))
end

@testset "_MyContainer2" begin
Containers.@container(v[i = 1:3], sin(i), container = _MyContainer2)
@test v.d isa Dict{Int,Float64}
@test v.names == [:i]
end

0 comments on commit 1af86eb

Please sign in to comment.