Skip to content

Commit

Permalink
newValLike tries to propagate sharding. (#2734)
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue authored Aug 3, 2024
1 parent 4a2987e commit 346e51c
Show file tree
Hide file tree
Showing 4 changed files with 333 additions and 172 deletions.
74 changes: 54 additions & 20 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <algorithm>
#include <limits>

#include <ir/builder.h>
#include <ir/utils.h>
#include <ops/arith.h>
#include <ops/utils.h>

#include <algorithm>
#include <limits>

namespace nvfuser {
namespace ops {

Expand Down Expand Up @@ -278,6 +279,22 @@ std::vector<IterDomain*> mapLinearOpIterDomains(
return mapping;
}

namespace {
ParallelType promoteParallelType(ParallelType a, ParallelType b) {
if (a == b) {
return a;
}
NVF_ERROR(
a == ParallelType::Serial || b == ParallelType::Serial,
"Doesn't know how to resolve ",
a,
" and ",
b,
" at this moment.");
return a == ParallelType::Serial ? b : a;
}
} // namespace

// Adding these pragmas since gcc-12.2.1
// incorrectly reports a warning with the use of evaluate
#if defined(__GNUC__) && !defined(__clang__)
Expand All @@ -297,26 +314,30 @@ IterDomain* newOutputIterDomain(
Val* extent_val = nullptr;
bool extent_is_from_symbolic = true;
Val* expanded_extent_val = nullptr;
auto parallel_type = ParallelType::Serial;
std::optional<IterType> iter_type = std::nullopt;

std::vector<IterDomain*> ids;
ids.reserve(input_ids.size());

// Filter out any nullptrs
std::copy_if(
input_ids.begin(),
input_ids.end(),
std::back_inserter(ids),
[](IterDomain* id) { return id != nullptr; });
for (auto id : input_ids) {
// Filter out any nullptrs
if (id == nullptr) {
continue;
}

for (auto id : ids) {
if (id->isBroadcast()) {
if (id->hasExpandedExtent()) {
expanded_extent_val =
promoteSize(expanded_extent_val, id->expandedExtent());
}
continue;
}

NVF_ERROR(
id->getParallelType() == ParallelType::Serial ||
isParallelTypeDeviceDim(id->getParallelType()),
id->getParallelType(),
" is not expected when building ops.");
parallel_type = promoteParallelType(parallel_type, id->getParallelType());

if (extent_is_from_symbolic && !id->isSymbolic()) {
// We prefer to use extents from non-Symbolic inputs if there are any
// because they might indicate a broadcast axis that is resolved in this
Expand Down Expand Up @@ -362,13 +383,15 @@ IterDomain* newOutputIterDomain(
IterDomainBuilder(
IrBuilder::create<Val>(start_offset, DataType::Index), extent_val)
.stop_offset(IrBuilder::create<Val>(stop_offset, DataType::Index))
.parallel_type(parallel_type)
.iter_type(iter_type.value())
.build();
} else {
out_domain = IterDomainBuilder(
FusionGuard::getCurFusion()->zeroVal(),
FusionGuard::getCurFusion()->oneVal())
.expanded_extent(expanded_extent_val)
.parallel_type(parallel_type)
.iter_type(IterType::Broadcast)
.build();
}
Expand All @@ -381,8 +404,8 @@ IterDomain* newOutputIterDomain(
std::vector<IterDomain*> newOutputDomain(const std::vector<Val*>& vals) {
std::vector<TensorView*> tvs;
for (auto val : vals) {
if (val->getValType() == ValType::TensorView) {
tvs.push_back(val->as<TensorView>());
if (auto* tv = dynamic_cast<TensorView*>(val)) {
tvs.push_back(tv);
}
}
NVF_CHECK(
Expand All @@ -395,7 +418,7 @@ std::vector<IterDomain*> newOutputDomain(const std::vector<Val*>& vals) {
for (const auto dim_i : c10::irange(out_domain.size())) {
std::vector<IterDomain*> input_ids;
input_ids.reserve(tvs.size());
for (auto tv : tvs) {
for (auto* tv : tvs) {
auto dom = TensorDomain::noReductions(tv->getLogicalDomain());
input_ids.emplace_back(dom[dim_i]);
}
Expand All @@ -406,10 +429,23 @@ std::vector<IterDomain*> newOutputDomain(const std::vector<Val*>& vals) {

TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype) {
auto out_domain = newOutputDomain(vals);
return IrBuilder::create<TensorView>(
auto* new_out = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
out_domain, TensorDomain::getContiguityFilledWith(out_domain, true)),
dtype);

DeviceMesh new_mesh;
// Find the first input that has a mesh. This seems arbitrary, but is at this
// moment safest because it's consistent with PropagateShardingsPass.
for (auto* tv : ir_utils::filterByType<TensorView>(vals)) {
if (tv->hasDeviceMesh()) {
new_mesh = tv->getDeviceMesh();
break;
}
}
new_out->setDeviceMesh(new_mesh);

return new_out;
}

std::vector<Val*> maybeBroadcast(const std::vector<Val*>& vals) {
Expand Down Expand Up @@ -439,9 +475,7 @@ Val* newValLike(Val* val, DataType dtype) {
NVF_CHECK(
dtype != DataType::Null, "Invalid datatype provided for new value.");

const ValType vtype = val->getValType().value();

if (vtype == ValType::TensorView) {
if (val->isA<TensorView>()) {
return newOutputTV({val}, dtype);
}

Expand Down
6 changes: 4 additions & 2 deletions csrc/ops/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,10 @@ IterDomain* newOutputIterDomain(
const std::vector<IterDomain*>& ids,
const std::optional<IterType> force_iter_type = std::nullopt);

// Takes a vector of tensorviews and assumes they are all aligned to create the
// output tensorview. For eg: BinaryOp.
// Takes a vector of `Val*`s and assumes they are all aligned to create the
// output tensorview, e.g., for BinaryOp. `vals` can contain scalars, e.g, when
// creating the output TensorView for `tv0+scalar`. This is for convenience and
// scalars will be ignored.
std::vector<IterDomain*> newOutputDomain(const std::vector<Val*>& vals);

TensorView* newOutputTV(const std::vector<Val*>& vals, DataType dtype);
Expand Down
48 changes: 48 additions & 0 deletions tests/cpp/test_multidevice_matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,4 +389,52 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutNT_ReduceScatter) {
->heuristic();
EXPECT_EQ(heuristic, ScheduleHeuristic::ExprEval);
}

// Reproduces #2721.
TEST_F(DistributedMatmulTest, PresegPreservesSharding) {
auto fusion = std::make_unique<Fusion>();
FusionGuard fg(fusion.get());
auto mesh = DeviceMesh::createForNumDevices(communicator_->size());

TensorView* x = makeContigTensor(2);
TensorView* w = makeContigTensor(3);
fusion->addInput(x);
fusion->addInput(w);

TensorView* w_t = transpose(w, 1, 2);
TensorView* mm = matmul(x, w_t);
TensorView* mm_t = transpose(mm, 1, 2);
fusion->addOutput(mm_t);

for (auto tv : {x}) {
tv->setDeviceMesh(mesh);
}
for (auto tv : {w, w_t, mm, mm_t}) {
tv->setDeviceMesh(mesh);
tv->axis(0)->parallelize(ParallelType::DIDx);
}

const auto options = at::TensorOptions().device(communicator_->device());
auto x_tensor = at::randn({12, 48}, options);
auto w_tensor = at::randn({mesh.size(), 36, 48}, options);
auto sharded_w_tensor = shardTensor(w_tensor, w);

MultiDeviceExecutor runtime(
std::move(fusion), *communicator_, executor_params_);
std::vector<c10::IValue> inputs({x_tensor, sharded_w_tensor});
std::vector<at::Tensor> outputs = runtime.runWithInput(inputs);

at::Tensor expected_mm_t_tensor =
atMatmul(x_tensor, w_tensor.view({mesh.size() * 36, 48}), MmaLayout::TN)
.transpose(0, 1)
.view({mesh.size(), 36, 12});
testValidate(
runtime.completeFusion(),
outputs,
inputs,
{shardTensor(expected_mm_t_tensor, mm_t)},
__LINE__,
__FILE__);
}

} // namespace nvfuser
Loading

0 comments on commit 346e51c

Please sign in to comment.