Skip to content

Commit

Permalink
add argmin (FlagOpen#318) (FlagOpen#346)
Browse files Browse the repository at this point in the history
* add argmin

* support int dtype and int64 index

* add test for dim=None

* rebase master

---------

Co-authored-by: wuyangjun <[email protected]>
  • Loading branch information
wyjoutstanding and wuyangjun authored Jan 10, 2025
1 parent ac17cd1 commit bb358f3
Show file tree
Hide file tree
Showing 8 changed files with 214 additions and 0 deletions.
1 change: 1 addition & 0 deletions OperatorList.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
- tanh
- amax
- argmax
- argmin
- max
- min
- outer
Expand Down
1 change: 1 addition & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def get_input_iter(self, cur_dtype) -> Generator:
("amax", torch.amax, FLOAT_DTYPES),
("any", torch.any, FLOAT_DTYPES),
("argmax", torch.argmax, FLOAT_DTYPES),
("argmin", torch.argmin, FLOAT_DTYPES),
("max", torch.max, FLOAT_DTYPES),
("mean", torch.mean, FLOAT_DTYPES),
("min", torch.min, FLOAT_DTYPES),
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def enable(lib=aten_lib, unused=None, registrar=registrar):
("min.dim", min_dim, Autograd.disable),
("amax", amax, Autograd.disable),
("argmax", argmax, Autograd.disable),
("argmin", argmin, Autograd.disable),
("prod", prod, Autograd.disable),
("prod.dim_int", prod_dim, Autograd.disable),
("sum", sum, Autograd.disable),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .any import any, any_dim, any_dims
from .arange import arange, arange_start
from .argmax import argmax
from .argmin import argmin
from .attention import scaled_dot_product_attention
from .batch_norm import batch_norm
from .bitwise_and import (
Expand Down Expand Up @@ -249,6 +250,7 @@
"sum_dim",
"amax",
"argmax",
"argmin",
"prod",
"prod_dim",
"var_mean",
Expand Down
168 changes: 168 additions & 0 deletions src/flag_gems/ops/argmin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import logging
import math

import torch
import triton
import triton.language as tl

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle

torch_dtype_to_tl_dtype_and_max_value = {
torch.int16: (tl.int16, torch.iinfo(torch.int16).max),
torch.int32: (tl.int32, torch.iinfo(torch.int32).max),
torch.float16: (tl.float16, torch.finfo(torch.float16).max),
torch.float32: (tl.float32, torch.finfo(torch.float32).max),
torch.bfloat16: (tl.float32, torch.finfo(torch.float32).max),
}


@libentry()
@triton.jit
def argmin_kernel_1(
inp,
mid_value,
mid_index,
M,
BLOCK_SIZE: tl.constexpr,
):
pid = tle.program_id(0)
offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
inp_ptrs = inp + offset
mask = offset < M
inp_val = tl.load(inp_ptrs, mask=mask, other=float("inf"))
min_val, min_index = tl.min(inp_val, axis=0, return_indices=True)
min_index = min_index + pid * BLOCK_SIZE
mid_value_ptr = mid_value + pid
min_index_ptr = mid_index + pid
tl.store(mid_value_ptr, min_val)
tl.store(min_index_ptr, min_index)


@libentry()
@triton.jit
def argmin_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):
offset = tl.arange(0, BLOCK_MID)
mid_ptrs = mid_value + offset
mask = offset < mid_size
mid_val = tl.load(mid_ptrs, mask=mask, other=float("inf"))
index_val = tl.argmin(mid_val, axis=0)
mid_index_ptrs = mid_index + index_val
out_val = tl.load(mid_index_ptrs)
tl.store(out, out_val)


def heur_block_n(args):
return min(4096, triton.next_power_of_2(args["N"]))


@libentry()
@triton.heuristics(runtime.get_heuristic_config("argmin"))
@triton.jit
def argmin_kernel(
inp,
out_index,
M,
N,
K,
tl_dtype: tl.constexpr,
dtype_max_value: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
# set offset
pid_m = tle.program_id(0)
pid_k = tle.program_id(1)
m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)

# min_values = tl.full([BLOCK_M], dtype=tl.float32, value=float("inf"))
min_values = tl.full([BLOCK_M], dtype=tl_dtype, value=dtype_max_value)
argmin_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)
for start_n in range(0, N, BLOCK_N):
n_offset = start_n + tl.arange(0, BLOCK_N)
offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k
mask = m_offset[:, None] < M and n_offset[None, :] < N
inp_ptrs = inp + offset
# inp_vals = tl.load(inp_ptrs, mask=mask, other=float("inf"))
inp_vals = tl.load(inp_ptrs, mask=mask, other=dtype_max_value)
local_min, local_argmin = tl.min(
inp_vals, 1, return_indices=True, return_indices_tie_break_left=True
)
# if return indices is not supported, call a tl.argmin in addition
# local_argmin = tl.argmin(inp_vals, 1)
update = local_min < min_values
min_values = tl.where(update, local_min, min_values)
argmin_values = tl.where(update, start_n + local_argmin, argmin_values)

offset_index = m_offset * K + pid_k
out_index_ptrs = out_index + offset_index
mask1 = m_offset < M
tl.store(out_index_ptrs, argmin_values, mask=mask1)


def argmin(inp, dim=None, keepdim=False, *, dtype=None):
logging.debug("GEMS argmin")
if dim is None:
M = inp.numel()
if dtype is None:
dtype = inp.dtype
block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))
mid_size = triton.cdiv(M, block_size)
block_mid = triton.next_power_of_2(mid_size)

mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)
mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)
if keepdim:
shape = list(inp.shape)
for i in range(0, inp.dim()):
shape[i] = 1
out = torch.empty(shape, dtype=torch.int64, device=inp.device)
else:
out = torch.empty([], dtype=torch.int64, device=inp.device)

with torch_device_fn.device(inp.device):
argmin_kernel_1[(mid_size, 1, 1)](
inp,
mid_value,
mid_index,
M,
block_size,
)
argmin_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)
return out
else:
assert dim >= -inp.ndim and dim < inp.ndim, "Invalid dim"
shape = inp.shape
dim = dim % inp.ndim
N = shape[dim]
M = math.prod(shape[:dim])
K = inp.numel() // M // N

inp = inp.contiguous()

shape_list = list(shape)
shape_list[dim] = 1
out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)
if not keepdim:
out_index = torch.squeeze(out_index, dim)

tl_dtype, dtype_max_value = torch_dtype_to_tl_dtype_and_max_value[inp.dtype]

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_M"]),
K,
)
with torch_device_fn.device(inp.device):
argmin_kernel[grid](
inp,
out_index,
M,
N,
K,
tl_dtype,
dtype_max_value,
)

return out_index
12 changes: 12 additions & 0 deletions src/flag_gems/runtime/backend/_nvidia/heuristics_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ def argmax_heur_block_n(args):
return min(4096, triton.next_power_of_2(args["N"]))


def argmin_heur_block_m(args):
return 4 if args["M"] < 4096 else 8


def argmin_heur_block_n(args):
return min(4096, triton.next_power_of_2(args["N"]))


def bmm_heur_divisible_m(args):
return args["M"] % args["TILE_M"] == 0

Expand Down Expand Up @@ -211,6 +219,10 @@ def batch_norm_heur_block_n(args):
"BLOCK_M": argmax_heur_block_m,
"BLOCK_N": argmax_heur_block_n,
},
"argmin": {
"BLOCK_M": argmin_heur_block_m,
"BLOCK_N": argmin_heur_block_n,
},
"bmm": {
"DIVISIBLE_M": bmm_heur_divisible_m,
"DIVISIBLE_N": bmm_heur_divisible_n,
Expand Down
10 changes: 10 additions & 0 deletions src/flag_gems/runtime/backend/_nvidia/tune_configs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,16 @@ bmm:
num_warps: 4
num_stages: 3
argmax:
- META:
BLOCK_M: 8
num_warps: 8
- META:
BLOCK_M: 16
num_warps: 8
- META:
BLOCK_M: 32
num_warps: 8
argmin:
- META:
BLOCK_M: 8
num_warps: 8
Expand Down
19 changes: 19 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,25 @@ def test_accuracy_argmax(shape, dim, keepdim, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.argmin
@pytest.mark.parametrize("shape", REDUCTION_SMALL_SHAPES)
@pytest.mark.parametrize("dim", DIM_LIST + [None])
@pytest.mark.parametrize("keepdim", [True, False])
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES)
def test_accuracy_argmin(shape, dim, keepdim, dtype):
if dtype in INT_DTYPES:
inp = torch.randint(-1024, 1024, size=shape, device=flag_gems.device).to(dtype)
else:
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)
ref_inp = to_reference(inp)

ref_out = torch.argmin(ref_inp, dim=dim, keepdim=keepdim)
with flag_gems.use_gems():
res_out = torch.argmin(inp, dim=dim, keepdim=keepdim)

gems_assert_equal(res_out, ref_out)


@pytest.mark.CrossEntropyLoss
@pytest.mark.parametrize("label_smoothing, ignore_index, shape", SMOOTH_IGNORE_SHAPE)
@pytest.mark.parametrize("reduction", CROSS_ENTROPY_LOSS_REDUCTION)
Expand Down

0 comments on commit bb358f3

Please sign in to comment.