Skip to content

Commit

Permalink
Optimize MemoryInput/Output for empty shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Oct 11, 2024
1 parent 3976549 commit 93b1fae
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 23 deletions.
9 changes: 4 additions & 5 deletions src/plugins/intel_cpu/src/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -655,7 +655,7 @@ std::vector<EdgePtr> Node::getChildEdgesAtPort(int inputNum) const {
if (!edge)
OPENVINO_THROW("Node ", getName(), " contains dead weak ptr");
if (edge->getInputNum() == inputNum)
res.push_back(edge);
res.emplace_back(std::move(edge));
}
return res;
}
Expand Down Expand Up @@ -793,11 +793,10 @@ void Node::redefineOutputMemory(const std::vector<VectorDims> &newOutputShapes)
void Node::redefineOutputMemory(const size_t port, const VectorDims& new_output_shape) {
const auto edges = getChildEdgesAtPort(port);

static const VectorDims single_element_shape = {1};

// avoid 0D shape incompatible
auto new_shape = new_output_shape;
if (new_shape.empty()) {
new_shape.push_back(1);
}
const auto& new_shape = new_output_shape.empty() ? single_element_shape : new_output_shape;

const auto& curr_desc = edges[0]->getMemory().getDesc();
if (curr_desc.getShape().isStatic() && curr_desc.getShape().getStaticDims() == new_shape) {
Expand Down
55 changes: 37 additions & 18 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -300,21 +300,27 @@ void MemoryOutput::runStatic(dnnl::stream strm) {
void MemoryOutput::runDynamic(dnnl::stream strm) {
//first we have to resize the output memory
auto inputMem = getSrcMemoryAtPort(0);
const auto& newDims = inputMem->getStaticDims();
OPENVINO_ASSERT(extMemDesc,
"MemoryOutput ",
getName(),
" uninitialized assigned memory");

auto newExternDesc = extMemDesc->cloneWithNewDims(newDims);

OPENVINO_ASSERT(assignedMem,
"MemoryOutput ",
getName(),
" uninitialized assigned memory");
assignedMem->redefineDesc(newExternDesc);

runStatic(strm);
const auto& newShape = inputMem->getShape();
const auto& stateShape = assignedMem->getShape();

if (stateShape.isDynamic() || stateShape.getStaticDims() != newShape.getStaticDims()) {
OPENVINO_ASSERT(extMemDesc,
"MemoryOutput ",
getName(),
" uninitialized assigned memory");
auto newExternDesc = extMemDesc->cloneWithNewDims(newShape.getStaticDims());
assignedMem->redefineDesc(newExternDesc);
}

if (!newShape.hasZeroDims()) { // no need to copy data for empty tensor
runStatic(strm);
}
}

bool MemoryOutputStub::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
Expand Down Expand Up @@ -593,31 +599,44 @@ void MemoryInput::runDynamic(dnnl::stream strm) {
getName(),
" assigned state has null memory ptr");

// check whether we can share memory block
const auto& stateDims = assignedMem->getStaticDims();
const bool hasZeroDims = std::count(std::begin(stateDims), std::end(stateDims), 0) > 0;
auto internDesc = getBaseMemDescAtOutputPort(0)->cloneWithNewDims(stateDims, hasZeroDims);

OPENVINO_ASSERT(memBlock,
"MemoryInput ",
getName(),
" has uninitialized memory block.");

// check whether we can share memory block
const auto& shape = assignedMem->getShape();
const bool hasZeroDims = shape.hasZeroDims();
const bool processInitGraph = needInitGraphProcessing();
const auto& stateDims = shape.getStaticDims();

if (hasZeroDims && !processInitGraph) {
// fast track as we don't really need to share memory and transfer any data for empty tensors
memBlock->reset();
redefineOutputMemory(0, stateDims);
return;
}

auto dst = getDstMemoryAtPort(0);
auto currentOutputDesc = dst->getDescPtr();

auto internDesc = currentOutputDesc->isDefined() && (currentOutputDesc->getShape().getStaticDims() == stateDims)
? currentOutputDesc
: getBaseMemDescAtOutputPort(0)->cloneWithNewDims(stateDims, hasZeroDims);

if (internDesc->isCompatible(assignedMem->getDesc())) {
memBlock->setMemBlock(assignedMem->getMemoryBlock());
} else {
memBlock->reset();
}

const bool processInitGraph = needInitGraphProcessing();
//reshape output
const auto& newDims = processInitGraph ? getSrcMemoryAtPort(0)->getStaticDims() : stateDims;

redefineOutputMemory({newDims});
redefineOutputMemory(0, newDims);

//copy data when necessary
auto src = processInitGraph ? getSrcMemoryAtPort(0) : assignedMem;
auto dst = getDstMemoryAtPort(0);
if (src->getData() != dst->getData()) {
dst->load(*src);
}
Expand Down Expand Up @@ -847,6 +866,6 @@ void MemoryInputSDPA::resolveInPlaceEdges(Edge::LOOK look) {
}
}

} // namespace node
} // namespace node
} // namespace intel_cpu
} // namespace ov

0 comments on commit 93b1fae

Please sign in to comment.