diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index ec71c434..83b37a5d 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -370,10 +370,11 @@ def get_or_create_triton_kernel( # `JITFunction._get_config` to get the specialization_attr. mock_torch_tensor = types.SimpleNamespace(data_ptr=lambda: 16) args_for_specialization_attr = [mock_torch_tensor] * len(arg_dtypes) + backend = backend_init_func(device, compute_capability) for i, _, v in scalar_args: args_for_specialization_attr[i] = v - specialization_attr = fn._get_config(*args_for_specialization_attr) # pylint: disable=protected-access + specialization_attr = backend.get_attrs_descriptor(fn.params[:len(args_for_specialization_attr)], args_for_specialization_attr) # pylint: disable=protected-access constants = dict(metaparams) constants.update({k: None for _, k, v in scalar_args if v is None}) constants.update({fn.arg_names[i]: 1 for i in specialization_attr.equal_to_1}) @@ -382,7 +383,7 @@ def get_or_create_triton_kernel( cache_key = ( fn, tuple(signature.items()), - tuple(vars(specialization_attr).values()), + tuple(specialization_attr.get_fn_attrs()), tuple(constants.items()), num_warps, num_stages, @@ -402,7 +403,6 @@ def get_or_create_triton_kernel( "enable_fp_fusion": enable_fp_fusion, } - backend = backend_init_func(device, compute_capability) options = backend.parse_options(opts) kernel_hash = abs(hash(cache_key)) @@ -645,7 +645,7 @@ def prune_configs(configs, named_args, **kwargs): kernel_params.append( triton_kernel_call_lib.create_array_parameter( zeroed_params_with_sizes.get(i, 0), - 16 if (i in specialization_attr.divisible_by_16) else 0, + 16 if (i in specialization_attr.divisibility_16) else 0, ) ) elif i not in specialization_attr.equal_to_1: diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index 098e6b5c..cf8bfe00 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -564,10 +564,10 @@ def test_specialization(self): # Pointers are assumed to divide by 16, as do `M`, `N`, `stride_{bk,cm}`. # However, we've marked `a_ptr`, `M`, `stride_bk`, and `c_ptr` as "do not # specialize", leaving `b_ptr`, `N`, and `stride_cm`. - self.assertEqual(specialization.attrs.divisible_by_16, (1, 3, 9)) + self.assertEqual(specialization.attrs.divisibility_16, [1, 3, 9]) # `stride_{ak,bn,cn}` equal 1, but we've marked `stride_ak` as "do not # specialize" leaving `stride_{bn,cn}`. - self.assertEqual(specialization.attrs.equal_to_1, (8, 10)) + self.assertEqual(specialization.attrs.equal_to_1, [8, 10]) if __name__ == "__main__":