From 683ad7066ff6cf0daf6b8b14e8a1afd3d585b5c1 Mon Sep 17 00:00:00 2001 From: Aliia Khasanova Date: Thu, 10 Oct 2024 05:11:13 -0700 Subject: [PATCH] Integrate Triton up to [68aa962e67baa191cec5aac173255abdba80db1a](https://github.com/openai/triton/commits/68aa962e67baa191cec5aac173255abdba80db1a) PiperOrigin-RevId: 684403022 --- jax_triton/triton_lib.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 33459d3e..2af07f9d 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -370,20 +370,20 @@ def get_or_create_triton_kernel( # We replace array arguments with mock Torch tensors, to allow us to use # `JITFunction._get_config` to get the specialization_attr. mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16) - args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes) + args_for_specialization_attr = [mock_torch_tensor] * len(fn.params) + backend = backend_init_func(device, compute_capability) for i, _, v in scalar_args: args_for_specialization_attr[i] = v - specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access + specialization_attr = backend.get_attrs_descriptor(fn.params, args_for_specialization_attr) # pylint: disable=protected-access constants = {k: v for k, v in metaparams.items()} constants.update({k: None for _, k, v in scalar_args if v is None}) constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1}) - # Cache key should contain any parameter that can affect the compiler output. cache_key = ( fn, tuple(signature.items()), - tuple(vars(specialization_attr).values()), + tuple(specialization_attr.arg_properties), tuple(constants.items()), num_warps, num_stages, @@ -403,7 +403,6 @@ def get_or_create_triton_kernel( "enable_fp_fusion": enable_fp_fusion, } - backend = backend_init_func(device, compute_capability) options = backend.parse_options(opts) kernel_hash = abs(hash(cache_key)) @@ -643,7 +642,7 @@ def prune_configs(configs, named_args, **kwargs): kernel_params.append( triton_kernel_call_lib.create_array_parameter( zeroed_params_with_sizes.get(i, 0), - 16 if (i in specialization_attr.divisible_by_16) else 0, + 16 if (i in specialization_attr.divisibility_16) else 0, ) ) elif i not in specialization_attr.equal_to_1: