diff --git a/tensilelite/Tensile/KernelWriter.py b/tensilelite/Tensile/KernelWriter.py index 5af14ca0c..fec67efc6 100644 --- a/tensilelite/Tensile/KernelWriter.py +++ b/tensilelite/Tensile/KernelWriter.py @@ -758,6 +758,9 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe packBIdx = 0 packMIdx = 0 + numPackedA = 0 + numPackedB = 0 + numPackedM = 0 ##### # Prepare localReadCode #### @@ -832,11 +835,19 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe instPerPackM = 0 if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and not kernel["UnrollMajorLDSMetadata"]: - instPerPackM = 1.5 if self.states.lrvwTileMetadata > 1 and kernel["MIInputPerThreadMetadata"] == 1 else 1 - - packItems = [] + instPerPackM = 1 + if self.states.lrvwTileMetadata > 1: + if kernel["MIInputPerThreadMetadata"] == 1: + instPerPackM = 1.5 + elif kernel["MIInputPerThreadMetadata"] == 4: + instPerPackM = 3 + packItemsA = [] + packItemsB = [] + packItemsM = [] for iui in range(kernel["InnerUnroll"]): - packINtems = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsA = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsB = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] + packINtemsM = [ [] for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)) ] packA = packCode.findNamedItem("packA_I%s"%(iui)) packB = packCode.findNamedItem("packB_I%s"%(iui)) packM = packCode.findNamedItem("packMetadata_I%s"%(iui)) @@ -856,58 +867,72 @@ def makeSubIterSchedule(self, kernel, tPA, tPB, localReadCode, iteration, pointe if packAItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8(): for n in range(instPerPackA): - packINtems[0].append(packAItems.pop(0)) + packINtemsA[0].append(packAItems.pop(0)) else: for j in range(self.states.numReadsIterCoalescedA): for n in range(instPerPackA): - packINtems[j].append(packAItems.pop(0)) + packINtemsA[j].append(packAItems.pop(0)) if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: for j in range(self.states.numReadsIterCoalescedMetadata): for n in range(ceil(instPerPackM)): if packMItems: - packINtems[j].append(packMItems.pop(0)) + packINtemsM[j].append(packMItems.pop(0)) else: break if packBItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8(): for n in range(instPerPackB): - packINtems[0].append(packBItems.pop(0)) + packINtemsB[0].append(packBItems.pop(0)) else: for j in range(self.states.numReadsIterCoalescedB): for n in range(instPerPackB): - packINtems[j].append(packBItems.pop(0)) + packINtemsB[j].append(packBItems.pop(0)) while packAItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeA"].isFloat8(): for n in range(instPerPackA): - packINtems[0].append(packAItems.pop(0)) + if packAItems: + packINtemsA[0].append(packAItems.pop(0)) + else: + break else: for j in range(self.states.numReadsIterCoalescedA): for n in range(instPerPackA): - packINtems[j].append(packAItems.pop(0)) + if packAItems: + packINtemsA[j].append(packAItems.pop(0)) + else: + break if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: while packMItems: for j in range(self.states.numReadsIterCoalescedMetadata): for n in range(ceil(instPerPackM)): if packMItems: - packINtems[j].append(packMItems.pop(0)) + packINtemsM[j].append(packMItems.pop(0)) else: break while packBItems: if kernel["ConvertAfterDS"] and kernel["ProblemType"]["DataTypeB"].isFloat8(): for n in range(instPerPackB): - packINtems[0].append(packBItems.pop(0)) + if packBItems: + packINtemsB[0].append(packBItems.pop(0)) + else: + break else: for j in range(self.states.numReadsIterCoalescedB): for n in range(instPerPackB): - packINtems[j].append(packBItems.pop(0)) + if packBItems: + packINtemsB[j].append(packBItems.pop(0)) + else: + break - for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB)): - packItems += packINtems.pop(0) + for j in range(max(self.states.numReadsIterCoalescedA,self.states.numReadsIterCoalescedB,self.states.numReadsIterCoalescedMetadata)): + packItemsA += packINtemsA.pop(0) + packItemsB += packINtemsB.pop(0) + packItemsM += packINtemsM.pop(0) # remove s_nop for packing # we will add s_nop if needed @@ -1053,7 +1078,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): numLocalReadShouldSchedule = 0 # prefetch load for next wave tile along M since we re-use B first. tileM: int = kernel["MIWaveTileA"] - instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItems + instsToCheck = mfmas[i:min(i+tileM+1, numMfmaPerIter)] + packItemsA + packItemsM + packItemsB localReadItemsThisLoop = sorted(localReadItemsThisLoop, key=lambda o: hasAnyDependency(o, instsToCheck), reverse=True) for lr in localReadItemsThisLoop: @@ -1241,7 +1266,7 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): mfmas = [mfma for mfma in macIterCode.flatitems() if isinstance(mfma, (MFMAInstruction, SMFMAInstruction,))] ## To support do["MAC"] is False mfma = [mfmas[i],] if len(mfmas) > 0 else [] - instsToCheck = mfma + packItems + instsToCheck = mfma + packItemsA + packItemsM + packItemsB numDsInsts = 0 lastLgkmCnt = -1 for ds in filter(lambda j: isinstance(j, (DSLoadInstruction, DSStoreInstruction, SWaitCnt)), reversed(prevIterCode.flatitems() + iterCode.flatitems())): @@ -1271,18 +1296,28 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): #### # scheduled pack #### - if packItems: + _instPerPackA = 0 + _instPerPackB = 0 + _instPerPackM = 0 + instPerPackA_r = instPerPackA + instPerPackB_r = instPerPackB + instPerPackM_r = instPerPackM + if packItemsA or packItemsB or packItemsM: # how many pack have to be done # calculate the data index of this mfma used for A and B # if i // kernel["MIWaveTile"][0]==0, mfma will use new A (need to take iu into account) # if i % kernel["MIWaveTile"][0]==0, mfma will use new B - packAIdx += instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 - packBIdx += instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0 + _instPerPackA = instPerPackA if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + packAIdx += _instPerPackA + _instPerPackB = instPerPackB if i % kernel["MIWaveTileA"] == 0 else 0 + packBIdx += _instPerPackB if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: if kernel["ProblemType"]["Sparse"] == 2: - packMIdx += instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0 + _instPerPackM = instPerPackM if i % kernel["MIWaveTileA"] == 0 else 0 else: - packMIdx += instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + _instPerPackM = instPerPackM if i//(kernel["MIWaveTileA"]+kernel["MIWaveTileA"]*kernel["MIWaveTileB"]*(i//(kernel["MIWaveTileA"]*kernel["MIWaveTileB"]))) == 0 else 0 + + packMIdx += _instPerPackM # blockWidth < 1, means 0.5 or 0.25 (BF,H,Int8) if self.states.archCaps["HasEccHalf"] or not self.states.asmCaps["HasWMMA_V1"]: packAIdx = packAIdx if tPA["bpe"] < 4 and (not kernel["UnrollMajorLDSA"] or kernel["ConvertAfterDS"]) else 0 @@ -1298,33 +1333,96 @@ def hasAnyDependency(lr: DSLoadInstruction, insts: List[Instruction]): iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u, packMIdx:%u" %(packAIdx,packBIdx,packMIdx)) else: iterCode.addComment0("pack scheduling: packAIdx:%u, packBIdx:%u" %(packAIdx,packBIdx)) - # we put 2 pack in each mfma - for j in range(instPerPackA): - if packItems: - iterCode.add(packItems.pop(0)) + + # 1st round + # put the required pack into mfma iter + for j in range(_instPerPackA): + if packItemsA: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedA >= packAIdx: + break + iterCode.add(packItemsA.pop(0)) curPackIdx += 1 + numPackedA += 1 + instPerPackA_r -= 1 + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"]: - for j in range(ceil(instPerPackM)): - if packItems: - iterCode.add(packItems.pop(0)) + for j in range(ceil(_instPerPackM)): + if packItemsM: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedM >= packMIdx: + break + iterCode.add(packItemsM.pop(0)) curPackIdx += 1 - for j in range(instPerPackB): - if packItems: - iterCode.add(packItems.pop(0)) + numPackedM += 1 + instPerPackM_r -= 1 + + for j in range(_instPerPackB): + if packItemsB: + # Skip if the required pack has already been placed in the previous mfma iter. + if numPackedB >= packBIdx: + break + iterCode.add(packItemsB.pop(0)) curPackIdx += 1 + numPackedB += 1 + instPerPackB_r -= 1 + + # 2nd round + # put the desired pack into mfma iter + if instPerPackA_r > 0: + for j in range(instPerPackA_r): + if packItemsA: + iterCode.add(packItemsA.pop(0)) + curPackIdx += 1 + numPackedA += 1 + + if kernel["ProblemType"]["Sparse"] and not kernel["DirectToVgprSparseMetadata"] and instPerPackM_r > 0: + for j in range(ceil(instPerPackM_r)): + if packItemsM: + iterCode.add(packItemsM.pop(0)) + curPackIdx += 1 + numPackedM += 1 + + if instPerPackB_r > 0: + for j in range(instPerPackB_r): + if packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + + remainLatency = 2 if curPackIdx < numPack + 2 else 0 + # since packed register need to wait 2 quad cycle to finish packing # we insert pack instruction if we can, or s_nop - while curPackIdx < numPack+2: - if packItems: - iterCode.add(packItems.pop(0)) + while remainLatency > 0: + if packItemsA: + iterCode.add(packItemsA.pop(0)) curPackIdx += 1 + numPackedA += 1 + remainLatency -= 1 + elif packItemsM: + iterCode.add(packItemsM.pop(0)) + curPackIdx += 1 + numPackedM += 1 + remainLatency -= 1 + elif packItemsB: + iterCode.add(packItemsB.pop(0)) + curPackIdx += 1 + numPackedB += 1 + remainLatency -= 1 else: - iterCode.add(SNop(waitState=1, comment="VALU packing writes to be consumed by matrix instruction")) + latency = remainLatency - 1 + iterCode.add(SNop(waitState=latency, comment="VALU packing writes to be consumed by matrix instruction")) curPackIdx += 1 - break + remainLatency -= (latency+1) + if i == numMfmaPerIter - 1: - while packItems: - iterCode.add(packItems.pop(0)) + while packItemsA: + iterCode.add(packItemsA.pop(0)) + while packItemsM: + iterCode.add(packItemsM.pop(0)) + while packItemsB: + iterCode.add(packItemsB.pop(0)) #### # scheduled mfma dependency