diff --git a/csrc/evaluator_common.cpp b/csrc/evaluator_common.cpp index 1042931914e..b17cfea9e8a 100644 --- a/csrc/evaluator_common.cpp +++ b/csrc/evaluator_common.cpp @@ -348,19 +348,21 @@ void PrecomputedValues::bindTensorMetaData( for (const auto dim : c10::irange(logical_domain.size())) { IterDomain* id = logical_domain[dim]; - auto dim_size = tensor.size(static_cast(dim)); - if (id->isDeviceDim()) { - dim_size = tv->getDeviceMesh().size(id->getParallelType()); - } - - if (id->hasExpandedExtent()) { - Val* extent = id->extent(); - Val* expanded_extent = id->expandedExtent(); - bindValue(extent->evaluatorIndex(), 1L); - bindValue(expanded_extent->evaluatorIndex(), dim_size); + const auto dim_size = tensor.size(static_cast(dim)); + if (id->isBroadcast()) { + // DIDs are ignored. + bindValue(id->extent()->evaluatorIndex(), 1L); + if (id->hasExpandedExtent()) { + bindValue(id->expandedExtent()->evaluatorIndex(), dim_size); + } } else { - Val* extent = id->extent(); - bindValue(extent->evaluatorIndex(), dim_size); + if (id->isDeviceDim()) { + bindValue( + id->extent()->evaluatorIndex(), + tv->getDeviceMesh().size(id->getParallelType())); + } else { + bindValue(id->extent()->evaluatorIndex(), dim_size); + } } } diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index c59842b2037..a360950354b 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -174,53 +174,61 @@ void ExpressionEvaluator::bind_( t.dim()); for (auto i : c10::irange(t.dim())) { auto id = logical_domain[i]; - if (id->hasExpandedExtent()) { - // Verify that t is also expanded - NVF_ERROR( - t.size(i) == 1 || t.stride(i) == 0, - "IterDomain ", - id->toString(), - " in ", - getInputPosString(tv), - "TensorView ", - tv->toString(), - " has expanded extent but input tensor has size ", - t.size(i), - " and stride ", - t.stride(i), - " in dimension ", - i); - bind_( - logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate); - } else if (logical_domain[i]->isDeviceDim()) { - // Currently we have the restrictions: - // (1) Devices parallelized axis extent == DeviceMesh's extent - // (2) Device parallelized axis cannot be split or merged - // Therefore, the device parallelized extents will always be allocated - // with size 1, but the symbolic axis extent is binded with the extent - // of the DeviceMesh - NVF_CHECK( - 1 == t.size(i), - "TensorView ", - tv->toString(), - getInputPosString(tv), - " IterDomain ", - id->toString(), - "is sharded and must have size 1, but input tensor has size ", - t.size(i)); - NVF_CHECK( - tv->hasDeviceMesh(), - "TV ", - tv->toString(), - getInputPosString(tv), - " has an empty DeviceMesh with DID parallelization") - bind_( - logical_domain[i]->extent(), - static_cast( - tv->getDeviceMesh().size(logical_domain[i]->getParallelType())), - evaluate_validate); + if (id->isBroadcast()) { + // DIDs are ignored. + bind_(logical_domain[i]->extent(), 1, evaluate_validate); + if (id->hasExpandedExtent()) { + // Verify that t is also expanded + NVF_ERROR( + t.size(i) == 1 || t.stride(i) == 0, + "IterDomain ", + id->toString(), + " in ", + getInputPosString(tv), + "TensorView ", + tv->toString(), + " has expanded extent but input tensor has size ", + t.size(i), + " and stride ", + t.stride(i), + " in dimension ", + i); + bind_( + logical_domain[i]->expandedExtent(), + t.size(i), + evaluate_validate); + } } else { - bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + if (logical_domain[i]->isDeviceDim()) { + // Currently we have the restrictions: + // (1) Devices parallelized axis extent == DeviceMesh's extent + // (2) Device parallelized axis cannot be split or merged + // Therefore, the device parallelized extents will always be allocated + // with size 1, but the symbolic axis extent is binded with the extent + // of the DeviceMesh + NVF_CHECK( + 1 == t.size(i), + "TensorView ", + tv->toString(), + getInputPosString(tv), + " IterDomain ", + id->toString(), + "is sharded and must have size 1, but input tensor has size ", + t.size(i)); + NVF_CHECK( + tv->hasDeviceMesh(), + "TV ", + tv->toString(), + getInputPosString(tv), + " has an empty DeviceMesh with DID parallelization") + bind_( + logical_domain[i]->extent(), + static_cast(tv->getDeviceMesh().size( + logical_domain[i]->getParallelType())), + evaluate_validate); + } else { + bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + } } } } diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 8b911d623c0..3b6bf6d561f 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -323,6 +323,13 @@ IterDomain* newOutputIterDomain( continue; } + NVF_ERROR( + id->getParallelType() == ParallelType::Serial || + isParallelTypeDeviceDim(id->getParallelType()), + id->getParallelType(), + " is not expected when building ops."); + parallel_type = promoteParallelType(parallel_type, id->getParallelType()); + if (id->isBroadcast()) { if (id->hasExpandedExtent()) { expanded_extent_val = @@ -331,13 +338,6 @@ IterDomain* newOutputIterDomain( continue; } - NVF_ERROR( - id->getParallelType() == ParallelType::Serial || - isParallelTypeDeviceDim(id->getParallelType()), - id->getParallelType(), - " is not expected when building ops."); - parallel_type = promoteParallelType(parallel_type, id->getParallelType()); - if (extent_is_from_symbolic && !id->isSymbolic()) { // We prefer to use extents from non-Symbolic inputs if there are any // because they might indicate a broadcast axis that is resolved in this