Skip to content

Commit

Permalink
2 states test plus tests minor refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Nov 2, 2023
1 parent 14f1bc9 commit 028deb2
Showing 1 changed file with 152 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,37 @@ using namespace ov::test;

namespace SubgraphTestsDefinitions {

class StatefulModelTest : public SubgraphBaseTest {
public:
static constexpr ov::element::Type_t testPrc = ov::element::Type_t::f32;

public:
void prepare() {
compile_model();
inferRequest = compiledModel.create_infer_request();
ASSERT_TRUE(inferRequest);
}

void reset_state() {
for (auto&& state : inferRequest.query_state()) {
state.reset();
}
}

static void float_compare(const float* expected_res, const float* actual_res, size_t size) {
constexpr float rel_diff_threshold = 1e-4f;
for (size_t i = 0; i < size; ++i) {
const float expected_val = expected_res[i];
const float actual_val = actual_res[i];
if (0.f == expected_val) {
ASSERT_TRUE(abs(actual_val) < rel_diff_threshold);
} else {
ASSERT_TRUE(abs(actual_val / expected_val - 1.f) < rel_diff_threshold);
}
}
}
};

// ┌────────┐ ┌───────┐
// │ Param1 │ │ Const │
// └───┬────┘ └───┬───┘
Expand All @@ -34,10 +65,7 @@ namespace SubgraphTestsDefinitions {
// │ Result │
// └───────────┘

class StaticShapeStatefulModel : public SubgraphBaseTest {
public:
static constexpr ov::element::Type_t testPrc = ov::element::Type_t::f32;

class StaticShapeStatefulModel : public StatefulModelTest {
public:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
Expand Down Expand Up @@ -79,12 +107,6 @@ class StaticShapeStatefulModel : public SubgraphBaseTest {
return result;
}

void prepare() {
compile_model();
inferRequest = compiledModel.create_infer_request();
ASSERT_TRUE(inferRequest);
}

void run_test() {
auto& input_vals = get_inputs();
for (size_t i = 0; i < input_vals.size(); ++i) {
Expand All @@ -101,25 +123,18 @@ class StaticShapeStatefulModel : public SubgraphBaseTest {
auto outputTensor = inferRequest.get_output_tensor(0);
ASSERT_TRUE(outputTensor);
inferRequest.infer();
constexpr float rel_diff_threshold = 1e-4f;
const auto& expected_res = calc_refs().first;
const float expected_val = expected_res[i];
const float actual_val = outputTensor.data<ov::element_type_traits<ov::element::f32>::value_type>()[0];
ASSERT_TRUE(abs(actual_val - expected_val) / abs(expected_val) < rel_diff_threshold);
float_compare(&expected_val, &actual_val, 1);
auto states = inferRequest.query_state();
ASSERT_FALSE(states.empty());
auto mstate = states.front().get_state();
ASSERT_TRUE(mstate);
const auto& expected_states = calc_refs().second;
const float expected_state_val = expected_states[i];
const float actual_state_val = mstate.data<ov::element_type_traits<ov::element::f32>::value_type>()[0];
ASSERT_TRUE(abs(expected_state_val - actual_state_val) / abs(expected_state_val) < rel_diff_threshold);
}
}

void reset_state() {
for (auto&& state : inferRequest.query_state()) {
state.reset();
float_compare(&expected_state_val, &actual_state_val, 1);
}
}
};
Expand All @@ -131,6 +146,122 @@ TEST_F(StaticShapeStatefulModel, smoke_Run_Stateful_Static) {
run_test();
}

// ┌────────┐ ┌───────┐
// │ Param1 │ │ Const ├────────┐
// └───┬────┘ └───┬───┘ │
// │ │ │
// │ ┌────┴──────┐ │
// .......│.........│ ReadValue │ │
// . │ └────┬──────┘ │
// . │ │ │
// . │ ┌─────┐ │ │
// . └───┤ Add ├────┘ ┌────┴──────┐
// . └──┬──┘ │ ReadValue │..
// . │ └────┬──────┘ .
// . │ │ .
// . ┌────────┐ │ ┌─────┐ │ .
// ..│ Assign ├──┴────┤ Add ├─────────┘ .
// └────────┘ └─────┘ .
// / \ .
// / \ .
// ┌───────────┐ ┌───────────┐ .
// │ Result │ │ Assign │......
// └───────────┘ └───────────┘

class StaticShapeTwoStatesModel : public StatefulModelTest {
public:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
ov::element::Type netPrc = testPrc;

const ov::Shape inpShape = {1, 1};
targetStaticShapes = {{inpShape}};

auto arg = std::make_shared<ov::op::v0::Parameter>(netPrc, ov::Shape{1, 1});
auto init_const = ov::op::v0::Constant::create(netPrc, ov::Shape{1, 1}, {5.f});

// The ReadValue/Assign operations must be used in pairs in the model.
// For each such a pair, its own variable object must be created.
auto variable0 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape::dynamic(), ov::element::dynamic, "variable0"});

auto variable1 = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{ov::PartialShape::dynamic(), ov::element::dynamic, "variable1"});

// Creating ov::Model
auto read0 = std::make_shared<ov::op::v6::ReadValue>(init_const, variable0);
auto add = ngraph::builder::makeEltwise(arg, read0, ngraph::helpers::EltwiseTypes::ADD);
auto assign0 = std::make_shared<ov::op::v6::Assign>(add, variable0);
auto read1 = std::make_shared<ov::op::v6::ReadValue>(init_const, variable1);
auto add2 = ngraph::builder::makeEltwise(add, read1, ngraph::helpers::EltwiseTypes::ADD);
auto assign1 = std::make_shared<ov::op::v6::Assign>(add2, variable1);
auto res = std::make_shared<ov::op::v0::Result>(add2);
function = std::make_shared<ov::Model>(
ov::ResultVector({res}),
ov::SinkVector({assign0, assign1}),
ov::ParameterVector({arg}));
}

const std::vector<float>& get_inputs() const {
static const std::vector<float> input_vals =
{6.06f, 5.75f, 1.92f, 1.61f, 7.78f, 7.47f, 3.64f, 3.33f, 9.5f, 9.19f};
return input_vals;
}

const std::pair<std::vector<float>, std::vector<float>>& calc_refs() const {
static const std::pair<std::vector<float>, std::vector<float>> result = {
{11.06f, 16.81f, 18.73f, 20.34f, 28.12f, 35.59f, 39.23f, 42.56f, 52.06f, 61.25f}, // state0
{16.06f, 32.87f, 51.60f, 71.94f, 100.06f, 135.65f, 174.88f, 217.44f, 269.50f, 330.75f} // state1 and result
};
return result;
}

void run_test() {
auto& input_vals = get_inputs();
for (size_t i = 0; i < input_vals.size(); ++i) {
inputs.clear();
const auto& funcInputs = function->inputs();
const auto& funcInput = funcInputs.front();
auto tensor = ov::runtime::Tensor{testPrc, funcInput.get_shape()};
auto inputData = tensor.data<ov::element_type_traits<testPrc>::value_type>();
inputData[0] = input_vals[i];
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
for (const auto& input : inputs) {
inferRequest.set_tensor(input.first, input.second);
}
auto outputTensor = inferRequest.get_output_tensor(0);
ASSERT_TRUE(outputTensor);
inferRequest.infer();
std::vector<float> expected_state0;
std::vector<float> expected_results;
std::tie(expected_state0, expected_results) = calc_refs();

auto states = inferRequest.query_state();
ASSERT_FALSE(states.empty());
ov::Tensor state0;
ov::Tensor state1;
for (auto&& state : states) {
if ("variable0" == state.get_name()) {
state0 = state.get_state();
}
if ("variable1" == state.get_name()) {
state1 = state.get_state();
}
}
ASSERT_TRUE(state0);
ASSERT_TRUE(state1);
auto actual_result = outputTensor.data<ov::element_type_traits<testPrc>::value_type>();
float_compare(&expected_state0[i], state0.data<ov::element_type_traits<testPrc>::value_type>(), 1);
float_compare(&expected_results[i], state1.data<ov::element_type_traits<testPrc>::value_type>(), 1);
float_compare(&expected_results[i], actual_result, 1);
}
}
};

TEST_F(StaticShapeTwoStatesModel, smoke_Run_Static_Two_States) {
prepare();
run_test();
}

// ┌─────────┐ ┌───────────┐
// │ Param1 │--->│ ReadValue │..
Expand All @@ -150,23 +281,7 @@ TEST_F(StaticShapeStatefulModel, smoke_Run_Stateful_Static) {
// │ Result │ │ Assign │.....
// └────────┘ └────────┘

static void float_compare(const float* expected_res, const float* actual_res, size_t size) {
constexpr float rel_diff_threshold = 1e-4f;
for (size_t i = 0; i < size; ++i) {
const float expected_val = expected_res[i];
const float actual_val = actual_res[i];
if (0.f == expected_val) {
ASSERT_TRUE(abs(actual_val) < rel_diff_threshold);
} else {
ASSERT_TRUE(abs(actual_val / expected_val - 1.f) < rel_diff_threshold);
}
}
}

class DynamicShapeStatefulModel : public SubgraphBaseTest {
public:
static constexpr ov::element::Type_t testPrc = ov::element::Type_t::f32;

class DynamicShapeStatefulModel : public StatefulModelTest {
public:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
Expand Down Expand Up @@ -214,12 +329,6 @@ class DynamicShapeStatefulModel : public SubgraphBaseTest {
return result;
}

void prepare() {
compile_model();
inferRequest = compiledModel.create_infer_request();
ASSERT_TRUE(inferRequest);
}

void run_test() {
std::vector<float> vec_state = {0};

Expand Down Expand Up @@ -265,12 +374,6 @@ class DynamicShapeStatefulModel : public SubgraphBaseTest {
float_compare(vec_state.data(), actual_state, vec_state.size());
}
}

void reset_state() {
for (auto&& state : inferRequest.query_state()) {
state.reset();
}
}
};

TEST_F(DynamicShapeStatefulModel, smoke_Run_Stateful_Dynamic) {
Expand Down Expand Up @@ -308,10 +411,7 @@ TEST_F(DynamicShapeStatefulModel, smoke_Run_Stateful_Dynamic) {
// │ Result2 │ │ Assign │.....
// └─────────┘ └────────┘

class DynamicShapeStatefulModelStateAsInp : public SubgraphBaseTest {
public:
static constexpr ov::element::Type_t testPrc = ov::element::Type_t::f32;

class DynamicShapeStatefulModelStateAsInp : public StatefulModelTest {
public:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
Expand Down Expand Up @@ -375,12 +475,6 @@ class DynamicShapeStatefulModelStateAsInp : public SubgraphBaseTest {
return {result1, result2};
}

void prepare() {
compile_model();
inferRequest = compiledModel.create_infer_request();
ASSERT_TRUE(inferRequest);
}

void run_test() {
std::vector<float> vec_state = {0.f};

Expand Down Expand Up @@ -439,12 +533,6 @@ class DynamicShapeStatefulModelStateAsInp : public SubgraphBaseTest {
}
}

void reset_state() {
for (auto&& state : inferRequest.query_state()) {
state.reset();
}
}

private:
float const_val = 0.0f;
};
Expand Down

0 comments on commit 028deb2

Please sign in to comment.