Skip to content

Commit

Permalink
Refactoy the pack scheduling for scheduleIterAlg = 3.
Browse files Browse the repository at this point in the history
 * Used 3 different pack pools to store the pack instructions of A, B, and Metadata
 * First, only put the required pack into the code (the number of required packs may differ for each mfma iteration).
   Second, put another pack or SNop before the mfma instruction according to the needed latency. the combination of insertion may be 2 packs, 1 pack + snop 0, or snop 1.
  • Loading branch information
vin-huang committed Nov 21, 2024
1 parent 3ce620f commit 14bd624
Showing 1 changed file with 116 additions and 47 deletions.
163 changes: 116 additions & 47 deletions tensilelite/Tensile/KernelWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,9 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
packBIdx = 0
packMIdx = 0

numPackedA = 0
numPackedB = 0
numPackedM = 0
#####
# Prepare localReadCode
####
Expand Down Expand Up @@ -832,11 +835,19 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe

instPerPackM = 0
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and not kernel["UnrollMajorLDSMetadata"]:
instPerPackM = 1.5 if self.states.lrvwTileMetadata > 1 and kernel["MIInputPerThreadMetadata"] == 1 else 1

packItems = []
instPerPackM = 1
if self.states.lrvwTileMetadata > 1:
if kernel["MIInputPerThreadMetadata"] == 1:
instPerPackM = 1.5
elif kernel["MIInputPerThreadMetadata"] == 4:
instPerPackM = 3
packItemsA = []
packItemsB = []
packItemsM = []
for iui in range(kernel["InnerUnroll"]):
packINtems = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
packINtemsA = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
packINtemsB = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
packINtemsM = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ]
packA = packCode.findNamedItem("packA_I%s"%(iui))
packB = packCode.findNamedItem("packB_I%s"%(iui))
packM = packCode.findNamedItem("packMetadata_I%s"%(iui))
Expand All @@ -856,58 +867,72 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe
if packAItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8():
for n in range(instPerPackA):
packINtems[0].append(packAItems.pop(0))
packINtemsA[0].append(packAItems.pop(0))
else:
for j in range(self.states.numReadsIterCoalescedA):
for n in range(instPerPackA):
packINtems[j].append(packAItems.pop(0))
packINtemsA[j].append(packAItems.pop(0))

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
for j in range(self.states.numReadsIterCoalescedMetadata):
for n in range(ceil(instPerPackM)):
if packMItems:
packINtems[j].append(packMItems.pop(0))
packINtemsM[j].append(packMItems.pop(0))
else:
break

if packBItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8():
for n in range(instPerPackB):
packINtems[0].append(packBItems.pop(0))
packINtemsB[0].append(packBItems.pop(0))
else:
for j in range(self.states.numReadsIterCoalescedB):
for n in range(instPerPackB):
packINtems[j].append(packBItems.pop(0))
packINtemsB[j].append(packBItems.pop(0))

while packAItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8():
for n in range(instPerPackA):
packINtems[0].append(packAItems.pop(0))
if packAItems:
packINtemsA[0].append(packAItems.pop(0))
else:
break
else:
for j in range(self.states.numReadsIterCoalescedA):
for n in range(instPerPackA):
packINtems[j].append(packAItems.pop(0))
if packAItems:
packINtemsA[j].append(packAItems.pop(0))
else:
break

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
while packMItems:
while packMItems:
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
for j in range(self.states.numReadsIterCoalescedMetadata):
for n in range(ceil(instPerPackM)):
if packMItems:
packINtems[j].append(packMItems.pop(0))
packINtemsM[j].append(packMItems.pop(0))
else:
break

while packBItems:
if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8():
for n in range(instPerPackB):
packINtems[0].append(packBItems.pop(0))
if packBItems:
packINtemsB[0].append(packBItems.pop(0))
else:
break
else:
for j in range(self.states.numReadsIterCoalescedB):
for n in range(instPerPackB):
packINtems[j].append(packBItems.pop(0))
if packBItems:
packINtemsB[j].append(packBItems.pop(0))
else:
break

for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB)):
packItems += packINtems.pop(0)
for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)):
packItemsA += packINtemsA.pop(0)
packItemsB += packINtemsB.pop(0)
packItemsM += packINtemsM.pop(0)

# remove s_nop for packing
# we will add s_nop if needed
Expand Down Expand Up @@ -1053,7 +1078,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
numLocalReadShouldSchedule = 0
# prefetch load for next wave tile along M since we re-use B first.
tileM: int = kernel["MIWaveTileA"]
instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItems
instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItemsA + packItemsM + packItemsB
localReadItemsThisLoop = sorted(localReadItemsThisLoop, key=lambda o: hasAnyDependency(o, instsToCheck), reverse=True)

for lr in localReadItemsThisLoop:
Expand Down Expand Up @@ -1241,7 +1266,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
mfmas = [mfma for mfma in macIterCode.flatitems() if isinstance(mfma, (MFMAInstruction, SMFMAInstruction,))]
## To support do["MAC"] is False
mfma = [mfmas[i],] if len(mfmas) > 0 else []
instsToCheck = mfma + packItems
instsToCheck = mfma + packItemsA + packItemsM + packItemsB
numDsInsts = 0
lastLgkmCnt = -1
for ds in filter(lambda j: isinstance(j, (DSLoadInstruction, DSStoreInstruction, SWaitCnt)), reversed(prevIterCode.flatitems() + iterCode.flatitems())):
Expand Down Expand Up @@ -1271,18 +1296,25 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
####
# scheduled pack
####
if packItems:
_instPerPackA = 0
_instPerPackB = 0
_instPerPackM = 0
if packItemsA or packItemsB or packItemsM:
# how many pack have to be done
# calculate the data index of this mfma used for A and B
# if i // kernel["MIWaveTile"][0]==0, mfma will use new A (need to take iu into account)
# if i % kernel["MIWaveTile"][0]==0, mfma will use new B
packAIdx += instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
packBIdx += instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0
_instPerPackA = instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
packAIdx += _instPerPackA
_instPerPackB = instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0
packBIdx += _instPerPackB
if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
if kernel["ProblemType"]["Sparse"] == 2:
packMIdx += instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0
_instPerPackM = instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0
else:
packMIdx += instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0
_instPerPackM = instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0

packMIdx += _instPerPackM
# blockWidth < 1, means 0.5 or 0.25 (BF,H,Int8)
if self.states.archCaps["HasEccHalf"] or not self.states.asmCaps["HasWMMA_V1"]:
packAIdx = packAIdx if tPA["bpe"] < 4 and (not kernel["UnrollMajorLDSA"] or kernel["ConvertAfterDS"]) else 0
Expand All @@ -1298,33 +1330,70 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]):
iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u, packMIdx:%u" %(packAIdx,packBIdx,packMIdx))
else:
iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u" %(packAIdx,packBIdx))
# we put 2 pack in each mfma
for j in range(instPerPackA):
if packItems:
iterCode.add(packItems.pop(0))

# put the required pack into mfma
for j in range(_instPerPackA):
if packItemsA:
# Skip if the required pack has already been placed in the previous mfma iter.
if numPackedA >= packAIdx:
break
iterCode.add(packItemsA.pop(0))
curPackIdx += 1
numPackedA += 1

if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]:
for j in range(ceil(instPerPackM)):
if packItems:
iterCode.add(packItems.pop(0))
for j in range(ceil(_instPerPackM)):
if packItemsM:
# Skip if the required pack has already been placed in the previous mfma iter.
if numPackedM >= packMIdx:
break
iterCode.add(packItemsM.pop(0))
curPackIdx += 1
for j in range(instPerPackB):
if packItems:
iterCode.add(packItems.pop(0))
curPackIdx += 1
# since packed register need to wait 2 quad cycle to finish packing
# we insert pack instruction if we can, or s_nop
while curPackIdx < numPack+2:
if packItems:
iterCode.add(packItems.pop(0))
curPackIdx += 1
else:
iterCode.add(SNop(waitState=1, comment="VALU packing writes to be consumed by matrix instruction"))
numPackedM += 1

for j in range(_instPerPackB):
if packItemsB:
# Skip if the required pack has already been placed in the previous mfma iter.
if numPackedB >= packBIdx:
break
iterCode.add(packItemsB.pop(0))
curPackIdx += 1
break
numPackedB += 1

# put unnecessary pack into mfma to fulfill the latency
remainLatency = 2
if curPackIdx < numPack + 2 :
# since packed register need to wait 2 quad cycle to finish packing
# we insert pack instruction if we can, or s_nop
while remainLatency:
if packItemsA:
iterCode.add(packItemsA.pop(0))
curPackIdx += 1
numPackedA += 1
remainLatency -= 1
elif packItemsM:
iterCode.add(packItemsM.pop(0))
curPackIdx += 1
numPackedM += 1
remainLatency -= 1
elif packItemsB:
iterCode.add(packItemsB.pop(0))
curPackIdx += 1
numPackedB += 1
remainLatency -= 1
else:
latency = remainLatency - 1
iterCode.add(SNop(waitState=latency, comment="VALU packing writes to be consumed by matrix instruction"))
curPackIdx += 1
remainLatency -= (latency+1)

if i == numMfmaPerIter - 1:
while packItems:
iterCode.add(packItems.pop(0))
while packItemsA:
iterCode.add(packItemsA.pop(0))
while packItemsM:
iterCode.add(packItemsM.pop(0))
while packItemsB:
iterCode.add(packItemsB.pop(0))

####
# scheduled mfma dependency
Expand Down

0 comments on commit 14bd624

Please sign in to comment.