Skip to content

Commit

Permalink
Fix nccl-induced segfault (#1481)
Browse files Browse the repository at this point in the history
  • Loading branch information
goliaro authored Aug 31, 2024
1 parent 3b59f05 commit 28aff70
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 31 deletions.
1 change: 1 addition & 0 deletions include/flexflow/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -1079,6 +1079,7 @@ class FFModel {
bool use_propagation) const;
#ifdef FF_USE_NCCL
ncclComm_t *find_nccl_comms(MachineView const &view) const;
void finish_nccl_comms();
#endif
#ifdef FF_USE_PROPAGATE
void propagate(std::map<Op *, ParallelConfig> const &current,
Expand Down
68 changes: 37 additions & 31 deletions src/runtime/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1589,41 +1589,47 @@ FFModel::FFModel(FFConfig &_config, bool cpu_offload)
model_id = model_counter++;
}

#ifdef FF_USE_NCCL
void FFModel::finish_nccl_comms() {
Context ctx = config.lg_ctx;
Runtime *runtime = config.lg_hlr;
for (auto const &comm : view_hash_to_nccl_comms) {
// Find the machine view that has the hash
MachineView view;
for (size_t l = 0; l < operators.size(); l++) {
view = operators[l]->outputs[0]->machine_view;
if (view.hash() == comm.first) {
break;
}
}
assert(view.hash() == comm.first && "Cannot find the machine view");
IndexSpace task_is = get_or_create_task_is(view);
Domain domain = runtime->get_index_space_domain(ctx, task_is);
ArgumentMap argmap;
int idx = 0;
for (Domain::DomainPointIterator it(domain); it; it++, idx++) {
argmap.set_point(*it,
TaskArgument(&comm.second[idx], sizeof(ncclComm_t)));
}
IndexLauncher index_launcher(NCCL_FINISH_COMMS_TASK_ID,
task_is,
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
comm.first);
FutureMap fm = runtime->execute_index_space(ctx, index_launcher);
fm.wait_all_results();
}
}
#endif

FFModel::~FFModel() {
// Destroy nccl communication groups
#ifdef FF_USE_NCCL
if (config.computationMode == COMP_MODE_TRAINING) {
Context ctx = config.lg_ctx;
Runtime *runtime = config.lg_hlr;
for (auto const &comm : view_hash_to_nccl_comms) {
// Find the machine view that has the hash
MachineView view;
for (size_t l = 0; l < operators.size(); l++) {
view = operators[l]->outputs[0]->machine_view;
if (view.hash() == comm.first) {
break;
}
}
assert(view.hash() == comm.first && "Cannot find the machine view");
IndexSpace task_is = get_or_create_task_is(view);
Domain domain = runtime->get_index_space_domain(ctx, task_is);
ArgumentMap argmap;
int idx = 0;
for (Domain::DomainPointIterator it(domain); it; it++, idx++) {
argmap.set_point(*it,
TaskArgument(&comm.second[idx], sizeof(ncclComm_t)));
}
IndexLauncher index_launcher(NCCL_FINISH_COMMS_TASK_ID,
task_is,
TaskArgument(nullptr, 0),
argmap,
Predicate::TRUE_PRED,
false /*must*/,
0 /*mapper_id*/,
comm.first);
FutureMap fm = runtime->execute_index_space(ctx, index_launcher);
fm.wait_all_results();
}
finish_nccl_comms();
}
#endif
}
Expand Down
3 changes: 3 additions & 0 deletions src/runtime/request_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2365,6 +2365,9 @@ void RequestManager::background_serving_task(
// Registered SSMs: perform speculative inference
rm->serve_spec_infer(llm);
}
#ifdef FF_USE_NCCL
llm->finish_nccl_comms();
#endif
}

/*static*/
Expand Down

0 comments on commit 28aff70

Please sign in to comment.