From a559b5c6701a1a05fa60586d2174e5e0a2e5e894 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 14 May 2024 11:17:49 -0700 Subject: [PATCH 1/2] fix bug in scatter --- csrc/multidevice/communication.cpp | 2 +- tests/cpp/test_multidevice_pipeline.cpp | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index b7ffc3d0bfd..87dfca9d063 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -294,7 +294,7 @@ c10::intrusive_ptr postScatter( input_tensors.front().push_back(output_tensor); continue; } - input_tensors.front().push_back(input_tensor.slice(0, j, j + 1)); + input_tensors.front().push_back(input_tensor.slice(0, j, j + 1).contiguous()); j++; } diff --git a/tests/cpp/test_multidevice_pipeline.cpp b/tests/cpp/test_multidevice_pipeline.cpp index 6c114448744..e11e491a889 100644 --- a/tests/cpp/test_multidevice_pipeline.cpp +++ b/tests/cpp/test_multidevice_pipeline.cpp @@ -217,8 +217,9 @@ DeviceMesh mesh1({1}); DeviceMesh mesh2({0, 1, 2, 3}); DeviceMesh mesh3({0, 2, 3}); DeviceMesh mesh4({1, 0, 2}); -auto all_meshes = testing::Values(mesh0, mesh1, mesh2, mesh3, mesh4); -auto all_nontrivial_meshes = testing::Values(mesh2, mesh3, mesh4); +DeviceMesh mesh5({1, 0}); +auto all_meshes = testing::Values(mesh0, mesh1, mesh2, mesh3, mesh4, mesh5); +auto all_nontrivial_meshes = testing::Values(mesh2, mesh3, mesh4, mesh5); } // namespace From bc50c2b4ddbdb6737a7f018a000717a116b8b0c0 Mon Sep 17 00:00:00 2001 From: snordmann Date: Tue, 14 May 2024 11:49:41 -0700 Subject: [PATCH 2/2] lintrunner --- csrc/multidevice/communication.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 87dfca9d063..4ad1cab4818 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -294,7 +294,8 @@ c10::intrusive_ptr postScatter( input_tensors.front().push_back(output_tensor); continue; } - input_tensors.front().push_back(input_tensor.slice(0, j, j + 1).contiguous()); + input_tensors.front().push_back( + input_tensor.slice(0, j, j + 1).contiguous()); j++; }