Skip to content

Commit

Permalink
Unshard tensor sizes before binding. (#3444)
Browse files Browse the repository at this point in the history
Fixes #3282 

With this PR, we'll still try to bind tensors to logical domains.
However, tensor sizes are "unsharded" before binding.

---------

Co-authored-by: samnordmann <[email protected]>
  • Loading branch information
wujingyue and samnordmann authored Nov 30, 2024
1 parent 889262e commit c154e90
Show file tree
Hide file tree
Showing 10 changed files with 291 additions and 145 deletions.
50 changes: 11 additions & 39 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
*/
// clang-format on

#include <functional>
#include <iostream>

#include <debug.h>
#include <evaluator_common.h>
#include <expr_evaluator.h>
Expand All @@ -14,11 +17,9 @@
#include <ir/iostream.h>
#include <ir/utils.h>
#include <logical_domain_map.h>
#include <multidevice/utils.h>
#include <polymorphic_value.h>

#include <functional>
#include <iostream>

namespace nvfuser {

namespace {
Expand Down Expand Up @@ -143,61 +144,32 @@ void ExpressionEvaluator::bindTensorDomain(
logical_domain.size(),
", but got a tensor of rank ",
t.dim());

std::vector<int64_t> logical_sizes = unshardedSizes(tv, t.sizes());
for (auto i : c10::irange(t.dim())) {
auto id = logical_domain[i];
if (id->isBroadcast()) {
// DIDs are ignored for broadcast.
bind_(logical_domain[i]->extent(), 1, evaluate_validate);
bind_(id->extent(), 1, evaluate_validate);
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
t.size(i) == 1 || t.stride(i) == 0,
logical_sizes[i] == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
t.size(i),
logical_sizes[i],
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(
logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate);
bind_(id->expandedExtent(), logical_sizes[i], evaluate_validate);
}
} else {
if (logical_domain[i]->isDeviceDim()) {
// Currently we have the restrictions:
// (1) Devices parallelized axis extent == DeviceMesh's extent
// (2) Device parallelized axis cannot be split or merged
// Therefore, the device parallelized extents will always be allocated
// with size 1, but the symbolic axis extent is binded with the extent
// of the DeviceMesh
NVF_CHECK(
1 == t.size(i),
"TensorView ",
tv->toString(),
getInputPosString(tv),
" IterDomain ",
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->hasDeviceMesh(),
"TV ",
tv->toString(),
getInputPosString(tv),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
static_cast<int64_t>(
tv->getDeviceMesh().size(logical_domain[i]->getParallelType())),
evaluate_validate);
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
}
bind_(id->extent(), logical_sizes[i], evaluate_validate);
}
}
}
Expand Down
7 changes: 6 additions & 1 deletion csrc/fusion_segmenter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1883,7 +1883,12 @@ void eraseInputDistinctRootDomains(Fusion* fusion) {
std::vector<IterDomain*> new_alloc;
new_alloc.reserve(tv->getAllocationDomain().size());
for (IterDomain* alloc_id : tv->getAllocationDomain()) {
new_alloc.push_back(replay.getReplay().at(alloc_id));
IterDomain* new_alloc_id = replay.getReplay().at(alloc_id);
// ReplayTransformations replay transforms but not paralelization, so
// we have to manually parallelize the new allocation ID. In other
// places, parallelization is usually done through parallelizeAllLike.
new_alloc_id->parallelize(alloc_id->getParallelType());
new_alloc.push_back(new_alloc_id);
}

std::vector<IterDomain*> new_loop;
Expand Down
23 changes: 10 additions & 13 deletions csrc/ir/nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3238,24 +3238,21 @@ bool TensorDomain::sameAs(
std::string TensorDomain::toString(const int indent_size, const bool loop_only)
const {
std::stringstream ss;
if (nDims() == 0) {
indent(ss, indent_size) << "[ ]";
return ss.str();
}
indent(ss, indent_size) << "[ " << toDelimitedString(loop()) << " ]";
if (!loop_only) {
if (loop_only) {
indent(ss, indent_size) << "[" << toDelimitedString(loop()) << "]";
} else {
indent(ss, indent_size)
<< "logical=[" << toDelimitedString(logical()) << "]" << std::endl;
if (hasRoot()) {
ss << "," << std::endl;
indent(ss, indent_size + 1)
<< "root=[ " << toDelimitedString(root()) << " ]";
<< "root=[" << toDelimitedString(root()) << "]" << std::endl;
}
ss << "," << std::endl;
indent(ss, indent_size + 1)
<< "logical=[ " << toDelimitedString(logical()) << " ]";
if (!allocation_domain_.empty()) {
ss << "," << std::endl;
<< "loop=[" << toDelimitedString(loop()) << "]" << std::endl;
if (hasAllocation()) {
indent(ss, indent_size + 1)
<< "allocation=[ " << toDelimitedString(allocation()) << " ]";
<< "allocation=[" << toDelimitedString(allocation()) << "]"
<< std::endl;
}
}
return ss.str();
Expand Down
49 changes: 47 additions & 2 deletions csrc/multidevice/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,8 @@ std::pair<std::vector<IterDomain*>, std::vector<IterDomain*>> getShardingChanges

bool isSharded(const TensorView* tv) {
bool is_sharded = false;
for (IterDomain* id : tv->getLoopDomain()) {
if (!id->isDeviceDim()) {
for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
if (!alloc_id->isDeviceDim()) {
continue;
}

Expand All @@ -121,6 +121,51 @@ bool isSharded(const TensorView* tv) {
return is_sharded;
}

std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes) {
std::vector<int64_t> unsharded_sizes = sizes.vec();

for (IterDomain* alloc_id : tv->getMaybeAllocationDomain()) {
const ParallelType parallel_type = alloc_id->getParallelType();
if (!isParallelTypeDeviceDim(parallel_type)) {
continue;
}

const auto inputs = IterVisitor::getInputsTo(
{alloc_id},
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()});
NVF_ERROR(
!inputs.empty(),
"IterVisitor::getInputsTo shouldn't return empty unless `of` is empty.");
NVF_ERROR(
inputs.size() == 1,
"Failed to find the single logical input to ",
alloc_id,
". This is likely because there's a Merge expression from logical to allocation, which isn't supported. Inputs are: ",
toDelimitedString(inputs));

const auto iter = std::find(
tv->getLogicalDomain().begin(),
tv->getLogicalDomain().end(),
inputs[0]);
NVF_ERROR(
iter != tv->getLogicalDomain().end(),
"The found input IterDomain isn't logical. This is likely because logical doesn't dominate allocation: ",
inputs[0]);

// Count the number of non-reduction IterDomains before `iter`. Reduction
// IterDomains are not materialized in the at::Tensor's shape.
const auto index = std::count_if(
tv->getLogicalDomain().begin(), iter, [](IterDomain* id) -> bool {
return !id->isReduction();
});
unsharded_sizes.at(index) *= tv->getDeviceMesh().size(parallel_type);
}

return unsharded_sizes;
}

int64_t numDeviceDims(const TensorView* tv) {
return std::count_if(
tv->getLoopDomain().begin(),
Expand Down
41 changes: 41 additions & 0 deletions csrc/multidevice/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// clang-format on
#pragma once

#include <c10/util/ArrayRef.h>

#include <compute_at_map.h>
#include <fusion.h>
#include <id_model/id_model.h>
Expand Down Expand Up @@ -127,4 +129,43 @@ int64_t getShardedAxis(TensorView*);

// Reorders a TensorView so that the DID parallelized axis are in front.
void reorderDIDToFront(TensorView*);

// Given a TensorView and the shape of a sharded tensor of which certain
// dimensions are partially allocated, returns the global shape that'll be used
// to bind to the TensorView's logical domain. This is to solve #3282 so we can
// bind a sharded tensor to a TensorView that has a DID-parallel loop domain.
//
// For example, when `tv` is
// logical: iM, iN
// allocation: iDIDx{D}, iN/D, iM
// and `sizes` is [2, 3], the returned shape will be [2, 3D]. This is because,
// according to the allocation domain, iM is fully allocated and iN is sharded
// and thus partially allocated.
//
// If the TensorView is not sharded, this function returns `sizes`.
//
// Limitations:
// - The function assumes that there are no Merges from logical to the
// DID-parallel IterDomains in allocation. Otherwise, it's unclear which logical
// dimension this DID-parallelization should be attributed to.
// - The function assumes that all Splits from logical to the DID-parallel
// IterDomains in allocation are even. This is because there are currently no
// ways to pass in the global shape.
//
// Despite these limitations, I took this approach as a shortcut to fix #3282,
// which blocked many other tasks. I'm however open to other better, long-term
// solutions. Some alternatives considered in #3282 are:
// - Try to bind `at::Tensor`s to allocation domains instead of logical. Many
// `*Op::evaluate` methods (e.g.
// https://github.com/NVIDIA/Fuser/blob/2415d904d1e9a5da7ca6fb1a55d3045bbd510341/csrc/ir/nodes.cpp#L4321-L4329)
// assume the input/output `at::Tensor`s have the same dimension order as the
// logical domain. Doing so would have to change them all.
// - Try to pass into FusionExecutorCache both logical (global) shapes and
// allocated (local) tensors for sharded TensorViews. The logical shapes would
// have to be passed through FusionKernelRuntime, FusionExecutor,
// ExpressionEvaluator, and so on, which is an API overhaul.
std::vector<int64_t> unshardedSizes(
const TensorView* tv,
c10::IntArrayRef sizes);

} // namespace nvfuser
36 changes: 19 additions & 17 deletions csrc/scheduler/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,34 +227,36 @@ class DomainMap : public pointwise_utils::DomainMap {
root_dim,
" in tensor ",
tv);
auto replay_exprs = StmtSort::getExprsBetween(
std::vector<Expr*> replay_exprs = StmtSort::getExprsBetween(
{mapped_id}, {tv->getLoopDomain().begin(), tv->getLoopDomain().end()});
// Project the root id to loop id. Similar to projectIdToRFactor.
for (auto expr : replay_exprs) {
if (expr->isA<Split>()) {
// Split with factor one is not supposed to be here, reshape would map
// this to a broadcast. This is a conservative assert, we can relaxed it
// and support with mapping it to outer.
NVF_ERROR(
!expr->as<Split>()->factor()->isOneInt(),
"split with factor one is supposed to be translated to broadcast by reshape");
if (expr->as<Split>()->in() == mapped_id) {
mapped_id = expr->as<Split>()->inner();
for (auto* expr : replay_exprs) {
if (auto* split = dynamic_cast<Split*>(expr)) {
if (split->in() == mapped_id) {
if (split->inner()->extent()->isOneInt() &&
!split->outer()->extent()->isOneInt()) {
mapped_id = split->outer();
} else {
mapped_id = split->inner();
}
}
} else if (expr->isA<Merge>()) {
} else if (auto* merge = dynamic_cast<Merge*>(expr)) {
// Merge with size-1 dimension is not supposed to be here, reshape would
// map this to a squeeze. This is a conservative assert, we can relaxed
// it and support with mapping it to out.
NVF_ERROR(
!expr->as<Merge>()->inner()->extent()->isOneInt(),
!merge->inner()->extent()->isOneInt(),
"merge with size-1 dimension is supposed to be translated to squeeze by reshape");
if (expr->as<Merge>()->inner() == mapped_id) {
mapped_id = expr->as<Merge>()->out();
if (merge->inner() == mapped_id) {
mapped_id = merge->out();
}
} else if (auto* resize = dynamic_cast<Resize*>(expr)) {
if (resize->in() == mapped_id) {
mapped_id = resize->out();
}
} else if (expr->isA<Resize>() && expr->as<Resize>()->in() == mapped_id) {
mapped_id = expr->as<Resize>()->out();
}
}

// Find the position of the loop id
const auto& dom = tv->getLoopDomain();
for (auto i : c10::irange(dom.size())) {
Expand Down
Loading

0 comments on commit c154e90

Please sign in to comment.