From 88d740e902cf2d6dfae517159ee8d4163708b96b Mon Sep 17 00:00:00 2001 From: Vin Huang Date: Fri, 15 Nov 2024 10:03:28 +0800 Subject: [PATCH] Refactoy the pack scheduling for scheduleIterAlg = 3. * Used 3 different pack pools to store the pack instructions of A, B, and Metadata * Step 1, only put the required pack into the code (the number of required packs may differ for each mfma iteration). check if the inserted pack is forfulled the instPerPack, if not insert next pack instructions until statisfied. Step 2, if there still have room before mfma, then insert next pack instructions (#instPerPack) Step 3, put another pack or SNop before the mfma instruction according to the needed latency. the combination of insertion may be 2 packs, 1 pack + snop 0, or snop 1. --- tensilelite/Tensile/KernelWriter.py | 250 ++++++++++++++++++++++------ 1 file changed, 203 insertions(+), 47 deletions(-) diff --git a/tensilelite/Tensile/KernelWriter.py b/tensilelite/Tensile/KernelWriter.py index 5af14ca0c3..1e8fe17677 100644 --- a/tensilelite/Tensile/KernelWriter.py +++ b/tensilelite/Tensile/KernelWriter.py @@ -758,6 +758,9 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe packBIdx = 0 packMIdx = 0 + numPackedA = 0 + numPackedB = 0 + numPackedM = 0 ##### # Prepare localReadCode #### @@ -832,11 +835,19 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe instPerPackM = 0 if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and not kernel["UnrollMajorLDSMetadata"]: - instPerPackM = 1.5 if self.states.lrvwTileMetadata > 1 and kernel["MIInputPerThreadMetadata"] == 1 else 1 - - packItems = [] + instPerPackM = 1 + if self.states.lrvwTileMetadata > 1: + if kernel["MIInputPerThreadMetadata"] == 1: + instPerPackM = 1.5 + elif kernel["MIInputPerThreadMetadata"] == 4: + instPerPackM = 3 + packItemsA = [] + packItemsB = [] + packItemsM = [] for iui in range(kernel["InnerUnroll"]): - packINtems = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsA = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsB = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsM = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] packA = packCode.findNamedItem("packA_I%s"%(iui)) packB = packCode.findNamedItem("packB_I%s"%(iui)) packM = packCode.findNamedItem("packMetadata_I%s"%(iui)) @@ -856,58 +867,72 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe if packAItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8(): for n in range(instPerPackA): - packINtems[0].append(packAItems.pop(0)) + packINtemsA[0].append(packAItems.pop(0)) else: for j in range(self.states.numReadsIterCoalescedA): for n in range(instPerPackA): - packINtems[j].append(packAItems.pop(0)) + packINtemsA[j].append(packAItems.pop(0)) if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: for j in range(self.states.numReadsIterCoalescedMetadata): for n in range(ceil(instPerPackM)): if packMItems: - packINtems[j].append(packMItems.pop(0)) + packINtemsM[j].append(packMItems.pop(0)) else: break if packBItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8(): for n in range(instPerPackB): - packINtems[0].append(packBItems.pop(0)) + packINtemsB[0].append(packBItems.pop(0)) else: for j in range(self.states.numReadsIterCoalescedB): for n in range(instPerPackB): - packINtems[j].append(packBItems.pop(0)) + packINtemsB[j].append(packBItems.pop(0)) while packAItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8(): for n in range(instPerPackA): - packINtems[0].append(packAItems.pop(0)) + if packAItems: + packINtemsA[0].append(packAItems.pop(0)) + else: + break else: for j in range(self.states.numReadsIterCoalescedA): for n in range(instPerPackA): - packINtems[j].append(packAItems.pop(0)) + if packAItems: + packINtemsA[j].append(packAItems.pop(0)) + else: + break if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: while packMItems: for j in range(self.states.numReadsIterCoalescedMetadata): for n in range(ceil(instPerPackM)): if packMItems: - packINtems[j].append(packMItems.pop(0)) + packINtemsM[j].append(packMItems.pop(0)) else: break while packBItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8(): for n in range(instPerPackB): - packINtems[0].append(packBItems.pop(0)) + if packBItems: + packINtemsB[0].append(packBItems.pop(0)) + else: + break else: for j in range(self.states.numReadsIterCoalescedB): for n in range(instPerPackB): - packINtems[j].append(packBItems.pop(0)) + if packBItems: + packINtemsB[j].append(packBItems.pop(0)) + else: + break - for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB)): - packItems += packINtems.pop(0) + for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)): + packItemsA += packINtemsA.pop(0) + packItemsB += packINtemsB.pop(0) + packItemsM += packINtemsM.pop(0) # remove s_nop for packing # we will add s_nop if needed @@ -993,7 +1018,7 @@ def checkLocalReadFIFOFull(currentMFMA, fifo, lrItems, numLR, numLREven): return numToBeIssued oneBufferScheduling = kernel["1LDSBuffer"] or kernel["DirectToLdsA"] or kernel["DirectToLdsB"] - + def hasDependency(lr: DSLoadInstruction, inst: Instruction) -> bool: lrDataReg = lr.dst @@ -1009,6 +1034,10 @@ def hasDependency(lr: DSLoadInstruction, inst: Instruction) -> bool: def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): return any(hasDependency(lr, inst) for inst in insts) + + insertedPackA = 0 + insertedPackB = 0 + insertedPackM = 0 for i in range(numMfmaPerIter): mfmaIndex = iteration * numMfmaPerIter + i insertInst = iterCode.countType(Instruction) @@ -1053,7 +1082,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): numLocalReadShouldSchedule = 0 # prefetch load for next wave tile along M since we re-use B first. tileM: int = kernel["MIWaveTileA"] - instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItems + instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItemsA + packItemsM + packItemsB localReadItemsThisLoop = sorted(localReadItemsThisLoop, key=lambda o: hasAnyDependency(o, instsToCheck), reverse=True) for lr in localReadItemsThisLoop: @@ -1241,7 +1270,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): mfmas = [mfma for mfma in macIterCode.flatitems() if isinstance(mfma, (MFMAInstruction, SMFMAInstruction,))] ## To support do["MAC"] is False mfma = [mfmas[i],] if len(mfmas) > 0 else [] - instsToCheck = mfma + packItems + instsToCheck = mfma + packItemsA + packItemsM + packItemsB numDsInsts = 0 lastLgkmCnt = -1 for ds in filter(lambda j: isinstance(j, (DSLoadInstruction, DSStoreInstruction, SWaitCnt)), reversed(prevIterCode.flatitems() + iterCode.flatitems())): @@ -1271,18 +1300,29 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): #### # scheduled pack #### - if packItems: + _instPerPackA = 0 + _instPerPackB = 0 + _instPerPackM = 0 + if packItemsA or packItemsB or packItemsM: + + # check the remain latency before mfma + currentInsertInst = iterCode.countType(Instruction) - insertInst + latencyLeft = self.states.miLatencyLeft - currentInsertInst # how many pack have to be done # calculate the data index of this mfma used for A and B # if i // kernel["MIWaveTile"][0]==0, mfma will use new A (need to take iu into account) # if i % kernel["MIWaveTile"][0]==0, mfma will use new B - packAIdx += instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 - packBIdx += instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0 + _instPerPackA = instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + packAIdx += _instPerPackA + _instPerPackB = instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0 + packBIdx += _instPerPackB if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: if kernel["ProblemType"]["Sparse"] == 2: - packMIdx += instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0 + _instPerPackM = instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0 else: - packMIdx += instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + _instPerPackM = instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + + packMIdx += _instPerPackM # blockWidth < 1, means 0.5 or 0.25 (BF,H,Int8) if self.states.archCaps["HasEccHalf"] or not self.states.asmCaps["HasWMMA_V1"]: packAIdx = packAIdx if tPA["bpe"] < 4 and (not kernel["UnrollMajorLDSA"] or kernel["ConvertAfterDS"]) else 0 @@ -1298,33 +1338,149 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u, packMIdx:%u" %(packAIdx,packBIdx,packMIdx)) else: iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u" %(packAIdx,packBIdx)) - # we put 2 pack in each mfma - for j in range(instPerPackA): - if packItems: - iterCode.add(packItems.pop(0)) + + # Step 1 + # put the required pack into mfma iter + for j in range(_instPerPackA): + if packItemsA: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedA >= packAIdx: + break + iterCode.add(packItemsA.pop(0)) curPackIdx += 1 + numPackedA += 1 + latencyLeft -= 1 + insertedPackA += 1 + + # check if the inserted pack instructions are fulfilled instPerPack. + # If not, insert the next pack instructions until satisfied. + # The unsatisfied is usually caused by Step 3 of the previous round. + instPackLeft = (insertedPackA % instPerPackA) if instPerPackA > 0 else 0 + instPackLeft = (instPerPackA - instPackLeft) if instPackLeft > 0 else 0 + if instPackLeft > 0: + insertedPackA = 0 + for j in range(instPackLeft): + if packItemsA: + iterCode.add(packItemsA.pop(0)) + curPackIdx += 1 + numPackedA += 1 + latencyLeft -= 1 + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: - for j in range(ceil(instPerPackM)): - if packItems: - iterCode.add(packItems.pop(0)) + # put the required pack into mfma iter + for j in range(ceil(_instPerPackM)): + if packItemsM: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedM >= packMIdx: + break + iterCode.add(packItemsM.pop(0)) curPackIdx += 1 - for j in range(instPerPackB): - if packItems: - iterCode.add(packItems.pop(0)) + numPackedM += 1 + latencyLeft -= 1 + insertedPackM += 1 + + # check if the inserted pack instructions are fulfilled instPerPack. + # If not, insert the next pack instructions until satisfied. + # The unsatisfied is usually caused by Step 3 of the previous round. + instPackLeft = (insertedPackM % ceil(instPerPackM)) if instPerPackM > 0 else 0 + instPackLeft = (ceil(instPerPackM) - instPackLeft) if instPackLeft > 0 else 0 + if instPackLeft > 0: + insertedPackM = 0 + for j in range(instPackLeft): + if packItemsM: + iterCode.add(packItemsM.pop(0)) + curPackIdx += 1 + numPackedM += 1 + latencyLeft -= 1 + + # put the required pack into mfma iter + for j in range(_instPerPackB): + if packItemsB: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedB >= packBIdx: + break + iterCode.add(packItemsB.pop(0)) curPackIdx += 1 + numPackedB += 1 + latencyLeft -= 1 + insertedPackB += 1 + + # check if the inserted pack instructions are fulfilled instPerPack. + # If not, insert the next pack instructions until satisfied. + # The unsatisfied is usually caused by Step 3 of the previous round. + instPackLeft = (insertedPackB % instPerPackB) if instPerPackB > 0 else 0 + instPackLeft = (instPerPackB - instPackLeft) if instPackLeft > 0 else 0 + if instPackLeft > 0: + insertedPackB = 0 + for j in range(instPackLeft): + if packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + latencyLeft -= 1 + + # Step 2 + # put the desired pack into mfma iter + if latencyLeft > 0: + for j in range(instPerPackA): + if packItemsA: + iterCode.add(packItemsA.pop(0)) + curPackIdx += 1 + numPackedA += 1 + latencyLeft -= 1 + + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and latencyLeft > 0: + for j in range(ceil(instPerPackM)): + if packItemsM: + iterCode.add(packItemsM.pop(0)) + curPackIdx += 1 + numPackedM += 1 + latencyLeft -= 1 + + if latencyLeft > 0: + for j in range(instPerPackB): + if packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + latencyLeft -= 1 + + # Step 3 # since packed register need to wait 2 quad cycle to finish packing # we insert pack instruction if we can, or s_nop - while curPackIdx < numPack+2: - if packItems: - iterCode.add(packItems.pop(0)) + remainLatency = 2 if curPackIdx < numPack + 2 else 0 + while remainLatency > 0: + if packItemsA: + iterCode.add(packItemsA.pop(0)) + curPackIdx += 1 + numPackedA += 1 + remainLatency -= 1 + insertedPackA += 1 + elif packItemsM: + iterCode.add(packItemsM.pop(0)) curPackIdx += 1 + numPackedM += 1 + remainLatency -= 1 + insertedPackM += 1 + elif packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + remainLatency -= 1 + insertedPackB += 1 else: - iterCode.add(SNop(waitState=1, comment="VALU packing writes to be consumed by matrix instruction")) + latency = remainLatency - 1 + iterCode.add(SNop(waitState=latency, comment="VALU packing writes to be consumed by matrix instruction")) curPackIdx += 1 - break + remainLatency -= (latency+1) + if i == numMfmaPerIter - 1: - while packItems: - iterCode.add(packItems.pop(0)) + while packItemsA: + iterCode.add(packItemsA.pop(0)) + while packItemsM: + iterCode.add(packItemsM.pop(0)) + while packItemsB: + iterCode.add(packItemsB.pop(0)) #### # scheduled mfma dependency @@ -2366,7 +2522,7 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): # Open persistent loop loopComponent = Component.PersistentLoop.find(self) module.add(loopComponent.openPersistentLoop(self, kernel)) - + module.add(self.setupNewTile(kernel, tensorParametersA, tensorParametersB, isOptNLL=False)) if self.do["executeToPrefetchEnd"]: @@ -2568,7 +2724,7 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): # last NLL or pack DTV case, no deep copy for pack # pack code for local prefetch is generated in noLoadLoopBody and used for DTV even deepCopyPack = pack - else: + else: # deepCopy packCode for OptNLL noLoadLoop deepCopyPack = fastdeepcopy(pack) module.add(self.noLoadLoop(kernel, tensorParametersA, tensorParametersB, isOptNLL=False, isNGLL=False, pack=deepCopyPack, NLLindex=NLLindex, NLLnum=NLLnum)) @@ -2912,7 +3068,7 @@ def initKernel(self, kernel, tensorParametersA, tensorParametersB): self.states.asmCaps = self.ti.getAsmCaps() self.states.archCaps = self.ti.getArchCaps() self.states.regCaps = self.ti.getRegCaps() - + self.asmAssert = Assert(self.states.laneSGPRCount, kernel["WavefrontSize"], self.db["EnableAsserts"]) # Only assembly supports scheduling @@ -3264,7 +3420,7 @@ def readWriteVectors(mat, vw, kernel): self.states.bpeCexternal = self.states.bpeCexternalGSU1 if kernel["_GlobalAccumulation"] and kernel["_GlobalAccumulation"] != 'PartialsBuffer': self.states.bpeCexternal = self.states.bpeCinternal - + # special case for wmma h and b if (kernel["EnableMatrixInstruction"] @@ -4071,7 +4227,7 @@ def readWriteVectors(mat, vw, kernel): self.defineSgpr("AddressWS", numSgprAddressWS) self.defineSgpr("AddressFlags", numSgprAddressFlags) self.states.numSgprStreamK += numSgprAddressWS + numSgprAddressFlags - + #asm input interface depen self.defineSgpr("StridesD", self.states.d.numSgprStrides) self.defineSgpr("StridesC", self.states.c.numSgprStrides) @@ -4144,7 +4300,7 @@ def readWriteVectors(mat, vw, kernel): self.defineSgpr("StreamKLocalEnd", 1) if kernel["StreamKAtomic"] == 0: self.defineSgpr("SrdWS", 4, 4) - + #------------------------ # Registers defined below this point are not available in the post-loop # Post-loop is after tail loop exits, ie the store code.