diff --git a/Project.toml b/Project.toml index 103aae3..715a381 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ParameterHandling" uuid = "2412ca09-6db7-441c-8e3a-88d5709968c5" authors = ["Invenia Technical Computing Corporation"] -version = "0.4.8" +version = "0.4.9" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/flatten.jl b/src/flatten.jl index 1414cb0..08c8f23 100644 --- a/src/flatten.jl +++ b/src/flatten.jl @@ -79,16 +79,20 @@ function flatten(::Type{T}, x::SparseMatrixCSC) where {T<:Real} end function flatten(::Type{T}, x::Tuple) where {T<:Real} - x_vecs_and_backs = map(val -> flatten(T, val), x) - x_vecs, x_backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs) - lengths = map(length, x_vecs) - sz = _cumsum(lengths) + vec1, back1 = flatten(T, first(x)) + vec2, back2 = flatten(T, Base.tail(x)) + l1 = length(vec1) + l2 = length(vec2) function unflatten_to_Tuple(v::Vector{T}) - map(x_backs, lengths, sz) do x_back, l, s - return x_back(v[(s - l + 1):s]) - end + return (back1(v[1:l1]), back2(v[(l1 + 1):(l1 + l2)])...) end - return reduce(vcat, x_vecs), unflatten_to_Tuple + return vcat(vec1, vec2), unflatten_to_Tuple +end + +function flatten(::Type{T}, x::Tuple{}) where {T<:Real} + v = T[] + unflatten_to_empty_Tuple(::Vector{T}) = x + return v, unflatten_to_empty_Tuple end function flatten(::Type{T}, x::NamedTuple) where {T<:Real} diff --git a/test/flatten.jl b/test/flatten.jl index 6b15e42..975cc5e 100644 --- a/test/flatten.jl +++ b/test/flatten.jl @@ -39,6 +39,13 @@ test_flatten_interface((1.0, 2.0); check_inferred=tuple_infers) test_flatten_interface((1.0, (2.0, 3.0), randn(5)); check_inferred=tuple_infers) + + # Prevent regression of PR #67 + @testset "Type stability of unflatten" begin + θ = (1.0, ((2.0, 3.0), 4.0)) + x, unflatten = flatten(θ) + @test (@inferred unflatten(x)) == θ + end end @testset "NamedTuple" begin