From ef70987fb1e39a6a319eb7f8c0ecff6a61882404 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 18 Nov 2024 11:15:02 -0800 Subject: [PATCH] Add a repro for #3282 --- tests/cpp/test_multidevice_sharding.cpp | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 1e1ff2eab9e..873cbd3e8ca 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -340,6 +340,38 @@ TEST_F(MultiDeviceTest, Transpose) { UnorderedElementsAre(HeuristicIs(SchedulerType::Transpose))); } +TEST_F(MultiDeviceTest, ParallelizeLoopSplit) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigConcreteTensor({num_devices * 3}); + in->setDeviceMesh(mesh); + fusion->addInput(in); + TensorView* out = set(in); + fusion->addOutput(out); + + for (auto* tv : {in, out}) { + tv->split(0, num_devices, /*inner_split=*/false); + tv->axis(0)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + at::Tensor in_tensor = at::randn({3}, tensor_options); + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {in_tensor}, + {in_tensor}, + __LINE__, + __FILE__); +} + class MultiDeviceBroadcastTest : public MultiDeviceTest, public testing::WithParamInterface {};