Skip to content

Commit

Permalink
Attempt to fix a sharded matmul test.
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Aug 28, 2024
1 parent 4fbbe0b commit fd03214
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 65 deletions.
26 changes: 14 additions & 12 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,21 @@ void PrecomputedValues::bindTensorMetaData(

for (const auto dim : c10::irange(logical_domain.size())) {
IterDomain* id = logical_domain[dim];
auto dim_size = tensor.size(static_cast<int64_t>(dim));
if (id->isDeviceDim()) {
dim_size = tv->getDeviceMesh().size(id->getParallelType());
}

if (id->hasExpandedExtent()) {
Val* extent = id->extent();
Val* expanded_extent = id->expandedExtent();
bindValue(extent->evaluatorIndex(), 1L);
bindValue(expanded_extent->evaluatorIndex(), dim_size);
const auto dim_size = tensor.size(static_cast<int64_t>(dim));
if (id->isBroadcast()) {
// DIDs are ignored.
bindValue(id->extent()->evaluatorIndex(), 1L);
if (id->hasExpandedExtent()) {
bindValue(id->expandedExtent()->evaluatorIndex(), dim_size);
}
} else {
Val* extent = id->extent();
bindValue(extent->evaluatorIndex(), dim_size);
if (id->isDeviceDim()) {
bindValue(
id->extent()->evaluatorIndex(),
tv->getDeviceMesh().size(id->getParallelType()));
} else {
bindValue(id->extent()->evaluatorIndex(), dim_size);
}
}
}

Expand Down
100 changes: 54 additions & 46 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,53 +174,61 @@ void ExpressionEvaluator::bind_(
t.dim());
for (auto i : c10::irange(t.dim())) {
auto id = logical_domain[i];
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
t.size(i) == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
t.size(i),
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(
logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate);
} else if (logical_domain[i]->isDeviceDim()) {
// Currently we have the restrictions:
// (1) Devices parallelized axis extent == DeviceMesh's extent
// (2) Device parallelized axis cannot be split or merged
// Therefore, the device parallelized extents will always be allocated
// with size 1, but the symbolic axis extent is binded with the extent
// of the DeviceMesh
NVF_CHECK(
1 == t.size(i),
"TensorView ",
tv->toString(),
getInputPosString(tv),
" IterDomain ",
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->hasDeviceMesh(),
"TV ",
tv->toString(),
getInputPosString(tv),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
static_cast<int>(
tv->getDeviceMesh().size(logical_domain[i]->getParallelType())),
evaluate_validate);
} else {
if (id->isBroadcast()) {
// DIDs are ignored.
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
t.size(i) == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
t.size(i),
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(
logical_domain[i]->expandedExtent(),
t.size(i),
evaluate_validate);
}
} else {
if (logical_domain[i]->isDeviceDim()) {
// Currently we have the restrictions:
// (1) Devices parallelized axis extent == DeviceMesh's extent
// (2) Device parallelized axis cannot be split or merged
// Therefore, the device parallelized extents will always be allocated
// with size 1, but the symbolic axis extent is binded with the extent
// of the DeviceMesh
NVF_CHECK(
1 == t.size(i),
"TensorView ",
tv->toString(),
getInputPosString(tv),
" IterDomain ",
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->hasDeviceMesh(),
"TV ",
tv->toString(),
getInputPosString(tv),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
static_cast<int>(tv->getDeviceMesh().size(
logical_domain[i]->getParallelType())),
evaluate_validate);
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
}
}
}
}
Expand Down
14 changes: 7 additions & 7 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,13 @@ IterDomain* newOutputIterDomain(
continue;
}

NVF_ERROR(
id->getParallelType() == ParallelType::Serial ||
isParallelTypeDeviceDim(id->getParallelType()),
id->getParallelType(),
" is not expected when building ops.");
parallel_type = promoteParallelType(parallel_type, id->getParallelType());

if (id->isBroadcast()) {
if (id->hasExpandedExtent()) {
expanded_extent_val =
Expand All @@ -331,13 +338,6 @@ IterDomain* newOutputIterDomain(
continue;
}

NVF_ERROR(
id->getParallelType() == ParallelType::Serial ||
isParallelTypeDeviceDim(id->getParallelType()),
id->getParallelType(),
" is not expected when building ops.");
parallel_type = promoteParallelType(parallel_type, id->getParallelType());

if (extent_is_from_symbolic && !id->isSymbolic()) {
// We prefer to use extents from non-Symbolic inputs if there are any
// because they might indicate a broadcast axis that is resolved in this
Expand Down

0 comments on commit fd03214

Please sign in to comment.