Skip to content

Commit

Permalink
Code lean up and refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Dec 14, 2023
1 parent a1d3fcb commit 0c2d7ea
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 123 deletions.
9 changes: 2 additions & 7 deletions src/plugins/intel_cpu/src/memory_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,6 @@ ov::SoPtr<ov::ITensor> VariableStateKVcache::get_state() const {
auto&& dims = actual_internal_desc->getShape().getStaticDims();

auto actual_external_desc = get_external_desc()->cloneWithNewDims(dims);

auto intermed_external_mem =
std::make_shared<Memory>(get_engine(), actual_external_desc->cloneWithNewPrecision(actual_internal_desc->getPrecision()));

auto external_mem = std::make_shared<Memory>(get_engine(), actual_external_desc);

// let's assume 4th rank KV tensors. This may be extended later
Expand All @@ -228,9 +224,8 @@ ov::SoPtr<ov::ITensor> VariableStateKVcache::get_state() const {
//sanity check
OPENVINO_ASSERT(actual_internal_order == m_dense_internal_desc->getOrder());

//TBD very naive implementation
// 1. map m_internal_mem to the intermed_external_mem (the same precision)
// 2. perform precision conversion from intermed_external_mem to external_mem
// Warning, this implementation is very KV cache specific it assumes that S is always a last dimension and it's not

if (m_hidden_state) {
PlainTensor output, pastkv, beam_table;
output.reset(external_mem);
Expand Down
222 changes: 117 additions & 105 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,30 +381,6 @@ MemoryInputBase::MemoryInputBase(const std::string id,
// this is their responsibility to link the input/output nodes properly
}

void MemoryInputBase::resolveInPlaceEdges(Edge::LOOK look) {
if (!(look & Edge::LOOK_UP)) {
Node::resolveInPlaceEdges(look);
return;
}

auto selected_pd = getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(selected_pd,
"MemoryInput ",
getName(),
" failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");

auto memDesc = selected_pd->getConfig().outConfs.front().getMemDesc();
memMngr = std::make_shared<ProxyMemoryMngr>();

for (auto&& edge : getChildEdgesAtPort(0)) { // always only one child port
OPENVINO_ASSERT(one_of(edge->getStatus(), Edge::Status::Uninitialized, Edge::Status::NotAllocated),
" Unexpected inplace resolve call to an allocated edge: ", edge->name());

auto edgeMem = std::make_shared<Memory>(getEngine(), memDesc, memMngr);
edge->reuse(edgeMem);
}
}

MemoryInputBase::~MemoryInputBase() {
if (outputNode) { outputNode->deregisterSibling(this); }
MemoryNodeVirtualEdge::remove(this, holder);
Expand All @@ -415,71 +391,6 @@ MemoryOutputBase& MemoryInputBase::getOutputNode() {
return *outputNode;
}

void MemoryInputBase::assignState(MemStatePtr newState) {
assignedMem = newState->input_mem();

isExecutableFlag = !getParentEdges().empty() && newState->is_reset_state();

OPENVINO_ASSERT(assignedMem,
"MemoryInput ",
getName(),
" assigned state has null memory ptr");

const auto& newDims = assignedMem->getStaticDims();
MemoryDescPtr internDesc;
if (isDynamicNode()) {
const bool hasZeroDims = std::count(std::begin(newDims), std::end(newDims), 0) > 0;
internDesc = getBaseMemDescAtOutputPort(0)->cloneWithNewDims(newDims, hasZeroDims);
} else {
auto expectedDims = getBaseMemDescAtOutputPort(0)->getShape().getStaticDims();
OPENVINO_ASSERT(expectedDims == newDims,
"MemoryInput ",
getName(),
" unexpected state shape: ",
vec2str(newDims),
", while the expected shape: ",
vec2str(expectedDims));

internDesc = getBaseMemDescAtOutputPort(0);
}

OPENVINO_ASSERT(memMngr,
"MemoryInput ",
getName(),
" has uninitialized memory manager.");

if (internDesc->isCompatible(assignedMem->getDesc())) {
memMngr->setMemMngr(assignedMem->getMemoryMngr());
} else {
memMngr->reset();
}

if (!isExecutableFlag) {
const auto& edges = getChildEdgesAtPort(0);
if (isDynamicNode()) {
for (auto&& edge : edges) {
edge->getMemoryPtr()->redefineDesc(internDesc);
}
}

auto outMem = edges.front()->getMemoryPtr();

if (outMem->getData() != assignedMem->getData()) {
outMem->load(*assignedMem);
}
}

getOutputNode().assignExtMemory(newState->output_mem(), newState->internal_desc());
}

bool MemoryInputBase::needShapeInfer() const {
return isExecutableFlag;
}

bool MemoryInputBase::isExecutable() const {
return isExecutableFlag && Node::isExecutable();
}

void MemoryInputBase::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
Expand Down Expand Up @@ -511,18 +422,6 @@ void MemoryInputBase::initSupportedPrimitiveDescriptors() {
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}

void MemoryInputBase::executeDynamicImpl(dnnl::stream strm) {
execute(strm);
}

void MemoryInputBase::execute(dnnl::stream strm) {
if (!isExecutableFlag) return;

auto&& src = getParentEdgeAt(0)->getMemory();
auto&& dst = getChildEdgesAtPort(0).front()->getMemoryPtr();
dst->load(src);
}

void MemoryInputBase::registerOutputNode(MemoryOutputBase* node) {
if (outputNode == node) { return; }
if (outputNode) { outputNode->deregisterSibling(this); }
Expand Down Expand Up @@ -573,6 +472,14 @@ void MemoryNodeVirtualEdge::remove(MemoryNode * node, Holder* holder) {
}
}

bool MemoryInput::needShapeInfer() const {
return isExecutableFlag;
}

bool MemoryInput::isExecutable() const {
return isExecutableFlag && Node::isExecutable();
}

void MemoryInput::initOptimalPrimitiveDescriptor() {
// Mimic the child node memory desc to avoid extra reorder
static const Type preferredTypes[] = {
Expand Down Expand Up @@ -629,6 +536,42 @@ void MemoryInput::initOptimalPrimitiveDescriptor() {
selectedPd->setConfig(config);
}

void MemoryInput::executeDynamicImpl(dnnl::stream strm) {
execute(strm);
}

void MemoryInput::execute(dnnl::stream strm) {
if (!isExecutableFlag) return;

auto&& src = getParentEdgeAt(0)->getMemory();
auto&& dst = getChildEdgesAtPort(0).front()->getMemoryPtr();
dst->load(src);
}

void MemoryInput::resolveInPlaceEdges(Edge::LOOK look) {
if (!(look & Edge::LOOK_UP)) {
Node::resolveInPlaceEdges(look);
return;
}

auto selected_pd = getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(selected_pd,
"MemoryInput ",
getName(),
" failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");

auto memDesc = selected_pd->getConfig().outConfs.front().getMemDesc();
memMngr = std::make_shared<ProxyMemoryMngr>();

for (auto&& edge : getChildEdgesAtPort(0)) { // always only one child port
OPENVINO_ASSERT(one_of(edge->getStatus(), Edge::Status::Uninitialized, Edge::Status::NotAllocated),
" Unexpected inplace resolve call to an allocated edge: ", edge->name());

auto edgeMem = std::make_shared<Memory>(getEngine(), memDesc, memMngr);
edge->reuse(edgeMem);
}
}

MemStatePtr MemoryInput::makeState() const {
// assume ov::Tensor is always dense
auto original_desc =
Expand All @@ -651,6 +594,64 @@ MemStatePtr MemoryInput::makeState() const {
original_desc);
}

void MemoryInput::assignState(MemStatePtr newState) {
assignedMem = newState->input_mem();

isExecutableFlag = !getParentEdges().empty() && newState->is_reset_state();

OPENVINO_ASSERT(assignedMem,
"MemoryInput ",
getName(),
" assigned state has null memory ptr");

const auto& newDims = assignedMem->getStaticDims();
MemoryDescPtr internDesc;
if (isDynamicNode()) {
const bool hasZeroDims = std::count(std::begin(newDims), std::end(newDims), 0) > 0;
internDesc = getBaseMemDescAtOutputPort(0)->cloneWithNewDims(newDims, hasZeroDims);
} else {
auto expectedDims = getBaseMemDescAtOutputPort(0)->getShape().getStaticDims();
OPENVINO_ASSERT(expectedDims == newDims,
"MemoryInput ",
getName(),
" unexpected state shape: ",
vec2str(newDims),
", while the expected shape: ",
vec2str(expectedDims));

internDesc = getBaseMemDescAtOutputPort(0);
}

OPENVINO_ASSERT(memMngr,
"MemoryInput ",
getName(),
" has uninitialized memory manager.");

if (internDesc->isCompatible(assignedMem->getDesc())) {
memMngr->setMemMngr(assignedMem->getMemoryMngr());
} else {
memMngr->reset();
}

if (!isExecutableFlag) {
const auto& edges = getChildEdgesAtPort(0);
if (isDynamicNode()) {
for (auto&& edge : edges) {
edge->getMemoryPtr()->redefineDesc(internDesc);
}
}

auto outMem = edges.front()->getMemoryPtr();

if (outMem->getData() != assignedMem->getData()) {
outMem->load(*assignedMem);
}
}

getOutputNode().assignExtMemory(newState->output_mem(), newState->internal_desc());
}


bool MemoryInput::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
return MemoryInputBase::isSupportedOperation(op, errorMessage);
}
Expand Down Expand Up @@ -719,9 +720,15 @@ void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
}

void MemoryInputSDPA::assignState(MemStatePtr newState) {
m_needShapeInfer = !getParentEdges().empty() && newState->is_reset_state();

if (!m_needShapeInfer) {
if (newState->is_reset_state()) {
if (getParentEdges().empty()) {
auto newShape = MemoryDescUtils::makeDummyShape(getBaseMemDescAtOutputPort(0)->getShape(), 0);
redefineOutputMemory({newShape.getStaticDims()});
m_needShapeInfer = false;
} else {
m_needShapeInfer = true;
}
} else {
auto stateMem = newState->input_mem();
OPENVINO_ASSERT(stateMem,
"Internal state mem id: ",
Expand All @@ -730,6 +737,7 @@ void MemoryInputSDPA::assignState(MemStatePtr newState) {
getName());

redefineOutputMemory({stateMem->getStaticDims()});
m_needShapeInfer = false;
}

auto sdpaNode = m_sdpaNode.lock();
Expand Down Expand Up @@ -768,7 +776,11 @@ MemStatePtr MemoryInputSDPA::makeState() const {
}

void MemoryInputSDPA::execute(dnnl::stream strm) {
return;
//nothing to do
}

void MemoryInputSDPA::executeDynamicImpl(dnnl::stream strm) {
//nothing to do
}

void MemoryInputSDPA::resolveInPlaceEdges(Edge::LOOK look) {
Expand Down
25 changes: 14 additions & 11 deletions src/plugins/intel_cpu/src/nodes/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,39 +161,41 @@ class MemoryInputBase : public Input, public MemoryStateNode {
return getType() == Type::MemoryInput;
}

bool needShapeInfer() const override;
bool isExecutable() const override;
void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override;
void executeDynamicImpl(dnnl::stream strm) override;

void resolveInPlaceEdges(Edge::LOOK look) override;

void registerOutputNode(MemoryOutputBase* node);
void deregisterSibling(MemoryOutputBase* node);

void assignState(MemStatePtr newState) override;
MemoryOutputBase& getOutputNode();

private:
/**
* @brief keeps reference to output sibling node
*/
MemoryOutputBase* outputNode = nullptr;
MemoryPtr assignedMem = nullptr;
MemoryNodeVirtualEdge::Holder* holder = nullptr;
ProxyMemoryMngrPtr memMngr = nullptr;
bool isExecutableFlag = true;
};

class MemoryInput : public MemoryInputBase {
public:
using MemoryInputBase::MemoryInputBase;
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;

bool needShapeInfer() const override;
bool isExecutable() const override;
void initOptimalPrimitiveDescriptor() override;
void execute(dnnl::stream strm) override;
void executeDynamicImpl(dnnl::stream strm) override;

void resolveInPlaceEdges(Edge::LOOK look) override;

void assignState(MemStatePtr newState) override;
MemStatePtr makeState() const override;

private:
bool isExecutableFlag = true;
ProxyMemoryMngrPtr memMngr = nullptr;
MemoryPtr assignedMem = nullptr;
};

class MemoryInputSDPA : public MemoryInputBase {
Expand All @@ -218,6 +220,7 @@ class MemoryInputSDPA : public MemoryInputBase {
void initOptimalPrimitiveDescriptor() override;

void execute(dnnl::stream strm) override;
void executeDynamicImpl(dnnl::stream strm) override;

void resolveInPlaceEdges(Edge::LOOK look) override;

Expand All @@ -227,7 +230,7 @@ class MemoryInputSDPA : public MemoryInputBase {
private:
std::weak_ptr<ScaledDotProductAttention> m_sdpaNode;
int m_child_port_idx = -1;
bool m_needShapeInfer = false; // TODO refactor MemoryInputBase in order to better separate responsibilities
bool m_needShapeInfer = false;
};
} // namespace node
} // namespace intel_cpu
Expand Down

0 comments on commit 0c2d7ea

Please sign in to comment.