Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Improve bulking in Gluon #13890

Merged
merged 3 commits into from
Jan 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,10 @@ struct RunContext {
* \brief the stream of the device, can be NULL or Stream<gpu>* in GPU mode
*/
void *stream;
/*!
* \brief indicator of whether this execution is run in bulk mode
*/
bool is_bulk;
/*!
* \brief get mshadow stream from Context
* \return the mshadow stream
Expand Down
10 changes: 6 additions & 4 deletions src/engine/stream_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
RunContext ret;
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr};
ret = RunContext{ctx, nullptr, false};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
Expand All @@ -85,7 +85,9 @@ RunContext StreamManager<kNumGpus, kStreams>::GetRunContext(
use_counter = counter;
counter = (counter + 1) % kStreams;
}
ret = RunContext{ctx, gpu_streams_.at(ctx.dev_id).at(use_counter)};
ret = RunContext{ctx,
gpu_streams_.at(ctx.dev_id).at(use_counter),
false};
break;
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
Expand All @@ -103,7 +105,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
RunContext ret;
switch (ctx.dev_mask()) {
case cpu::kDevMask:
ret = RunContext{ctx, nullptr};
ret = RunContext{ctx, nullptr, false};
break;
case gpu::kDevMask: {
#if MXNET_USE_CUDA
Expand All @@ -114,7 +116,7 @@ RunContext StreamManager<kNumGpus, kStreams>::GetIORunContext(
gpu_io_streams_.at(ctx.dev_id) = mshadow::NewStream<gpu>(false, false, ctx.dev_id);
}
}
ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id)};
ret = RunContext{ctx, gpu_io_streams_.at(ctx.dev_id), false};
break;
#else
LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
Expand Down
6 changes: 6 additions & 0 deletions src/engine/threaded_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,13 @@ class ThreadedEngine : public Engine {
DeduplicateVarHandle(&bulk_status.const_vars, &bulk_status.mutable_vars);
SyncFn fn = std::move(bulk_status.fn);
this->PushAsync([fn](RunContext ctx, CallbackOnComplete on_complete) {
ctx.is_bulk = true;
fn(ctx);
ctx.is_bulk = false;
bool is_gpu = ctx.ctx.dev_mask() == gpu::kDevMask;
if (is_gpu) {
ctx.get_stream<gpu>()->Wait();
}
on_complete();
}, bulk_status.ctx, bulk_status.const_vars, bulk_status.mutable_vars,
FnProperty::kNormal, 0, "ImperativeBulk");
Expand Down
5 changes: 1 addition & 4 deletions src/imperative/cached_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -583,14 +583,11 @@ void CachedOp::StaticInitExec(
}

size_t bulk_size = idx.num_nodes();
std::unordered_set<uint32_t> excludes;
if (recording || keep_fwd) {
bulk_size = keep_fwd ? config_.backward_bulk_size : config_.forward_bulk_size;
for (const auto& i : idx.outputs()) excludes.insert(idx.entry_id(i));
for (const auto& i : idx.input_nodes()) excludes.insert(idx.entry_id(i, 0));
}

CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size, excludes,
CreateEngineOpSeg(idx, default_ctx, start_nid, end_nid, bulk_size,
state.execs, skip_plus_node, &state.opr_segs);
}

Expand Down
10 changes: 1 addition & 9 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ inline void PushFCompute(const FCompute& fn,
fn(attrs, opctx, input_blobs, tmp_req, output_blobs);
// post-fcompute fallback, cast to original storage type
CastNonDefaultStorage(post_temp_src, post_temp_dst, opctx, is_gpu);
if (is_gpu) {
if (is_gpu && !rctx.is_bulk) {
rctx.get_stream<gpu>()->Wait();
}
}, ctx, read_vars, write_vars, FnProperty::kNormal,
Expand Down Expand Up @@ -928,7 +928,6 @@ inline void CreateEngineOpSeg(
const size_t start_nid,
const size_t end_nid,
const size_t bulk_size,
const std::unordered_set<uint32_t>& excludes,
const std::vector<std::shared_ptr<exec::OpExecutor> >& execs,
const std::vector<int> skip_plus_node,
std::vector<EngineOprSeg> *opr_segs) {
Expand All @@ -944,13 +943,6 @@ inline void CreateEngineOpSeg(

// Stop at async nodes and invalid node (due to input/output is not allocated)
bool stop = is_async || !valid || seg_execs.size() >= bulk_size;
for (size_t i = 0; i < node.inputs.size() && !stop; ++i) {
if (excludes.count(idx.entry_id(node.inputs[i]))) stop = true;
}
auto num_outputs = node.source->num_outputs();
for (size_t i = 0; i < num_outputs && !stop; ++i) {
if (excludes.count(idx.entry_id(nid, i))) stop = true;
}

// Create opr segment for previous nodes.
if (stop && nid > seg_start) {
Expand Down