Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MarkAliasesPrepare applies bookend from inputs as well as outputs. #2815

Closed
wants to merge 15 commits into from
26 changes: 14 additions & 12 deletions csrc/evaluator_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,19 +348,21 @@ void PrecomputedValues::bindTensorMetaData(

for (const auto dim : c10::irange(logical_domain.size())) {
IterDomain* id = logical_domain[dim];
auto dim_size = tensor.size(static_cast<int64_t>(dim));
if (id->isDeviceDim()) {
dim_size = tv->getDeviceMesh().size(id->getParallelType());
}

if (id->hasExpandedExtent()) {
Val* extent = id->extent();
Val* expanded_extent = id->expandedExtent();
bindValue(extent->evaluatorIndex(), 1L);
bindValue(expanded_extent->evaluatorIndex(), dim_size);
const auto dim_size = tensor.size(static_cast<int64_t>(dim));
if (id->isBroadcast()) {
// DIDs are ignored.
bindValue(id->extent()->evaluatorIndex(), 1L);
if (id->hasExpandedExtent()) {
bindValue(id->expandedExtent()->evaluatorIndex(), dim_size);
}
} else {
Val* extent = id->extent();
bindValue(extent->evaluatorIndex(), dim_size);
if (id->isDeviceDim()) {
bindValue(
id->extent()->evaluatorIndex(),
tv->getDeviceMesh().size(id->getParallelType()));
} else {
bindValue(id->extent()->evaluatorIndex(), dim_size);
}
}
}

Expand Down
100 changes: 54 additions & 46 deletions csrc/expr_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,53 +174,61 @@ void ExpressionEvaluator::bind_(
t.dim());
for (auto i : c10::irange(t.dim())) {
auto id = logical_domain[i];
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
t.size(i) == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
t.size(i),
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(
logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate);
} else if (logical_domain[i]->isDeviceDim()) {
// Currently we have the restrictions:
// (1) Devices parallelized axis extent == DeviceMesh's extent
// (2) Device parallelized axis cannot be split or merged
// Therefore, the device parallelized extents will always be allocated
// with size 1, but the symbolic axis extent is binded with the extent
// of the DeviceMesh
NVF_CHECK(
1 == t.size(i),
"TensorView ",
tv->toString(),
getInputPosString(tv),
" IterDomain ",
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->hasDeviceMesh(),
"TV ",
tv->toString(),
getInputPosString(tv),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
static_cast<int>(
tv->getDeviceMesh().size(logical_domain[i]->getParallelType())),
evaluate_validate);
if (id->isBroadcast()) {
// DIDs are ignored.
bind_(logical_domain[i]->extent(), 1, evaluate_validate);
if (id->hasExpandedExtent()) {
// Verify that t is also expanded
NVF_ERROR(
t.size(i) == 1 || t.stride(i) == 0,
"IterDomain ",
id->toString(),
" in ",
getInputPosString(tv),
"TensorView ",
tv->toString(),
" has expanded extent but input tensor has size ",
t.size(i),
" and stride ",
t.stride(i),
" in dimension ",
i);
bind_(
logical_domain[i]->expandedExtent(),
t.size(i),
evaluate_validate);
}
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
if (logical_domain[i]->isDeviceDim()) {
// Currently we have the restrictions:
// (1) Devices parallelized axis extent == DeviceMesh's extent
// (2) Device parallelized axis cannot be split or merged
// Therefore, the device parallelized extents will always be allocated
// with size 1, but the symbolic axis extent is binded with the extent
// of the DeviceMesh
NVF_CHECK(
1 == t.size(i),
"TensorView ",
tv->toString(),
getInputPosString(tv),
" IterDomain ",
id->toString(),
"is sharded and must have size 1, but input tensor has size ",
t.size(i));
NVF_CHECK(
tv->hasDeviceMesh(),
"TV ",
tv->toString(),
getInputPosString(tv),
" has an empty DeviceMesh with DID parallelization")
bind_(
logical_domain[i]->extent(),
static_cast<int64_t>(tv->getDeviceMesh().size(
logical_domain[i]->getParallelType())),
evaluate_validate);
} else {
bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate);
}
}
}
}
Expand Down
14 changes: 7 additions & 7 deletions csrc/ops/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,13 @@ IterDomain* newOutputIterDomain(
continue;
}

NVF_ERROR(
id->getParallelType() == ParallelType::Serial ||
isParallelTypeDeviceDim(id->getParallelType()),
id->getParallelType(),
" is not expected when building ops.");
parallel_type = promoteParallelType(parallel_type, id->getParallelType());

if (id->isBroadcast()) {
if (id->hasExpandedExtent()) {
expanded_extent_val =
Expand All @@ -331,13 +338,6 @@ IterDomain* newOutputIterDomain(
continue;
}

NVF_ERROR(
id->getParallelType() == ParallelType::Serial ||
isParallelTypeDeviceDim(id->getParallelType()),
id->getParallelType(),
" is not expected when building ops.");
parallel_type = promoteParallelType(parallel_type, id->getParallelType());

if (extent_is_from_symbolic && !id->isSymbolic()) {
// We prefer to use extents from non-Symbolic inputs if there are any
// because they might indicate a broadcast axis that is resolved in this
Expand Down
Loading
Loading