Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Hoist loop invariant reads #296

Merged
merged 6 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 75 additions & 13 deletions iree/turbine/kernel/wave/hoisting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from .constraints import Constraint
from .utils import get_induction_variable
from ...support.logging import get_logger
from iree.turbine.kernel._support.tracing import CapturedTrace
import torch.fx as fx
Expand All @@ -13,25 +15,85 @@
logger = get_logger("turbine.wave.hoisting")


def get_allocs(graph: fx.Graph) -> list[CustomOp]:
return [
custom_node
for node in graph.nodes
if isinstance((custom_node := get_custom(node)), Allocate)
]
def get_hoistable_ops(
graph: fx.Graph,
captured_vars: list[CustomOp],
induction_variable: IndexExpr,
) -> list[CustomOp]:
"""
Get hoistable ops. Currently only handle allocs and read who doesn't depends on
induction variables.

Note: For codegen to work properly, we'd need to hoist allocs first. This is to avoid
using alloc before defined/non-dominating behavior.
(e.g hoisting read from global to shared before shared alloc is defined.)
"""
hoistable_allocs = []
hoistable_ops = []
for node in graph.nodes:
custom_node = get_custom(node)
if isinstance(custom_node, Allocate):
hoistable_allocs.append(custom_node)
elif isinstance(custom_node, Read):
if custom_node.index is None:
continue
# Only handle case where memory is captured var.
# i.e it has source value from root graph.
if not custom_node.memory in captured_vars:
continue
# Only handle case where we are not writing to the same memory.
# Counterproof: we may expect different read if we write to same memory.
for memory_user in custom_node.memory.users:
if isinstance(get_custom(memory_user), Write):
continue
# Only hoist Read that is loop invariant.
dims_indexing = [ind.start for ind in custom_node.index.values()]
dim_depends_on_ivar = [ind.has(induction_variable) for ind in dims_indexing]
if any(dim_depends_on_ivar):
continue
hoistable_ops.append(custom_node)
else:
continue
all_hoistables_ops = hoistable_allocs + hoistable_ops
Hardcode84 marked this conversation as resolved.
Show resolved Hide resolved
return all_hoistables_ops


def remove_unused_captured_vars(reduction: CustomOp, subgraph: fx.Graph):
captured_vars = reduction.captured_vars(subgraph)
new_implicit_captures = list(reduction.implicit_captures)
for captured_idx in reversed(range(len(captured_vars))):
if len(captured_vars[captured_idx].users) == 0:
get_custom(captured_vars[captured_idx]).erase()
new_implicit_captures.pop(captured_idx)
reduction.update_arg("implicit_captures", new_implicit_captures)


def hoist_allocs(trace: CapturedTrace):
"""Hoists allocs from reduction subgraphs to outer root graph."""
def hoist_loop_invariant_ops(trace: CapturedTrace, constraints: list[Constraint]):
"""Hoists ops that are loop-invariant from reduction subgraphs to outer root graph."""
root_graph = trace.get_root_graph()
for node in root_graph.nodes:
custom_node = get_custom(node)
match custom_node:
case Reduction():
with root_graph.inserting_before(custom_node.fx_node):
induction_variable = get_induction_variable(
custom_node, constraints
)
subgraph = trace.get_subgraph(custom_node.subgraph_name)
allocs = get_allocs(subgraph)
for alloc in allocs:
new_alloc = alloc.copy(new_graph=root_graph)
alloc.replace_all_uses_with(new_alloc)
alloc.erase()
# Capture/root variables from outside the loop.
implicit_captures = custom_node.implicit_captures
# Captured variables from inside the loop.
captured_vars = custom_node.captured_vars(subgraph)
hoistable_ops = get_hoistable_ops(
subgraph, captured_vars, induction_variable
)
for hoistable_op in hoistable_ops:
new_op = hoistable_op.copy(new_graph=root_graph)
hoistable_op.replace_all_uses_with(new_op)
hoistable_op.erase()
if isinstance(hoistable_op, Read):
capture_arg = captured_vars.index(hoistable_op.memory)
new_op.update_arg("memory", implicit_captures[capture_arg])
# Clear/Remove unused captured var to correct codegen. Ops inside
# scf.for will be indexing/loading from the wrong bindings otherwise.
remove_unused_captured_vars(custom_node, subgraph)
4 changes: 2 additions & 2 deletions iree/turbine/kernel/wave/wave.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from .codegen import WaveEmitter
from .expansion import expand_graph
from .promotion import promote_placeholders
from .hoisting import hoist_allocs
from .hoisting import hoist_loop_invariant_ops
from .utils import (
canonicalize_module,
compile_and_invoke,
Expand Down Expand Up @@ -232,7 +232,6 @@ def _trace_and_get_kernel_signature(

# Promote the placeholders to the appropriate address space.
promote_placeholders(graph, self.constraints)
hoist_allocs(graph)

# Set indices.
set_node_indices(graph, self.constraints)
Expand All @@ -250,6 +249,7 @@ def _trace_and_get_kernel_signature(
remove_chained_getresult(graph)

# Optimizations.
hoist_loop_invariant_ops(graph, self.constraints)
minimize_global_loads(graph, self.constraints)

# Apply shared memory indexing corrections.
Expand Down
30 changes: 15 additions & 15 deletions lit_tests/kernel/wave/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ def repeat(
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
print(dynamic_attention_pipelined(q, k, v, output).module_op)

# CHECK: func.func @dynamic_attention_pipelined
# CHECK-COUNT-4: {{.*}} = vector.maskedload {{.*}}
# CHECK-LABEL: func.func @dynamic_attention_pipelined
# CHECK-COUNT-6: {{.*}} = vector.maskedload {{.*}}
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-4: {{.*}} = vector.maskedload {{.*}}
# CHECK-COUNT-13: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-3: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = vector.maskedload {{.*}}
# CHECK-COUNT-14: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-7: {{.*}} = amdgpu.mfma
# CHECK-COUNT-5: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-16: vector.maskedstore {{.*}}

Expand Down Expand Up @@ -281,13 +281,13 @@ def repeat(
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
print(base_attention_pipelined(q, k, v, output).module_op)

# CHECK: func.func @base_attention_pipelined
# CHECK-LABEL: func.func @base_attention_pipelined
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-13: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-14: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-7: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
# CHECK-COUNT-3: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = gpu.shuffle xor {{.*}}


Expand Down Expand Up @@ -401,7 +401,7 @@ def repeat(
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
print(base_attention_32x32x8(q, k, v, output).module_op)

# CHECK: func.func @base_attention_32x32x8
# CHECK-LABEL: func.func @base_attention_32x32x8
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-2: {{.*}} = gpu.shuffle xor {{.*}}
Expand Down Expand Up @@ -524,7 +524,7 @@ def repeat(
output = torch.zeros(shape[0], shape[1], shape[2], dtype=torch.float32)
print(base_attention(q, k, v, output).module_op)

# CHECK: func.func @base_attention
# CHECK-LABEL: func.func @base_attention
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-16: {{.*}} = amdgpu.mfma
# CHECK-COUNT-8: {{.*}} = gpu.shuffle xor {{.*}}
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
)
from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
Expand Down Expand Up @@ -177,11 +177,11 @@ def test_gemm():
read_nodes = get_read_nodes(graph)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
hoist_allocs(trace)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
tweak_index(graph)
hoist_loop_invariant_ops(trace, constraints)
add_shared_memory_barriers(trace)
print_trace(trace, False)
# Root graph:
Expand Down
14 changes: 11 additions & 3 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def test_chained_gemm():

@tkw.wave(constraints)
def chained_gemm(
q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16],
q: tkl.Memory[B, M, K1, ADDRESS_SPACE_0, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[B, M, N, ADDRESS_SPACE_0, tkl.f32],
Expand Down Expand Up @@ -1051,8 +1051,13 @@ def repeat(
output = torch.zeros(8, 64, 128, dtype=torch.float32)
print(chained_gemm(q, k, v, output).module_op)

# CHECK-LABEL: func.func @chained_gemm(
# CHECK-LABEL: func.func @chained_gemm
# CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding)
# CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1x32x36xf16, #gpu.address_space<workgroup>>
# CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]]
# CHECK-COUNT-4: vector.load %[[GLOBAL_0]]
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-4: {{.*}} = vector.load %[[ALLOC]]
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
# CHECK-COUNT-4: {{.*}} = arith.truncf
# CHECK-COUNT-8: {{.*}} = amdgpu.mfma
Expand Down Expand Up @@ -1131,7 +1136,10 @@ def repeat(
output = torch.zeros(8, 64, 128, dtype=torch.float32)
print(chained_gemm_32x32x8(q, k, v, output).module_op)

# CHECK-LABEL: func.func @chained_gemm_32x32x8(
# CHECK-LABEL: func.func @chained_gemm_32x32x8
# CHECK-SAME: (%[[ARG0:.*]]: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding, %{{.+}}: !stream.binding)
# CHECK: %[[GLOBAL_0:.+]] = stream.binding.subspan %[[ARG0]]
# CHECK: %[[GLOBAL_READ_0:.+]] = vector.load %[[GLOBAL_0]]
# CHECK: {{.*}} = scf.for
# CHECK-COUNT-4: {{.*}} = amdgpu.mfma
# CHECK-COUNT-1: {{.*}} = arith.truncf
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/index_sequence_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
Expand Down Expand Up @@ -87,10 +87,10 @@ def test_gemm():
IndexingContext.current().finalize()
infer_types(trace)
promote_placeholders(trace, constraints)
hoist_allocs(trace)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
hoist_loop_invariant_ops(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
partition_strided_operators(trace, constraints)
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/minimize_global_loads.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.barriers import add_shared_memory_barriers
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
Expand Down Expand Up @@ -90,12 +90,12 @@ def test_gemm():
IndexingContext.current().finalize()
infer_types(trace)
promote_placeholders(trace, constraints)
hoist_allocs(trace)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
if visualize:
visualize_graph(trace.get_subgraph("region_0"), "before.png")
hoist_loop_invariant_ops(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
if visualize:
Expand Down
8 changes: 6 additions & 2 deletions lit_tests/kernel/wave/promotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_node, promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
from iree.turbine.kernel._support.tracing import CapturedTrace
Expand Down Expand Up @@ -35,6 +35,9 @@ def get_read_nodes(graph: fx.Graph) -> list[CustomOp]:
ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0
ADDRESS_SPACE_1 = tkl.sym.ADDRESS_SPACE_1

# Induction variable for dimension K
ARGK = tkl.sym.ARGK


@tkw.wave_trace_only()
def read_write_same_size(
Expand Down Expand Up @@ -160,6 +163,7 @@ def test_gemm():
constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WorkgroupConstraint(K, BLOCK_K, 2)]
constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)]
constraints += [
tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(1, 1, 1))
]
Expand All @@ -177,7 +181,7 @@ def test_gemm():
infer_types(trace)
for read_node in read_nodes:
promote_node(read_node, SHARED_ADDRESS_SPACE, constraints)
hoist_allocs(trace)
hoist_loop_invariant_ops(trace, constraints)
print_trace(trace, False)
# Root graph:
# CHECK: %a
Expand Down
4 changes: 2 additions & 2 deletions lit_tests/kernel/wave/scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import iree.turbine.kernel.lang as tkl
import iree.turbine.kernel.wave as tkw
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.lang.global_symbols import *
Expand Down Expand Up @@ -99,10 +99,10 @@ def test_gemm_pipelined():
IndexingContext.current().finalize()
infer_types(trace)
promote_placeholders(trace, constraints)
hoist_allocs(trace)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
hoist_loop_invariant_ops(trace, constraints)
minimize_global_loads(trace, constraints)
apply_shared_memory_indexing_corrections(trace, constraints)
schedule_graph(trace, constraints, True)
Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/wave/scheduling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from iree.turbine.kernel._support.tracing import CapturedTrace
from iree.turbine.kernel._support.indexing import IndexingContext
from iree.turbine.kernel.wave.promotion import promote_placeholders
from iree.turbine.kernel.wave.hoisting import hoist_allocs
from iree.turbine.kernel.wave.hoisting import hoist_loop_invariant_ops
from iree.turbine.kernel.wave.expansion import expand_graph
from iree.turbine.kernel.wave.type_inference import infer_types
from iree.turbine.kernel.wave.minimize_global_loads import minimize_global_loads
Expand Down Expand Up @@ -285,7 +285,7 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]:
IndexingContext.current().finalize()
infer_types(trace)
promote_placeholders(trace, constraints)
hoist_allocs(trace)
hoist_loop_invariant_ops(trace, constraints)
set_node_indices(trace, constraints)
expand_graph(trace, constraints)
set_post_expansion_indices(trace, constraints)
Expand Down
4 changes: 2 additions & 2 deletions tests/kernel/wave/wave_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def testAttention(

@tkw.wave(constraints)
def base_attention(
q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16],
q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
Expand Down Expand Up @@ -840,7 +840,7 @@ def testAttentionF8(

@tkw.wave(constraints)
def base_attention(
q: tkl.Memory[B, M, K1, ADDRESS_SPACE, tkl.f16],
q: tkl.Memory[B, M, K1, GLOBAL_ADDRESS_SPACE, tkl.f16],
k: tkl.Memory[B, K2, K1, ADDRESS_SPACE, tkl.f16],
v: tkl.Memory[B, N, K2, ADDRESS_SPACE, tkl.f16],
c: tkl.Memory[B, M, N, GLOBAL_ADDRESS_SPACE, tkl.f32],
Expand Down
Loading