diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 47a858b..3499669 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -8,12 +8,16 @@ on: jobs: tests: - name: ${{ matrix.os }} + name: ${{ matrix.os }}-${{ matrix.jax-version }} runs-on: ${{ matrix.os }} strategy: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] + jax-version: ["jax[cpu]"] + include: + - os: ubuntu-latest + jax-version: "jax[cpu]==0.4.20" steps: - uses: actions/checkout@v4 @@ -37,7 +41,7 @@ jobs: - name: Build run: | python -m pip install -U pip - python -m pip install -U jax[cpu] + python -m pip install -U ${{ matrix.jax-version }} python -m pip install -v .[test] - name: Run tests diff --git a/pyproject.toml b/pyproject.toml index 8d77032..c9f81da 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ requires-python = ">=3.9" license = { file = "LICENSE" } urls = { Homepage = "https://github.com/dfm/jax-finufft" } -dependencies = ["jax", "numpy", "pydantic>=2"] +dependencies = ["jax>=0.4.20", "numpy", "pydantic>=2"] dynamic = ["version"] [project.optional-dependencies] diff --git a/src/jax_finufft/lowering.py b/src/jax_finufft/lowering.py index 977a16e..365973d 100644 --- a/src/jax_finufft/lowering.py +++ b/src/jax_finufft/lowering.py @@ -1,43 +1,10 @@ import numpy as np +from jax.interpreters import mlir from jax.interpreters.mlir import ir from jax.lib import xla_client -from jaxlib.hlo_helpers import custom_call as _custom_call -from jax_finufft import options - -try: - from jaxlib.hlo_helpers import hlo_const -except ImportError: - # Copied from jaxlib/hlo_helpers.py for old versions of jax - from functools import partial - - import jaxlib.mlir.dialects.stablehlo as hlo - - _dtype_to_ir_type_factory = { - np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1), - np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8), - np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16), - np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32), - np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64), - np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8), - np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16), - np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32), - np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64), - np.dtype(np.float16): ir.F16Type.get, - np.dtype(np.float32): ir.F32Type.get, - np.dtype(np.float64): ir.F64Type.get, - np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()), - np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()), - } - - def dtype_to_ir_type(dtype) -> ir.Type: - return _dtype_to_ir_type_factory[np.dtype(dtype)]() - - def hlo_const(x): - assert isinstance(x, np.ndarray) - return hlo.ConstantOp( - ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype)) - ).result +from jaxlib.hlo_helpers import custom_call, hlo_const +from jax_finufft import options from . import jax_finufft_cpu @@ -45,7 +12,7 @@ def hlo_const(x): from . import jax_finufft_gpu for _name, _value in jax_finufft_gpu.registrations().items(): - xla_client.register_custom_call_target(_name, _value, platform="gpu") + xla_client.register_custom_call_target(_name, _value, platform="CUDA") except ImportError: jax_finufft_gpu = None @@ -53,26 +20,25 @@ def hlo_const(x): xla_client.register_custom_call_target(_name, _value, platform="cpu") -# Handle old versions of jax which had a different syntax for custom_call -def custom_call(*args, **kwargs): - try: - return _custom_call(*args, **kwargs).results - except TypeError: - kwargs["out_types"] = kwargs.pop("result_types") - return (_custom_call(*args, **kwargs),) - - def default_layouts(*shapes): return [range(len(shape) - 1, -1, -1) for shape in shapes] -def lowering(platform, ctx, source, *points, output_shape, iflag, eps, opts): - del ctx - - if platform not in ["cpu", "gpu"]: - raise ValueError(f"Unrecognized platform '{platform}'") - - if platform == "gpu" and jax_finufft_gpu is None: +def lowering( + ctx: mlir.LoweringRuleContext, + source: ir.Value, + *points, + output_shape, + iflag, + eps, + opts, +): + if len(ctx.module_context.platforms) > 1: + raise ValueError("Multi-platform lowering is not supported") + platform = ctx.module_context.platforms[0] + if platform not in {"cpu", "cuda"}: + raise ValueError(f"Unsupported platform '{platform}'") + if platform == "cuda" and jax_finufft_gpu is None: raise ValueError("jax-finufft was not compiled with GPU support") ndim = len(points) @@ -80,36 +46,24 @@ def lowering(platform, ctx, source, *points, output_shape, iflag, eps, opts): if platform == "gpu" and ndim == 1: raise ValueError("1-D transforms are not yet supported on the GPU") - source_type = ir.RankedTensorType(source.type) - points_type = [ir.RankedTensorType(x.type) for x in points] - - # Check supported and consistent dtypes - f32 = ir.F32Type.get() - f64 = ir.F64Type.get() - source_dtype = source_type.element_type - single = source_dtype == ir.ComplexType.get(f32) and all( - x.element_type == f32 for x in points_type - ) - double = source_dtype == ir.ComplexType.get(f64) and all( - x.element_type == f64 for x in points_type - ) - assert single or double + source_aval = ctx.avals_in[0] + single = source_aval.dtype == np.complex64 suffix = "f" if single else "" - # Check shapes - source_shape = source_type.shape - points_shape = tuple(x.shape for x in points_type) + source_shape = source_aval.shape + points_shape = tuple(x.shape for x in ctx.avals_in[1:]) n_tot = source_shape[0] n_transf = source_shape[1] n_j = points_shape[0][1] - if output_shape is None: + + # Dispatch to the correct custom call target depending on the dimension, + # dtype, and NUFFT type. + if output_shape is None: # Type 2 op_name = f"nufft{ndim}d2{suffix}".encode("ascii") n_k = np.array(source_shape[2:], dtype=np.int64) - full_output_shape = tuple(source_shape[:2]) + (n_j,) - else: + else: # Type 1 op_name = f"nufft{ndim}d1{suffix}".encode("ascii") n_k = np.array(output_shape, dtype=np.int64) - full_output_shape = tuple(source_shape[:2]) + tuple(output_shape) # The backend expects the output shape in Fortran order, so we'll just # fake it here, by sending in n_k and x in the reverse order. @@ -123,37 +77,30 @@ def lowering(platform, ctx, source, *points, output_shape, iflag, eps, opts): if platform == "cpu": opts = opts.to_finufft_opts() - opaque = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")( + descriptor_bytes = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")( eps, iflag, n_tot, n_transf, n_j, *n_k_full, opts ) - opaque_arg = hlo_const(np.frombuffer(opaque, dtype=np.uint8)) - opaque_shape = ir.RankedTensorType(opaque_arg.type).shape + descriptor = hlo_const(np.frombuffer(descriptor_bytes, dtype=np.uint8)) return custom_call( op_name, - result_types=[ - ir.RankedTensorType.get(full_output_shape, source_type.element_type) - ], + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], # Reverse points because backend uses Fortran order - operands=[opaque_arg, source, *points[::-1]], - operand_layouts=default_layouts( - opaque_shape, source_shape, *points_shape[::-1] - ), - result_layouts=default_layouts(full_output_shape), - ) + operands=[descriptor, source, *points[::-1]], + operand_layouts=default_layouts([0], source_shape, *points_shape[::-1]), + result_layouts=default_layouts(ctx.avals_out[0].shape), + ).results else: opts = opts.to_cufinufft_opts() - opaque = getattr(jax_finufft_gpu, f"build_descriptor{suffix}")( + descriptor_bytes = getattr(jax_finufft_gpu, f"build_descriptor{suffix}")( eps, iflag, n_tot, n_transf, n_j, *n_k_full, opts ) return custom_call( op_name, - result_types=[ - ir.RankedTensorType.get(full_output_shape, source_type.element_type) - ], + result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], # Reverse points because backend uses Fortran order operands=[source, *points[::-1]], - backend_config=opaque, + backend_config=descriptor_bytes, operand_layouts=default_layouts(source_shape, *points_shape[::-1]), - result_layouts=default_layouts(full_output_shape), - ) + result_layouts=default_layouts(ctx.avals_out[0].shape), + ).results diff --git a/src/jax_finufft/ops.py b/src/jax_finufft/ops.py index ce50b26..545f56e 100644 --- a/src/jax_finufft/ops.py +++ b/src/jax_finufft/ops.py @@ -195,9 +195,9 @@ def batch(args, axes, *, output_shape, **kwargs): nufft1_p = core.Primitive("nufft1") nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p)) nufft1_p.def_abstract_eval(shapes.abstract_eval) -mlir.register_lowering(nufft1_p, partial(lowering.lowering, "cpu"), platform="cpu") +mlir.register_lowering(nufft1_p, lowering.lowering, platform="cpu") if lowering.jax_finufft_gpu is not None: - mlir.register_lowering(nufft1_p, partial(lowering.lowering, "gpu"), platform="gpu") + mlir.register_lowering(nufft1_p, lowering.lowering, platform="cuda") ad.primitive_jvps[nufft1_p] = partial(jvp, nufft1_p) ad.primitive_transposes[nufft1_p] = transpose batching.primitive_batchers[nufft1_p] = batch @@ -206,9 +206,9 @@ def batch(args, axes, *, output_shape, **kwargs): nufft2_p = core.Primitive("nufft2") nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p)) nufft2_p.def_abstract_eval(shapes.abstract_eval) -mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="cpu") +mlir.register_lowering(nufft2_p, lowering.lowering, platform="cpu") if lowering.jax_finufft_gpu is not None: - mlir.register_lowering(nufft2_p, partial(lowering.lowering, "gpu"), platform="gpu") + mlir.register_lowering(nufft2_p, lowering.lowering, platform="cuda") ad.primitive_jvps[nufft2_p] = partial(jvp, nufft2_p) ad.primitive_transposes[nufft2_p] = transpose batching.primitive_batchers[nufft2_p] = batch diff --git a/src/jax_finufft/shapes.py b/src/jax_finufft/shapes.py index e3dab0b..6ac7346 100644 --- a/src/jax_finufft/shapes.py +++ b/src/jax_finufft/shapes.py @@ -6,7 +6,6 @@ import jax.numpy as jnp import numpy as np from jax import dtypes -from jax.core import ShapedArray @dataclass @@ -116,10 +115,15 @@ def abstract_eval(source, *points, output_shape, **_): assert all(p.ndim == 2 for p in points) assert all(p.shape == points[0].shape for p in points[1:]) assert source.shape[0] == points[0].shape[0] - if output_shape is None: + + if output_shape is None: # Type 2 assert source.ndim == 2 + ndim - return ShapedArray(source.shape[:2] + (points[0].shape[-1],), source_dtype) - else: + return source.update( + shape=source.shape[:2] + (points[0].shape[-1],), dtype=source_dtype + ) + else: # Type 1 assert source.ndim == 3 assert source.shape[2] == points[0].shape[1] - return ShapedArray(source.shape[:2] + tuple(output_shape), source_dtype) + return source.update( + shape=source.shape[:2] + tuple(output_shape), dtype=source_dtype + )