From 413fee2a86d50d13a1324da289e3f835a489a21c Mon Sep 17 00:00:00 2001 From: Mang Guo Date: Mon, 24 Jan 2022 20:07:35 +0800 Subject: [PATCH] [CPU] Enable cache for MVN (#9549) --- .../intel_cpu/src/nodes/mkldnn_lrn_node.cpp | 2 +- .../intel_cpu/src/nodes/mkldnn_mvn_node.cpp | 408 +++++--- .../intel_cpu/src/nodes/mkldnn_mvn_node.h | 83 +- .../plugin/cpu/single_layer_tests/mvn.cpp | 910 ++++++++++-------- .../ngraph_functions/src/mvn.cpp | 2 +- 5 files changed, 825 insertions(+), 580 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_lrn_node.cpp b/src/plugins/intel_cpu/src/nodes/mkldnn_lrn_node.cpp index 3038c159323e6f..ad46d6e71cfd76 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_lrn_node.cpp +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_lrn_node.cpp @@ -50,7 +50,7 @@ bool LrnKey::operator==(const LrnKey &rhs) const { retVal = retVal && inp0 && rhs.inp0 && inp0->getDnnlDesc() == rhs.inp0->getDnnlDesc(); } - retVal = retVal && implType == rhs.implType && alg == rhs.alg && alg == rhs.alg && size == rhs.size && k == rhs.k && + retVal = retVal && implType == rhs.implType && alg == rhs.alg && size == rhs.size && k == rhs.k && alpha == rhs.alpha && beta == rhs.beta; return retVal; } diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_mvn_node.cpp b/src/plugins/intel_cpu/src/nodes/mkldnn_mvn_node.cpp index 495e10a11598d9..b8f4e3ee81c9c1 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_mvn_node.cpp +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_mvn_node.cpp @@ -36,6 +36,56 @@ using namespace Xbyak; #define GET_OFF(field) offsetof(jit_mvn_call_args, field) +namespace { +struct MVNKey { + MKLDNNMVNNode::MVNAttrs mvnAttrs; + mkldnn::primitive_attr attr; + + size_t hash() const; + bool operator==(const MVNKey& rhs) const; +}; + +size_t MVNKey::hash() const { + using namespace dnnl::impl; + using namespace dnnl::impl::primitive_hashing; + + size_t seed = 0; + + seed = hash_combine(seed, std::get<0>(mvnAttrs.shape5D)); + seed = hash_combine(seed, std::get<1>(mvnAttrs.shape5D)); + seed = hash_combine(seed, std::get<2>(mvnAttrs.shape5D)); + seed = hash_combine(seed, std::get<3>(mvnAttrs.shape5D)); + seed = hash_combine(seed, std::get<4>(mvnAttrs.shape5D)); + seed = hash_combine(seed, mvnAttrs.initAcrossChannels_); + seed = hash_combine(seed, mvnAttrs.execAcrossChannels_); + seed = hash_combine(seed, mvnAttrs.normalizeVariance_); + seed = hash_combine(seed, mvnAttrs.epsValue_); + seed = hash_combine(seed, mvnAttrs.epsMode_); + seed = hash_combine(seed, mvnAttrs.src_prc.getPrecVal()); + seed = hash_combine(seed, mvnAttrs.dst_prc.getPrecVal()); + seed = hash_combine(seed, mvnAttrs.planar_layout); + seed = hash_combine(seed, mvnAttrs.is_nhwc); + seed = hash_combine(seed, get_attr_hash(*attr.get())); + return seed; +} + +bool MVNKey::operator==(const MVNKey& rhs) const { + bool retVal = true; + retVal = retVal && mvnAttrs.shape5D == rhs.mvnAttrs.shape5D && + mvnAttrs.initAcrossChannels_ == rhs.mvnAttrs.initAcrossChannels_ && + mvnAttrs.execAcrossChannels_ == rhs.mvnAttrs.execAcrossChannels_ && + mvnAttrs.normalizeVariance_ == rhs.mvnAttrs.normalizeVariance_ && + mvnAttrs.epsValue_ == rhs.mvnAttrs.epsValue_ && + mvnAttrs.epsMode_ == rhs.mvnAttrs.epsMode_ && + mvnAttrs.src_prc == rhs.mvnAttrs.src_prc && + mvnAttrs.dst_prc == rhs.mvnAttrs.dst_prc && + mvnAttrs.is_nhwc == rhs.mvnAttrs.is_nhwc && + mvnAttrs.planar_layout == mvnAttrs.planar_layout; + retVal = retVal && *attr.get() == *rhs.attr.get(); + return retVal; +} +} // namespace + // some utility functions static inline bool isFloatCompatible(Precision prc) { return Precision::FP32 == prc || Precision::BF16 == prc; @@ -389,6 +439,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator this->preamble(); + mov(reg_post_ops_data, ptr[reg_params + GET_OFF(post_op_data)]); mov(reg_src, ptr[reg_params + GET_OFF(src)]); mov(reg_mean, ptr[reg_params + GET_OFF(mean)]); if (jcp_.normalize_variance) @@ -505,6 +556,7 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator Xbyak::Reg64 reg_oc_off = rax; Xbyak::Reg64 reg_d_weights = rbx; Xbyak::Reg64 reg_d_bias = rdx; + Xbyak::Reg64 reg_post_ops_data = rsi; Xbyak::Reg64 reg_load_table = r15; Xbyak::Reg64 reg_load_store_mask = rbp; @@ -570,16 +622,19 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator int eltwise_inj_idx = 0; int depthwise_inj_idx = 0; int quantization_inj_idx = 0; + int post_ops_data_offset = 0; for (int i = 0; i < p.len(); i++) { auto& post_op = p.entry_[i]; if (post_op.is_eltwise()) { eltwise_injectors[eltwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), vmm_val.getIdx() + 1); eltwise_inj_idx++; } else if (post_op.is_depthwise()) { - mov(reg_d_weights, reinterpret_cast(post_op.depthwise.weights_data)); - mov(reg_d_bias, reinterpret_cast(post_op.depthwise.biases_data)); + mov(reg_d_weights, ptr[reg_post_ops_data + post_ops_data_offset]); add(reg_d_weights, reg_oc_off); + post_ops_data_offset += sizeof(float*); + mov(reg_d_bias, ptr[reg_post_ops_data + post_ops_data_offset]); add(reg_d_bias, reg_oc_off); + post_ops_data_offset += sizeof(float*); depthwise_injectors[depthwise_inj_idx]->compute_vector_range(vmm_val.getIdx(), vmm_val.getIdx() + 1, reg_d_weights, reg_d_bias, is_broadcast); depthwise_inj_idx++; } else if (post_op.is_quantization()) { @@ -587,15 +642,16 @@ struct jit_uni_mvn_kernel_f32 : public jit_uni_mvn_kernel, public jit_generator bool do_rounding = do_dequantization || isFloatCompatible(dst_prc) || i != p.len() - 1; int s_idx = vmm_val.getIdx(); - quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_oc_off); + quantization_injectors[quantization_inj_idx]->init_crop_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off); quantization_injectors[quantization_inj_idx]->compute_crop(s_idx, s_idx + 1, 0, 0, is_broadcast); - quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_oc_off); + quantization_injectors[quantization_inj_idx]->init_input_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off); quantization_injectors[quantization_inj_idx]->compute_input_scale_shift(s_idx, s_idx + 1, 0, do_rounding, 0, is_broadcast); - quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_oc_off); + quantization_injectors[quantization_inj_idx]->init_output_scale_shift_ptrs(reg_post_ops_data + post_ops_data_offset, reg_oc_off); quantization_injectors[quantization_inj_idx]->compute_output_scale_shift(s_idx, s_idx + 1, 0, 0, is_broadcast); + post_ops_data_offset += quantization_injectors[quantization_inj_idx]->memoryStep(); quantization_inj_idx++; } } @@ -676,24 +732,24 @@ MKLDNNMVNNode::MKLDNNMVNNode(const std::shared_ptr& op, const mkld IE_THROW(NotImplemented) << errorMessage; } - epsMode_ = INSIDE_SQRT; + mvnAttrs.epsMode_ = INSIDE_SQRT; if (auto mvnOp = ngraph::as_type_ptr(op)) { - normalizeVariance_ = mvnOp->get_normalize_variance(); - epsValue_ = mvnOp->get_eps(); + mvnAttrs.normalizeVariance_ = mvnOp->get_normalize_variance(); + mvnAttrs.epsValue_ = mvnOp->get_eps(); if (mvnOp->get_eps_mode() == ngraph::op::MVNEpsMode::OUTSIDE_SQRT) { - epsMode_ = OUTSIDE_SQRT; + mvnAttrs.epsMode_ = OUTSIDE_SQRT; } - initAcrossChannels_ = false; + mvnAttrs.initAcrossChannels_ = false; const auto& inDataShapeSize = getInputShapeAtPort(0).getRank(); if (inDataShapeSize == mvnOp->input_value(1).get_shape()[0] + 1 || inDataShapeSize == 1) - initAcrossChannels_ = true; + mvnAttrs.initAcrossChannels_ = true; } else if (auto mvnOp = ngraph::as_type_ptr(op)) { - normalizeVariance_ = mvnOp->get_normalize_variance(); - epsValue_ = mvnOp->get_eps(); - initAcrossChannels_ = mvnOp->get_across_channels(); + mvnAttrs.normalizeVariance_ = mvnOp->get_normalize_variance(); + mvnAttrs.epsValue_ = mvnOp->get_eps(); + mvnAttrs.initAcrossChannels_ = mvnOp->get_across_channels(); } - execAcrossChannels_ = initAcrossChannels_; + mvnAttrs.execAcrossChannels_ = mvnAttrs.initAcrossChannels_; } void MKLDNNMVNNode::getSupportedDescriptors() {} @@ -718,11 +774,8 @@ void MKLDNNMVNNode::initSupportedPrimitiveDescriptors() { inputPrecision = outputPrecision = Precision::FP32; } - src_data_size = inputPrecision.size(); - dst_data_size = outputPrecision.size(); - // TODO [DS]: inplace - bool canBeInplace = !isDynamicNode() && (src_data_size == dst_data_size) && + bool canBeInplace = !isDynamicNode() && (inputPrecision.size() == outputPrecision.size()) && (getParentEdgeAt(0)->getParent()->getChildEdges().size() == 1) && !getParentEdgeAt(0)->getParent()->isConstant(); @@ -781,6 +834,77 @@ void MKLDNNMVNNode::initSupportedPrimitiveDescriptors() { pushDesc(LayoutType::ncsp, impl_type); } +MKLDNNMVNNode::MVNExecutor::MVNExecutor(const MVNAttrs& mvnAttrs) + : mvnAttrs(mvnAttrs), + src_data_size(mvnAttrs.src_prc.size()), + dst_data_size(mvnAttrs.dst_prc.size()) {} + +MKLDNNMVNNode::MVNJitExecutor::MVNJitExecutor(const MVNAttrs& mvnAttrs, + const mkldnn::primitive_attr& attr): + MVNExecutor(mvnAttrs) { + auto jcp = jit_mvn_config_params(); + jcp.src_prc = mvnAttrs.src_prc; + jcp.dst_prc = mvnAttrs.dst_prc; + jcp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.src_prc)); + jcp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.dst_prc)); + jcp.planar_layout = mvnAttrs.planar_layout; + jcp.normalize_variance = mvnAttrs.normalizeVariance_; + jcp.across_channels = mvnAttrs.execAcrossChannels_; + int N = 0; + std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = mvnAttrs.shape5D; + if (mayiuse(cpu::x64::avx512_common)) { + mvn_kernel.reset(new jit_uni_mvn_kernel_f32(jcp, *attr.get())); + jcp.normalize_variance = false; + mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); + if (mvnAttrs.normalizeVariance_) { + jcp.normalize_variance = true; + mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); + } + } else if (mayiuse(cpu::x64::avx2)) { + mvn_kernel.reset(new jit_uni_mvn_kernel_f32(jcp, *attr.get())); + jcp.normalize_variance = false; + mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); + if (mvnAttrs.normalizeVariance_) { + jcp.normalize_variance = true; + mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); + } + } else if (mayiuse(cpu::x64::sse41)) { + mvn_kernel.reset(new jit_uni_mvn_kernel_f32(jcp, *attr.get())); + jcp.normalize_variance = false; + mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); + if (mvnAttrs.normalizeVariance_) { + jcp.normalize_variance = true; + mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); + } + } else { + IE_THROW() << "Can't create jit MVN kernel"; + } + + if (mvn_kernel) + mvn_kernel->create_ker(); + if (mvn_mean_kernel) + mvn_mean_kernel->create_ker(); + if (mvn_variance_kernel) + mvn_variance_kernel->create_ker(); +} + +void MKLDNNMVNNode::MVNJitExecutor::exec(const uint8_t *src_data, uint8_t *dst_data, const void *post_ops_data_) { + if (!mvn_mean_kernel || (mvnAttrs.normalizeVariance_ && !mvn_variance_kernel) || !mvn_kernel) { + IE_THROW() << "MVN layer doesn't create kernel to execute on sse41 above platform."; + } + if (mvnAttrs.planar_layout) { + mvn_pln(src_data, dst_data, post_ops_data_); + } else { + mvn_blk(src_data, dst_data, post_ops_data_); + } +} + +MKLDNNMVNNode::MVNRefExecutor::MVNRefExecutor(const MVNAttrs& mvnAttrs):MVNExecutor(mvnAttrs) {} + +void MKLDNNMVNNode::MVNRefExecutor::exec(const uint8_t *src_data, uint8_t *dst_data, const void *post_ops_data_) { + mvn_ref(src_data, dst_data); +} + void MKLDNNMVNNode::prepareParams() { auto& dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); auto& srcMemPtr = getParentEdgeAt(0)->getMemoryPtr(); @@ -794,59 +918,48 @@ void MKLDNNMVNNode::prepareParams() { const SizeVector in_dims = srcMemPtr->getStaticDims(); transformTo5DCase(in_dims); - setPostOps(attr, true); - if (mayiuse(cpu::x64::sse41)) { auto selectedPD = getSelectedPrimitiveDescriptor(); - auto jcp = jit_mvn_config_params(); - jcp.src_prc = selectedPD->getConfig().inConfs[0].desc->getPrecision(); - jcp.dst_prc = selectedPD->getConfig().outConfs[0].desc->getPrecision(); - jcp.src_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.src_prc)); - jcp.dst_data_size = MKLDNNExtensionUtils::sizeOfDataType(MKLDNNExtensionUtils::IEPrecisionToDataType(jcp.dst_prc)); - jcp.planar_layout = selectedPD->getConfig().inConfs[0].desc->hasLayoutType(LayoutType::ncsp); - jcp.normalize_variance = normalizeVariance_; - jcp.across_channels = execAcrossChannels_; - int N = 0; - std::tie(N, jcp.C, jcp.D, jcp.H, jcp.W) = shape5D; - - if (mayiuse(cpu::x64::avx512_common)) { - mvn_kernel.reset(new jit_uni_mvn_kernel_f32(jcp, *attr.get())); - - jcp.normalize_variance = false; - mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); - if (normalizeVariance_) { - jcp.normalize_variance = true; - mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); - } - } else if (mayiuse(cpu::x64::avx2)) { - mvn_kernel.reset(new jit_uni_mvn_kernel_f32(jcp, *attr.get())); - - jcp.normalize_variance = false; - mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); - if (normalizeVariance_) { - jcp.normalize_variance = true; - mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); - } - } else if (mayiuse(cpu::x64::sse41)) { - mvn_kernel.reset(new jit_uni_mvn_kernel_f32(jcp, *attr.get())); - - jcp.normalize_variance = false; - mvn_mean_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); - if (normalizeVariance_) { - jcp.normalize_variance = true; - mvn_variance_kernel.reset(new jit_uni_mvn_mean_variance_kernel_f32(jcp)); - } - } - - if (mvn_kernel) - mvn_kernel->create_ker(); + mvnAttrs.src_prc = selectedPD->getConfig().inConfs[0].desc->getPrecision(); + mvnAttrs.dst_prc = selectedPD->getConfig().outConfs[0].desc->getPrecision(); + mvnAttrs.planar_layout = selectedPD->getConfig().inConfs[0].desc->hasLayoutType(LayoutType::ncsp); + mvnAttrs.is_nhwc = getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc); + } - if (mvn_mean_kernel) - mvn_mean_kernel->create_ker(); + MVNKey key = {mvnAttrs, mkldnn::primitive_attr()}; + setPostOps(key.attr, true); + + postOpsDataPtrs.clear(); + auto &postOps = (*key.attr.get()).post_ops_; + for (int i = 0; i < postOps.len(); ++i) { + auto &postOp = postOps.entry_[i]; + if (postOp.is_quantization()) { + auto &data = postOp.quantization.data; + postOpsDataPtrs.insert(postOpsDataPtrs.end(), std::begin(data), std::end(data)); + memset(data, 0, sizeof(data)); + } else if (postOp.is_depthwise()) { + auto &weights = postOp.depthwise.weights_data; + auto &biases = postOp.depthwise.biases_data; + postOpsDataPtrs.push_back(weights); + postOpsDataPtrs.push_back(biases); + weights = 0; + biases = 0; + } + } - if (mvn_variance_kernel) - mvn_variance_kernel->create_ker(); + auto builder = [&](const MVNKey& key) -> std::shared_ptr { + std::shared_ptr executor; + if (mayiuse(cpu::x64::sse41)) { + executor = std::make_shared(key.mvnAttrs, key.attr); + } else { + executor = std::make_shared(key.mvnAttrs); } + return executor; + }; + + auto cache = getRuntimeCache(); + auto result = cache->getOrCreate(key, builder); + execPtr = result.first; } void MKLDNNMVNNode::transformTo5DCase(const SizeVector& shape) { @@ -854,26 +967,26 @@ void MKLDNNMVNNode::transformTo5DCase(const SizeVector& shape) { // for 1 and 2 rank, if initAcrossChannels_ is true, adjust shape to fully vectorize under unified 5d procedure. // otherwise there are not enough data in spatial dimension to process in one kernel. case 1 : // C - if (initAcrossChannels_) { - shape5D = std::make_tuple(1, 1, 1, 1, shape[0]); - execAcrossChannels_ = false; + if (mvnAttrs.initAcrossChannels_) { + mvnAttrs.shape5D = std::make_tuple(1, 1, 1, 1, shape[0]); + mvnAttrs.execAcrossChannels_ = false; break; } else { - shape5D = std::make_tuple(1, shape[0], 1, 1, 1); + mvnAttrs.shape5D = std::make_tuple(1, shape[0], 1, 1, 1); break; } case 2 : // NC - if (initAcrossChannels_) { - shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1); - execAcrossChannels_ = false; + if (mvnAttrs.initAcrossChannels_) { + mvnAttrs.shape5D = std::make_tuple(1, shape[0], 1, shape[1], 1); + mvnAttrs.execAcrossChannels_ = false; break; } else { - shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1); + mvnAttrs.shape5D = std::make_tuple(shape[0], shape[1], 1, 1, 1); break; } - case 3 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; } - case 4 : { shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; } - case 5 : { shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; } + case 3 : { mvnAttrs.shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], 1); break; } + case 4 : { mvnAttrs.shape5D = std::make_tuple(shape[0], shape[1], 1, shape[2], shape[3]); break; } + case 5 : { mvnAttrs.shape5D = std::make_tuple(shape[0], shape[1], shape[2], shape[3], shape[4]); break; } default : { IE_THROW() << "MVN layer with name '" << getName() << "' doesn't support planar layout with rank: " << shape.size(); } } } @@ -881,7 +994,7 @@ void MKLDNNMVNNode::transformTo5DCase(const SizeVector& shape) { void MKLDNNMVNNode::setPostOps(mkldnn::primitive_attr &attr, bool initWeights) { mkldnn::post_ops ops; VectorDims postOpDims(5); - std::tie(postOpDims[0], postOpDims[1], postOpDims[2], postOpDims[3], postOpDims[4]) = shape5D; + std::tie(postOpDims[0], postOpDims[1], postOpDims[2], postOpDims[3], postOpDims[4]) = mvnAttrs.shape5D; for (auto &node : fusedWith) { auto* fakeQuantizeNode = dynamic_cast(node.get()); if (fakeQuantizeNode) { @@ -904,27 +1017,18 @@ void MKLDNNMVNNode::executeDynamicImpl(mkldnn::stream strm) { } void MKLDNNMVNNode::execute(mkldnn::stream strm) { + if (!execPtr) { + IE_THROW() << "Can't execute MVN node. Primitive didn't created"; + } auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr(); auto &srcMemPtr = getParentEdgeAt(0)->getMemoryPtr(); uint8_t *dst_data = reinterpret_cast(dstMemPtr->GetPtr()); uint8_t *src_data = reinterpret_cast(srcMemPtr->GetPtr()); - - if (mayiuse(cpu::x64::sse41)) { - if (!mvn_mean_kernel || (normalizeVariance_ && !mvn_variance_kernel) || !mvn_kernel) { - IE_THROW() << "MVN layer with name '" << getName() << "' doesn't create kernel to execute on sse41 above platform."; - } - if (getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::ncsp)) { - mvn_pln(src_data, dst_data); - } else { - mvn_blk(src_data, dst_data); - } - } else { - mvn_ref(src_data, dst_data); - } + execPtr->exec(src_data, dst_data, postOpsDataPtrs.data()); } -void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { +void MKLDNNMVNNode::MVNJitExecutor::mvn_pln(const uint8_t* src_data, uint8_t* dst_data, const void *post_ops_data_) { size_t blk_size = 1; // blk size in vmm if (mayiuse(cpu::x64::avx512_common)) { blk_size = 16; @@ -935,7 +1039,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { } size_t N = 0; size_t C = 0; size_t D = 0; size_t H = 0; size_t W = 0; - std::tie(N, C, D, H, W) = shape5D; + std::tie(N, C, D, H, W) = mvnAttrs.shape5D; size_t C1 = H * W; size_t C2 = C1 * D; @@ -946,7 +1050,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { for (size_t b = 0lu; b < N; b++) { size_t cb = b * C3; - if (execAcrossChannels_) { + if (mvnAttrs.execAcrossChannels_) { // Calculate mean value for one instance in batch // Parallel sum for each channel float C3inv = 1.f / static_cast(C3); @@ -959,6 +1063,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { arg.sum = static_cast(&mean_internal); arg.src_stride = src_stride_size; arg.work_amount = static_cast(C2 / blk_size); // for vector part + arg.post_op_data = post_ops_data_; (*mvn_mean_kernel)(&arg); return mean_internal; }); @@ -967,7 +1072,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { // calculate variance value for one instance in batch // parallel sum for each channel - if (normalizeVariance_) { + if (mvnAttrs.normalizeVariance_) { float variance_temp = 0.0f; variance_temp = parallel_sum(C, variance_temp, [&](size_t c)->float { float variance_internal = 0.0f; @@ -978,15 +1083,16 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { arg.variance = static_cast(&variance_internal); arg.src_stride = src_stride_size; arg.work_amount = static_cast(C2 / blk_size); // vector part + arg.post_op_data = post_ops_data_; (*mvn_variance_kernel)(&arg); return variance_internal; }); float variance = 1.f; - if (epsMode_ == INSIDE_SQRT) - variance /= sqrtf(variance_temp * C3inv + epsValue_); - else if (epsMode_ == OUTSIDE_SQRT) - variance /= sqrtf(variance_temp * C3inv) + epsValue_; + if (mvnAttrs.epsMode_ == INSIDE_SQRT) + variance /= sqrtf(variance_temp * C3inv + mvnAttrs.epsValue_); + else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT) + variance /= sqrtf(variance_temp * C3inv) + mvnAttrs.epsValue_; // mvn for one instance in batch parallel_for(C, [&](int c) { size_t cc = cb + c * C2; @@ -999,6 +1105,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { arg.dst_stride = dst_stride_size; arg.work_amount = static_cast(C2 / blk_size); // work amount for vector part arg.oc_off = sizeof(float) * c; + arg.post_op_data = post_ops_data_; (*mvn_kernel)(&arg); }); } else { @@ -1013,6 +1120,7 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { arg.dst_stride = dst_stride_size; arg.work_amount = static_cast(C2 / blk_size); arg.oc_off = sizeof(float) * c; + arg.post_op_data = post_ops_data_; (*mvn_kernel)(&arg); }); } @@ -1031,21 +1139,22 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { arg.dst_stride = dst_stride_size; arg.work_amount = static_cast(C2 / blk_size); arg.oc_off = static_cast(c * sizeof(float)); + arg.post_op_data = post_ops_data_; (*mvn_mean_kernel)(&arg); mean *= C2inv; - if (normalizeVariance_) { + if (mvnAttrs.normalizeVariance_) { // variance for this channel float variance = 0.f; arg.mean = static_cast(&mean); arg.variance = static_cast(&variance); (*mvn_variance_kernel)(&arg); - if (epsMode_ == INSIDE_SQRT) - variance = 1.f / sqrtf(variance * C2inv + epsValue_); - else if (epsMode_ == OUTSIDE_SQRT) - variance = 1.f / (sqrtf(variance * C2inv) + epsValue_); + if (mvnAttrs.epsMode_ == INSIDE_SQRT) + variance = 1.f / sqrtf(variance * C2inv + mvnAttrs.epsValue_); + else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT) + variance = 1.f / (sqrtf(variance * C2inv) + mvnAttrs.epsValue_); // mvn for this channel (*mvn_kernel)(&arg); @@ -1059,11 +1168,11 @@ void MKLDNNMVNNode::mvn_pln(const uint8_t* src_data, uint8_t* dst_data) { } } -void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) { +void MKLDNNMVNNode::MVNRefExecutor::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) { const float *src_data_ptr = reinterpret_cast(src_data); float *dst_data_ptr = reinterpret_cast(dst_data); size_t N = 0; size_t C = 0; size_t D = 0; size_t H = 0; size_t W = 0; - std::tie(N, C, D, H, W) = shape5D; + std::tie(N, C, D, H, W) = mvnAttrs.shape5D; size_t C1 = H * W; size_t C2 = C1 * D; @@ -1071,7 +1180,7 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) { for (size_t b = 0lu; b < N; b++) { size_t cb = b * C3; - if (execAcrossChannels_) { + if (mvnAttrs.execAcrossChannels_) { // Parallel sum for each channel for mean float C3inv = 1.f / static_cast(C3); float mean_temp = 0.0f; @@ -1087,7 +1196,7 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) { float mean = mean_temp * C3inv; - if (normalizeVariance_) { + if (mvnAttrs.normalizeVariance_) { // parallel sum for each channel for variance float variance_temp = 0.0f; variance_temp = parallel_sum(C, variance_temp, [&](size_t c)->float { @@ -1100,10 +1209,10 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) { }); float variance = 1.f; - if (epsMode_ == INSIDE_SQRT) - variance = 1.f / sqrtf(variance_temp * C3inv + epsValue_); - else if (epsMode_ == OUTSIDE_SQRT) - variance = 1.f / (sqrtf(variance_temp * C3inv) + epsValue_); + if (mvnAttrs.epsMode_ == INSIDE_SQRT) + variance = 1.f / sqrtf(variance_temp * C3inv + mvnAttrs.epsValue_); + else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT) + variance = 1.f / (sqrtf(variance_temp * C3inv) + mvnAttrs.epsValue_); parallel_for(C, [&](int c) { size_t cc = cb + c * C2; @@ -1130,17 +1239,17 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) { } mean *= C2inv; - if (normalizeVariance_) { + if (mvnAttrs.normalizeVariance_) { // variance for this channel float variance = 0.f; for (size_t sp = 0lu; sp < C2; sp++) { variance += (src_data_ptr[cc + sp] - mean) * (src_data_ptr[cc + sp] - mean); } - if (epsMode_ == INSIDE_SQRT) - variance = 1.f / sqrtf(variance * C2inv + epsValue_); - else if (epsMode_ == OUTSIDE_SQRT) - variance = 1.f / (sqrtf(variance * C2inv) + epsValue_); + if (mvnAttrs.epsMode_ == INSIDE_SQRT) + variance = 1.f / sqrtf(variance * C2inv + mvnAttrs.epsValue_); + else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT) + variance = 1.f / (sqrtf(variance * C2inv) + mvnAttrs.epsValue_); // mvn for this channel for (size_t sp = 0lu; sp < C2; sp++) { @@ -1157,7 +1266,7 @@ void MKLDNNMVNNode::mvn_ref(const uint8_t* src_data, uint8_t* dst_data) { } } -void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { +void MKLDNNMVNNode::MVNJitExecutor::mvn_blk(const uint8_t* src_data, uint8_t* dst_data, const void *post_ops_data_) { size_t blk_size = 1; // channel blk for memory layout if (mayiuse(cpu::x64::avx512_common)) { blk_size = 16; @@ -1166,34 +1275,32 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { } size_t N = 1; size_t C = 1; size_t D = 1; size_t H = 1; size_t W = 1; - std::tie(N, C, D, H, W) = shape5D; - - bool is_nhwc = getParentEdgeAt(0)->getMemory().getDesc().hasLayoutType(LayoutType::nspc); + std::tie(N, C, D, H, W) = mvnAttrs.shape5D; size_t CB = div_up(C, blk_size); - size_t C0 = is_nhwc ? W * C : W * blk_size; + size_t C0 = mvnAttrs.is_nhwc ? W * C : W * blk_size; size_t C1 = C0 * H; size_t C2 = C1 * D; size_t C3 = C2 * CB; size_t C5 = C * D * H * W; size_t threads_num = parallel_get_num_threads(); - size_t aux_buffer_size = execAcrossChannels_ ? blk_size : rnd_up(C, blk_size); + size_t aux_buffer_size = mvnAttrs.execAcrossChannels_ ? blk_size : rnd_up(C, blk_size); std::vector mean_buffer(aux_buffer_size * threads_num); std::vector variance_buffer(aux_buffer_size * threads_num); - size_t src_stride_size = is_nhwc ? static_cast(C * src_data_size) : static_cast(blk_size * src_data_size); - size_t dst_stride_size = is_nhwc ? static_cast(C * dst_data_size) : static_cast(blk_size * dst_data_size); + size_t src_stride_size = mvnAttrs.is_nhwc ? static_cast(C * src_data_size) : static_cast(blk_size * src_data_size); + size_t dst_stride_size = mvnAttrs.is_nhwc ? static_cast(C * dst_data_size) : static_cast(blk_size * dst_data_size); for (size_t b = 0lu; b < N; b++) { - size_t b_offset = is_nhwc ? b * C5 : b * C3; - if (execAcrossChannels_) { + size_t b_offset = mvnAttrs.is_nhwc ? b * C5 : b * C3; + if (mvnAttrs.execAcrossChannels_) { // mean for this instance in batch float C5inv = 1.f / static_cast(C5); float mean_temp = 0.0f; mean_temp = parallel_sum3d(CB, D, H, mean_temp, [&](size_t cb, size_t d, size_t h)->float { - size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size + size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size : b_offset + cb * C2 + d * C1 + h * C0; float mean_internal = 0.0f; @@ -1225,11 +1332,11 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { }); float mean = mean_temp * C5inv; - if (normalizeVariance_) { + if (mvnAttrs.normalizeVariance_) { // variance: sum((x-mean)*(x-mean)) for one instance in batch float variance_temp = 0.0f; variance_temp = parallel_sum3d(CB, D, H, variance_temp, [&](size_t cb, size_t d, size_t h)->float { - size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size + size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size : b_offset + cb * C2 + d * C1 + h * C0; float variance_internal = 0.0f; @@ -1244,6 +1351,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { arg.src_stride = src_stride_size; arg.work_amount = static_cast(W); arg.oc_off = cb * blk_size * sizeof(float); + arg.post_op_data = post_ops_data_; (*mvn_variance_kernel)(&arg); size_t min_cb = (std::min)(blk_size, C - cb * blk_size); @@ -1253,13 +1361,13 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { }); float variance = 1.f; - if (epsMode_ == INSIDE_SQRT) - variance /= sqrtf(variance_temp * C5inv + epsValue_); - else if (epsMode_ == OUTSIDE_SQRT) - variance /= sqrtf(variance_temp * C5inv) + epsValue_; + if (mvnAttrs.epsMode_ == INSIDE_SQRT) + variance /= sqrtf(variance_temp * C5inv + mvnAttrs.epsValue_); + else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT) + variance /= sqrtf(variance_temp * C5inv) + mvnAttrs.epsValue_; // mvn for one instance in batch parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) { - size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size + size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size : b_offset + cb * C2 + d * C1 + h * C0; auto arg = jit_mvn_call_args(); arg.src = src_data + src_offset * src_data_size; @@ -1270,12 +1378,13 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { arg.dst_stride = dst_stride_size; arg.work_amount = static_cast(W); arg.oc_off = cb * blk_size * sizeof(float); + arg.post_op_data = post_ops_data_; (*mvn_kernel)(&arg); }); } else { // mvn for one instance in batch parallel_for3d(CB, D, H, [&](size_t cb, size_t d, size_t h) { - size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size + size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size : b_offset + cb * C2 + d * C1 + h * C0; auto arg = jit_mvn_call_args(); arg.src = src_data + src_offset * src_data_size; @@ -1285,6 +1394,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { arg.dst_stride = dst_stride_size; arg.work_amount = static_cast(W); arg.oc_off = cb * blk_size * sizeof(float); + arg.post_op_data = post_ops_data_; (*mvn_kernel)(&arg); }); } @@ -1297,7 +1407,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { // keep the compute order the same as planar parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) { for (size_t cb = 0; cb < CB; cb++) { - size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size + size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size : b_offset + cb * C2 + d * C1 + h * C0; auto mean_buffer_ptr = &mean_buffer[blk_size * cb + aux_buffer_size * thr_idx]; @@ -1307,6 +1417,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { arg.src_stride = src_stride_size; arg.work_amount = static_cast(W); arg.oc_off = cb * blk_size * sizeof(float); + arg.post_op_data = post_ops_data_; (*mvn_mean_kernel)(&arg); } }); @@ -1318,13 +1429,13 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { for (size_t c = 0; c < C; c++) mean_buffer[c] *= size_inv; - if (normalizeVariance_) { + if (mvnAttrs.normalizeVariance_) { for (int i = 0; i < variance_buffer.size(); i++) variance_buffer[i] = 0.f; parallel_for2d(D, H, [&](size_t thr_idx, size_t d, size_t h) { for (size_t cb = 0; cb < CB; cb++) { - size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size + size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size : b_offset + cb * C2 + d * C1 + h * C0; auto mean_buffer_ptr = &mean_buffer[blk_size * cb]; auto variance_buffer_ptr = &variance_buffer[blk_size * cb + aux_buffer_size * thr_idx]; @@ -1336,6 +1447,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { arg.src_stride = src_stride_size; arg.work_amount = static_cast(W); arg.oc_off = cb * blk_size * sizeof(float); + arg.post_op_data = post_ops_data_; (*mvn_variance_kernel)(&arg); } }); @@ -1344,15 +1456,15 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { variance_buffer[c] += variance_buffer[c + aux_buffer_size * i]; } for (size_t c = 0; c < C; c++) { - if (epsMode_ == INSIDE_SQRT) - variance_buffer[c] = 1.f / sqrtf(variance_buffer[c] * size_inv + epsValue_); - else if (epsMode_ == OUTSIDE_SQRT) - variance_buffer[c] = 1.f / (sqrtf(variance_buffer[c] * size_inv) + epsValue_); + if (mvnAttrs.epsMode_ == INSIDE_SQRT) + variance_buffer[c] = 1.f / sqrtf(variance_buffer[c] * size_inv + mvnAttrs.epsValue_); + else if (mvnAttrs.epsMode_ == OUTSIDE_SQRT) + variance_buffer[c] = 1.f / (sqrtf(variance_buffer[c] * size_inv) + mvnAttrs.epsValue_); } parallel_for2d(D, H, [&](size_t d, size_t h) { for (size_t cb = 0; cb < CB; cb++) { - size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size + size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size : b_offset + cb * C2 + d * C1 + h * C0; auto mean_buffer_ptr = &mean_buffer[blk_size * cb]; auto variance_buffer_ptr = &variance_buffer[blk_size * cb]; @@ -1366,6 +1478,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { arg.dst_stride = dst_stride_size; arg.work_amount = static_cast(W); arg.oc_off = cb * blk_size * sizeof(float); + arg.post_op_data = post_ops_data_; (*mvn_kernel)(&arg); } }); @@ -1373,7 +1486,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { // normalizeVariance_ == false parallel_for2d(D, H, [&](size_t d, size_t h) { for (size_t cb = 0; cb < CB; cb++) { - size_t src_offset = is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size + size_t src_offset = mvnAttrs.is_nhwc ? b_offset + d * C1 + h * C0 + cb * blk_size : b_offset + cb * C2 + d * C1 + h * C0; auto mean_buffer_ptr = &mean_buffer[blk_size * cb]; @@ -1385,6 +1498,7 @@ void MKLDNNMVNNode::mvn_blk(const uint8_t* src_data, uint8_t* dst_data) { arg.dst_stride = dst_stride_size; arg.work_amount = static_cast(W); arg.oc_off = cb * blk_size * sizeof(float); + arg.post_op_data = post_ops_data_; (*mvn_kernel)(&arg); } }); @@ -1404,7 +1518,7 @@ bool MKLDNNMVNNode::canFuse(const MKLDNNNodePtr& node) const { EltwiseSwish, EltwiseHswish, EltwiseMish, EltwiseHsigmoid, EltwiseRoundHalfToEven, EltwiseRoundHalfAwayFromZero, EltwiseAbs, EltwiseSqrt, EltwiseSoftRelu); if ((inputRank == 1 && !unaryEltwise) || - (inputRank == 2 && !unaryEltwise && initAcrossChannels_)) { + (inputRank == 2 && !unaryEltwise && mvnAttrs.initAcrossChannels_)) { return false; } diff --git a/src/plugins/intel_cpu/src/nodes/mkldnn_mvn_node.h b/src/plugins/intel_cpu/src/nodes/mkldnn_mvn_node.h index 755961ea0808c4..0af92924b041e2 100644 --- a/src/plugins/intel_cpu/src/nodes/mkldnn_mvn_node.h +++ b/src/plugins/intel_cpu/src/nodes/mkldnn_mvn_node.h @@ -35,6 +35,7 @@ struct jit_mvn_call_args { size_t dst_stride; size_t work_amount; size_t oc_off; + const void* post_op_data; }; struct jit_uni_mvn_mean_variance_kernel { @@ -85,50 +86,82 @@ class MKLDNNMVNNode : public MKLDNNNode { } inline bool getAcrossChannels() const { - return initAcrossChannels_; + return mvnAttrs.initAcrossChannels_; } inline bool getNormalizeVariance() const { - return normalizeVariance_; + return mvnAttrs.normalizeVariance_; } bool canFuse(const MKLDNNNodePtr& node) const override; - void prepareParams() override; + // Defines way to add epsilon: inside sqrt or outside. + enum MVNEpsMode { + INSIDE_SQRT, + OUTSIDE_SQRT + }; + struct MVNAttrs { + bool planar_layout; + std::tuple shape5D; + bool initAcrossChannels_; + bool execAcrossChannels_; + bool normalizeVariance_; + float epsValue_; + MVNEpsMode epsMode_; + bool is_nhwc; + InferenceEngine::Precision src_prc; + InferenceEngine::Precision dst_prc; + }; + private: - void mvn_pln(const uint8_t *src_data, uint8_t *dst_data); + void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false); - void mvn_blk(const uint8_t *src_data, uint8_t *dst_data); + void transformTo5DCase(const InferenceEngine::SizeVector& shape); - void mvn_ref(const uint8_t *src_data, uint8_t *dst_data); + std::vector postOpsDataPtrs; - void setPostOps(mkldnn::primitive_attr &attr, bool initWeights = false); + MVNAttrs mvnAttrs; - void transformTo5DCase(const InferenceEngine::SizeVector& shape); + class MVNExecutor { + public: + MVNExecutor(const MVNAttrs& mvnAttrs); + virtual void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) = 0; + virtual ~MVNExecutor() = default; - std::tuple shape5D; + protected: + MVNAttrs mvnAttrs; + size_t src_data_size = 0; + size_t dst_data_size = 0; + }; - bool initAcrossChannels_ = false; - bool execAcrossChannels_ = false; - bool normalizeVariance_ = true; - float epsValue_ = 1e-9f; - // Defines way to add epsilon: inside sqrt or outside. - enum MVNEpsMode { - INSIDE_SQRT, - OUTSIDE_SQRT + std::shared_ptr execPtr = nullptr; + + class MVNJitExecutor : public MVNExecutor { + public: + MVNJitExecutor(const MVNAttrs& mvnAttrs, + const mkldnn::primitive_attr &attr); + + void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) override; + + private: + void mvn_pln(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_); + void mvn_blk(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_); + + std::shared_ptr mvn_mean_kernel; + std::shared_ptr mvn_variance_kernel; + std::shared_ptr mvn_kernel; }; - MVNEpsMode epsMode_; - InferenceEngine::Precision input_prec, output_prec; - size_t src_data_size = 0; - size_t dst_data_size = 0; + class MVNRefExecutor : public MVNExecutor { + public: + MVNRefExecutor(const MVNAttrs& mvnAttrs); - mkldnn::primitive_attr attr; + void exec(const uint8_t *in_ptr_, uint8_t *out_ptr_, const void *post_ops_data_) override; - std::shared_ptr mvn_mean_kernel; - std::shared_ptr mvn_variance_kernel; - std::shared_ptr mvn_kernel; + private: + void mvn_ref(const uint8_t *in_ptr_, uint8_t *out_ptr_); + }; }; } // namespace MKLDNNPlugin diff --git a/src/tests/functional/plugin/cpu/single_layer_tests/mvn.cpp b/src/tests/functional/plugin/cpu/single_layer_tests/mvn.cpp index 0a6cb3a35f62d6..e0c0cdac9921c9 100644 --- a/src/tests/functional/plugin/cpu/single_layer_tests/mvn.cpp +++ b/src/tests/functional/plugin/cpu/single_layer_tests/mvn.cpp @@ -1,406 +1,504 @@ -//// Copyright (C) 2018-2022 Intel Corporation -//// SPDX-License-Identifier: Apache-2.0 -//// -// -//#include -//#include "ngraph_functions/builders.hpp" -//#include "test_utils/cpu_test_utils.hpp" -//#include "test_utils/fusing_test_utils.hpp" -// -//using namespace InferenceEngine; -//using namespace CPUTestUtils; -// -//namespace CPULayerTestsDefinitions { -// -//using basicCpuMvnParams = std::tuple< -// std::pair, std::vector>, // Input shapes -// InferenceEngine::Precision, // Input precision -// ngraph::AxisSet, // Reduction axes -// bool, // Across channels -// bool, // Normalize variance -// double>; // Epsilon -// -//typedef std::tuple< -// basicCpuMvnParams, -// CPUSpecificParams, -// fusingSpecificParams, -// Precision, // CNNNetwork input precision -// Precision> // CNNNetwork output precision -//MvnLayerCPUTestParamSet; -// -//class MvnLayerCPUTest : public testing::WithParamInterface, -// virtual public LayerTestsUtils::LayerTestsCommon, public CpuTestWithFusing { -//public: -// static std::string getTestCaseName(testing::TestParamInfo obj) { -// basicCpuMvnParams basicParamsSet; -// CPUSpecificParams cpuParams; -// fusingSpecificParams fusingParams; -// Precision inputPrecision, outputPrecision; -// std::tie(basicParamsSet, cpuParams, fusingParams, inputPrecision, outputPrecision) = obj.param; -// -// std::pair, std::vector> inputShapes; -// InferenceEngine::Precision netPrecision; -// ngraph::AxisSet axes; -// bool acrossChanels, normalizeVariance; -// double eps; -// std::tie(inputShapes, netPrecision, axes, acrossChanels, normalizeVariance, eps) = basicParamsSet; -// -// std::ostringstream result; -// if (!inputShapes.first.empty()) { -// result << "IS=" << CommonTestUtils::partialShape2str(inputShapes.first) << "_"; -// } -// result << "TS="; -// for (const auto& shape : inputShapes.second) { -// result << "(" << CommonTestUtils::vec2str(shape) << ")_"; -// } -// result << "Precision=" << netPrecision.name() << "_"; -// if (!axes.empty()) { -// result << "ReductionAccess=" << CommonTestUtils::vec2str(axes.to_vector()) << "_"; -// } else { -// result << "AcrossChannels=" << (acrossChanels ? "TRUE" : "FALSE") << "_"; -// } -// result << "NormalizeVariance=" << (normalizeVariance ? "TRUE" : "FALSE") << "_"; -// result << "Epsilon=" << eps; -// result << "_" << "CNNInpPrc=" << inputPrecision.name(); -// result << "_" << "CNNOutPrc=" << outputPrecision.name(); -// -// result << CPUTestsBase::getTestCaseName(cpuParams); -// -// result << CpuTestWithFusing::getTestCaseName(fusingParams); -// -// return result.str(); -// } -//protected: -// void SetUp() override { -// targetDevice = CommonTestUtils::DEVICE_CPU; -// -// basicCpuMvnParams basicParamsSet; -// CPUSpecificParams cpuParams; -// fusingSpecificParams fusingParams; -// std::tie(basicParamsSet, cpuParams, fusingParams, inPrc, outPrc) = this->GetParam(); -// -// std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; -// std::tie(postOpMgrPtr, fusedOps) = fusingParams; -// -// std::pair, std::vector> inputShapes; -// InferenceEngine::Precision netPrecision; -// ngraph::AxisSet axes; -// bool acrossChanels, normalizeVariance; -// double eps; -// std::tie(inputShapes, netPrecision, axes, acrossChanels, normalizeVariance, eps) = basicParamsSet; -// -// for (size_t i = 0; i < inputShapes.second.size(); i++) { -// targetStaticShapes.push_back({inputShapes.second[i]}); -// } -// inputDynamicShapes = inputShapes.first; -// -// auto netPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(netPrecision); -// auto param = ngraph::builder::makeParams(netPrc, {targetStaticShapes[0].front()}); -// auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(param)); -// auto mvn = ngraph::builder::makeMVN(paramOuts[0], acrossChanels, normalizeVariance, eps); -// if (!axes.empty()) { -// mvn = ngraph::builder::makeMVN(paramOuts[0], axes, normalizeVariance, eps); -// } -// -// selectedType = getPrimitiveType() + "_" + inPrc.name(); -// -// threshold = 0.015f; -// function = makeNgraphFunction(netPrc, param, mvn, "mvn"); -// } -//}; -// -//TEST_P(MvnLayerCPUTest, CompareWithRefs) { -// SKIP_IF_CURRENT_TEST_IS_DISABLED() -// -// Run(); -// CheckPluginRelatedResults(executableNetwork, "MVN"); -//} -// -//namespace { -// -//const std::vector, std::vector>> inputShapes_1D = { -// { {}, {{5}}}, -// { {}, {{16}}}, -// { -// // dynamic -// {{-1}}, -// // target -// { -// {2}, -// {16}, -// {1} -// } -// }, -// { -// // dynamic -// {{{1, 20}}}, -// // target -// { -// {1}, -// {16}, -// {4} -// } -// } -//}; -// -//const std::vector, std::vector>> inputShapes_2D = { -// { {}, {{1, 32}}}, -// { {}, {{16, 64}}}, -// { -// // dynamic -// {{-1, -1}}, -// // target -// { -// {2, 16}, -// {4, 16}, -// {1, 16} -// } -// }, -// { -// // dynamic -// {{{1, 5}, {1, 20}}}, -// // target -// { -// {1, 1}, -// {2, 16}, -// {4, 16} -// } -// } -//}; -// -//const std::vector, std::vector>> inputShapes_3D = { -// { {}, {{1, 32, 17}}}, -// { {}, {{1, 37, 9}}}, -// { {}, {{1, 16, 4}}}, -// { -// // dynamic -// {{-1, -1, -1}}, -// // target -// { -// {2, 16, 6}, -// {4, 16, 2}, -// {1, 16, 4} -// } -// }, -// { -// // dynamic -// {{{1, 5}, {1, 20}, {1, 7}}}, -// // target -// { -// {1, 1, 1}, -// {2, 16, 6}, -// {4, 16, 2} -// } -// } -//}; -// -//const std::vector, std::vector>> inputShapes_4D = { -// { {}, {{1, 16, 5, 8}}}, -// { {}, {{2, 19, 5, 10}}}, -// { {}, {{7, 32, 2, 8}}}, -// { {}, {{5, 8, 3, 5}}}, -// { {}, {{1, 2, 7, 5}}}, -// { {}, {{1, 4, 5, 5}}}, -// { {}, {{1, 7, 3, 5}}}, -// { {}, {{1, 15, 9, 5}}}, -// { {}, {{4, 41, 6, 9}}}, -// { -// // dynamic -// {{-1, -1, -1, -1}}, -// // target -// { -// {2, 16, 10, 6}, -// {4, 16, 2, 2}, -// {1, 16, 8, 4} -// } -// }, -// { -// // dynamic -// {{{1, 5}, {1, 20}, {1, 10}, {1, 7}}}, -// // target -// { -// {1, 1, 1, 1}, -// {2, 16, 10, 6}, -// {4, 16, 2, 2} -// } -// } -//}; -// -//const std::vector, std::vector>> inputShapes_5D = { -// { {}, {{1, 32, 8, 1, 6}}}, -// { {}, {{1, 9, 1, 15, 9}}}, -// { {}, {{6, 64, 6, 1, 18}}}, -// { {}, {{2, 31, 2, 9, 1}}}, -// { {}, {{10, 16, 5, 10, 6}}}, -// { -// // dynamic -// {{-1, -1, -1, -1, -1}}, -// // target -// { -// {2, 16, 5, 10, 6}, -// {4, 16, 7, 2, 2}, -// {1, 16, 11, 8, 4} -// } -// }, -// { -// // dynamic -// {{{1, 5}, {1, 20}, {1, 7}, {1, 10}, {1, 7}}}, -// // target -// { -// {1, 1, 1, 1, 1}, -// {2, 16, 5, 10, 6}, -// {4, 16, 7, 2, 2} -// } -// } -//}; -// -//const std::vector acrossChannels = { -// true, -// false -//}; -// -//const std::vector normalizeVariance = { -// true, -// false -//}; -// -//const std::vector epsilon = { -// 0.000000001 -//}; -// -//const std::vector emptyReductionAxes = {{}}; -// -//std::vector inpPrc = {Precision::I8, Precision::BF16, Precision::FP32}; -//std::vector outPrc = {Precision::BF16, Precision::FP32}; -// -//std::vector cpuParams_4D = { -// CPUSpecificParams({nhwc}, {nhwc}, {}, {}), -// CPUSpecificParams({nChw16c}, {nChw16c}, {}, {}), -// CPUSpecificParams({nchw}, {nchw}, {}, {}) -//}; -// -//std::vector cpuParams_5D = { -// CPUSpecificParams({ndhwc}, {ndhwc}, {}, {}), -// CPUSpecificParams({nCdhw16c}, {nCdhw16c}, {}, {}), -// CPUSpecificParams({ncdhw}, {ncdhw}, {}, {}) -//}; -// -//std::vector fusingParamsSet { -// emptyFusingSpec, -// /* activations */ -// fusingRelu, -// fusingElu, -// fusingTanh, -// fusingSwish, -// /* FQ */ -// fusingFakeQuantizePerChannel, -// fusingFakeQuantizePerChannelRelu, -// fusingFakeQuantizePerTensorRelu, -// /* another patterns */ -// fusingScaleShift, -// fusingAddPerTensor -//}; -// -//const auto Mvn3D = ::testing::Combine( -// ::testing::Combine( -// ::testing::ValuesIn(inputShapes_3D), -// ::testing::Values(InferenceEngine::Precision::FP32), -// ::testing::ValuesIn(emptyReductionAxes), -// ::testing::ValuesIn(acrossChannels), -// ::testing::ValuesIn(normalizeVariance), -// ::testing::ValuesIn(epsilon)), -// ::testing::Values(emptyCPUSpec), -// ::testing::ValuesIn(fusingParamsSet), -// ::testing::ValuesIn(inpPrc), -// ::testing::ValuesIn(outPrc)); -// -//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn3D, MvnLayerCPUTest, Mvn3D, MvnLayerCPUTest::getTestCaseName); -// -//const auto Mvn4D = ::testing::Combine( -// ::testing::Combine( -// ::testing::ValuesIn(inputShapes_4D), -// ::testing::Values(InferenceEngine::Precision::FP32), -// ::testing::ValuesIn(emptyReductionAxes), -// ::testing::ValuesIn(acrossChannels), -// ::testing::ValuesIn(normalizeVariance), -// ::testing::ValuesIn(epsilon)), -// ::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D)), -// ::testing::ValuesIn(fusingParamsSet), -// ::testing::ValuesIn(inpPrc), -// ::testing::ValuesIn(outPrc)); -// -//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn4D, MvnLayerCPUTest, Mvn4D, MvnLayerCPUTest::getTestCaseName); -// -//const auto Mvn5D = ::testing::Combine( -// ::testing::Combine( -// ::testing::ValuesIn(inputShapes_5D), -// ::testing::Values(InferenceEngine::Precision::FP32), -// ::testing::ValuesIn(emptyReductionAxes), -// ::testing::ValuesIn(acrossChannels), -// ::testing::ValuesIn(normalizeVariance), -// ::testing::ValuesIn(epsilon)), -// ::testing::ValuesIn(filterCPUSpecificParams(cpuParams_5D)), -// ::testing::ValuesIn(fusingParamsSet), -// ::testing::ValuesIn(inpPrc), -// ::testing::ValuesIn(outPrc)); -// -//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn5D, MvnLayerCPUTest, Mvn5D, MvnLayerCPUTest::getTestCaseName); -// -//// 1D 2D case -//std::vector fusingUnaryEltwiseParamsSet { -// /* activations */ -// fusingRelu, -// fusingElu, -// fusingTanh, -// fusingSwish, -//}; -// -//const auto Mvn1D = ::testing::Combine( -// ::testing::Combine( -// ::testing::ValuesIn(inputShapes_1D), -// ::testing::Values(InferenceEngine::Precision::FP32), -// ::testing::ValuesIn(emptyReductionAxes), -// ::testing::ValuesIn(acrossChannels), -// ::testing::ValuesIn(normalizeVariance), -// ::testing::ValuesIn(epsilon)), -// ::testing::Values(emptyCPUSpec), -// ::testing::ValuesIn(fusingUnaryEltwiseParamsSet), -// ::testing::ValuesIn(inpPrc), -// ::testing::ValuesIn(outPrc)); -// -//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn1D, MvnLayerCPUTest, Mvn1D, MvnLayerCPUTest::getTestCaseName); -// -//// 2D no transformed -//const auto Mvn2D = ::testing::Combine( -// ::testing::Combine( -// ::testing::ValuesIn(inputShapes_2D), -// ::testing::Values(InferenceEngine::Precision::FP32), -// ::testing::ValuesIn(emptyReductionAxes), -// ::testing::Values(false), -// ::testing::ValuesIn(normalizeVariance), -// ::testing::ValuesIn(epsilon)), -// ::testing::Values(emptyCPUSpec), -// ::testing::ValuesIn(fusingParamsSet), -// ::testing::ValuesIn(inpPrc), -// ::testing::ValuesIn(outPrc)); -// -//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn2D, MvnLayerCPUTest, Mvn2D, MvnLayerCPUTest::getTestCaseName); -// -//// 2d transformed -//const auto Mvn2DTrans = ::testing::Combine( -// ::testing::Combine( -// ::testing::ValuesIn(inputShapes_2D), -// ::testing::Values(InferenceEngine::Precision::FP32), -// ::testing::ValuesIn(emptyReductionAxes), -// ::testing::Values(true), -// ::testing::ValuesIn(normalizeVariance), -// ::testing::ValuesIn(epsilon)), -// ::testing::Values(emptyCPUSpec), -// ::testing::ValuesIn(fusingUnaryEltwiseParamsSet), -// ::testing::ValuesIn(inpPrc), -// ::testing::ValuesIn(outPrc)); -// -//INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_MVN2DTrans, MvnLayerCPUTest, Mvn2DTrans, MvnLayerCPUTest::getTestCaseName); -// -//} // namespace -//} // namespace CPULayerTestsDefinitions \ No newline at end of file +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include "ngraph_functions/builders.hpp" +#include "test_utils/cpu_test_utils.hpp" +#include "test_utils/fusing_test_utils.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" + +using namespace InferenceEngine; +using namespace CPUTestUtils; +using namespace ov::test; + +namespace CPULayerTestsDefinitions { + +using basicCpuMvnParams = std::tuple< + InputShape, // Input shapes + ElementType, // Input precision + ngraph::AxisSet, // Reduction axes + bool, // Across channels + bool, // Normalize variance + double>; // Epsilon + +using MvnLayerCPUTestParamSet = std::tuple< + basicCpuMvnParams, + CPUSpecificParams, + fusingSpecificParams, + ElementType, // CNNNetwork input precision + ElementType>; // CNNNetwork output precision + +class MvnLayerCPUTest : public testing::WithParamInterface, + virtual public SubgraphBaseTest, public CpuTestWithFusing { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + basicCpuMvnParams basicParamsSet; + CPUSpecificParams cpuParams; + fusingSpecificParams fusingParams; + ElementType inputPrecision, outputPrecision; + std::tie(basicParamsSet, cpuParams, fusingParams, inputPrecision, outputPrecision) = obj.param; + + InputShape inputShapes; + ElementType netPrecision; + ngraph::AxisSet axes; + bool acrossChanels, normalizeVariance; + double eps; + std::tie(inputShapes, netPrecision, axes, acrossChanels, normalizeVariance, eps) = basicParamsSet; + + std::ostringstream result; + result << "IS=" << CommonTestUtils::partialShape2str({inputShapes.first}) << "_"; + result << "TS="; + for (const auto& shape : inputShapes.second) { + result << "(" << CommonTestUtils::vec2str(shape) << ")_"; + } + result << "Precision=" << netPrecision << "_"; + if (!axes.empty()) { + result << "ReductionAccess=" << CommonTestUtils::vec2str(axes.to_vector()) << "_"; + } else { + result << "AcrossChannels=" << (acrossChanels ? "TRUE" : "FALSE") << "_"; + } + result << "NormalizeVariance=" << (normalizeVariance ? "TRUE" : "FALSE") << "_"; + result << "Epsilon=" << eps; + result << "_" << "CNNInpPrc=" << inputPrecision; + result << "_" << "CNNOutPrc=" << outputPrecision; + + result << CPUTestsBase::getTestCaseName(cpuParams); + + result << CpuTestWithFusing::getTestCaseName(fusingParams); + + return result.str(); + } +protected: + void SetUp() override { + targetDevice = CommonTestUtils::DEVICE_CPU; + + basicCpuMvnParams basicParamsSet; + CPUSpecificParams cpuParams; + fusingSpecificParams fusingParams; + ElementType inPrc; + ElementType outPrc; + std::tie(basicParamsSet, cpuParams, fusingParams, inPrc, outPrc) = this->GetParam(); + + std::tie(inFmts, outFmts, priority, selectedType) = cpuParams; + std::tie(postOpMgrPtr, fusedOps) = fusingParams; + + InputShape inputShapes; + ElementType netPrecision; + ngraph::AxisSet axes; + bool acrossChanels, normalizeVariance; + double eps; + std::tie(inputShapes, netPrecision, axes, acrossChanels, normalizeVariance, eps) = basicParamsSet; + + init_input_shapes({inputShapes}); + + auto param = ngraph::builder::makeDynamicParams(netPrecision, inputDynamicShapes); + auto paramOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes(param)); + auto mvn = ngraph::builder::makeMVN(paramOuts[0], acrossChanels, normalizeVariance, eps); + if (!axes.empty()) { + mvn = ngraph::builder::makeMVN(paramOuts[0], axes, normalizeVariance, eps); + } + + selectedType = getPrimitiveType(); + selectedType = makeSelectedTypeStr(selectedType, netPrecision); + + rel_threshold = 0.015f; + function = makeNgraphFunction(netPrecision, param, mvn, "mvn"); + } +}; + +TEST_P(MvnLayerCPUTest, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + + run(); + CheckPluginRelatedResults(executableNetwork, "MVN"); +} + +namespace { + +const std::vector inputShapes_1D = { + { {}, {{5}}}, + { {}, {{16}}}, + { + // dynamic + {-1}, + // target + { + {2}, + {16}, + {1}, + {2} + } + }, + { + // dynamic + {{1, 20}}, + // target + { + {1}, + {16}, + {4}, + {16} + } + } +}; + +const std::vector inputShapes_2D = { + { {}, {{1, 32}}}, + { {}, {{16, 64}}}, + + { + // dynamic + {-1, -1}, + // target + { + {2, 16}, + {4, 16}, + {1, 16}, + {4, 16} + } + }, + { + // dynamic + {{1, 5}, {1, 20}}, + // target + { + {1, 1}, + {2, 16}, + {4, 16}, + {2, 16} + } + } +}; + +const std::vector inputShapes_3D = { + { {}, {{1, 32, 17}}}, + { {}, {{1, 37, 9}}}, + { {}, {{1, 16, 4}}}, + { + // dynamic + {-1, -1, -1}, + // target + { + {2, 16, 6}, + {4, 16, 2}, + {2, 16, 6}, + {4, 16, 2} + } + }, + { + // dynamic + {{1, 5}, {1, 20}, {1, 7}}, + // target + { + {1, 1, 1}, + {2, 16, 6}, + {4, 16, 2}, + {2, 16, 6} + } + } +}; + +const std::vector inputShapes_4D = { + { {}, {{1, 16, 5, 8}}}, + { {}, {{2, 19, 5, 10}}}, + { {}, {{7, 32, 2, 8}}}, + { {}, {{5, 8, 3, 5}}}, + { {}, {{1, 2, 7, 5}}}, + { {}, {{1, 4, 5, 5}}}, + { {}, {{1, 7, 3, 5}}}, + { {}, {{1, 15, 9, 5}}}, + { {}, {{4, 41, 6, 9}}}, + { + // dynamic + {-1, -1, -1, -1}, + // target + { + {2, 16, 10, 6}, + {4, 16, 2, 2}, + {2, 16, 10, 6}, + {4, 16, 2, 2} + } + }, + { + // dynamic + {{1, 5}, {1, 20}, {1, 10}, {1, 7}}, + // target + { + {1, 1, 1, 1}, + {2, 16, 10, 6}, + {4, 16, 2, 2}, + {2, 16, 10, 6} + } + } +}; + +const std::vector inputShapes_5D = { + { {}, {{1, 32, 8, 1, 6}}}, + { {}, {{1, 9, 1, 15, 9}}}, + { {}, {{6, 64, 6, 1, 18}}}, + { {}, {{2, 31, 2, 9, 1}}}, + { {}, {{10, 16, 5, 10, 6}}}, + { + // dynamic + {-1, -1, -1, -1, -1}, + // target + { + {2, 16, 5, 10, 6}, + {4, 16, 7, 2, 2}, + {2, 16, 5, 10, 6}, + {4, 16, 7, 2, 2} + } + }, + { + // dynamic + {{1, 5}, {1, 20}, {1, 7}, {1, 10}, {1, 7}}, + // target + { + {1, 1, 1, 1, 1}, + {2, 16, 5, 10, 6}, + {4, 16, 7, 2, 2}, + {2, 16, 5, 10, 6} + } + } +}; + +const std::vector acrossChannels = { + true, + false +}; + +const std::vector normalizeVariance = { + true, + false +}; + +const std::vector epsilon = { + 0.000000001 +}; + +const std::vector emptyReductionAxes = {{}}; + +std::vector inpPrc = {ElementType::i8, ElementType::bf16, ElementType::f32}; +std::vector outPrc = {ElementType::bf16, ElementType::f32}; + +std::vector cpuParams_4D = { + CPUSpecificParams({nhwc}, {nhwc}, {}, {}), + CPUSpecificParams({nChw16c}, {nChw16c}, {}, {}), + CPUSpecificParams({nchw}, {nchw}, {}, {}) +}; + +std::vector cpuParams_5D = { + CPUSpecificParams({ndhwc}, {ndhwc}, {}, {}), + CPUSpecificParams({nCdhw16c}, {nCdhw16c}, {}, {}), + CPUSpecificParams({ncdhw}, {ncdhw}, {}, {}) +}; + +std::vector fusingParamsSet { + emptyFusingSpec, + /* activations */ + fusingRelu, + fusingElu, + fusingTanh, + fusingSwish, + /* FQ */ + fusingFakeQuantizePerTensorRelu, + /* another patterns */ + fusingAddPerTensor +}; + +std::vector fusingParamsSetStaticShape { + /* FQ */ + fusingFakeQuantizePerChannel, + fusingFakeQuantizePerChannelRelu, + /* another patterns */ + fusingScaleShift, +}; + +const auto Mvn3D = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inputShapes_3D), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::ValuesIn(acrossChannels), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::Values(emptyCPUSpec), + ::testing::ValuesIn(fusingParamsSet), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn3D, MvnLayerCPUTest, Mvn3D, MvnLayerCPUTest::getTestCaseName); + +const auto Mvn4D = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inputShapes_4D), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::ValuesIn(acrossChannels), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D)), + ::testing::ValuesIn(fusingParamsSet), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn4D, MvnLayerCPUTest, Mvn4D, MvnLayerCPUTest::getTestCaseName); + +const auto Mvn5D = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inputShapes_5D), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::ValuesIn(acrossChannels), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::ValuesIn(filterCPUSpecificParams(cpuParams_5D)), + ::testing::ValuesIn(fusingParamsSet), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn5D, MvnLayerCPUTest, Mvn5D, MvnLayerCPUTest::getTestCaseName); + +// 1D 2D case +std::vector fusingUnaryEltwiseParamsSet { + /* activations */ + fusingRelu, + fusingElu, + fusingTanh, + fusingSwish, +}; + +const auto Mvn1D = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inputShapes_1D), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::ValuesIn(acrossChannels), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::Values(emptyCPUSpec), + ::testing::ValuesIn(fusingUnaryEltwiseParamsSet), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn1D, MvnLayerCPUTest, Mvn1D, MvnLayerCPUTest::getTestCaseName); + +// 2D no transformed +const auto Mvn2D = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inputShapes_2D), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::Values(false), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::Values(emptyCPUSpec), + ::testing::ValuesIn(fusingParamsSet), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn2D, MvnLayerCPUTest, Mvn2D, MvnLayerCPUTest::getTestCaseName); + +// 2d transformed +const auto Mvn2DTrans = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inputShapes_2D), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::Values(true), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::Values(emptyCPUSpec), + ::testing::ValuesIn(fusingUnaryEltwiseParamsSet), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn2DTrans, MvnLayerCPUTest, Mvn2DTrans, MvnLayerCPUTest::getTestCaseName); + +// Static shape test for some specific fusing parameters in fusingParamsSetStaticShape + +const std::vector inputShapesStatic_2D = { + {1}, + {16}, + {4} +}; + +const std::vector inputShapesStatic_3D = { + {2, 16, 6}, + {4, 16, 2}, + {1, 16, 4} +}; + +const std::vector inputShapesStatic_4D = { + {1, 7, 3, 5}, + {1, 15, 9, 5}, + {4, 41, 6, 9} +}; + +const std::vector inputShapesStatic_5D = { + {1, 32, 8, 1, 6}, + {1, 9, 1, 15, 9}, + {6, 64, 6, 1, 18} +}; + +const auto Mvn2DStatic = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inputShapesStatic_2D), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::Values(false), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::Values(emptyCPUSpec), + ::testing::ValuesIn(fusingParamsSetStaticShape), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +const auto Mvn3DStatic = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesStatic_3D)), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::ValuesIn(acrossChannels), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::Values(emptyCPUSpec), + ::testing::ValuesIn(fusingParamsSetStaticShape), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn3D_Static, MvnLayerCPUTest, Mvn3DStatic, MvnLayerCPUTest::getTestCaseName); + +const auto Mvn4DStatic = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesStatic_4D)), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::ValuesIn(acrossChannels), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::ValuesIn(filterCPUSpecificParams(cpuParams_4D)), + ::testing::ValuesIn(fusingParamsSetStaticShape), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn4D_Static, MvnLayerCPUTest, Mvn4DStatic, MvnLayerCPUTest::getTestCaseName); + +const auto Mvn5DStatic = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(static_shapes_to_test_representation(inputShapesStatic_5D)), + ::testing::Values(ElementType::f32), + ::testing::ValuesIn(emptyReductionAxes), + ::testing::ValuesIn(acrossChannels), + ::testing::ValuesIn(normalizeVariance), + ::testing::ValuesIn(epsilon)), + ::testing::ValuesIn(filterCPUSpecificParams(cpuParams_5D)), + ::testing::ValuesIn(fusingParamsSetStaticShape), + ::testing::ValuesIn(inpPrc), + ::testing::ValuesIn(outPrc)); + +INSTANTIATE_TEST_SUITE_P(smoke_CompareWithRefs_Mvn5D_Static, MvnLayerCPUTest, Mvn5DStatic, MvnLayerCPUTest::getTestCaseName); + +} // namespace +} // namespace CPULayerTestsDefinitions \ No newline at end of file diff --git a/src/tests/ngraph_helpers/ngraph_functions/src/mvn.cpp b/src/tests/ngraph_helpers/ngraph_functions/src/mvn.cpp index 3e4801a692704e..6fa430627bfaa8 100644 --- a/src/tests/ngraph_helpers/ngraph_functions/src/mvn.cpp +++ b/src/tests/ngraph_helpers/ngraph_functions/src/mvn.cpp @@ -16,7 +16,7 @@ std::shared_ptr makeMVN(const ngraph::Output &in, // Ngraph MVN implementation implicitly adds 0th dimension to reduction axes set which is not valid behavior ngraph::AxisSet axes; const size_t startAxis = acrossChannels ? 1 : 2; - const size_t numOfDims = in.get_shape().size(); + const size_t numOfDims = in.get_partial_shape().size(); for (size_t i = startAxis; i < numOfDims; i++) axes.insert(i); mvnNode->set_reduction_axes(axes);