Skip to content

Commit

Permalink
Adding resize(PadOp) vectorization analysis (#3321)
Browse files Browse the repository at this point in the history
Adding **conditional** support of reszie in vectorization analysis. This
PR allows vectorized load on `PadOp` directly without using cache load.
This PR improves performance of generated kernel.

What's in this PR:
1. Add propagation rule for resize in vectorization analysis. The
propagation rule works as:
i. For supported resize: a). project the resize op to the frontier and
clear `(frontier.begin(), resize_position)`; b). add projected extent of
the new resize op as `gcd(id_from, resize_op->leftExpand(),
resize_op->rightExpand)`
ii. For unsupported resize: clear `[frontier.begin(), resize_position]`;
no behavior change.

2. updating TensorView::cacheAfter to opt-in a set of uses to cache
while leaving other uses unchanged. Necessary for cases where inputs are
used by PadOp as well as other operation that relies on cached load for
vectorization.

Follow up to #3261.
Work for supporting rope performance. [design
doc](https://docs.google.com/document/d/1tafRMNIXMmHlIGAiNlaPkYp6mZAzJ2Rh_NtARHbmYNA/edit?disco=AAABYEnV_ZY):

---------

Co-authored-by: Naoya Maruyama <[email protected]>
  • Loading branch information
jjsjann123 and naoyam authored Nov 13, 2024
1 parent 61ffac9 commit 2fb5539
Show file tree
Hide file tree
Showing 7 changed files with 463 additions and 63 deletions.
8 changes: 7 additions & 1 deletion csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,10 +482,16 @@ class NVF_API TensorView : public Val {
//!
//! @param op_type: memory operator to use for the inserted op between
//! the the data tensor and the cache tensor
//! @param cache_op: cache operator, see enum class CacheOp
//! @param propagate_allocation_domain: replay allocation domain on cached
//! load
//! @param cached_uses: if empty, cache all uses; otherwise, only try to cache
//! uses in cached_uses.
TensorView* cacheAfter(
LoadStoreOpType op_type = LoadStoreOpType::Set,
CacheOp cache_op = CacheOp::Unspecified,
bool propagate_allocation_domain = true);
bool propagate_allocation_domain = true,
std::vector<Expr*> cached_uses = {});

// For a fusion output with other uses, we want to avoid writing to global
// memory and then reading the output again. We write to global memory
Expand Down
5 changes: 4 additions & 1 deletion csrc/preseg_passes/move_pad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,10 @@ TensorView* replayConcretePad(

auto* new_out = IrBuilder::create<TensorView>(
IrBuilder::create<TensorDomain>(
merged_root_ids, merged_logical_ids, merged_logical_ids),
merged_root_ids,
merged_logical_ids,
merged_logical_ids,
TensorDomain::getContiguityFilledWith(merged_logical_ids, true)),
pad_tv->getDataType().value());
IrBuilder::create<PadOp>(
new_out,
Expand Down
68 changes: 34 additions & 34 deletions csrc/scheduler/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1198,12 +1198,36 @@ std::vector<TensorView*> cacheInputs(Fusion* fusion, bool unroll) {
for (auto tv : in_tvs) {
if (tv->uses().empty() || ir_utils::isTorchGatherLookupTv(tv) ||
ir_utils::isIndexSelectLookupTv(tv) ||
ir_utils::isTvUsedByOpsOfType<SliceOp, SelectOp, PadOp>(tv)) {
ir_utils::isTvUsedByOpsOfType<SliceOp, SelectOp>(tv)) {
// Right now, tensors that are input to the slice, select, and pad ops
// can't be cached as they must be in global memory.
continue;
}
auto cached_tv = tv->cacheAfter();

// TODO: might need to reverse this when scheduler handles pad directly
// Do not insert a cache for pad as vectorization needs to be
// done directly.
//
// Note that this means that if an input is padded and also is
// used without padding, it will be read twice, once for pad and
// once more for caching load. It would make sense to use the PTX
// caching load instructions.
std::vector<Expr*> cached_uses;
for (auto use : tv->uses()) {
if (!use->isA<PadOp>()) {
cached_uses.push_back(use);
}
}

if (cached_uses.empty()) {
continue;
}

auto cached_tv = tv->cacheAfter(
/*op_type=*/LoadStoreOpType::Set,
/*cache_op=*/CacheOp::Unspecified,
/*propagate_allocation_domain=*/true,
/*cached_uses=*/cached_uses);
cached_inputs.emplace_back(cached_tv);
}
return cached_inputs;
Expand Down Expand Up @@ -1290,12 +1314,7 @@ IterDomain* projectIdToRoot(
} else if (expr->isA<Resize>()) {
auto resize = expr->as<Resize>();
if (resize->out() == projected_id) {
// We do not allow vectorization with resize at this moment
if (vectorize_pass) {
projected_id = nullptr;
} else {
projected_id = resize->in();
}
projected_id = resize->in();
}
} else {
NVF_THROW("Didn't recognize the iterdomain expression: ", expr);
Expand Down Expand Up @@ -1350,12 +1369,7 @@ IterDomain* projectIdToRFactor(
} else if (expr->isA<Resize>()) {
auto resize = expr->as<Resize>();
if (resize->in() == projected_id) {
// We do not allow vectorization wit resize at this moment
if (vectorize_pass) {
projected_id = nullptr;
} else {
projected_id = resize->out();
}
projected_id = resize->out();
}
} else {
NVF_THROW("Didn't recognize the iterdomain expression: ", expr);
Expand Down Expand Up @@ -1549,12 +1563,6 @@ std::vector<TensorView*> getInputsOutputsWithInnerDim(
// scheduler prefer to use output instead of input as reference tensor.
for (auto output_tv :
ir_utils::filterByType<TensorView>(reference_tv->fusion()->outputs())) {
// At this moment, vectorization through resize is not
// supported. This is not required currently as we always insert
// cacheBefore, but just in case.
if (ir_utils::hasResizedRfactor(output_tv)) {
continue;
}
if (hasInnerDim(output_tv, vectorizable_dims, vectorize_pass)) {
vectorizable_tensors.push_back(output_tv);
}
Expand All @@ -1569,19 +1577,11 @@ std::vector<TensorView*> getInputsOutputsWithInnerDim(
continue;
}

auto expr_resizes = [](Expr* e) -> bool {
return std::any_of(
e->outputs().begin(), e->outputs().end(), [](Val* out) -> bool {
if (auto* out_tv = dynamic_cast<TensorView*>(out)) {
return ir_utils::hasResizedRfactor(out_tv);
}
return false;
});
};

// At this moment, vectorization through resize is not supported
if (std::any_of(
input_tv->uses().begin(), input_tv->uses().end(), expr_resizes)) {
// Slice op is explicitly not enabled for vectorized load.
if (std::all_of(
input_tv->uses().begin(),
input_tv->uses().end(),
[](Expr* e) -> bool { return e->isA<SliceOp>(); })) {
continue;
}

Expand Down Expand Up @@ -2385,7 +2385,7 @@ bool revertUseOfInputCache(
void prepareForMemoryTypePromotion(Fusion* fusion) {
auto non_pwise_pairs = getNonPointwiseProducerConsumerPairs(fusion);

// Inserting a copy of each proucer. If a tensor shows up as a
// Inserting a copy of each producer. If a tensor shows up as a
// producer for multiple consumers, only insert one
// copy and share it with all the consumers.

Expand Down
81 changes: 74 additions & 7 deletions csrc/scheduler/vectorize_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,30 @@ Val* ContiguousInnerDimensionsMapper::isFullyProjected(IterDomain* id) {
getProjectedExtent(id), commonOrConstExtent(ca_map_, id));
}

void ContiguousInnerDimensionsMapper::initializeResizeInfo(Fusion* fusion) {
auto exprs = fusion->exprs();
for (auto* pad_op : ir_utils::filterByType<PadOp>(exprs)) {
if (!pad_op->out()->isA<TensorView>()) {
continue;
}

auto* out_tv = pad_op->out()->as<TensorView>();

auto consumer_exprs = StmtSort::getExprsBetween(
{out_tv->getMaybeRootDomain().begin(),
out_tv->getMaybeRootDomain().end()},
{out_tv->getLogicalDomain().begin(), out_tv->getLogicalDomain().end()});

// NOTE: if we can assume that PadOp is always on inputs, then we can skip
// to innermost resize instead.
auto resize_ops = ir_utils::filterByType<Resize>(consumer_exprs);
std::copy(
resize_ops.begin(),
resize_ops.end(),
std::inserter(resize_in_pad_, resize_in_pad_.end()));
}
}

ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper(
TensorView* reference,
const std::vector<IterDomain*>& ids,
Expand All @@ -67,6 +91,9 @@ ContiguousInnerDimensionsMapper::ContiguousInnerDimensionsMapper(
ca_map_(std::move(ca_map)),
divisible_splits_(divisible_splits) {
FusionGuard fg(reference->fusion());

initializeResizeInfo(reference->fusion());

// Exclude reduction IDs if the reference is a fusion input as they
// don't manifest at all in the fusion. This simplifies the
// analysis in getContigMergeOfInnerSize, which only looks at
Expand Down Expand Up @@ -365,9 +392,51 @@ std::vector<IterDomain*> ContiguousInnerDimensionsMapper::projectId(
distributePE(merge_or_split);
};

auto clear_left_of = [&frontier](IterDomain* id) {
auto it = std::find(frontier.begin(), frontier.end(), id);
if (it != frontier.end()) {
auto propagateResize = [&frontier, this](Resize* resize_op, bool p2c) {
IterDomain* id_from = p2c ? resize_op->in() : resize_op->out();
IterDomain* id_to = p2c ? resize_op->out() : resize_op->in();

auto it = std::find(frontier.begin(), frontier.end(), id_from);
if (it == frontier.end()) {
return;
}

auto pos = std::distance(frontier.begin(), it);
if (resize_in_pad_.count(resize_op) != 0) {
// resize created by PadOp.

// project resize op to frontier.
frontier[pos] = id_to;
// clear left of resize, since those are no long contiguous.
frontier.erase(frontier.begin(), it);

if (recording_) {
// TODO: support negative resize extent.
//
// Limit current support to only positive resize extent for now. So we
// only consider the pad_extent, which becomes the real buffer on
// output. Hence we do GCD among padded extent as well as extent of the
// id_from. Note since we are taking the GCD here, I don't think using
// id_from or id_to makes a difference.
auto consumer_factor = getProjectedExtent(id_from);
auto comp = [](Val* factor, Val* extent) {
return SimplifyingIrBuilder::whereExpr(
SimplifyingIrBuilder::eqExpr(
extent, extent->container()->zeroVal()),
factor,
// for extent < 0, we'll take max(1, extent). Because of the gcd,
// This is effectively excluding the resize id from vectorization.
SimplifyingIrBuilder::gcdExpr(
factor,
SimplifyingIrBuilder::maxExpr(
extent->container()->oneVal(), extent)));
};
consumer_factor = comp(consumer_factor, resize_op->leftExpand());
consumer_factor = comp(consumer_factor, resize_op->rightExpand());
addProjectedExtent(id_to, consumer_factor);
}
} else {
// unsupproted resize.
frontier.erase(frontier.begin(), it + 1);
}
};
Expand All @@ -391,8 +460,7 @@ std::vector<IterDomain*> ContiguousInnerDimensionsMapper::projectId(
} else if (Merge* merge = dynamic_cast<Merge*>(expr)) {
propagateDistribute(merge);
} else if (Resize* resize = dynamic_cast<Resize*>(expr)) {
// Cannot vectorize through resize
clear_left_of(resize->out());
propagateResize(resize, false);
} else {
// TODO: I wonder if we should just remove all inputs instead of erroring.
// Seems that would be safe.
Expand All @@ -415,8 +483,7 @@ std::vector<IterDomain*> ContiguousInnerDimensionsMapper::projectId(
} else if (Split* split = dynamic_cast<Split*>(expr)) {
propagateDistribute(split);
} else if (Resize* resize = dynamic_cast<Resize*>(expr)) {
// Cannot vectorize through resize
clear_left_of(resize->in());
propagateResize(resize, true);
} else {
// TODO: I wonder if we should just remove all inputs instead of erroring.
// Seems that would be safe.
Expand Down
10 changes: 8 additions & 2 deletions csrc/scheduler/vectorize_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace vectorize_helper {

// Projects IterDomains through the fusion starting at provided reference. IDs
// in the reference are expected to be "contiguous", simply means dimensions
// that the iter domains are consecutive and next to eachother in the
// that the iter domains are consecutive and next to each other in the
// reference. This property is not enforced, but mapping can have some
// unpredictbale results if they are not. The reason we want contiguity here
// is this class is primarily used for vectorization analysis. Domains may be
Expand Down Expand Up @@ -78,7 +78,7 @@ namespace vectorize_helper {
// tv1[2*3, 5, 7*11] = view(tv0)
// with tv1 and [2*3, 7*11] as the reference and ids. tv0's 2 and 11 dim are
// easily identified as being mapped. The 3*5*7 dimension however, is
// partially mapped on the left and right side. Since this class is intended to
// partially mapped on the left and right side. Since this class is intended to
// line up "inner dimensions" of tensors through out the graph for the purpose
// of unrolling and vectorization, it only tracks partial dimensions as they are
// on the right hand side of iteration domains. For example in the last case we
Expand Down Expand Up @@ -289,6 +289,9 @@ class NVF_API ContiguousInnerDimensionsMapper
void propagateP2C(TensorView* from, TensorView* to) final;
void propagateSibling(TensorView* from, TensorView* to) final;

// traverse fusion to mark the origin of Resize
void initializeResizeInfo(Fusion* fusion);

// Initialized to false, series of compute... calls will be performed to find
// the spanning tree. Then propagate... calls will call the compute... calls.
// recording_ starts as false, and stays that way during the first series of
Expand All @@ -308,6 +311,9 @@ class NVF_API ContiguousInnerDimensionsMapper
tv_infos_;

std::unordered_map<IterDomain*, Val*> projected_extent_;

//! stores all Resize* op that's added from PadOp*
std::unordered_set<Resize*> resize_in_pad_;
};

// logical_reorder_map is provided to assume reference_tv will be reordered per
Expand Down
43 changes: 33 additions & 10 deletions csrc/tensor_view.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1169,15 +1169,37 @@ TensorView* TensorView::cacheFork() {
TensorView* TensorView::cacheAfter(
LoadStoreOpType op_type,
CacheOp cache_op,
bool propagate_allocation_domain) {
bool propagate_allocation_domain,
std::vector<Expr*> cached_uses) {
NVF_ERROR(
!container()->isA<kir::Kernel>(),
"Function invalid for kernel container.");
FusionGuard fg(fusion());

if (!cached_uses.empty()) {
std::unordered_set<Expr*> unique_uses = fusion()->unordered_uses(this);
for (auto use : cached_uses) {
NVF_ERROR(
unique_uses.count(use),
"cached_uses is not among the use of the TensorView");
}
} else {
// avoid non-determinism and ensure unique
std::unordered_set<Expr*> unique_uses;
auto this_uses = uses();
cached_uses.reserve(this_uses.size());
for (Expr* use : this_uses) {
NVF_ERROR(
unique_uses.count(use) == 0,
"detect duplicated entries in TensorView::uses()");
cached_uses.push_back(use);
unique_uses.insert(use);
}
}

// Get all the uses for this Tensorview
NVF_CHECK(
!uses().empty(),
!cached_uses.empty(),
"Error adding cacheAfter ",
this,
" we restrict using cacheAfter on tensors that have no further uses.");
Expand All @@ -1188,18 +1210,19 @@ TensorView* TensorView::cacheAfter(
!hasComputeAt(),
"Caching computed-at tensors is not allowed. Apply caching before computeAt.");

bool is_allowed_op =
!ir_utils::isTvUsedByOpsOfType<SliceOp, SelectOp, PadOp>(this) &&
!ir_utils::isIndexSelectLookupTv(this);
NVF_CHECK(
is_allowed_op,
"Right now, caching tensors that are input to the select/slice/pad ops are not allowed as they must be in global memory.")
// disallow cache on operation where we require data remain in global memory.
for (auto use : cached_uses) {
NVF_ERROR(
!(use->isOneOf<SliceOp, SelectOp, PadOp>()) &&
!(use->isA<IndexSelectOp>() && use->input(0) == this),
"Right now, caching tensors that are input to the select/slice/pad ops are not allowed as they must be in global memory.");
}

// It also did additional transformation when this tensor is an
// input and the outputs of its consumers have computeAt. Make sure
// we no longer rely on that behavior.
if (isFusionInput()) {
for (const auto& expr : uses()) {
for (const auto& expr : cached_uses) {
for (TensorView* output :
ir_utils::filterByType<TensorView>(expr->outputs())) {
NVF_CHECK(
Expand Down Expand Up @@ -1242,7 +1265,7 @@ TensorView* TensorView::cacheAfter(
// After: This TV -> [Set Op] -> New CA TV -> [Use Op] -> Next TV

// Expr* consumer_uses =
for (auto expr : fusion()->unordered_uses(this)) {
for (auto expr : cached_uses) {
ir_utils::replaceValInExprInputs(expr, this, consumer);
}

Expand Down
Loading

0 comments on commit 2fb5539

Please sign in to comment.