From 21c6e2e4a0f029766330a7c3acb964b956567395 Mon Sep 17 00:00:00 2001 From: Jingyue Wu Date: Tue, 12 Nov 2024 10:15:45 -0800 Subject: [PATCH] WIP --- csrc/ir/internal_nodes.h | 36 ++++++------- csrc/ir/utils.h | 19 +++++++ .../allocation_order_inference.cpp | 53 +++++++++++++++++-- csrc/runtime/fusion_kernel_runtime.cpp | 2 +- 4 files changed, 87 insertions(+), 23 deletions(-) diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 4d5c41e3ec2..b861eec8bd1 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2579,40 +2579,40 @@ class SdpaBwdOp : public Expr { std::string toString(int indent_size = 0) const override; std::string toInlineString(int indent_size = 0) const override; - Val* grad_query() const { - return output(0); + TensorView* grad_query() const { + return output(0)->as(); } - Val* grad_key() const { - return output(1); + TensorView* grad_key() const { + return output(1)->as(); } - Val* grad_value() const { - return output(2); + TensorView* grad_value() const { + return output(2)->as(); } - Val* grad_attn() const { - return input(0); + TensorView* grad_attn() const { + return input(0)->as(); } - Val* query() const { - return input(1); + TensorView* query() const { + return input(1)->as(); } - Val* key() const { - return input(2); + TensorView* key() const { + return input(2)->as(); } - Val* value() const { - return input(3); + TensorView* value() const { + return input(3)->as(); } - Val* attn_out() const { - return input(4); + TensorView* attn_out() const { + return input(4)->as(); } - Val* logsumexp() const { - return input(5); + TensorView* logsumexp() const { + return input(5)->as(); } Val* dropout_p() const { diff --git a/csrc/ir/utils.h b/csrc/ir/utils.h index 60062b0e440..228cda74b40 100644 --- a/csrc/ir/utils.h +++ b/csrc/ir/utils.h @@ -658,6 +658,25 @@ std::optional> computePermutation( return permutation; } +template +std::vector applyPermutation( + const std::vector& in, + const std::vector& permutation) { + NVF_CHECK(in.size() == permutation.size()); + + std::vector identity(permutation.size()); + std::iota(identity.begin(), identity.end(), 0); + NVF_CHECK(std::is_permutation( + permutation.begin(), permutation.end(), identity.begin())); + + std::vector out; + out.reserve(permutation.size()); + for (auto i : permutation) { + out.push_back(in[i]); + } + return out; +} + bool hasTrivialAllocationDomain(const TensorView* tv); // Returns true if all expr outputs should be mapped unconditionally diff --git a/csrc/preseg_passes/allocation_order_inference.cpp b/csrc/preseg_passes/allocation_order_inference.cpp index c01f07d7b15..fd5375e123d 100644 --- a/csrc/preseg_passes/allocation_order_inference.cpp +++ b/csrc/preseg_passes/allocation_order_inference.cpp @@ -336,6 +336,18 @@ void inferenceAllocationOrder( } } +namespace { +void propagateAllocation(TensorView* in, TensorView* out) { + auto in_order = ir_utils::computePermutation( + in->getLogicalDomain(), in->getMaybeAllocationDomain()); + if (!in_order.has_value()) { + return; + } + out->setAllocationDomain( + ir_utils::applyPermutation(out->getLogicalDomain(), *in_order), true); +} +} // namespace + void AllocationDomainPass::runPass(Fusion* fusion) { // mark input TensorViews as propagation sources auto input_tvs = ir_utils::filterByType(fusion->inputs()); @@ -351,15 +363,48 @@ void AllocationDomainPass::runPass(Fusion* fusion) { // hint, but they should respect semantic requirement. // see issue: https://github.com/NVIDIA/Fuser/pull/2425 for (TensorView* output : output_tvs) { - if (output->isDefinitionType() || - output->isDefinitionType() || - output->isDefinitionType()) { - continue; + if (Expr* def = output->definition()) { + if (def->isOneOf()) { + continue; + } } dsts.push_back(output); } // propagate allocation domain from sources to destinations inferenceAllocationOrder(fusion, srcs, dsts); + + for (Expr* e : fusion->exprs()) { + if (auto* sdpa_fwd = dynamic_cast(e)) { + std::optional> out_order = {}; + for (TensorView* in : + {sdpa_fwd->query(), sdpa_fwd->key(), sdpa_fwd->value()}) { + auto in_order = ir_utils::computePermutation( + in->getLogicalDomain(), in->getMaybeAllocationDomain()); + if (!in_order.has_value()) { + out_order = std::nullopt; + break; + } + if ((*out_order).empty()) { + out_order = *in_order; + continue; + } + if (*out_order != *in_order) { + out_order = std::nullopt; + break; + } + } + if (out_order.has_value()) { + TensorView* out = sdpa_fwd->attn_out(); + out->setAllocationDomain( + ir_utils::applyPermutation(out->getLogicalDomain(), *out_order), + true); + } + } else if (auto* sdpa_bwd = dynamic_cast(e)) { + propagateAllocation(sdpa_bwd->query(), sdpa_bwd->grad_query()); + propagateAllocation(sdpa_bwd->key(), sdpa_bwd->grad_key()); + propagateAllocation(sdpa_bwd->value(), sdpa_bwd->grad_value()); + } + } } } // namespace nvfuser::preseg_passes diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index 8069ed3ee3a..8cd85117b51 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -84,7 +84,7 @@ FusionKernelRuntime::FusionKernelRuntime( if (!communicator.is_available() || communicator.local_rank() == 0) { debug() << "Fusion IR after pre-segmenter optimization passes:" << std::endl; - fusion->printMath(); + fusion->print(); } }