Skip to content

Commit

Permalink
Update distance interface
Browse files Browse the repository at this point in the history
  • Loading branch information
baggepinnen committed May 25, 2020
1 parent 9ba271d commit f235e11
Show file tree
Hide file tree
Showing 11 changed files with 534 additions and 366 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DynamicAxisWarping"
uuid = "aaaaaaaa-4a10-5553-b683-e707b00e83ce"
authors = ["Fredrik Bagge Carlson <[email protected]>"]
version = "0.1.5"
version = "0.2.0"

[deps]
BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee"
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ da = ReverseDiff.gradient(a->soft_dtw_cost(a,b; γ=1), a)
Zygote.jl will not work due to the array-mutation limitation.
See also function `soft_dtw_cost_matrix`.

The following [example](https://github.com/baggepinnen/DynamicAxisWarping.jl/blob/master/examples/softDTW.jl) illustrates how to calculate a barycenter using Soft DTW and [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), the result is shown below.
The following [example](https://github.com/baggepinnen/DynamicAxisWarping.jl/blob/master/examples/softDTW.jl) illustrates how to calculate a barycenter (generalized average) using Soft DTW and [Optim.jl](https://github.com/JuliaNLSolvers/Optim.jl), the result is shown below, together with three instances of the input series

![barycenter](examples/barycenter.svg)

Expand Down Expand Up @@ -160,9 +160,10 @@ See the file [`frequency_warping.jl`](https://github.com/baggepinnen/DynamicAxis
## Distances.jl interface

```julia
d = DTWDistance(method=DTW(radius), dist=SqEuclidean())
d = DTW(radius=radius, dist=SqEuclidean()) # Or FastDTW / SoftDTW
d(a,b)
```
`method` can be either of `DTW()` or `FastDTW(radius)`.


## Acknowledgements

Expand Down
656 changes: 397 additions & 259 deletions examples/barycenter.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
15 changes: 9 additions & 6 deletions examples/softDTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@ function dtwmeancost(data, γ)
end
end

input = c1[1,:]
input = mean(c1v) # Our initial guess will be the Euclidean mean
costfun = dtwmeancost(c1v, 1)
costfun(input)

cfg = ReverseDiff.GradientConfig(input)
tape = ReverseDiff.GradientTape(costfun, input)
ctape = ReverseDiff.compile(tape)
results = DiffResults.GradientResult(similar(input))
ReverseDiff.CompiledGradient


function fg!(F,G,x)
if G != nothing
Expand Down Expand Up @@ -55,10 +55,13 @@ res = Optim.optimize(
x_tol = 1e-3,
f_tol = 1e-4,
g_tol = 1e-4,
f_calls_limit = 0,
g_calls_limit = 0,
),
)

plot(c1', legend=false)
plot!(res.minimizer, l=(4, :red))
##
using Plots, Plots.PlotMeasures
f1 = plot(c1', lab="", axis=false, legend=:bottom)
plot!(input, l=(4, :red), lab="Euclidean mean")
plot!(res.minimizer, l=(4, :green), lab="Soft-DTW mean")
f2 = plot(c1[1:3,:]', layout=(1,3), legend=false, axis=false, margin=-5mm)
plot(f1,f2, layout=(2,1))
2 changes: 1 addition & 1 deletion src/DynamicAxisWarping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ export dtw,
dtw_cost,
soft_dtw_cost,
DTWDistance,
DTWMethod,
DTW,
SoftDTW,
FastDTW,
distpath,
dba,
Expand Down
11 changes: 3 additions & 8 deletions src/dba.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ mutable struct DBAResult
end

"""
avgseq, results = dba(sequences, [dist=SqEuclidean()]; kwargs...)
avgseq, results = dba(sequences, dist::DTWDistance; kwargs...)
Perfoms DTW Barycenter Averaging (DBA) given a collection of `sequences`
and the current estimate of the average sequence.
Expand All @@ -23,12 +23,11 @@ Example usage:
x = [1., 2., 2., 3., 3., 4.]
y = [1., 3., 4.]
z = [1., 2., 2., 4.]
avg,result = dba([x,y,z])
avg,result = dba([x,y,z], DTW(3))
"""
function dba(
sequences::AbstractVector,
method::DTWMethod,
dist::SemiMetric = SqEuclidean();
dtwdist::DTWDistance;
init_center = rand(sequences),
iterations::Int = 1000,
rtol::Float64 = 1e-5,
Expand All @@ -38,9 +37,6 @@ function dba(
i2max::AbstractVector = [],
)

# method for computing dtw
dtwdist = DTWDistance(method, dist)

# initialize dbavg as a random sample from the dataset
nseq = length(sequences)
dbavg = deepcopy(init_center)
Expand Down Expand Up @@ -107,7 +103,6 @@ end


"""
newavg, cost = dba_iteration(dbavg, sequences, dist)
Performs one iteration of DTW Barycenter Averaging (DBA) given a collection of
`sequences` and the current estimate of the average sequence, `dbavg`. Returns
Expand Down
101 changes: 49 additions & 52 deletions src/dbaclust.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,69 +13,73 @@ mutable struct DBAclustResult{T}
end

"""
avgseq, results = dba(sequences, [dist=SqEuclidean()]; kwargs...)
Perfoms DTW Barycenter Averaging (DBA) given a collection of `sequences`
and the current estimate of the average sequence.
dbaclust(
sequences,
nclust::Int,
dtwdist::DTWDistance;
n_init::Int = 1,
iterations::Int = 100,
inner_iterations::Int = 10,
rtol::Float64 = 1e-4,
rtol_inner::Float64 = rtol,
n_jobs::Int = 1,
show_progress::Bool = true,
store_trace::Bool = true,
i2min::AbstractVector = [],
i2max::AbstractVector = [],
)
Example usage:
x = [1,2,2,3,3,4]
y = [1,3,4]
z = [1,2,2,4]
avg,result = dba([x,y,z])
# Arguments:
- `nclust`: Number of clsuters
- `n_init`: Number of initialization tries
- `inner_iterations`: Number of iterations in the inner alg.
- `i2min`: Bounds on the warping path
"""
function dbaclust(
sequences,
nclust::Int,
_method::DTWMethod,
_dist::SemiMetric = SqEuclidean();
dtwdist::DTWDistance;
n_init::Int = 1,
iterations::Int = 100,
inner_iterations::Int = 10,
rtol::Float64 = 1e-4,
rtol_inner::Float64 = rtol,
n_jobs::Int = 1,
show_progress::Bool = true,
store_trace::Bool = true,
i2min::AbstractVector = [],
i2max::AbstractVector = [],
)

if n_jobs == 1
best_result = []
best_cost = []
for i = 1:n_init
dbaclust_result = dbaclust_single(
sequences,
nclust,
_method,
_dist;
iterations = iterations,
inner_iterations = inner_iterations,
rtol = rtol,
rtol_inner = rtol_inner,
show_progress = show_progress,
store_trace = store_trace,
i2min = i2min,
i2max = i2max,
)
if isempty(best_cost) || dbaclust_result.dbaresult.cost < best_cost
best_result = deepcopy(dbaclust_result)
best_cost = best_result.dbaresult.cost
end
end #1:n_init
else
#results = Array{DBAclustResult}(n)_
error("parallelism for dbaclust not implemented yet")

end # n_jobs
best_result = []
best_cost = []
for i = 1:n_init
dbaclust_result = dbaclust_single(
sequences,
nclust,
dtwdist;
iterations = iterations,
inner_iterations = inner_iterations,
rtol = rtol,
rtol_inner = rtol_inner,
show_progress = show_progress,
store_trace = store_trace,
i2min = i2min,
i2max = i2max,
)
if isempty(best_cost) || dbaclust_result.dbaresult.cost < best_cost
best_result = deepcopy(dbaclust_result)
best_cost = best_result.dbaresult.cost
end
end #1:n_init

return best_result
end


"""
avgseq, results = dbaclust_single(sequences, [dist=SqEuclidean()]; kwargs...)
avgseq, results = dbaclust_single(sequences, dist; kwargs...)
Perfoms a single DTW Barycenter Averaging (DBA) given a collection of `sequences`
and the current estimate of the average sequence.
Expand All @@ -85,18 +89,16 @@ Example usage:
x = [1,2,2,3,3,4]
y = [1,3,4]
z = [1,2,2,4]
avg,result = dba([x,y,z])
avg,result = dba([x,y,z], DTW(3))
"""
function dbaclust_single(
sequences::AbstractVector,
nclust::Int,
_method::DTWMethod,
_dist::SemiMetric = SqEuclidean();
dtwdist::DTWDistance;
init_centers::AbstractVector = dbaclust_initial_centers(
sequences,
nclust,
_method,
_dist,
dtwdist
),
iterations::Int = 100,
inner_iterations::Int = 10,
Expand All @@ -122,8 +124,6 @@ function dbaclust_single(
nseq = length(sequences)
maxseqlen = maximum([length(s) for s in sequences])

# initialize procedure for computing DTW
dtwdist = DTWDistance(_method, _dist)

# TODO switch to ntuples?
counts = [zeros(Int, N) for _ = 1:nclust]
Expand Down Expand Up @@ -300,20 +300,17 @@ end


"""
dbaclust_initial_centers(sequences, nclust, dist)
dbaclust_initial_centers(sequences, nclust, dtwdist::DTWDistance)
Uses kmeans++ (but with dtw distance) to initialize the centers
for dba clustering.
"""
function dbaclust_initial_centers(
sequences::AbstractVector,
nclust::Int,
_method::DTWMethod,
_dist::Union{SemiMetric, Function} = SqEuclidean();
dtwdist::DTWDistance,
)

# procedure for calculating dtw
dtwdist = DTWDistance(_method, _dist)
# number of sequences in dataset
nseq = length(sequences)
# distance of each datapoint to each center
Expand Down
64 changes: 47 additions & 17 deletions src/distance_interface.jl
Original file line number Diff line number Diff line change
@@ -1,35 +1,65 @@
# methods for estimating dtw #
abstract type DTWMethod end

struct DTW <: DTWMethod
abstract type DTWDistance{D <: Union{Function, Distances.PreMetric}} end


"""
struct DTW{D} <: DTWDistance{D}
# Arguments:
- `radius`: The maximum allowed deviation of the matching path from the diagonal
- `dist`: Inner distance
- `transportcost` If >1, an additional penalty factor for non-diagonal moves is added.
"""
Base.@kwdef struct DTW{D} <: DTWDistance{D}
"The maximum allowed deviation of the matching path from the diagonal"
radius::Int
dist::D = SqEuclidean()
"If >1, an additional penalty factor for non-diagonal moves is added."
transportcost::Float64
DTW(r,transportcost=1) = new(r,transportcost)
transportcost::Float64 = 1.0
DTW(r,dist=SqEuclidean(),transportcost=1) = new{typeof(dist)}(r,dist,transportcost)
end

"""
struct SoftDTW{D, T} <: DTWDistance{D}
struct FastDTW <: DTWMethod
radius::Int
# Arguments:
- `γ`: smoothing parameter
- `dist`
- `transportcost`
"""
Base.@kwdef struct SoftDTW{D,T} <: DTWDistance{D}
γ::T
"The maximum allowed deviation of the matching path from the diagonal"
dist::D = SqEuclidean()
"If >1, an additional penalty factor for non-diagonal moves is added."
transportcost::Float64 = 1.0
SoftDTW=1.0, dist=SqEuclidean(),transportcost=1) = new{typeof(dist), typeof(γ)}(γ,dist,transportcost)
end

# distance interface #
Base.@kwdef struct DTWDistance{M<:DTWMethod,D<:SemiMetric} <: SemiMetric
method::M

"""
struct FastDTW{D} <: DTWDistance{D}
- `radius`
- `dist` inner distance
"""
Base.@kwdef struct FastDTW{D} <: DTWDistance{D}
radius::Int
dist::D = SqEuclidean()
end

DTWDistance(m::DTWMethod) = DTWDistance(m, SqEuclidean())

Distances.evaluate(d::DTWDistance{DTW}, x, y) = dtw_cost(x, y, d.dist, d.method.radius)
Distances.evaluate(d::DTWDistance{FastDTW}, x, y) =
fastdtw(x, y, d.dist, d.method.radius)[1]

distpath(d::DTWDistance{DTW}, x, y) = dtw(x, y, d.dist)
distpath(d::DTWDistance{DTW}, x, y, i2min::AbstractVector, i2max::AbstractVector) =
Distances.evaluate(d::DTW, x, y) = dtw_cost(x, y, d.dist, d.radius)
Distances.evaluate(d::SoftDTW, x, y) = soft_dtw_cost(x, y, d.dist, γ=d.γ)
Distances.evaluate(d::FastDTW, x, y) =
fastdtw(x, y, d.dist, d.radius)[1]

distpath(d::DTW, x, y) = dtw(x, y, d.dist)
distpath(d::DTW, x, y, i2min::AbstractVector, i2max::AbstractVector) =
dtw(x, y, i2min, i2max, d.dist)
distpath(d::DTWDistance{FastDTW}, x, y) = fastdtw(x, y, d.dist, d.method.radius)
distpath(d::FastDTW, x, y) = fastdtw(x, y, d.dist, d.radius)

(d::DTWDistance)(x,y) = Distances.evaluate(d,x,y)

Expand All @@ -43,6 +73,6 @@ function distance_profile(d::DTWDistance, Q::AbstractArray{S}, T::AbstractArray{
n = lastlength(T)
n >= m || throw(ArgumentError("Q cannot be longer than T"))
l = n-m+1
res = dtwnn(Q, Y, d.dist, d.method.radius; saveall=true, kwargs...)
res = dtwnn(Q, Y, d.dist, d.radius; saveall=true, kwargs...)
res.dists
end
6 changes: 3 additions & 3 deletions src/matrix_profile.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@

function MatrixProfile.matrix_profile(T, m::Int, dist::DTWDistance; showprogress=true, normalizer=Val(Nothing))
function MatrixProfile.matrix_profile(T, m::Int, dist::DTW; showprogress=true, normalizer=Val(Nothing))
n = lastlength(T)
l = n-m+1
r = dist.method.radius
r = dist.radius
# n > 2m+1 || throw(ArgumentError("Window length too long, maximum length is $((n+1)÷2)"))
P = Vector{floattype(T)}(undef, l)
I = Vector{Int}(undef, l)
prog = Progress((l - 1) ÷ 5, dt=1, desc="Matrix profile", barglyphs = BarGlyphs("[=> ]"), color=:blue)
bsf = typemax(floattype(T))
@inbounds for i = 1:l
Ti = getwindow(T,m,i)
res = dtwnn(Ti, T, dist.dist, r, transportcost=dist.method.transportcost, avoid=i-r:i+r, normalizer=normalizer)
res = dtwnn(Ti, T, dist.dist, r, transportcost=dist.transportcost, avoid=i-r:i+r, normalizer=normalizer)
I[i] = res.loc
P[i] = res.cost
showprogress && i % 5 == 0 && next!(prog)
Expand Down
Loading

0 comments on commit f235e11

Please sign in to comment.