Skip to content

Commit

Permalink
Rename mem manager to mem block
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Aug 5, 2024
1 parent d29948c commit f211b01
Show file tree
Hide file tree
Showing 21 changed files with 278 additions and 342 deletions.
173 changes: 65 additions & 108 deletions src/plugins/intel_cpu/src/cpu_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ namespace {
Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, const void* data, bool pads_zeroing) :
m_eng(eng),
m_pMemDesc(desc),
m_mgrHandle(std::make_shared<DnnlMemoryMngr>(make_unique<MemoryMngrWithReuse>()), this),
m_blockHandle(std::make_shared<DnnlMemoryBlock>(make_unique<MemoryBlockWithReuse>()), this),
dnnlMemHandle(this) {
if (desc->getPrecision() == element::string) {
OPENVINO_THROW("[CPU] Memory object cannot be created for string data.");
Expand All @@ -77,18 +77,18 @@ Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, const void* data, bo
Memory::Memory(const dnnl::engine& eng, const MemoryDesc& desc, const void* data, bool pads_zeroing) :
Memory::Memory(eng, desc.clone(), data, pads_zeroing) {}

Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, MemoryMngrPtr mngr) :
m_eng(eng), m_pMemDesc(desc), m_mgrHandle(mngr, this), dnnlMemHandle(this) {
Memory::Memory(const dnnl::engine& eng, MemoryDescPtr desc, MemoryBlockPtr block) :
m_eng(eng), m_pMemDesc(desc), m_blockHandle(block, this), dnnlMemHandle(this) {
if (desc->getPrecision() == element::string) {
OPENVINO_THROW("[CPU] Memory object can't be created for string data.");
}
bool memAllocated = m_mgrHandle->getRawPtr();
bool memAllocated = m_blockHandle->getRawPtr();

create(desc, nullptr, !memAllocated);
}

Memory::Memory(const dnnl::engine& eng, const MemoryDesc& desc, MemoryMngrPtr mngr) :
Memory::Memory(eng, desc.clone(), mngr) {}
Memory::Memory(const dnnl::engine& eng, const MemoryDesc& desc, MemoryBlockPtr block) :
Memory::Memory(eng, desc.clone(), block) {}

size_t Memory::getSize() const {
auto size = getDesc().getCurrentMemSize();
Expand All @@ -112,9 +112,9 @@ void Memory::create(MemoryDescPtr desc, const void* data, bool pads_zeroing) {
}
auto memSize = m_pMemDesc->getCurrentMemSize();
if (nullptr != data) {
m_mgrHandle->setExtBuff(const_cast<void*>(data), memSize);
m_blockHandle->setExtBuff(const_cast<void*>(data), memSize);
} else {
m_mgrHandle->resize(memSize);
m_blockHandle->resize(memSize);
}
}

Expand Down Expand Up @@ -145,7 +145,7 @@ void Memory::redefineDesc(MemoryDescPtr desc) {
void Memory::update() {
if (dnnlMemHandle.isInit()) {
auto prim = dnnlMemHandle.getPrim();
prim.set_data_handle(m_mgrHandle->getRawPtr());
prim.set_data_handle(m_blockHandle->getRawPtr());
}
}

Expand Down Expand Up @@ -185,7 +185,7 @@ dnnl::memory Memory::DnnlMemPrimHandle::getPrim() const {
}

bool Memory::isAllocated() const noexcept {
if (m_mgrHandle->getRawPtr()) {
if (m_blockHandle->getRawPtr()) {
return true;
}
if (!m_pMemDesc) {
Expand All @@ -209,17 +209,17 @@ void* Memory::getData() const {
return data;
}

void* MemoryMngrWithReuse::getRawPtr() const noexcept {
void* MemoryBlockWithReuse::getRawPtr() const noexcept {
return m_data.get();
}

void MemoryMngrWithReuse::setExtBuff(void *ptr, size_t size) {
void MemoryBlockWithReuse::setExtBuff(void *ptr, size_t size) {
m_useExternalStorage = true;
m_memUpperBound = size;
m_data = decltype(m_data)(ptr, release);
}

bool MemoryMngrWithReuse::resize(size_t size) {
bool MemoryBlockWithReuse::resize(size_t size) {
constexpr int cacheLineSize = 64;
bool sizeChanged = false;
if (size > m_memUpperBound) {
Expand All @@ -234,63 +234,20 @@ bool MemoryMngrWithReuse::resize(size_t size) {

if (numa_node >= 0) {
if (!mbind_move(ptr, size, numa_node)) {
DEBUG_LOG("MemoryMngrWithReuse move_memory to node ", numa_node, " failed\n");
DEBUG_LOG("MemoryBlockWithReuse move_memory to node ", numa_node, " failed\n");
}
}
}
return sizeChanged;
}

bool MemoryMngrWithReuse::hasExtBuffer() const noexcept {
bool MemoryBlockWithReuse::hasExtBuffer() const noexcept {
return m_useExternalStorage;
}

void MemoryMngrWithReuse::release(void *ptr) {}
void MemoryBlockWithReuse::release(void *ptr) {}

void MemoryMngrWithReuse::destroy(void *ptr) {
dnnl::impl::free(ptr);
}

void* MemoryMngrRealloc::getRawPtr() const noexcept {
return m_data.get();
}

void MemoryMngrRealloc::setExtBuff(void *ptr, size_t size) {
m_useExternalStorage = true;
m_memUpperBound = size;
m_data = decltype(m_data)(ptr, release);
}

bool MemoryMngrRealloc::resize(size_t size) {
constexpr int cacheLineSize = 64;
constexpr size_t growFactor = 2;
bool sizeChanged = false;
if (size > m_memUpperBound) {
size *= growFactor;
void *ptr = dnnl::impl::malloc(size, cacheLineSize);
if (!ptr) {
OPENVINO_THROW("Failed to allocate ", size, " bytes of memory");
}

if (auto src = m_data.get()) {
std::memcpy(ptr, src, m_memUpperBound);
}

m_memUpperBound = size;
m_useExternalStorage = false;
m_data = decltype(m_data)(ptr, destroy);
sizeChanged = true;
}
return sizeChanged;
}

bool MemoryMngrRealloc::hasExtBuffer() const noexcept {
return m_useExternalStorage;
}

void MemoryMngrRealloc::release(void *ptr) {}

void MemoryMngrRealloc::destroy(void *ptr) {
void MemoryBlockWithReuse::destroy(void *ptr) {
dnnl::impl::free(ptr);
}

Expand All @@ -301,7 +258,7 @@ StringMemory::StringMemory(const dnnl::engine& engine, const MemoryDescPtr& desc
OPENVINO_THROW("[CPU] StringMemory supports String type only.");
}

m_manager = std::make_shared<StringMemoryMngr>();
m_memoryBlock = std::make_shared<StringMemoryBlock>();

if (!m_mem_desc->isDefined()) {
return;
Expand All @@ -311,9 +268,9 @@ StringMemory::StringMemory(const dnnl::engine& engine, const MemoryDescPtr& desc

if (data != nullptr) {
auto not_const_data = const_cast<void *>(data);
m_manager->setExtBuff(reinterpret_cast<OvString *>(not_const_data), string_size);
m_memoryBlock->setExtBuff(reinterpret_cast<OvString *>(not_const_data), string_size);
} else {
m_manager->resize(string_size);
m_memoryBlock->resize(string_size);
}
}

Expand All @@ -326,7 +283,7 @@ void StringMemory::load(const IMemory& src, bool ftz) const {
}

void* StringMemory::getData() const {
return m_manager->getRawPtr();
return m_memoryBlock->getRawPtr();
}

void StringMemory::redefineDesc(MemoryDescPtr desc) {
Expand All @@ -339,13 +296,13 @@ void StringMemory::redefineDesc(MemoryDescPtr desc) {

m_mem_desc = desc;
const auto string_size = m_mem_desc->getShape().getElementsCount();
m_manager->resize(string_size);
m_memoryBlock->resize(string_size);
}

void StringMemory::nullify() {
auto data_ptr = m_manager->getStringPtr();
auto data_ptr = m_memoryBlock->getStringPtr();
if (data_ptr != nullptr) {
std::fill(data_ptr, data_ptr + m_manager->getStrLen(), OvString());
std::fill(data_ptr, data_ptr + m_memoryBlock->getStrLen(), OvString());
}
}

Expand Down Expand Up @@ -373,25 +330,25 @@ size_t StringMemory::getSize() const { // In bytes
return size;
}

MemoryMngrPtr StringMemory::getMemoryMngr() const {
OPENVINO_THROW("Unexpected call of StringMemory::getMemoryMngr()");
MemoryBlockPtr StringMemory::getMemoryBlock() const {
OPENVINO_THROW("Unexpected call of StringMemory::getMemoryBlock()");
}

dnnl::memory StringMemory::getPrimitive() const {
OPENVINO_THROW("Unexpected call of StringMemory::getPrimitive()");
}

void StringMemory::StringMemoryMngr::setExtBuff(OvString* ptr, size_t size) {
void StringMemory::StringMemoryBlock::setExtBuff(OvString* ptr, size_t size) {
m_use_external_storage = true;
m_str_upper_bound = size;
m_data = decltype(m_data)(ptr, release);
}

StringMemory::OvString* StringMemory::StringMemoryMngr::getStringPtr() const noexcept {
StringMemory::OvString* StringMemory::StringMemoryBlock::getStringPtr() const noexcept {
return m_data.get();
}

bool StringMemory::StringMemoryMngr::resize(size_t size) {
bool StringMemory::StringMemoryBlock::resize(size_t size) {
bool sizeChanged = false;
if (size > m_str_upper_bound) {
if (size > PTRDIFF_MAX) {
Expand All @@ -410,58 +367,58 @@ bool StringMemory::StringMemoryMngr::resize(size_t size) {
return sizeChanged;
}

bool StringMemory::StringMemoryMngr::hasExtBuffer() const noexcept {
bool StringMemory::StringMemoryBlock::hasExtBuffer() const noexcept {
return m_use_external_storage;
}

size_t StringMemory::StringMemoryMngr::getStrLen() const noexcept {
size_t StringMemory::StringMemoryBlock::getStrLen() const noexcept {
return m_str_upper_bound;
}

void StringMemory::StringMemoryMngr::destroy(OvString* ptr) {
void StringMemory::StringMemoryBlock::destroy(OvString* ptr) {
delete[] ptr;
}

void* StringMemory::StringMemoryMngr::getRawPtr() const noexcept {
void* StringMemory::StringMemoryBlock::getRawPtr() const noexcept {
return reinterpret_cast<void *>(m_data.get());
}

/////////////// DnnlMemoryMngr ///////////////
/////////////// DnnlMemoryBlock ///////////////

void* DnnlMemoryMngr::getRawPtr() const noexcept {
return m_pMemMngr->getRawPtr();
void* DnnlMemoryBlock::getRawPtr() const noexcept {
return m_pMemBlock->getRawPtr();
}

void DnnlMemoryMngr::setExtBuff(void *ptr, size_t size) {
m_pMemMngr->setExtBuff(ptr, size);
void DnnlMemoryBlock::setExtBuff(void *ptr, size_t size) {
m_pMemBlock->setExtBuff(ptr, size);
notifyUpdate();
}

bool DnnlMemoryMngr::resize(size_t size) {
bool sizeChanged = m_pMemMngr->resize(size);
bool DnnlMemoryBlock::resize(size_t size) {
bool sizeChanged = m_pMemBlock->resize(size);
if (sizeChanged) {
notifyUpdate();
}
return sizeChanged;
}

bool DnnlMemoryMngr::hasExtBuffer() const noexcept {
return m_pMemMngr->hasExtBuffer();
bool DnnlMemoryBlock::hasExtBuffer() const noexcept {
return m_pMemBlock->hasExtBuffer();
}

void DnnlMemoryMngr::registerMemory(Memory* memPtr) {
void DnnlMemoryBlock::registerMemory(Memory* memPtr) {
if (memPtr) {
m_setMemPtrs.insert(memPtr);
}
}

void DnnlMemoryMngr::unregisterMemory(Memory* memPtr) {
void DnnlMemoryBlock::unregisterMemory(Memory* memPtr) {
if (memPtr) {
m_setMemPtrs.erase(memPtr);
}
}

void DnnlMemoryMngr::notifyUpdate() {
void DnnlMemoryBlock::notifyUpdate() {
for (auto& item : m_setMemPtrs) {
if (item) {
item->update();
Expand All @@ -481,9 +438,9 @@ StaticMemory::StaticMemory(const dnnl::engine& eng, MemoryDescPtr desc, const vo
m_size = m_pMemDesc->getCurrentMemSize();

if (data) {
m_pMemMngr = std::make_shared<StaticMemoryMngr>(const_cast<void*>(data), m_size);
m_pMemBlock = std::make_shared<StaticMemoryBlock>(const_cast<void*>(data), m_size);
} else {
m_pMemMngr = std::make_shared<StaticMemoryMngr>(m_size);
m_pMemBlock = std::make_shared<StaticMemoryBlock>(m_size);
}

try {
Expand All @@ -494,7 +451,7 @@ StaticMemory::StaticMemory(const dnnl::engine& eng, MemoryDescPtr desc, const vo
m_prim = dnnl::memory(dnnl_desc->getDnnlDesc(), m_eng, DNNL_MEMORY_NONE);
//
// ========================
m_prim.set_data_handle(m_pMemMngr->getRawPtr());
m_prim.set_data_handle(m_pMemBlock->getRawPtr());
}
catch (const std::exception& exc) {
dnnlErrorCtx = exc.what();
Expand All @@ -517,7 +474,7 @@ MemoryDescPtr StaticMemory::getDescPtr() const {
}

void* StaticMemory::getData() const {
return m_pMemMngr->getRawPtr();
return m_pMemBlock->getRawPtr();
}

size_t StaticMemory::getSize() const {
Expand All @@ -543,8 +500,8 @@ void StaticMemory::load(const IMemory& src, bool ftz) const {
transferData(src, *this, ftz);
}

MemoryMngrPtr StaticMemory::getMemoryMngr() const {
return m_pMemMngr;
MemoryBlockPtr StaticMemory::getMemoryBlock() const {
return m_pMemBlock;
}

//oneDNN specifics for backward compatibility
Expand All @@ -561,38 +518,38 @@ void StaticMemory::nullify() {
memset(dataPtr, 0, getSize());
}

StaticMemory::StaticMemoryMngr::StaticMemoryMngr(size_t size) : m_size(size) {
memMngrImpl.resize(m_size);
StaticMemory::StaticMemoryBlock::StaticMemoryBlock(size_t size) : m_size(size) {
memBlockImpl.resize(m_size);
}

StaticMemory::StaticMemoryMngr::StaticMemoryMngr(void* data, size_t size) : m_size(size) {
memMngrImpl.setExtBuff(data, m_size);
StaticMemory::StaticMemoryBlock::StaticMemoryBlock(void* data, size_t size) : m_size(size) {
memBlockImpl.setExtBuff(data, m_size);
}

void* StaticMemory::StaticMemoryMngr::getRawPtr() const noexcept {
return memMngrImpl.getRawPtr();
void* StaticMemory::StaticMemoryBlock::getRawPtr() const noexcept {
return memBlockImpl.getRawPtr();
}

void StaticMemory::StaticMemoryMngr::setExtBuff(void* ptr, size_t size) {
OPENVINO_THROW("Unexpected: StaticMemoryMngr may not be modified");
void StaticMemory::StaticMemoryBlock::setExtBuff(void* ptr, size_t size) {
OPENVINO_THROW("Unexpected: StaticMemoryBlock may not be modified");
}

bool StaticMemory::StaticMemoryMngr::resize(size_t size) {
bool StaticMemory::StaticMemoryBlock::resize(size_t size) {
if (size != m_size) {
OPENVINO_THROW("Unexpected: StaticMemoryMngr may not resize the memory");
OPENVINO_THROW("Unexpected: StaticMemoryBlock may not resize the memory");
}
return false;
}

bool StaticMemory::StaticMemoryMngr::hasExtBuffer() const noexcept {
return memMngrImpl.hasExtBuffer();
bool StaticMemory::StaticMemoryBlock::hasExtBuffer() const noexcept {
return memBlockImpl.hasExtBuffer();
}

void StaticMemory::StaticMemoryMngr::registerMemory(Memory* memPtr) {
void StaticMemory::StaticMemoryBlock::registerMemory(Memory* memPtr) {
//do nothing
}

void StaticMemory::StaticMemoryMngr::unregisterMemory(Memory* memPtr) {
void StaticMemory::StaticMemoryBlock::unregisterMemory(Memory* memPtr) {
//do nothing
}

Expand Down
Loading

0 comments on commit f211b01

Please sign in to comment.