Skip to content

Commit

Permalink
Refactoy the pack scheduling for scheduleIterAlg = 3.
Browse files Browse the repository at this point in the history
 * Used 3 different pack pools to store the pack instructions of A, B, and Metadata
 * First, only put the required pack into the code (the number of required packs may differ for each mfma iteration).
   Second, put another pack or SNop before the mfma instruction according to the needed latency. the combination of insertion may be 2 packs, 1 pack + snop 0, or snop 1.
  • Loading branch information
vin-huang committed Nov 24, 2024
1 parent 780a5a8 commit 1a0646c
Showing 1 changed file with 138 additions and 40 deletions.
178 changes: 138 additions & 40 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,9 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
packBIdx = 0
packMIdx = 0

numPackedA = 0
numPackedB = 0
numPackedM = 0
#####
# Prepare localReadCode
####
Expand Down Expand Up @@ -832,11 +835,19 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe

instPerPackM = 0
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and not kernel["UnrollMajorLDSMetadata"]:
instPerPackM = 1.5 if self.states.lrvwTileMetadata > 1 and kernel["MIInputPerThreadMetadata"] == 1 else 1

packItems = []
instPerPackM = 1
if self.states.lrvwTileMetadata > 1:
if kernel["MIInputPerThreadMetadata"] == 1:
instPerPackM = 1.5
elif kernel["MIInputPerThreadMetadata"] == 4:
instPerPackM = 3
packItemsA = []
packItemsB = []
packItemsM = []
for iui in range(kernel["InnerUnroll"]):
packINtems = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
packINtemsA = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
packINtemsB = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
packINtemsM = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
packA = packCode.findNamedItem("packA_I%s"%(iui))
packB = packCode.findNamedItem("packB_I%s"%(iui))
packM = packCode.findNamedItem("packMetadata_I%s"%(iui))
Expand All @@ -856,58 +867,72 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
if packAItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8():
for n in range(instPerPackA):
packINtems[0].append(packAItems.pop(0))
packINtemsA[0].append(packAItems.pop(0))
else:
for j in range(self.states.numReadsIterCoalescedA):
for n in range(instPerPackA):
packINtems[j].append(packAItems.pop(0))
packINtemsA[j].append(packAItems.pop(0))

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
for j in range(self.states.numReadsIterCoalescedMetadata):
for n in range(ceil(instPerPackM)):
if packMItems:
packINtems[j].append(packMItems.pop(0))
packINtemsM[j].append(packMItems.pop(0))
else:
break

if packBItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8():
for n in range(instPerPackB):
packINtems[0].append(packBItems.pop(0))
packINtemsB[0].append(packBItems.pop(0))
else:
for j in range(self.states.numReadsIterCoalescedB):
for n in range(instPerPackB):
packINtems[j].append(packBItems.pop(0))
packINtemsB[j].append(packBItems.pop(0))

while packAItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8():
for n in range(instPerPackA):
packINtems[0].append(packAItems.pop(0))
if packAItems:
packINtemsA[0].append(packAItems.pop(0))
else:
break
else:
for j in range(self.states.numReadsIterCoalescedA):
for n in range(instPerPackA):
packINtems[j].append(packAItems.pop(0))
if packAItems:
packINtemsA[j].append(packAItems.pop(0))
else:
break

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
while packMItems:
for j in range(self.states.numReadsIterCoalescedMetadata):
for n in range(ceil(instPerPackM)):
if packMItems:
packINtems[j].append(packMItems.pop(0))
packINtemsM[j].append(packMItems.pop(0))
else:
break

while packBItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8():
for n in range(instPerPackB):
packINtems[0].append(packBItems.pop(0))
if packBItems:
packINtemsB[0].append(packBItems.pop(0))
else:
break
else:
for j in range(self.states.numReadsIterCoalescedB):
for n in range(instPerPackB):
packINtems[j].append(packBItems.pop(0))
if packBItems:
packINtemsB[j].append(packBItems.pop(0))
else:
break

for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB)):
packItems += packINtems.pop(0)
for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)):
packItemsA += packINtemsA.pop(0)
packItemsB += packINtemsB.pop(0)
packItemsM += packINtemsM.pop(0)

# remove s_nop for packing
# we will add s_nop if needed
Expand Down Expand Up @@ -1053,7 +1078,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
numLocalReadShouldSchedule = 0
# prefetch load for next wave tile along M since we re-use B first.
tileM: int = kernel["MIWaveTileA"]
instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItems
instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItemsA + packItemsM + packItemsB
localReadItemsThisLoop = sorted(localReadItemsThisLoop, key=lambda o: hasAnyDependency(o, instsToCheck), reverse=True)

for lr in localReadItemsThisLoop:
Expand Down Expand Up @@ -1241,7 +1266,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
mfmas = [mfma for mfma in macIterCode.flatitems() if isinstance(mfma, (MFMAInstruction, SMFMAInstruction,))]
## To support do["MAC"] is False
mfma = [mfmas[i],] if len(mfmas) > 0 else []
instsToCheck = mfma + packItems
instsToCheck = mfma + packItemsA + packItemsM + packItemsB
numDsInsts = 0
lastLgkmCnt = -1
for ds in filter(lambda j: isinstance(j, (DSLoadInstruction, DSStoreInstruction, SWaitCnt)), reversed(prevIterCode.flatitems() + iterCode.flatitems())):
Expand Down Expand Up @@ -1271,18 +1296,28 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
####
# scheduled pack
####
if packItems:
_instPerPackA = 0
_instPerPackB = 0
_instPerPackM = 0
instPerPackA_r = instPerPackA
instPerPackB_r = instPerPackB
instPerPackM_r = instPerPackM
if packItemsA or packItemsB or packItemsM:
# how many pack have to be done
# calculate the data index of this mfma used for A and B
# if i // kernel["MIWaveTile"][0]==0, mfma will use new A (need to take iu into account)
# if i % kernel["MIWaveTile"][0]==0, mfma will use new B
packAIdx += instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
packBIdx += instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0
_instPerPackA = instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
packAIdx += _instPerPackA
_instPerPackB = instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0
packBIdx += _instPerPackB
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
if kernel["ProblemType"]["Sparse"] == 2:
packMIdx += instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0
_instPerPackM = instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0
else:
packMIdx += instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
_instPerPackM = instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0

packMIdx += _instPerPackM
# blockWidth < 1, means 0.5 or 0.25 (BF,H,Int8)
if self.states.archCaps["HasEccHalf"] or not self.states.asmCaps["HasWMMA_V1"]:
packAIdx = packAIdx if tPA["bpe"] < 4 and (not kernel["UnrollMajorLDSA"] or kernel["ConvertAfterDS"]) else 0
Expand All @@ -1298,33 +1333,96 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u, packMIdx:%u" %(packAIdx,packBIdx,packMIdx))
else:
iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u" %(packAIdx,packBIdx))
# we put 2 pack in each mfma
for j in range(instPerPackA):
if packItems:
iterCode.add(packItems.pop(0))

# 1st round
# put the required pack into mfma iter
for j in range(_instPerPackA):
if packItemsA:
# Skip if the required pack has already been placed in the previous mfma iter.
if numPackedA >= packAIdx:
break
iterCode.add(packItemsA.pop(0))
curPackIdx += 1
numPackedA += 1
instPerPackA_r -= 1

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
for j in range(ceil(instPerPackM)):
if packItems:
iterCode.add(packItems.pop(0))
for j in range(ceil(_instPerPackM)):
if packItemsM:
# Skip if the required pack has already been placed in the previous mfma iter.
if numPackedM >= packMIdx:
break
iterCode.add(packItemsM.pop(0))
curPackIdx += 1
for j in range(instPerPackB):
if packItems:
iterCode.add(packItems.pop(0))
numPackedM += 1
instPerPackM_r -= 1

for j in range(_instPerPackB):
if packItemsB:
# Skip if the required pack has already been placed in the previous mfma iter.
if numPackedB >= packBIdx:
break
iterCode.add(packItemsB.pop(0))
curPackIdx += 1
numPackedB += 1
instPerPackB_r -= 1

# 2nd round
# put the desired pack into mfma iter
if instPerPackA_r > 0:
for j in range(instPerPackA_r):
if packItemsA:
iterCode.add(packItemsA.pop(0))
curPackIdx += 1
numPackedA += 1

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and instPerPackM_r > 0:
for j in range(ceil(instPerPackM_r)):
if packItemsM:
iterCode.add(packItemsM.pop(0))
curPackIdx += 1
numPackedM += 1

if instPerPackB_r > 0:
for j in range(instPerPackB_r):
if packItemsB:
iterCode.add(packItemsB.pop(0))
curPackIdx += 1
numPackedB += 1

remainLatency = 2 if curPackIdx < numPack + 2 else 0

# since packed register need to wait 2 quad cycle to finish packing
# we insert pack instruction if we can, or s_nop
while curPackIdx < numPack+2:
if packItems:
iterCode.add(packItems.pop(0))
while remainLatency > 0:
if packItemsA:
iterCode.add(packItemsA.pop(0))
curPackIdx += 1
numPackedA += 1
remainLatency -= 1
elif packItemsM:
iterCode.add(packItemsM.pop(0))
curPackIdx += 1
numPackedM += 1
remainLatency -= 1
elif packItemsB:
iterCode.add(packItemsB.pop(0))
curPackIdx += 1
numPackedB += 1
remainLatency -= 1
else:
iterCode.add(SNop(waitState=1, comment="VALU packing writes to be consumed by matrix instruction"))
latency = remainLatency - 1
iterCode.add(SNop(waitState=latency, comment="VALU packing writes to be consumed by matrix instruction"))
curPackIdx += 1
break
remainLatency -= (latency+1)

if i == numMfmaPerIter - 1:
while packItems:
iterCode.add(packItems.pop(0))
while packItemsA:
iterCode.add(packItemsA.pop(0))
while packItemsM:
iterCode.add(packItemsM.pop(0))
while packItemsB:
iterCode.add(packItemsB.pop(0))

####
# scheduled mfma dependency
Expand Down

0 comments on commit 1a0646c

Please sign in to comment.