Skip to content

Commit

Permalink
Introduce arbitrary order desc creator
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Nov 27, 2023
1 parent ad723b5 commit 504b6e9
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 34 deletions.
1 change: 0 additions & 1 deletion src/plugins/intel_cpu/src/memory_desc/cpu_memory_desc.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ enum class LayoutType : unsigned {
ncsp, // general planar
nCsp8c, // general channels blocked by 8
nCsp16c, // general channels blocked by 16
cabd
};

class MemoryDesc {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "arbitrary_order_desc_creator.h"

namespace ov {
namespace intel_cpu {

ArbitraryOrderDescCreator::ArbitraryOrderDescCreator(VectorDims order) :
m_order(std::move(order)) {
OPENVINO_ASSERT(std::adjacent_find(m_order.begin(), m_order.end()) == m_order.end(),
"Can't construct ArbitraryOrderDescCreator, order vector contains repetitive elements",
vec2str(m_order));
}

CpuBlockedMemoryDesc
ArbitraryOrderDescCreator::createDesc(const ov::element::Type& precision, const Shape& srcShape) const {
auto&& dims = srcShape.getDims();
OPENVINO_ASSERT(dims.size() == m_order.size(),
"Couldn't create a tensor descriptor, shape and order size mismatch. Shape: ",
vec2str(dims),
" order: ",
vec2str(m_order));

VectorDims blkDims(dims.size());
for (size_t i = 0; i < dims.size(); ++i) {
blkDims[i] = dims[m_order[i]];
}

return CpuBlockedMemoryDesc(precision, srcShape, blkDims, m_order);
}

size_t ArbitraryOrderDescCreator::getMinimalRank() const {
return m_order.size();
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "blocked_desc_creator.h"

namespace ov {
namespace intel_cpu {

class ArbitraryOrderDescCreator : public BlockedDescCreator {
public:
ArbitraryOrderDescCreator(VectorDims order);

CpuBlockedMemoryDesc createDesc(const ov::element::Type& precision, const Shape& srcShape) const override;
size_t getMinimalRank() const override;

private:
VectorDims m_order;
};

} // namespace intel_cpu
} // namespace ov
24 changes: 1 addition & 23 deletions src/plugins/intel_cpu/src/nodes/common/blocked_desc_creator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,35 +70,13 @@ class ChannelBlockedCreator : public BlockedDescCreator {
size_t _blockSize;
};

class CABDCreator : public BlockedDescCreator {
public:
CpuBlockedMemoryDesc createDesc(const ov::element::Type& precision, const Shape& srcShape) const override {
SizeVector order(srcShape.getRank());
std::iota(order.begin(), order.end(), 0);
SizeVector blkDims = srcShape.getDims();
if (srcShape.getRank() > 2) {
auto moveElementFront = [](SizeVector& vector, size_t indx) {
auto itr = vector.begin() + indx;
std::rotate(vector.begin(), itr, vector.end() - 1);
};

moveElementFront(order, srcShape.getRank() - 1 - 1);
moveElementFront(blkDims, srcShape.getRank() - 1 - 1);
}

return CpuBlockedMemoryDesc(precision, srcShape, blkDims, order);
}
size_t getMinimalRank() const override { return 3lu; }
};

} // namespace

const BlockedDescCreator::CreatorsMap& BlockedDescCreator::getCommonCreators() {
static const CreatorsMap map{ { LayoutType::nspc, CreatorConstPtr(new PerChannelCreator) },
{ LayoutType::nCsp8c, CreatorConstPtr(new ChannelBlockedCreator(8)) },
{ LayoutType::nCsp16c, CreatorConstPtr(new ChannelBlockedCreator(16)) },
{ LayoutType::ncsp, CreatorConstPtr(new PlainFormatCreator) },
{ LayoutType::cabd, CreatorConstPtr(new CABDCreator) } };
{ LayoutType::ncsp, CreatorConstPtr(new PlainFormatCreator) } };
return map;
}

Expand Down
4 changes: 3 additions & 1 deletion src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "utils/ngraph_utils.hpp"
#include "shape_inference/shape_inference_pass_through.hpp"
#include "common/arbitrary_order_desc_creator.h"

using namespace dnnl;
using namespace InferenceEngine;
Expand Down Expand Up @@ -547,11 +548,12 @@ void MemoryInputSDPA::initSupportedPrimitiveDescriptors() {

// Since this is a very specialized implementation, lets mimic SDPA precision and set cabd layout
precision = SDPA->getOriginalInputPrecisionAtPort(childPort);
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});

PortConfig outPortConfig;
outPortConfig.inPlace(0);
outPortConfig.constant(false);
outPortConfig.setMemDesc(descCreators.at(LayoutType::cabd)->createSharedDesc(precision, shape));
outPortConfig.setMemDesc(cabdDescCreator.createSharedDesc(precision, shape));
config.outConfs.push_back(std::move(outPortConfig));
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}
Expand Down
21 changes: 12 additions & 9 deletions src/plugins/intel_cpu/src/nodes/scaled_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "utils/plain_tensor.hpp"
#include <openvino/op/scaled_dot_product_attention.hpp>
#include "common/arbitrary_order_desc_creator.h"

#ifdef OV_CPU_WITH_MLAS
# include "mlas/sgemm.hpp"
Expand Down Expand Up @@ -786,24 +787,26 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
ov::element::f32, getInputShapeAtPort(nextPortIdx)));
}

if (m_config.config.fuse_concat) {
config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::cabd)->createSharedDesc(
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});

config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc(
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0)));
config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::cabd)->createSharedDesc(
config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc(
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1)));
}

config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getOutputShapeAtPort(0)));

if (m_config.config.fuse_concat) {
config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::cabd)->createSharedDesc(
config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc(
rtPrecision, getOutputShapeAtPort(1)));
config.outConfs[1].inPlace(orginSDPInputNumber + 0);
config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::cabd)->createSharedDesc(
config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc(
rtPrecision, getOutputShapeAtPort(2)));
config.outConfs[2].inPlace(orginSDPInputNumber + 1);
}

config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getOutputShapeAtPort(0)));

supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any);
// may fallback to abcd without inplace
if (m_config.config.fuse_concat) {
Expand Down

0 comments on commit 504b6e9

Please sign in to comment.