Skip to content

Commit

Permalink
add default scheduling params
Browse files Browse the repository at this point in the history
Signed-off-by: Ivan Butygin <[email protected]>
  • Loading branch information
Hardcode84 committed Nov 26, 2024
1 parent 99bc0a2 commit 2874ca5
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 62 deletions.
18 changes: 18 additions & 0 deletions iree/turbine/kernel/wave/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,24 @@ def get_default_compile_config() -> dict[Any, Any]:
return {"backend": "rocm", "device": "hip", "target": "gfx942"}


def get_default_scheduling_params() -> dict[IndexSymbol, Any]:
# TODO: get values based get_default_arch()
return {
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
VALU_DELAY: 1,
SHUFFLE_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
VALU_UNITS: 2,
SHUFFLE_UNITS: 2,
}


def print_trace(trace: CapturedTrace, custom_print: bool = True):
"""
Prints all subgraphs of a trace starting with the root graph.
Expand Down
45 changes: 5 additions & 40 deletions tests/kernel/wave/wave_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from iree.turbine.kernel.wave.utils import (
get_default_run_config,
get_default_arch,
get_default_scheduling_params,
get_mfma_load_elems_per_thread,
get_mfma_store_elems_per_thread,
device_randn,
Expand Down Expand Up @@ -165,15 +166,8 @@ def repeat(
N: shape[2],
K1: shape[3],
K2: shape[4],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
}
hyperparams.update(get_default_scheduling_params())
config = get_default_run_config()
if run_bench:
config["benchmark_batch_size"] = 10
Expand Down Expand Up @@ -314,15 +308,8 @@ def repeat(
N: shape[2],
K1: shape[3],
K2: shape[4],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
}
hyperparams.update(get_default_scheduling_params())
config = get_default_run_config()
if run_bench:
config["benchmark_batch_size"] = 10
Expand Down Expand Up @@ -491,19 +478,8 @@ def repeat(
N: shape[2],
K1: shape[3],
K2: shape[4],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
VALU_DELAY: 1,
SHUFFLE_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
VALU_UNITS: 2,
SHUFFLE_UNITS: 2,
}
hyperparams.update(get_default_scheduling_params())
config = get_default_run_config()
if run_bench:
config["benchmark_batch_size"] = 10
Expand Down Expand Up @@ -682,19 +658,8 @@ def repeat(
N: shape[2],
K1: shape[3],
K2: shape[4],
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
VALU_DELAY: 1,
SHUFFLE_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
VALU_UNITS: 2,
SHUFFLE_UNITS: 2,
}
hyperparams.update(get_default_scheduling_params())
config = get_default_run_config()
if run_bench:
config["benchmark_batch_size"] = 10
Expand Down
40 changes: 18 additions & 22 deletions tests/kernel/wave/wave_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from iree.turbine.kernel.wave.utils import (
get_default_arch,
get_default_run_config,
get_default_scheduling_params,
device_randn,
device_randint,
device_randperm,
Expand Down Expand Up @@ -994,6 +995,22 @@ def test_igemm_conv(n, h, w, c, hf, wf, nf, stride, mem_space, layout, request):
)
]

hyperparams = {
N: n,
C: c,
W: w,
H: h,
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: 64,
BLOCK_N: 128,
BLOCK_K: 32,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: mem_space,
}
hyperparams.update(get_default_scheduling_params())

@tkw.wave(constraints)
def conv(
x: x_type,
Expand Down Expand Up @@ -1037,28 +1054,7 @@ def repeat(acc: tkl.Register[M, NF, tkl.f32]) -> tkl.Register[M, NF, tkl.f32]:
)

with tk.gen.TestLaunchContext(
{
N: n,
C: c,
W: w,
H: h,
NF: nf,
WF: wf,
HF: hf,
BLOCK_M: 64,
BLOCK_N: 128,
BLOCK_K: 32,
ELEMS_PER_THREAD: 4,
ADDRESS_SPACE: mem_space,
READ_SHARED_DELAY: 1,
WRITE_SHARED_DELAY: 1,
READ_GLOBAL_DELAY: 2,
WRITE_GLOBAL_DELAY: 2,
MMA_DELAY: 1,
SHARED_MEMORY_UNITS: 4,
GLOBAL_MEMORY_UNITS: 4,
MMA_UNITS: 4,
},
hyperparams,
canonicalize=True,
run=True,
run_bench=run_bench,
Expand Down

0 comments on commit 2874ca5

Please sign in to comment.