Skip to content

Commit

Permalink
[BACKEND] Handle repetitive threads in scan op when the tensor dim is…
Browse files Browse the repository at this point in the history
  • Loading branch information
Jokeren authored Sep 20, 2023
1 parent e5eda09 commit ed5a530
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 37 deletions.
3 changes: 3 additions & 0 deletions include/triton/Analysis/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class ScanLoweringHelper {
unsigned getAxisNumWarpsWithUniqueData();
// Return the number of threads per warp along axis dim.
unsigned getAxisNumThreadsPerWarp();
// Return the number of threads per warp along axis dim with unique data.
unsigned getAxisNumThreadsPerWarpWithUniqueData();
// Return the number of blocks along axis dim.
unsigned getAxisNumBlocks();
// Return the number of blocks along non axis dim.
Expand All @@ -109,6 +111,7 @@ class ScanLoweringHelper {
Location getLoc() { return scanOp.getLoc(); }
unsigned getAxis() { return scanOp.getAxis(); }
triton::gpu::BlockedEncodingAttr getEncoding();
llvm::ArrayRef<int64_t> getShape();
Region &getCombineOp();

private:
Expand Down
65 changes: 34 additions & 31 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,8 @@ unsigned ScanLoweringHelper::getAxisNumElementsPerThread() {
}

unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread() {
SmallVector<unsigned> sizePerThreads = getContigPerThread(getEncoding());
SmallVector<unsigned> sizePerThreads =
triton::gpu::getContigPerThread(getEncoding());
sizePerThreads[getAxis()] = 1;
return product<unsigned>(sizePerThreads);
}
Expand All @@ -223,6 +224,11 @@ unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp() {
return triton::gpu::getThreadsPerWarp(getEncoding())[getAxis()];
}

unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData() {
return triton::gpu::getThreadsPerWarpWithUniqueData(getEncoding(),
getShape())[getAxis()];
}

unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp() {
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding());
threadsPerWarp[getAxis()] = 1;
Expand All @@ -239,42 +245,36 @@ unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA() {
}

unsigned ScanLoweringHelper::getAxisNumWarps() {
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
return warpsPerCTA[getAxis()];
return triton::gpu::getWarpsPerCTA(getEncoding())[getAxis()];
}

unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData() {
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
auto shape = type.getShape();
auto warpsPerCTA =
triton::gpu::getWarpsPerCTAWithUniqueData(srcEncoding, shape);
return warpsPerCTA[getAxis()];
return triton::gpu::getWarpsPerCTAWithUniqueData(getEncoding(),
getShape())[getAxis()];
}

unsigned ScanLoweringHelper::getAxisNumBlocks() {
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding());
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding());
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding());
unsigned axis = getAxis();
return ceil<unsigned>(
type.getShape()[axis],
getShape()[axis],
(sizePerThreads[axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
}

unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding());
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding());
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding());
unsigned axis = getAxis();
unsigned numBlocks = 1;
for (unsigned i = 0; i < sizePerThreads.size(); i++) {
if (i == axis)
continue;
numBlocks *= ceil<unsigned>(
type.getShape()[i],
(sizePerThreads[i] * threadsPerWarp[i] * warpsPerCTA[i]));
numBlocks *=
ceil<unsigned>(getShape()[i], (sizePerThreads[i] * threadsPerWarp[i] *
warpsPerCTA[i]));
}
return numBlocks;
}
Expand All @@ -283,7 +283,7 @@ bool ScanLoweringHelper::isSupported() {
// TODO: Support the following cases:
// 1. Scan on non-blocking encodings
// 2. Scan with multiple operands
if (!isa<triton::gpu::BlockedEncodingAttr>(srcEncoding))
if (!isa<triton::gpu::BlockedEncodingAttr>(getEncoding()))
return false;
if (scanOp.getNumOperands() != 1)
return false;
Expand All @@ -309,8 +309,12 @@ triton::gpu::BlockedEncodingAttr ScanLoweringHelper::getEncoding() {
return srcEncoding.cast<triton::gpu::BlockedEncodingAttr>();
}

llvm::ArrayRef<int64_t> ScanLoweringHelper::getShape() {
return scanOp.getOperand(0).getType().cast<RankedTensorType>().getShape();
}

unsigned ScanLoweringHelper::getAxisElementStride() {
auto order = triton::gpu::getOrder(srcEncoding);
auto order = triton::gpu::getOrder(getEncoding());
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
Expand All @@ -321,7 +325,7 @@ unsigned ScanLoweringHelper::getAxisElementStride() {
}

unsigned ScanLoweringHelper::getAxisThreadStride() {
auto order = triton::gpu::getOrder(srcEncoding);
auto order = triton::gpu::getOrder(getEncoding());
unsigned stride = 1;
for (unsigned dim : order) {
if (dim == getAxis())
Expand All @@ -332,18 +336,17 @@ unsigned ScanLoweringHelper::getAxisThreadStride() {
}

unsigned ScanLoweringHelper::getAxisBlockStride() {
auto order = triton::gpu::getOrder(srcEncoding);
auto order = triton::gpu::getOrder(getEncoding());
unsigned stride = 1;
auto type = scanOp.getOperand(0).getType().cast<RankedTensorType>();
auto sizePerThreads = triton::gpu::getSizePerThread(srcEncoding);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(srcEncoding);
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(srcEncoding);
auto sizePerThreads = triton::gpu::getSizePerThread(getEncoding());
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(getEncoding());
auto warpsPerCTA = triton::gpu::getWarpsPerCTA(getEncoding());
for (unsigned dim : order) {
if (dim == getAxis())
return stride;
stride *= ceil<unsigned int>(type.getShape()[dim], sizePerThreads[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
stride *= ceil<unsigned int>(getShape()[dim], sizePerThreads[dim] *
threadsPerWarp[dim] *
warpsPerCTA[dim]);
}
llvm_unreachable("Axis not found in order");
}
Expand Down
10 changes: 5 additions & 5 deletions lib/Conversion/TritonGPUToLLVM/ScanOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ static void warpScan(SmallVector<Value> &srcValues,
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned elementStride = helper.getAxisElementStride();
unsigned threadStride = helper.getAxisThreadStride();
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
for (unsigned srcIndex = 0; srcIndex < srcValues.size(); srcIndex++) {
unsigned elementIdx = (srcIndex / elementStride) % scanElementsPerThreads;
// Only consider the last element of each contiguous chunk of elements.
Expand Down Expand Up @@ -96,7 +96,7 @@ static void storeWarpAccumulator(SmallVector<Value> &srcValues,
Value parallelLaneId) {
Location loc = helper.getLoc();
unsigned scanElementsPerThreads = helper.getAxisNumElementsPerThread();
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
unsigned chunkId = 0;
Expand Down Expand Up @@ -222,7 +222,7 @@ static void AddPartialReduceOneWarp(SmallVector<Value> &srcValues,
unsigned threadStride = helper.getAxisThreadStride();
unsigned axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
unsigned numParallelLane = helper.getNonAxisNumThreadsPerCTA();
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
Value maskFirstWarp = icmp_eq(warpId, i32_val(0));
Value maskFirstLane = icmp_eq(laneIdAxis, i32_val(0));
Value maskFirstThread = and_(maskFirstWarp, maskFirstLane);
Expand Down Expand Up @@ -394,7 +394,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
auto input = adaptor.getOperands()[0];
auto type = op.getOperand(0).getType().cast<RankedTensorType>();
auto axisNumWarps = helper.getAxisNumWarpsWithUniqueData();
auto axisNumThreads = helper.getAxisNumThreadsPerWarp();
auto axisNumThreads = helper.getAxisNumThreadsPerWarpWithUniqueData();
warpIdAxis = urem(warpIdAxis, i32_val(axisNumWarps));
SmallVector<Value> srcValues =
getTypeConverter()->unpackLLElements(loc, input, rewriter, type);
Expand Down Expand Up @@ -423,7 +423,7 @@ ScanOpConversion::emitFastScan(triton::ScanOp op, triton::ScanOpAdaptor adaptor,
} else if (srcValues.size() > 1) {
// Fast path for the case where there is only one warp with unique data on
// the axis.
unsigned scanDim = helper.getAxisNumThreadsPerWarp();
unsigned scanDim = helper.getAxisNumThreadsPerWarpWithUniqueData();
auto multiDimLaneId = getMultiDimLaneId(rewriter, helper, laneId);
multiDimLaneId[helper.getAxis()] = i32_val(scanDim - 1);
auto threadsPerWarp = triton::gpu::getThreadsPerWarp(helper.getEncoding());
Expand Down
2 changes: 1 addition & 1 deletion python/test/unit/operators/test_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def triton_(in_ptr0, out_ptr0, XBLOCK: tl.constexpr):
torch.testing.assert_close(out, out_ref)


@pytest.mark.parametrize("RBLOCK", [32, 64, 128])
@pytest.mark.parametrize("RBLOCK", [1, 16, 32, 64, 128])
@pytest.mark.parametrize("num_warps", [1, 4])
def test_scan2d_broadcast(RBLOCK, num_warps):
@triton.jit(debug=True)
Expand Down

0 comments on commit ed5a530

Please sign in to comment.