diff --git a/jax_triton/__init__.py b/jax_triton/__init__.py index 064313c..977bda5 100644 --- a/jax_triton/__init__.py +++ b/jax_triton/__init__.py @@ -24,7 +24,6 @@ "__version_info__", ] -import jaxlib from jax._src.lib import gpu_triton from jax_triton import utils from jax_triton.triton_lib import triton_call @@ -36,6 +35,7 @@ try: get_compute_capability = gpu_triton.get_compute_capability + get_serialized_metadata = gpu_triton.get_serialized_metadata except AttributeError: raise ImportError( "jax-triton requires JAX to be installed with GPU support. The " @@ -43,13 +43,7 @@ "instructions for installing a supported version:\n" "https://jax.readthedocs.io/en/latest/installation.html" ) - -if jaxlib.version.__version_info__ >= (0, 4, 14): - try: - get_serialized_metadata = gpu_triton.get_serialized_metadata - except AttributeError: - get_serialized_metadata = None +else: + del gpu_triton # Not part of the API. # trailer -del gpu_triton -del jaxlib