From 04e06a80b7a6341c5037a0b294ccb11594fe7242 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 18 Nov 2024 11:15:02 -0800 Subject: [PATCH 01/20] Add a repro for #3282 --- tests/cpp/test_multidevice_sharding.cpp | 32 +++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 1e1ff2eab9e..873cbd3e8ca 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -340,6 +340,38 @@ TEST_F(MultiDeviceTest, Transpose) { UnorderedElementsAre(HeuristicIs(SchedulerType::Transpose))); } +TEST_F(MultiDeviceTest, ParallelizeLoopSplit) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigConcreteTensor({num_devices * 3}); + in->setDeviceMesh(mesh); + fusion->addInput(in); + TensorView* out = set(in); + fusion->addOutput(out); + + for (auto* tv : {in, out}) { + tv->split(0, num_devices, /*inner_split=*/false); + tv->axis(0)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + + at::Tensor in_tensor = at::randn({3}, tensor_options); + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; + + testValidate( + executor_cache.fusion(), + {out_tensor}, + {in_tensor}, + {in_tensor}, + __LINE__, + __FILE__); +} + class MultiDeviceBroadcastTest : public MultiDeviceTest, public testing::WithParamInterface {}; From 416f1d0df4c239f92d848a7d8b63c46539d70788 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 18 Nov 2024 11:58:07 -0800 Subject: [PATCH 02/20] Remove an assumption in the transpose scheduler. --- csrc/scheduler/transpose.cpp | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index e5023f4e25c..553ba9d773e 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -227,34 +227,31 @@ class DomainMap : public pointwise_utils::DomainMap { root_dim, " in tensor ", tv); - auto replay_exprs = StmtSort::getExprsBetween( + std::vector replay_exprs = StmtSort::getExprsBetween( {mapped_id}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); // Project the root id to loop id. Similar to projectIdToRFactor. - for (auto expr : replay_exprs) { - if (expr->isA()) { - // Split with factor one is not supposed to be here, reshape would map - // this to a broadcast. This is a conservative assert, we can relaxed it - // and support with mapping it to outer. - NVF_ERROR( - !expr->as()->factor()->isOneInt(), - "split with factor one is supposed to be translated to broadcast by reshape"); - if (expr->as()->in() == mapped_id) { - mapped_id = expr->as()->inner(); + for (auto* expr : replay_exprs) { + if (auto* split = dynamic_cast(expr)) { + if (split->in() == mapped_id) { + mapped_id = split->inner(); } - } else if (expr->isA()) { + } else if (auto* merge = dynamic_cast(expr)) { // Merge with size-1 dimension is not supposed to be here, reshape would // map this to a squeeze. This is a conservative assert, we can relaxed // it and support with mapping it to out. NVF_ERROR( - !expr->as()->inner()->extent()->isOneInt(), + !merge->inner()->extent()->isOneInt(), "merge with size-1 dimension is supposed to be translated to squeeze by reshape"); - if (expr->as()->inner() == mapped_id) { - mapped_id = expr->as()->out(); + if (merge->inner() == mapped_id) { + mapped_id = merge->out(); + } + } else if (auto* resize = dynamic_cast(expr)) { + if (resize->in() == mapped_id) { + mapped_id = resize->out(); } - } else if (expr->isA() && expr->as()->in() == mapped_id) { - mapped_id = expr->as()->out(); } } + // Find the position of the loop id const auto& dom = tv->getLoopDomain(); for (auto i : c10::irange(dom.size())) { From 2c984c8dd9918726255c4062b0f2170033f18722 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 18 Nov 2024 17:37:50 -0800 Subject: [PATCH 03/20] Reimplement unshardSizesAndStrides. --- csrc/expr_evaluator.cpp | 91 ++++++++++++++++++++++--------------- csrc/fusion_segmenter.cpp | 4 +- csrc/multidevice/utils.cpp | 5 ++- csrc/tensor_metadata.cpp | 92 ++++++++++++++++++++++++-------------- 4 files changed, 118 insertions(+), 74 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index d4ca6daa022..567447871e4 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -6,6 +6,9 @@ */ // clang-format on +#include +#include + #include #include #include @@ -14,11 +17,9 @@ #include #include #include +#include #include -#include -#include - namespace nvfuser { namespace { @@ -127,6 +128,44 @@ void validateValWithConcreteValue( } } +std::vector unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes) { + std::vector unsharded_sizes = sizes.vec(); + + for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { + const ParallelType parallel_type = alloc_id->getParallelType(); + if (!isParallelTypeDeviceDim(parallel_type)) { + continue; + } + + const auto inputs = IterVisitor::getInputsTo( + {alloc_id}, + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); + if (inputs.empty()) { + // FIXME: is this even possible? Logical ought to dominate allocation. + continue; + } + NVF_ERROR(inputs.size() == 1); + + const auto iter = std::find( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + inputs[0]); + if (iter == tv->getLogicalDomain().end()) { + // FIXME: is this even possible? Logical ought to dominate allocation. + continue; + } + const auto index = std::count_if( + tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { + return !id->isReduction(); + }); + unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); + } + + return unsharded_sizes; +} + } // namespace void ExpressionEvaluator::bindTensorDomain( @@ -143,6 +182,14 @@ void ExpressionEvaluator::bindTensorDomain( logical_domain.size(), ", but got a tensor of rank ", t.dim()); + + std::vector sizes; + if (isSharded(tv)) { + sizes = unshardedSizes(tv, t.sizes()); + } else { + sizes = t.sizes().vec(); + } + for (auto i : c10::irange(t.dim())) { auto id = logical_domain[i]; if (id->isBroadcast()) { @@ -151,7 +198,7 @@ void ExpressionEvaluator::bindTensorDomain( if (id->hasExpandedExtent()) { // Verify that t is also expanded NVF_ERROR( - t.size(i) == 1 || t.stride(i) == 0, + sizes[i] == 1 || t.stride(i) == 0, "IterDomain ", id->toString(), " in ", @@ -159,45 +206,15 @@ void ExpressionEvaluator::bindTensorDomain( "TensorView ", tv->toString(), " has expanded extent but input tensor has size ", - t.size(i), + sizes[i], " and stride ", t.stride(i), " in dimension ", i); - bind_( - logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate); + bind_(logical_domain[i]->expandedExtent(), sizes[i], evaluate_validate); } } else { - if (logical_domain[i]->isDeviceDim()) { - // Currently we have the restrictions: - // (1) Devices parallelized axis extent == DeviceMesh's extent - // (2) Device parallelized axis cannot be split or merged - // Therefore, the device parallelized extents will always be allocated - // with size 1, but the symbolic axis extent is binded with the extent - // of the DeviceMesh - NVF_CHECK( - 1 == t.size(i), - "TensorView ", - tv->toString(), - getInputPosString(tv), - " IterDomain ", - id->toString(), - "is sharded and must have size 1, but input tensor has size ", - t.size(i)); - NVF_CHECK( - tv->hasDeviceMesh(), - "TV ", - tv->toString(), - getInputPosString(tv), - " has an empty DeviceMesh with DID parallelization") - bind_( - logical_domain[i]->extent(), - static_cast( - tv->getDeviceMesh().size(logical_domain[i]->getParallelType())), - evaluate_validate); - } else { - bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); - } + bind_(logical_domain[i]->extent(), sizes[i], evaluate_validate); } } } diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 506a2e81987..fa8b50d6fa2 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1883,7 +1883,9 @@ void eraseInputDistinctRootDomains(Fusion* fusion) { std::vector new_alloc; new_alloc.reserve(tv->getAllocationDomain().size()); for (IterDomain* alloc_id : tv->getAllocationDomain()) { - new_alloc.push_back(replay.getReplay().at(alloc_id)); + IterDomain* new_alloc_id = replay.getReplay().at(alloc_id); + new_alloc_id->parallelize(alloc_id->getParallelType()); + new_alloc.push_back(new_alloc_id); } std::vector new_loop; diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 1c40ffc5c2b..65cb76b0da7 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -106,8 +106,9 @@ std::pair, std::vector> getShardingChanges bool isSharded(const TensorView* tv) { bool is_sharded = false; - for (IterDomain* id : TensorDomain::noReductions(tv->getLoopDomain())) { - if (!id->isDeviceDim()) { + for (IterDomain* alloc_id : + TensorDomain::noReductions(tv->getMaybeAllocationDomain())) { + if (!alloc_id->isDeviceDim()) { continue; } diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 32fdee2de42..e35d235b709 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -272,6 +272,46 @@ void validateAllocationSizesAndStrides( } } +// FIXME: strides are never changed +std::pair, std::vector> unshardedSizesAndStrides( + TensorView* tv, + c10::IntArrayRef sizes, + c10::IntArrayRef strides) { + std::vector unsharded_sizes = sizes.vec(); + std::vector unsharded_strides = strides.vec(); + + for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { + const ParallelType parallel_type = alloc_id->getParallelType(); + if (!isParallelTypeDeviceDim(parallel_type)) { + continue; + } + + const auto inputs = IterVisitor::getInputsTo( + {alloc_id}, + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); + if (inputs.empty()) { + // FIXME: is this even possible? Logical ought to dominate loop. + continue; + } + NVF_ERROR(inputs.size() == 1); + + const auto iter = std::find( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + inputs[0]); + if (iter == tv->getLogicalDomain().end()) { + // FIXME: is this even possible? Logical ought to dominate loop. + continue; + } + const auto index = std::count_if( + tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { + return !id->isReduction(); + }); + unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); + } + + return {unsharded_sizes, unsharded_strides}; +} } // namespace std::pair, std::vector> @@ -282,11 +322,21 @@ inferAndValidateAllocationSizesAndStrides( const auto& logical = tv->getLogicalDomain(); const auto& alloc = tv->getMaybeAllocationDomain(); + std::vector logical_sizes; + std::vector logical_strides; + if (isSharded(tv)) { + std::tie(logical_sizes, logical_strides) = + unshardedSizesAndStrides(tv, tensor.sizes(), tensor.strides()); + } else { + logical_sizes = tensor.sizes().vec(); + logical_strides = tensor.strides().vec(); + } + // active IDs and their shape and stride std::unordered_map> active_ids; int64_t dim_index = 0; for (IterDomain* id : TensorDomain::noReductions(logical)) { - active_ids[id] = {tensor.size(dim_index), tensor.stride(dim_index)}; + active_ids[id] = {logical_sizes[dim_index], logical_strides[dim_index]}; dim_index++; } NVF_ERROR(dim_index == tensor.dim()); @@ -296,50 +346,24 @@ inferAndValidateAllocationSizesAndStrides( // Now active_ids should contain the final sizes and strides, unordered. We // need to put them to the correct order. - std::vector sizes; - std::vector strides; - sizes.reserve(alloc.size()); - strides.reserve(alloc.size()); + std::vector allocation_sizes; + std::vector allocation_strides; for (IterDomain* id : TensorDomain::noReductions(alloc)) { if (id->isDeviceDim()) { - sizes.push_back(1); + allocation_sizes.push_back(1); } else { - sizes.push_back(active_ids.at(id).first); + allocation_sizes.push_back(active_ids.at(id).first); } - strides.push_back(active_ids.at(id).second); + allocation_strides.push_back(active_ids.at(id).second); } // Only validate final sizes and strides when we have a non-empty tensor. if (tensor.numel() != 0) { validateAllocationSizesAndStrides( - alloc, tv->getContiguity(), sizes, strides); - } - return {std::move(sizes), std::move(strides)}; -} - -namespace { -std::pair, std::vector> unshardedSizesAndStrides( - TensorView* tv, - c10::IntArrayRef sizes, - c10::IntArrayRef strides) { - std::vector unsharded_sizes(sizes.size()); - std::vector unsharded_strides(strides.size()); - for (const auto i : c10::irange(sizes.size())) { - IterDomain* id = tv->getLogicalDomain()[i]; - if (id->isDeviceDim()) { - unsharded_sizes[i] = tv->getDeviceMesh().size(id->getParallelType()); - // This probably doesn't matter in practice unless a kernel accidentally - // tries to access the data on another rank. To be safe, set the stride - // to zero, analogous to an expanded broadcast dimension. - unsharded_strides[i] = 0; - } else { - unsharded_sizes[i] = sizes[i]; - unsharded_strides[i] = strides[i]; - } + alloc, tv->getContiguity(), allocation_sizes, allocation_strides); } - return {unsharded_sizes, unsharded_strides}; + return {std::move(allocation_sizes), std::move(allocation_strides)}; } -} // namespace std::vector GetMetaData::evaluate( const ExpressionEvaluator& ee, From 44b50912bffb6280aab21738f9b45a1d1fd23ae1 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Nov 2024 14:29:06 -0800 Subject: [PATCH 04/20] Inherit parallel type for new allocation IDs --- csrc/fusion_segmenter.cpp | 1 + csrc/transform_replay.cpp | 9 ++++++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index fa8b50d6fa2..4d7cb4e693b 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1884,6 +1884,7 @@ void eraseInputDistinctRootDomains(Fusion* fusion) { new_alloc.reserve(tv->getAllocationDomain().size()); for (IterDomain* alloc_id : tv->getAllocationDomain()) { IterDomain* new_alloc_id = replay.getReplay().at(alloc_id); + // FIXME: should this be taken care of by ReplayTransformations? new_alloc_id->parallelize(alloc_id->getParallelType()); new_alloc.push_back(new_alloc_id); } diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 06e15929aa9..0ac2a9e97a4 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -770,11 +770,14 @@ std::pair TransformReplay::replayCasP( new_contiguity.reserve(producer_rank); for (auto i : c10::irange(producer_rank)) { - IterDomain* id = producer->getAllocationDomain()[i]; + IterDomain* alloc_id = producer->getAllocationDomain()[i]; // We won't find reduction IterDomains in the map. See // AllocationDomainTest.CacheBefore. - if (auto it = p2c_map.find(id); it != p2c_map.end()) { - new_allocation_domain.push_back(it->second); + if (auto it = p2c_map.find(alloc_id); it != p2c_map.end()) { + IterDomain* new_alloc_id = it->second; + // FIXME: should this be taken care of by ReplayTransformations? + new_alloc_id->parallelize(alloc_id->getParallelType()); + new_allocation_domain.push_back(new_alloc_id); new_contiguity.push_back(producer->getContiguity()[i]); } } From 521c783a6e1a7d7ad5c0dc568ff6e2af79924501 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 21 Nov 2024 15:07:07 -0800 Subject: [PATCH 05/20] Fix broadcast tests --- csrc/expr_evaluator.cpp | 7 +++---- tests/cpp/test_multidevice_sharding.cpp | 18 +++++++++++++----- 2 files changed, 16 insertions(+), 9 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 567447871e4..7e04c3bcd4d 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -193,8 +193,7 @@ void ExpressionEvaluator::bindTensorDomain( for (auto i : c10::irange(t.dim())) { auto id = logical_domain[i]; if (id->isBroadcast()) { - // DIDs are ignored for broadcast. - bind_(logical_domain[i]->extent(), 1, evaluate_validate); + bind_(id->extent(), 1, evaluate_validate); if (id->hasExpandedExtent()) { // Verify that t is also expanded NVF_ERROR( @@ -211,10 +210,10 @@ void ExpressionEvaluator::bindTensorDomain( t.stride(i), " in dimension ", i); - bind_(logical_domain[i]->expandedExtent(), sizes[i], evaluate_validate); + bind_(id->expandedExtent(), sizes[i], evaluate_validate); } } else { - bind_(logical_domain[i]->extent(), sizes[i], evaluate_validate); + bind_(id->extent(), sizes[i], evaluate_validate); } } } diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 873cbd3e8ca..4b1f605313d 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -424,20 +424,28 @@ TEST_P(MultiDeviceBroadcastTest, Expanded) { TensorView* in = TensorViewBuilder() .dtype(DataType::Float) .contiguity({std::nullopt, true}) - .shape({3, -1}) + .shape({num_devices * 3, -1}) .expanded({true, false}) .build(); in->setDeviceMesh(mesh); - if (parallelizes_broadcast) { - in->axis(0)->parallelize(ParallelType::DIDx); - } TensorView* out = set(in); fusion->addInput(in); fusion->addOutput(out); + if (parallelizes_broadcast) { + for (auto* tv : {in, out}) { + tv->split(0, num_devices, /*inner_split=*/false); + tv->axis(0)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + } + } + FusionExecutorCache executor_cache(std::move(fusion)); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor in_tensor = at::randn({8}, options).as_strided({3, 8}, {0, 1}); + at::Tensor in_tensor = + at::randn({8}, options) + .as_strided( + {parallelizes_broadcast ? 3 : num_devices * 3, 8}, {0, 1}); at::Tensor out_tensor = executor_cache.runFusionWithInputs({in_tensor})[0]; testValidate( executor_cache.fusion(), {out_tensor}, {in_tensor}, __LINE__, __FILE__); From 9ff10cf8281b78d89c640080d79777b0d9f2513d Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 22 Nov 2024 09:59:00 -0800 Subject: [PATCH 06/20] Unify unshardedSizes. --- csrc/expr_evaluator.cpp | 54 +++--------------------------- csrc/multidevice/utils.cpp | 45 +++++++++++++++++++++++++ csrc/multidevice/utils.h | 29 ++++++++++++++++ csrc/tensor_metadata.cpp | 68 +++++--------------------------------- 4 files changed, 88 insertions(+), 108 deletions(-) diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index 7e04c3bcd4d..a2ebccfb7b3 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -128,44 +128,6 @@ void validateValWithConcreteValue( } } -std::vector unshardedSizes( - const TensorView* tv, - c10::IntArrayRef sizes) { - std::vector unsharded_sizes = sizes.vec(); - - for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { - const ParallelType parallel_type = alloc_id->getParallelType(); - if (!isParallelTypeDeviceDim(parallel_type)) { - continue; - } - - const auto inputs = IterVisitor::getInputsTo( - {alloc_id}, - {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); - if (inputs.empty()) { - // FIXME: is this even possible? Logical ought to dominate allocation. - continue; - } - NVF_ERROR(inputs.size() == 1); - - const auto iter = std::find( - tv->getLogicalDomain().begin(), - tv->getLogicalDomain().end(), - inputs[0]); - if (iter == tv->getLogicalDomain().end()) { - // FIXME: is this even possible? Logical ought to dominate allocation. - continue; - } - const auto index = std::count_if( - tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { - return !id->isReduction(); - }); - unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); - } - - return unsharded_sizes; -} - } // namespace void ExpressionEvaluator::bindTensorDomain( @@ -183,13 +145,7 @@ void ExpressionEvaluator::bindTensorDomain( ", but got a tensor of rank ", t.dim()); - std::vector sizes; - if (isSharded(tv)) { - sizes = unshardedSizes(tv, t.sizes()); - } else { - sizes = t.sizes().vec(); - } - + std::vector logical_sizes = unshardedSizes(tv, t.sizes()); for (auto i : c10::irange(t.dim())) { auto id = logical_domain[i]; if (id->isBroadcast()) { @@ -197,7 +153,7 @@ void ExpressionEvaluator::bindTensorDomain( if (id->hasExpandedExtent()) { // Verify that t is also expanded NVF_ERROR( - sizes[i] == 1 || t.stride(i) == 0, + logical_sizes[i] == 1 || t.stride(i) == 0, "IterDomain ", id->toString(), " in ", @@ -205,15 +161,15 @@ void ExpressionEvaluator::bindTensorDomain( "TensorView ", tv->toString(), " has expanded extent but input tensor has size ", - sizes[i], + logical_sizes[i], " and stride ", t.stride(i), " in dimension ", i); - bind_(id->expandedExtent(), sizes[i], evaluate_validate); + bind_(id->expandedExtent(), logical_sizes[i], evaluate_validate); } } else { - bind_(id->extent(), sizes[i], evaluate_validate); + bind_(id->extent(), logical_sizes[i], evaluate_validate); } } } diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 65cb76b0da7..f17e9a42c0a 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -122,6 +122,51 @@ bool isSharded(const TensorView* tv) { return is_sharded; } +std::vector unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes) { + std::vector unsharded_sizes = sizes.vec(); + + for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { + const ParallelType parallel_type = alloc_id->getParallelType(); + if (!isParallelTypeDeviceDim(parallel_type)) { + continue; + } + + const auto inputs = IterVisitor::getInputsTo( + {alloc_id}, + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); + NVF_ERROR( + !inputs.empty(), + "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); + NVF_ERROR( + inputs.size() == 1, + "Failed to find the single logical input to ", + alloc_id, + ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", + toDelimitedString(inputs)); + + const auto iter = std::find( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + inputs[0]); + NVF_ERROR( + iter != tv->getLogicalDomain().end(), + "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", + inputs[0]); + + // Count the number of non-reduction IterDomains before `iter`. Reduction + // IterDomains are not materialized in the at::Tensor's shape. + const auto index = std::count_if( + tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { + return !id->isReduction(); + }); + unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); + } + + return unsharded_sizes; +} + int64_t numDeviceDims(const TensorView* tv) { return std::count_if( tv->getLoopDomain().begin(), diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 12013e918b4..b4f25f0df6c 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -7,6 +7,8 @@ // clang-format on #pragma once +#include + #include #include #include @@ -127,4 +129,31 @@ int64_t getShardedAxis(TensorView*); // Reorders a TensorView so that the DID parallelized axis are in front. void reorderDIDToFront(TensorView*); + +// Given a TensorView and the shape of a sharded tensor of which certain +// dimensions are partially alloated, returns the global shape that'll be used +// to bind to the TensorView's logical domain. This is to solve #3282 so we can +// bind a sharded tensor to a TensorView that has a DID-parallel loop domain. +// +// For example, when `tv` is +// logical: iM, iN +// allocation: iDIDx{D}, iN/D, iM +// and `sizes` is [2, 3], the returned shape will be [2, 3D]. This is because, +// according to the allocation domain, iM is fully allocated and iN is sharded +// and thus partially allocated. +// +// As a degenerate case, it's fine to call this function with a non-sharded +// TensorView and tensor. +// +// Limitations: +// - The function assumes that there are no Merges from logical to the +// DID-parallel IterDomains in allocation. Otherwise, it's unclear which logical +// dimension this DID-parallelization should be attributed to. +// - The function assumes that all Splits from logical to the DID-parallel +// IterDomains in allocation are even. This is because there are currently no +// ways to pass in the global shape without an API overhaul. +std::vector unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes); + } // namespace nvfuser diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index e35d235b709..5bb67cf1b9d 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -272,46 +272,6 @@ void validateAllocationSizesAndStrides( } } -// FIXME: strides are never changed -std::pair, std::vector> unshardedSizesAndStrides( - TensorView* tv, - c10::IntArrayRef sizes, - c10::IntArrayRef strides) { - std::vector unsharded_sizes = sizes.vec(); - std::vector unsharded_strides = strides.vec(); - - for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { - const ParallelType parallel_type = alloc_id->getParallelType(); - if (!isParallelTypeDeviceDim(parallel_type)) { - continue; - } - - const auto inputs = IterVisitor::getInputsTo( - {alloc_id}, - {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); - if (inputs.empty()) { - // FIXME: is this even possible? Logical ought to dominate loop. - continue; - } - NVF_ERROR(inputs.size() == 1); - - const auto iter = std::find( - tv->getLogicalDomain().begin(), - tv->getLogicalDomain().end(), - inputs[0]); - if (iter == tv->getLogicalDomain().end()) { - // FIXME: is this even possible? Logical ought to dominate loop. - continue; - } - const auto index = std::count_if( - tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { - return !id->isReduction(); - }); - unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); - } - - return {unsharded_sizes, unsharded_strides}; -} } // namespace std::pair, std::vector> @@ -322,21 +282,12 @@ inferAndValidateAllocationSizesAndStrides( const auto& logical = tv->getLogicalDomain(); const auto& alloc = tv->getMaybeAllocationDomain(); - std::vector logical_sizes; - std::vector logical_strides; - if (isSharded(tv)) { - std::tie(logical_sizes, logical_strides) = - unshardedSizesAndStrides(tv, tensor.sizes(), tensor.strides()); - } else { - logical_sizes = tensor.sizes().vec(); - logical_strides = tensor.strides().vec(); - } - // active IDs and their shape and stride + std::vector logical_sizes = unshardedSizes(tv, tensor.sizes()); std::unordered_map> active_ids; int64_t dim_index = 0; for (IterDomain* id : TensorDomain::noReductions(logical)) { - active_ids[id] = {logical_sizes[dim_index], logical_strides[dim_index]}; + active_ids[id] = {logical_sizes[dim_index], tensor.stride(dim_index)}; dim_index++; } NVF_ERROR(dim_index == tensor.dim()); @@ -348,6 +299,8 @@ inferAndValidateAllocationSizesAndStrides( // need to put them to the correct order. std::vector allocation_sizes; std::vector allocation_strides; + allocation_sizes.reserve(alloc.size()); + allocation_strides.reserve(alloc.size()); for (IterDomain* id : TensorDomain::noReductions(alloc)) { if (id->isDeviceDim()) { allocation_sizes.push_back(1); @@ -388,22 +341,19 @@ std::vector GetMetaData::evaluate( metadata->data = input.data_ptr(); if (isSharded(tv)) { - auto [unsharded_sizes, unsharded_strides] = - unshardedSizesAndStrides(tv, input.sizes(), input.strides()); + std::vector unsharded_sizes = unshardedSizes(tv, input.sizes()); metadata->logical_size_data = std::move(unsharded_sizes); metadata->logical_size = c10::makeArrayRef(metadata->logical_size_data); - metadata->logical_stride_data = std::move(unsharded_strides); - metadata->logical_stride = c10::makeArrayRef(metadata->logical_stride_data); } else { metadata->logical_size = input.sizes(); - metadata->logical_stride = input.strides(); } + metadata->logical_stride = input.strides(); - auto [sizes, strides] = + auto [allocation_sizes, allocation_strides] = inferAndValidateAllocationSizesAndStrides(input, tv, ee); - metadata->alloc_size_data = std::move(sizes); + metadata->alloc_size_data = std::move(allocation_sizes); metadata->alloc_size = c10::makeArrayRef(metadata->alloc_size_data); - metadata->alloc_stride_data = std::move(strides); + metadata->alloc_stride_data = std::move(allocation_strides); metadata->alloc_stride = c10::makeArrayRef(metadata->alloc_stride_data); return {PolymorphicValue(std::move(struct_))}; } From 8bd9486caa36265c81ca715e454ed7164b38abc1 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 25 Nov 2024 16:52:35 -0800 Subject: [PATCH 07/20] Fix a test --- csrc/ir/nodes.cpp | 23 ++++++++---------- tests/cpp/test_sharding.cpp | 47 ++++++++++++++++++++++++------------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index c93c4980e85..6906861814e 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3238,24 +3238,21 @@ bool TensorDomain::sameAs( std::string TensorDomain::toString(const int indent_size, const bool loop_only) const { std::stringstream ss; - if (nDims() == 0) { - indent(ss, indent_size) << "[ ]"; - return ss.str(); - } - indent(ss, indent_size) << "[ " << toDelimitedString(loop()) << " ]"; - if (!loop_only) { + if (loop_only) { + indent(ss, indent_size) << "[" << toDelimitedString(loop()) << "]"; + } else { + indent(ss, indent_size) + << "logical=[" << toDelimitedString(logical()) << "]" << std::endl; if (hasRoot()) { - ss << "," << std::endl; indent(ss, indent_size + 1) - << "root=[ " << toDelimitedString(root()) << " ]"; + << "root=[" << toDelimitedString(root()) << "]" << std::endl; } - ss << "," << std::endl; indent(ss, indent_size + 1) - << "logical=[ " << toDelimitedString(logical()) << " ]"; - if (!allocation_domain_.empty()) { - ss << "," << std::endl; + << "loop=[" << toDelimitedString(loop()) << "]" << std::endl; + if (hasAllocation()) { indent(ss, indent_size + 1) - << "allocation=[ " << toDelimitedString(allocation()) << " ]"; + << "allocation=[" << toDelimitedString(allocation()) << "]" + << std::endl; } } return ss.str(); diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index a0c643e95b4..9c9ca068689 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -23,26 +23,41 @@ namespace nvfuser { using ShardingTest = NVFuserFixtureParamTest; -// TODO: This test checks that isSharded generates an error when a split/merged -// axis is parallelized with DIDx. Update when this restriction is lifted. -TEST_F(ShardingTest, IsSharded) { +TEST_F(ShardingTest, LogicalIsSharded) { Fusion fusion; FusionGuard fg(&fusion); - TensorView* a = makeSymbolicTensor(3); - a->axis(2)->parallelize(ParallelType::DIDx); - a->split(0, 4); - EXPECT_TRUE(isSharded(a)) << "DIDx on logical domain"; + TensorView* x = makeSymbolicTensor(3); + x->axis(2)->parallelize(ParallelType::DIDx); + x->split(0, 4); - TensorView* b = makeSymbolicTensor(3); - b->split(1, 4); - b->axis(1)->parallelize(ParallelType::DIDx); - EXPECT_TRUE(isSharded(b)) << "DIDx on loop domain"; - - TensorView* c = makeSymbolicTensor(3); - c->axis(0)->parallelize(ParallelType::DIDx); - c->axis(1)->parallelize(ParallelType::DIDx); - EXPECT_ANY_THROW(isSharded(c)) << "Multiple DIDx"; + EXPECT_TRUE(isSharded(x)) << "DIDx on logical domain:" << std::endl + << x->domain()->toString(0, /*loop_only=*/false); +} + +TEST_F(ShardingTest, AllocationIsSharded) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = makeSymbolicTensor(3); + x->split(1, 4); + x->axis(1)->parallelize(ParallelType::DIDx); + x->setAllocationDomain(x->getLoopDomain(), true); + + EXPECT_TRUE(isSharded(x)) << "DIDx on allocation domain:" << std::endl + << x->domain()->toString(0, /*loop_only=*/false); +} + +TEST_F(ShardingTest, MultipleDIDx) { + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* x = makeSymbolicTensor(3); + x->axis(0)->parallelize(ParallelType::DIDx); + x->axis(1)->parallelize(ParallelType::DIDx); + EXPECT_ANY_THROW(isSharded(x)) + << "Multiple DIDx:" << std::endl + << x->domain()->toString(0, /*loop_only=*/false); } TEST_F(ShardingTest, PropagateSharding) { From 5a16349adcde62fb267259949565f85a8d4660b6 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Mon, 25 Nov 2024 20:23:31 -0800 Subject: [PATCH 08/20] Refine the logic in the transpose scheduler --- csrc/scheduler/transpose.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/scheduler/transpose.cpp b/csrc/scheduler/transpose.cpp index 553ba9d773e..7e320f99a91 100644 --- a/csrc/scheduler/transpose.cpp +++ b/csrc/scheduler/transpose.cpp @@ -233,7 +233,12 @@ class DomainMap : public pointwise_utils::DomainMap { for (auto* expr : replay_exprs) { if (auto* split = dynamic_cast(expr)) { if (split->in() == mapped_id) { - mapped_id = split->inner(); + if (split->inner()->extent()->isOneInt() && + !split->outer()->extent()->isOneInt()) { + mapped_id = split->outer(); + } else { + mapped_id = split->inner(); + } } } else if (auto* merge = dynamic_cast(expr)) { // Merge with size-1 dimension is not supposed to be here, reshape would From 27689708c7e15f91c9ac1383f434284464bf1990 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 27 Nov 2024 11:17:17 -0800 Subject: [PATCH 09/20] Comment --- csrc/multidevice/utils.h | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index b4f25f0df6c..50f0223a034 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -151,7 +151,20 @@ void reorderDIDToFront(TensorView*); // dimension this DID-parallelization should be attributed to. // - The function assumes that all Splits from logical to the DID-parallel // IterDomains in allocation are even. This is because there are currently no -// ways to pass in the global shape without an API overhaul. +// ways to pass in the global shape. +// +// Despite these limitations, I took this approach as a shortcut to fix #3282, +// which blocked many other tasks. I'm however open to other better, long-term +// solutions. Some alternatives considered in #3282 are: +// - Try to bind `at::Tensor`s to allocation domains instead of logical. Many +// `*Op::evaluate` methods (e.g. +// https://github.com/NVIDIA/Fuser/blob/2415d904d1e9a5da7ca6fb1a55d3045bbd510341/csrc/ir/nodes.cpp#L4321-L4329) +// assume the input/output `at::Tensor`s have the same dimension order as the +// logical domain. Doing so would have to change them all. +// - Try to pass into FusionExecutorCache both logical (global) shapes and +// allocated (local) tensors for sharded TensorViews. The logical shapes would +// have to be passed through FusionKernelRuntime, FusionExecutor, +// ExpressionEvaluator, and so on, which is an API overhaul. std::vector unshardedSizes( const TensorView* tv, c10::IntArrayRef sizes); From 400684ea2dc7294f18b4e3609764c16e28f3945a Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 27 Nov 2024 11:32:15 -0800 Subject: [PATCH 10/20] Resolve two fixmes --- csrc/fusion_segmenter.cpp | 4 +++- csrc/transform_replay.cpp | 1 - 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 4d7cb4e693b..c98543a179a 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -1884,7 +1884,9 @@ void eraseInputDistinctRootDomains(Fusion* fusion) { new_alloc.reserve(tv->getAllocationDomain().size()); for (IterDomain* alloc_id : tv->getAllocationDomain()) { IterDomain* new_alloc_id = replay.getReplay().at(alloc_id); - // FIXME: should this be taken care of by ReplayTransformations? + // ReplayTransformations replay transforms but not paralelization, so + // we have to manually parallelize the new allocation ID. In other + // places, parallelization is usually done through parallelizeAllLike. new_alloc_id->parallelize(alloc_id->getParallelType()); new_alloc.push_back(new_alloc_id); } diff --git a/csrc/transform_replay.cpp b/csrc/transform_replay.cpp index 0ac2a9e97a4..6c58d83528a 100644 --- a/csrc/transform_replay.cpp +++ b/csrc/transform_replay.cpp @@ -775,7 +775,6 @@ std::pair TransformReplay::replayCasP( // AllocationDomainTest.CacheBefore. if (auto it = p2c_map.find(alloc_id); it != p2c_map.end()) { IterDomain* new_alloc_id = it->second; - // FIXME: should this be taken care of by ReplayTransformations? new_alloc_id->parallelize(alloc_id->getParallelType()); new_allocation_domain.push_back(new_alloc_id); new_contiguity.push_back(producer->getContiguity()[i]); From 3086237234f21a74fc95014085ed94ae76854e35 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 25 Oct 2024 16:35:11 -0700 Subject: [PATCH 11/20] Add a repro. --- csrc/multidevice/utils.cpp | 2 +- .../test_multidevice_lower_communication.cpp | 36 +++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 24b7e582104..1e65bd5ceec 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -503,7 +503,7 @@ std::set involvedDevices(Expr* expr) { } int64_t getShardedAxis(TensorView* tv) { - auto ids = TensorDomain::noReductions(tv->getLogicalDomain()); + auto ids = TensorDomain::noReductions(tv->getLoopDomain()); for (size_t i = 0; i < ids.size(); ++i) { if (ids[i]->getParallelType() == ParallelType::DIDx) { return static_cast(i); diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 643b5b2220d..73d5a124357 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -202,6 +202,42 @@ TEST_F(LowerCollectiveTest, Allgather) { EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor)); } +TEST_F(LowerCollectiveTest, Allgather_SplitLoop) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigTensor(1); + in->setDeviceMesh(mesh); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + in->split(0, num_devices, /*inner_split=*/false); + in->axis(0)->parallelize(ParallelType::DIDx); + in->setAllocationDomain(in->getLoopDomain(), true); + + out->split(0, num_devices, /*inner_split=*/false); + out->setAllocationDomain(out->getLoopDomain(), true); + + at::Tensor unsharded_tensor = + at::randn({num_devices * kTensorSize}, at::kFloat); + at::Tensor in_tensor = unsharded_tensor + .slice( + 0, + communicator_->deviceId() * kTensorSize, + (communicator_->deviceId() + 1) * kTensorSize) + .to(communicator_->device()); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + assertIsCompiledToHostIrContainer(fec); + + EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor)); +} + TEST_F(LowerCollectiveTest, Broadcast) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); From fe0cec638d5e4ad9d85be620ff63f4f62ab20ff6 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 13 Nov 2024 22:42:25 -0800 Subject: [PATCH 12/20] Generalize postAllgather to support DID loop split. --- csrc/multidevice/communication.cpp | 4 ++-- csrc/multidevice/utils.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index 29ef6995969..fbef4831e96 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -325,8 +325,8 @@ c10::intrusive_ptr postAllgather( c10d::Backend* backend, at::Tensor input_tensor, at::Tensor output_tensor) { - auto splits = at::split(output_tensor, /*split_size=*/1, /*dim=*/0); - assertBufferCount(splits, communication->team().size()); + auto splits = + at::tensor_split(output_tensor, communication->team().size(), /*dim=*/0); assertBuffersHaveSameSize({input_tensor}, splits); // allgather primitive in c10d induces extra buffering time to copy out the diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 1e65bd5ceec..334f8271aa2 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -503,8 +503,8 @@ std::set involvedDevices(Expr* expr) { } int64_t getShardedAxis(TensorView* tv) { - auto ids = TensorDomain::noReductions(tv->getLoopDomain()); - for (size_t i = 0; i < ids.size(); ++i) { + auto ids = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); + for (const auto i : c10::irange(ids.size())) { if (ids[i]->getParallelType() == ParallelType::DIDx) { return static_cast(i); } From 5d60dd551ea0d6fc82b59ad2ea18087ff5bad806 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 27 Nov 2024 15:16:49 -0800 Subject: [PATCH 13/20] Try to reuse getShardedAxis --- csrc/multidevice/lower_communication.cpp | 2 +- csrc/multidevice/utils.cpp | 134 +++++++++--------- csrc/multidevice/utils.h | 6 +- tests/cpp/multidevice.cpp | 11 +- .../test_multidevice_lower_communication.cpp | 10 +- tests/cpp/test_multidevice_transformer.cpp | 124 ++++++++-------- 6 files changed, 140 insertions(+), 147 deletions(-) diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index c8068b5a113..a8fca521c4a 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -196,7 +196,7 @@ void lowerToReduceScatter( std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); auto reduction_axis = output_tv->getReductionAxis().value(); - auto scattered_axis = getShardedAxis(output_tv); + auto scattered_axis = getShardedAxis(output_tv, ParallelType::DIDx); // The output tensor is sharded on scattered_axis and needs to be mapped // back onto the input. The input has an reduced axis, so the scattered axis // is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 334f8271aa2..0209a6b5d5f 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -121,48 +121,77 @@ bool isSharded(const TensorView* tv) { return is_sharded; } -std::vector unshardedSizes( - const TensorView* tv, - c10::IntArrayRef sizes) { - std::vector unsharded_sizes = sizes.vec(); - - for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) { - const ParallelType parallel_type = alloc_id->getParallelType(); +namespace { +// Collect device-parallel IterDomains in `domain` and return them as a +// ParallelType-to-IterDomain map. +std::unordered_map mapDeviceParallelTypeToId( + const std::vector& domain) { + std::unordered_map parallel_type_to_id; + parallel_type_to_id.reserve(kParallelTypeDIDs.size()); + for (IterDomain* id : domain) { + const ParallelType parallel_type = id->getParallelType(); if (!isParallelTypeDeviceDim(parallel_type)) { continue; } - const auto inputs = IterVisitor::getInputsTo( - {alloc_id}, - {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); NVF_ERROR( - !inputs.empty(), - "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); - NVF_ERROR( - inputs.size() == 1, - "Failed to find the single logical input to ", - alloc_id, - ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", - toDelimitedString(inputs)); - - const auto iter = std::find( - tv->getLogicalDomain().begin(), - tv->getLogicalDomain().end(), - inputs[0]); - NVF_ERROR( - iter != tv->getLogicalDomain().end(), - "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", - inputs[0]); - - // Count the number of non-reduction IterDomains before `iter`. Reduction - // IterDomains are not materialized in the at::Tensor's shape. - const auto index = std::count_if( - tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { - return !id->isReduction(); - }); - unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type); + parallel_type_to_id.try_emplace(parallel_type, id).second, + "Found multiple loop IterDomains with the same parallel type (", + parallel_type, + "): ", + toDelimitedString(domain)); + } + return parallel_type_to_id; +} +} // namespace + +int64_t getShardedAxis(const TensorView* tv, const ParallelType parallel_type) { + std::unordered_map parallel_type_to_id = + mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); + IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); + if (alloc_id == nullptr) { + return -1; } + const auto inputs = IterVisitor::getInputsTo( + {alloc_id}, + {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); + NVF_ERROR( + !inputs.empty(), + "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); + NVF_ERROR( + inputs.size() == 1, + "Failed to find the single logical input to ", + alloc_id, + ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", + toDelimitedString(inputs)); + + const auto iter = std::find( + tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), inputs[0]); + NVF_ERROR( + iter != tv->getLogicalDomain().end(), + "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", + inputs[0]); + + // Count the number of non-reduction IterDomains before `iter`. Reduction + // IterDomains are not materialized in the at::Tensor's shape. + return std::count_if( + tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { + return !id->isReduction(); + }); +} + +std::vector unshardedSizes( + const TensorView* tv, + c10::IntArrayRef sizes) { + std::vector unsharded_sizes = sizes.vec(); + for (ParallelType parallel_type : kParallelTypeDIDs) { + const int64_t sharded_axis = getShardedAxis(tv, parallel_type); + if (sharded_axis == -1) { + continue; + } + unsharded_sizes.at(sharded_axis) *= tv->getDeviceMesh().size(parallel_type); + } return unsharded_sizes; } @@ -174,27 +203,6 @@ int64_t numDeviceDims(const TensorView* tv) { } namespace { -// Collect device-parallel IterDomains in `loop_domain` and return them as a -// ParallelType-to-IterDomain map. -std::unordered_map mapParallelTypeToId( - const std::vector& loop_domain) { - std::unordered_map parallel_type_to_id; - parallel_type_to_id.reserve(kParallelTypeDIDs.size()); - for (IterDomain* loop_id : loop_domain) { - const ParallelType parallel_type = loop_id->getParallelType(); - if (!isParallelTypeDeviceDim(parallel_type)) { - continue; - } - - NVF_ERROR( - parallel_type_to_id.try_emplace(parallel_type, loop_id).second, - "Found multiple loop IterDomains with the same parallel type (", - parallel_type, - "): ", - toDelimitedString(loop_domain)); - } - return parallel_type_to_id; -} std::vector getInputsInTargetDomain( IterDomain* loop_id, @@ -294,9 +302,9 @@ bool haveDifferentShardings( // 3. Check if the two loop IterDomains are almost-exactly mapped in the // IdModel. std::unordered_map p_parallel_type_to_id = - mapParallelTypeToId(producer->getLoopDomain()); + mapDeviceParallelTypeToId(producer->getLoopDomain()); std::unordered_map c_parallel_type_to_id = - mapParallelTypeToId(consumer->getLoopDomain()); + mapDeviceParallelTypeToId(consumer->getLoopDomain()); for (const auto parallel_type : kParallelTypeDIDs) { IterDomain* p_loop_id = getOrDefault(p_parallel_type_to_id, parallel_type); @@ -502,16 +510,6 @@ std::set involvedDevices(Expr* expr) { return ret; } -int64_t getShardedAxis(TensorView* tv) { - auto ids = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); - for (const auto i : c10::irange(ids.size())) { - if (ids[i]->getParallelType() == ParallelType::DIDx) { - return static_cast(i); - } - } - return -1; -} - void reorderDIDToFront(TensorView* tv) { // new position to old position std::unordered_map order_map; diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 50f0223a034..439ecfd28c2 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -123,9 +123,9 @@ int64_t requestedNumberOfDevices(Fusion*); void unshard(Fusion*); void unshard(TensorView*); -// Returns the index of the a sharded axis if none return -1. -// TODO: Assumes no merges/splits on sharded axis. -int64_t getShardedAxis(TensorView*); +// Returns the index of the sharded logical axis corresponding to +// `parallel_type`. If `tv` isn't sharded on the parallel type, returns -1. +int64_t getShardedAxis(const TensorView* tv, ParallelType parallel_type); // Reorders a TensorView so that the DID parallelized axis are in front. void reorderDIDToFront(TensorView*); diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index bab5cdccc5e..92c13c1c460 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -128,7 +128,8 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) { return tensor; } NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv); - return shardTensor(tensor, getShardedAxis(tv), tv->getDeviceMesh()); + return shardTensor( + tensor, getShardedAxis(tv, ParallelType::DIDx), tv->getDeviceMesh()); } at::Tensor MultiDeviceTest::shardTensor( @@ -144,13 +145,7 @@ at::Tensor MultiDeviceTest::shardTensor( auto stride = extent / nslices; // TODO: returning slice 0 temporarily when device is not in the mesh. i = (i < 0) ? 0 : i; - auto slice = tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); - // Temporary until https://github.com/NVIDIA/Fuser/issues/2563. Adds DIDx - // axis in front representing the sharded extent of the tensor. - if (stride > 1) { - slice = slice.unsqueeze(0); - } - return slice; + return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); } } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index 73d5a124357..e316ec0f219 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -202,7 +202,7 @@ TEST_F(LowerCollectiveTest, Allgather) { EXPECT_TRUE(at::equal(out_tensor, unsharded_tensor)); } -TEST_F(LowerCollectiveTest, Allgather_SplitLoop) { +TEST_F(LowerCollectiveTest, Allgather_LoopSplit) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -224,12 +224,8 @@ TEST_F(LowerCollectiveTest, Allgather_SplitLoop) { at::Tensor unsharded_tensor = at::randn({num_devices * kTensorSize}, at::kFloat); - at::Tensor in_tensor = unsharded_tensor - .slice( - 0, - communicator_->deviceId() * kTensorSize, - (communicator_->deviceId() + 1) * kTensorSize) - .to(communicator_->device()); + at::Tensor in_tensor = + shardTensor(unsharded_tensor, in).to(communicator_->device()); FusionExecutorCache fec(std::move(fusion)); at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 2ef33dcdf8f..0f39ae6f6e5 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -720,14 +720,14 @@ TEST_P(DistributedTransformerTest, MLP_Layer) { std::vector inputs = { x, - shardTensor(w0, 0, mesh), - shardTensor(b0, 0, mesh), - shardTensor(w1, 1, mesh), + shardTensor(w0, 0, mesh).unsqueeze(0), + shardTensor(b0, 0, mesh).unsqueeze(0), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { - shardTensor(reference_outs[0], 1, mesh), - shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), reference_outs[2], reference_outs[3]}; @@ -801,17 +801,17 @@ TEST_P(DistributedTransformerTest, Sequence_Parallel_MLP_Layer) { auto mask_ = reference_outs[4]; std::vector inputs = { - shardTensor(x_, 0, mesh), - shardTensor(w0_, 0, mesh), - shardTensor(b0_, 0, mesh), - shardTensor(w1_, 1, mesh), + shardTensor(x_, 0, mesh).unsqueeze(0), + shardTensor(w0_, 0, mesh).unsqueeze(0), + shardTensor(b0_, 0, mesh).unsqueeze(0), + shardTensor(w1_, 1, mesh).unsqueeze(0), b1_}; std::vector expected_outputs = { - shardTensor(reference_outs[0], 1, mesh), - shardTensor(reference_outs[1], 1, mesh), - shardTensor(reference_outs[2], 0, mesh), - shardTensor(reference_outs[3], 0, mesh)}; + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[2], 0, mesh).unsqueeze(0), + shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)}; FusionExecutorCache executor_cache(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -866,12 +866,12 @@ TEST_P(DistributedTransformerTest, MultiheadAttention) { x, shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh) .view({1, B * S, 3 * E / D}), - shardTensor(reference_outs[1], 1, mesh), + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), reference_outs[2], reference_outs[3]}; @@ -929,17 +929,17 @@ TEST_P(DistributedTransformerTest, MultiheadAttention_SP) { at::manual_seed(getATenRandomSeed()); auto reference_outs = reference_mha(x, w0, b0, w1, b1); std::vector inputs = { - shardTensor(x, 0, mesh), + shardTensor(x, 0, mesh).unsqueeze(0), shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(b0.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), b1}; std::vector expected_outputs = { shardTensor(reference_outs[0].view({B * S, 3, E}), 2, mesh) .view({1, B * S, 3 * E / D}), - shardTensor(reference_outs[1], 1, mesh), - shardTensor(reference_outs[2], 0, mesh), - shardTensor(reference_outs[3], 0, mesh)}; + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), + shardTensor(reference_outs[2], 0, mesh).unsqueeze(0), + shardTensor(reference_outs[3], 0, mesh).unsqueeze(0)}; FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -1003,16 +1003,16 @@ TEST_P(DistributedTransformerTest, MLP_Backward) { grad_, x_, mask_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), - shardTensor(linear0_, 1, mesh)}; + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), + shardTensor(linear0_, 1, mesh).unsqueeze(0)}; std::vector expected_outputs = { outs[0], // dropout grad - shardTensor(outs[1], 1, mesh), // linear1 weight grad + shardTensor(outs[1], 1, mesh).unsqueeze(0), // linear1 weight grad outs[2], // linear1 bias grad - shardTensor(outs[3], 1, mesh), // gelu grad - shardTensor(outs[4], 0, mesh), // linear0 weight grad - shardTensor(outs[5], 0, mesh), // linear0 bias grad + shardTensor(outs[3], 1, mesh).unsqueeze(0), // gelu grad + shardTensor(outs[4], 0, mesh).unsqueeze(0), // linear0 weight grad + shardTensor(outs[5], 0, mesh).unsqueeze(0), // linear0 bias grad outs[6]}; // linear0 grad x FusionExecutorCache executor_cache(std::move(fusion)); @@ -1094,22 +1094,23 @@ TEST_P(DistributedTransformerTest, MHA_Backward) { std::vector inputs = { x, shardTensor(w0.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), - shardTensor(w1, 1, mesh), + shardTensor(w1, 1, mesh).unsqueeze(0), grad, mask, - shardTensor(reference_outs[0], 1, mesh), // sdpa.output - shardTensor(reference_outs[1], 1, mesh), // sdpa.log_sumexp + shardTensor(reference_outs[0], 1, mesh).unsqueeze(0), // sdpa.output + shardTensor(reference_outs[1], 1, mesh).unsqueeze(0), // sdpa.log_sumexp reference_outs[2], // sdpa.seed reference_outs[3], // sdpa.offset - shardTensor(reference_outs[13], 1, mesh) // linear0 + shardTensor(reference_outs[13], 1, mesh).unsqueeze(0) // linear0 }; std::vector expected_outputs = { reference_outs[4], // dropout grad - shardTensor(reference_outs[5], 1, mesh), // linear1 weight grad + shardTensor(reference_outs[5], 1, mesh) + .unsqueeze(0), // linear1 weight grad reference_outs[6], // linear1 bias grad - shardTensor(reference_outs[7], 1, mesh), // q grad - shardTensor(reference_outs[8], 1, mesh), // k grad - shardTensor(reference_outs[9], 1, mesh), // v grad + shardTensor(reference_outs[7], 1, mesh).unsqueeze(0), // q grad + shardTensor(reference_outs[8], 1, mesh).unsqueeze(0), // k grad + shardTensor(reference_outs[9], 1, mesh).unsqueeze(0), // v grad shardTensor(reference_outs[10].view({3, E, E}), 1, mesh) .view({1, 3 * E / D, E}), // linear0 weight grad shardTensor(reference_outs[11].view({3, E}), 1, mesh) @@ -1234,26 +1235,26 @@ TEST_P(DistributedTransformerTest, Forward_SP) { auto at_out = (resid0_ + mlp_out_).to(at_dtype); std::vector inputs = { - shardTensor(x_, 0, mesh), + shardTensor(x_, 0, mesh).unsqueeze(0), ln0_w_, ln0_b_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(mha_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), mha_b1_, ln1_w_, ln1_b_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_b0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_b1_}; std::vector expected_outputs = { - shardTensor(ln0_out_, 0, mesh), - shardTensor(mha_out_, 0, mesh), - shardTensor(ln1_out_, 0, mesh), - shardTensor(mlp_out_, 0, mesh), - shardTensor(at_out, 0, mesh)}; + shardTensor(ln0_out_, 0, mesh).unsqueeze(0), + shardTensor(mha_out_, 0, mesh).unsqueeze(0), + shardTensor(ln1_out_, 0, mesh).unsqueeze(0), + shardTensor(mlp_out_, 0, mesh).unsqueeze(0), + shardTensor(at_out, 0, mesh).unsqueeze(0)}; FusionExecutorCache fec(std::move(fusion)); at::manual_seed(getATenRandomSeed()); @@ -1367,13 +1368,13 @@ TEST_P(DistributedTransformerTest, Forward) { ln0_b_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), shardTensor(mha_b0_.view({3, E}), 1, mesh).view({1, 3 * E / D}), - shardTensor(mha_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), mha_b1_, ln1_w_, ln1_b_, - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_b0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_b0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_b1_}; std::vector expected_outputs = { @@ -1620,13 +1621,16 @@ TEST_P(DistributedTransformerTest, Backward) { auto dx_ = (ln0_x_grad_ + resid1_grad_).to(at_dtype); auto expected_outputs = { - shardTensor(mlp_grads_[1], 1, mesh), // mlp_linear1_weight_grad + shardTensor(mlp_grads_[1], 1, mesh) + .unsqueeze(0), // mlp_linear1_weight_grad mlp_grads_[2], // mlp_linear1_bias_grad - shardTensor(mlp_grads_[4], 0, mesh), // mlp_linear0_weight_grad - shardTensor(mlp_grads_[5], 0, mesh), // mlp_linear0_bias_grad + shardTensor(mlp_grads_[4], 0, mesh) + .unsqueeze(0), // mlp_linear0_weight_grad + shardTensor(mlp_grads_[5], 0, mesh).unsqueeze(0), // mlp_linear0_bias_grad ln1_w_grad_, ln1_b_grad_, - shardTensor(mha_grads_[5], 1, mesh), // mha linear1 weight grad + shardTensor(mha_grads_[5], 1, mesh) + .unsqueeze(0), // mha linear1 weight grad mha_grads_[6], // mha linear1 bias grad shardTensor( mha_grads_[10].view({3, E, E}), 1, mesh) // failing starting here @@ -1641,13 +1645,13 @@ TEST_P(DistributedTransformerTest, Backward) { x_, grad_, shardTensor(mha_w0_.view({3, E, E}), 1, mesh).view({1, 3 * E / D, E}), - shardTensor(mha_w1_, 1, mesh), - shardTensor(mlp_w0_, 0, mesh), - shardTensor(mlp_w1_, 1, mesh), + shardTensor(mha_w1_, 1, mesh).unsqueeze(0), + shardTensor(mlp_w0_, 0, mesh).unsqueeze(0), + shardTensor(mlp_w1_, 1, mesh).unsqueeze(0), mlp_out_[4], // mlp dropout mask mha_out_[4], // mha dropout mask - shardTensor(mha_grads_[0], 1, mesh), // sdpa output - shardTensor(mha_grads_[1], 1, mesh), // sdpa logsum_exp + shardTensor(mha_grads_[0], 1, mesh).unsqueeze(0), // sdpa output + shardTensor(mha_grads_[1], 1, mesh).unsqueeze(0), // sdpa logsum_exp mha_grads_[2], // sdpa seed mha_grads_[3], // sdpa offset ln1_w_, @@ -1658,9 +1662,9 @@ TEST_P(DistributedTransformerTest, Backward) { ln0_b_, ln0_mean_, ln0_rstd_, - shardTensor(mha_out_[0], 1, mesh), // mha linear0 + shardTensor(mha_out_[0], 1, mesh).unsqueeze(0), // mha linear0 mha_out_[2].to(at::kFloat), // mha linear1 - shardTensor(mlp_out_[0], 1, mesh) // mlp linear1 + shardTensor(mlp_out_[0], 1, mesh).unsqueeze(0) // mlp linear1 }; FusionExecutorCache executor_cache(std::move(fusion)); From 7ccfccd21b324b5afa38ac7220be49cd13dbf948 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 27 Nov 2024 16:06:45 -0800 Subject: [PATCH 14/20] Fix lint --- csrc/multidevice/communication.cpp | 2 +- csrc/multidevice/communication.h | 5 +++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/communication.cpp b/csrc/multidevice/communication.cpp index fbef4831e96..3ffd9e3d24d 100644 --- a/csrc/multidevice/communication.cpp +++ b/csrc/multidevice/communication.cpp @@ -326,7 +326,7 @@ c10::intrusive_ptr postAllgather( at::Tensor input_tensor, at::Tensor output_tensor) { auto splits = - at::tensor_split(output_tensor, communication->team().size(), /*dim=*/0); + at::tensor_split(output_tensor, communication->team_size(), /*dim=*/0); assertBuffersHaveSameSize({input_tensor}, splits); // allgather primitive in c10d induces extra buffering time to copy out the diff --git a/csrc/multidevice/communication.h b/csrc/multidevice/communication.h index 45c104b36d3..8631a1a04e5 100644 --- a/csrc/multidevice/communication.h +++ b/csrc/multidevice/communication.h @@ -90,6 +90,11 @@ class Communication : public Expr { return attribute(1); } + // A convenience helper so the user doesn't need to convert size_t to int64_t. + int64_t team_size() const { + return static_cast(team().size()); + } + DeviceIdxType root() const { return attribute(2); } From 7cf238402fdbd0656f26577088128418e7433f6d Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 4 Dec 2024 14:37:05 -0800 Subject: [PATCH 15/20] Disallow DID on inner splits --- csrc/multidevice/utils.cpp | 76 +++++++++++++++++-------- tests/cpp/test_multidevice_sharding.cpp | 48 ++++++++++++++++ 2 files changed, 99 insertions(+), 25 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 0209a6b5d5f..ae705e3145a 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -143,6 +143,22 @@ std::unordered_map mapDeviceParallelTypeToId( } return parallel_type_to_id; } + +std::unordered_map mapIterDomainToTensorAxis( + const std::vector& domain) { + std::unordered_map id_to_axis; + int64_t axis = 0; + for (auto* id : domain) { + // Reduction IterDomains are not materialized as an at::Tensor axis. + if (id->isReduction()) { + continue; + } + id_to_axis[id] = axis; + axis++; + } + return id_to_axis; +} + } // namespace int64_t getShardedAxis(const TensorView* tv, const ParallelType parallel_type) { @@ -153,32 +169,42 @@ int64_t getShardedAxis(const TensorView* tv, const ParallelType parallel_type) { return -1; } - const auto inputs = IterVisitor::getInputsTo( - {alloc_id}, - {tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()}); - NVF_ERROR( - !inputs.empty(), - "IterVisitor::getInputsTo shouldn't return empty unless `of` is empty."); - NVF_ERROR( - inputs.size() == 1, - "Failed to find the single logical input to ", - alloc_id, - ". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ", - toDelimitedString(inputs)); - - const auto iter = std::find( - tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), inputs[0]); - NVF_ERROR( - iter != tv->getLogicalDomain().end(), - "The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ", - inputs[0]); + std::unordered_map logical_id_to_axis = + mapIterDomainToTensorAxis(tv->getLogicalDomain()); + IterDomain* id = alloc_id; + while (logical_id_to_axis.count(id) == 0) { + Expr* def = id->definition(); + NVF_ERROR( + def != nullptr, + "Failed to find a non-reduction logical IterDomain that produces ", + alloc_id); + if (auto* split = dynamic_cast(def)) { + // FIXME: comment + NVF_ERROR( + split->outer() == id, + "Currently, we don't support DID on inner splits: ", + split); + id = split->in(); + } else if (auto* merge = dynamic_cast(def)) { + // For example, + // + // t = makeContigTensor(2); + // t->merge(0, 1); + // t->axis(0)->parallelize(DIDx); + // + // When `unshardedSizes` is given a local tensor of shape [1, 1], it's + // unclear the global shape is [1, D] or [D, 1] or even [2, D/2], etc. + NVF_THROW( + "Failed to attribute the sharding to a single tensor axis and therefore bailed out: ", + merge); + } else { + NVF_THROW( + "Unexpected transforms from logical to a DID-parallel allocation IterDomain: ", + def); + } + } - // Count the number of non-reduction IterDomains before `iter`. Reduction - // IterDomains are not materialized in the at::Tensor's shape. - return std::count_if( - tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool { - return !id->isReduction(); - }); + return logical_id_to_axis.at(id); } std::vector unshardedSizes( diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 3adac90bc5e..aaa5d3a3218 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -491,4 +491,52 @@ TEST_P(MultiDeviceBroadcastTest, Expanded) { INSTANTIATE_TEST_SUITE_P(, MultiDeviceBroadcastTest, testing::Bool()); +TEST_F(MultiDeviceTest, ShardTensor_OuterSplit) { + const int d = communicator_->size(); + + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv = makeContigConcreteTensor({2, d * 3}); + tv->setDeviceMesh(DeviceMesh::createForNumDevices(d)); + tv->split(1, d, /*inner_split=*/false); + tv->axis(1)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + + fusion.addInput(tv); + fusion.addOutput(tv); + + at::Tensor unsharded = at::arange(2 * d * 3).view({2, d * 3}); + at::Tensor sharded = shardTensor(unsharded, tv); + + EXPECT_THAT(sharded.sizes(), ElementsAre(2, 3)); + at::Tensor expected = unsharded.view({2, d, 3}).index( + {torch::indexing::Slice(), + communicator_->deviceId(), + torch::indexing::Slice()}); + EXPECT_TRUE(at::equal(sharded, expected)); +} + +TEST_F(MultiDeviceTest, ShardTensor_InnerSplit) { + const int d = communicator_->size(); + + Fusion fusion; + FusionGuard fg(&fusion); + + TensorView* tv = makeContigConcreteTensor({d * 3}); + tv->setDeviceMesh(DeviceMesh::createForNumDevices(d)); + tv->split(0, d, /*inner_split=*/true); + tv->axis(-1)->parallelize(ParallelType::DIDx); + tv->setAllocationDomain(tv->getLoopDomain(), true); + + fusion.addInput(tv); + fusion.addOutput(tv); + + at::Tensor unsharded = at::arange(d * 3); + EXPECT_THAT( + [&]() { shardTensor(unsharded, tv); }, + ::testing::ThrowsMessage( + ::testing::HasSubstr("DID on inner splits"))); +} + } // namespace nvfuser From c27a5858fbe1cda203bfb4a55a85c4ade211f8a6 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 4 Dec 2024 15:08:31 -0800 Subject: [PATCH 16/20] Add a test for noncontiguous allgather --- .../test_multidevice_lower_communication.cpp | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/tests/cpp/test_multidevice_lower_communication.cpp b/tests/cpp/test_multidevice_lower_communication.cpp index e316ec0f219..d1f06d80e1d 100644 --- a/tests/cpp/test_multidevice_lower_communication.cpp +++ b/tests/cpp/test_multidevice_lower_communication.cpp @@ -234,6 +234,41 @@ TEST_F(LowerCollectiveTest, Allgather_LoopSplit) { EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor)); } +// This currently fails due to getShardingChanges reads root/logical only: +// https://github.com/NVIDIA/Fuser/blob/1dda106a946adcfd1526b83e4f2d4abebb9e32e4/csrc/multidevice/utils.cpp#L77. +// Will try to fix this in a follow-up PR and reenable the test. +TEST_F(LowerCollectiveTest, DISABLED_Allgather_LoopSplit_Noncontiguous) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const auto num_devices = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(num_devices); + + TensorView* in = makeContigTensor(2); + in->setDeviceMesh(mesh); + TensorView* out = set(in); + fusion->addInput(in); + fusion->addOutput(out); + + in->split(1, num_devices, /*inner_split=*/false); + in->axis(1)->parallelize(ParallelType::DIDx); + in->setAllocationDomain(in->getLoopDomain(), true); + + out->split(1, num_devices, /*inner_split=*/false); + out->setAllocationDomain(out->getLoopDomain(), true); + + at::Tensor unsharded_tensor = + at::arange(2 * num_devices * 3, at::kFloat).view({2, num_devices * 3}); + at::Tensor in_tensor = + shardTensor(unsharded_tensor, in).to(communicator_->device()); + + FusionExecutorCache fec(std::move(fusion)); + at::Tensor out_tensor = fec.runFusionWithInputs({in_tensor})[0]; + assertIsCompiledToHostIrContainer(fec); + + EXPECT_TRUE(at::equal(out_tensor.cpu(), unsharded_tensor)); +} + TEST_F(LowerCollectiveTest, Broadcast) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); From c218c70c55557a07bb1f71a31fe6cc479e2824b9 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Wed, 4 Dec 2024 23:47:29 -0800 Subject: [PATCH 17/20] Comment --- csrc/multidevice/utils.h | 5 +++-- tests/cpp/multidevice.cpp | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 61b06df0383..2d39b4dbe29 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -123,8 +123,9 @@ int64_t requestedNumberOfDevices(Fusion*); void unshard(Fusion*); void unshard(TensorView*); -// Returns the index of the sharded logical axis corresponding to -// `parallel_type`. If `tv` isn't sharded on the parallel type, returns -1. +// Returns the index of the sharded logical axis that produces the allocation +// IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel +// type, returns -1. int64_t getShardedAxis(const TensorView* tv, ParallelType parallel_type); // Reorders a TensorView so that the DID parallelized axis are in front. diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index 92c13c1c460..e591d61f980 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -145,6 +145,9 @@ at::Tensor MultiDeviceTest::shardTensor( auto stride = extent / nslices; // TODO: returning slice 0 temporarily when device is not in the mesh. i = (i < 0) ? 0 : i; + // The following slicing is problematic when DID is on an inner split (cf. + // MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and + // it's enforced by getShardedAxis. return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); } From 9c2a2183df984b018dcbff8a3b63bcac2b596c47 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Thu, 5 Dec 2024 09:36:03 -0800 Subject: [PATCH 18/20] Rename --- csrc/multidevice/lower_communication.cpp | 2 +- csrc/multidevice/utils.cpp | 6 ++++-- csrc/multidevice/utils.h | 2 +- tests/cpp/multidevice.cpp | 6 ++++-- 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/csrc/multidevice/lower_communication.cpp b/csrc/multidevice/lower_communication.cpp index a8fca521c4a..4b878ac7376 100644 --- a/csrc/multidevice/lower_communication.cpp +++ b/csrc/multidevice/lower_communication.cpp @@ -196,7 +196,7 @@ void lowerToReduceScatter( std::vector& comms) { const DeviceMesh& mesh = input_tv->getDeviceMesh(); auto reduction_axis = output_tv->getReductionAxis().value(); - auto scattered_axis = getShardedAxis(output_tv, ParallelType::DIDx); + auto scattered_axis = getShardedLogicalAxis(output_tv, ParallelType::DIDx); // The output tensor is sharded on scattered_axis and needs to be mapped // back onto the input. The input has an reduced axis, so the scattered axis // is adjusted to account for this. Ex: [DIDx(i0), i1] -> [r0, DIDx(i1)] The diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index ae705e3145a..76215d982c6 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -161,7 +161,9 @@ std::unordered_map mapIterDomainToTensorAxis( } // namespace -int64_t getShardedAxis(const TensorView* tv, const ParallelType parallel_type) { +int64_t getShardedLogicalAxis( + const TensorView* tv, + const ParallelType parallel_type) { std::unordered_map parallel_type_to_id = mapDeviceParallelTypeToId(tv->getMaybeAllocationDomain()); IterDomain* alloc_id = getOrDefault(parallel_type_to_id, parallel_type); @@ -212,7 +214,7 @@ std::vector unshardedSizes( c10::IntArrayRef sizes) { std::vector unsharded_sizes = sizes.vec(); for (ParallelType parallel_type : kParallelTypeDIDs) { - const int64_t sharded_axis = getShardedAxis(tv, parallel_type); + const int64_t sharded_axis = getShardedLogicalAxis(tv, parallel_type); if (sharded_axis == -1) { continue; } diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 2d39b4dbe29..287235b554b 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -126,7 +126,7 @@ void unshard(TensorView*); // Returns the index of the sharded logical axis that produces the allocation // IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel // type, returns -1. -int64_t getShardedAxis(const TensorView* tv, ParallelType parallel_type); +int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); // Reorders a TensorView so that the DID parallelized axis are in front. void reorderDIDToFront(TensorView*); diff --git a/tests/cpp/multidevice.cpp b/tests/cpp/multidevice.cpp index e591d61f980..22897dc5311 100644 --- a/tests/cpp/multidevice.cpp +++ b/tests/cpp/multidevice.cpp @@ -129,7 +129,9 @@ at::Tensor MultiDeviceTest::shardTensor(at::Tensor tensor, TensorView* tv) { } NVF_ERROR(tv->hasDeviceMesh(), "`tv` has no DeviceMesh: ", tv); return shardTensor( - tensor, getShardedAxis(tv, ParallelType::DIDx), tv->getDeviceMesh()); + tensor, + getShardedLogicalAxis(tv, ParallelType::DIDx), + tv->getDeviceMesh()); } at::Tensor MultiDeviceTest::shardTensor( @@ -147,7 +149,7 @@ at::Tensor MultiDeviceTest::shardTensor( i = (i < 0) ? 0 : i; // The following slicing is problematic when DID is on an inner split (cf. // MultiDeviceTest.ShardTensor_InnerSplit). We currently disallow that and - // it's enforced by getShardedAxis. + // it's enforced by getShardedLogicalAxis. return tensor.slice(axis, i * stride, (i + 1) * stride).contiguous(); } From 5229512c792877cd6352f30a1f25bf13daee1cea Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 6 Dec 2024 11:39:38 -0800 Subject: [PATCH 19/20] Comment --- csrc/multidevice/utils.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 287235b554b..ef88fbdcf80 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -126,6 +126,12 @@ void unshard(TensorView*); // Returns the index of the sharded logical axis that produces the allocation // IterDomain sharded on `parallel_type`. If `tv` isn't sharded on the parallel // type, returns -1. +// +// This is used to correlate `tv` and its corresponding at::Tensor, e.g., by +// `unshardedSizes` and `shardTensor`. `at::Tensor::sizes` and +// `tv->getLogicalDomain()` map one-to-one modulo reduction. However, a size in +// `at::Tensor::sizes` is a factor of the corresponding logical IterDomain's +// extent if that IterDomain is sharded. int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); // Reorders a TensorView so that the DID parallelized axis are in front. From 5157fe1231337777567378868e0793ec43a08906 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Fri, 6 Dec 2024 14:33:58 -0800 Subject: [PATCH 20/20] More comments --- csrc/multidevice/utils.cpp | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 76215d982c6..54f1303bc16 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -181,7 +181,35 @@ int64_t getShardedLogicalAxis( "Failed to find a non-reduction logical IterDomain that produces ", alloc_id); if (auto* split = dynamic_cast(def)) { - // FIXME: comment + // Returning just which tensor axis is sharded isn't sufficient to let + // shardTensor, a user of this function, know how to shard the tensor. + // For example, + // + // t = makeContigConcreteTensor({6}); + // t->split(0, 2, /*inner_split=*/true); + // t->axis(-1)->parallelize(DIDx); + // // [i{3}, iDIDx{2}] + // + // and the unsharded tensor is [0, 1, 2, 3, 4, 5], regardless of the + // stride. The sharded tensor ought to be [0, 2, 4] for GPU 0 and [1, 3, + // 5] for GPU 1. However, shardTensor as is will return [0, 1, 2] and [3, + // 4, 5], assuming the axis is sharded outermost. + // + // One potential way to solve the general problem is to replay and rewind + // the splits on the at::Tensor. For example, + // + // t = makeContigConcreteTensor({30}); + // t->split(0, 5); + // t->split(0, 3); + // t->axis(0)->parallelize(Host); + // t->axis(1)->parallelize(DIDx); + // // [iHost{2}, iDIDx{3}, i{5}] + // + // Given an unsharded at::Tensor of shape [30], we'll first replay the + // splits using `torch.view` to get a tensor of shape [2,3,5]. Then, we + // `torch.slice` axis 1 for DIDx to get a tensor of shape [2,1,5]. Then, + // we rewind the splits (and therefore apply merging) using + // `torch.reshape` to get a sharded tensor of shape [10]. NVF_ERROR( split->outer() == id, "Currently, we don't support DID on inner splits: ",