From a0f6b99a19a2aeeff11874738dcd817acf5fb5af Mon Sep 17 00:00:00 2001 From: Alejandro Armas Date: Tue, 17 May 2022 22:17:36 -0700 Subject: [PATCH] Added softmax dispatch helpers. --- include/m_algorithms_utilities.h | 40 +++++++++++++++++++++----------- 1 file changed, 26 insertions(+), 14 deletions(-) diff --git a/include/m_algorithms_utilities.h b/include/m_algorithms_utilities.h index 4f866cd..9b26aaf 100644 --- a/include/m_algorithms_utilities.h +++ b/include/m_algorithms_utilities.h @@ -41,50 +41,62 @@ namespace Matrix { */ struct Stringify { constexpr std::string_view operator()( - const Unary::ReLU& _) { + const Unary::ReLU&) { return "ReLU"; } constexpr std::string_view operator()( - const Binary::HadamardProduct::Std& _) { + const Unary::SoftMax&) { + return "Softmax"; } + constexpr std::string_view operator()( + const Binary::HadamardProduct::Std&) { return "HadamardProduct"; } constexpr std::string_view operator()( - const Binary::Multiplication::ParallelDNC& _) { + const Binary::Multiplication::ParallelDNC&) { return "MatrixMultiply"; } constexpr std::string_view operator()( - const Binary::Multiplication::Naive& _) { + const Binary::Multiplication::Naive&) { return "MatrixMultiply"; } constexpr std::string_view operator()( - const Binary::Multiplication::Square& _) { + const Binary::Multiplication::Square&) { return "MatrixMultiply"; } constexpr std::string_view operator()( - const Binary::Addition::Std& _) { + const Binary::Addition::Std&) { return "Addition"; } constexpr std::string_view operator()( - const Binary::OuterProduct::Naive& _) { + const Binary::OuterProduct::Naive&) { return "OuterProduct"; } + constexpr std::string_view operator()( + const Metric::CrossEntropy&) { + return "CrossEntropy"; } }; struct Codify { constexpr Code operator()( - const Unary::ReLU& _) + const Unary::ReLU&) { return Code::ReLU; } constexpr Code operator()( - const Binary::HadamardProduct::Std& _) + const Unary::SoftMax&) + { return Code::SoftMax; } + constexpr Code operator()( + const Binary::HadamardProduct::Std&) { return Code::HADAMARD; } constexpr Code operator()( - const Binary::Multiplication::ParallelDNC& _) + const Binary::Multiplication::ParallelDNC&) { return Code::MULTIPLY; } constexpr Code operator()( - const Binary::Multiplication::Naive& _) + const Binary::Multiplication::Naive&) { return Code::MULTIPLY; } constexpr Code operator()( - const Binary::Multiplication::Square& _) + const Binary::Multiplication::Square&) { return Code::MULTIPLY; } constexpr Code operator()( - const Binary::Addition::Std& _) + const Binary::Addition::Std&) { return Code::PLUS; } constexpr Code operator()( - const Binary::OuterProduct::Naive& _) + const Binary::OuterProduct::Naive&) { return Code::OUTER_PRODUCT; } + constexpr Code operator()( + const Metric::CrossEntropy&) { + return Code::CROSS_ENTROPY; } };