Skip to content

Commit

Permalink
[wip] smem tokens storage is added
Browse files Browse the repository at this point in the history
  • Loading branch information
drzejan2 committed Jan 24, 2024
1 parent 8099e38 commit 15a17a6
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 20 deletions.
42 changes: 32 additions & 10 deletions csrc/device_lower/pass/alias_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -869,17 +869,39 @@ class AllocationInfoMap : private kir::IrVisitor {
const auto expr_pos = scope_map_.getExprPos(expr);

if (auto init = dynamic_cast<kir::MBarrierInit*>(expr)) {
auto alloc_info = getAllocInfoFromTV(init->mbarrier()->as<TensorView>());
alloc_info->inner_live_interval->markWrite(expr_pos);
auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info);
auto write_pos = outer_loop_info ? outer_loop_info->start_pos : expr_pos;
alloc_info->outer_live_interval->markWrite(write_pos);
const auto markWrite = [&expr_pos, this](TensorView* tv) {
auto alloc_info = getAllocInfoFromTV(tv);
alloc_info->inner_live_interval->markWrite(expr_pos);
auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info);
auto write_pos =
outer_loop_info ? outer_loop_info->start_pos : expr_pos;
alloc_info->outer_live_interval->markWrite(write_pos);
};

markWrite(init->mbarrier()->as<TensorView>());

// Register life time start for a smem placeholder with tokens
// returned by MBarrierArriveExpectTx / MBarrierArrive
if (GpuLower::current()->tokenMBarrierMap().count(expr)) {
markWrite(GpuLower::current()->tokenMBarrierMap()[expr]);
}
} else if (auto inval = dynamic_cast<kir::MBarrierInvalidate*>(expr)) {
auto alloc_info = getAllocInfoFromTV(inval->mbarrier()->as<TensorView>());
alloc_info->inner_live_interval->markWrite(expr_pos);
auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info);
auto write_pos = outer_loop_info ? outer_loop_info->start_pos : expr_pos;
alloc_info->outer_live_interval->markWrite(write_pos);
const auto markRead = [&expr_pos, this](TensorView* tv) {
auto alloc_info = getAllocInfoFromTV(tv);
alloc_info->inner_live_interval->markRead(expr_pos);
auto outer_loop_info = ascendLoopNestToSameLevelAs(alloc_info);
auto write_pos =
outer_loop_info ? outer_loop_info->start_pos : expr_pos;
alloc_info->outer_live_interval->markRead(write_pos);
};

markRead(inval->mbarrier()->as<TensorView>());

// Register life time end for a smem placeholder with tokens
// returned by MBarrierArriveExpectTx / MBarrierArrive
if (GpuLower::current()->tokenMBarrierMap().count(expr)) {
markRead(GpuLower::current()->tokenMBarrierMap()[expr]);
}
}
}

Expand Down
9 changes: 7 additions & 2 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -575,7 +575,6 @@ class AllocationInserter : public kir::ExprMutator {
.contiguity(true)
.build();
mbarrier->setMemoryType(MemoryType::Shared);
mbarrier->setMBarrierPlaceholder(true);

kir::Allocate* mbarrier_alloc =
IrBuilder::create<kir::Allocate>(mbarrier, MemoryType::Shared);
Expand Down Expand Up @@ -604,15 +603,21 @@ class AllocationInserter : public kir::ExprMutator {
registerInsertBefore(expr, mbarrier_init, expr_scope);
registerInsertAfter(expr, mbarrier_inval, expr_scope);
GpuLower::current()->ldstMBarrierMap()[expr] = mbarrier;

GpuLower::current()->tokenMBarrierMap()[expr] = mbarrier_tokens;
// Resgiter tokens placeholder for MBarrierInit and MBarrierInvalidate,
// needed to manage life time of smem buffor in alias memory
GpuLower::current()->tokenMBarrierMap()[mbarrier_init] =
mbarrier_tokens;
GpuLower::current()->tokenMBarrierMap()[mbarrier_inval] =
mbarrier_tokens;
} else {
TensorView* mbarrier = TensorViewBuilder()
.shape(std::vector<int64_t>{})
.dtype(DataType::UInt)
.contiguity(true)
.build();
mbarrier->setMemoryType(MemoryType::Shared);
mbarrier->setMBarrierPlaceholder(true);
auto mbarrier_init = IrBuilder::create<kir::MBarrierInit>(
mbarrier, expr->container()->oneVal(DataType::UInt32));
auto mbarrier_inval =
Expand Down
1 change: 1 addition & 0 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,7 @@ void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) {
SimplifyingIrBuilder::maybeCastExpr(DataType::UInt32, expect_bytes);
// set proper placeholder for MBarrierArriveExpectTx token
if (GpuLower::current()->tokenMBarrierMap().count(ldst)) {
NVF_ERROR(false, "Using tokens placeholder is not implemented");
auto mbarrier_tokens = GpuLower::current()->ldstMBarrierMap().at(ldst);
state = lower_utils::u32IndexScalarSmemTv(mbarrier_tokens);
} else {
Expand Down
8 changes: 0 additions & 8 deletions csrc/ir/interface_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -413,13 +413,6 @@ class TensorView : public Val {
return circular_buffer_stage_;
}

void setMBarrierPlaceholder(bool state) {
is_mbarrier_placeholder_ = state;
}
bool isMBarrierPlaceholder() const {
return is_mbarrier_placeholder_;
}

//! Transforms the innermost iterdomains according to the given mma swizzle,
//! this should be used on the tvs that are either inputs/outputs of an
//! MmaOp, or any tv's that are involved in prolog/epilog fusions and need to
Expand Down Expand Up @@ -583,7 +576,6 @@ class TensorView : public Val {
unsigned int max_producer_pos_ = 0;
MemoryType memory_type_ = MemoryType::Local;
bool is_double_buffered_ = false;
bool is_mbarrier_placeholder_ = false;

//! Indicates if the tensor is circular buffered.
bool is_circular_buffered_ = false;
Expand Down

0 comments on commit 15a17a6

Please sign in to comment.