diff --git a/src/gdtw.jl b/src/gdtw.jl index a4c081e..fbb1b8a 100644 --- a/src/gdtw.jl +++ b/src/gdtw.jl @@ -57,12 +57,27 @@ function update_τ!(τ, t, M, l, u) end """ - gdtw(x, y; M::Int=100, N=nothing, t=range(0, stop=1, length=N), λcum=0.01, - λinst=0.01, η=1 / 8, max_iters=3, metric=(x,y) -> norm(x-y), - Rcum=u -> u * u, smin::Real=0.001, smax::Real=5.0, - Rinst=u -> smin <= u <= smax ? u^2 : Inf, - verbose=false, - cache=GDTWWorkspace(M, N), warp=zeros(N)) + prepare_gdtw( + x, + y, + ::Type{T} = Float64; + M::Int = 100, + N = 100, + t = range(T(0), stop = T(1), length = N), + λcum = T(0.01), + λinst = T(0.01), + η = T(1 / 8), + max_iters = 3, + metric = (x, y) -> norm(x - y), + Rcum = u -> u^2, + smin::Real = T(0.001), + smax::Real = T(5.0), + Rinst = u -> smin <= u <= smax ? u^2 : typemax(T), + verbose = false, + warp = zeros(length(t)), + callback = nothing, + cache::GDTWWorkspace{T} = GDTWWorkspace(M, length(t)), + ) Computes a general DTW distance following [DB19](https://arxiv.org/abs/1905.12893). The parameters are: @@ -78,7 +93,7 @@ Computes a general DTW distance following [DB19](https://arxiv.org/abs/1905.1289 The following may be pre-allocated and reused between distance computations with the same `M` and `N` (or `length(t)`). -* `cache`: a cache of matrices and vectors, generated by `GDTW.GDTWWorkspace{Float64}(N,M)` +* `cache`: a cache of matrices and vectors, generated by `GDTW.GDTWWorkspace{T}(N,M)` """ gdtw(args...; kwargs...) = iterative_gdtw!(prepare_gdtw(args...; kwargs...)) @@ -89,13 +104,27 @@ gdtw(args...; kwargs...) = iterative_gdtw!(prepare_gdtw(args...; kwargs...)) Creates a NamedTuple of parameters, using the same keyword argments as `dist`. A preprocessing step before calling `iterative_gdtw!`. """ -function prepare_gdtw(x, y; M::Int=100, N=100, t = range(0, stop=1, length=N), λcum=0.01, - λinst=0.01, η=1 / 8, max_iters=3, metric=(x,y) -> norm(x-y), - Rcum=u -> u^2, smin::Real=0.001, smax::Real=5.0, - Rinst=u -> smin <= u <= smax ? u^2 : Inf, - verbose=false, - cache=GDTWWorkspace(M, length(t)), warp=zeros(length(t)), - callback=nothing) +function prepare_gdtw( + x, + y, + ::Type{T} = Float64; + M::Int = 100, + N = 100, + t = range(T(0), stop = T(1), length = N), + cache::GDTWWorkspace = GDTWWorkspace{T}(M, length(t)), + λcum = T(0.01), + λinst = T(0.01), + η = T(1 / 8), + max_iters = 3, + metric = (x, y) -> norm(x - y), + Rcum = u -> u^2, + smin::Real = T(0.001), + smax::Real = T(5.0), + Rinst = u -> smin <= u <= smax ? u^2 : typemax(T), + verbose = false, + warp = zeros(length(t)), + callback = nothing, +) where T N = length(t) if !(M > N / smax) @@ -110,15 +139,30 @@ function prepare_gdtw(x, y; M::Int=100, N=100, t = range(0, stop=1, length=N), node_weight(j, s) = metric(x(τ[j, s]), y(t[s])) + λcum * Rcum(τ[j, s] - t[s]) @inline function edge_weight((j, s), (k, s2)) - s + 1 ≠ s2 && return Inf + s + 1 ≠ s2 && return typemax(T) u = (τ[k, s+1] - τ[j, s]) / (t[s+1] - t[s]) λinst * Rinst(u) end - return (N=N, M=M, τ=τ, node_weight=node_weight, edge_weight=edge_weight, - l₀=l₀, u₀=u₀, η=η, max_iters=max_iters, t=t, smin=smin, smax=smax, - callback=callback, verbose=verbose, metric=metric, cache=cache, - warp=warp) + return ( + N = N, + M = M, + τ = τ, + node_weight = node_weight, + edge_weight = edge_weight, + l₀ = l₀, + u₀ = u₀, + η = η, + max_iters = max_iters, + t = t, + smin = smin, + smax = smax, + callback = callback, + verbose = verbose, + metric = metric, + cache = cache, + warp = warp, + ) end @@ -150,7 +194,7 @@ function iterative_gdtw!(data) update_τ!(τ, t, M, l, u) cost = single_gdtw!(data) if callback !== nothing - callback((iter=iter, t=t, τ=τ, warp=warp, cost=cost)) + callback((iter = iter, t = t, τ = τ, warp = warp, cost = cost)) end iter += 1 end diff --git a/test/runtests.jl b/test/runtests.jl index 217a6e5..9f4d4a4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,9 +5,13 @@ using Distances, Plots @testset "DynamicAxisWarping" begin @info "Testing DynamicAxisWarping" - include("test_gdtw.jl") + @testset "GDTW" begin + @info "Testing GDTW" + include("test_gdtw.jl") + end @testset "LinearInterpolation" begin + @info "Testing LinearInterpolation" # Test arrays x = rand(20, 20, 100) x_interp = LinearInterpolation(x) @@ -27,7 +31,7 @@ using Distances, Plots @test x_interp(1) == x[end] @test x_interp(4/99) == x[5] end - + @testset "Normalizers" begin @info "Testing Normalizers" a = randn(2,100)