From 985077f6480ee3f9e4cf976f240360c9396c0470 Mon Sep 17 00:00:00 2001 From: Monica Dessole <36501030+mdessole@users.noreply.github.com> Date: Fri, 24 May 2024 11:38:08 +0200 Subject: [PATCH] [matrix] Fast element setter method (#15606) * [matrix] Add SetElement method * [matrix] Add SetElement test --- math/matrix/inc/TMatrixT.h | 18 ++++++++++++++++++ math/matrix/inc/TMatrixTSym.h | 19 +++++++++++++++++++ math/matrix/test/testMatrixT.cxx | 16 ++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/math/matrix/inc/TMatrixT.h b/math/matrix/inc/TMatrixT.h index a99449c9efc41..b174cf50cbd75 100644 --- a/math/matrix/inc/TMatrixT.h +++ b/math/matrix/inc/TMatrixT.h @@ -31,6 +31,7 @@ #include "Rtypes.h" #include "TError.h" +#include template class TMatrixTSym; template class TMatrixTSparse; @@ -107,6 +108,8 @@ template class TMatrixT : public TMatrixTBase { void MultT(const TMatrixTSym &a,const TMatrixT &b); void MultT(const TMatrixTSym &a,const TMatrixTSym &b) { Mult(a,b); } + inline void SetElement(Int_t rown, Int_t coln, Element val); + const Element *GetMatrixArray () const override; Element *GetMatrixArray () override; const Int_t *GetRowIndexArray() const override { return nullptr; } @@ -283,6 +286,21 @@ template inline Element &TMatrixT::operator()(Int_t row return (fElements[arown*this->fNcols+acoln]); } +//////////////////////////////////////////////////////////////////////////////// +/// Efficiently sets element (rown,coln) equal to val +/// Index bound checks can be deactivated by defining NDEBUG + +template +inline void TMatrixT::SetElement(Int_t rown, Int_t coln, Element val) +{ + assert(this->IsValid()); + rown = rown - this->fRowLwb; + coln = coln - this->fColLwb; + assert((rown < this->fNrows && rown >= 0) && "SetElement() error: row index outside matrix range"); + assert((coln < this->fNcols && coln >= 0) && "SetElement() error: column index outside matrix range"); + fElements[rown * this->fNcols + coln] = val; +} + inline namespace TMatrixTAutoloadOps { template TMatrixT operator+ (const TMatrixT &source1,const TMatrixT &source2); diff --git a/math/matrix/inc/TMatrixTSym.h b/math/matrix/inc/TMatrixTSym.h index 7754b0d71e935..b22984f66a975 100644 --- a/math/matrix/inc/TMatrixTSym.h +++ b/math/matrix/inc/TMatrixTSym.h @@ -27,6 +27,8 @@ #include "TMatrixTBase.h" #include "TMatrixTUtils.h" +#include + templateclass TMatrixT; templateclass TMatrixTSymLazy; templateclass TVectorT; @@ -79,6 +81,8 @@ template class TMatrixTSym : public TMatrixTBase { void Plus (const TMatrixTSym &a,const TMatrixTSym &b); void Minus(const TMatrixTSym &a,const TMatrixTSym &b); + inline void SetElement(Int_t rown, Int_t coln, Element val); + const Element *GetMatrixArray () const override; Element *GetMatrixArray () override; const Int_t *GetRowIndexArray() const override { return nullptr; } @@ -235,6 +239,21 @@ template inline Element &TMatrixTSym::operator()(Int_t return (fElements[arown*this->fNcols+acoln]); } +//////////////////////////////////////////////////////////////////////////////// +/// Efficiently sets element (rown,coln) equal to val +/// Index bound checks can be deactivated by defining NDEBUG + +template +inline void TMatrixTSym::SetElement(Int_t rown, Int_t coln, Element val) +{ + assert(this->IsValid()); + rown = rown - this->fRowLwb; + coln = coln - this->fColLwb; + assert((rown < this->fNrows && rown >= 0) && "SetElement() error: row index outside matrix range"); + assert((coln < this->fNcols && coln >= 0) && "SetElement() error: column index outside matrix range"); + fElements[rown * this->fNcols + coln] = val; +} + template Bool_t operator== (const TMatrixTSym &source1,const TMatrixTSym &source2); template TMatrixTSym operator+ (const TMatrixTSym &source1,const TMatrixTSym &source2); template TMatrixTSym operator+ (const TMatrixTSym &source1, Element val); diff --git a/math/matrix/test/testMatrixT.cxx b/math/matrix/test/testMatrixT.cxx index b61089347d933..1b871cae31629 100644 --- a/math/matrix/test/testMatrixT.cxx +++ b/math/matrix/test/testMatrixT.cxx @@ -160,6 +160,22 @@ TYPED_TEST(testMatrix, Invert) CompareTMatrix(c, TestFixture::eye); } +TYPED_TEST(testMatrix, SetElement) +{ + TypeParam b(n, n); + TypeParam c(n, n); + + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + b(i, j) = n * i + j + 1 * (i == j); + + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + c.SetElement(i, j, n * i + j + 1 * (i == j)); + + CompareTMatrix(b, c); +} + class testMatrixD : public testing::Test { protected: void SetUp() override