Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/pw_scheduler_reference_find_patc…
Browse files Browse the repository at this point in the history
…h' into HEAD
  • Loading branch information
jjsjann123 committed Dec 12, 2024
2 parents d46323c + 7668d4e commit c70f160
Show file tree
Hide file tree
Showing 22 changed files with 1,503 additions and 525 deletions.
6 changes: 6 additions & 0 deletions csrc/device_lower/lower2device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,12 @@ IdModelOptions getIdModelOptions(Fusion* fusion) {
} else if (expr->isA<MmaOp>()) {
options.setBuildTensorIndexer(true);
continue;
} else if (expr->isOneOf<SliceOp, PadOp>()) {
options.setProducerIndex(true);
options.setConsumerIndex(true);
options.setInlinePredicate(true);
options.setUnswitchPredicate(true);
continue;
} else if (auto reshape = dynamic_cast<ViewOp*>(expr)) {
// The legacy indexer has an issue when an expand broadcast is
// involved in reshape transformations. Enable both tensor and
Expand Down
34 changes: 24 additions & 10 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,19 +413,29 @@ class AllocationDomainSetup : private kir::IrVisitor {
}

// Reorder non-logical allocation domains to follow the ordering of
// the logical domain. This is necessary when an allocation domain
// includes a vectorized loop iter domain since it must be at the
// the set allocation domain. This is necessary when an allocation
// domain includes a vectorized loop iter domain since it must be at the
// innermost position but that may not be the case in the loop
// domain. Not strictly necessary otherwise, but this should also
// domain. It is also necessary when the tensor is a producer of a
// vectorized store. Not strictly necessary otherwise, but this should also
// minimize the deviation from the old indexing scheme which always
// uses the logical domain to index.
//
// Returns reordered allocation domains if reordering is done.
std::optional<std::vector<IterDomain*>> reorderAllocationDomains(
const TensorView* tv,
const std::vector<IterDomain*>& allocation_domains) const {
// Use getMaybeAllocationDomain instead of getLogicalDomain. When
// this tv is a producer of a vectorized store, the consumer
// tensor shoud be a global memory tensor and this is likely a
// cache tensor created by cacheBefore. The consumer tensor may
// have a reordered allocation domain and that dictates the actual
// allocation ordering of this producer local tensor as well. If
// getLogicalDomain is used, DistributedTransformerTest.Backward
// fails at the result validation.
auto exprs = DependencyCheck::getAllExprsBetween(
{tv->getLogicalDomain().begin(), tv->getLogicalDomain().end()},
{tv->getMaybeAllocationDomain().begin(),
tv->getMaybeAllocationDomain().end()},
{allocation_domains.begin(), allocation_domains.end()});

if (exprs.empty()) {
Expand All @@ -434,7 +444,7 @@ class AllocationDomainSetup : private kir::IrVisitor {

// Replay exprs from the logical domain to get the non-reordered
// domains
auto ordered_domains = tv->getLogicalDomain();
auto ordered_domains = tv->getMaybeAllocationDomain();
for (auto expr : exprs) {
// Find the position to insert the outputs.
int64_t insertion_pos = -1;
Expand Down Expand Up @@ -845,14 +855,18 @@ std::vector<Val*> TensorIndexer::getIndexFor(
const auto& replacement_map = getIndexReplacementMap(
expr, as_consumer, info.loop_domains, for_loops, info.index_map);

const auto index_groups = traversalGraph().toGroups(index_ids);
// Note that IDs of index_ids may be mapped as the traversal graph
// is the AlmostExact graph.

std::vector<Val*> result;
result.reserve(index_groups.size());
for (const auto& g : index_groups) {
auto it = info.index_map.find(g);
result.reserve(index_ids.size());
for (IterDomain* index_id : index_ids) {
const auto& index_group = traversalGraph().toGroup(index_id);
auto it = info.index_map.find(index_group);
NVF_ERROR(
it != info.index_map.end(), "Index not found for ", g->toString());
it != info.index_map.end(),
"Index not found for ",
index_id->toString());
result.push_back(
ir_utils::replaceValRecursively(it->second, replacement_map));
}
Expand Down
36 changes: 18 additions & 18 deletions csrc/kernel_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class Predicate final : public Val {

std::string toString(int indent_size = 0) const override;

NVF_API std::string toInlineString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;

PredicateType predicate_type() const {
return ptype_;
Expand Down Expand Up @@ -148,7 +148,7 @@ class Predicate final : public Val {
Val* value_ = nullptr;
};

class NVF_API TensorIndex final : public Val {
class TensorIndex final : public Val {
public:
TensorIndex(
IrBuilderPasskey,
Expand Down Expand Up @@ -252,7 +252,7 @@ class Asm final : public Expr {
//! is required as an intermediate within a kernel. The extent is the expression
//! of the size of the buffer that is generated from the TensorView that
//! describes the output of an operation.
class NVF_API Allocate final : public Expr {
class Allocate final : public Expr {
public:
using Expr::Expr;

Expand Down Expand Up @@ -385,7 +385,7 @@ class NVF_API Allocate final : public Expr {
//
// TODO(kir): change name to SyncThreads as we could have other barriers.
//
class NVF_API BlockSync final : public Expr {
class BlockSync final : public Expr {
public:
using Expr::Expr;

Expand All @@ -408,7 +408,7 @@ class NVF_API BlockSync final : public Expr {

// Synchronize all blocks in device, implies cooperative group launch is
// required.
class NVF_API GridSync final : public Expr {
class GridSync final : public Expr {
public:
using Expr::Expr;

Expand Down Expand Up @@ -436,7 +436,7 @@ class NVF_API GridSync final : public Expr {
};

// PTX: fence.proxy.async
class NVF_API FenceAsyncProxy final : public Expr {
class FenceAsyncProxy final : public Expr {
public:
using Expr::Expr;

Expand All @@ -453,7 +453,7 @@ class NVF_API FenceAsyncProxy final : public Expr {
};

// PTX: wgmma.fence.sync.aligned
class NVF_API WgMmaFence final : public Expr {
class WgMmaFence final : public Expr {
public:
using Expr::Expr;

Expand All @@ -469,7 +469,7 @@ class NVF_API WgMmaFence final : public Expr {
std::string toInlineString(int indent_size = 0) const override;
};

class NVF_API MBarrierInit final : public Expr {
class MBarrierInit final : public Expr {
public:
using Expr::Expr;
explicit MBarrierInit(
Expand All @@ -495,7 +495,7 @@ class NVF_API MBarrierInit final : public Expr {
}
};

class NVF_API MBarrierInvalidate final : public Expr {
class MBarrierInvalidate final : public Expr {
public:
using Expr::Expr;
explicit MBarrierInvalidate(IrBuilderPasskey passkey, Val* mbarrier);
Expand All @@ -514,7 +514,7 @@ class NVF_API MBarrierInvalidate final : public Expr {
}
};

class NVF_API MBarrierArrive final : public Expr {
class MBarrierArrive final : public Expr {
public:
using Expr::Expr;
explicit MBarrierArrive(IrBuilderPasskey passkey, Val* state, Val* mbarrier);
Expand Down Expand Up @@ -544,7 +544,7 @@ class NVF_API MBarrierArrive final : public Expr {
// This is usually used to specify the number of bytes that will be
// transferred for cp.async and cp.async.bulk, so that future mbarrier.wait
// can wait for the completion of the transfer.
class NVF_API MBarrierArriveExpectTx final : public Expr {
class MBarrierArriveExpectTx final : public Expr {
public:
using Expr::Expr;
explicit MBarrierArriveExpectTx(
Expand Down Expand Up @@ -578,7 +578,7 @@ class NVF_API MBarrierArriveExpectTx final : public Expr {
}
};

class NVF_API MBarrierWait final : public Expr {
class MBarrierWait final : public Expr {
public:
using Expr::Expr;
explicit MBarrierWait(IrBuilderPasskey passkey, Val* mbarrier, Val* state);
Expand All @@ -601,7 +601,7 @@ class NVF_API MBarrierWait final : public Expr {
}
};

class NVF_API MBarrierWaitParity final : public Expr {
class MBarrierWaitParity final : public Expr {
public:
using Expr::Expr;
explicit MBarrierWaitParity(
Expand Down Expand Up @@ -796,7 +796,7 @@ class UpdateMagicZero final : public Expr {
//!
//! TODO(kir): this is not a real expression
//!
class NVF_API IfThenElse final : public Expr {
class IfThenElse final : public Expr {
public:
using Expr::Expr;

Expand Down Expand Up @@ -915,7 +915,7 @@ class GridReduction final : public ReductionOp {
}
};

class NVF_API GroupedGridReduction final : public GroupedReductionOp {
class GroupedGridReduction final : public GroupedReductionOp {
public:
using GroupedReductionOp::GroupedReductionOp;

Expand Down Expand Up @@ -1006,7 +1006,7 @@ class NVF_API GroupedGridReduction final : public GroupedReductionOp {
//!
//! This node provides KernelExecutor the information it needs to allocate the
//! broadcast and sync buffers.
class NVF_API GridBroadcast final : public Expr {
class GridBroadcast final : public Expr {
public:
using Expr::Expr;

Expand Down Expand Up @@ -1117,7 +1117,7 @@ class GridWelford final : public Expr {
}
};

class NVF_API GroupedGridWelford final : public GroupedWelfordOp {
class GroupedGridWelford final : public GroupedWelfordOp {
public:
using GroupedWelfordOp::GroupedWelfordOp;

Expand Down Expand Up @@ -1211,7 +1211,7 @@ class NVF_API GroupedGridWelford final : public GroupedWelfordOp {

//! Represents a WelfordOp with the division by count is hoisted out
//! of an innermost loop
class NVF_API VectorizedWelfordOp final : public WelfordOp {
class VectorizedWelfordOp final : public WelfordOp {
public:
using WelfordOp::WelfordOp;

Expand Down
33 changes: 21 additions & 12 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,27 +429,36 @@ c10::intrusive_ptr<c10d::Work> postReduceScatter(
scattered_axis >= 0,
"scattered_axis is expected to be non-negative: ",
scattered_axis);
// reduce_scatter primitive in c10d induces extra buffering time to copy the
// user's input tensors to an internal source buffer. It is therefore always
// preferable to use _reduce_scatter_base (which does not perform any extra
// copy) when the tensors are stored contiguously (i.e., when
// scattered_axis==0). Note however than only nccl supports
// _reduce_scatter_base, not ucc.

std::vector<at::Tensor> input_tensors = at::tensor_split(
input_tensor, communication->team_size(), scattered_axis);
// We could have checked the output shape as well if reduction_axis is
// available. It's not always available via
// `communication->out()->getReductionAxis()` for manually constructed host
// IRs like
// https://github.com/NVIDIA/Fuser/blob/89c47f695b296eb4ffd27984bd4c953fc3f3264b/tests/cpp/test_multidevice_overlap.cpp#L347.
assertBuffersHaveSameSize(input_tensors, {});

// reduce_scatter primitive in c10d induces extra buffering time to copy the
// user's input tensors to an internal source buffer. It is therefore always
// preferable to use _reduce_scatter_base (which does not perform any extra
// copy) when the tensors are stored contiguously (i.e., when
// scattered_axis==0). Note however than only nccl supports
// _reduce_scatter_base, not ucc.
#if defined(NVFUSER_DISTRIBUTED) && defined(USE_C10D_NCCL)
if (scattered_axis == 0 &&
backend->getBackendName() == c10d::NCCL_BACKEND_NAME) {
return backend->_reduce_scatter_base(
output_tensor, input_tensor, {.reduceOp = communication->reduceOp()});
}
#endif
std::vector<std::vector<at::Tensor>> input_tensors(1);
input_tensors[0] = at::split(input_tensor, /*split_size=*/1, scattered_axis);

std::vector<at::Tensor> output_tensors({output_tensor});

assertBufferCount(input_tensors[0], communication->team().size());
std::vector<std::vector<at::Tensor>> input_tensors_vec({input_tensors});
std::vector<at::Tensor> output_tensor_vec({output_tensor});
return backend->reduce_scatter(
output_tensors, input_tensors, {.reduceOp = communication->reduceOp()});
output_tensor_vec,
input_tensors_vec,
{.reduceOp = communication->reduceOp()});
}

c10::intrusive_ptr<c10d::Work> postSendRecv(
Expand Down
Loading

0 comments on commit c70f160

Please sign in to comment.