diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index fbef4831e96..3ffd9e3d24d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -326,7 +326,7 @@ c10::intrusive_ptr postAllgather( at::Tensor input_tensor, at::Tensor output_tensor) { auto splits = - at::tensor_split(output_tensor, communication->team().size(), /*dim=*/0); + at::tensor_split(output_tensor, communication->team_size(), /*dim=*/0); assertBuffersHaveSameSize({input_tensor}, splits); // allgather primitive in c10d induces extra buffering time to copy out the diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 45c104b36d3..8631a1a04e5 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -90,6 +90,11 @@ class Communication : public Expr { return attribute(1); } + // A convenience helper so the user doesn't need to convert size_t to int64_t. + int64_t team_size() const { + return static_cast(team().size()); + } + DeviceIdxType root() const { return attribute(2); }