diff --git a/.github/workflows/ci-tk.yaml b/.github/workflows/ci-tk.yaml index fafbf355..bd23693d 100644 --- a/.github/workflows/ci-tk.yaml +++ b/.github/workflows/ci-tk.yaml @@ -57,13 +57,20 @@ jobs: run: | pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ - - name: Run e2e tests on AMD GPU - if: "contains(matrix.os, 'amdgpu') && !cancelled()" + - name: Run e2e tests on AMD GPU MI300 + if: "contains(matrix.os, 'mi300') && !cancelled()" run: | pip install --no-compile -r pytorch-rocm-requirements.txt export WAVE_RUN_E2E_TESTS=1 pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/ + - name: Run e2e tests on AMD GPU MI250 + if: "contains(matrix.os, 'mi250') && !cancelled()" + run: | + pip install --no-compile -r pytorch-rocm-requirements.txt + export WAVE_RUN_E2E_TESTS=1 + pytest -n 2 --capture=tee-sys -vv ./tests/kernel/wave/ + - name: Run LIT tests if: ${{ !cancelled() }} run: | diff --git a/iree/turbine/kernel/wave/hoisting.py b/iree/turbine/kernel/wave/hoisting.py index 5a4773d7..a2efbee4 100644 --- a/iree/turbine/kernel/wave/hoisting.py +++ b/iree/turbine/kernel/wave/hoisting.py @@ -4,6 +4,8 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from .constraints import Constraint +from .utils import get_induction_variable from ...support.logging import get_logger from iree.turbine.kernel._support.tracing import CapturedTrace import torch.fx as fx @@ -13,25 +15,87 @@ logger = get_logger("turbine.wave.hoisting") -def get_allocs(graph: fx.Graph) -> list[CustomOp]: - return [ - custom_node - for node in graph.nodes - if isinstance((custom_node := get_custom(node)), Allocate) - ] +def get_hoistable_ops( + graph: fx.Graph, + captured_vars: list[CustomOp], + induction_variable: IndexExpr, +) -> list[CustomOp]: + """ + Get hoistable ops. Currently only handle allocs and read who doesn't depends on + induction variables. + + Note: For codegen to work properly, we'd need to hoist allocs first. This is to avoid + using alloc before defined/non-dominating behavior. + (e.g hoisting read from global to shared before shared alloc is defined.) + """ + hoistable_allocs = [] + hoistable_ops = [] + for node in graph.nodes: + custom_node = get_custom(node) + if isinstance(custom_node, Allocate): + hoistable_allocs.append(custom_node) + elif isinstance(custom_node, Read): + if custom_node.index is None: + continue + # Only handle case where memory is captured var. + # i.e it has source value from root graph. + if not custom_node.memory in captured_vars: + continue + # Only handle case where we are not writing to the same memory. + # Counterproof: we may expect different read if we write to same memory. + if any( + isinstance(get_custom(mem_user), Write) + for mem_user in custom_node.memory.users + ): + continue + # Only hoist Read that is loop invariant. + if any( + ind.start.has(induction_variable) for ind in custom_node.index.values() + ): + continue + hoistable_ops.append(custom_node) + else: + continue + all_hoistables_ops = hoistable_allocs + hoistable_ops + return all_hoistables_ops + + +def remove_unused_captured_vars(reduction: CustomOp, subgraph: fx.Graph): + captured_vars = reduction.captured_vars(subgraph) + new_implicit_captures = list(reduction.implicit_captures) + for captured_idx in reversed(range(len(captured_vars))): + if len(captured_vars[captured_idx].users) == 0: + get_custom(captured_vars[captured_idx]).erase() + new_implicit_captures.pop(captured_idx) + reduction.update_arg("implicit_captures", new_implicit_captures) -def hoist_allocs(trace: CapturedTrace): - """Hoists allocs from reduction subgraphs to outer root graph.""" +def hoist_loop_invariant_ops(trace: CapturedTrace, constraints: list[Constraint]): + """Hoists ops that are loop-invariant from reduction subgraphs to outer root graph.""" root_graph = trace.get_root_graph() for node in root_graph.nodes: custom_node = get_custom(node) match custom_node: case Reduction(): with root_graph.inserting_before(custom_node.fx_node): + induction_variable = get_induction_variable( + custom_node, constraints + ) subgraph = trace.get_subgraph(custom_node.subgraph_name) - allocs = get_allocs(subgraph) - for alloc in allocs: - new_alloc = alloc.copy(new_graph=root_graph) - alloc.replace_all_uses_with(new_alloc) - alloc.erase() + # Capture/root variables from outside the loop. + implicit_captures = custom_node.implicit_captures + # Captured variables from inside the loop. + captured_vars = custom_node.captured_vars(subgraph) + hoistable_ops = get_hoistable_ops( + subgraph, captured_vars, induction_variable + ) + for hoistable_op in hoistable_ops: + new_op = hoistable_op.copy(new_graph=root_graph) + hoistable_op.replace_all_uses_with(new_op) + hoistable_op.erase() + if isinstance(hoistable_op, Read): + capture_arg = captured_vars.index(hoistable_op.memory) + new_op.update_arg("memory", implicit_captures[capture_arg]) + # Clear/Remove unused captured var to correct codegen. Ops inside + # scf.for will be indexing/loading from the wrong bindings otherwise. + remove_unused_captured_vars(custom_node, subgraph) diff --git a/iree/turbine/kernel/wave/wave.py b/iree/turbine/kernel/wave/wave.py index 63fb620d..aaa43e74 100644 --- a/iree/turbine/kernel/wave/wave.py +++ b/iree/turbine/kernel/wave/wave.py @@ -22,7 +22,7 @@ from .codegen import WaveEmitter from .expansion import expand_graph from .promotion import promote_placeholders -from .hoisting import hoist_allocs +from .hoisting import hoist_loop_invariant_ops from .utils import ( canonicalize_module, compile_and_invoke, @@ -232,7 +232,6 @@ def _trace_and_get_kernel_signature( # Promote the placeholders to the appropriate address space. promote_placeholders(graph, self.constraints) - hoist_allocs(graph) # Set indices. set_node_indices(graph, self.constraints) @@ -250,6 +249,7 @@ def _trace_and_get_kernel_signature( remove_chained_getresult(graph) # Optimizations. + hoist_loop_invariant_ops(graph, self.constraints) minimize_global_loads(graph, self.constraints) # Apply shared memory indexing corrections. diff --git a/lit_tests/kernel/wave/attention.py b/lit_tests/kernel/wave/attention.py index 2c4d22a3..70753b79 100644 --- a/lit_tests/kernel/wave/attention.py +++ b/lit_tests/kernel/wave/attention.py @@ -154,15 +154,15 @@ def repeat( output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) print(dynamic_attention_pipelined(q, k, v, output).module_op) - # CHECK: func.func @dynamic_attention_pipelined - # CHECK-COUNT-4: {{.*}} = vector.maskedload {{.*}} + # CHECK-LABEL: func.func @dynamic_attention_pipelined + # CHECK-COUNT-6: {{.*}} = vector.maskedload {{.*}} # CHECK: {{.*}} = scf.for - # CHECK-COUNT-4: {{.*}} = vector.maskedload {{.*}} - # CHECK-COUNT-13: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-8: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-3: {{.*}} = amdgpu.mfma + # CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}} + # CHECK-COUNT-14: {{.*}} = amdgpu.mfma + # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-7: {{.*}} = amdgpu.mfma + # CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-2: {{.*}} = amdgpu.mfma # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} # CHECK-COUNT-16: vector.maskedstore {{.*}} @@ -281,13 +281,13 @@ def repeat( output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) print(base_attention_pipelined(q, k, v, output).module_op) - # CHECK: func.func @base_attention_pipelined + # CHECK-LABEL: func.func @base_attention_pipelined # CHECK: {{.*}} = scf.for - # CHECK-COUNT-13: {{.*}} = amdgpu.mfma - # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-8: {{.*}} = amdgpu.mfma + # CHECK-COUNT-14: {{.*}} = amdgpu.mfma + # CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}} + # CHECK-COUNT-7: {{.*}} = amdgpu.mfma # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} - # CHECK-COUNT-3: {{.*}} = amdgpu.mfma + # CHECK-COUNT-2: {{.*}} = amdgpu.mfma # CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}} @@ -401,7 +401,7 @@ def repeat( output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) print(base_attention_32x32x8(q, k, v, output).module_op) - # CHECK: func.func @base_attention_32x32x8 + # CHECK-LABEL: func.func @base_attention_32x32x8 # CHECK: {{.*}} = scf.for # CHECK-COUNT-8: {{.*}} = amdgpu.mfma # CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}} @@ -524,7 +524,7 @@ def repeat( output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32) print(base_attention(q, k, v, output).module_op) - # CHECK: func.func @base_attention + # CHECK-LABEL: func.func @base_attention # CHECK: {{.*}} = scf.for # CHECK-COUNT-16: {{.*}} = amdgpu.mfma # CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}} diff --git a/lit_tests/kernel/wave/barriers.py b/lit_tests/kernel/wave/barriers.py index 6ddf48d5..78470c6d 100644 --- a/lit_tests/kernel/wave/barriers.py +++ b/lit_tests/kernel/wave/barriers.py @@ -12,7 +12,7 @@ ) from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers -from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops from iree.turbine.kernel.wave.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * @@ -177,11 +177,11 @@ def test_gemm(): read_nodes = get_read_nodes(graph) for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) - hoist_allocs(trace) set_node_indices(trace, constraints) expand_graph(trace, constraints) set_post_expansion_indices(trace, constraints) tweak_index(graph) + hoist_loop_invariant_ops(trace, constraints) add_shared_memory_barriers(trace) print_trace(trace, False) # Root graph: diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index a2eb3ecc..80ef9a58 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1004,7 +1004,7 @@ def test_chained_gemm(): @tkw.wave(constraints) def chained_gemm( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + q: tkl.Memory[B, M, K1, ADDRESS_SPACE_0, tkl.f16], k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32], @@ -1051,8 +1051,13 @@ def repeat( output = torch.zeros(8, 64, 128, dtype=torch.float32) print(chained_gemm(q, k, v, output).module_op) - # CHECK-LABEL: func.func @chained_gemm( + # CHECK-LABEL: func.func @chained_gemm + # CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding) + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x32x36xf16, #gpu.address_space> + # CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]] + # CHECK-COUNT-4: vector.load %[[GLOBAL_0]] # CHECK: {{.*}} = scf.for + # CHECK-COUNT-4: {{.*}} = vector.load %[[ALLOC]] # CHECK-COUNT-8: {{.*}} = amdgpu.mfma # CHECK-COUNT-4: {{.*}} = arith.truncf # CHECK-COUNT-8: {{.*}} = amdgpu.mfma @@ -1131,7 +1136,10 @@ def repeat( output = torch.zeros(8, 64, 128, dtype=torch.float32) print(chained_gemm_32x32x8(q, k, v, output).module_op) - # CHECK-LABEL: func.func @chained_gemm_32x32x8( + # CHECK-LABEL: func.func @chained_gemm_32x32x8 + # CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding) + # CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]] + # CHECK: %[[GLOBAL_READ_0:.+]] = vector.load %[[GLOBAL_0]] # CHECK: {{.*}} = scf.for # CHECK-COUNT-4: {{.*}} = amdgpu.mfma # CHECK-COUNT-1: {{.*}} = arith.truncf diff --git a/lit_tests/kernel/wave/index_sequence_analysis.py b/lit_tests/kernel/wave/index_sequence_analysis.py index 96409be8..76b8b63d 100644 --- a/lit_tests/kernel/wave/index_sequence_analysis.py +++ b/lit_tests/kernel/wave/index_sequence_analysis.py @@ -6,7 +6,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_placeholders -from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops from iree.turbine.kernel.wave.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * @@ -87,10 +87,10 @@ def test_gemm(): IndexingContext.current().finalize() infer_types(trace) promote_placeholders(trace, constraints) - hoist_allocs(trace) set_node_indices(trace, constraints) expand_graph(trace, constraints) set_post_expansion_indices(trace, constraints) + hoist_loop_invariant_ops(trace, constraints) minimize_global_loads(trace, constraints) apply_shared_memory_indexing_corrections(trace, constraints) partition_strided_operators(trace, constraints) diff --git a/lit_tests/kernel/wave/minimize_global_loads.py b/lit_tests/kernel/wave/minimize_global_loads.py index 90053239..406c8172 100644 --- a/lit_tests/kernel/wave/minimize_global_loads.py +++ b/lit_tests/kernel/wave/minimize_global_loads.py @@ -7,7 +7,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_placeholders -from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers from iree.turbine.kernel.wave.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types @@ -90,12 +90,12 @@ def test_gemm(): IndexingContext.current().finalize() infer_types(trace) promote_placeholders(trace, constraints) - hoist_allocs(trace) set_node_indices(trace, constraints) expand_graph(trace, constraints) set_post_expansion_indices(trace, constraints) if visualize: visualize_graph(trace.get_subgraph("region_0"), "before.png") + hoist_loop_invariant_ops(trace, constraints) minimize_global_loads(trace, constraints) apply_shared_memory_indexing_corrections(trace, constraints) if visualize: diff --git a/lit_tests/kernel/wave/promotion.py b/lit_tests/kernel/wave/promotion.py index 4b54ab30..925912bf 100644 --- a/lit_tests/kernel/wave/promotion.py +++ b/lit_tests/kernel/wave/promotion.py @@ -6,7 +6,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders -from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * from iree.turbine.kernel._support.tracing import CapturedTrace @@ -35,6 +35,9 @@ def get_read_nodes(graph: fx.Graph) -> list[CustomOp]: ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 ADDRESS_SPACE_1 = tkl.sym.ADDRESS_SPACE_1 +# Induction variable for dimension K +ARGK = tkl.sym.ARGK + @tkw.wave_trace_only() def read_write_same_size( @@ -160,6 +163,7 @@ def test_gemm(): constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] constraints += [tkw.WorkgroupConstraint(K, BLOCK_K, 2)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] constraints += [ tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1)) ] @@ -177,7 +181,7 @@ def test_gemm(): infer_types(trace) for read_node in read_nodes: promote_node(read_node, SHARED_ADDRESS_SPACE, constraints) - hoist_allocs(trace) + hoist_loop_invariant_ops(trace, constraints) print_trace(trace, False) # Root graph: # CHECK: %a diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py index 56b4c5d5..83f6053a 100644 --- a/lit_tests/kernel/wave/scheduling.py +++ b/lit_tests/kernel/wave/scheduling.py @@ -6,7 +6,7 @@ import iree.turbine.kernel.lang as tkl import iree.turbine.kernel.wave as tkw from iree.turbine.kernel.wave.promotion import promote_placeholders -from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops from iree.turbine.kernel.wave.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.lang.global_symbols import * @@ -99,10 +99,10 @@ def test_gemm_pipelined(): IndexingContext.current().finalize() infer_types(trace) promote_placeholders(trace, constraints) - hoist_allocs(trace) set_node_indices(trace, constraints) expand_graph(trace, constraints) set_post_expansion_indices(trace, constraints) + hoist_loop_invariant_ops(trace, constraints) minimize_global_loads(trace, constraints) apply_shared_memory_indexing_corrections(trace, constraints) schedule_graph(trace, constraints, True) diff --git a/tests/kernel/wave/scheduling_test.py b/tests/kernel/wave/scheduling_test.py index b5ebb785..a49e5a50 100644 --- a/tests/kernel/wave/scheduling_test.py +++ b/tests/kernel/wave/scheduling_test.py @@ -28,7 +28,7 @@ from iree.turbine.kernel._support.tracing import CapturedTrace from iree.turbine.kernel._support.indexing import IndexingContext from iree.turbine.kernel.wave.promotion import promote_placeholders -from iree.turbine.kernel.wave.hoisting import hoist_allocs +from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops from iree.turbine.kernel.wave.expansion import expand_graph from iree.turbine.kernel.wave.type_inference import infer_types from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads @@ -285,7 +285,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: IndexingContext.current().finalize() infer_types(trace) promote_placeholders(trace, constraints) - hoist_allocs(trace) + hoist_loop_invariant_ops(trace, constraints) set_node_indices(trace, constraints) expand_graph(trace, constraints) set_post_expansion_indices(trace, constraints) diff --git a/tests/kernel/wave/wave_attention_test.py b/tests/kernel/wave/wave_attention_test.py index 412dbd04..769fae2d 100644 --- a/tests/kernel/wave/wave_attention_test.py +++ b/tests/kernel/wave/wave_attention_test.py @@ -435,7 +435,7 @@ def testAttention( @tkw.wave(constraints) def base_attention( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32], @@ -840,7 +840,7 @@ def testAttentionF8( @tkw.wave(constraints) def base_attention( - q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16], + q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16], k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16], v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16], c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],