diff --git a/include/flexflow/model.h b/include/flexflow/model.h index ea64f65a95..6dda67bbfe 100644 --- a/include/flexflow/model.h +++ b/include/flexflow/model.h @@ -1079,6 +1079,7 @@ class FFModel { bool use_propagation) const; #ifdef FF_USE_NCCL ncclComm_t *find_nccl_comms(MachineView const &view) const; + void finish_nccl_comms(); #endif #ifdef FF_USE_PROPAGATE void propagate(std::map const ¤t, diff --git a/src/runtime/model.cc b/src/runtime/model.cc index f1e222e6e3..4c67de1aa9 100644 --- a/src/runtime/model.cc +++ b/src/runtime/model.cc @@ -1589,41 +1589,47 @@ FFModel::FFModel(FFConfig &_config, bool cpu_offload) model_id = model_counter++; } +#ifdef FF_USE_NCCL +void FFModel::finish_nccl_comms() { + Context ctx = config.lg_ctx; + Runtime *runtime = config.lg_hlr; + for (auto const &comm : view_hash_to_nccl_comms) { + // Find the machine view that has the hash + MachineView view; + for (size_t l = 0; l < operators.size(); l++) { + view = operators[l]->outputs[0]->machine_view; + if (view.hash() == comm.first) { + break; + } + } + assert(view.hash() == comm.first && "Cannot find the machine view"); + IndexSpace task_is = get_or_create_task_is(view); + Domain domain = runtime->get_index_space_domain(ctx, task_is); + ArgumentMap argmap; + int idx = 0; + for (Domain::DomainPointIterator it(domain); it; it++, idx++) { + argmap.set_point(*it, + TaskArgument(&comm.second[idx], sizeof(ncclComm_t))); + } + IndexLauncher index_launcher(NCCL_FINISH_COMMS_TASK_ID, + task_is, + TaskArgument(nullptr, 0), + argmap, + Predicate::TRUE_PRED, + false /*must*/, + 0 /*mapper_id*/, + comm.first); + FutureMap fm = runtime->execute_index_space(ctx, index_launcher); + fm.wait_all_results(); + } +} +#endif + FFModel::~FFModel() { // Destroy nccl communication groups #ifdef FF_USE_NCCL if (config.computationMode == COMP_MODE_TRAINING) { - Context ctx = config.lg_ctx; - Runtime *runtime = config.lg_hlr; - for (auto const &comm : view_hash_to_nccl_comms) { - // Find the machine view that has the hash - MachineView view; - for (size_t l = 0; l < operators.size(); l++) { - view = operators[l]->outputs[0]->machine_view; - if (view.hash() == comm.first) { - break; - } - } - assert(view.hash() == comm.first && "Cannot find the machine view"); - IndexSpace task_is = get_or_create_task_is(view); - Domain domain = runtime->get_index_space_domain(ctx, task_is); - ArgumentMap argmap; - int idx = 0; - for (Domain::DomainPointIterator it(domain); it; it++, idx++) { - argmap.set_point(*it, - TaskArgument(&comm.second[idx], sizeof(ncclComm_t))); - } - IndexLauncher index_launcher(NCCL_FINISH_COMMS_TASK_ID, - task_is, - TaskArgument(nullptr, 0), - argmap, - Predicate::TRUE_PRED, - false /*must*/, - 0 /*mapper_id*/, - comm.first); - FutureMap fm = runtime->execute_index_space(ctx, index_launcher); - fm.wait_all_results(); - } + finish_nccl_comms(); } #endif } diff --git a/src/runtime/request_manager.cc b/src/runtime/request_manager.cc index d21285eef2..bada87ab19 100644 --- a/src/runtime/request_manager.cc +++ b/src/runtime/request_manager.cc @@ -2365,6 +2365,9 @@ void RequestManager::background_serving_task( // Registered SSMs: perform speculative inference rm->serve_spec_infer(llm); } +#ifdef FF_USE_NCCL + llm->finish_nccl_comms(); +#endif } /*static*/