Skip to content

Commit

Permalink
[TKW] Hoist loop invariant reads (#296)
Browse files Browse the repository at this point in the history
  • Loading branch information
raikonenfnu authored Nov 27, 2024
1 parent d9d2e7b commit e3b6c87
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 49 deletions.
11 changes: 9 additions & 2 deletions .github/workflows/ci-tk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,20 @@ jobs:
run: |
pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/
- name: Run e2e tests on AMD GPU
if: "contains(matrix.os, 'amdgpu') && !cancelled()"
- name: Run e2e tests on AMD GPU MI300
if: "contains(matrix.os, 'mi300') && !cancelled()"
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1
pytest -n 4 --capture=tee-sys -vv ./tests/kernel/wave/
- name: Run e2e tests on AMD GPU MI250
if: "contains(matrix.os, 'mi250') && !cancelled()"
run: |
pip install --no-compile -r pytorch-rocm-requirements.txt
export WAVE_RUN_E2E_TESTS=1
pytest -n 2 --capture=tee-sys -vv ./tests/kernel/wave/
- name: Run LIT tests
if: ${{ !cancelled() }}
run: |
Expand Down
90 changes: 77 additions & 13 deletions iree/turbine/kernel/wave/hoisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .constraints import Constraint
from .utils import get_induction_variable
from ...support.logging import get_logger
from iree.turbine.kernel._support.tracing import CapturedTrace
import torch.fx as fx
Expand All @@ -13,25 +15,87 @@
logger = get_logger("turbine.wave.hoisting")


def get_allocs(graph: fx.Graph) -> list[CustomOp]:
return [
custom_node
for node in graph.nodes
if isinstance((custom_node := get_custom(node)), Allocate)
]
def get_hoistable_ops(
graph: fx.Graph,
captured_vars: list[CustomOp],
induction_variable: IndexExpr,
) -> list[CustomOp]:
"""
Get hoistable ops. Currently only handle allocs and read who doesn't depends on
induction variables.
Note: For codegen to work properly, we'd need to hoist allocs first. This is to avoid
using alloc before defined/non-dominating behavior.
(e.g hoisting read from global to shared before shared alloc is defined.)
"""
hoistable_allocs = []
hoistable_ops = []
for node in graph.nodes:
custom_node = get_custom(node)
if isinstance(custom_node, Allocate):
hoistable_allocs.append(custom_node)
elif isinstance(custom_node, Read):
if custom_node.index is None:
continue
# Only handle case where memory is captured var.
# i.e it has source value from root graph.
if not custom_node.memory in captured_vars:
continue
# Only handle case where we are not writing to the same memory.
# Counterproof: we may expect different read if we write to same memory.
if any(
isinstance(get_custom(mem_user), Write)
for mem_user in custom_node.memory.users
):
continue
# Only hoist Read that is loop invariant.
if any(
ind.start.has(induction_variable) for ind in custom_node.index.values()
):
continue
hoistable_ops.append(custom_node)
else:
continue
all_hoistables_ops = hoistable_allocs + hoistable_ops
return all_hoistables_ops


def remove_unused_captured_vars(reduction: CustomOp, subgraph: fx.Graph):
captured_vars = reduction.captured_vars(subgraph)
new_implicit_captures = list(reduction.implicit_captures)
for captured_idx in reversed(range(len(captured_vars))):
if len(captured_vars[captured_idx].users) == 0:
get_custom(captured_vars[captured_idx]).erase()
new_implicit_captures.pop(captured_idx)
reduction.update_arg("implicit_captures", new_implicit_captures)


def hoist_allocs(trace: CapturedTrace):
"""Hoists allocs from reduction subgraphs to outer root graph."""
def hoist_loop_invariant_ops(trace: CapturedTrace, constraints: list[Constraint]):
"""Hoists ops that are loop-invariant from reduction subgraphs to outer root graph."""
root_graph = trace.get_root_graph()
for node in root_graph.nodes:
custom_node = get_custom(node)
match custom_node:
case Reduction():
with root_graph.inserting_before(custom_node.fx_node):
induction_variable = get_induction_variable(
custom_node, constraints
)
subgraph = trace.get_subgraph(custom_node.subgraph_name)
allocs = get_allocs(subgraph)
for alloc in allocs:
new_alloc = alloc.copy(new_graph=root_graph)
alloc.replace_all_uses_with(new_alloc)
alloc.erase()
# Capture/root variables from outside the loop.
implicit_captures = custom_node.implicit_captures
# Captured variables from inside the loop.
captured_vars = custom_node.captured_vars(subgraph)
hoistable_ops = get_hoistable_ops(
subgraph, captured_vars, induction_variable
)
for hoistable_op in hoistable_ops:
new_op = hoistable_op.copy(new_graph=root_graph)
hoistable_op.replace_all_uses_with(new_op)
hoistable_op.erase()
if isinstance(hoistable_op, Read):
capture_arg = captured_vars.index(hoistable_op.memory)
new_op.update_arg("memory", implicit_captures[capture_arg])
# Clear/Remove unused captured var to correct codegen. Ops inside
# scf.for will be indexing/loading from the wrong bindings otherwise.
remove_unused_captured_vars(custom_node, subgraph)
4 changes: 2 additions & 2 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .codegen import WaveEmitter
from .expansion import expand_graph
from .promotion import promote_placeholders
from .hoisting import hoist_allocs
from .hoisting import hoist_loop_invariant_ops
from .utils import (
canonicalize_module,
compile_and_invoke,
Expand Down Expand Up @@ -232,7 +232,6 @@ def _trace_and_get_kernel_signature(

# Promote the placeholders to the appropriate address space.
promote_placeholders(graph, self.constraints)
hoist_allocs(graph)

# Set indices.
set_node_indices(graph, self.constraints)
Expand All @@ -250,6 +249,7 @@ def _trace_and_get_kernel_signature(
remove_chained_getresult(graph)

# Optimizations.
hoist_loop_invariant_ops(graph, self.constraints)
minimize_global_loads(graph, self.constraints)

# Apply shared memory indexing corrections.
Expand Down
30 changes: 15 additions & 15 deletions lit_tests/kernel/wave/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ def repeat(
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
print(dynamic_attention_pipelined(q, k, v, output).module_op)

# CHECK: func.func @dynamic_attention_pipelined
# CHECK-COUNT-4: {{.*}} = vector.maskedload {{.*}}
# CHECK-LABEL: func.func @dynamic_attention_pipelined
# CHECK-COUNT-6: {{.*}} = vector.maskedload {{.*}}
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-4: {{.*}} = vector.maskedload {{.*}}
# CHECK-COUNT-13: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-3: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}}
# CHECK-COUNT-14: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-7: {{.*}} = amdgpu.mfma
# CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-16: vector.maskedstore {{.*}}

Expand Down Expand Up @@ -281,13 +281,13 @@ def repeat(
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
print(base_attention_pipelined(q, k, v, output).module_op)

# CHECK: func.func @base_attention_pipelined
# CHECK-LABEL: func.func @base_attention_pipelined
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-13: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-14: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-7: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-3: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}}


Expand Down Expand Up @@ -401,7 +401,7 @@ def repeat(
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
print(base_attention_32x32x8(q, k, v, output).module_op)

# CHECK: func.func @base_attention_32x32x8
# CHECK-LABEL: func.func @base_attention_32x32x8
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
Expand Down Expand Up @@ -524,7 +524,7 @@ def repeat(
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
print(base_attention(q, k, v, output).module_op)

# CHECK: func.func @base_attention
# CHECK-LABEL: func.func @base_attention
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-16: {{.*}} = amdgpu.mfma
# CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}}
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
Expand Down Expand Up @@ -177,11 +177,11 @@ def test_gemm():
read_nodes = get_read_nodes(graph)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
hoist_allocs(trace)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
tweak_index(graph)
hoist_loop_invariant_ops(trace, constraints)
add_shared_memory_barriers(trace)
print_trace(trace, False)
# Root graph:
Expand Down
14 changes: 11 additions & 3 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def test_chained_gemm():

@tkw.wave(constraints)
def chained_gemm(
q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16],
q: tkl.Memory[B, M, K1, ADDRESS_SPACE_0, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32],
Expand Down Expand Up @@ -1051,8 +1051,13 @@ def repeat(
output = torch.zeros(8, 64, 128, dtype=torch.float32)
print(chained_gemm(q, k, v, output).module_op)

# CHECK-LABEL: func.func @chained_gemm(
# CHECK-LABEL: func.func @chained_gemm
# CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding)
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x32x36xf16, #gpu.address_space<workgroup>>
# CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]]
# CHECK-COUNT-4: vector.load %[[GLOBAL_0]]
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-4: {{.*}} = vector.load %[[ALLOC]]
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = arith.truncf
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
Expand Down Expand Up @@ -1131,7 +1136,10 @@ def repeat(
output = torch.zeros(8, 64, 128, dtype=torch.float32)
print(chained_gemm_32x32x8(q, k, v, output).module_op)

# CHECK-LABEL: func.func @chained_gemm_32x32x8(
# CHECK-LABEL: func.func @chained_gemm_32x32x8
# CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding)
# CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]]
# CHECK: %[[GLOBAL_READ_0:.+]] = vector.load %[[GLOBAL_0]]
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-4: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = arith.truncf
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
Expand Down Expand Up @@ -87,10 +87,10 @@ def test_gemm():
IndexingContext.current().finalize()
infer_types(trace)
promote_placeholders(trace, constraints)
hoist_allocs(trace)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
hoist_loop_invariant_ops(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
partition_strided_operators(trace, constraints)
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
Expand Down Expand Up @@ -90,12 +90,12 @@ def test_gemm():
IndexingContext.current().finalize()
infer_types(trace)
promote_placeholders(trace, constraints)
hoist_allocs(trace)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
if visualize:
visualize_graph(trace.get_subgraph("region_0"), "before.png")
hoist_loop_invariant_ops(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
if visualize:
Expand Down
8 changes: 6 additions & 2 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel._support.tracing import CapturedTrace
Expand Down Expand Up @@ -35,6 +35,9 @@ def get_read_nodes(graph: fx.Graph) -> list[CustomOp]:
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0
ADDRESS_SPACE_1 = tkl.sym.ADDRESS_SPACE_1

# Induction variable for dimension K
ARGK = tkl.sym.ARGK


@tkw.wave_trace_only()
def read_write_same_size(
Expand Down Expand Up @@ -160,6 +163,7 @@ def test_gemm():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(K, BLOCK_K, 2)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
Expand All @@ -177,7 +181,7 @@ def test_gemm():
infer_types(trace)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
hoist_allocs(trace)
hoist_loop_invariant_ops(trace, constraints)
print_trace(trace, False)
# Root graph:
# CHECK: %a
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
Expand Down Expand Up @@ -99,10 +99,10 @@ def test_gemm_pipelined():
IndexingContext.current().finalize()
infer_types(trace)
promote_placeholders(trace, constraints)
hoist_allocs(trace)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
hoist_loop_invariant_ops(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
schedule_graph(trace, constraints, True)
Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/wave/scheduling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from iree.turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.indexing import IndexingContext
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads
Expand Down Expand Up @@ -285,7 +285,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
IndexingContext.current().finalize()
infer_types(trace)
promote_placeholders(trace, constraints)
hoist_allocs(trace)
hoist_loop_invariant_ops(trace, constraints)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
Expand Down
Loading

0 comments on commit e3b6c87

Please sign in to comment.