From 5c7f92d392f1d832725aaadd6ffee82b494609bb Mon Sep 17 00:00:00 2001 From: Matthew Kelley Date: Sat, 7 Mar 2020 12:20:39 -0700 Subject: [PATCH 1/3] Fix DBA to work with Julia 1.x --- Project.toml | 3 ++- src/dba.jl | 16 +++++++++------- test/runtests.jl | 8 ++++++++ 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 9c4c01e..7a514c6 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ uuid = "c3fb160f-4a10-5553-b683-e707b00e83ce" [deps] BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Parameters = "d96e819e-fc66-5662-9728-84c9c7592b0a" Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" @@ -13,4 +14,4 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] -julia = "≥ 0.7.0" \ No newline at end of file +julia = "≥ 0.7.0" diff --git a/src/dba.jl b/src/dba.jl index f748153..8e37066 100644 --- a/src/dba.jl +++ b/src/dba.jl @@ -1,3 +1,5 @@ +using LinearAlgebra + """ DBAResult(cost,converged,iterations,cost_trace) @@ -18,14 +20,14 @@ and the current estimate of the average sequence. Example usage: - x = [1,2,2,3,3,4] - y = [1,3,4] - z = [1,2,2,4] + x = Sequence([1., 2., 2., 3., 3., 4.]) + y = Sequence([1., 3., 4.]) + z = Sequence([1., 2., 2., 4.]) avg,result = dba([x,y,z]) """ function dba( sequences::AbstractVector{T}, - method::DTWMethod, + method::DTWMethod = ClassicDTW(), dist::SemiMetric = SqEuclidean(); init_center::T = rand(sequences), iterations::Int = 1000, @@ -111,13 +113,13 @@ function dba_iteration!( total_cost = 0.0 # store stats for barycenter averages - scale!(counts,0) - scale!(newavg,0) + rmul!(counts,0) + rmul!(newavg,0) # main ploop for seq in sequences # time warp signal versus average - # if one of the two is empty, use unconstrained window. If both are nonempty, but not the same lenght, distpath will throw error + # if one of the two is empty, use unconstrained window. If both are nonempty, but not the same length, distpath will throw error if isempty(i2min) && isempty(i2max) cost, i1, i2 = distpath(d, oldavg, seq) else diff --git a/test/runtests.jl b/test/runtests.jl index ead2b07..b80ede0 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -298,3 +298,11 @@ end @test px == qx @test py == qy end + +@testset "DBA" begin + x = Sequence([1., 2., 2., 3., 3., 4.]) + y = Sequence([1., 3., 4.]) + z = Sequence([1., 2., 2., 4.]) + avg, _ = dba([x, y, z], init_center=z) + @test avg == [1.0, 1.75, 2.75, 4.0] +end From 78c36f8f9f04aab45f94edb36a5e5d284bef4d54 Mon Sep 17 00:00:00 2001 From: Matthew Kelley Date: Sun, 8 Mar 2020 15:31:00 -0600 Subject: [PATCH 2/3] Update dba to work with multi-dim Sequences --- src/dba.jl | 3 ++- src/sequence.jl | 12 ++---------- 2 files changed, 4 insertions(+), 11 deletions(-) diff --git a/src/dba.jl b/src/dba.jl index 8e37066..5cc954d 100644 --- a/src/dba.jl +++ b/src/dba.jl @@ -46,7 +46,8 @@ function dba( dbavg = deepcopy(init_center) # storage for each iteration - newavg = Sequence(zeros(length(dbavg))) +# newavg = Sequence(zeros(size(dbavg))) + newavg = Sequence(zeros(size(dbavg.val))) counts = zeros(Int, length(dbavg)) # variables storing optimization progress diff --git a/src/sequence.jl b/src/sequence.jl index 1e95f42..5354b33 100644 --- a/src/sequence.jl +++ b/src/sequence.jl @@ -37,14 +37,6 @@ end :( x.val[@ntuple($N, (n-> n==$N ? i : Colon()))...] = val ) end -# convert a sequence back into an array -# in case sequences have different lenght, throw error -function seq_to_array(seq::AbstractVector{T}) where T <: Sequence - len_seq = length(seq[1]) - arr = zeros(typeof(seq[1][1]),len_seq,length(seq)) - for i=1:length(seq) - length(seq[i]) != len_seq ? error("Sequences do not have the same length, cannot construct array") : nothing - arr[:,i] = seq[i][:] - end - return arr +function seq_to_array(seq::Sequence{N}) where N + return cat(seq..., dims=N) end From 5b6b67f6936ea2dfbd12c32dc26703909879bff2 Mon Sep 17 00:00:00 2001 From: Matthew Kelley Date: Sun, 8 Mar 2020 15:43:40 -0600 Subject: [PATCH 3/3] Remove commented line --- src/dba.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/dba.jl b/src/dba.jl index 5cc954d..89cb115 100644 --- a/src/dba.jl +++ b/src/dba.jl @@ -46,7 +46,6 @@ function dba( dbavg = deepcopy(init_center) # storage for each iteration -# newavg = Sequence(zeros(size(dbavg))) newavg = Sequence(zeros(size(dbavg.val))) counts = zeros(Int, length(dbavg))