From 80a50134982f982e48db183a857fffcad270cdf0 Mon Sep 17 00:00:00 2001 From: David OK Date: Tue, 5 Dec 2023 17:09:51 +0000 Subject: [PATCH] MAINT: use C++ concept for the Kalman filter. --- .../KalmanFilter/DistributionConcepts.hpp | 56 +++++++++++++++ .../Sara/KalmanFilter/EigenMatrixConcepts.hpp | 50 +++++++++++++ cpp/src/DO/Sara/KalmanFilter/KalmanFilter.hpp | 26 ------- .../Sara/KalmanFilter/ObservationEquation.hpp | 71 +++++++++++++++++++ .../KalmanFilter/StateTransitionModel.hpp | 25 +++++++ 5 files changed, 202 insertions(+), 26 deletions(-) create mode 100644 cpp/src/DO/Sara/KalmanFilter/DistributionConcepts.hpp create mode 100644 cpp/src/DO/Sara/KalmanFilter/EigenMatrixConcepts.hpp delete mode 100644 cpp/src/DO/Sara/KalmanFilter/KalmanFilter.hpp create mode 100644 cpp/src/DO/Sara/KalmanFilter/ObservationEquation.hpp create mode 100644 cpp/src/DO/Sara/KalmanFilter/StateTransitionModel.hpp diff --git a/cpp/src/DO/Sara/KalmanFilter/DistributionConcepts.hpp b/cpp/src/DO/Sara/KalmanFilter/DistributionConcepts.hpp new file mode 100644 index 000000000..43247c75d --- /dev/null +++ b/cpp/src/DO/Sara/KalmanFilter/DistributionConcepts.hpp @@ -0,0 +1,56 @@ +#pragma once + +#include + +#include +#include + + +namespace DO::Sara::KalmanFilter { + + template + concept GaussianDistribution = requires(T dist) + { + typename T::scalar_type; + typename T::mean_type; + typename T::covariance_matrix_type; + // clang-format off + // Constructor. + { T{typename T::mean_type{}, typename T::covariance_matrix_type{}} } -> std::same_as; + // Methods. + { dist.mean() } -> std::same_as; + { dist.covariance_matrix() } -> std::same_as; + // clang-format on + }; + + template + concept ZeroMeanGaussianDistribution = requires(T dist) + { + typename T::scalar_type; + typename T::mean_type; + typename T::covariance_matrix_type; + // clang-format off + { T{typename T::covariance_matrix_type{}} } -> std::same_as; + { dist.covariance_matrix() } -> std::same_as; + // clang-format on + }; + + template + concept StateDistribution = GaussianDistribution; + + template + concept NoiseDistribution = ZeroMeanGaussianDistribution; + + template + concept FixedSizeStateDistribution = // + GaussianDistribution && // + CompileTimeFixedMatrix && + CompileTimeFixedMatrix; + + template + concept FixedSizeNoiseDistribution = // + ZeroMeanGaussianDistribution && + CompileTimeFixedMatrix && + CompileTimeFixedMatrix; + +} // namespace DO::Sara::KalmanFilter diff --git a/cpp/src/DO/Sara/KalmanFilter/EigenMatrixConcepts.hpp b/cpp/src/DO/Sara/KalmanFilter/EigenMatrixConcepts.hpp new file mode 100644 index 000000000..6e8e9ab2f --- /dev/null +++ b/cpp/src/DO/Sara/KalmanFilter/EigenMatrixConcepts.hpp @@ -0,0 +1,50 @@ +#pragma once + +#include + +#include + + +namespace DO::Sara::KalmanFilter { + + template + concept EigenVector = requires(T) + { + typename T::Scalar; + // clang-format off + { T{}.rows() } -> std::same_as; + { T{}.cols() } -> std::same_as; + { T{}(int{}) } -> std::same_as; + // clang-format on + }; + + template + concept EigenMatrix = requires(T) + { + typename T::Scalar; + // clang-format off + { T{}.rows() } -> std::same_as; + { T{}.cols() } -> std::same_as; + { T{}(int{}, int{}) } -> std::same_as; + { T{}.transpose() }; + // clang-format on + }; + + template + concept EigenSquareMatrix = EigenMatrix && requires + { + T::Rows == T::Cols; + }; + + template + concept CompileTimeFixedMatrix = requires + { + typename T::scalar_type; + // clang-format off + { T::Rows } -> std::same_as; + { T::Cols } -> std::same_as; + { T{} } -> std::same_as>; + // clang-format on + }; + +} // namespace DO::Sara::KalmanFilter diff --git a/cpp/src/DO/Sara/KalmanFilter/KalmanFilter.hpp b/cpp/src/DO/Sara/KalmanFilter/KalmanFilter.hpp deleted file mode 100644 index c2a727862..000000000 --- a/cpp/src/DO/Sara/KalmanFilter/KalmanFilter.hpp +++ /dev/null @@ -1,26 +0,0 @@ -#pragma once - - -namespace DO::Sara::KalmanFilter { - - template - struct KalmanFilter - { - using State = typename StateTransitionEquation::State; - using Observation = typename StateTransitionEquation::Observation; - - auto predict(const State& x) -> State - { - return _state_transition_equation.predict(x); - } - - auto update(const State& x_predicted, const Observation& z) -> State - { - return _observation_equation.update(x_predicted, z); - } - - StateTransitionEquation _state_transition_equation; - ObservationEquation _observation_equation; - }; - -} // namespace DO::Sara::KalmanFilter diff --git a/cpp/src/DO/Sara/KalmanFilter/ObservationEquation.hpp b/cpp/src/DO/Sara/KalmanFilter/ObservationEquation.hpp new file mode 100644 index 000000000..bbd4cfd93 --- /dev/null +++ b/cpp/src/DO/Sara/KalmanFilter/ObservationEquation.hpp @@ -0,0 +1,71 @@ +#pragma once + +#include + + +namespace DO::Sara::KalmanFilter { + + template + struct ObservationEquation + { + using T = typename ObservationModelMatrix::Scalar; + using Innovation = Observation; + using KalmanGain = typename Observation::covariance_matrix_type; + + inline static auto observation_model_matrix() -> ObservationModelMatrix + { + auto H = ObservationModelMatrix{}; + + static const auto I = Eigen::Matrix4::Identity(); + static const auto O = Eigen::Matrix::Zero(); + H << I, O; + + return H; + } + + inline auto innovation(const State& x_a_priori, + const Observation& z) const // + -> Innovation + { + return { + z - H * x_a_priori, // + H * x_a_priori.covariance_matrix() * H.transpose() + + v.covariance_matrix() // + }; + } + + inline auto kalman_gain_matrix(const Observation& x_predicted, + const Innovation& S) const // + -> KalmanGain + { + return x_predicted.covariance_matrix() * H.transpose() * + S.covariance_matrix().inverse(); + } + + inline auto update(const State& x_predicted, const Observation& z) -> State + { + const auto y = innovation(x_predicted, z); + const auto K = kalman_gain_matrix(x_predicted, y); + + static const auto I = State::CovarianceMatrix::Identity(); + return { + x_predicted.mean() + K * y.mean(), // + (I - K * H) * x_predicted.covariance_matrix() // + }; + } + + inline auto residual(const Observation& z, + const State& x) const // + -> typename Observation::mean_type + { + return z.mean() - H * x.covariance_matrix(); + } + + const ObservationModelMatrix H; + ObservationNoise v; + }; + +} // namespace DO::Sara::KalmanFilter diff --git a/cpp/src/DO/Sara/KalmanFilter/StateTransitionModel.hpp b/cpp/src/DO/Sara/KalmanFilter/StateTransitionModel.hpp new file mode 100644 index 000000000..cbe44afac --- /dev/null +++ b/cpp/src/DO/Sara/KalmanFilter/StateTransitionModel.hpp @@ -0,0 +1,25 @@ +#pragma once + +#include + + +namespace DO::Sara::KalmanFilter { + + template + struct StateTransitionEquation + { + auto predict(const StateDistribution& x) -> StateDistribution + { + return { + F * x.mean(), // + F * x.covariance_matrix() * F.transpose() + w.covariance_matrix() // + }; + }; + + StateTransitionMatrix F; + ProcessNoiseDistribution w; + }; + +} // namespace DO::Sara::KalmanFilter