diff --git a/src/plugins/intel_cpu/src/memory_state.cpp b/src/plugins/intel_cpu/src/memory_state.cpp index a09ecf56993233..897167713d9ef6 100644 --- a/src/plugins/intel_cpu/src/memory_state.cpp +++ b/src/plugins/intel_cpu/src/memory_state.cpp @@ -214,10 +214,6 @@ ov::SoPtr VariableStateKVcache::get_state() const { auto&& dims = actual_internal_desc->getShape().getStaticDims(); auto actual_external_desc = get_external_desc()->cloneWithNewDims(dims); - - auto intermed_external_mem = - std::make_shared(get_engine(), actual_external_desc->cloneWithNewPrecision(actual_internal_desc->getPrecision())); - auto external_mem = std::make_shared(get_engine(), actual_external_desc); // let's assume 4th rank KV tensors. This may be extended later @@ -228,9 +224,8 @@ ov::SoPtr VariableStateKVcache::get_state() const { //sanity check OPENVINO_ASSERT(actual_internal_order == m_dense_internal_desc->getOrder()); - //TBD very naive implementation - // 1. map m_internal_mem to the intermed_external_mem (the same precision) - // 2. perform precision conversion from intermed_external_mem to external_mem + // Warning, this implementation is very KV cache specific it assumes that S is always a last dimension and it's not + if (m_hidden_state) { PlainTensor output, pastkv, beam_table; output.reset(external_mem); diff --git a/src/plugins/intel_cpu/src/nodes/memory.cpp b/src/plugins/intel_cpu/src/nodes/memory.cpp index 4e3319e675e7da..65645c071841f7 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.cpp +++ b/src/plugins/intel_cpu/src/nodes/memory.cpp @@ -381,30 +381,6 @@ MemoryInputBase::MemoryInputBase(const std::string id, // this is their responsibility to link the input/output nodes properly } -void MemoryInputBase::resolveInPlaceEdges(Edge::LOOK look) { - if (!(look & Edge::LOOK_UP)) { - Node::resolveInPlaceEdges(look); - return; - } - - auto selected_pd = getSelectedPrimitiveDescriptor(); - OPENVINO_ASSERT(selected_pd, - "MemoryInput ", - getName(), - " failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set"); - - auto memDesc = selected_pd->getConfig().outConfs.front().getMemDesc(); - memMngr = std::make_shared(); - - for (auto&& edge : getChildEdgesAtPort(0)) { // always only one child port - OPENVINO_ASSERT(one_of(edge->getStatus(), Edge::Status::Uninitialized, Edge::Status::NotAllocated), - " Unexpected inplace resolve call to an allocated edge: ", edge->name()); - - auto edgeMem = std::make_shared(getEngine(), memDesc, memMngr); - edge->reuse(edgeMem); - } -} - MemoryInputBase::~MemoryInputBase() { if (outputNode) { outputNode->deregisterSibling(this); } MemoryNodeVirtualEdge::remove(this, holder); @@ -415,71 +391,6 @@ MemoryOutputBase& MemoryInputBase::getOutputNode() { return *outputNode; } -void MemoryInputBase::assignState(MemStatePtr newState) { - assignedMem = newState->input_mem(); - - isExecutableFlag = !getParentEdges().empty() && newState->is_reset_state(); - - OPENVINO_ASSERT(assignedMem, - "MemoryInput ", - getName(), - " assigned state has null memory ptr"); - - const auto& newDims = assignedMem->getStaticDims(); - MemoryDescPtr internDesc; - if (isDynamicNode()) { - const bool hasZeroDims = std::count(std::begin(newDims), std::end(newDims), 0) > 0; - internDesc = getBaseMemDescAtOutputPort(0)->cloneWithNewDims(newDims, hasZeroDims); - } else { - auto expectedDims = getBaseMemDescAtOutputPort(0)->getShape().getStaticDims(); - OPENVINO_ASSERT(expectedDims == newDims, - "MemoryInput ", - getName(), - " unexpected state shape: ", - vec2str(newDims), - ", while the expected shape: ", - vec2str(expectedDims)); - - internDesc = getBaseMemDescAtOutputPort(0); - } - - OPENVINO_ASSERT(memMngr, - "MemoryInput ", - getName(), - " has uninitialized memory manager."); - - if (internDesc->isCompatible(assignedMem->getDesc())) { - memMngr->setMemMngr(assignedMem->getMemoryMngr()); - } else { - memMngr->reset(); - } - - if (!isExecutableFlag) { - const auto& edges = getChildEdgesAtPort(0); - if (isDynamicNode()) { - for (auto&& edge : edges) { - edge->getMemoryPtr()->redefineDesc(internDesc); - } - } - - auto outMem = edges.front()->getMemoryPtr(); - - if (outMem->getData() != assignedMem->getData()) { - outMem->load(*assignedMem); - } - } - - getOutputNode().assignExtMemory(newState->output_mem(), newState->internal_desc()); -} - -bool MemoryInputBase::needShapeInfer() const { - return isExecutableFlag; -} - -bool MemoryInputBase::isExecutable() const { - return isExecutableFlag && Node::isExecutable(); -} - void MemoryInputBase::initSupportedPrimitiveDescriptors() { if (!supportedPrimitiveDescriptors.empty()) return; @@ -511,18 +422,6 @@ void MemoryInputBase::initSupportedPrimitiveDescriptors() { supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown); } -void MemoryInputBase::executeDynamicImpl(dnnl::stream strm) { - execute(strm); -} - -void MemoryInputBase::execute(dnnl::stream strm) { - if (!isExecutableFlag) return; - - auto&& src = getParentEdgeAt(0)->getMemory(); - auto&& dst = getChildEdgesAtPort(0).front()->getMemoryPtr(); - dst->load(src); -} - void MemoryInputBase::registerOutputNode(MemoryOutputBase* node) { if (outputNode == node) { return; } if (outputNode) { outputNode->deregisterSibling(this); } @@ -573,6 +472,14 @@ void MemoryNodeVirtualEdge::remove(MemoryNode * node, Holder* holder) { } } +bool MemoryInput::needShapeInfer() const { + return isExecutableFlag; +} + +bool MemoryInput::isExecutable() const { + return isExecutableFlag && Node::isExecutable(); +} + void MemoryInput::initOptimalPrimitiveDescriptor() { // Mimic the child node memory desc to avoid extra reorder static const Type preferredTypes[] = { @@ -629,6 +536,42 @@ void MemoryInput::initOptimalPrimitiveDescriptor() { selectedPd->setConfig(config); } +void MemoryInput::executeDynamicImpl(dnnl::stream strm) { + execute(strm); +} + +void MemoryInput::execute(dnnl::stream strm) { + if (!isExecutableFlag) return; + + auto&& src = getParentEdgeAt(0)->getMemory(); + auto&& dst = getChildEdgesAtPort(0).front()->getMemoryPtr(); + dst->load(src); +} + +void MemoryInput::resolveInPlaceEdges(Edge::LOOK look) { + if (!(look & Edge::LOOK_UP)) { + Node::resolveInPlaceEdges(look); + return; + } + + auto selected_pd = getSelectedPrimitiveDescriptor(); + OPENVINO_ASSERT(selected_pd, + "MemoryInput ", + getName(), + " failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set"); + + auto memDesc = selected_pd->getConfig().outConfs.front().getMemDesc(); + memMngr = std::make_shared(); + + for (auto&& edge : getChildEdgesAtPort(0)) { // always only one child port + OPENVINO_ASSERT(one_of(edge->getStatus(), Edge::Status::Uninitialized, Edge::Status::NotAllocated), + " Unexpected inplace resolve call to an allocated edge: ", edge->name()); + + auto edgeMem = std::make_shared(getEngine(), memDesc, memMngr); + edge->reuse(edgeMem); + } +} + MemStatePtr MemoryInput::makeState() const { // assume ov::Tensor is always dense auto original_desc = @@ -651,6 +594,64 @@ MemStatePtr MemoryInput::makeState() const { original_desc); } +void MemoryInput::assignState(MemStatePtr newState) { + assignedMem = newState->input_mem(); + + isExecutableFlag = !getParentEdges().empty() && newState->is_reset_state(); + + OPENVINO_ASSERT(assignedMem, + "MemoryInput ", + getName(), + " assigned state has null memory ptr"); + + const auto& newDims = assignedMem->getStaticDims(); + MemoryDescPtr internDesc; + if (isDynamicNode()) { + const bool hasZeroDims = std::count(std::begin(newDims), std::end(newDims), 0) > 0; + internDesc = getBaseMemDescAtOutputPort(0)->cloneWithNewDims(newDims, hasZeroDims); + } else { + auto expectedDims = getBaseMemDescAtOutputPort(0)->getShape().getStaticDims(); + OPENVINO_ASSERT(expectedDims == newDims, + "MemoryInput ", + getName(), + " unexpected state shape: ", + vec2str(newDims), + ", while the expected shape: ", + vec2str(expectedDims)); + + internDesc = getBaseMemDescAtOutputPort(0); + } + + OPENVINO_ASSERT(memMngr, + "MemoryInput ", + getName(), + " has uninitialized memory manager."); + + if (internDesc->isCompatible(assignedMem->getDesc())) { + memMngr->setMemMngr(assignedMem->getMemoryMngr()); + } else { + memMngr->reset(); + } + + if (!isExecutableFlag) { + const auto& edges = getChildEdgesAtPort(0); + if (isDynamicNode()) { + for (auto&& edge : edges) { + edge->getMemoryPtr()->redefineDesc(internDesc); + } + } + + auto outMem = edges.front()->getMemoryPtr(); + + if (outMem->getData() != assignedMem->getData()) { + outMem->load(*assignedMem); + } + } + + getOutputNode().assignExtMemory(newState->output_mem(), newState->internal_desc()); +} + + bool MemoryInput::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { return MemoryInputBase::isSupportedOperation(op, errorMessage); } @@ -719,9 +720,15 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() { } void MemoryInputSDPA::assignState(MemStatePtr newState) { - m_needShapeInfer = !getParentEdges().empty() && newState->is_reset_state(); - - if (!m_needShapeInfer) { + if (newState->is_reset_state()) { + if (getParentEdges().empty()) { + auto newShape = MemoryDescUtils::makeDummyShape(getBaseMemDescAtOutputPort(0)->getShape(), 0); + redefineOutputMemory({newShape.getStaticDims()}); + m_needShapeInfer = false; + } else { + m_needShapeInfer = true; + } + } else { auto stateMem = newState->input_mem(); OPENVINO_ASSERT(stateMem, "Internal state mem id: ", @@ -730,6 +737,7 @@ void MemoryInputSDPA::assignState(MemStatePtr newState) { getName()); redefineOutputMemory({stateMem->getStaticDims()}); + m_needShapeInfer = false; } auto sdpaNode = m_sdpaNode.lock(); @@ -768,7 +776,11 @@ MemStatePtr MemoryInputSDPA::makeState() const { } void MemoryInputSDPA::execute(dnnl::stream strm) { - return; + //nothing to do +} + +void MemoryInputSDPA::executeDynamicImpl(dnnl::stream strm) { + //nothing to do } void MemoryInputSDPA::resolveInPlaceEdges(Edge::LOOK look) { diff --git a/src/plugins/intel_cpu/src/nodes/memory.hpp b/src/plugins/intel_cpu/src/nodes/memory.hpp index e7997f96e2f6df..7d648e42c632ae 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.hpp +++ b/src/plugins/intel_cpu/src/nodes/memory.hpp @@ -161,18 +161,11 @@ class MemoryInputBase : public Input, public MemoryStateNode { return getType() == Type::MemoryInput; } - bool needShapeInfer() const override; - bool isExecutable() const override; void initSupportedPrimitiveDescriptors() override; - void execute(dnnl::stream strm) override; - void executeDynamicImpl(dnnl::stream strm) override; - - void resolveInPlaceEdges(Edge::LOOK look) override; void registerOutputNode(MemoryOutputBase* node); void deregisterSibling(MemoryOutputBase* node); - void assignState(MemStatePtr newState) override; MemoryOutputBase& getOutputNode(); private: @@ -180,10 +173,7 @@ class MemoryInputBase : public Input, public MemoryStateNode { * @brief keeps reference to output sibling node */ MemoryOutputBase* outputNode = nullptr; - MemoryPtr assignedMem = nullptr; MemoryNodeVirtualEdge::Holder* holder = nullptr; - ProxyMemoryMngrPtr memMngr = nullptr; - bool isExecutableFlag = true; }; class MemoryInput : public MemoryInputBase { @@ -191,9 +181,21 @@ class MemoryInput : public MemoryInputBase { using MemoryInputBase::MemoryInputBase; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + bool needShapeInfer() const override; + bool isExecutable() const override; void initOptimalPrimitiveDescriptor() override; + void execute(dnnl::stream strm) override; + void executeDynamicImpl(dnnl::stream strm) override; + + void resolveInPlaceEdges(Edge::LOOK look) override; + void assignState(MemStatePtr newState) override; MemStatePtr makeState() const override; + +private: + bool isExecutableFlag = true; + ProxyMemoryMngrPtr memMngr = nullptr; + MemoryPtr assignedMem = nullptr; }; class MemoryInputSDPA : public MemoryInputBase { @@ -218,6 +220,7 @@ class MemoryInputSDPA : public MemoryInputBase { void initOptimalPrimitiveDescriptor() override; void execute(dnnl::stream strm) override; + void executeDynamicImpl(dnnl::stream strm) override; void resolveInPlaceEdges(Edge::LOOK look) override; @@ -227,7 +230,7 @@ class MemoryInputSDPA : public MemoryInputBase { private: std::weak_ptr m_sdpaNode; int m_child_port_idx = -1; - bool m_needShapeInfer = false; // TODO refactor MemoryInputBase in order to better separate responsibilities + bool m_needShapeInfer = false; }; } // namespace node } // namespace intel_cpu