diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index 3e08c9a74..b4fa1e5ef 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -142,7 +142,7 @@ def enable(lib=aten_lib): lib.impl("fill.Scalar", fill_scalar, "CUDA") lib.impl("fill.Tensor", fill_tensor, "CUDA") lib.impl("flip", flip, "CUDA") - lib.impl("slice_scatter", slice_scatter, "CUDA") + lib.impl("slice_scatter", slice_scatter_v2, "CUDA") lib.impl("select_scatter", select_scatter, "CUDA") lib.impl("index_select", index_select, "CUDA") lib.impl("tile", tile, "CUDA") diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 746657597..f7f80c6fa 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -88,7 +88,7 @@ from .sigmoid import sigmoid from .silu import silu from .sin import sin -from .slice_scatter import slice_scatter +from .slice_scatter import slice_scatter, slice_scatter_v2 from .softmax import softmax from .stack import stack from .sub import sub @@ -234,6 +234,7 @@ "where_scalar_other", "select_scatter", "slice_scatter", + "slice_scatter_v2", "masked_fill", "_unique2", "_upsample_bicubic2d_aa", diff --git a/src/flag_gems/ops/slice_scatter.py b/src/flag_gems/ops/slice_scatter.py index 71881a5c4..7bbf516df 100644 --- a/src/flag_gems/ops/slice_scatter.py +++ b/src/flag_gems/ops/slice_scatter.py @@ -94,3 +94,491 @@ def slice_scatter(inp, src, dim=0, start=None, end=None, step=1): slice_scatter_kernel[grid](out, indices, src, src_offsets, M, N) return out + + +def simplify(x, retain_dim, ordered_dims=None): + # This helper function tries to create a new view of the input such that + # the original dims other than the retain_dim are coalesced into one outer dim + # and/or one inner dim. + ordered_dims = ordered_dims or sorted(range(x.ndim), key=lambda i: x.stride(i)) + assert x.ndim == len(ordered_dims) + if len(ordered_dims) == 1: + return x, ordered_dims + + size_list = [x.size(dim) for dim in ordered_dims] + stride_list = [x.stride(dim) for dim in ordered_dims] + + n = ordered_dims.index(retain_dim) + + # Try to merge into a 3d tensor, retain_dim is kept in the middle + new_sizes = [1, size_list[n], 1] + new_dim = 0 + for i in range(x.ndim): + if i == n: + new_dim += 2 + continue + new_sizes[new_dim] *= size_list[i] + if 0 < i < n or i > n + 1: + if stride_list[i] != stride_list[i - 1] * size_list[i]: + # cannot merge! + return None, None + + K, N, M = new_sizes + K_stride, N_stride = stride_list[0], stride_list[n] + M_stride = stride_list[n + 1] if n < x.ndim - 1 else None + + if n == 0: + new_x = x.as_strided((M, N), (M_stride, N_stride)) + elif n == x.ndim - 1: + new_x = x.as_strided((N, K), (N_stride, K_stride)) + else: + new_x = x.as_strided((M, N, K), (M_stride, N_stride, K_stride)) + + return new_x, ordered_dims + + +@triton.jit( + do_not_specialize=[ + "N1", + "N2", + "K", + "am_stride", + "an_stride", + "ak_stride", + "bm_stride", + "bn_stride", + "bk_stride", + "start", + "end", + "step", + ] +) +def scatter_by_row_kernel( + A, + B, + A_out, + # mid dim size of A and A_out + N1, + # mid dim size of B + N2, + # inner dim size of inp, src and out + K, + # strides of A and A_out + am_stride, + an_stride, + ak_stride, + # strides of B + bm_stride, + bn_stride, + bk_stride, + # slice start index + start, + # slice end index + end, + # slice step + step, + BLOCK: tl.constexpr, +): + pidx = tl.program_id(0) + pidy = tl.program_id(1) + + am_idx = pidx // N1 + an_idx = pidx % N1 + + # inp and out share the same layout + offset = BLOCK * pidy + tl.arange(0, BLOCK) + mask = offset < K + + if (start <= an_idx) & (an_idx < end) & ((an_idx - start) % step == 0): + # mid dim) % step == 0 size): of src + tl.device_assert(N2 == (end - start) // step) + # slice index in B + bn_idx = (an_idx - start) // step + B_start = B + am_idx * bm_stride + bn_idx * bn_stride + slice = tl.load(B_start + offset, mask=mask) + else: + # slice index in A + A_start = A + am_idx * am_stride + an_idx * an_stride + slice = tl.load(A_start + offset, mask=mask) + + A_out_start = A_out + am_idx * am_stride + an_idx * an_stride + tl.store(A_out_start + offset, slice, mask=mask) + + +def scatter_by_row(A, B, A_out, start, end, step): + assert A_out.ndim in (2, 3) + if A_out.ndim == 2: + A_out = A_out.unsqueeze(0) + B = B.unsqueeze(0) + M, N1, K = A_out.size() + N2 = B.size(1) + am_stride, an_stride, ak_stride = A_out.stride() + bm_stride, bn_stride, bk_stride = B.stride() + + grid = lambda meta: (M * N1, triton.cdiv(K, meta["BLOCK"])) + scatter_by_row_kernel[grid]( + A, + B, + A_out, + N1, + N2, + K, + am_stride, + an_stride, + ak_stride, + bm_stride, + bn_stride, + bk_stride, + start, + end, + step, + BLOCK=512, + ) + + +@triton.jit( + do_not_specialize=[ + "N1", + "N2", + "K", + "am_stride", + "an_stride", + "ak_stride", + "bm_stride", + "bn_stride", + "bk_stride", + "start", + "end", + "step", + ] +) +def scatter_3d_mid_kernel( + A, + B, + A_out, + # mid dim size of A and A_out + N1, + # mid dim size of B + N2, + # inner dim size of inp, src and out + K, + # strides of A and A_out + am_stride, + an_stride, + ak_stride, + # strides of B + bm_stride, + bn_stride, + bk_stride, + # slice start index + start, + # slice end index + end, + # slice step + step, + NBLOCK: tl.constexpr, + KBLOCK: tl.constexpr, +): + # Each cta processes one [1, NBLOCK, KBLOCK] chunk for input and output + # The src chunk size is dynamically determined + + pidx = tl.program_id(0) + pidy = tl.program_id(1) + + am_idx = pidx + an_idx = pidy * NBLOCK + + # Offsets into inp and out chunks + n_idx_tile = an_idx + tl.arange(0, NBLOCK)[:, None] + k_idx_tile = tl.arange(0, KBLOCK)[None, :] + a_offset = am_idx * am_stride + n_idx_tile * an_stride + k_idx_tile * ak_stride + n_mask = n_idx_tile < N1 + k_mask = k_idx_tile < K + + # Offsets into src + b_offset = am_idx * bm_stride + b_offset += (n_idx_tile - start) * bn_stride // step + b_offset += k_idx_tile * bk_stride + # This mask applies to the [NBLOCK, KBLOCK] + bn_mask = start <= n_idx_tile + bn_mask &= n_idx_tile < end + bn_mask &= (n_idx_tile - start) % step == 0 + + # merge inp and src then write back + inp = tl.load(A + a_offset, mask=n_mask & k_mask) + src = tl.load(B + b_offset, mask=bn_mask & k_mask) + out = tl.where(bn_mask & k_mask, src, inp) + tl.store(A_out + a_offset, out, mask=n_mask & k_mask) + # tl.store(A_out + a_offset, src, mask=bn_mask & k_mask) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config(kwargs={"R": 1, "C": 512}, num_warps=4), + triton.Config(kwargs={"R": 32, "C": 32}, num_warps=4), + triton.Config(kwargs={"R": 64, "C": 64}, num_warps=4), + triton.Config(kwargs={"R": 4, "C": 512}, num_warps=4), + triton.Config(kwargs={"R": 16, "C": 128}, num_warps=4), + ], + key=["strided", "pivoted"], +) +@triton.heuristics( + values={ + "predicate_load": lambda args: args["x_ncol"] > 0.5 * args["y_ncol"], + } +) +@triton.jit( + do_not_specialize=[ + "nrow", + "x_ncol", + "y_ncol", + "x_si", + "x_sj", + "y_si", + "y_sj", + "start", + "end", + "step", + ] +) +def scatter_2d_inner_kernel( + X, + Y, + X_out, + # sizes + nrow, + x_ncol, + y_ncol, + # strides + x_si, + x_sj, + y_si, + y_sj, + # slice + start, + end, + step, + strided: tl.constexpr, + pivoted: tl.constexpr, + predicate_load: tl.constexpr, + R: tl.constexpr, + C: tl.constexpr, +): + i0 = tl.program_id(0) * R + j0 = tl.program_id(1) * C + ii = i0 + tl.arange(0, R)[:, None] + jj = j0 + tl.arange(0, C)[None, :] + + if predicate_load: + # predicate then load + px = X + ii * x_si + jj * x_sj + if (j0 + C < start) | (j0 >= end): + p = px + else: + py = Y + ii * y_si + (jj - start) * y_sj // step + mask = (start <= jj) & (jj < end) & ((jj - start) % step == 0) + p = tl.where((ii < nrow) & mask, py, px) + tmp = tl.load(p, mask=(ii < nrow) & (jj < x_ncol)) + tl.store(X_out + ii * x_si + jj * x_sj, tmp, mask=(ii < nrow) & (jj < x_ncol)) + else: + # load then predicate + x = tl.load(X + ii * x_si + jj * x_sj, mask=(ii < nrow) & (jj < x_ncol)) + if (j0 + C < start) | (j0 >= end): + tl.store(X_out + ii * x_si + jj * x_sj, x, mask=(ii < nrow) & (jj < x_ncol)) + else: + mask = (start <= jj) & (jj < end) & ((jj - start) % step == 0) + y = tl.load( + Y + ii * y_si + (jj - start) * y_sj // step, mask=(ii < nrow) & mask + ) + z = tl.where((ii < nrow) & mask, y, x) + tl.store(X_out + ii * x_si + jj * x_sj, z, mask=(ii < nrow) & (jj < x_ncol)) + + +@triton.jit( + do_not_specialize=[ + "nrow_A", + "nrow_B", + "ncol", + "a_stride0", + "a_stride1", + "b_stride0", + "b_stride1", + "start", + "end", + "step", + ] +) +def scatter_2d_outer_kernel( + A, + B, + A_out, + nrow_A, + nrow_B, + ncol, + a_stride0, + a_stride1, + b_stride0, + b_stride1, + start, + end, + step, + NROW: tl.constexpr, + NCOL: tl.constexpr, +): + pidx = tl.program_id(0) + pidy = tl.program_id(1) + + row_idx = pidx * NROW + col_idx = pidy * NCOL + + # Offsets into inp and out chunks + row_idx = row_idx + tl.arange(0, NROW)[:, None] + col_idx = col_idx + tl.arange(0, NCOL)[None, :] + a_offset = row_idx * a_stride0 + col_idx * a_stride1 + a_row_mask = row_idx < nrow_A + col_mask = col_idx < ncol + + # Offsets into src + b_offset = (row_idx - start) * b_stride0 // step + b_offset += col_idx * b_stride1 + # This mask applies to the [NROW, NCOL] + b_row_mask = start <= row_idx + b_row_mask &= row_idx < end + b_row_mask &= (row_idx - start) % step == 0 + + # merge inp and src then write back + inp = tl.load(A + a_offset, mask=a_row_mask & col_mask) + src = tl.load(B + b_offset, mask=b_row_mask & col_mask) + out = tl.where(b_row_mask & col_mask, src, inp) + tl.store(A_out + a_offset, out, mask=a_row_mask & col_mask) + + +def scatter_2d_inner(x, y, x_out, start, end, step): + nrow, x_ncol = x_out.size() + y_ncol = y.size(1) + x_stride_i, x_stride_j = x_out.stride() + y_stride_i, y_stride_j = y.stride() + strided = step > 1 + pivoted = y_stride_i < y_stride_j + grid = lambda meta: ( + triton.cdiv(nrow, meta["R"]), + triton.cdiv(x_ncol, meta["C"]), + ) + scatter_2d_inner_kernel[grid]( + x, + y, + x_out, + nrow, + x_ncol, + y_ncol, + x_stride_i, + x_stride_j, + y_stride_i, + y_stride_j, + start, + end, + step, + strided=strided, + pivoted=pivoted, + ) + + +def scatter_2d_outer(A, B, A_out, start, end, step): + nrow_A, ncol = A_out.size() + nrow_B = B.size(0) + a_stride0, a_stride1 = A_out.stride() + b_stride0, b_stride1 = B.stride() + grid = lambda meta: ( + triton.cdiv(nrow_A, meta["NROW"]), + triton.cdiv(ncol, meta["NCOL"]), + ) + scatter_2d_outer_kernel[grid]( + A, + B, + A_out, + nrow_A, + nrow_B, + ncol, + a_stride0, + a_stride1, + b_stride0, + b_stride1, + start, + end, + step, + NROW=64, + NCOL=64, + ) + + +def scatter_3d_mid(A, B, A_out, start, end, step): + M, N1, K = A_out.size() + N2 = B.size(1) + am_stride, an_stride, ak_stride = A_out.stride() + bm_stride, bn_stride, bk_stride = B.stride() + + grid = lambda meta: (M, triton.cdiv(N1, meta["NBLOCK"])) + scatter_3d_mid_kernel[grid]( + A, + B, + A_out, + N1, + N2, + K, + am_stride, + an_stride, + ak_stride, + bm_stride, + bn_stride, + bk_stride, + start, + end, + step, + NBLOCK=512, + KBLOCK=triton.next_power_of_2(K), + ) + + +def slice_scatter_v2(inp, src, dim=0, start=None, end=None, step=1): + logging.debug("GEMS SLICE_SCATTER") + assert src.device == inp.device, "inp and src reside on different devices." + assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim" + assert step > 0, "slice step must be positive" + dim = dim % inp.ndim + + start = start or 0 + end = end or inp.size(dim) + + out = torch.empty_strided( + inp.size(), inp.stride(), dtype=inp.dtype, device=inp.device + ) + + # Look for a permute of dims so that the outer dims and inner dims relative to dim + # after permute can be coalesced. + # But this can be difficult so we're resorting to a suffcient and not necessary condition. + new_out, ordered_dims = simplify(out, dim) + if new_out is not None: + new_src, _ = simplify(src, dim, ordered_dims) + else: + new_src = None + + if new_out is not None and new_src is not None: + if dim == ordered_dims[0]: + if len(ordered_dims) == 1: + new_out = new_out.unsqueeze(0) + new_src = new_src.unsqueeze(0) + # slices on inner dim + scatter_2d_inner(inp, new_src, new_out, start, end, step) + elif new_src.stride(-1) == new_out.stride(-1) == 1 and new_src.size(-1) >= 128: + # slices on outer dims while inner dims are contiguous + scatter_by_row(inp, new_src, new_out, start, end, step) + elif dim == ordered_dims[-1]: + # slices on outer dims while inner dims may not be contiguous + scatter_2d_outer(inp, new_src, new_out, start, end, step) + else: + # slices on middle dims + scatter_3d_mid(inp, new_src, new_out, start, end, step) + return out + # Fall back + return slice_scatter(inp, src, dim, start, end, step) diff --git a/tests/accuracy_utils.py b/tests/accuracy_utils.py index cacca65ff..0be6544a0 100644 --- a/tests/accuracy_utils.py +++ b/tests/accuracy_utils.py @@ -60,7 +60,42 @@ def SkipTorchVersion(skip_pattern): [(16, 256), (16, 256)], [(20, 320, 15), (20, 320, 15), (20, 320, 15)], ] +CONTIGUOUS_SHAPE_STRIDES_1D = [ + ((1,), (1,)), + ((1024,), (1,)), + ((1000000,), (1,)), +] +DILATED_SHAPE_STRIDES_1D = [ + ((1,), (2,)), + ((1024,), (2,)), + ((1000000,), (2,)), +] +CONTIGUOUS_SHAPE_STRIDES_2D = [ + ((1, 1024), (1024, 1)), + ((10000, 128), (128, 1)), +] +TRANSPOSED_SHAPE_STRIDES_2D = [ + ((1024, 1), (1, 1024)), + ((128, 10000), (1, 128)), +] +CONTIGUOUS_SHAPE_STRIDES_3D = [ + ((20, 320, 15), (4800, 15, 1)), + ((200, 40999, 3), (122997, 3, 1)), +] +TRANSPOSED_SHAPE_STRIDES_3D = [ + ((320, 20, 15), (15, 4800, 1)), + ((3, 40999, 200), (1, 3, 122997)), +] +SHAPE_STRIDES = ( + CONTIGUOUS_SHAPE_STRIDES_1D + + DILATED_SHAPE_STRIDES_1D + + CONTIGUOUS_SHAPE_STRIDES_2D + + TRANSPOSED_SHAPE_STRIDES_2D + + CONTIGUOUS_SHAPE_STRIDES_3D + + TRANSPOSED_SHAPE_STRIDES_3D +) +IRREGULAR_SHAPE_STRIDES = [((10, 10, 10, 10, 10), (1, 10000, 23, 399, 1024))] UPSAMPLE_SHAPES = [ (32, 16, 128, 128), diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 542516452..4197ccc98 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1,13 +1,18 @@ +import random + import pytest import torch import flag_gems from .accuracy_utils import ( + CONTIGUOUS_SHAPE_STRIDES_2D, FLOAT_DTYPES, INT_DTYPES, + IRREGULAR_SHAPE_STRIDES, REDUCTION_SHAPES, REDUCTION_SMALL_SHAPES, + SHAPE_STRIDES, gems_assert_close, gems_assert_equal, to_reference, @@ -32,11 +37,24 @@ if QUICK_MODE else list(zip([1, 0.1, 0], REDUCTION_SHAPES)) ) -DIM_SHAPE = ( - [(1, REDUCTION_SMALL_SHAPES[0])] +DIM_SHAPE_STRIDES = ( + [(1, *CONTIGUOUS_SHAPE_STRIDES_2D[1])] if QUICK_MODE - else list(zip([0, 1, 1], REDUCTION_SMALL_SHAPES)) + else list( + (random.randint(0, len(shape) - 1), shape, stride) + for shape, stride in SHAPE_STRIDES + ) +) +REGULAR_DIM_SHAPE_STRIDES = ( + [(1, *CONTIGUOUS_SHAPE_STRIDES_2D[1])] + if QUICK_MODE + else list( + (random.randint(0, len(shape) - 1), shape, stride) + for shape, stride in CONTIGUOUS_SHAPE_STRIDES_2D + ) ) +IRREGULAR_DIM_SHAPE_STRIDES = [(3, *IRREGULAR_SHAPE_STRIDES)] + THRESHOLD_SHAPE = ( [(0.3, REDUCTION_SHAPES[0])] if QUICK_MODE @@ -467,26 +485,29 @@ def test_accuracy_select_scatter(shape, dim, dtype): @pytest.mark.slice_scatter -@pytest.mark.parametrize(("dim", "shape"), DIM_SHAPE) +@pytest.mark.parametrize(("dim", "shape", "stride"), DIM_SHAPE_STRIDES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) @pytest.mark.parametrize("start", [16, 64]) @pytest.mark.parametrize("end", [1024, 256]) @pytest.mark.parametrize("step", [1, 2]) -def test_accuracy_slice_scatter(shape, dim, dtype, start, end, step): - inp = torch.randn(shape, dtype=dtype, device="cuda") +def test_accuracy_slice_scatter_v2(shape, stride, dim, dtype, start, end, step): + inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda") + inp.copy_(1) - range = end - start valid_shape = list(inp.shape) + size = valid_shape[dim] + + start = start % size + end = end % (size + 1) + if end < start: - range = 0 - elif (end - start) > valid_shape[dim]: - range = valid_shape[dim] - start = 0 - end = valid_shape[dim] + end, start = start, end + elif end == start: + end = size - valid_shape[dim] = (range + (step - 1)) // step + valid_shape[dim] = (end - start + step - 1) // step - src = torch.randn(valid_shape, dtype=dtype, device="cuda") + src = torch.rand(valid_shape, dtype=dtype, device="cuda") ref_inp = to_reference(inp) ref_src = to_reference(src) @@ -501,6 +522,44 @@ def test_accuracy_slice_scatter(shape, dim, dtype, start, end, step): gems_assert_equal(res_out, ref_out) +@pytest.mark.slice_scatter +@pytest.mark.parametrize(("dim", "shape", "stride"), REGULAR_DIM_SHAPE_STRIDES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +@pytest.mark.parametrize("start", [16, 64]) +@pytest.mark.parametrize("end", [1024, 256]) +@pytest.mark.parametrize("step", [1, 2]) +def test_accuracy_slice_scatter_fallback(shape, stride, dim, dtype, start, end, step): + inp = torch.empty_strided(shape, stride, dtype=dtype, device="cuda") + inp.copy_(1) + + valid_shape = list(inp.shape) + size = valid_shape[dim] + + start = start % size + end = end % (size + 1) + + if end < start: + end, start = start, end + elif end == start: + end = size + + valid_shape[dim] = (end - start + step - 1) // step + + src = torch.rand(valid_shape, dtype=dtype, device="cuda") + + ref_inp = to_reference(inp) + ref_src = to_reference(src) + ref_out = torch.slice_scatter( + ref_inp, dim=dim, src=ref_src, start=start, end=end, step=step + ) + + res_out = flag_gems.ops.slice_scatter( + inp, dim=dim, src=src, start=start, end=end, step=step + ) + + gems_assert_equal(res_out, ref_out) + + # TODO: failed at (200, 40999, 3) @pytest.mark.index_select @pytest.mark.parametrize("shape", REDUCTION_SHAPES)