diff --git a/include/m_algorithms.h b/include/m_algorithms.h index b01ff5a..e237a8b 100644 --- a/include/m_algorithms.h +++ b/include/m_algorithms.h @@ -23,42 +23,23 @@ namespace Matrix { }; - class BaseInterface { - - public: - virtual ~BaseInterface() = default; - virtual std::unique_ptr operator()( - const std::unique_ptr& l, - const std::unique_ptr& r = nullptr) = 0; - virtual Code get_operation_code() = 0; - - }; - + namespace Unary { template - class UnaryAdapter : public BaseInterface { + class UnaryAdapter { public: std::unique_ptr operator()( - const std::unique_ptr& l, - const std::unique_ptr& r = nullptr) override { - + const std::unique_ptr& l) { if (!l) { throw std::invalid_argument("Left operand not referencing a matrix."); } - - - if (r != nullptr) { - throw std::invalid_argument("Unary Operation needs one operand."); - } return Impl().operate(l); }; ~UnaryAdapter() = default; - Code get_operation_code() override { return Impl().get_code(); }; - private: Implementation& Impl() { return *static_cast(this); } friend Implementation; @@ -74,8 +55,6 @@ namespace Matrix { public: std::unique_ptr operate( const std::unique_ptr& m); - Code get_code() { return Code::ReLU; }; - }; static_assert(MatrixOperatable); @@ -87,13 +66,13 @@ namespace Matrix { template - class BaseOp : public BaseInterface { + class BaseOp { public: BaseOp() = default; virtual std::unique_ptr operator()( const std::unique_ptr& l, - const std::unique_ptr& r) override { + const std::unique_ptr& r) { if (!l) { throw std::invalid_argument("Left operand not referencing a matrix."); @@ -104,9 +83,7 @@ namespace Matrix { return Impl().operate(l, r); }; - virtual ~BaseOp() = default; - Code get_operation_code() override { return Impl().get_code(); }; - private: + virtual ~BaseOp() = default; private: Implementation& Impl() { return *static_cast(this); } friend Implementation; @@ -125,8 +102,6 @@ namespace Matrix { std::unique_ptr operate( const std::unique_ptr& l, const std::unique_ptr& r); - Code get_code() { return Code::PLUS; }; - }; } @@ -143,8 +118,6 @@ namespace Matrix { std::unique_ptr operate( const std::unique_ptr& l, const std::unique_ptr& r); - Code get_code() { return Code::OUTER_PRODUCT; }; - }; @@ -166,8 +139,6 @@ namespace Matrix { std::unique_ptr operate( const std::unique_ptr& l, const std::unique_ptr& r); - Code get_code() { return Code::HADAMARD; }; - }; @@ -177,8 +148,6 @@ namespace Matrix { std::unique_ptr operate( const std::unique_ptr& l, const std::unique_ptr& r); - Code get_code() { return Code::HADAMARD; }; - }; @@ -208,8 +177,6 @@ namespace Matrix { std::unique_ptr operate( const std::unique_ptr& l, const std::unique_ptr& r); - Code get_code() { return Code::MULTIPLY; }; - }; @@ -220,8 +187,6 @@ namespace Matrix { std::unique_ptr operate( const std::unique_ptr& l, const std::unique_ptr& r) ; - Code get_code() { return Code::MULTIPLY; }; - }; @@ -231,8 +196,6 @@ namespace Matrix { std::unique_ptr operate( const std::unique_ptr& l, const std::unique_ptr& r); - Code get_code() { return Code::MULTIPLY; }; - }; diff --git a/include/matrix_benchmark.h b/include/matrix_benchmark.h index 2073c73..4583570 100644 --- a/include/matrix_benchmark.h +++ b/include/matrix_benchmark.h @@ -13,29 +13,6 @@ namespace Matrix { namespace Operations { - - template - class Benchmark { - - public: - Benchmark(std::unique_ptr _m) : matrix_operation(std::move(_m)) {} - Code get_operation_code() { return matrix_operation->get_operation_code(); }; - - protected: - Implementation* Impl() { return static_cast(this);} - std::unique_ptr operator()( - const std::unique_ptr& l, - const std::unique_ptr& r = nullptr) { - return Impl()->operator()(l, r); - }; - ~Benchmark() = default; - std::unique_ptr matrix_operation; - - - }; - - - /* DESCRIPTION: @@ -58,43 +35,29 @@ namespace Matrix { std::cout << "Performed in " << mul_bm_r.get_computation_duration_ms() << " ms." << std::endl; */ - class Timer : public Benchmark { + template + class Timer { public: - // Timer() : Benchmark(std::move( - // std::make_unique())) {} - Timer(std::unique_ptr _m) : - Benchmark(std::move(_m)) {} + Timer(Operator _m) : + matrix_operation(_m) {} std::unique_ptr operator()( const std::unique_ptr& l, const std::unique_ptr& r = nullptr); - int get_computation_duration_ms() { - return std::chrono::duration_cast>(end - start).count(); } - - std::chrono::steady_clock::time_point get_start() { return start; } - std::chrono::steady_clock::time_point get_end() { return end; } - + int get_computation_duration_ms() { + return std::chrono::duration_cast>(end - start).count(); } + std::chrono::steady_clock::time_point get_start() { return start; } + std::chrono::steady_clock::time_point get_end() { return end; } + private: + Operator matrix_operation; std::chrono::steady_clock::time_point start; std::chrono::steady_clock::time_point end; }; - - // #ifdef CILKSCALE - // class ParallelMeasurer : public Benchmark { - - // public: - // ParallelMeasurer(std::unique_ptr _m) : - // Benchmark(std::move(_m)) {} - // std::unique_ptr operator()( - // const std::unique_ptr& l, - // const std::unique_ptr& r); - - // }; - // #endif - + } diff --git a/matrix_benchmark.cpp b/matrix_benchmark.cpp index 30911f9..a378306 100644 --- a/matrix_benchmark.cpp +++ b/matrix_benchmark.cpp @@ -11,17 +11,37 @@ namespace Matrix { - std::unique_ptr Operations::Timer::operator()(const std::unique_ptr& l, + template + std::unique_ptr Operations::Timer::operator()(const std::unique_ptr& l, const std::unique_ptr& r) { - start = std::chrono::steady_clock::now(); - std::unique_ptr mc = this->matrix_operation->operator()(l, r); + std::unique_ptr mc; + + start = std::chrono::steady_clock::now(); + if constexpr (Matrix::Operations::UnaryMatrixOperatable) { + mc = this->matrix_operation(l); + } + else if constexpr (Matrix::Operations::BinaryMatrixOperatable) { + mc = this->matrix_operation(l, r); + } + end = std::chrono::steady_clock::now(); return mc; } + + template class Operations::Timer; + template class Operations::Timer; + template class Operations::Timer; + template class Operations::Timer; + template class Operations::Timer; + template class Operations::Timer; + template class Operations::Timer; + + + // #ifdef CILKSCALE /* Cilkscale's command-line output includes work and span measurements for the Cilk program in terms of empirically measured times.