Skip to content

Commit

Permalink
Add a repro for #3282
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 21, 2024
1 parent cf126e6 commit ef70987
Showing 1 changed file with 32 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/cpp/test_multidevice_sharding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,38 @@ TEST_F(MultiDeviceTest, Transpose) {
UnorderedElementsAre(HeuristicIs(SchedulerType::Transpose)));
}

TEST_F(MultiDeviceTest, ParallelizeLoopSplit) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());

const auto num_devices = communicator_->size();
auto mesh = DeviceMesh::createForNumDevices(num_devices);

TensorView* in = makeContigConcreteTensor({num_devices * 3});
in->setDeviceMesh(mesh);
fusion->addInput(in);
TensorView* out = set(in);
fusion->addOutput(out);

for (auto* tv : {in, out}) {
tv->split(0, num_devices, /*inner_split=*/false);
tv->axis(0)->parallelize(ParallelType::DIDx);
tv->setAllocationDomain(tv->getLoopDomain(), true);
}

at::Tensor in_tensor = at::randn({3}, tensor_options);
FusionExecutorCache executor_cache(std::move(fusion));
at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0];

testValidate(
executor_cache.fusion(),
{out_tensor},
{in_tensor},
{in_tensor},
__LINE__,
__FILE__);
}

class MultiDeviceBroadcastTest : public MultiDeviceTest,
public testing::WithParamInterface<bool> {};

Expand Down

0 comments on commit ef70987

Please sign in to comment.