Skip to content

Commit

Permalink
Pop the pack items by round robin to make sure the remain items will be
Browse files Browse the repository at this point in the history
pop in the same round of mfma.
  • Loading branch information
vin-huang committed Nov 18, 2024
1 parent 3ce620f commit 2a5d3cc
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,34 +879,45 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
for n in range(instPerPackB):
packINtems[j].append(packBItems.pop(0))

while packAItems:
#pop the packItems: A, Metadata and B by round robin.
while packAItems or packBItems or packMItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8():
for n in range(instPerPackA):
packINtems[0].append(packAItems.pop(0))
if packAItems:
packINtems[0].append(packAItems.pop(0))
else:
break
else:
for j in range(self.states.numReadsIterCoalescedA):
for n in range(instPerPackA):
packINtems[j].append(packAItems.pop(0))
if packAItems:
packINtems[j].append(packAItems.pop(0))
else:
break

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
while packMItems:
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
for j in range(self.states.numReadsIterCoalescedMetadata):
for n in range(ceil(instPerPackM)):
if packMItems:
packINtems[j].append(packMItems.pop(0))
else:
break

while packBItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8():
for n in range(instPerPackB):
packINtems[0].append(packBItems.pop(0))
if packBItems:
packINtems[0].append(packBItems.pop(0))
else:
break
else:
for j in range(self.states.numReadsIterCoalescedB):
for n in range(instPerPackB):
packINtems[j].append(packBItems.pop(0))
if packBItems:
packINtems[j].append(packBItems.pop(0))
else:
break;

for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB)):
for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)):
packItems += packINtems.pop(0)

# remove s_nop for packing
Expand Down

0 comments on commit 2a5d3cc

Please sign in to comment.