Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

framework for recurrect neural networks #254

Open
wants to merge 190 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 126 commits
Commits
Show all changes
190 commits
Select commit Hold shift + click to select a range
6903f19
plagiarise datasetclient for now
lewardo Aug 22, 2023
46ecaf3
rename object to dataseries
lewardo Aug 22, 2023
9993a1e
functional clone of dataset
lewardo Aug 22, 2023
0aa7d8b
dataseries CTORs
lewardo Aug 23, 2023
1bd2e93
addSeries member
lewardo Aug 23, 2023
12e52c4
getSeries member
lewardo Aug 23, 2023
f54ce93
addFrame/getFrame members
lewardo Aug 23, 2023
725342b
get member function
lewardo Aug 23, 2023
8ffcfd3
modify member getters
lewardo Aug 23, 2023
8a9a92d
change container type to vector of tensors
lewardo Aug 23, 2023
c8aa44c
modify initFromData for vector change
lewardo Aug 23, 2023
a1b89a8
removeSeries/Frame and updateSeries/Frame member functions
lewardo Aug 23, 2023
af564f9
remove superfluous dataset members and convert names
lewardo Aug 23, 2023
db247f6
remove kFrameLen paramter
lewardo Aug 23, 2023
2774b29
add const qualifiers to input views
lewardo Aug 23, 2023
d875bd4
add `DataSeries` to libmanipulation
lewardo Aug 23, 2023
9e116fa
fix constness point setting issues
lewardo Aug 23, 2023
66897f9
xvalue views for casting issues
lewardo Aug 23, 2023
534fd36
printing shenanigains
lewardo Aug 23, 2023
9536952
series const getter
lewardo Aug 23, 2023
4576eb9
fix json (de)serialisation
lewardo Aug 23, 2023
2145f25
deft solution for tensor vector to tensorview vector casting
lewardo Aug 23, 2023
6e93be4
acc incrementing the right iterator might help
lewardo Aug 23, 2023
14a9d8e
const casting shenanigains
lewardo Aug 23, 2023
eab2162
change buffer frame return paradigm
lewardo Aug 23, 2023
4459cf9
rename messages
lewardo Aug 23, 2023
adb599d
fix assert crash by casting new points properly
lewardo Aug 23, 2023
440bfa6
printing now neat and tidyy
lewardo Aug 23, 2023
5503969
fixed read/write operations
lewardo Aug 23, 2023
5a34df7
replace asserts with errors
lewardo Aug 23, 2023
327ed07
consistent whitespace formatting
lewardo Aug 23, 2023
08ec779
deleteSeries message
lewardo Aug 23, 2023
9e41fb0
view converting methods for buffer `float` shenanigains
lewardo Aug 24, 2023
409d3cd
fix printing issues
lewardo Aug 24, 2023
99015b8
remove superflous reference creation
lewardo Aug 24, 2023
06692a0
convert to template view function calls
lewardo Aug 24, 2023
a6a23ef
addSeries message
lewardo Aug 24, 2023
77251f5
register new series-level messages
lewardo Aug 24, 2023
2845d37
regroup member functions
lewardo Aug 24, 2023
784fff6
`getSeries` message
lewardo Aug 24, 2023
2c46d51
`setSeries` message
lewardo Aug 24, 2023
6128d91
`updateSeries` message
lewardo Aug 24, 2023
54d48c6
`getDataSet` message
lewardo Aug 24, 2023
3ab06c2
actually pointing to the right member might be helpful
lewardo Aug 24, 2023
7d15aaf
fix `deleteframe` case with single frame in series
lewardo Aug 24, 2023
ce0efe8
add `toBuffer` and `fromBuffer` message aliases
lewardo Aug 28, 2023
892b3a6
use custom allocator on std containers
lewardo Aug 30, 2023
1e82ed9
formatting
lewardo Aug 30, 2023
dde8232
convert `dataseries` template operations for consistency
lewardo Aug 30, 2023
f1ede98
fix templature correction
lewardo Aug 31, 2023
220aac9
squash merge `data-series` into `recurrent-networks`
lewardo Sep 4, 2023
840646a
lstm file structure and somewhat boilerplate
lewardo Sep 1, 2023
84ebecc
eigen maps for matrix maths
lewardo Sep 1, 2023
63569ab
namespace shenanigains
lewardo Sep 1, 2023
46ec185
move initialisation to constructor because objects are annoying
lewardo Sep 1, 2023
14bb3ee
acc matrix maths with the maps
lewardo Sep 1, 2023
c633382
random initialisation
lewardo Sep 1, 2023
045b77a
rethink structure of recurrent networks
lewardo Sep 1, 2023
116ed65
rename for consistency
lewardo Sep 1, 2023
309a0bd
lstmstate object
lewardo Sep 1, 2023
a0bc915
begin lstmcell object
lewardo Sep 1, 2023
cc7b6a3
both array and matrix views for fun
lewardo Sep 1, 2023
01779e4
process frame
lewardo Sep 1, 2023
e0dc072
finish renaming with various views and maps
lewardo Sep 1, 2023
0e952d8
derivative matrix maps
lewardo Sep 1, 2023
4644bda
remove mDXH from state
lewardo Sep 1, 2023
6b1942c
backwards frame
lewardo Sep 1, 2023
00504ff
construct state using params
lewardo Sep 1, 2023
e036a09
rename {lstmcell => lstm}.hpp ('twas false advertising)
lewardo Sep 1, 2023
785b055
replace reference with shared/weak ptr locking
lewardo Sep 1, 2023
b8625a5
plural names less confusing
lewardo Sep 1, 2023
007c4da
Recur object members
lewardo Sep 1, 2023
3aa76ee
infer parameter class from Cell class
lewardo Sep 1, 2023
961a798
DataClient boilerplate
lewardo Sep 1, 2023
2c35ca5
lstm json conversions
lewardo Sep 1, 2023
44a09cc
paramtype and hide shared_ptr constructor
lewardo Sep 1, 2023
46c1066
process function
lewardo Sep 1, 2023
868da89
initial algo tests
lewardo Sep 3, 2023
8a57b8a
getstate getter
lewardo Sep 3, 2023
41b8f25
typename specifier and correct member names
lewardo Sep 4, 2023
59e1b3d
acc randomise all the weights and biases
lewardo Sep 4, 2023
c6797ac
actually writing to the output view might be smart
lewardo Sep 4, 2023
b837f24
add backward frame output derivatives
lewardo Sep 4, 2023
2d02060
matrix and vector parameter view containers
lewardo Sep 4, 2023
fd984fa
make inheritant parameters rather than wrapper
lewardo Sep 4, 2023
480fba0
convert cryptic names to eigen-esque `.matrix()` and `.array()`
lewardo Sep 4, 2023
9cbddd7
add size members to state
lewardo Sep 4, 2023
c1bfc4d
typedefing for consistentcy
lewardo Sep 4, 2023
065aedc
actually setting the previous state helps with derivatives
lewardo Sep 4, 2023
05f3d0b
methods accept other states rather than raw vectors
lewardo Sep 4, 2023
9dadf65
fix ambiguous constructor initialiser
lewardo Sep 4, 2023
e8cb5dc
remove superfluous typedefs
lewardo Sep 4, 2023
2c5cc75
added cell typedefs
lewardo Sep 4, 2023
39d6b2f
fit
lewardo Sep 4, 2023
49082dd
process
lewardo Sep 4, 2023
203ac59
cleaner iteration
lewardo Sep 4, 2023
f0347f0
elementary, my dear watson, you's not used the correct damn equation
lewardo Sep 4, 2023
b847494
consistent printing terminology
lewardo Sep 5, 2023
c372c22
check if id exists before resizing buffer
lewardo Sep 5, 2023
2385095
squash merge `data-series` into `lstm-rnn`
lewardo Sep 5, 2023
4a6b1d4
fix stale pointer deallocation issue
lewardo Sep 5, 2023
68c2c68
best practice housekeeping
lewardo Sep 5, 2023
ab23cf7
remove all this `MatrixParam` nonesense, `asEigen` works just fine
lewardo Sep 6, 2023
e7ae9c5
stateful processing
lewardo Sep 6, 2023
538638a
weight normalisation
lewardo Sep 6, 2023
d9d466b
rename for consistency
lewardo Sep 6, 2023
1296c9b
single frame processing for rt applications
lewardo Sep 6, 2023
27cfa81
defer process to processframe
lewardo Sep 6, 2023
ac9b913
rename client file
lewardo Sep 6, 2023
279c0c3
update recur member order
lewardo Sep 6, 2023
a3859db
sgd for N to M and N to 1
lewardo Sep 6, 2023
137d981
RecurSGD helper
lewardo Sep 6, 2023
6c5ced5
also reassign pointer, dont just reset
lewardo Sep 6, 2023
94635cc
helper input output processFrame
lewardo Sep 6, 2023
c5e3f59
initial classifier structure
lewardo Sep 6, 2023
3365ad5
_deep_ lstm: two layers for more control on regression ability
lewardo Sep 7, 2023
1bf2ca0
recur copy constructor
lewardo Sep 7, 2023
e8f1951
better practice for resetting ptrs
lewardo Sep 7, 2023
e908d37
deep recurrence saving
lewardo Sep 7, 2023
cdb4c57
make prediction cells members
lewardo Sep 7, 2023
423cdf0
recursgd asserts
lewardo Sep 7, 2023
8ae448a
correct data helps amrite
lewardo Sep 7, 2023
bab5b14
hardcode edge cases rather than allocating extra memory
lewardo Sep 7, 2023
9525e84
classifierdata object
lewardo Sep 7, 2023
6269e7c
classifier parameters
lewardo Sep 7, 2023
9f743bb
clear, reset
lewardo Sep 7, 2023
20dd738
`fit`
lewardo Sep 7, 2023
c2417d5
`predictpoint`
lewardo Sep 7, 2023
6cd4fd9
`predict`
lewardo Sep 7, 2023
5a2a22a
add messages to object
lewardo Sep 7, 2023
3bf0714
regressor boilerplate (basically classifier)
lewardo Sep 7, 2023
663bd3f
resetting issues for training (need to reset state between `fit`s)
lewardo Sep 7, 2023
e0ceb9e
add classifier and regressor to libmanipulation
lewardo Sep 7, 2023
cb22d0a
move constructor and copy assignment op
lewardo Sep 7, 2023
0199cce
predictpoint into buffer rather than as message
lewardo Sep 8, 2023
9159090
lstmforecast boilerplate
lewardo Sep 8, 2023
2fc0a19
actually use mTrained
lewardo Sep 8, 2023
a0ca860
`std::unique_ptr` already has move semantics brev
lewardo Sep 8, 2023
2d39806
add state-setting process
lewardo Sep 8, 2023
80a29df
saner parameter defaults
lewardo Sep 8, 2023
4895443
proper buffer checks
lewardo Sep 8, 2023
fe445be
forecase parameters
lewardo Sep 8, 2023
58e7612
forecast predictpoint
lewardo Sep 8, 2023
09bc882
forecast predict
lewardo Sep 8, 2023
9122aae
code housekeeping
lewardo Sep 8, 2023
6173b9c
remove `toBuffer` and `fromBuffer` aliases to unconfuse
lewardo Sep 9, 2023
c800a96
added proper dataseries error messages
lewardo Sep 9, 2023
69265ac
negative frame indexing
lewardo Sep 10, 2023
035638a
`getdataset` argument order
lewardo Sep 10, 2023
5972b50
remove `toBuffer` and `fromBuffer` aliases to unconfuse
lewardo Sep 9, 2023
0358053
added proper dataseries error messages
lewardo Sep 9, 2023
33132be
negative frame indexing
lewardo Sep 10, 2023
3bbdd71
`getdataset` argument order
lewardo Sep 10, 2023
7ad4195
add `LSTMForecast` to libmanipulation
lewardo Sep 11, 2023
96b50f4
fix reading/writing recur objects
lewardo Sep 11, 2023
0401b52
typedefs and correct std find algorithm
lewardo Sep 11, 2023
7ab7bdc
its a temporary anyway lets make it clearer
lewardo Sep 11, 2023
9223097
pre-`write`/`dump` checks
lewardo Sep 11, 2023
d87a3bd
you never actually declared that typedef silly
lewardo Sep 11, 2023
09e3b7b
copy/paste only works if all three clients are identical
lewardo Sep 11, 2023
ac22784
iTs cAsE SeNsiTiVE
lewardo Sep 11, 2023
0cf7377
remove bptt.hpp, not acc used in the end
lewardo Sep 11, 2023
30e6388
Merge branch 'main' into lstm-rnn
tremblap Sep 25, 2023
d92ee9e
typo in the client name in the cmake
tremblap Sep 25, 2023
33924d6
update to tip of DataSeriesClient
tremblap Sep 25, 2023
9aeec81
fix temporary xvalue binding issues
lewardo Sep 26, 2023
955969e
predictPoint is predictSeries
tremblap Sep 27, 2023
a6f6e5b
rename `LSTMForecast` to `LSTMForecaster` for consistency
lewardo Sep 29, 2023
52f099a
update recur members
lewardo Oct 1, 2023
e4ffd05
fix inteface semantics
lewardo Oct 1, 2023
4f61647
constructors and assignment
lewardo Oct 1, 2023
81c0232
make all the fitting deep
lewardo Oct 1, 2023
6586138
deep process
lewardo Oct 1, 2023
53a3792
add interface parameters
lewardo Oct 1, 2023
15acbd7
update for deep interface
lewardo Oct 1, 2023
ba4baab
deep lstm saving and loading
lewardo Oct 1, 2023
adb2156
xvalues apparently cant bind to lvalue references
lewardo Oct 2, 2023
f93a40e
the correct identifier usually helps brev
lewardo Oct 2, 2023
579f6c9
added validation to training
lewardo Oct 3, 2023
dca1cfe
momentum implementation and formatting
lewardo Oct 3, 2023
017d2f7
bug get size
tremblap Dec 23, 2023
25dab1f
bug fix forecaster - must enter end of context in output
tremblap Dec 23, 2023
890985f
temporary fix to reload error - json is alphabetical - TODO: potentia…
tremblap Jan 8, 2024
c4bef86
Merge branch 'data-series' into lstm-rnn
tremblap Jan 8, 2024
86a43e8
added back the commented stuff
Jan 23, 2024
a8ef19b
Merge branch 'main' into lstm-rnn
tremblap Feb 29, 2024
09eaf7e
printseries fix print 2nd boundary counter
tremblap Feb 29, 2024
53e3f29
Merge branch 'main' into lstm-rnn
tremblap Mar 7, 2024
b0c466c
remove the T in dataset printing and json dumping
tremblap Mar 7, 2024
15180b1
correct padding per series instead
tremblap Mar 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 19 additions & 15 deletions FlucomaClients.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -142,19 +142,23 @@ add_client(Transients clients/rt/TransientClient.hpp CLASS RTTransientClient )

#lib manipulation client group
add_client(DataSet clients/nrt/DataSetClient.hpp CLASS NRTThreadedDataSetClient GROUP MANIPULATION)
add_client(DataSetQuery clients/nrt/DataSetQueryClient.hpp CLASS NRTThreadedDataSetQueryClient GROUP MANIPULATION)
add_client(DataSeries clients/nrt/DataSeriesClient.hpp CLASS NRTThreadedDataSeriesClient GROUP MANIPULATION)
# add_client(DataSetQuery clients/nrt/DataSetQueryClient.hpp CLASS NRTThreadedDataSetQueryClient GROUP MANIPULATION)
add_client(LabelSet clients/nrt/LabelSetClient.hpp CLASS NRTThreadedLabelSetClient GROUP MANIPULATION)
add_client(KDTree clients/nrt/KDTreeClient.hpp CLASS NRTThreadedKDTreeClient GROUP MANIPULATION)
add_client(KMeans clients/nrt/KMeansClient.hpp CLASS NRTThreadedKMeansClient GROUP MANIPULATION)
add_client(SKMeans clients/nrt/SKMeansClient.hpp CLASS NRTThreadedSKMeansClient GROUP MANIPULATION)
add_client(KNNClassifier clients/nrt/KNNClassifierClient.hpp CLASS NRTThreadedKNNClassifierClient GROUP MANIPULATION)
add_client(KNNRegressor clients/nrt/KNNRegressorClient.hpp CLASS NRTThreadedKNNRegressorClient GROUP MANIPULATION)
add_client(Normalize clients/nrt/NormalizeClient.hpp CLASS NRTThreadedNormalizeClient GROUP MANIPULATION)
add_client(RobustScale clients/nrt/RobustScaleClient.hpp CLASS NRTThreadedRobustScaleClient GROUP MANIPULATION)
add_client(Standardize clients/nrt/StandardizeClient.hpp CLASS NRTThreadedStandardizeClient GROUP MANIPULATION)
add_client(PCA clients/nrt/PCAClient.hpp CLASS NRTThreadedPCAClient GROUP MANIPULATION)
add_client(MDS clients/nrt/MDSClient.hpp CLASS NRTThreadedMDSClient GROUP MANIPULATION)
add_client(UMAP clients/nrt/UMAPClient.hpp CLASS NRTThreadedUMAPClient GROUP MANIPULATION)
add_client(MLPRegressor clients/nrt/MLPRegressorClient.hpp CLASS NRTThreadedMLPRegressorClient GROUP MANIPULATION)
add_client(MLPClassifier clients/nrt/MLPClassifierClient.hpp CLASS NRTThreadedMLPClassifierClient GROUP MANIPULATION)
add_client(Grid clients/nrt/GridClient.hpp CLASS NRTThreadedGridClient GROUP MANIPULATION)
# add_client(KDTree clients/nrt/KDTreeClient.hpp CLASS NRTThreadedKDTreeClient GROUP MANIPULATION)
# add_client(KMeans clients/nrt/KMeansClient.hpp CLASS NRTThreadedKMeansClient GROUP MANIPULATION)
# add_client(SKMeans clients/nrt/SKMeansClient.hpp CLASS NRTThreadedSKMeansClient GROUP MANIPULATION)
# add_client(KNNClassifier clients/nrt/KNNClassifierClient.hpp CLASS NRTThreadedKNNClassifierClient GROUP MANIPULATION)
# add_client(KNNRegressor clients/nrt/KNNRegressorClient.hpp CLASS NRTThreadedKNNRegressorClient GROUP MANIPULATION)
# add_client(Normalize clients/nrt/NormalizeClient.hpp CLASS NRTThreadedNormalizeClient GROUP MANIPULATION)
# add_client(RobustScale clients/nrt/RobustScaleClient.hpp CLASS NRTThreadedRobustScaleClient GROUP MANIPULATION)
# add_client(Standardize clients/nrt/StandardizeClient.hpp CLASS NRTThreadedStandardizeClient GROUP MANIPULATION)
# add_client(PCA clients/nrt/PCAClient.hpp CLASS NRTThreadedPCAClient GROUP MANIPULATION)
# add_client(MDS clients/nrt/MDSClient.hpp CLASS NRTThreadedMDSClient GROUP MANIPULATION)
# add_client(UMAP clients/nrt/UMAPClient.hpp CLASS NRTThreadedUMAPClient GROUP MANIPULATION)
# add_client(MLPRegressor clients/nrt/MLPRegressorClient.hpp CLASS NRTThreadedMLPRegressorClient GROUP MANIPULATION)
# add_client(MLPClassifier clients/nrt/MLPClassifierClient.hpp CLASS NRTThreadedMLPClassifierClient GROUP MANIPULATION)
# add_client(Grid clients/nrt/GridClient.hpp CLASS NRTThreadedGridClient GROUP MANIPULATION)
add_client(LSTMClassifier clients/nrt/LSTMClassifierClient.hpp CLASS NRTThreadedLSTMClassifierClient GROUP MANIPULATION)
add_client(LSTMRegressor clients/nrt/LSTMRegressorClient.hpp CLASS NRTThreadedLSTMRegressorClient GROUP MANIPULATION)
add_client(LSTMForecaster clients/nrt/LSTMForecasterClient.hpp CLASS NRTThreadedLSTMForecasterClient GROUP MANIPULATION)
238 changes: 238 additions & 0 deletions include/algorithms/public/DTW.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
Part of the Fluid Corpus Manipulation Project (http://www.flucoma.org/)
Copyright University of Huddersfield.
Licensed under the BSD-3 License.
See license.md file in the project root for full license information.
This project has received funding from the European Research Council (ERC)
under the European Union’s Horizon 2020 research and innovation programme
(grant agreement No 725899).
*/

#pragma once

#include "../util/FluidEigenMappings.hpp"
#include "../../data/FluidDataSet.hpp"
#include "../../data/FluidIndex.hpp"
#include "../../data/FluidMemory.hpp"
#include "../../data/FluidTensor.hpp"
#include "../../data/TensorTypes.hpp"
#include <Eigen/Core>
#include <cstddef>
#include <iterator>
#include <random>

namespace fluid {
namespace algorithm {


enum class DTWConstraint { kUnconstrained, kIkatura, kSakoeChiba };

// debt of gratitude to the wonderful article on
// https://rtavenar.github.io/blog/dtw.html a better explanation of DTW than any
// other algorithm explanation I've seen

class DTW
{
struct Constraint;

public:
explicit DTW() = default;
~DTW() = default;

void init() {}
void clear() {}

index size() const { return mPNorm; }
constexpr index dims() const { return 0; }
constexpr index initialized() const { return true; }

double process(InputRealMatrixView x1, InputRealMatrixView x2,
DTWConstraint constr = DTWConstraint::kUnconstrained,
index param = 2,
Allocator& alloc = FluidDefaultAllocator()) const
{
ScopedEigenMap<Eigen::VectorXd> x1r(x1.cols(), alloc),
x2r(x2.cols(), alloc);
Constraint constraint(constr, x1.rows(), x2.rows(), param);

mDistanceMetrics.resize(x1.rows(), x2.rows());
mDistanceMetrics.fill(std::numeric_limits<double>::max());

constraint.iterate([&, this](index r, index c) {
x1r = _impl::asEigen<Eigen::Matrix>(x1.row(r));
x2r = _impl::asEigen<Eigen::Matrix>(x2.row(c));

mDistanceMetrics(r, c) = differencePNormToTheP(x1r, x2r);

if (r > 0 || c > 0)
{
double minimum = std::numeric_limits<double>::max();

if (r > 0) minimum = std::min(minimum, mDistanceMetrics(r - 1, c));
if (c > 0) minimum = std::min(minimum, mDistanceMetrics(r, c - 1));
if (r > 0 && c > 0)
minimum = std::min(minimum, mDistanceMetrics(r - 1, c - 1));

mDistanceMetrics(r, c) += minimum;
}
});

return std::pow(mDistanceMetrics(x1.rows() - 1, x2.rows() - 1),
1.0 / mPNorm);
}

private:
mutable RealMatrix mDistanceMetrics;
index mPNorm{2};

// P-Norm of the difference vector
// Lp{vec} = (|vec[0]|^p + |vec[1]|^p + ... + |vec[n-1]|^p + |vec[n]|^p)^(1/p)
// i.e., the 2-norm of a vector is the euclidian distance from the origin
// the 1-norm is the sum of the absolute value of the elements
// To the power P since we'll be summing multiple Norms together and they
// can combine into a single norm if you calculate the norm of multiple norms
// (normception)
inline double
differencePNormToTheP(const Eigen::Ref<const Eigen::VectorXd>& v1,
const Eigen::Ref<const Eigen::VectorXd>& v2) const
{
// assert(v1.size() == v2.size());
return (v1.array() - v2.array()).abs().pow(mPNorm).sum();
}

// fun little fold operation to do a variadic minimum
template <typename... Args>
inline static auto min(Args&&... args)
{
auto m = (args, ...);
return ((m = std::min(m, args)), ...);
}

// filter for minimum chaining, if cond evaluates to false then the value
// isn't used (never will be the minimum if its the numeric maximum)
template <typename T>
inline static T useIf(bool cond, T val)
{
return cond ? val : std::numeric_limits<T>::max();
}

struct Constraint
{
Constraint(DTWConstraint c, index rows, index cols, float param)
: mType{c}, mRows{rows}, mCols{cols}, mParam{param}
{
// ifn't gradient more than digonal set it to be the diagonal
// (sakoe-chiba with radius 0)
if (c == DTWConstraint::kIkatura)
{
float big = std::max(mRows, mCols), smol = std::min(mRows, mCols);

if (mParam <= big / smol)
{
mType = DTWConstraint::kSakoeChiba;
mParam = 0;
}
}
};

void iterate(std::function<void(index, index)> f)
{
index first, last;

for (index r = 0; r < mRows; ++r)
{
first = firstCol(r);
last = lastCol(r);

for (index c = first; c <= last; ++c) f(r, c);
}
};

private:
DTWConstraint mType;
index mRows, mCols;
float mParam; // mParam is either radius (SC) or gradient (Ik)

inline static index rasterLineMinY(index x1, index y1, float dydx, index x)
{
return std::round(y1 + (x - x1) * dydx);
}

inline static index rasterLineMinY(index x1, index y1, index x2, index y2,
index x)
{
float dy = y2 - y1, dx = x2 - x1;
return rasterLineMinY(x1, y1, dy / dx, x);
}

inline static index rasterLineMaxY(index x1, index y1, float dydx, index x)
{
if (dydx > 1)
return rasterLineMinY(x1, y1, dydx, x + 1) - 1;
else
return rasterLineMinY(x1, y1, dydx, x);
}

inline static index rasterLineMaxY(index x1, index y1, index x2, index y2,
index x)
{
float dy = y2 - y1, dx = x2 - x1;
return rasterLineMaxY(x1, y1, dy / dx, x);
}

index firstCol(index row)
{
switch (mType)
{
case DTWConstraint::kUnconstrained: return 0;

case DTWConstraint::kIkatura: {
index colNorm = rasterLineMinY(mRows - 1, mCols - 1, mParam, row);
index colInv = rasterLineMinY(0, 0, 1 / mParam, row);

index col = std::max(colNorm, colInv);

return col < 0 ? 0 : col > mCols - 1 ? mCols - 1 : col;
}

case DTWConstraint::kSakoeChiba: {
index col = rasterLineMinY(mParam, -mParam, mRows - 1 + mParam,
mCols - 1 - mParam, row);

return col < 0 ? 0 : col > mCols - 1 ? mCols - 1 : col;
}
}

return 0;
};

index lastCol(index row)
{
switch (mType)
{
case DTWConstraint::kUnconstrained: return mCols - 1;

case DTWConstraint::kIkatura: {
index colNorm = rasterLineMaxY(0, 0, mParam, row);
index colInv = rasterLineMaxY(mRows - 1, mCols - 1, 1 / mParam, row);

index col = std::min(colNorm, colInv);

return col < 0 ? 0 : col > mCols - 1 ? mCols - 1 : col;
}

case DTWConstraint::kSakoeChiba: {
index col = rasterLineMaxY(-mParam, mParam, mRows - 1 - mParam,
mCols - 1 + mParam, row);

return col < 0 ? 0 : col > mCols - 1 ? mCols - 1 : col;
}
}

return mCols - 1;
};
}; // struct Constraint
};

} // namespace algorithm
} // namespace fluid
Loading