Skip to content

Commit

Permalink
Generic types in gdtw
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed Jun 11, 2020
1 parent 62c8af2 commit 30d8119
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 22 deletions.
84 changes: 64 additions & 20 deletions src/gdtw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,27 @@ function update_τ!(τ, t, M, l, u)
end

"""
gdtw(x, y; M::Int=100, N=nothing, t=range(0, stop=1, length=N), λcum=0.01,
λinst=0.01, η=1 / 8, max_iters=3, metric=(x,y) -> norm(x-y),
Rcum=u -> u * u, smin::Real=0.001, smax::Real=5.0,
Rinst=u -> smin <= u <= smax ? u^2 : Inf,
verbose=false,
cache=GDTWWorkspace(M, N), warp=zeros(N))
prepare_gdtw(
x,
y,
::Type{T} = Float64;
M::Int = 100,
N = 100,
t = range(T(0), stop = T(1), length = N),
λcum = T(0.01),
λinst = T(0.01),
η = T(1 / 8),
max_iters = 3,
metric = (x, y) -> norm(x - y),
Rcum = u -> u^2,
smin::Real = T(0.001),
smax::Real = T(5.0),
Rinst = u -> smin <= u <= smax ? u^2 : typemax(T),
verbose = false,
warp = zeros(length(t)),
callback = nothing,
cache::GDTWWorkspace{T} = GDTWWorkspace(M, length(t)),
)
Computes a general DTW distance following [DB19](https://arxiv.org/abs/1905.12893). The parameters are:
Expand All @@ -78,7 +93,7 @@ Computes a general DTW distance following [DB19](https://arxiv.org/abs/1905.1289
The following may be pre-allocated and reused between distance computations with the same `M` and `N` (or `length(t)`).
* `cache`: a cache of matrices and vectors, generated by `GDTW.GDTWWorkspace{Float64}(N,M)`
* `cache`: a cache of matrices and vectors, generated by `GDTW.GDTWWorkspace{T}(N,M)`
"""
gdtw(args...; kwargs...) = iterative_gdtw!(prepare_gdtw(args...; kwargs...))
Expand All @@ -89,13 +104,27 @@ gdtw(args...; kwargs...) = iterative_gdtw!(prepare_gdtw(args...; kwargs...))
Creates a NamedTuple of parameters, using the same keyword argments as `dist`.
A preprocessing step before calling `iterative_gdtw!`.
"""
function prepare_gdtw(x, y; M::Int=100, N=100, t = range(0, stop=1, length=N), λcum=0.01,
λinst=0.01, η=1 / 8, max_iters=3, metric=(x,y) -> norm(x-y),
Rcum=u -> u^2, smin::Real=0.001, smax::Real=5.0,
Rinst=u -> smin <= u <= smax ? u^2 : Inf,
verbose=false,
cache=GDTWWorkspace(M, length(t)), warp=zeros(length(t)),
callback=nothing)
function prepare_gdtw(
x,
y,
::Type{T} = Float64;
M::Int = 100,
N = 100,
t = range(T(0), stop = T(1), length = N),
cache::GDTWWorkspace = GDTWWorkspace{T}(M, length(t)),
λcum = T(0.01),
λinst = T(0.01),
η = T(1 / 8),
max_iters = 3,
metric = (x, y) -> norm(x - y),
Rcum = u -> u^2,
smin::Real = T(0.001),
smax::Real = T(5.0),
Rinst = u -> smin <= u <= smax ? u^2 : typemax(T),
verbose = false,
warp = zeros(length(t)),
callback = nothing,
) where T
N = length(t)

if !(M > N / smax)
Expand All @@ -110,15 +139,30 @@ function prepare_gdtw(x, y; M::Int=100, N=100, t = range(0, stop=1, length=N),
node_weight(j, s) = metric(x(τ[j, s]), y(t[s])) + λcum * Rcum(τ[j, s] - t[s])

@inline function edge_weight((j, s), (k, s2))
s + 1 s2 && return Inf
s + 1 s2 && return typemax(T)
u = (τ[k, s+1] - τ[j, s]) / (t[s+1] - t[s])
λinst * Rinst(u)
end

return (N=N, M=M, τ=τ, node_weight=node_weight, edge_weight=edge_weight,
l₀=l₀, u₀=u₀, η=η, max_iters=max_iters, t=t, smin=smin, smax=smax,
callback=callback, verbose=verbose, metric=metric, cache=cache,
warp=warp)
return (
N = N,
M = M,
τ = τ,
node_weight = node_weight,
edge_weight = edge_weight,
l₀ = l₀,
u₀ = u₀,
η = η,
max_iters = max_iters,
t = t,
smin = smin,
smax = smax,
callback = callback,
verbose = verbose,
metric = metric,
cache = cache,
warp = warp,
)
end


Expand Down Expand Up @@ -150,7 +194,7 @@ function iterative_gdtw!(data)
update_τ!(τ, t, M, l, u)
cost = single_gdtw!(data)
if callback !== nothing
callback((iter=iter, t=t, τ=τ, warp=warp, cost=cost))
callback((iter = iter, t = t, τ = τ, warp = warp, cost = cost))
end
iter += 1
end
Expand Down
8 changes: 6 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ using Distances, Plots
@testset "DynamicAxisWarping" begin
@info "Testing DynamicAxisWarping"

include("test_gdtw.jl")
@testset "GDTW" begin
@info "Testing GDTW"
include("test_gdtw.jl")
end

@testset "LinearInterpolation" begin
@info "Testing LinearInterpolation"
# Test arrays
x = rand(20, 20, 100)
x_interp = LinearInterpolation(x)
Expand All @@ -27,7 +31,7 @@ using Distances, Plots
@test x_interp(1) == x[end]
@test x_interp(4/99) == x[5]
end

@testset "Normalizers" begin
@info "Testing Normalizers"
a = randn(2,100)
Expand Down

0 comments on commit 30d8119

Please sign in to comment.