Skip to content

Commit

Permalink
[tmva][sofie] Add Sin/Cos operators
Browse files Browse the repository at this point in the history
Add Sin and Cos operators as new Unary operators.
Add also tests, taken from Vedant's PR  root-project#16809
  • Loading branch information
lmoneta committed Dec 12, 2024
1 parent d462aa0 commit 812964f
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 1 deletion.
14 changes: 13 additions & 1 deletion tmva/sofie/inc/TMVA/ROperator_BasicUnary.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace TMVA {
namespace Experimental {
namespace SOFIE {

enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog };
enum class EBasicUnaryOperator { kReciprocal, kSqrt , kNeg, kExp, kLog, kSin, kCos };

template <typename T, EBasicUnaryOperator Op>
struct UnaryOpTraits {
Expand Down Expand Up @@ -45,6 +45,18 @@ struct UnaryOpTraits<T, EBasicUnaryOperator::kLog> {
static std::string Op(const std::string &X) { return "std::log(" + X + ")"; }
};

template <typename T>
struct UnaryOpTraits<T, EBasicUnaryOperator::kSin> {
static std::string Name() { return "Sin"; }
static std::string Op(const std::string &X) { return "std::sin(" + X + ")"; }
};

template <typename T>
struct UnaryOpTraits<T, EBasicUnaryOperator::kCos> {
static std::string Name() { return "Cos"; }
static std::string Op(const std::string &X) { return "std::cos(" + X + ")"; }
};

template <typename T, EBasicUnaryOperator Op>
class ROperator_BasicUnary final : public ROperator {
private:
Expand Down
50 changes: 50 additions & 0 deletions tmva/sofie/test/TestCustomModelsFromONNX.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@

#include "Where_FromONNX.hxx"

#include "Sin_FromONNX.hxx"

#include "Cos_FromONNX.hxx"

#include "gtest/gtest.h"

constexpr float DEFAULT_TOLERANCE = 1e-3f;
Expand Down Expand Up @@ -2937,4 +2941,50 @@ TEST(ONNX, Where) {
for (size_t i = 0; i < output.size(); i++) {
EXPECT_EQ(output[i], correct[i]);
}
}
float outputs[] = {0.406200, 0.111242, 0.770231, 0.940162, 0.260436, -0.258742,
0.304129, 0.999899, 0.256423, 0.410855, 0.843406, 0.862500};

TEST(ONNX, Sin)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing some random input
std::vector<float> input({
-0.786738,-0.197796,-0.187787,0.142758,0.876096,-0.653239,0.145444,-1.107658,2.259171,-0.947054,-0.506689,1.801250
});

TMVA_SOFIE_Sin::Session s("Sin_FromONNX.dat");

std::vector<float> output = s.infer(input.data());

// Checking output size
EXPECT_EQ(output.size(), input.size());

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - std::sin(input[i])), TOLERANCE);
}
}

TEST(ONNX, Cos)
{
constexpr float TOLERANCE = DEFAULT_TOLERANCE;

// Preparing the random input
std::vector<float> input({
1.152504,-1.459324,0.691594,0.347690,-1.307323,1.832516,-1.261772,0.014224,1.311477,1.147405,-0.567206,-0.530606
});

TMVA_SOFIE_Cos::Session s("Cos_FromONNX.dat");

std::vector<float> output = s.infer(input.data());

// Checking output size
EXPECT_EQ(output.size(), input.size());

// Checking every output value, one by one
for (size_t i = 0; i < output.size(); ++i) {
EXPECT_LE(std::abs(output[i] - std::cos(input[i])), TOLERANCE);
}
}
12 changes: 12 additions & 0 deletions tmva/sofie/test/input_models/Cos.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

 cos_example:S

inputoutput"CosCosGraphZ
input


b
output


B
12 changes: 12 additions & 0 deletions tmva/sofie/test/input_models/Sin.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

 onnx-example:S

inputoutput"Sinsin_testZ
input


b
output


B

0 comments on commit 812964f

Please sign in to comment.