From 1a0646cce1a9ebede3dcc4e5f8b7339c570ebefd Mon Sep 17 00:00:00 2001 From: Vin Huang Date: Fri, 15 Nov 2024 10:03:28 +0800 Subject: [PATCH] Refactoy the pack scheduling for scheduleIterAlg = 3. * Used 3 different pack pools to store the pack instructions of A, B, and Metadata * First, only put the required pack into the code (the number of required packs may differ for each mfma iteration). Second, put another pack or SNop before the mfma instruction according to the needed latency. the combination of insertion may be 2 packs, 1 pack + snop 0, or snop 1. --- tensilelite/Tensile/KernelWriter.py | 178 +++++++++++++++++++++------- 1 file changed, 138 insertions(+), 40 deletions(-) diff --git a/tensilelite/Tensile/KernelWriter.py b/tensilelite/Tensile/KernelWriter.py index 5af14ca0c3..fec67efc6e 100644 --- a/tensilelite/Tensile/KernelWriter.py +++ b/tensilelite/Tensile/KernelWriter.py @@ -758,6 +758,9 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe packBIdx = 0 packMIdx = 0 + numPackedA = 0 + numPackedB = 0 + numPackedM = 0 ##### # Prepare localReadCode #### @@ -832,11 +835,19 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe instPerPackM = 0 if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and not kernel["UnrollMajorLDSMetadata"]: - instPerPackM = 1.5 if self.states.lrvwTileMetadata > 1 and kernel["MIInputPerThreadMetadata"] == 1 else 1 - - packItems = [] + instPerPackM = 1 + if self.states.lrvwTileMetadata > 1: + if kernel["MIInputPerThreadMetadata"] == 1: + instPerPackM = 1.5 + elif kernel["MIInputPerThreadMetadata"] == 4: + instPerPackM = 3 + packItemsA = [] + packItemsB = [] + packItemsM = [] for iui in range(kernel["InnerUnroll"]): - packINtems = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsA = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsB = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsM = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] packA = packCode.findNamedItem("packA_I%s"%(iui)) packB = packCode.findNamedItem("packB_I%s"%(iui)) packM = packCode.findNamedItem("packMetadata_I%s"%(iui)) @@ -856,58 +867,72 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe if packAItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8(): for n in range(instPerPackA): - packINtems[0].append(packAItems.pop(0)) + packINtemsA[0].append(packAItems.pop(0)) else: for j in range(self.states.numReadsIterCoalescedA): for n in range(instPerPackA): - packINtems[j].append(packAItems.pop(0)) + packINtemsA[j].append(packAItems.pop(0)) if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: for j in range(self.states.numReadsIterCoalescedMetadata): for n in range(ceil(instPerPackM)): if packMItems: - packINtems[j].append(packMItems.pop(0)) + packINtemsM[j].append(packMItems.pop(0)) else: break if packBItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8(): for n in range(instPerPackB): - packINtems[0].append(packBItems.pop(0)) + packINtemsB[0].append(packBItems.pop(0)) else: for j in range(self.states.numReadsIterCoalescedB): for n in range(instPerPackB): - packINtems[j].append(packBItems.pop(0)) + packINtemsB[j].append(packBItems.pop(0)) while packAItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8(): for n in range(instPerPackA): - packINtems[0].append(packAItems.pop(0)) + if packAItems: + packINtemsA[0].append(packAItems.pop(0)) + else: + break else: for j in range(self.states.numReadsIterCoalescedA): for n in range(instPerPackA): - packINtems[j].append(packAItems.pop(0)) + if packAItems: + packINtemsA[j].append(packAItems.pop(0)) + else: + break if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: while packMItems: for j in range(self.states.numReadsIterCoalescedMetadata): for n in range(ceil(instPerPackM)): if packMItems: - packINtems[j].append(packMItems.pop(0)) + packINtemsM[j].append(packMItems.pop(0)) else: break while packBItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8(): for n in range(instPerPackB): - packINtems[0].append(packBItems.pop(0)) + if packBItems: + packINtemsB[0].append(packBItems.pop(0)) + else: + break else: for j in range(self.states.numReadsIterCoalescedB): for n in range(instPerPackB): - packINtems[j].append(packBItems.pop(0)) + if packBItems: + packINtemsB[j].append(packBItems.pop(0)) + else: + break - for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB)): - packItems += packINtems.pop(0) + for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)): + packItemsA += packINtemsA.pop(0) + packItemsB += packINtemsB.pop(0) + packItemsM += packINtemsM.pop(0) # remove s_nop for packing # we will add s_nop if needed @@ -1053,7 +1078,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): numLocalReadShouldSchedule = 0 # prefetch load for next wave tile along M since we re-use B first. tileM: int = kernel["MIWaveTileA"] - instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItems + instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItemsA + packItemsM + packItemsB localReadItemsThisLoop = sorted(localReadItemsThisLoop, key=lambda o: hasAnyDependency(o, instsToCheck), reverse=True) for lr in localReadItemsThisLoop: @@ -1241,7 +1266,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): mfmas = [mfma for mfma in macIterCode.flatitems() if isinstance(mfma, (MFMAInstruction, SMFMAInstruction,))] ## To support do["MAC"] is False mfma = [mfmas[i],] if len(mfmas) > 0 else [] - instsToCheck = mfma + packItems + instsToCheck = mfma + packItemsA + packItemsM + packItemsB numDsInsts = 0 lastLgkmCnt = -1 for ds in filter(lambda j: isinstance(j, (DSLoadInstruction, DSStoreInstruction, SWaitCnt)), reversed(prevIterCode.flatitems() + iterCode.flatitems())): @@ -1271,18 +1296,28 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): #### # scheduled pack #### - if packItems: + _instPerPackA = 0 + _instPerPackB = 0 + _instPerPackM = 0 + instPerPackA_r = instPerPackA + instPerPackB_r = instPerPackB + instPerPackM_r = instPerPackM + if packItemsA or packItemsB or packItemsM: # how many pack have to be done # calculate the data index of this mfma used for A and B # if i // kernel["MIWaveTile"][0]==0, mfma will use new A (need to take iu into account) # if i % kernel["MIWaveTile"][0]==0, mfma will use new B - packAIdx += instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 - packBIdx += instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0 + _instPerPackA = instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + packAIdx += _instPerPackA + _instPerPackB = instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0 + packBIdx += _instPerPackB if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: if kernel["ProblemType"]["Sparse"] == 2: - packMIdx += instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0 + _instPerPackM = instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0 else: - packMIdx += instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + _instPerPackM = instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + + packMIdx += _instPerPackM # blockWidth < 1, means 0.5 or 0.25 (BF,H,Int8) if self.states.archCaps["HasEccHalf"] or not self.states.asmCaps["HasWMMA_V1"]: packAIdx = packAIdx if tPA["bpe"] < 4 and (not kernel["UnrollMajorLDSA"] or kernel["ConvertAfterDS"]) else 0 @@ -1298,33 +1333,96 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u, packMIdx:%u" %(packAIdx,packBIdx,packMIdx)) else: iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u" %(packAIdx,packBIdx)) - # we put 2 pack in each mfma - for j in range(instPerPackA): - if packItems: - iterCode.add(packItems.pop(0)) + + # 1st round + # put the required pack into mfma iter + for j in range(_instPerPackA): + if packItemsA: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedA >= packAIdx: + break + iterCode.add(packItemsA.pop(0)) curPackIdx += 1 + numPackedA += 1 + instPerPackA_r -= 1 + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: - for j in range(ceil(instPerPackM)): - if packItems: - iterCode.add(packItems.pop(0)) + for j in range(ceil(_instPerPackM)): + if packItemsM: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedM >= packMIdx: + break + iterCode.add(packItemsM.pop(0)) curPackIdx += 1 - for j in range(instPerPackB): - if packItems: - iterCode.add(packItems.pop(0)) + numPackedM += 1 + instPerPackM_r -= 1 + + for j in range(_instPerPackB): + if packItemsB: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedB >= packBIdx: + break + iterCode.add(packItemsB.pop(0)) curPackIdx += 1 + numPackedB += 1 + instPerPackB_r -= 1 + + # 2nd round + # put the desired pack into mfma iter + if instPerPackA_r > 0: + for j in range(instPerPackA_r): + if packItemsA: + iterCode.add(packItemsA.pop(0)) + curPackIdx += 1 + numPackedA += 1 + + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and instPerPackM_r > 0: + for j in range(ceil(instPerPackM_r)): + if packItemsM: + iterCode.add(packItemsM.pop(0)) + curPackIdx += 1 + numPackedM += 1 + + if instPerPackB_r > 0: + for j in range(instPerPackB_r): + if packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + + remainLatency = 2 if curPackIdx < numPack + 2 else 0 + # since packed register need to wait 2 quad cycle to finish packing # we insert pack instruction if we can, or s_nop - while curPackIdx < numPack+2: - if packItems: - iterCode.add(packItems.pop(0)) + while remainLatency > 0: + if packItemsA: + iterCode.add(packItemsA.pop(0)) curPackIdx += 1 + numPackedA += 1 + remainLatency -= 1 + elif packItemsM: + iterCode.add(packItemsM.pop(0)) + curPackIdx += 1 + numPackedM += 1 + remainLatency -= 1 + elif packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + remainLatency -= 1 else: - iterCode.add(SNop(waitState=1, comment="VALU packing writes to be consumed by matrix instruction")) + latency = remainLatency - 1 + iterCode.add(SNop(waitState=latency, comment="VALU packing writes to be consumed by matrix instruction")) curPackIdx += 1 - break + remainLatency -= (latency+1) + if i == numMfmaPerIter - 1: - while packItems: - iterCode.add(packItems.pop(0)) + while packItemsA: + iterCode.add(packItemsA.pop(0)) + while packItemsM: + iterCode.add(packItemsM.pop(0)) + while packItemsB: + iterCode.add(packItemsB.pop(0)) #### # scheduled mfma dependency