diff --git a/include/tvm/relay/vm/vm.h b/include/tvm/relay/vm/vm.h index d791b8c2ae8e..2a6cdfe60b88 100644 --- a/include/tvm/relay/vm/vm.h +++ b/include/tvm/relay/vm/vm.h @@ -6,8 +6,10 @@ #ifndef TVM_RELAY_RUNTIME_H_ #define TVM_RELAY_RUNTIME_H_ +#include +#include #include -#include +#include namespace tvm { namespace relay { @@ -161,6 +163,8 @@ struct VirtualMachine { const Instruction* code; size_t pc; size_t bp; + + std::vector ctxs; // Interface debugging. std::unordered_map global_map; @@ -177,7 +181,10 @@ struct VirtualMachine { functions(), frames(), stack(), func_index(0), code(nullptr), pc(0), bp(0) {} - static VirtualMachine FromModule(const Module& module); + void Init(const std::vector& ctxs); + + static VirtualMachine FromModule(const Module& module, + const std::vector& ctxs); }; VirtualMachine CompileModule(const Module& mod); diff --git a/include/tvm/runtime/memory_manager.h b/include/tvm/runtime/memory_manager.h new file mode 100644 index 000000000000..b9bebe4498f6 --- /dev/null +++ b/include/tvm/runtime/memory_manager.h @@ -0,0 +1,75 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/runtime/memory_manager.h + * \brief Abstract device memory management API + */ +#ifndef TVM_RUNTIME_MEMORY_MANAGER_H_ +#define TVM_RUNTIME_MEMORY_MANAGER_H_ + +#include +#include +#include +#include +#include "c_runtime_api.h" + +namespace std { +template<> +struct hash { + std::size_t operator()(const TVMContext& ctx) const { + return ((ctx.device_id << 8) | ctx.device_type); + } +}; + +template<> +struct equal_to { + bool operator()(const TVMContext& lhs, const TVMContext& rhs) const { + return (lhs.device_type == rhs.device_type && + lhs.device_id == rhs.device_id); + } +}; + +} // namespace std + +namespace tvm { +namespace runtime { + +struct Buffer { + // data pointer + void* data{nullptr}; + // Buffer size in bytes + size_t size{0}; + // TVM Context + TVMContext ctx; +}; + +class Allocator { + public: + Allocator(TVMContext ctx) : ctx_(ctx) {} + + virtual Buffer Alloc(size_t nbytes, size_t alignment, TVMType type_hint) = 0; + virtual void Free(const Buffer& buffer) = 0; + virtual size_t UsedMemory() = 0; + virtual ~Allocator() = default; + + protected: + TVMContext ctx_; +}; + +class MemoryManager { + public: + static MemoryManager* Global(); + + Allocator* GetAllocator(TVMContext ctx); + + private: + MemoryManager() {} + + private: + std::mutex mu_; + std::unordered_map> allocators_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_MEMORY_MANAGER_H_ diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 2b9674301607..4ae25d367cfa 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -10,6 +10,7 @@ #include #include #include "c_runtime_api.h" +#include "memory_manager.h" #include "serializer.h" namespace tvm { @@ -149,7 +150,8 @@ class NDArray { */ TVM_DLL static NDArray Empty(std::vector shape, DLDataType dtype, - DLContext ctx); + DLContext ctx, + Allocator* allocator = nullptr); /*! * \brief Create a NDArray backed by a dlpack tensor. * @@ -291,6 +293,19 @@ class NDArray::Container { } } } + + private: + friend class NDArray; + friend class RPCWrappedFunc; + /*! + * \brief The shape container, + * can be used used for shape data. + */ + std::vector shape_; + /*! \brief The internal array object */ + std::atomic ref_counter_{0}; + /*! \brief Buffer allocated by allocator */ + Buffer* buffer_; }; // implementations of inline functions diff --git a/python/tvm/relay/vm.py b/python/tvm/relay/vm.py index ab67c0542f0f..d5481ba4ffd2 100644 --- a/python/tvm/relay/vm.py +++ b/python/tvm/relay/vm.py @@ -50,7 +50,7 @@ def convert(args): _convert(arg, cargs) return cargs -def eval_vm(expr_or_mod, *args): +def eval_vm(expr_or_mod, ctx, *args): if isinstance(expr_or_mod, Expr): mod = Module.from_expr(expr_or_mod) else: @@ -67,4 +67,4 @@ def eval_vm(expr_or_mod, *args): cargs = convert(list(args)) import pdb; pdb.set_trace() - return _evaluate_vm(mod, cargs) + return _evaluate_vm(mod, ctx.device_type, ctx.device_id, cargs) diff --git a/src/relay/vm/vm.cc b/src/relay/vm/vm.cc index e6e159414545..58a36e145d15 100644 --- a/src/relay/vm/vm.cc +++ b/src/relay/vm/vm.cc @@ -4,14 +4,18 @@ * \brief Abstract device memory management API */ +#include #include #include #include #include "../backend/compile_engine.h" +#include "../../runtime/naive_allocator.h" #include #include +using namespace tvm::runtime; + namespace tvm { namespace relay { namespace vm { @@ -254,7 +258,7 @@ struct VMCompiler : ExprFunctor { auto it = this->context->global_map.find(global); CHECK(it != this->context->global_map.end()); CHECK(it->second < 5); - std::cout << "Invoke with: " << it->second; + std::cout << "Invoke with: " << global->name_hint << "(func idx" << it->second << ")" << std::endl; Emit(Invoke(it->second)); } @@ -474,6 +478,8 @@ void VirtualMachine::InvokeGlobal(const VMFunction& func, const std::vector& args) { InvokeGlobal(func, args); Run(); + auto alloc = MemoryManager::Global()->GetAllocator(ctxs[0]); + std::cout << "Memory used: " << alloc->UsedMemory() << " B\n"; // std::cout << "final stack size: " << stack.size() << "bp: " << bp << std::endl; return stack.back(); } @@ -504,6 +510,10 @@ void InvokePacked(const PackedFunc& func, size_t arg_count, std::vector& ctxs) { + this->ctxs = ctxs; +} + static int trip_counter = 0; void VirtualMachine::Run() { @@ -573,12 +583,10 @@ void VirtualMachine::Run() { } case Opcode::AllocTensor: { const auto& ti = instr.tensor_info; - DLContext ctx; - ctx.device_type = DLDeviceType::kDLCPU; - ctx.device_id = 0; auto shape = std::vector(ti.ndim); shape.assign(ti.shape, ti.shape + ti.ndim); - auto data = NDArray::Empty(shape, ti.dtype, ctx); + auto allocator = MemoryManager::Global()->GetAllocator(ctxs[0]); + auto data = NDArray::Empty(shape, ti.dtype, ctxs[0], allocator); stack.push_back(VMTensor(data)); pc++; goto main_loop; @@ -607,8 +615,11 @@ void VirtualMachine::Run() { } } -VirtualMachine VirtualMachine::FromModule(const Module& module) { - return CompileModule(module); +VirtualMachine VirtualMachine::FromModule(const Module& module, + const std::vector& ctxs) { + auto vm = CompileModule(module); + vm.Init(ctxs); + return vm; } /*! \brief Convert from an array of relay.Value into VM compatible objects. @@ -648,8 +659,9 @@ Value ConvertVMToValue(VMObject obj) { } } -VMObject EvaluateModule(const Module& module, const std::vector& vm_args) { - VirtualMachine vm = VirtualMachine::FromModule(module); +VMObject EvaluateModule(const Module& module, const std::vector ctxs, + const std::vector& vm_args) { + VirtualMachine vm = VirtualMachine::FromModule(module, ctxs); std::cout << "--------------------------" << std::endl; VMFunctionPrint(vm.functions[0]); std::cout << "--------------------------" << std::endl; @@ -659,6 +671,10 @@ VMObject EvaluateModule(const Module& module, const std::vector& vm_ar TVM_REGISTER_API("relay._vm._evaluate_vm") .set_body([](TVMArgs args, TVMRetValue* ret) { NodeRef to_compile = args[0]; + TVMContext ctx; + int dev_type = args[1]; + ctx.device_type = static_cast(dev_type); + ctx.device_id = args[2]; Module module; if (to_compile.as()) { @@ -670,8 +686,8 @@ TVM_REGISTER_API("relay._vm._evaluate_vm") LOG(FATAL) << "expected function or module"; } - std::vector vm_args = ConvertArgsToVM(args[1]); - auto result = EvaluateModule(module, vm_args); + std::vector vm_args = ConvertArgsToVM(args[3]); + auto result = EvaluateModule(module, {ctx}, vm_args); *ret = ConvertVMToValue(result); }); diff --git a/src/runtime/memory_manager.cc b/src/runtime/memory_manager.cc new file mode 100644 index 000000000000..d7d889a9087f --- /dev/null +++ b/src/runtime/memory_manager.cc @@ -0,0 +1,26 @@ +#include +#include "naive_allocator.h" +#include "pooled_allocator.h" + +namespace tvm { +namespace runtime { + +MemoryManager* MemoryManager::Global() { + static MemoryManager memory_manager; + return &memory_manager; +} + +Allocator* MemoryManager::GetAllocator(TVMContext ctx) { + std::lock_guard lock(mu_); + if (allocators_.find(ctx) == allocators_.end()) { + LOG(INFO) << "New allocator for " << DeviceName(ctx.device_type) << "(" + << ctx.device_id << ")"; + std::unique_ptr alloc(new NaiveAllocator(ctx)); + //std::unique_ptr alloc(new PooledAllocator(ctx, 128)); + allocators_.emplace(ctx, std::move(alloc)); + } + return allocators_.at(ctx).get(); +} + +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/naive_allocator.h b/src/runtime/naive_allocator.h new file mode 100644 index 000000000000..2225c9a1a995 --- /dev/null +++ b/src/runtime/naive_allocator.h @@ -0,0 +1,47 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file runtime/naive_allocator.h + */ +#ifndef TVM_RUNTIME_NAIVE_ALLOCATOR_H_ +#define TVM_RUNTIME_NAIVE_ALLOCATOR_H_ + +#include +#include +#include + +namespace tvm { +namespace runtime { + +class NaiveAllocator final : public Allocator { + public: + NaiveAllocator(TVMContext ctx) : Allocator(ctx), used_memory_(0) {} + + Buffer Alloc(size_t nbytes, size_t alignment, TVMType type_hint) override { + Buffer buf; + buf.ctx = ctx_; + buf.size = nbytes; + buf.data = DeviceAPI::Get(ctx_)->AllocDataSpace( + ctx_, nbytes, alignment, type_hint); + used_memory_.fetch_add(nbytes, std::memory_order_relaxed); + LOG(INFO) << "allocate " << nbytes << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + DeviceAPI::Get(ctx_)->FreeDataSpace(buffer.ctx, buffer.data); + used_memory_.fetch_sub(buffer.size, std::memory_order_relaxed); + LOG(INFO) << "free " << buffer.size << " B, used memory " << used_memory_ << " B"; + } + + size_t UsedMemory() override { + return used_memory_.load(std::memory_order_relaxed); + } + + private: + std::atomic used_memory_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_NAIVE_ALLOCATOR_H_ diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc index 0ffa4c174544..5d8bd84c1462 100644 --- a/src/runtime/ndarray.cc +++ b/src/runtime/ndarray.cc @@ -7,6 +7,7 @@ #include #include #include +#include #include "runtime_base.h" // deleter for arrays used by DLPack exporter @@ -57,15 +58,27 @@ struct NDArray::Internal { } delete ptr; } + + static void BufferDeleter(NDArray::Container* ptr) { + CHECK(ptr->buffer_ != nullptr); + MemoryManager::Global()->GetAllocator(ptr->buffer_->ctx)-> + Free(*(ptr->buffer_)); + delete ptr->buffer_; + delete ptr; + } // Local create function which allocates tensor metadata // but does not allocate space for the data. static NDArray Create(std::vector shape, DLDataType dtype, - DLContext ctx) { + DLContext ctx, bool with_allocator = false) { VerifyDataType(dtype); // critical zone NDArray::Container* data = new NDArray::Container(); - data->deleter = DefaultDeleter; + if (with_allocator) { + data->deleter = BufferDeleter; + } else { + data->deleter = DefaultDeleter; + } NDArray ret(data); ret.data_ = data; // RAII now in effect @@ -123,14 +136,21 @@ DLManagedTensor* NDArray::ToDLPack() const { NDArray NDArray::Empty(std::vector shape, DLDataType dtype, - DLContext ctx) { - NDArray ret = Internal::Create(shape, dtype, ctx); + DLContext ctx, + Allocator* allocator) { + NDArray ret = Internal::Create(shape, dtype, ctx, (allocator != nullptr)); // setup memory content size_t size = GetDataSize(ret.data_->dl_tensor); size_t alignment = GetDataAlignment(ret.data_->dl_tensor); - ret.data_->dl_tensor.data = - DeviceAPI::Get(ret->ctx)->AllocDataSpace( - ret->ctx, size, alignment, ret->dtype); + if (allocator == nullptr) { + ret.data_->dl_tensor.data = + DeviceAPI::Get(ret->ctx)->AllocDataSpace( + ret->ctx, size, alignment, ret->dtype); + } else { + ret.data_->buffer_ = new Buffer; + *ret.data_->buffer_ = allocator->Alloc(size, alignment, ret->dtype); + ret.data_->dl_tensor.data = ret.data_->buffer_->data; + } return ret; } diff --git a/src/runtime/pooled_allocator.h b/src/runtime/pooled_allocator.h new file mode 100644 index 000000000000..7be38dc3e962 --- /dev/null +++ b/src/runtime/pooled_allocator.h @@ -0,0 +1,85 @@ +/*! + * Copyright (c) 2019 by Contributors + * \file runtime/pooled_allocator.h + */ +#ifndef TVM_RUNTIME_POOLED_ALLOCATOR_H_ +#define TVM_RUNTIME_POOLED_ALLOCATOR_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { + +class PooledAllocator final : public Allocator { + public: + static constexpr size_t kDefaultPageSize = 4096; + + PooledAllocator(TVMContext ctx, size_t page_size=kDefaultPageSize) : + Allocator(ctx), page_size_(page_size), used_memory_(0) {} + + ~PooledAllocator() { + ReleaseAll(); + } + + Buffer Alloc(size_t nbytes, size_t alignment, TVMType type_hint) override { + std::lock_guard lock(mu_); + size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; + auto&& it = memory_pool_.find(size); + if (it != memory_pool_.end() && !it->second.empty()) { + auto&& pool = it->second; + auto ret = pool.back(); + pool.pop_back(); + return ret; + } + Buffer buf; + buf.ctx = ctx_; + buf.size = size; + buf.data = DeviceAPI::Get(ctx_)->AllocDataSpace( + ctx_, size, alignment, type_hint); + used_memory_.fetch_add(size, std::memory_order_relaxed); + LOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + return buf; + } + + void Free(const Buffer& buffer) override { + std::lock_guard lock(mu_); + if (memory_pool_.find(buffer.size) == memory_pool_.end()) { + memory_pool_.emplace(buffer.size, std::vector{}); + } + memory_pool_.at(buffer.size).push_back(buffer); + LOG(INFO) << "reclaim buffer " << buffer.size; + } + + size_t UsedMemory() override { + return used_memory_.load(std::memory_order_relaxed); + } + + private: + void ReleaseAll() { + std::lock_guard lock(mu_); + for (auto const& it : memory_pool_) { + auto const& pool = it.second; + for (auto const& buf : pool) { + DeviceAPI::Get(buf.ctx)->FreeDataSpace(buf.ctx, buf.data); + } + } + memory_pool_.clear(); + used_memory_ = 0; + LOG(INFO) << "release all buffers"; + } + + private: + size_t page_size_; + std::atomic used_memory_; + std::unordered_map> memory_pool_; + std::mutex mu_; +}; + +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_POOLED_ALLOCATOR_H_ diff --git a/tests/python/relay/test_vm.py b/tests/python/relay/test_vm.py index b674ed4b1e23..8299be5fd1dc 100644 --- a/tests/python/relay/test_vm.py +++ b/tests/python/relay/test_vm.py @@ -12,14 +12,14 @@ def test_id(): x = relay.var('x', shape=(10, 10)) f = relay.Function([x], x) x_data = np.random.rand(10, 10).astype('float64') - res = eval_vm(f, x_data) + res = eval_vm(f, tvm.cpu(), x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data) def test_op(): x = relay.var('x', shape=(10, 10)) f = relay.Function([x], x + x) x_data = np.random.rand(10, 10).astype('float32') - res = eval_vm(f, x_data) + res = eval_vm(f, tvm.cpu(), x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data + x_data) def any(x): @@ -35,11 +35,11 @@ def test_cond(): y_data = np.random.rand(10, 10).astype('float32') # same - res = eval_vm(f, x_data, x_data) + res = eval_vm(f, tvm.cpu(), x_data, x_data) tvm.testing.assert_allclose(res.asnumpy(), True) # diff - res = eval_vm(f, x_data, y_data) + res = eval_vm(f, tvm.cpu(), x_data, y_data) tvm.testing.assert_allclose(res.asnumpy(), False) @@ -52,11 +52,11 @@ def test_simple_if(): y_data = np.random.rand(10, 10).astype('float32') # same - res = eval_vm(f, x_data, x_data) + res = eval_vm(f, tvm.cpu(), x_data, x_data) tvm.testing.assert_allclose(res.asnumpy(), x_data) # diff - res = eval_vm(f, x_data, y_data) + res = eval_vm(f, tvm.cpu(), x_data, y_data) tvm.testing.assert_allclose(res.asnumpy(), y_data) def test_simple_call(): @@ -70,7 +70,7 @@ def test_simple_call(): i_data = np.array(0, dtype='int32') # Refactor this bit mod[mod.entry_func] = relay.Function([], sum_up) - result = eval_vm(mod, i_data) + result = eval_vm(mod, tvm.cpu(), i_data) tvm.testing.assert_allclose(result.asnumpy(), i_data) def test_count_loop(): @@ -88,7 +88,7 @@ def test_count_loop(): mod[sum_up] = func i_data = np.array(0, dtype='int32') mod[mod.entry_func] = relay.Function([], sum_up) - result = eval_vm(mod, i_data) + result = eval_vm(mod, tvm.cpu(), i_data) tvm.testing.assert_allclose(result.asnumpy(), i_data) def test_sum_loop(): @@ -108,7 +108,7 @@ def test_sum_loop(): i_data = np.array(10, dtype='int32') accum_data = np.array(0, dtype='int32') mod[mod.entry_func] = relay.Function([], sum_up) - result = eval_vm(mod, i_data, accum_data) + result = eval_vm(mod, tvm.cpu(), i_data, accum_data) tvm.testing.assert_allclose(result.asnumpy(), sum(range(1, 11))) def test_tuple_fst(): @@ -117,7 +117,7 @@ def test_tuple_fst(): f = relay.Function([tup], relay.TupleGetItem(tup, 0)) i_data = np.random.rand(1).astype('float32') j_data = np.random.rand(10).astype('float32') - result = eval_vm(f, (i_data, j_data)) + result = eval_vm(f, tvm.cpu(), (i_data, j_data)) tvm.testing.assert_allclose(result.asnumpy(), i_data) def import_mxnet_model(cell_type, input_size, hidden_size, fname, batch=1, seq_len=100): @@ -149,12 +149,12 @@ def test_rnn(): # execute_mxnet_model('gru', 128, 128, "gru_i128_h128") if __name__ == "__main__": - # test_id() - # test_op() - # test_cond() - # test_simple_if() - # test_simple_call() - # test_count_loop() - # test_sum_loop() + test_id() + test_op() + test_cond() + test_simple_if() + test_simple_call() + test_count_loop() + test_sum_loop() test_rnn() test_tuple_fst()