Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distributed backend (XCCL) #1105

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

Add distributed backend (XCCL) #1105

wants to merge 21 commits into from

Conversation

Chao1Han
Copy link
Contributor

@Chao1Han Chao1Han commented Nov 20, 2024

Motivation:

As design illustrated in Intel distributed support RFC pytorch/pytorch#141741, Intel GPU distributed Backend integration in PyTorch torch-xpu-ops.

Design:

USE_XCCL is set to ON by default. Users can manually set it to OFF to disable XCCL compilation. The OneCCL path is first searched in /opt/intel/oneapi/ccl/latest. If not found, it uses the CCL_ROOT flag set by the user after sourcing OneCCL. The USE_C10D_XCCL variable is intended to align with other distributed backend environment variables.
Oneccl lib link to torch_xpu align with other distribute backend.


include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake)

set(XCCL_ROOT $ENV{CCL_ROOT})

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do you get CCL_ROOT? I think you cannot assume it will be set after oneccl source.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will auto set after source oneccl env, and I remember oneccl update not affect this flag.

"Not able to create/get "
"XCCL Communicator since the devices are empty ");
{
// todo: why do we need mutex here?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we followed the same code logic as NCCL, right?

m.impl("recv_any_source_", recv_any_source_XPU);
m.impl("reduce_", reduce_XPU);
m.impl("broadcast_", broadcast_XPU);
m.impl("allreduce_", allreduce_XPU);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this PR, we only have allreduce implemented. Then should we only register allreduce here?

@Chao1Han Chao1Han changed the title [WIP] Add distributed backend (XCCL) Add distributed backend (XCCL) Dec 13, 2024
bool is_reduction_op = false) {
TORCH_CHECK(
!isFloat8Type(type) && is_reduction_op,
"Float8 dtypes are not currenlty supported for XCCL reductions");

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For non-reduction collective, please add mapping from FP8 to ccl data type.

{at::kDouble, ccl::datatype::float64},
{at::kBFloat16, ccl::datatype::bfloat16},
{at::kBool, ccl::datatype::uint8},
// use for allgather

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refine the description.

CMakeLists.txt Outdated Show resolved Hide resolved
CMakeLists.txt Outdated Show resolved Hide resolved
@gujinghui
Copy link
Contributor

LGTM

@zhangxiaoli73
Copy link

@EikanWang Could you please help review this PR?

Copy link
Contributor

@EikanWang EikanWang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we not reuse PyTorch test cases?

CMakeLists.txt Show resolved Hide resolved

include(${CMAKE_ROOT}/Modules/FindPackageHandleStandardArgs.cmake)

set(XCCL_ROOT $ENV{ONEAPI_ROOT}/ccl/latest)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it mean the ONEAPI_ROOT is a must-to-have environment variable?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we only require oneCCL source in building, so CCL_ROOT is a must-to-have environment variable. Updated in latest.

@Chao1Han
Copy link
Contributor Author

Why do we not reuse PyTorch test cases?

This pr just implement allreduce, so add simple test case. In the long term, once all operations are implemented, we will have one or two test files to validate the basic operations. Other tests, such as FSDP and DTensor, will directly reuse PyTorch's unit tests.

@EikanWang
Copy link
Contributor

Why do we not reuse PyTorch test cases?

This pr just implement allreduce, so add simple test case. In the long term, once all operations are implemented, we will have one or two test files to validate the basic operations. Other tests, such as FSDP and DTensor, will directly reuse PyTorch's unit tests.

Then, I suggest reusing PyTorch test cases by disabling some test cases that are not applicable. Pls. check with Daisy.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants