Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TMA support for circular buffering pass #2833

Merged
merged 40 commits into from
Sep 29, 2024
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
12db3ee
Add allocation changes
rdspring1 Aug 21, 2024
2491171
Add Indexing changes
rdspring1 Aug 21, 2024
6d8ad5f
Add circular buffering pass and testing
rdspring1 Aug 22, 2024
25c482d
Merge branch 'main' of https://github.com/nvidia/fuser into tma_cb
rdspring1 Sep 8, 2024
2f8d9e9
predicate and mbarrier changes
rdspring1 Sep 5, 2024
2a06157
add mbarrier_wait immediately
rdspring1 Sep 7, 2024
0c8858f
skip expressions_allocated_in_main_loop
rdspring1 Sep 5, 2024
f8123af
Ensure a full warp exists if there is elect sync predicate
rdspring1 Sep 9, 2024
508d674
comments≈
rdspring1 Sep 9, 2024
d4c7938
Merge branch 'main' of https://github.com/nvidia/fuser into tma_cb
rdspring1 Sep 9, 2024
ccfedfc
Add compatibility check for elect sync
rdspring1 Sep 16, 2024
f29aa22
add test for elect sync compatibility
rdspring1 Sep 16, 2024
f685a9b
Use MBarrierArrive
rdspring1 Sep 16, 2024
b84eb96
comments
rdspring1 Sep 17, 2024
c1fdec5
Add has_elect_sync_predicate to kernel_summary
rdspring1 Sep 17, 2024
4f011b5
Merge branch 'main' of https://github.com/nvidia/fuser into tma_cb
rdspring1 Sep 18, 2024
d132ef4
minor fixes
rdspring1 Sep 21, 2024
c393776
use inlineAt and inlineMost
rdspring1 Sep 22, 2024
dc9f20e
add string exception check
rdspring1 Sep 22, 2024
4058d73
clean-up
rdspring1 Sep 22, 2024
4cb4342
comment
rdspring1 Sep 23, 2024
cdbf609
comment
rdspring1 Sep 23, 2024
b6c4e20
comment
rdspring1 Sep 23, 2024
b137637
comment
rdspring1 Sep 23, 2024
c319e5c
generalize short-circuit
rdspring1 Sep 23, 2024
e8c7fd5
comment
rdspring1 Sep 23, 2024
95e7bd0
use scalar hoisting
rdspring1 Sep 23, 2024
89b61bb
rename
rdspring1 Sep 23, 2024
444252d
Merge branch 'main' into tma_cb
rdspring1 Sep 25, 2024
64ef3cb
Create TmaCircularBufferInfo to consolidate data fields. (#3004)
rdspring1 Sep 25, 2024
7ebffe3
Initialize and invalidate mbarrier in allocation pass
rdspring1 Sep 26, 2024
508dbb0
comments
rdspring1 Sep 26, 2024
7ee80d3
move to allocation pass
rdspring1 Sep 26, 2024
1f5bed1
create TmaCircularBufferInfo class
rdspring1 Sep 26, 2024
95cbba1
Merge branch 'main' into tma_cb
rdspring1 Sep 26, 2024
0a6abcd
rename CloneTmaCircularBufferLoopAndInsertSync
rdspring1 Sep 27, 2024
b9cc784
comments
rdspring1 Sep 27, 2024
afb4e1c
Add PointwiseCpAsync failing test
rdspring1 Sep 27, 2024
3bcc32c
Merge branch 'main' into tma_cb
rdspring1 Sep 28, 2024
5df582a
Merge branch 'main' into tma_cb
rdspring1 Sep 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions csrc/device_lower/pass/allocation.cpp
rdspring1 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -483,10 +483,19 @@ class AllocationInserter : public kir::ExprMutator {
.build();
mbarrier->setMemoryType(MemoryType::Shared);

// The wait condition for mbarrier is a single thread and the expected
// number of transaction bytes
kir::MBarrierInit* mbarrier_init = IrBuilder::create<kir::MBarrierInit>(
mbarrier, expr->container()->oneVal(DataType::UInt32));
// Get all threads in CTA
NamedScalar* bdimx = NamedScalar::getParallelDim(ParallelType::TIDx);
NamedScalar* bdimy = NamedScalar::getParallelDim(ParallelType::TIDy);
NamedScalar* bdimz = NamedScalar::getParallelDim(ParallelType::TIDz);
rdspring1 marked this conversation as resolved.
Show resolved Hide resolved
Val* all_threads_in_cta = SimplifyingIrBuilder::mulExpr(
bdimx, SimplifyingIrBuilder::mulExpr(bdimy, bdimz));
all_threads_in_cta = SimplifyingIrBuilder::maybeCastExpr(
DataType::UInt32, all_threads_in_cta);

// The wait condition for mbarrier is a all participating threads in CTA
// and the expected number of transaction bytes
kir::MBarrierInit* mbarrier_init =
IrBuilder::create<kir::MBarrierInit>(mbarrier, all_threads_in_cta);

kir::Allocate* mbarrier_alloc =
IrBuilder::create<kir::Allocate>(mbarrier, MemoryType::Shared);
Expand Down
Loading
Loading