Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
wujingyue committed Nov 12, 2024
1 parent 6c29653 commit 536fd2a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 23 deletions.
36 changes: 18 additions & 18 deletions csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -2579,40 +2579,40 @@ class SdpaBwdOp : public Expr {
std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

Val* grad_query() const {
return output(0);
TensorView* grad_query() const {
return output(0)->as<TensorView>();
}

Val* grad_key() const {
return output(1);
TensorView* grad_key() const {
return output(1)->as<TensorView>();
}

Val* grad_value() const {
return output(2);
TensorView* grad_value() const {
return output(2)->as<TensorView>();
}

Val* grad_attn() const {
return input(0);
TensorView* grad_attn() const {
return input(0)->as<TensorView>();
}

Val* query() const {
return input(1);
TensorView* query() const {
return input(1)->as<TensorView>();
}

Val* key() const {
return input(2);
TensorView* key() const {
return input(2)->as<TensorView>();
}

Val* value() const {
return input(3);
TensorView* value() const {
return input(3)->as<TensorView>();
}

Val* attn_out() const {
return input(4);
TensorView* attn_out() const {
return input(4)->as<TensorView>();
}

Val* logsumexp() const {
return input(5);
TensorView* logsumexp() const {
return input(5)->as<TensorView>();
}

Val* dropout_p() const {
Expand Down
19 changes: 19 additions & 0 deletions csrc/ir/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,25 @@ std::optional<std::vector<int64_t>> computePermutation(
return permutation;
}

template <typename T>
std::vector<T> applyPermutation(
const std::vector<T>& in,
const std::vector<int64_t>& permutation) {
NVF_CHECK(in.size() == permutation.size());

std::vector<int64_t> identity(permutation.size());
std::iota(identity.begin(), identity.end(), 0);
NVF_CHECK(std::is_permutation(
permutation.begin(), permutation.end(), identity.begin()));

std::vector<T> out;
out.reserve(permutation.size());
for (auto i : permutation) {
out.push_back(in[i]);
}
return out;
}

bool hasTrivialAllocationDomain(const TensorView* tv);

// Returns true if all expr outputs should be mapped unconditionally
Expand Down
53 changes: 49 additions & 4 deletions csrc/preseg_passes/allocation_order_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,18 @@ void inferenceAllocationOrder(
}
}

namespace {
void propagateAllocation(TensorView* in, TensorView* out) {
auto in_order = ir_utils::computePermutation(
in->getLogicalDomain(), in->getMaybeAllocationDomain());
if (!in_order.has_value()) {
return;
}
out->setAllocationDomain(
ir_utils::applyPermutation(out->getLogicalDomain(), *in_order), true);
}
} // namespace

void AllocationDomainPass::runPass(Fusion* fusion) {
// mark input TensorViews as propagation sources
auto input_tvs = ir_utils::filterByType<TensorView>(fusion->inputs());
Expand All @@ -351,15 +363,48 @@ void AllocationDomainPass::runPass(Fusion* fusion) {
// hint, but they should respect semantic requirement.
// see issue: https://github.com/NVIDIA/Fuser/pull/2425
for (TensorView* output : output_tvs) {
if (output->isDefinitionType<LinearOp>() ||
output->isDefinitionType<MatmulOp>() ||
output->isDefinitionType<MmaOp>()) {
continue;
if (Expr* def = output->definition()) {
if (def->isOneOf<LinearOp, SdpaFwdOp, SdpaBwdOp, MatmulOp, MmaOp>()) {
continue;
}
}
dsts.push_back(output);
}
// propagate allocation domain from sources to destinations
inferenceAllocationOrder(fusion, srcs, dsts);

for (Expr* e : fusion->exprs()) {
if (auto* sdpa_fwd = dynamic_cast<SdpaFwdOp*>(e)) {
std::optional<std::vector<int64_t>> out_order = {};
for (TensorView* in :
{sdpa_fwd->query(), sdpa_fwd->key(), sdpa_fwd->value()}) {
auto in_order = ir_utils::computePermutation(
in->getLogicalDomain(), in->getMaybeAllocationDomain());
if (!in_order.has_value()) {
out_order = std::nullopt;
break;
}
if ((*out_order).empty()) {
out_order = *in_order;
continue;
}
if (*out_order != *in_order) {
out_order = std::nullopt;
break;
}
}
if (out_order.has_value()) {
TensorView* out = sdpa_fwd->attn_out();
out->setAllocationDomain(
ir_utils::applyPermutation(out->getLogicalDomain(), *out_order),
true);
}
} else if (auto* sdpa_bwd = dynamic_cast<SdpaBwdOp*>(e)) {
propagateAllocation(sdpa_bwd->query(), sdpa_bwd->grad_query());
propagateAllocation(sdpa_bwd->key(), sdpa_bwd->grad_key());
propagateAllocation(sdpa_bwd->value(), sdpa_bwd->grad_value());
}
}
}

} // namespace nvfuser::preseg_passes
2 changes: 1 addition & 1 deletion csrc/runtime/fusion_kernel_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ FusionKernelRuntime::FusionKernelRuntime(
if (!communicator.is_available() || communicator.local_rank() == 0) {
debug() << "Fusion IR after pre-segmenter optimization passes:"
<< std::endl;
fusion->printMath();
fusion->print();
}
}

Expand Down

0 comments on commit 536fd2a

Please sign in to comment.