diff --git a/tensilelite/Tensile/KernelWriter.py b/tensilelite/Tensile/KernelWriter.py index 5af14ca0c..1e8fe1767 100644 --- a/tensilelite/Tensile/KernelWriter.py +++ b/tensilelite/Tensile/KernelWriter.py @@ -758,6 +758,9 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe packBIdx = 0 packMIdx = 0 + numPackedA = 0 + numPackedB = 0 + numPackedM = 0 ##### # Prepare localReadCode #### @@ -832,11 +835,19 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe instPerPackM = 0 if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and not kernel["UnrollMajorLDSMetadata"]: - instPerPackM = 1.5 if self.states.lrvwTileMetadata > 1 and kernel["MIInputPerThreadMetadata"] == 1 else 1 - - packItems = [] + instPerPackM = 1 + if self.states.lrvwTileMetadata > 1: + if kernel["MIInputPerThreadMetadata"] == 1: + instPerPackM = 1.5 + elif kernel["MIInputPerThreadMetadata"] == 4: + instPerPackM = 3 + packItemsA = [] + packItemsB = [] + packItemsM = [] for iui in range(kernel["InnerUnroll"]): - packINtems = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsA = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsB = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsM = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] packA = packCode.findNamedItem("packA_I%s"%(iui)) packB = packCode.findNamedItem("packB_I%s"%(iui)) packM = packCode.findNamedItem("packMetadata_I%s"%(iui)) @@ -856,58 +867,72 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe if packAItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8(): for n in range(instPerPackA): - packINtems[0].append(packAItems.pop(0)) + packINtemsA[0].append(packAItems.pop(0)) else: for j in range(self.states.numReadsIterCoalescedA): for n in range(instPerPackA): - packINtems[j].append(packAItems.pop(0)) + packINtemsA[j].append(packAItems.pop(0)) if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: for j in range(self.states.numReadsIterCoalescedMetadata): for n in range(ceil(instPerPackM)): if packMItems: - packINtems[j].append(packMItems.pop(0)) + packINtemsM[j].append(packMItems.pop(0)) else: break if packBItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8(): for n in range(instPerPackB): - packINtems[0].append(packBItems.pop(0)) + packINtemsB[0].append(packBItems.pop(0)) else: for j in range(self.states.numReadsIterCoalescedB): for n in range(instPerPackB): - packINtems[j].append(packBItems.pop(0)) + packINtemsB[j].append(packBItems.pop(0)) while packAItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8(): for n in range(instPerPackA): - packINtems[0].append(packAItems.pop(0)) + if packAItems: + packINtemsA[0].append(packAItems.pop(0)) + else: + break else: for j in range(self.states.numReadsIterCoalescedA): for n in range(instPerPackA): - packINtems[j].append(packAItems.pop(0)) + if packAItems: + packINtemsA[j].append(packAItems.pop(0)) + else: + break if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: while packMItems: for j in range(self.states.numReadsIterCoalescedMetadata): for n in range(ceil(instPerPackM)): if packMItems: - packINtems[j].append(packMItems.pop(0)) + packINtemsM[j].append(packMItems.pop(0)) else: break while packBItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8(): for n in range(instPerPackB): - packINtems[0].append(packBItems.pop(0)) + if packBItems: + packINtemsB[0].append(packBItems.pop(0)) + else: + break else: for j in range(self.states.numReadsIterCoalescedB): for n in range(instPerPackB): - packINtems[j].append(packBItems.pop(0)) + if packBItems: + packINtemsB[j].append(packBItems.pop(0)) + else: + break - for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB)): - packItems += packINtems.pop(0) + for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)): + packItemsA += packINtemsA.pop(0) + packItemsB += packINtemsB.pop(0) + packItemsM += packINtemsM.pop(0) # remove s_nop for packing # we will add s_nop if needed @@ -993,7 +1018,7 @@ def checkLocalReadFIFOFull(currentMFMA, fifo, lrItems, numLR, numLREven): return numToBeIssued oneBufferScheduling = kernel["1LDSBuffer"] or kernel["DirectToLdsA"] or kernel["DirectToLdsB"] - + def hasDependency(lr: DSLoadInstruction, inst: Instruction) -> bool: lrDataReg = lr.dst @@ -1009,6 +1034,10 @@ def hasDependency(lr: DSLoadInstruction, inst: Instruction) -> bool: def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): return any(hasDependency(lr, inst) for inst in insts) + + insertedPackA = 0 + insertedPackB = 0 + insertedPackM = 0 for i in range(numMfmaPerIter): mfmaIndex = iteration * numMfmaPerIter + i insertInst = iterCode.countType(Instruction) @@ -1053,7 +1082,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): numLocalReadShouldSchedule = 0 # prefetch load for next wave tile along M since we re-use B first. tileM: int = kernel["MIWaveTileA"] - instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItems + instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItemsA + packItemsM + packItemsB localReadItemsThisLoop = sorted(localReadItemsThisLoop, key=lambda o: hasAnyDependency(o, instsToCheck), reverse=True) for lr in localReadItemsThisLoop: @@ -1241,7 +1270,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): mfmas = [mfma for mfma in macIterCode.flatitems() if isinstance(mfma, (MFMAInstruction, SMFMAInstruction,))] ## To support do["MAC"] is False mfma = [mfmas[i],] if len(mfmas) > 0 else [] - instsToCheck = mfma + packItems + instsToCheck = mfma + packItemsA + packItemsM + packItemsB numDsInsts = 0 lastLgkmCnt = -1 for ds in filter(lambda j: isinstance(j, (DSLoadInstruction, DSStoreInstruction, SWaitCnt)), reversed(prevIterCode.flatitems() + iterCode.flatitems())): @@ -1271,18 +1300,29 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): #### # scheduled pack #### - if packItems: + _instPerPackA = 0 + _instPerPackB = 0 + _instPerPackM = 0 + if packItemsA or packItemsB or packItemsM: + + # check the remain latency before mfma + currentInsertInst = iterCode.countType(Instruction) - insertInst + latencyLeft = self.states.miLatencyLeft - currentInsertInst # how many pack have to be done # calculate the data index of this mfma used for A and B # if i // kernel["MIWaveTile"][0]==0, mfma will use new A (need to take iu into account) # if i % kernel["MIWaveTile"][0]==0, mfma will use new B - packAIdx += instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 - packBIdx += instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0 + _instPerPackA = instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + packAIdx += _instPerPackA + _instPerPackB = instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0 + packBIdx += _instPerPackB if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: if kernel["ProblemType"]["Sparse"] == 2: - packMIdx += instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0 + _instPerPackM = instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0 else: - packMIdx += instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + _instPerPackM = instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + + packMIdx += _instPerPackM # blockWidth < 1, means 0.5 or 0.25 (BF,H,Int8) if self.states.archCaps["HasEccHalf"] or not self.states.asmCaps["HasWMMA_V1"]: packAIdx = packAIdx if tPA["bpe"] < 4 and (not kernel["UnrollMajorLDSA"] or kernel["ConvertAfterDS"]) else 0 @@ -1298,33 +1338,149 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u, packMIdx:%u" %(packAIdx,packBIdx,packMIdx)) else: iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u" %(packAIdx,packBIdx)) - # we put 2 pack in each mfma - for j in range(instPerPackA): - if packItems: - iterCode.add(packItems.pop(0)) + + # Step 1 + # put the required pack into mfma iter + for j in range(_instPerPackA): + if packItemsA: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedA >= packAIdx: + break + iterCode.add(packItemsA.pop(0)) curPackIdx += 1 + numPackedA += 1 + latencyLeft -= 1 + insertedPackA += 1 + + # check if the inserted pack instructions are fulfilled instPerPack. + # If not, insert the next pack instructions until satisfied. + # The unsatisfied is usually caused by Step 3 of the previous round. + instPackLeft = (insertedPackA % instPerPackA) if instPerPackA > 0 else 0 + instPackLeft = (instPerPackA - instPackLeft) if instPackLeft > 0 else 0 + if instPackLeft > 0: + insertedPackA = 0 + for j in range(instPackLeft): + if packItemsA: + iterCode.add(packItemsA.pop(0)) + curPackIdx += 1 + numPackedA += 1 + latencyLeft -= 1 + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: - for j in range(ceil(instPerPackM)): - if packItems: - iterCode.add(packItems.pop(0)) + # put the required pack into mfma iter + for j in range(ceil(_instPerPackM)): + if packItemsM: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedM >= packMIdx: + break + iterCode.add(packItemsM.pop(0)) curPackIdx += 1 - for j in range(instPerPackB): - if packItems: - iterCode.add(packItems.pop(0)) + numPackedM += 1 + latencyLeft -= 1 + insertedPackM += 1 + + # check if the inserted pack instructions are fulfilled instPerPack. + # If not, insert the next pack instructions until satisfied. + # The unsatisfied is usually caused by Step 3 of the previous round. + instPackLeft = (insertedPackM % ceil(instPerPackM)) if instPerPackM > 0 else 0 + instPackLeft = (ceil(instPerPackM) - instPackLeft) if instPackLeft > 0 else 0 + if instPackLeft > 0: + insertedPackM = 0 + for j in range(instPackLeft): + if packItemsM: + iterCode.add(packItemsM.pop(0)) + curPackIdx += 1 + numPackedM += 1 + latencyLeft -= 1 + + # put the required pack into mfma iter + for j in range(_instPerPackB): + if packItemsB: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedB >= packBIdx: + break + iterCode.add(packItemsB.pop(0)) curPackIdx += 1 + numPackedB += 1 + latencyLeft -= 1 + insertedPackB += 1 + + # check if the inserted pack instructions are fulfilled instPerPack. + # If not, insert the next pack instructions until satisfied. + # The unsatisfied is usually caused by Step 3 of the previous round. + instPackLeft = (insertedPackB % instPerPackB) if instPerPackB > 0 else 0 + instPackLeft = (instPerPackB - instPackLeft) if instPackLeft > 0 else 0 + if instPackLeft > 0: + insertedPackB = 0 + for j in range(instPackLeft): + if packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + latencyLeft -= 1 + + # Step 2 + # put the desired pack into mfma iter + if latencyLeft > 0: + for j in range(instPerPackA): + if packItemsA: + iterCode.add(packItemsA.pop(0)) + curPackIdx += 1 + numPackedA += 1 + latencyLeft -= 1 + + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and latencyLeft > 0: + for j in range(ceil(instPerPackM)): + if packItemsM: + iterCode.add(packItemsM.pop(0)) + curPackIdx += 1 + numPackedM += 1 + latencyLeft -= 1 + + if latencyLeft > 0: + for j in range(instPerPackB): + if packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + latencyLeft -= 1 + + # Step 3 # since packed register need to wait 2 quad cycle to finish packing # we insert pack instruction if we can, or s_nop - while curPackIdx < numPack+2: - if packItems: - iterCode.add(packItems.pop(0)) + remainLatency = 2 if curPackIdx < numPack + 2 else 0 + while remainLatency > 0: + if packItemsA: + iterCode.add(packItemsA.pop(0)) + curPackIdx += 1 + numPackedA += 1 + remainLatency -= 1 + insertedPackA += 1 + elif packItemsM: + iterCode.add(packItemsM.pop(0)) curPackIdx += 1 + numPackedM += 1 + remainLatency -= 1 + insertedPackM += 1 + elif packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + remainLatency -= 1 + insertedPackB += 1 else: - iterCode.add(SNop(waitState=1, comment="VALU packing writes to be consumed by matrix instruction")) + latency = remainLatency - 1 + iterCode.add(SNop(waitState=latency, comment="VALU packing writes to be consumed by matrix instruction")) curPackIdx += 1 - break + remainLatency -= (latency+1) + if i == numMfmaPerIter - 1: - while packItems: - iterCode.add(packItems.pop(0)) + while packItemsA: + iterCode.add(packItemsA.pop(0)) + while packItemsM: + iterCode.add(packItemsM.pop(0)) + while packItemsB: + iterCode.add(packItemsB.pop(0)) #### # scheduled mfma dependency @@ -2366,7 +2522,7 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): # Open persistent loop loopComponent = Component.PersistentLoop.find(self) module.add(loopComponent.openPersistentLoop(self, kernel)) - + module.add(self.setupNewTile(kernel, tensorParametersA, tensorParametersB, isOptNLL=False)) if self.do["executeToPrefetchEnd"]: @@ -2568,7 +2724,7 @@ def kernelBody( self, kernel, tensorParametersA, tensorParametersB ): # last NLL or pack DTV case, no deep copy for pack # pack code for local prefetch is generated in noLoadLoopBody and used for DTV even deepCopyPack = pack - else: + else: # deepCopy packCode for OptNLL noLoadLoop deepCopyPack = fastdeepcopy(pack) module.add(self.noLoadLoop(kernel, tensorParametersA, tensorParametersB, isOptNLL=False, isNGLL=False, pack=deepCopyPack, NLLindex=NLLindex, NLLnum=NLLnum)) @@ -2912,7 +3068,7 @@ def initKernel(self, kernel, tensorParametersA, tensorParametersB): self.states.asmCaps = self.ti.getAsmCaps() self.states.archCaps = self.ti.getArchCaps() self.states.regCaps = self.ti.getRegCaps() - + self.asmAssert = Assert(self.states.laneSGPRCount, kernel["WavefrontSize"], self.db["EnableAsserts"]) # Only assembly supports scheduling @@ -3264,7 +3420,7 @@ def readWriteVectors(mat, vw, kernel): self.states.bpeCexternal = self.states.bpeCexternalGSU1 if kernel["_GlobalAccumulation"] and kernel["_GlobalAccumulation"] != 'PartialsBuffer': self.states.bpeCexternal = self.states.bpeCinternal - + # special case for wmma h and b if (kernel["EnableMatrixInstruction"] @@ -4071,7 +4227,7 @@ def readWriteVectors(mat, vw, kernel): self.defineSgpr("AddressWS", numSgprAddressWS) self.defineSgpr("AddressFlags", numSgprAddressFlags) self.states.numSgprStreamK += numSgprAddressWS + numSgprAddressFlags - + #asm input interface depen self.defineSgpr("StridesD", self.states.d.numSgprStrides) self.defineSgpr("StridesC", self.states.c.numSgprStrides) @@ -4144,7 +4300,7 @@ def readWriteVectors(mat, vw, kernel): self.defineSgpr("StreamKLocalEnd", 1) if kernel["StreamKAtomic"] == 0: self.defineSgpr("SrdWS", 4, 4) - + #------------------------ # Registers defined below this point are not available in the post-loop # Post-loop is after tail loop exits, ie the store code.