-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
internal review #6
base: main
Are you sure you want to change the base?
Conversation
auto currentTimepoint = std::chrono::steady_clock::now(); | ||
auto timeElapsed = std::chrono::duration_cast<std::chrono::milliseconds>( | ||
currentTimepoint - workStartTime_); | ||
std::chrono::milliseconds opTimeout = std::chrono::milliseconds(60000); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
where do you use it?
@@ -67,17 +67,15 @@ ccl::reduction getXcclReduceOp(const ReduceOp& reduceOp, at::Tensor& input) { | |||
return xcclOps.at(reduceOp); | |||
} catch (const std::out_of_range&) { | |||
switch (reduceOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need to switch
: Work(rank, opType, "profilingTitle", inputs), | ||
device_(device), | ||
workStartTime_(std::chrono::steady_clock::now()) { | ||
unsigned char enable_timing = 0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you always set it as 0, then we don't need to keep it, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, Defining this variable serves as a form of annotation, informing reviewers and users that 0 represents the state of enable_timing
, which is meaningful.
"Work ran for ", | ||
timeElapsed.count(), | ||
" milliseconds before timing out."); | ||
TORCH_CHECK(false, exceptionMsg) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TORCH_CHECK(false, exceptionMsg);
abort();
} // namespace | ||
|
||
static std::mutex xcclCommDevIdxMapMutex; | ||
static std::unordered_map<std::shared_ptr<xcclComm_t>, int> xcclCommDevIdxMap; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Those static variables are not used in your code. Please check.
blockingWait_ = getCvarBool(TORCH_XCCL_BLOCKING_WAIT, false); | ||
init(); | ||
|
||
// Intel oneCCL requires passing CCL_LOCAL_RANK and CCL_LOCAL_SIZE for non-MPI |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More comment for why we use LOCAL_RANK
and LOCAL_WORLD_SIZE
.
std::shared_ptr<xcclComm_t> ProcessGroupXCCL::getXCCLComm( | ||
const std::string& deviceKey, | ||
at::Device& device) { | ||
if (deviceKey.empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C10_THROW_ERROR_WITH
devXCCLCommMap_.emplace(deviceKey, XCCLComm); | ||
} | ||
|
||
xcclStreamsMap_.emplace(deviceKey, std::move(stream)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so xcclEventsMap does not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
restore it
PreProcess pre, | ||
PostProcess post, | ||
OpType opType) { | ||
using traits = function_traits<Fn>; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
which collective need attribute as a must?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, allgather meet build error
for (const auto i : c10::irange(inputs.size())) { | ||
c10::xpu::XPUCachingAllocator::recordStream( | ||
inputs[i].storage().data_ptr(), stream); | ||
fn(inputs[i], outputs[i], attr, *comm, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add comment for output record stream.
false, "ProcessGroupXCCL::WorkXCCL::isSuccess not implemented"); | ||
} | ||
|
||
void abort() override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
abort here?
|
||
bool isCompleted() override; | ||
|
||
bool isSuccess() const override { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
case ReduceOp::BXOR: | ||
C10_THROW_ERROR(ValueError, "Cannot use ReduceOp.BXOR with NCCL"); | ||
C10_THROW_ERROR( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't change NCCL now.
|
||
c10::impl::VirtualGuardImpl impl(device.type()); | ||
c10::Stream stream = impl.getStream(device); | ||
sycl::queue& q = c10::xpu::XPUStream(stream).queue(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a big bug to use current stream as communication stream.
int rank, | ||
OpType opType, | ||
const std::optional<std::vector<at::Tensor>>& inputs) | ||
: Work(rank, opType, "profilingTitle", inputs), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need change
@@ -126,6 +131,13 @@ class TORCH_API ProcessGroup : public torch::CustomClassHolder { | |||
return backendType_; | |||
}; | |||
|
|||
inline bool backendSupportsSequenceNumbers(BackendType backendType) { | |||
if (backendType == BackendType::GLOO || backendType == BackendType::NCCL || | |||
backendType == BackendType::XCCL || backendType == BackendType::UCC) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you make sure that we need to support this sequence number?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sequence number used by RECORD_PARAM_COMMS. so we need it
@@ -180,7 +181,8 @@ def skip_if_lt_x_gpu(x): | |||
def decorator(func): | |||
@wraps(func) | |||
def wrapper(*args, **kwargs): | |||
if torch.cuda.is_available() and torch.cuda.device_count() >= x: | |||
if (torch.cuda.is_available() and torch.cuda.device_count() >= x) or \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't use if
for accelerator related check
See pytorch#140725 (comment) Running `torch.mps.synchronize()` after metal kernel resulted in infinite wait inside `[_MTLCommandBuffer waitUntilCompleted]` ``` (lldb) bt * thread #1, queue = 'com.apple.main-thread', stop reason = signal SIGSTOP * frame #0: 0x00000001aa919084 Metal`pthread_cond_wait + 12 frame #1: 0x00000001aa78b1b4 Metal`-[_MTLCommandBuffer waitUntilCompleted] + 84 frame #2: 0x00000001032bf358 libtorch_python.dylib`torch::mps::MPSModule_deviceSynchronize(_object*, _object*) + 40 frame #3: 0x0000000100e94c20 Python`cfunction_vectorcall_NOARGS + 100 frame #4: 0x0000000100e389b8 Python`PyObject_Vectorcall + 92 frame #5: 0x0000000100f61e38 Python`_PyEval_EvalFrameDefault + 19040 frame #6: 0x0000000100f5d180 Python`PyEval_EvalCode + 200 frame #7: 0x0000000100fcd1a4 Python`run_eval_code_obj + 104 frame #8: 0x0000000100fccbe4 Python`run_mod + 168 frame #9: 0x0000000100fcb518 Python`pyrun_file + 164 frame #10: 0x0000000100fca854 Python`_PyRun_SimpleFileObject + 256 frame pytorch#11: 0x0000000100fca4e8 Python`_PyRun_AnyFileObject + 80 frame pytorch#12: 0x0000000100ff2028 Python`pymain_run_file_obj + 164 frame pytorch#13: 0x0000000100ff1ce4 Python`pymain_run_file + 72 frame pytorch#14: 0x0000000100ff0f74 Python`Py_RunMain + 988 frame pytorch#15: 0x0000000100ff1564 Python`pymain_main + 304 frame pytorch#16: 0x0000000100ff1604 Python`Py_BytesMain + 40 frame pytorch#17: 0x000000019f630274 dyld`start + 2840 ``` Pull Request resolved: pytorch#141296 Approved by: https://github.com/huydhn
Fixes #ISSUE_NUMBER