diff --git a/src/context.cc b/src/context.cc index f8bb3ec83..12b52b78a 100644 --- a/src/context.cc +++ b/src/context.cc @@ -11,7 +11,7 @@ namespace mscclpp { -Context::Impl::Impl() : ipcStream_(cudaStreamNonBlocking) {} +Context::Impl::Impl() {} IbCtx* Context::Impl::getIbContext(Transport ibTransport) { // Find IB context or create it @@ -43,7 +43,10 @@ MSCCLPP_API_CPP std::shared_ptr Context::connect(Endpoint localEndpo if (remoteEndpoint.transport() != Transport::CudaIpc) { throw mscclpp::Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage); } - conn = std::make_shared(localEndpoint, remoteEndpoint, pimpl_->ipcStream_); + if (!(pimpl_->ipcStream_)) { + pimpl_->ipcStream_ = std::make_shared(cudaStreamNonBlocking); + } + conn = std::make_shared(localEndpoint, remoteEndpoint, cudaStream_t(*(pimpl_->ipcStream_))); } else if (AllIBTransports.has(localEndpoint.transport())) { if (!AllIBTransports.has(remoteEndpoint.transport())) { throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage); diff --git a/src/include/context.hpp b/src/include/context.hpp index abb95b27d..e88c7e5fa 100644 --- a/src/include/context.hpp +++ b/src/include/context.hpp @@ -16,7 +16,7 @@ namespace mscclpp { struct Context::Impl { std::vector> connections_; std::unordered_map> ibContexts_; - CudaStreamWithFlags ipcStream_; + std::shared_ptr ipcStream_; CUmemGenericAllocationHandle mcHandle_; Impl();