Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Failed to lower a set between bDID and b. #3488

Open
wujingyue opened this issue Nov 27, 2024 · 6 comments
Open

Failed to lower a set between bDID and b. #3488

wujingyue opened this issue Nov 27, 2024 · 6 comments
Assignees

Comments

@wujingyue
Copy link
Collaborator

(I ran into this issue incidentally but haven't tried to reduce the repros or identify the reasons.)

Symptoms

Below are two minimal repros. Both run the following definition but with different parallelizations. The first test shards y but not x or z, and the second test shards x and z but not y.

x: [D]
y: [1]
z = add(x, y): [D]
diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp
index 1e1ff2ea..4c5ba235 100644
--- a/tests/cpp/test_multidevice_sharding.cpp
+++ b/tests/cpp/test_multidevice_sharding.cpp
@@ -413,4 +413,60 @@ TEST_P(MultiDeviceBroadcastTest, Expanded) {

 INSTANTIATE_TEST_SUITE_P(, MultiDeviceBroadcastTest, testing::Bool());

+TEST_F(MultiDeviceTest, AddWithBroadcast_BroadcastIsParallelized) {
+  const auto num_devices = communicator_->size();
+  const auto mesh = DeviceMesh::createForNumDevices(num_devices);
+
+  auto fusion = std::make_unique<Fusion>();
+  FusionGuard fg(fusion.get());
+
+  TensorView* x = makeContigConcreteTensor({num_devices});
+  x->setDeviceMesh(mesh);
+  TensorView* y = makeContigConcreteTensor({1});
+  y->setDeviceMesh(mesh);
+  TensorView* z = add(x, y);
+
+  fusion->addInput(x);
+  fusion->addInput(y);
+  fusion->addOutput(z);
+
+  y->axis(0)->parallelize(ParallelType::DIDx);
+
+  std::vector<c10::IValue> in_tensors(
+      {at::randn({num_devices}, tensor_options),
+       at::randn({1}, tensor_options)});
+
+  FusionExecutorCache executor_cache(std::move(fusion));
+  auto out_tensors = executor_cache.runFusionWithInputs(in_tensors);
+  testValidate(
+      executor_cache.fusion(), out_tensors, in_tensors, __LINE__, __FILE__);
+}
+
+TEST_F(MultiDeviceTest, AddWithBroadcast_BroadcastIsNotParallelized) {
+  const auto num_devices = communicator_->size();
+  const auto mesh = DeviceMesh::createForNumDevices(num_devices);
+
+  auto fusion = std::make_unique<Fusion>();
+  FusionGuard fg(fusion.get());
+
+  TensorView* x = makeContigConcreteTensor({num_devices});
+  x->setDeviceMesh(mesh);
+  TensorView* y = makeContigConcreteTensor({1});
+  y->setDeviceMesh(mesh);
+  TensorView* z = add(x, y);
+
+  fusion->addInput(x);
+  fusion->addInput(y);
+  fusion->addOutput(z);
+
+  x->axis(0)->parallelize(ParallelType::DIDx);
+  z->axis(0)->parallelize(ParallelType::DIDx);
+
+  std::vector<c10::IValue> in_tensors(
+      {at::randn({1}, tensor_options), at::randn({1}, tensor_options)});
+
+  FusionExecutorCache executor_cache(std::move(fusion));
+  executor_cache.runFusionWithInputs(in_tensors);
+}
+
 } // namespace nvfuser

Both tests fail to execute and throw errors like

C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/multidevice/communication.cpp":72, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. all buffers must have the same number of elements

Reasons for failure

Currently, isResharding doesn't map bDID and b. As a result, after InsertReshardingsPass, a set was added between y and z of two different shardings. This set was lowered to either an Allgather or a Scatter, both of which failed to execute. The failed Allgather tried to concatenate D input tensors of shape [1] to an output tensor of shape [1]. The failed Scatter tried to split an input tensor of [1] to D devices.

Failed attempts

My first reaction is to let isResharding ignore the DID on broadcast dimensions.

diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp
index c1943fed..18c51e08 100644
--- a/csrc/multidevice/utils.cpp
+++ b/csrc/multidevice/utils.cpp
@@ -278,8 +278,12 @@ bool haveDifferentShardings(
         return true;
       }

-      if (a == nullptr || b == nullptr) {
-        return false;
+      if (a == nullptr) {
+        return b->isBroadcast();
+      }
+
+      if (b == nullptr) {
+        return a->isBroadcast();
       }

       // Going between bDIDx{1} and iDIDx{N} doesn't trigger resharding, but

This was able to avoid the set and therefore the communication. However, the first test failed with a different error:

C++ exception with description " INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/multidevice/utils.cpp":144, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Found multiple loop IterDomains with the same parallel type (deviceIdx.x): bdeviceIdx.x22{1}, bdeviceIdx.x23{1}, bdeviceIdx.x21{128}

This is because the pointerwise scheduler

  1. picked z as the reference TensorView, which isn't sharded,
  2. loop-split z on intra-GPU parallel types,
  3. tried to split x in the same way and produced multiple DIDx dimensions in its loop domain.

Potential solutions

  1. Let schedulers support inconsistent DIDs on broadcast dimensions.
  2. Disallow b(DID) in favor of b.
  3. Fix the lowering of a set between b(DID) and b, e.g., lower that instead to an alias operation.
@wujingyue
Copy link
Collaborator Author

cc @naoyam, @cowanmeg and @samnordmann

@naoyam
Copy link
Collaborator

naoyam commented Nov 27, 2024

  • picked z as the reference TensorView, which isn't sharded,
  • loop-split z on intra-GPU parallel types,
  • tried to split x in the same way and produced multiple DIDx dimensions in its loop domain.

Can you show the actual scheduling results of these tensors?

@wujingyue
Copy link
Collaborator Author

Can you show the actual scheduling results of these tensors?

%kernel {
T3_l_float[ iblockIdx.x26{1}, iUS27{1}, ithreadIdx.x25{128} ] ca_pos( 2 ) (DeviceMesh{0 1})
   = Set( T0_g_float[ iS30{1}, iS31{1}, iS29{128} ] (DeviceMesh{0 1}), cache_op=Streaming )
T4_l_float[ bblockIdx.x18{1}, bUS19{1}, bthreadIdx.x17{128} ] (DeviceMesh{0 1})
   = Set( T1_g_float[ bdeviceIdx.x22{1}, bdeviceIdx.x23{1}, bdeviceIdx.x21{128} ] (DeviceMesh{0 1}), cache_op=AllLevels )
T5_l_float[ iblockIdx.x14{1}, iUS15{1}, ithreadIdx.x13{128} ] ca_pos( 3 ) produce_pos( 2 ) (DeviceMesh{0 1})
   = T3_l_float[ iblockIdx.x26{1}, iUS27{1}, ithreadIdx.x25{128} ] ca_pos( 2 ) (DeviceMesh{0 1})
   + T4_l_float[ bblockIdx.x18{1}, bUS19{1}, bthreadIdx.x17{128} ] (DeviceMesh{0 1});
T2_g_float[ iblockIdx.x10{1}, iUS11{1}, ithreadIdx.x9{128} ] ca_pos( 2 ) produce_pos( 3 ) (DeviceMesh{0 1})
   = Set( T5_l_float[ iblockIdx.x14{1}, iUS15{1}, ithreadIdx.x13{128} ] ca_pos( 3 ) produce_pos( 2 ) (DeviceMesh{0 1}), cache_op=Streaming )

TransformPrinter :
T0_g_float[ iS30{1}, iS31{1}, iS29{128} ] (DeviceMesh{0 1})
 logical domain : (iS0{2})
 contiguity: t
  Split: iS0{2} by factor 128 -> iS28{1}, iS29{128}
  Split: iS28{1} by factor 1 -> iS30{1}, iS31{1}
 loop domain : (iS30{1}, iS31{1}, iS29{128})
T3_l_float[ iblockIdx.x26{1}, iUS27{1}, ithreadIdx.x25{128} ] ca_pos( 2 ) (DeviceMesh{0 1})
 logical domain : (iS5{2})
 contiguity: t
  Split: iS5{2} by factor 128 -> iS24{1}, ithreadIdx.x25{128}
  Split: iS24{1} by factor 1 -> iblockIdx.x26{1}, iUS27{1}
 loop domain : (iblockIdx.x26{1}, iUS27{1}, ithreadIdx.x25{128})
T1_g_float[ bdeviceIdx.x22{1}, bdeviceIdx.x23{1}, bdeviceIdx.x21{128} ] (DeviceMesh{0 1})
 logical domain : (bdeviceIdx.x1{1})
 contiguity: n
  Split: bdeviceIdx.x1{1} by factor 128 -> bdeviceIdx.x20{1}, bdeviceIdx.x21{128}
  Split: bdeviceIdx.x20{1} by factor 1 -> bdeviceIdx.x22{1}, bdeviceIdx.x23{1}
 loop domain : (bdeviceIdx.x22{1}, bdeviceIdx.x23{1}, bdeviceIdx.x21{128})
T4_l_float[ bblockIdx.x18{1}, bUS19{1}, bthreadIdx.x17{128} ] (DeviceMesh{0 1})
 logical domain : (bdeviceIdx.x6{1})
 contiguity: n
  Split: bdeviceIdx.x6{1} by factor 128 -> bdeviceIdx.x16{1}, bthreadIdx.x17{128}
  Split: bdeviceIdx.x16{1} by factor 1 -> bblockIdx.x18{1}, bUS19{1}
 loop domain : (bblockIdx.x18{1}, bUS19{1}, bthreadIdx.x17{128})
T5_l_float[ iblockIdx.x14{1}, iUS15{1}, ithreadIdx.x13{128} ] ca_pos( 3 ) produce_pos( 2 ) (DeviceMesh{0 1})
 logical domain : (iS2{2})
 contiguity: t
  Split: iS2{2} by factor 128 -> iS12{1}, ithreadIdx.x13{128}
  Split: iS12{1} by factor 1 -> iblockIdx.x14{1}, iUS15{1}
 loop domain : (iblockIdx.x14{1}, iUS15{1}, ithreadIdx.x13{128})
T2_g_float[ iblockIdx.x10{1}, iUS11{1}, ithreadIdx.x9{128} ] ca_pos( 2 ) produce_pos( 3 ) (DeviceMesh{0 1})
 logical domain : (iS7{2})
 contiguity: t
  Split: iS7{2} by factor 128 -> iS8{1}, ithreadIdx.x9{128}
  Split: iS8{1} by factor 1 -> iblockIdx.x10{1}, iUS11{1}
 loop domain : (iblockIdx.x10{1}, iUS11{1}, ithreadIdx.x9{128})
} // %kernel

See T1_g_float.

@naoyam
Copy link
Collaborator

naoyam commented Dec 6, 2024

Thanks. Does the second pattern work fine with change of haveDifferentShardings?

For the first pattern, how is the parallelization interpreted? Since z is not sharded, is the add operation executed by all devices?

@wujingyue
Copy link
Collaborator Author

For the first pattern, how is the parallelization interpreted? Since z is not sharded, is the add operation executed by all devices?

That's right. I forgot where this pattern actually happened. Probably in dropout which is replicated until sequence parallel is enabled.

@wujingyue
Copy link
Collaborator Author

Does the second pattern work fine with change of haveDifferentShardings?

Yes. I guess this is because the pointwise scheduler picked z the reference TV, which has DID in it. So the schedule it proposes skips DID.

wujingyue added a commit that referenced this issue Dec 7, 2024
a single-device ReduceScatter
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

When branches are created from issues, their pull requests are automatically linked.

2 participants