Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
1c012dc by Rahul Batra <[email protected]>:

ROCM updates

PiperOrigin-RevId: 673474325
  • Loading branch information
teijeong authored and The jax_triton Authors committed Sep 11, 2024
1 parent 9e7d6a9 commit e82c529
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 168 deletions.
198 changes: 33 additions & 165 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import types
from typing import Any, Callable, Dict, Optional, Protocol, Sequence, Tuple, Union
import zlib
from functools import partial

from absl import logging
import jax
Expand Down Expand Up @@ -57,14 +56,6 @@
CAN_USE_TRITON = True
except ModuleNotFoundError:
pass

try:
import triton.backends.amd.compiler as hb
except ImportError:
hb = None
pass


try:
from jax._src.lib import gpu_triton as triton_kernel_call_lib
except ImportError:
Expand Down Expand Up @@ -99,6 +90,7 @@
jnp.dtype("bool"): "B",
}


Grid = Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]]
GridOrLambda = Union[Grid, Callable[[Dict[str, Any]], Grid]]

Expand Down Expand Up @@ -165,61 +157,22 @@ def aval_size_bytes(aval):
return np.dtype(aval.dtype).itemsize * aval.size


def get_cuda_backend(device, compute_capability):
target = cb.GPUTarget('cuda', compute_capability, 32)
backend = cb.CUDABackend(target)
return backend

def get_hip_backend(device, compute_capability):
arch = triton_kernel_call_lib.get_arch_details(device)
arch = arch.split(":")[0]
target = hb.GPUTarget('hip', arch, 64)
backend = hb.HIPBackend(target)
return backend

@dataclasses.dataclass
class CompilationResult:
binary: str
class PtxCompilationResult:
ptx: str
name: str
shared_mem_bytes: int
cluster_dims: tuple
ttgir: Optional[str]
llir: Optional[str]

def compile_ttir_inplace(
ttir,
backend: [cb.CUDABackend | hb.HIPBackend],
options: [cb.CUDAOptions | hb.HIPOptions],
compute_capability,
platform
):
if platform == 'cuda':
return compile_ttir_to_ptx_inplace(
ttir,
backend,
options,
compute_capability,
)

elif platform == 'rocm':
return compile_ttir_to_hsaco_inplace(
ttir,
backend,
options,
compute_capability,
)
else:
raise ValueError(
"Unsupported device."
)


def compile_ttir_to_ptx_inplace(
ttir,
cuda_backend: cb.CUDABackend,
cuda_options: cb.CUDAOptions,
compute_capability,
) -> CompilationResult:
) -> PtxCompilationResult:
if cuda_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
Expand All @@ -236,7 +189,7 @@ def compile_ttir_to_ptx_inplace(
ttir = tl_ir.parse_mlir_module(f.name, context)
ttir.context = context
try:
metadata = {}
metadata = dict()
opt_ttir = cuda_backend.make_ttir(ttir, metadata, cuda_options)
ttgir = cuda_backend.make_ttgir(
opt_ttir,
Expand Down Expand Up @@ -274,95 +227,20 @@ def compile_ttir_to_ptx_inplace(
cluster_dims = metadata["cluster_dims"]
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
return CompilationResult(
binary=ptx,
return PtxCompilationResult(
ptx=ptx,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
ttgir=ttgir,
llir=llir,
)

def compile_ttir_to_hsaco_inplace(
ttir,
hip_backend: hb.HIPBackend,
hip_options: hb.HIPOptions,
compute_capability,
) -> CompilationResult:
if hip_options.debug:
print(ttir)
if isinstance(ttir, ir.Module):
context = _triton.ir.context()
_triton.ir.load_dialects(context)
hip_backend.load_dialects(context)

# Triton compilation APIs only accept Triton-specific MLIR wrappers.
# So, here we serialize an ir.Module to a file and then deserialize
# it as a tl_ir.module.
with tempfile.NamedTemporaryFile(mode="wb") as f:
ttir.operation.write_bytecode(f)
f.flush()
ttir = tl_ir.parse_mlir_module(f.name, context)
ttir.context = context
try:
metadata = {}
opt_ttir = hip_backend.make_ttir(ttir, metadata, hip_options)
ttgir = hip_backend.make_ttgir(
opt_ttir,
metadata,
hip_options
)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if hip_options.debug:
print(ttgir)
try:
llir = hip_backend.make_llir(
ttgir,
metadata,
hip_options
)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = metadata["shared"]
if hip_options.debug:
print(llir)

amdgcn = hip_backend.make_amdgcn(llir, metadata, hip_options)
hsaco = hip_backend.make_hsaco(amdgcn, metadata, hip_options)

if hip_options.debug:
print(x)
name = metadata["name"]
ttgir = str(ttgir) if _JAX_TRITON_DUMP_DIR else None
llir = str(llir) if _JAX_TRITON_DUMP_DIR else None
# cluster dims are NOT useful on hip backend.
# We just fill up with some value for API compatibility
cluster_dims = (0, 0, 0)
# Instead of passing hsaco which are "bytes", we first write
# to a file and then pass the "string" path. This is needed because
# nanobind doesn't automatically convert between bytes and string.
# https://github.com/wjakob/nanobind/discussions/137
fd, hsaco_path = tempfile.mkstemp()
with os.fdopen(fd, "wb") as f:
f.write(hsaco)
return CompilationResult(
binary=hsaco_path,
name=name,
shared_mem_bytes=shared_mem_bytes,
cluster_dims=cluster_dims,
ttgir=ttgir,
llir=llir,
)

_COMPILED_KERNEL_CACHE = {} # TODO(cjfj): Convert to LRU cache?


def get_or_create_triton_kernel(
backend_init_func,
platform,
fn,
arg_dtypes,
scalar_args,
Expand Down Expand Up @@ -419,29 +297,29 @@ def get_or_create_triton_kernel(
kernel = _COMPILED_KERNEL_CACHE.get(cache_key)

if kernel is None:
opts = {
"num_warps": num_warps,
"num_stages": num_stages,
"num_ctas": num_ctas,
"optimize_epilogue": False,
"debug": dump,
"enable_fp_fusion": enable_fp_fusion,
}

backend = backend_init_func(device, compute_capability)
options = backend.parse_options(opts)

target = cb.GPUTarget('cuda', compute_capability, 32)
cuda_backend = cb.CUDABackend(target)
cuda_options = cuda_backend.parse_options(
dict(
num_warps=num_warps,
num_stages=num_stages,
num_ctas=num_ctas,
optimize_epilogue=False,
debug=dump,
enable_fp_fusion=enable_fp_fusion,
)
)
kernel_hash = abs(hash(cache_key))
if _JAX_TRITON_DUMP_DIR:
os.makedirs(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}")
with open(f"{_JAX_TRITON_DUMP_DIR}/{kernel_hash}/config", "w") as f:
pprint.pprint(cache_key, stream=f)
pprint.pprint(options, stream=f)
pprint.pprint(cuda_options, stream=f)

context = _triton.ir.context()
_triton.ir.load_dialects(context)
backend.load_dialects(context)
codegen_fns = backend.get_codegen_implementation()
cuda_backend.load_dialects(context)
codegen_fns = cuda_backend.get_codegen_implementation()

module = (
code_gen.ast_to_ttir(
Expand All @@ -452,10 +330,10 @@ def get_or_create_triton_kernel(
signature=signature,
attrs=specialization_attr,
),
options=options,
options=cuda_options,
codegen_fns=codegen_fns,
context=context,
module_map=backend.get_module_map(),
module_map=cuda_backend.get_module_map(),
)
if "module_map" in inspect.getfullargspec(code_gen.ast_to_ttir).args
# Triton changes ASTSource.ast_to_ttir to include module_map. Handle
Expand All @@ -468,21 +346,19 @@ def get_or_create_triton_kernel(
signature=signature,
attrs=specialization_attr,
),
options=options,
options=cuda_options,
codegen_fns=codegen_fns,
context=context,
)
)
ttir = str(module)

compilation_result = compile_ttir_inplace(
module,
backend,
options,
compute_capability,
platform
compilation_result = compile_ttir_to_ptx_inplace(
module,
cuda_backend,
cuda_options,
compute_capability,
)

kernel_name = compilation_result.name
if _JAX_TRITON_DUMP_DIR:
with open(
Expand Down Expand Up @@ -515,7 +391,7 @@ def get_or_create_triton_kernel(
kernel_name,
num_warps,
compilation_result.shared_mem_bytes,
compilation_result.binary,
compilation_result.ptx,
ttir,
compute_capability,
*compilation_result.cluster_dims,
Expand All @@ -527,7 +403,6 @@ def get_or_create_triton_kernel(


def triton_kernel_call_lowering(
backend_init_func,
ctx,
*array_args,
fn,
Expand All @@ -552,7 +427,6 @@ def triton_kernel_call_lowering(
"`input_output_aliases` only supported on `jaxlib>=0.3.22"
)


kernel_call_name = name
args = list(ctx.avals_in)
arg_dtypes = list(map(get_triton_type, ctx.avals_in))
Expand Down Expand Up @@ -647,8 +521,6 @@ def prune_configs(configs, named_args, **kwargs):
kernel_calls = []
for params in config_params:
kernel, specialization_attr = get_or_create_triton_kernel(
backend_init_func,
ctx.module_context.platforms[0],
fn,
arg_dtypes,
scalar_args,
Expand Down Expand Up @@ -718,13 +590,9 @@ def prune_configs(configs, named_args, **kwargs):
operand_output_aliases=dict(input_output_aliases),
).results

mlir.register_lowering(triton_kernel_call_p,
partial(triton_kernel_call_lowering, get_cuda_backend),
platform='cuda')

mlir.register_lowering(triton_kernel_call_p,
partial(triton_kernel_call_lowering, get_hip_backend),
platform='rocm')
mlir.register_lowering(triton_kernel_call_p, triton_kernel_call_lowering)


class ShapeDtype(Protocol):

Expand Down
6 changes: 3 additions & 3 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,16 +328,16 @@ def test_kernel_cache_equivalent_kernels(self):
x1, y1 = create_random_inputs([42])
x2, y2 = create_random_inputs([43])

compile_ttir_inplace = jt.triton_lib.compile_ttir_inplace
compile_ttir_to_ptx_inplace = jt.triton_lib.compile_ttir_to_ptx_inplace

call_count = [0]

def my_compile(*args, **kwargs):
call_count[0] += 1
return compile_ttir_inplace(*args, **kwargs)
return compile_ttir_to_ptx_inplace(*args, **kwargs)

with mock.patch.object(
jt.triton_lib, "compile_ttir_inplace", new=my_compile
jt.triton_lib, "compile_ttir_to_ptx_inplace", new=my_compile
):
_ = fn1(x1, y1)
self.assertEqual(call_count[0], 1)
Expand Down

0 comments on commit e82c529

Please sign in to comment.