Skip to content

Commit

Permalink
[Muti backend] add heuristics_config for muti_bacnkend (FlagOpen#377)
Browse files Browse the repository at this point in the history
* muti_backend

* [Fix bugs] Variable reference dependency error (FlagOpen#374)

* fix benchmark bugs

* fix

* [bugfix] disable more shapes from blas benchmark (FlagOpen#375)

* add heuristics_config

* modify

* modify

* modify

* modify

---------

Co-authored-by: StrongSpoon <[email protected]>
  • Loading branch information
Galaxy1458 and StrongSpoon authored Dec 26, 2024
1 parent 3183842 commit 101227f
Show file tree
Hide file tree
Showing 45 changed files with 377 additions and 366 deletions.
2 changes: 1 addition & 1 deletion src/flag_gems/ops/addmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("addmm"),
configs=runtime.get_tuned_config("addmm"),
key=["M", "N", "K"],
)
@triton.jit(do_not_specialize=["alpha", "beta"])
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/all.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def reduce_all(a, b):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("all"), key=["M", "N"])
@triton.autotune(configs=runtime.get_tuned_config("all"), key=["M", "N"])
@triton.jit
def all_kernel_dim(
inp,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/amax.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def amax_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("amax"), key=["M", "N"])
@triton.autotune(configs=runtime.get_tuned_config("amax"), key=["M", "N"])
@triton.jit
def amax_kernel(
inp,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/any.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def reduce_any(a, b):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("any"), key=["M", "N"])
@triton.autotune(configs=runtime.get_tuned_config("any"), key=["M", "N"])
@triton.jit
def any_kernel_dim(
inp,
Expand Down
16 changes: 2 additions & 14 deletions src/flag_gems/ops/argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import triton
import triton.language as tl

from .. import runtime
from ..runtime import torch_device_fn
from ..utils import libentry
from ..utils import triton_lang_extension as tle
Expand Down Expand Up @@ -45,21 +46,8 @@ def argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr
tl.store(out, out_val)


def heur_block_m(args):
return 4 if args["M"] < 4096 else 8


def heur_block_n(args):
return min(4096, triton.next_power_of_2(args["N"]))


@libentry()
@triton.heuristics(
{
"BLOCK_M": heur_block_m,
"BLOCK_N": heur_block_n,
}
)
@triton.heuristics(runtime.get_heuristic_config("argmax"))
@triton.jit
def argmax_kernel(
inp,
Expand Down
2 changes: 1 addition & 1 deletion src/flag_gems/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def early_config_prune(configs, nargs, **kwargs):


@triton.autotune(
configs=runtime.get_triton_config("attention"),
configs=runtime.get_tuned_config("attention"),
key=["KV_CTX", "HEAD_DIM"],
prune_configs_by={
"early_config_prune": early_config_prune,
Expand Down
22 changes: 2 additions & 20 deletions src/flag_gems/ops/bmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,12 @@
from ..utils import triton_lang_extension as tle


def heur_divisible_m(args):
return args["M"] % args["TILE_M"] == 0


def heur_divisible_n(args):
return args["N"] % args["TILE_N"] == 0


def heur_divisible_k(args):
return args["K"] % args["TILE_K"] == 0


@libentry()
@triton.autotune(
configs=runtime.get_triton_config("bmm"),
configs=runtime.get_tuned_config("bmm"),
key=["M", "N", "K"],
)
@triton.heuristics(
{
"DIVISIBLE_M": heur_divisible_m,
"DIVISIBLE_N": heur_divisible_n,
"DIVISIBLE_K": heur_divisible_k,
}
)
@triton.heuristics(runtime.get_heuristic_config("bmm"))
@triton.jit
def bmm_kernel(
A,
Expand Down
4 changes: 2 additions & 2 deletions src/flag_gems/ops/count_nonzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def count_nonzero_kernel_1(x_ptr, out_ptr, numel, BLOCK_SIZE: tl.constexpr):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("count_nonzero"), key=["numel"])
@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
@triton.jit
def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
pid_x = tle.program_id(0)
Expand All @@ -41,7 +41,7 @@ def count_nonzero_kernel(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):


@libentry()
@triton.autotune(configs=runtime.get_triton_config("count_nonzero"), key=["numel"])
@triton.autotune(configs=runtime.get_tuned_config("count_nonzero"), key=["numel"])
@triton.jit
def count_nonzero_combin_kernel_1(x_ptr, out_ptr, N, numel, BLOCK_SIZE: tl.constexpr):
pid_x = tle.program_id(0)
Expand Down
14 changes: 7 additions & 7 deletions src/flag_gems/ops/cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index"])
Expand Down Expand Up @@ -73,7 +73,7 @@ def celoss_indices_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["label_smoothing"])
Expand Down Expand Up @@ -132,7 +132,7 @@ def celoss_probability_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "label_smoothing"])
Expand Down Expand Up @@ -213,7 +213,7 @@ def celoss_indices_smooth_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "mean_num"])
Expand Down Expand Up @@ -284,7 +284,7 @@ def celoss_indices_bwd(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["label_smoothing", "mean_num"])
Expand Down Expand Up @@ -369,7 +369,7 @@ def celoss_probability_bwd(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss"),
configs=runtime.get_tuned_config("cross_entropy_loss"),
key=["C", "D"],
)
@triton.jit(do_not_specialize=["ignore_index", "label_smoothing", "mean_num"])
Expand Down Expand Up @@ -463,7 +463,7 @@ def celoss_indices_smooth_bwd(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("cross_entropy_loss_sum_and_scale"),
configs=runtime.get_tuned_config("cross_entropy_loss_sum_and_scale"),
key=[
"N",
],
Expand Down
31 changes: 3 additions & 28 deletions src/flag_gems/ops/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,11 @@
uint_to_uniform_float,
)

from .. import runtime
from ..runtime import torch_device_fn


def heur_block(args):
if args["N"] <= 512:
return 512
else:
return 1024


def heur_num_warps(args):
if args["N"] <= 512:
return 4
elif args["N"] <= 1024:
return 8
else:
return 16


@triton.heuristics(
{
"BLOCK": heur_block,
"num_warps": heur_num_warps,
}
)
@triton.heuristics(runtime.get_heuristic_config("dropout"))
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
def dropout_forward_kernel(
X,
Expand Down Expand Up @@ -85,12 +65,7 @@ def dropout_forward_kernel(
tl.store(Y + off_3, y3, mask=off_3 < N, eviction_policy="evict_first")


@triton.heuristics(
{
"BLOCK": heur_block,
"num_warps": heur_num_warps,
}
)
@triton.heuristics(runtime.get_heuristic_config("dropout"))
@triton.jit(do_not_specialize=["p", "philox_seed", "philox_offset"])
def dropout_backward_kernel(
DY,
Expand Down
24 changes: 2 additions & 22 deletions src/flag_gems/ops/exponential_.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,11 @@
uint_to_uniform_float,
)

from .. import runtime
from ..runtime import torch_device_fn


def heur_block(args):
if args["N"] <= 512:
return 512
else:
return 1024


def heur_num_warps(args):
if args["N"] <= 512:
return 4
elif args["N"] <= 1024:
return 8
else:
return 16


@triton.heuristics(
{
"BLOCK": heur_block,
"num_warps": heur_num_warps,
}
)
@triton.heuristics(runtime.get_heuristic_config("exponential_"))
@triton.jit(do_not_specialize=["philox_seed", "philox_offset", "N"])
def fused_exponential_kernel(
out_ptr,
Expand Down
20 changes: 1 addition & 19 deletions src/flag_gems/ops/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,11 @@ def generate_gather_kernel(
# make the inlined function visible in the context
code.newline()

code.writeline("def heur_block_m(args):")
with code.indent():
code.writeline(
"return min(4, triton.next_power_of_2(triton.cdiv(args['N'], 2048)))"
)

code.newline()
code.writeline("def heur_block_n(args):")
with code.indent():
code.writeline("return min(2048, triton.next_power_of_2(args['N']))")

code.newline()
code.newline()

# the decorators
code.writeline("@libentry()")
code.writeline("@triton.heuristics(")
with code.indent():
code.writeline("{")
with code.indent():
code.writeline('"BLOCK_M": heur_block_m,')
code.writeline('"BLOCK_N": heur_block_n,')
code.writeline("}")
code.writeline("runtime.get_heuristic_config('gather')")
code.writeline(")")
code.writeline("@triton.jit")

Expand Down
17 changes: 2 additions & 15 deletions src/flag_gems/ops/index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,13 @@
import triton
import triton.language as tl

from .. import runtime
from ..utils import dim_compress, libentry
from ..utils import triton_lang_extension as tle


def heur_block_m(args):
return min(4, triton.next_power_of_2(triton.cdiv(256, args["N"])))


def heur_block_n(args):
m = min(triton.next_power_of_2(triton.cdiv(args["N"], 16)), 512)
return max(m, 16)


@libentry()
@triton.heuristics(
{
"BLOCK_M": heur_block_m,
"BLOCK_N": heur_block_n,
}
)
@triton.heuristics(runtime.get_heuristic_config("index_select"))
@triton.jit
def index_select_kernel(
inp, out, M, N, index, index_len, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
Expand Down
12 changes: 6 additions & 6 deletions src/flag_gems/ops/instancenorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def prev_multiple_of(a, b):

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instancenorm"),
configs=runtime.get_tuned_config("instancenorm"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -72,7 +72,7 @@ def instance_norm_persistent_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instancenorm"),
configs=runtime.get_tuned_config("instancenorm"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -126,7 +126,7 @@ def instance_norm_persistent_kernel_multiline(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instance_norm_loop"),
configs=runtime.get_tuned_config("instance_norm_loop"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -218,7 +218,7 @@ def instance_norm_loop_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instancenorm"),
configs=runtime.get_tuned_config("instancenorm"),
key=["M", "N"],
)
@triton.jit(do_not_specialize=["eps"])
Expand Down Expand Up @@ -312,7 +312,7 @@ def update_running_stats_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instance_norm_backward"),
configs=runtime.get_tuned_config("instance_norm_backward"),
key=["M", "N", "C"],
)
@triton.jit
Expand Down Expand Up @@ -379,7 +379,7 @@ def instance_norm_backward_kernel(

@libentry()
@triton.autotune(
configs=runtime.get_triton_config("instance_norm_weight_bias_backward"),
configs=runtime.get_tuned_config("instance_norm_weight_bias_backward"),
key=["N", "B", "C"],
)
@triton.jit
Expand Down
Loading

0 comments on commit 101227f

Please sign in to comment.