Skip to content

Commit

Permalink
Add memory stub and permutation processing
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Dec 6, 2023
1 parent badff01 commit bd4955c
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 24 deletions.
61 changes: 44 additions & 17 deletions src/plugins/intel_cpu/src/memory_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,39 @@ void VariableStateSingleBuffer::commit_impl() {
}


VariableStateKVcache::VariableStateKVcache(const std::string& name, const MemoryDescPtr& external_desc) :
VariableStateBase(name, external_desc) {
VariableStateKVcache::VariableStateKVcache(
const std::string& name,
const MemoryDescPtr& external_desc,
const BlockedMemoryDescPtr& dense_internal_desc) :
VariableStateBase(name, external_desc), m_dense_internal_desc(dense_internal_desc) {
auto&& shape = external_desc->getShape();

OPENVINO_ASSERT(shape.isDynamic(), "VariableStateKVcache is unexpectedly initalized with a static tensor");
}

ov::SoPtr<ov::ITensor> VariableStateKVcache::get_state() const {
//TBD
return {};
auto actual_internal_desc = m_internal_mem->getDescWithType<BlockedMemoryDesc>();
auto&& dims = actual_internal_desc->getShape().getStaticDims();

auto actual_external_desc = get_external_desc()->cloneWithNewDims(dims);

auto intermed_external_mem =
std::make_shared<Memory>(get_engine(), actual_external_desc->cloneWithNewPrecision(actual_internal_desc->getPrecision()));

auto external_mem = std::make_shared<Memory>(get_engine(), actual_external_desc);

// let's assume 4th rank KV tensors. This may be extended later
OPENVINO_ASSERT(actual_internal_desc->getShape().getRank() == 4);
OPENVINO_ASSERT(actual_external_desc->getShape().getRank() == 4);

auto&& actual_internal_order = actual_internal_desc->getOrder();
//sanity check
OPENVINO_ASSERT(actual_internal_order == m_dense_internal_desc->getOrder());

//TBD very naive implementation
// 1. map m_internal_mem to the intermed_external_mem (the same precision)
// 2. perform precision conversion from intermed_external_mem to external_mem
return std::make_shared<Tensor>(external_mem);
}

void VariableStateKVcache::set_state_impl(const ov::SoPtr<ov::ITensor>& state) {
Expand All @@ -214,17 +237,19 @@ void VariableStateKVcache::set_state_impl(const ov::SoPtr<ov::ITensor>& state) {
auto state_desc = MemoryDescUtils::generateCpuBlockedMemoryDesc(m_state);

//May be optimized by reusing the state tensor underlining memory pointer, but corner cases should be considered
m_internal_mem = std::make_shared<Memory>(get_engine(), state_desc);
auto src = m_state->data();
auto dst = m_internal_mem->getData();
if (src && dst) {
std::memcpy(dst, src, m_state->get_byte_size());
}
auto dense_internal_desc = m_dense_internal_desc->cloneWithNewDims(state_desc->getShape().getStaticDims());

m_internal_mem = std::make_shared<Memory>(get_engine(), dense_internal_desc);
Memory external_mem(get_engine(), state_desc, m_state->data());

m_internal_mem->load(external_mem);

//2. Reset the beam search table
auto&& stateDims = state_desc->getShape().getStaticDims();
const size_t size_B = stateDims[axis_B];
const size_t size_L = stateDims[axis_L];
auto&& state_dims = dense_internal_desc->getShape().getStaticDims();
auto&& order = m_dense_internal_desc->getOrder();

const size_t size_B = state_dims[order.at(0)];
const size_t size_L = state_dims[order.at(2)];
auto mem_desc =
std::make_shared<CpuBlockedMemoryDesc>(ov::element::i32, Shape{size_B, size_L});

Expand All @@ -239,14 +264,16 @@ void VariableStateKVcache::set_state_impl(const ov::SoPtr<ov::ITensor>& state) {

void VariableStateKVcache::reset_impl() {
// 1. reset internal state
auto internal_state_desc = to_static(get_external_desc());
auto internal_state_desc = to_static(m_dense_internal_desc);
m_internal_mem = std::make_shared<Memory>(get_engine(), internal_state_desc);
m_internal_mem->nullify();

// 2. reset hidden state
auto&& stateDims = internal_state_desc->getShape().getStaticDims();
const size_t size_B = stateDims[axis_B];
const size_t size_L = stateDims[axis_L];
auto&& state_dims = internal_state_desc->getShape().getStaticDims();
auto&& order = m_dense_internal_desc->getOrder();

const size_t size_B = state_dims[order.at(0)];
const size_t size_L = state_dims[order.at(2)];
auto hidden_state_desc =
std::make_shared<CpuBlockedMemoryDesc>(ov::element::i32, Shape{size_B, size_L});

Expand Down
9 changes: 5 additions & 4 deletions src/plugins/intel_cpu/src/memory_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,9 @@ class VariableStateSingleBuffer : public VariableStateBase {

class VariableStateKVcache : public VariableStateBase {
public:
VariableStateKVcache(const std::string& name, const MemoryDescPtr& external_desc);
VariableStateKVcache(const std::string& name,
const MemoryDescPtr& external_desc,
const BlockedMemoryDescPtr& dense_internal_desc);

//ov::IVariableState
ov::SoPtr<ov::ITensor> get_state() const override;
Expand All @@ -148,9 +150,8 @@ class VariableStateKVcache : public VariableStateBase {
MemoryPtr m_internal_mem; // kv cache
MemoryPtr m_hidden_state; // beam access table

//TODO: how is it better to pass these values to the state object?
size_t axis_B = 0;
size_t axis_L = 2;
// this desc stores the internal prc and axis permutation
BlockedMemoryDescPtr m_dense_internal_desc;
};

using MemStatePtr = std::shared_ptr<IVariableState>;
Expand Down
82 changes: 79 additions & 3 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,65 @@ namespace ov {
namespace intel_cpu {
namespace node {

namespace {
class MemoryStub : public IMemory {
public:
MemoryStub(const dnnl::engine& eng, const MemoryDescPtr& pMemDesc) : m_eng(eng), m_pMemDesc(pMemDesc) {}

bool isAllocated() const noexcept override {
return true;
}

const MemoryDesc& getDesc() const override {
return *m_pMemDesc;
}

MemoryDescPtr getDescPtr() const override {
return m_pMemDesc;
}

void* getData() const override {
OPENVINO_THROW("Unexpected call MemoryStub::getData()");
}

size_t getSize() const override {
return 0;
}

const Shape& getShape() const override {
return m_pMemDesc->getShape();
}

const VectorDims& getStaticDims() const override {
return m_pMemDesc->getShape().getStaticDims();
}

void redefineDesc(MemoryDescPtr desc) override {
m_pMemDesc = desc;
}

void load(const IMemory& src, bool ftz = true) const override {
OPENVINO_THROW("Unexpected call MemoryStub::load()");
}

MemoryMngrPtr getMemoryMngr() const override {
OPENVINO_THROW("Unexpected call MemoryStub::getMemoryMngr()");
}

dnnl::memory getPrimitive() const override {
OPENVINO_THROW("Unexpected call MemoryStub::getPrimitive()");
}

void nullify() override {
// nothing to do
}

private:
dnnl::engine m_eng;
MemoryDescPtr m_pMemDesc;
};
} // namespace

std::mutex MemoryNodeVirtualEdge::holderMutex;

MemoryNode::MemoryNode(const std::shared_ptr<ov::Node>& op) {
Expand Down Expand Up @@ -260,7 +319,8 @@ void MemoryOutputStub::resolveInPlaceEdges(Edge::LOOK look) {

auto memDesc = selected_pd->getConfig().inConfs.front().getMemDesc();
// make a fake memory
//parentEdge->reuse(edgeMem);
auto edgeMem = std::make_shared<MemoryStub>(getEngine(), memDesc);
parentEdge->reuse(edgeMem);
}

void MemoryOutputStub::assignExtMemory(const MemoryPtr& mem, const MemoryDescPtr& memDesc) {
Expand Down Expand Up @@ -639,7 +699,16 @@ MemStatePtr MemoryInputSDPA::makeState() const {
state_name = state_name.substr(0, suffix_idx);
}

return std::make_shared<VariableStateKVcache>(state_name, original_desc);
// somehow retrieve the internal precision and axis order from the SDPA node
// m_sdpa->get_kv_precision();
// m_sdpa->get_kv_axis_order();

auto kv_precision = element::bf16;
VectorDims order = {0, 1, 2, 3};

auto internal_desc = ArbitraryOrderDescCreator(order).createSharedDesc(kv_precision, outputShapes.at(0));

return std::make_shared<VariableStateKVcache>(state_name, original_desc, internal_desc);
}

bool MemoryInputSDPA::isExecutable() const {
Expand All @@ -656,7 +725,14 @@ void MemoryInputSDPA::resolveInPlaceEdges(Edge::LOOK look) {
if (getParentEdgeAt(0)) {
Node::resolveInPlaceEdges(look);
} else {
// place some kind of fake memory simply to transfer the the shape
auto memDesc = getBaseMemDescAtOutputPort(0);
for (auto&& edge : getChildEdgesAtPort(0)) { // always only one child port
OPENVINO_ASSERT(one_of(edge->getStatus(), Edge::Status::Uninitialized, Edge::Status::NotAllocated),
" Unexpected inplace resolve call to an allocated edge: ", edge->name());

auto edgeMem = std::make_shared<MemoryStub>(getEngine(), memDesc);
edge->reuse(edgeMem);
}
}
}

Expand Down

0 comments on commit bd4955c

Please sign in to comment.