Skip to content

Commit

Permalink
FIxed is_on_constant_path() using in all places (openvinotoolkit#19239)
Browse files Browse the repository at this point in the history
* Fixed matmul weights check in snippets_mark_skipped

* fix

* ConvertMatMulToFC: is_on_constant_path fix

* [TESTS] added SplitMatMulConcat subgraph test

* MarkDequantizationSubgraph: is_on_constant_path fix
  • Loading branch information
antonvor authored Aug 18, 2023
1 parent 24ddf1b commit 4f29e60
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
ov::matcher_pass_callback callback = [=](pattern::Matcher& m) -> bool {
const auto& pattern_map = m.get_pattern_value_map();
auto convert = pattern_map.at(convert_pattern).get_node_shared_ptr();
auto input = pattern_map.at(input_pattern).get_node_shared_ptr();
auto input = pattern_map.at(input_pattern);
const auto multiply = m.get_match_root();

if (transformation_callback(multiply)) {
Expand All @@ -48,12 +48,12 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
if (node && std::find(precisions.begin(), precisions.end(), node->get_input_element_type(0)) !=
precisions.end()) {
convert = node;
input = convert->get_input_node_shared_ptr(0);
input = convert->input_value(0);
}
}
}

const auto& input_precision = input->get_output_element_type(0);
const auto& input_precision = input.get_element_type();
// validation by Convert operation input precisions
if (std::find(precisions.begin(), precisions.end(), input_precision) == precisions.end()) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
MATCHER_SCOPE(ConvertMatMulToFC);
auto activations_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
auto weights_path = [](const ov::Output<ov::Node>& output) {
return ov::op::util::is_on_constant_path(output.get_node_shared_ptr());
return ov::op::util::is_on_constant_path(output);
};
auto weights_m = ngraph::pattern::any_input(weights_path);
auto matmul_m = ngraph::pattern::wrap_type<ngraph::op::v0::MatMul>({ activations_m, weights_m }, ngraph::pattern::has_static_rank());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
ov::PartialShape matmul_shape;
for (const auto &parent_out : node->input_values()) {
const auto parent = parent_out.get_node_shared_ptr();
if (ov::op::util::is_on_constant_path(parent)) {
if (ov::op::util::is_on_constant_path(parent_out)) {
bias_shape = parent_out.get_shape();
num_non_const_inputs++;
} else {
Expand All @@ -265,7 +265,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
// first check that weights are constant and both activations and weights have static shape
if (grandparents.size() == 2 &&
grandparents[1].get_partial_shape().is_static() &&
(ov::op::util::is_on_constant_path(grandparents[1].get_node_shared_ptr()))) {
(ov::op::util::is_on_constant_path(grandparents[1]))) {
auto rank_a = grandparents[0].get_partial_shape().rank().get_length();
auto rank_w = grandparents[1].get_partial_shape().rank().get_length();
if (rank_a != 1 && rank_w != 1 && rank_a <= 3 && rank_w <= 3)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "test_utils/fusing_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"

using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ov::test;

namespace SubgraphTestsDefinitions {

/*
---------------
| Input |
---------------
|
---------------
|VariadicSplit|
---------------
| |
--------- |
|MatMul | |
--------- |
| |
---------------
| Concat |
---------------
|
---------------
| Output |
---------------
*/

using SplitMatMulConcatParams = std::tuple<
std::vector<InputShape>, // input shapes
std::pair<bool, bool> // transposeA, transposeB
>;

class SplitMatMulConcatTest : public testing::WithParamInterface<SplitMatMulConcatParams>,
virtual public SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(testing::TestParamInfo<SplitMatMulConcatParams> obj) {
std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose;

std::tie(inputShapes, transpose) = obj.param;

std::ostringstream result;
for (const auto& shape : inputShapes) {
result << ov::test::utils::partialShape2str({shape.first}) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
if (!shape.second.empty()) {
auto itr = shape.second.begin();
do {
result << ov::test::utils::vec2str(*itr);
} while (++itr != shape.second.end() && result << "_");
}
result << ")_";
}
result << "transpose_a=" << transpose.first << "_";
result << "transpose_b=" << transpose.second << "_";

return result.str();
}

protected:
template<typename T>
void transposeShape(T& shape) {
IE_ASSERT(shape.size() > 1);
std::swap(*(shape.end() - 1), *(shape.end() - 2));
}

void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;

std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose;

std::tie(inputShapes, transpose) = this->GetParam();

init_input_shapes(inputShapes);

bool transpA = transpose.first;
bool transpB = transpose.second;

if (transpA) {
transposeShape(inputDynamicShapes[0]);
for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[0]);
}
}
if (transpB) {
transposeShape(inputDynamicShapes[1]);
for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[1]);
}
}

const auto& inShapeA = inputDynamicShapes[0];
const auto& inShapeB = inputDynamicShapes[1];

auto params = builder::makeDynamicParams(ElementType::f32, {inShapeA});
auto paramOuts = helpers::convert2OutputVector(helpers::castOps2Nodes<opset1::Parameter>(params));
std::shared_ptr<Node> inputB = builder::makeConstant<float>(ElementType::f32, inShapeB.get_shape(), {}, true);

auto split = builder::makeVariadicSplit(paramOuts[0], {1, 1}, 0);

auto matMul = builder::makeMatMul(split->output(0), inputB, transpA, transpB);

auto concat = builder::makeConcat({matMul, split->output(1)}, 0);

function = CPUTestsBase::makeNgraphFunction(ElementType::f32, params, concat, "FullyConnected");
}
};

TEST_P(SplitMatMulConcatTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
run();
}

namespace {

const std::vector<std::pair<bool, bool>> transposeParams = {
{false, true},
};

const std::vector<std::vector<InputShape>> inputShapes2D = {
static_shapes_to_test_representation({{2, 3}, {3, 3}}),
};

const auto testParams2D_FP32_smoke = ::testing::Combine(
::testing::ValuesIn(inputShapes2D),
::testing::ValuesIn(transposeParams));

INSTANTIATE_TEST_SUITE_P(smoke_FC_2D_FP32, SplitMatMulConcatTest, testParams2D_FP32_smoke,
SplitMatMulConcatTest::getTestCaseName);

} // namespace

} // namespace SubgraphTestsDefinitions

0 comments on commit 4f29e60

Please sign in to comment.