diff --git a/csrc/dynamic_transform.cpp b/csrc/dynamic_transform.cpp index 49afbac974e..24404db8d65 100644 --- a/csrc/dynamic_transform.cpp +++ b/csrc/dynamic_transform.cpp @@ -1048,12 +1048,6 @@ void DynamicTransformConcretizer::mutate(TensorView* tv) { // check the root to logical transforms to be sure we have concretized any // intermediate IterDomains. - // At this point, there should be no expr beyond rfactor root - NVF_ERROR( - tv->getLoopDomain() == tv->getLogicalDomain(), - "Invalid tensor: ", - tv->toString()); - // If it has an root domain, the IterTypes of the logical // IDs may need to be updated as well. Traverse the rfactor exprs // and mutate the IterTypes of output IDs if symbolic. diff --git a/csrc/expr_evaluator.cpp b/csrc/expr_evaluator.cpp index d4ca6daa022..da0a7720e15 100644 --- a/csrc/expr_evaluator.cpp +++ b/csrc/expr_evaluator.cpp @@ -60,7 +60,8 @@ void validateValWithConcreteValue( concrete_value); const auto& t = concrete_value.as(); auto expect_dim = - (int64_t)TensorDomain::noReductions(tv->getLogicalDomain()).size(); + (int64_t)TensorDomain::noReductions(tv->getMaybeAllocationDomain()) + .size(); NVF_CHECK( t.dim() == expect_dim, "Expected ", @@ -133,21 +134,22 @@ void ExpressionEvaluator::bindTensorDomain( const TensorView* tv, const at::Tensor& t, const bool evaluate_validate) { - auto logical_domain = TensorDomain::noReductions(tv->getLogicalDomain()); + auto alloc_domain = + TensorDomain::noReductions(tv->getMaybeAllocationDomain()); NVF_ERROR( - t.dim() == (int64_t)logical_domain.size(), + t.dim() == (int64_t)alloc_domain.size(), "Expected ", getInputPosString(tv), tv->toString(), ", to be bound to a tensor of rank ", - logical_domain.size(), + alloc_domain.size(), ", but got a tensor of rank ", t.dim()); for (auto i : c10::irange(t.dim())) { - auto id = logical_domain[i]; + auto id = alloc_domain[i]; if (id->isBroadcast()) { // DIDs are ignored for broadcast. - bind_(logical_domain[i]->extent(), 1, evaluate_validate); + bind_(alloc_domain[i]->extent(), 1, evaluate_validate); if (id->hasExpandedExtent()) { // Verify that t is also expanded NVF_ERROR( @@ -164,11 +166,10 @@ void ExpressionEvaluator::bindTensorDomain( t.stride(i), " in dimension ", i); - bind_( - logical_domain[i]->expandedExtent(), t.size(i), evaluate_validate); + bind_(alloc_domain[i]->expandedExtent(), t.size(i), evaluate_validate); } } else { - if (logical_domain[i]->isDeviceDim()) { + if (alloc_domain[i]->isDeviceDim()) { // Currently we have the restrictions: // (1) Devices parallelized axis extent == DeviceMesh's extent // (2) Device parallelized axis cannot be split or merged @@ -191,12 +192,12 @@ void ExpressionEvaluator::bindTensorDomain( getInputPosString(tv), " has an empty DeviceMesh with DID parallelization") bind_( - logical_domain[i]->extent(), + alloc_domain[i]->extent(), static_cast( - tv->getDeviceMesh().size(logical_domain[i]->getParallelType())), + tv->getDeviceMesh().size(alloc_domain[i]->getParallelType())), evaluate_validate); } else { - bind_(logical_domain[i]->extent(), t.size(i), evaluate_validate); + bind_(alloc_domain[i]->extent(), t.size(i), evaluate_validate); } } } diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 29fa52461e6..2dc60e87c8b 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -10,6 +10,7 @@ #include #include +#include #include #include #include @@ -404,7 +405,7 @@ namespace { class ForwardTraverseFromAllocToLogical { at::Tensor tensor_; - ExpressionEvaluator& ee_; + const ExpressionEvaluator& ee_; std::list& frontier_; // Forward traverse split from allocation to logical. Needs to, for example, @@ -521,7 +522,7 @@ class ForwardTraverseFromAllocToLogical { public: ForwardTraverseFromAllocToLogical( at::Tensor tensor, - ExpressionEvaluator& ee, + const ExpressionEvaluator& ee, std::list& frontier) : tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {} @@ -541,7 +542,7 @@ class ForwardTraverseFromAllocToLogical { // transformations. class BackwardTraverseFromAllocToLogical { at::Tensor tensor_; - ExpressionEvaluator& ee_; + const ExpressionEvaluator& ee_; std::list& frontier_; // Backward traverse split from allocation to logical. Needs to, for example, @@ -645,7 +646,7 @@ class BackwardTraverseFromAllocToLogical { public: BackwardTraverseFromAllocToLogical( at::Tensor tensor, - ExpressionEvaluator& ee, + const ExpressionEvaluator& ee, std::list& frontier) : tensor_(std::move(tensor)), ee_(ee), frontier_(frontier) {} @@ -662,49 +663,6 @@ class BackwardTraverseFromAllocToLogical { } }; -// Start from a tensor whose dimensions are consistent with the allocation -// domain of tv, apply a sequence of view/permute to the tensor to transform it -// into a format whose dimensions are consistent with the logical domain of tv. -// For example, if the logical domain is [I1, I2], and the allocation domain is -// [I2*I1], then we will allocate as [I2*I1], then do a tensor.view(I2, I1).t() -// to get a tensor whose semantics is [I1, I2] but its memory is [I2*I1]. -// Another example, if the logical domain is [I1*I2] and the allocation domain -// is [I1, I2], then we will allocate as [I1, I2] and do a tensor.view(I1*I2) to -// get a tensor whose semantics is [I1*I2] but memory is [I1,I2] -at::Tensor transformFromAllocationToLogical( - at::Tensor tensor, - TensorView* tv, - ExpressionEvaluator& ee) { - FUSER_PERF_SCOPE("allocations::transformFromAllocationToLogical"); - // Ignore reductions because reductions does not exist in tensor's definition - auto logical = TensorDomain::noReductions(tv->getLogicalDomain()); - auto alloc = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); - // Traverse all affine transformations from allocation domain. Because - // allocation domain can be before or after the logical domain, we need both a - // forward and a backward traverse. - std::list frontier(alloc.begin(), alloc.end()); - NVF_ERROR(tensor.dim() == (int64_t)frontier.size()); - tensor = ForwardTraverseFromAllocToLogical(tensor, ee, frontier) - .run(logical, alloc); - tensor = BackwardTraverseFromAllocToLogical(tensor, ee, frontier) - .run(logical, alloc); - NVF_ERROR(frontier.size() == logical.size()); - // Now that all affine transformations are handled, and frontiers should - // contain the same set of IDs as logical. We still need to do a final - // permutation so that their orders are also consistent. - std::unordered_map current_dims; - int64_t counter = 0; - for (auto id : frontier) { - current_dims[id] = counter++; - } - std::vector dims; - dims.reserve(frontier.size()); - for (auto id : logical) { - dims.emplace_back(current_dims.at(id)); - } - return tensor.permute(dims); -} - } // namespace std::pair, std::vector> inferShapeOfOutput( @@ -748,11 +706,53 @@ std::pair, std::vector> inferShapeOfOutput( c10::TensorOptions().device(c10::Device(c10::DeviceType::Meta)); auto meta_tensor = at::empty_strided(size_stride.first, size_stride.second, options); - // TODO(jiej): we should refactor it here, there's no need to use - // meta_tensor at all, size + stride should be used directly in the - // `transformFromAllocationToLogical` - meta_tensor = transformFromAllocationToLogical(meta_tensor, tv, expr_eval); return {meta_tensor.sizes().vec(), meta_tensor.strides().vec()}; } +at::Tensor transformFromAllocationToLogical( + at::Tensor tensor, + TensorView* tv, + const ExpressionEvaluator& ee) { + FUSER_PERF_SCOPE("allocations::transformFromAllocationToLogical"); + // Ignore reductions because reductions does not exist in tensor's definition + auto logical = TensorDomain::noReductions(tv->getLogicalDomain()); + auto alloc = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); + // Traverse all affine transformations from allocation domain. Because + // allocation domain can be before or after the logical domain, we need both a + // forward and a backward traverse. + std::list frontier(alloc.begin(), alloc.end()); + NVF_ERROR(tensor.dim() == (int64_t)frontier.size()); + tensor = ForwardTraverseFromAllocToLogical(tensor, ee, frontier) + .run(logical, alloc); + tensor = BackwardTraverseFromAllocToLogical(tensor, ee, frontier) + .run(logical, alloc); + NVF_ERROR(frontier.size() == logical.size()); + // Now that all affine transformations are handled, and frontiers should + // contain the same set of IDs as logical. We still need to do a final + // permutation so that their orders are also consistent. + std::unordered_map current_dims; + int64_t counter = 0; + for (auto id : frontier) { + current_dims[id] = counter++; + } + std::vector dims; + dims.reserve(frontier.size()); + for (auto id : logical) { + dims.emplace_back(current_dims.at(id)); + } + return tensor.permute(dims); +} + +at::Tensor transformFromLogicalToAllocation(at::Tensor tensor, TensorView* tv) { + FUSER_PERF_SCOPE("allocations::transformLogicalToAllocation"); + // Ignore reductions because reduction dimensions are not allocated in + // `tensor`. + auto logical = TensorDomain::noReductions(tv->getLogicalDomain()); + auto alloc = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); + + std::vector permutation = + *ir_utils::computePermutation(logical, alloc); + return tensor.permute(permutation); +} + } // namespace nvfuser diff --git a/csrc/runtime/allocations.h b/csrc/runtime/allocations.h index 1ec77eb3ce2..ca2fde41cd8 100644 --- a/csrc/runtime/allocations.h +++ b/csrc/runtime/allocations.h @@ -77,4 +77,20 @@ std::vector getBufferInfos( DataType index_dtype, const std::vector& fusion_outputs); +// Start from a tensor whose dimensions are consistent with the allocation +// domain of tv, apply a sequence of view/permute to the tensor to transform it +// into a format whose dimensions are consistent with the logical domain of tv. +// For example, if the logical domain is [I1, I2], and the allocation domain is +// [I2*I1], then we will allocate as [I2*I1], then do a tensor.view(I2, I1).t() +// to get a tensor whose semantics is [I1, I2] but its memory is [I2*I1]. +// Another example, if the logical domain is [I1*I2] and the allocation domain +// is [I1, I2], then we will allocate as [I1, I2] and do a tensor.view(I1*I2) to +// get a tensor whose semantics is [I1*I2] but memory is [I1,I2] +at::Tensor transformFromAllocationToLogical( + at::Tensor tensor, + TensorView* tv, + const ExpressionEvaluator& ee); + +at::Tensor transformFromLogicalToAllocation(at::Tensor tensor, TensorView* tv); + } // namespace nvfuser diff --git a/csrc/runtime/fusion_executor_cache.cpp b/csrc/runtime/fusion_executor_cache.cpp index 24830ba9bd1..c4337bd4f92 100644 --- a/csrc/runtime/fusion_executor_cache.cpp +++ b/csrc/runtime/fusion_executor_cache.cpp @@ -68,6 +68,7 @@ std::vector FusionExecutorCache::runFusionWithInputs( most_recent_runtime_ = kernel_runtime; auto fusion = kernel_runtime->fusionSegments()->completeFusion(); + ExpressionEvaluator evaluator = executor_utils::bindInputs(args, fusion); // Make sure the forced index type is indeed used if (forced_index_type.has_value()) { @@ -79,16 +80,22 @@ std::vector FusionExecutorCache::runFusionWithInputs( } auto outputs = kernel_runtime->runWithInputs(args); + NVF_ERROR(fusion->outputs().size() == outputs.size()); // Kernel time measurement is off by default kernel_runtime->disableKernelTimeMeasurement(); + for (const auto out_index : c10::irange(outputs.size())) { + at::Tensor& output = outputs[out_index]; + auto* out = fusion->outputs()[out_index]->as(); + output = transformFromAllocationToLogical(output, out, evaluator); + } + // Removing aliased outputs, since those are updated by the Fusion. It is not // semantically correct to actually return them as outputs from // fusion. - NVF_ERROR(fusion->outputs().size() == outputs.size()); size_t new_size = 0; - for (size_t out_index = 0; out_index < outputs.size(); out_index++) { + for (const auto out_index : c10::irange(outputs.size())) { Val* out = fusion->outputs()[out_index]; if (!fusion->getOutputAlias(out).hide_output) { outputs[new_size] = outputs[out_index]; @@ -113,8 +120,20 @@ KernelArgumentHolder FusionExecutorCache::prepareInputs( std::optional selected_device) { FUSER_PERF_SCOPE("FusionExecutorCache::prepareInputs"); - KernelArgumentHolder args = - KernelArgumentHolder::createKernelArgumentHolder(inputs, selected_device); + std::vector inputs_matching_allocation; + inputs_matching_allocation.reserve(inputs.size()); + for (const auto i : c10::irange(inputs.size())) { + const auto& input = inputs[i]; + if (!input.isTensor()) { + inputs_matching_allocation.push_back(input); + continue; + } + inputs_matching_allocation.push_back(transformFromLogicalToAllocation( + input.toTensor(), fusion_->inputs()[i]->as())); + } + + KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder( + inputs_matching_allocation, selected_device); // TODO: move InputsIdLookup inside KernelArgumentHolder; // NOTE: We must ensure that the cache id is in fact unique. Dynamic fusions @@ -125,7 +144,7 @@ KernelArgumentHolder FusionExecutorCache::prepareInputs( // short-circuiting here, resulting in avoidable rebuilds of concretization // info. auto id_lookup_ret = inputs_id_lookup_.lookupId( - inputs, + inputs_matching_allocation, initialInfo().scalarInputsAffectingConcretization(), args.getDeviceIndex()); if (id_lookup_ret.eviction) { diff --git a/csrc/tensor_metadata.cpp b/csrc/tensor_metadata.cpp index 26b5e21338f..220c407107e 100644 --- a/csrc/tensor_metadata.cpp +++ b/csrc/tensor_metadata.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include namespace nvfuser { @@ -291,26 +292,22 @@ inferAndValidateAllocationSizesAndStrides( } const auto& alloc = TensorDomain::noReductions(tv->getMaybeAllocationDomain()); - const auto& logical = TensorDomain::noReductions(tv->getLogicalDomain()); // active IDs and their shape and stride std::unordered_map> active_ids; - NVF_ERROR((int64_t)logical.size() == tensor.dim()); - for (int64_t i : c10::irange((int64_t)logical.size())) { - auto rf_id = logical.at(i); - active_ids[rf_id] = {tensor.size(i), tensor.stride(i)}; + NVF_ERROR((int64_t)alloc.size() == tensor.dim()); + for (const auto i : c10::irange(alloc.size())) { + IterDomain* alloc_id = alloc.at(i); + active_ids[alloc_id] = {tensor.size(i), tensor.stride(i)}; } - ForwardTraverseFromLogicalToAlloc(ee, active_ids).run(tv, logical, alloc); - BackwardTraverseFromLogicalToAlloc(ee, active_ids).run(tv, logical, alloc); - // Now active_ids should contain the final sizes and strides, unordered. We // need to put them to the correct order. std::vector sizes; std::vector strides; sizes.reserve(alloc.size()); strides.reserve(alloc.size()); - for (auto i : c10::irange(alloc.size())) { + for (const auto i : c10::irange(alloc.size())) { auto id = alloc.at(i); sizes.emplace_back(active_ids.at(id).first); strides.emplace_back(active_ids.at(id).second); @@ -369,30 +366,19 @@ std::vector GetMetaData::evaluate( std::get(aten_to_data_type(input.scalar_type()).type); metadata->data = input.data_ptr(); - if (isSharded(tv)) { - auto [unsharded_sizes, unsharded_strides] = - unshardedSizesAndStrides(tv, input.sizes(), input.strides()); - metadata->logical_size_data = std::move(unsharded_sizes); - metadata->logical_size = c10::makeArrayRef(metadata->logical_size_data); - metadata->logical_stride_data = std::move(unsharded_strides); - metadata->logical_stride = c10::makeArrayRef(metadata->logical_stride_data); - } else { - metadata->logical_size = input.sizes(); - metadata->logical_stride = input.strides(); - } + metadata->alloc_size = input.sizes(); + metadata->alloc_stride = input.strides(); + + at::Tensor meta_input = at::empty_like(input, at::device(at::kMeta)); + // FIXME: change size-1 to the device mesh size. + meta_input = transformFromAllocationToLogical(meta_input, tv, ee); + std::vector logical_sizes = meta_input.sizes().vec(); + std::vector logical_strides = meta_input.strides().vec(); + metadata->logical_size_data = std::move(logical_sizes); + metadata->logical_size = c10::makeArrayRef(metadata->logical_size_data); + metadata->logical_stride_data = std::move(logical_strides); + metadata->logical_stride = c10::makeArrayRef(metadata->logical_stride_data); - if (tv->hasAllocation()) { - auto allocation_data = - inferAndValidateAllocationSizesAndStrides(input, tv, ee); - metadata->alloc_size_data = std::move(allocation_data.first); - metadata->alloc_size = c10::makeArrayRef(metadata->alloc_size_data); - metadata->alloc_stride_data = std::move(allocation_data.second); - metadata->alloc_stride = c10::makeArrayRef(metadata->alloc_stride_data); - } else { - metadata->alloc_size = metadata->logical_size; - metadata->alloc_stride = metadata->logical_stride; - // TODO: validateAllocationSizesAndStrides - } return {PolymorphicValue(std::move(struct_))}; } diff --git a/tests/cpp/test_allocation_domain.cpp b/tests/cpp/test_allocation_domain.cpp index 42e1c48df8b..9d53496a2f5 100644 --- a/tests/cpp/test_allocation_domain.cpp +++ b/tests/cpp/test_allocation_domain.cpp @@ -107,24 +107,20 @@ TEST_F(AllocationDomainTest, NCHW4d_To_NHWC4d) { // A global->global copy kernel converting NCHW memory format into NHWC, with a // 1d allocation domain in output. TEST_F(AllocationDomainTest, NCHW4d_To_NHWC1d) { - Fusion fusion; - FusionGuard fg(&fusion); + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); auto tv0 = makeContigTensor(4); - fusion.addInput(tv0); + fusion->addInput(tv0); auto tv1 = set(tv0); - fusion.addOutput(tv1); + fusion->addOutput(tv1); // [N, C, H, W] tv1->reorder({{1, -1}}); // [N, H, W, C] tv1->flatten(); - tv1->setAllocationDomain({tv1->axis(0)}, true); // [N*H*W*C] - tv1->split(0, 128); - tv1->axis(1)->parallelize(ParallelType::TIDx); - tv1->axis(0)->parallelize(ParallelType::BIDx); - // [BIDx, TIDx] + tv1->setAllocationDomain(tv1->getLoopDomain(), true); auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); @@ -132,14 +128,12 @@ TEST_F(AllocationDomainTest, NCHW4d_To_NHWC1d) { at::Tensor t0 = at::randn({n, c, h, w}, options); - FusionExecutor fe; - fe.compileFusion(&fusion, {t0}); - - auto cg_outputs = fe.runFusion({t0}); + FusionExecutorCache executor_cache(std::move(fusion)); + auto cg_outputs = executor_cache.runFusionWithInputs({t0}); - ASSERT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); + EXPECT_TRUE(cg_outputs[0].is_contiguous(at::MemoryFormat::ChannelsLast)); - testValidate(&fusion, cg_outputs, {t0}, __LINE__, __FILE__); + testValidate(executor_cache.fusion(), cg_outputs, {t0}, __LINE__, __FILE__); } // A global->global copy kernel converting NCHW memory format into NHWC, with a