Skip to content

Commit

Permalink
[wip] fixing or breaking indexing?
Browse files Browse the repository at this point in the history
  • Loading branch information
drzejan2 committed Feb 22, 2024
1 parent 8e0b625 commit b0fd2be
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 25 deletions.
16 changes: 15 additions & 1 deletion csrc/device_lower/lower2device.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,14 @@ class GpuLower : public NonCopyable {
return mbarrier_token_smem_alloc_set_;
}

std::unordered_map<const Expr*, Val*>& ldstMBarrierIndexMap() {
return ldst_mbarrier_index_map_;
}

const std::unordered_map<const Expr*, Val*>& ldstMBarrierIndexMap() const {
return ldst_mbarrier_index_map_;
}

bool isNvFuserZeroEnabled() {
if (isOptionDisabled(DisableOption::MagicZero)) {
return false;
Expand Down Expand Up @@ -313,12 +321,18 @@ class GpuLower : public NonCopyable {
std::unordered_map<const Expr*, TensorView*> ldst_mbarrier_map_;

// Keep track of placeholders for tokens returned by arrive/expected tx
// mbarrier operations
// mbarrier operations for each load/store operation that requires such
// synchronization
std::unordered_map<const Expr*, TensorView*> ldst_mbarrier_token_map_;

// Collection of kir::Allocate for smem buffers used for mbarrier and token
// objects from cpAsyncBulk synchronization
std::unordered_set<const Expr*> mbarrier_token_smem_alloc_set_;

// Keep track what mbarrier object is used in load/store operation that
// requires such synchronization, required by indexing pass
std::unordered_map<const Expr*, Val*> ldst_mbarrier_index_map_;

Fusion* fusion_ = nullptr;
};

Expand Down
11 changes: 11 additions & 0 deletions csrc/device_lower/pass/double_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,12 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
->withPredicate(ldst->predicate());

body.push_back(new_ldst);

// Register mbarrier object to be used with new LoadStoreOp
// from prolog loop
gpu_lower->ldstMBarrierIndexMap()[new_ldst] =
mbarrier_arrive_tx->mbarrier();

cloned_scopes_.back()->push_back(if_expr);
#ifdef EXTRA_LOGS
std::cout << "[DEBUG] new MBarrierArriveExpectTx node: "
Expand Down Expand Up @@ -531,6 +537,11 @@ class DoubleBufferLoopCloner : public kir::IrVisitor {
body.push_back(mbarrier_arrive_tx);
body.push_back(ldst);

// Register mbarrier object to be used with LoadStoreOp
// from main loop
gpu_lower->ldstMBarrierIndexMap()[ldst] =
mbarrier_arrive_tx->mbarrier();

cloned_scopes_.back()->push_back(if_expr);

// Construct mBarrier::wait for current stage
Expand Down
60 changes: 36 additions & 24 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1409,41 +1409,46 @@ void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) {
auto out_tv = ldst->out()->as<TensorView>();
auto in_tv = ldst->in()->as<TensorView>();

// mbarrier and tokens/state placeholders are smem arrays due to
// double/circullar buffering? check if tokens have a smem placeholder
const auto is_smem_array =
(0 != GpuLower::current()->ldstMBarrierTokenMap().count(ldst));
auto gpu_lower = GpuLower::current();

// If LoadStoreOp has a smem TV in ldstMBarrierTokenMap then it is a part
// of double buffer loop, and it has been already handled by double buffer
// pass - kir nodes for arrive/expect tx and wait are added in the proper
// scope.
// Otherwise, add these nodes here, at this stage.
const auto is_double_buffered =
(0 != gpu_lower->ldstMBarrierTokenMap().count(ldst));

// indexing mbarrier
auto mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst);
auto mbarrier = gpu_lower->ldstMBarrierMap().at(ldst);
Val* mbarrier_index = nullptr;
if (is_smem_array) {
mbarrier_index = IrBuilder::create<kir::TensorIndex>(
mbarrier, GpuLower::current()->kernel()->zeroVal());
if (is_double_buffered) {
NVF_ERROR(
(gpu_lower->ldstMBarrierIndexMap().count(ldst) != 0),
"Expected LoadStoreOp participating in double buffering ",
"loop to have a defined mbarrier index object, one to be ",
"used for generating cpAsyncBulk.");
mbarrier_index = gpu_lower->ldstMBarrierIndexMap()[ldst];
} else {
mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier);
}

// arrive and expect_tx mbarrier
Val* state = nullptr;
Val* expect_bytes = IrBuilder::create<Val>(dataTypeSize(in_tv->dtype()));
for (auto id : in_tv->getLeafDomain()) {
expect_bytes = SimplifyingIrBuilder::mulExpr(expect_bytes, id->extent());
}
expect_bytes =
SimplifyingIrBuilder::maybeCastExpr(DataType::UInt32, expect_bytes);
if (!is_double_buffered) {
Val* expect_bytes = IrBuilder::create<Val>(dataTypeSize(in_tv->dtype()));
for (auto id : in_tv->getLeafDomain()) {
expect_bytes = SimplifyingIrBuilder::mulExpr(expect_bytes, id->extent());
}
expect_bytes =
SimplifyingIrBuilder::maybeCastExpr(DataType::UInt32, expect_bytes);

if (is_smem_array) {
auto mbarrier_tokens = GpuLower::current()->ldstMBarrierTokenMap().at(ldst);
state = IrBuilder::create<kir::TensorIndex>(
mbarrier_tokens, GpuLower::current()->kernel()->zeroVal());
} else {
state = IrBuilder::create<Val>(DataType::UInt);
pushBack(IrBuilder::create<kir::Allocate>(
state, MemoryType::Local, ldst->container()->oneVal()));
pushBack(IrBuilder::create<kir::MBarrierArriveExpectTx>(
state, mbarrier_index, expect_bytes));
}
pushBack(IrBuilder::create<kir::MBarrierArriveExpectTx>(
state, mbarrier_index, expect_bytes));

// indexing ldst op
auto out = lowerDstIndex(ldst->out(), {}, true);
Expand All @@ -1452,9 +1457,16 @@ void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) {
IrBuilder::create<LoadStoreOp>(ldst->opType(), out, in, ldst->cacheOp())
->withPredicate(ldst->predicate());
pushBack(new_ldst);
GpuLower::current()->propagateExprInfo(ldst, back());
// wait mbarrier
pushBack(IrBuilder::create<kir::MBarrierWait>(mbarrier_index, state));

// register new LoadStoreOp with mbarrier index
gpu_lower->ldstMBarrierIndexMap()[new_ldst] = mbarrier_index;

gpu_lower->propagateExprInfo(ldst, back());

if (!is_double_buffered) {
// wait mbarrier
pushBack(IrBuilder::create<kir::MBarrierWait>(mbarrier_index, state));
}
}

void IndexLowering::handleCpAsyncBulkStore(const LoadStoreOp* ldst) {
Expand Down

0 comments on commit b0fd2be

Please sign in to comment.