Skip to content

Commit

Permalink
init sdpa code for beam search
Browse files Browse the repository at this point in the history
  • Loading branch information
luo-cheng2021 committed Dec 7, 2023
1 parent 96f017d commit d01830f
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 121 deletions.
10 changes: 10 additions & 0 deletions src/plugins/intel_cpu/src/memory_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,5 +313,15 @@ void VariableStateKVcache::assign_hidden_state(const MemoryPtr& mem) {
m_hidden_state = mem;
}

void VariableStateKVcache::init_from(const MemoryPtr& init_val) {
// TODO
OPENVINO_ASSERT(false, "Implement VariableStateKVcache::init_from");
}

void VariableStateKVcache::gather_concat_pastkv(MemoryPtr cur_kv, MemoryPtr beam_idx) {
// TODO
OPENVINO_ASSERT(false, "Implement VariableStateKVcache::gather_concat_pastkv");
}

} // namespace intel_cpu
} // namespace ov
3 changes: 3 additions & 0 deletions src/plugins/intel_cpu/src/memory_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,9 @@ class VariableStateKVcache : public VariableStateBase {
MemoryPtr hidden_state_mem() const;
void assign_hidden_state(const MemoryPtr& mem);

void init_from(const MemoryPtr& init_val);
void gather_concat_pastkv(MemoryPtr cur_kv, MemoryPtr beam_idx);

private:
//ov::intel_cpu::VariableStateBase
void set_state_impl(const ov::SoPtr<ov::ITensor>& state) override;
Expand Down
96 changes: 79 additions & 17 deletions src/plugins/intel_cpu/src/nodes/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "utils/ngraph_utils.hpp"
#include "shape_inference/shape_inference_pass_through.hpp"
#include "shape_inference/shape_inference_internal_dyn.hpp"
#include "common/arbitrary_order_desc_creator.h"

using namespace dnnl;
Expand Down Expand Up @@ -668,20 +669,73 @@ MemoryInputSDPA::MemoryInputSDPA(const std::string id,
const ov::optional<Shape>& input_shape,
const ov::optional<ov::element::Type>& input_prc,
const std::shared_ptr<ScaledDotProductAttention>& sdpaNode) :
MemoryInputBase(id, name, type, output_shape, output_prc, context, input_shape, input_prc), m_sdpaNode(sdpaNode) {}
MemoryInputBase(id, name, type, output_shape, output_prc, context, input_shape, input_prc), m_sdpaNode(sdpaNode) {
if (isDynamic) {
// 2 scenarios:
// 1, after reset(first token)
// a, if there is init-subgraph, the shape should be got from init subgraph
// b, if there is no init-subgraph, the shape will be computed from state
// 2, second token: the shape will be computed from state
// since the source is determined by the condition, can use InternalDynShapeInferFactory to
// dynamicly get the shape
shapeInference = InternalDynShapeInferFactory().makeShapeInfer();
}
}

void MemoryInputSDPA::createPrimitive() {
MemoryInputBase::createPrimitive();
// determine the output node idx
// child_port_idx =
// determine the output node idx
auto memDesc = getBaseMemDescAtOutputPort(0);
auto sdpaNode = m_sdpaNode.lock();
for (auto&& edge : getChildEdgesAtPort(0)) { // always only one child port
auto node = edge->getChild();
if (node == sdpaNode) {
child_port_idx = edge->getOutputNum();
break;
}
}
OPENVINO_ASSERT(child_port_idx != -1, getName(), " should connect to SDPA node.");
}

void MemoryInputSDPA::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;

auto&& shape = getOutputShapeAtPort(0);
auto precision = getOriginalOutputPrecisionAtPort(0);
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();
NodeConfig config;
if (!getParentEdges().empty()) {
PortConfig inPortConfig;
inPortConfig.inPlace(-1);
inPortConfig.constant(false);
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
config.inConfs.push_back(std::move(inPortConfig));
}

auto node = m_sdpaNode.lock();
// retrieve the internal precision and axis order from the SDPA node
auto kv_precision = node->getKVCachePrecision();
VectorDims order = {0, 1, 2, 3};
if (!node->getKVCacheOrder().empty())
order = node->getKVCacheOrder();
ArbitraryOrderDescCreator cabdDescCreator(order);

PortConfig outPortConfig;
// output edge will be a fake memory obj, real memory is stored in state
outPortConfig.inPlace(-1);
outPortConfig.constant(false);
outPortConfig.setMemDesc(cabdDescCreator.createSharedDesc(kv_precision, shape));
config.outConfs.push_back(std::move(outPortConfig));
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}

void MemoryInputSDPA::assignState(MemStatePtr newState) {
auto sdpaNode = m_sdpaNode.lock();
OPENVINO_ASSERT(sdpaNode);
auto sdpaState = std::dynamic_pointer_cast<VariableStateKVcache>(newState);
OPENVINO_ASSERT(sdpaState);
sdpaNode->assignState(sdpaState, child_port_idx);
m_sdpaState = std::dynamic_pointer_cast<VariableStateKVcache>(newState);
OPENVINO_ASSERT(m_sdpaState);
sdpaNode->assignState(m_sdpaState, child_port_idx);
}

MemStatePtr MemoryInputSDPA::makeState() const {
Expand All @@ -699,25 +753,33 @@ MemStatePtr MemoryInputSDPA::makeState() const {
state_name = state_name.substr(0, suffix_idx);
}

// somehow retrieve the internal precision and axis order from the SDPA node
// m_sdpa->get_kv_precision();
// m_sdpa->get_kv_axis_order();

auto kv_precision = element::bf16;
auto node = m_sdpaNode.lock();
// retrieve the internal precision and axis order from the SDPA node
auto kv_precision = node->getKVCachePrecision();
VectorDims order = {0, 1, 2, 3};
if (!node->getKVCacheOrder().empty())
order = node->getKVCacheOrder();

auto internal_desc = ArbitraryOrderDescCreator(order).createSharedDesc(kv_precision, outputShapes.at(0));

return std::make_shared<VariableStateKVcache>(state_name, original_desc, internal_desc);
}

bool MemoryInputSDPA::isExecutable() const {
// this node is mostly a proxy to transfer the state to the SDPA
// so the SDPA itself should handle the reset state as it handles the memory manipulation
return false;
}

void MemoryInputSDPA::execute(dnnl::stream strm) {
if (m_sdpaState->is_reset_state()) {
// has init subgraph
if (!getParentEdges().empty()) {
auto input = getParentEdgeAt(0)->getMemoryPtr();
m_sdpaState->init_from(input);
}
}
// 1, if in reset:
// a, if has init subgraph, the state shape will be defined after init_from
// b, if no init subgraph, the state shape will be defined in VariableStateKVcache::reset
// 2, if not in reset, the state will be updated in sdpa::infer call
// Update the shape to to fake memoryobj, shapeof can get the correct shape then
this->redefineOutputMemory(0, m_sdpaState->internal_desc()->getShape().getStaticDims());

return;
}

Expand Down
6 changes: 4 additions & 2 deletions src/plugins/intel_cpu/src/nodes/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,20 @@ class MemoryInputSDPA : public MemoryInputBase {
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;

void createPrimitive() override;
void initSupportedPrimitiveDescriptors() override;

bool isExecutable() const override;
void execute(dnnl::stream strm) override;

void resolveInPlaceEdges(Edge::LOOK look) override;

void assignState(MemStatePtr newState) override;
MemStatePtr makeState() const override;
bool needShapeInfer() const override { return false; }

private:
std::weak_ptr<ScaledDotProductAttention> m_sdpaNode;
int child_port_idx;
std::shared_ptr<VariableStateKVcache> m_sdpaState;
int child_port_idx = -1;
};
} // namespace node
} // namespace intel_cpu
Expand Down
Loading

0 comments on commit d01830f

Please sign in to comment.