Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplifying lowering and abstract evaluation rules #89

Merged
merged 6 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@ on:

jobs:
tests:
name: ${{ matrix.os }}
name: ${{ matrix.os }}-${{ matrix.jax-version }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest]
jax-version: ["jax[cpu]"]
include:
- os: ubuntu-latest
jax-version: "jax[cpu]==0.4.20"

steps:
- uses: actions/checkout@v4
Expand All @@ -37,7 +41,7 @@ jobs:
- name: Build
run: |
python -m pip install -U pip
python -m pip install -U jax[cpu]
python -m pip install -U ${{ matrix.jax-version }}
python -m pip install -v .[test]

- name: Run tests
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ authors = [
requires-python = ">=3.9"
license = { file = "LICENSE" }
urls = { Homepage = "https://github.com/dfm/jax-finufft" }
dependencies = ["jax", "numpy", "pydantic>=2"]
dependencies = ["jax>=0.4.20", "numpy", "pydantic>=2"]
dynamic = ["version"]

[project.optional-dependencies]
Expand Down
133 changes: 40 additions & 93 deletions src/jax_finufft/lowering.py
Original file line number Diff line number Diff line change
@@ -1,115 +1,69 @@
import numpy as np
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jax.lib import xla_client
from jaxlib.hlo_helpers import custom_call as _custom_call
from jax_finufft import options

try:
from jaxlib.hlo_helpers import hlo_const
except ImportError:
# Copied from jaxlib/hlo_helpers.py for old versions of jax
from functools import partial

import jaxlib.mlir.dialects.stablehlo as hlo

_dtype_to_ir_type_factory = {
np.dtype(np.bool_): partial(ir.IntegerType.get_signless, 1),
np.dtype(np.int8): partial(ir.IntegerType.get_signless, 8),
np.dtype(np.int16): partial(ir.IntegerType.get_signless, 16),
np.dtype(np.int32): partial(ir.IntegerType.get_signless, 32),
np.dtype(np.int64): partial(ir.IntegerType.get_signless, 64),
np.dtype(np.uint8): partial(ir.IntegerType.get_unsigned, 8),
np.dtype(np.uint16): partial(ir.IntegerType.get_unsigned, 16),
np.dtype(np.uint32): partial(ir.IntegerType.get_unsigned, 32),
np.dtype(np.uint64): partial(ir.IntegerType.get_unsigned, 64),
np.dtype(np.float16): ir.F16Type.get,
np.dtype(np.float32): ir.F32Type.get,
np.dtype(np.float64): ir.F64Type.get,
np.dtype(np.complex64): lambda: ir.ComplexType.get(ir.F32Type.get()),
np.dtype(np.complex128): lambda: ir.ComplexType.get(ir.F64Type.get()),
}

def dtype_to_ir_type(dtype) -> ir.Type:
return _dtype_to_ir_type_factory[np.dtype(dtype)]()

def hlo_const(x):
assert isinstance(x, np.ndarray)
return hlo.ConstantOp(
ir.DenseElementsAttr.get(x, type=dtype_to_ir_type(x.dtype))
).result
from jaxlib.hlo_helpers import custom_call, hlo_const

from jax_finufft import options

from . import jax_finufft_cpu

try:
from . import jax_finufft_gpu

for _name, _value in jax_finufft_gpu.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
xla_client.register_custom_call_target(_name, _value, platform="CUDA")
except ImportError:
jax_finufft_gpu = None

for _name, _value in jax_finufft_cpu.registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="cpu")


# Handle old versions of jax which had a different syntax for custom_call
def custom_call(*args, **kwargs):
try:
return _custom_call(*args, **kwargs).results
except TypeError:
kwargs["out_types"] = kwargs.pop("result_types")
return (_custom_call(*args, **kwargs),)


def default_layouts(*shapes):
return [range(len(shape) - 1, -1, -1) for shape in shapes]


def lowering(platform, ctx, source, *points, output_shape, iflag, eps, opts):
del ctx

if platform not in ["cpu", "gpu"]:
raise ValueError(f"Unrecognized platform '{platform}'")

if platform == "gpu" and jax_finufft_gpu is None:
def lowering(
ctx: mlir.LoweringRuleContext,
source: ir.Value,
*points,
output_shape,
iflag,
eps,
opts,
):
if len(ctx.module_context.platforms) > 1:
raise ValueError("Multi-platform lowering is not supported")
platform = ctx.module_context.platforms[0]
if platform not in {"cpu", "cuda"}:
raise ValueError(f"Unsupported platform '{platform}'")
if platform == "cuda" and jax_finufft_gpu is None:
raise ValueError("jax-finufft was not compiled with GPU support")

ndim = len(points)
assert 1 <= ndim <= 3
if platform == "gpu" and ndim == 1:
raise ValueError("1-D transforms are not yet supported on the GPU")

source_type = ir.RankedTensorType(source.type)
points_type = [ir.RankedTensorType(x.type) for x in points]

# Check supported and consistent dtypes
f32 = ir.F32Type.get()
f64 = ir.F64Type.get()
source_dtype = source_type.element_type
single = source_dtype == ir.ComplexType.get(f32) and all(
x.element_type == f32 for x in points_type
)
double = source_dtype == ir.ComplexType.get(f64) and all(
x.element_type == f64 for x in points_type
)
assert single or double
source_aval = ctx.avals_in[0]
single = source_aval.dtype == np.complex64
suffix = "f" if single else ""

# Check shapes
source_shape = source_type.shape
points_shape = tuple(x.shape for x in points_type)
source_shape = source_aval.shape
points_shape = tuple(x.shape for x in ctx.avals_in[1:])
n_tot = source_shape[0]
n_transf = source_shape[1]
n_j = points_shape[0][1]
if output_shape is None:

# Dispatch to the correct custom call target depending on the dimension,
# dtype, and NUFFT type.
if output_shape is None: # Type 2
op_name = f"nufft{ndim}d2{suffix}".encode("ascii")
n_k = np.array(source_shape[2:], dtype=np.int64)
full_output_shape = tuple(source_shape[:2]) + (n_j,)
else:
else: # Type 1
op_name = f"nufft{ndim}d1{suffix}".encode("ascii")
n_k = np.array(output_shape, dtype=np.int64)
full_output_shape = tuple(source_shape[:2]) + tuple(output_shape)

# The backend expects the output shape in Fortran order, so we'll just
# fake it here, by sending in n_k and x in the reverse order.
Expand All @@ -123,37 +77,30 @@ def lowering(platform, ctx, source, *points, output_shape, iflag, eps, opts):

if platform == "cpu":
opts = opts.to_finufft_opts()
opaque = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")(
descriptor_bytes = getattr(jax_finufft_cpu, f"build_descriptor{suffix}")(
eps, iflag, n_tot, n_transf, n_j, *n_k_full, opts
)
opaque_arg = hlo_const(np.frombuffer(opaque, dtype=np.uint8))
opaque_shape = ir.RankedTensorType(opaque_arg.type).shape
descriptor = hlo_const(np.frombuffer(descriptor_bytes, dtype=np.uint8))
return custom_call(
op_name,
result_types=[
ir.RankedTensorType.get(full_output_shape, source_type.element_type)
],
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
# Reverse points because backend uses Fortran order
operands=[opaque_arg, source, *points[::-1]],
operand_layouts=default_layouts(
opaque_shape, source_shape, *points_shape[::-1]
),
result_layouts=default_layouts(full_output_shape),
)
operands=[descriptor, source, *points[::-1]],
operand_layouts=default_layouts([0], source_shape, *points_shape[::-1]),
result_layouts=default_layouts(ctx.avals_out[0].shape),
).results

else:
opts = opts.to_cufinufft_opts()
opaque = getattr(jax_finufft_gpu, f"build_descriptor{suffix}")(
descriptor_bytes = getattr(jax_finufft_gpu, f"build_descriptor{suffix}")(
eps, iflag, n_tot, n_transf, n_j, *n_k_full, opts
)
return custom_call(
op_name,
result_types=[
ir.RankedTensorType.get(full_output_shape, source_type.element_type)
],
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
# Reverse points because backend uses Fortran order
operands=[source, *points[::-1]],
backend_config=opaque,
backend_config=descriptor_bytes,
operand_layouts=default_layouts(source_shape, *points_shape[::-1]),
result_layouts=default_layouts(full_output_shape),
)
result_layouts=default_layouts(ctx.avals_out[0].shape),
).results
8 changes: 4 additions & 4 deletions src/jax_finufft/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,9 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft1_p = core.Primitive("nufft1")
nufft1_p.def_impl(partial(xla.apply_primitive, nufft1_p))
nufft1_p.def_abstract_eval(shapes.abstract_eval)
mlir.register_lowering(nufft1_p, partial(lowering.lowering, "cpu"), platform="cpu")
mlir.register_lowering(nufft1_p, lowering.lowering, platform="cpu")
if lowering.jax_finufft_gpu is not None:
mlir.register_lowering(nufft1_p, partial(lowering.lowering, "gpu"), platform="gpu")
mlir.register_lowering(nufft1_p, lowering.lowering, platform="cuda")
ad.primitive_jvps[nufft1_p] = partial(jvp, nufft1_p)
ad.primitive_transposes[nufft1_p] = transpose
batching.primitive_batchers[nufft1_p] = batch
Expand All @@ -206,9 +206,9 @@ def batch(args, axes, *, output_shape, **kwargs):
nufft2_p = core.Primitive("nufft2")
nufft2_p.def_impl(partial(xla.apply_primitive, nufft2_p))
nufft2_p.def_abstract_eval(shapes.abstract_eval)
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "cpu"), platform="cpu")
mlir.register_lowering(nufft2_p, lowering.lowering, platform="cpu")
if lowering.jax_finufft_gpu is not None:
mlir.register_lowering(nufft2_p, partial(lowering.lowering, "gpu"), platform="gpu")
mlir.register_lowering(nufft2_p, lowering.lowering, platform="cuda")
ad.primitive_jvps[nufft2_p] = partial(jvp, nufft2_p)
ad.primitive_transposes[nufft2_p] = transpose
batching.primitive_batchers[nufft2_p] = batch
14 changes: 9 additions & 5 deletions src/jax_finufft/shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax.numpy as jnp
import numpy as np
from jax import dtypes
from jax.core import ShapedArray


@dataclass
Expand Down Expand Up @@ -116,10 +115,15 @@ def abstract_eval(source, *points, output_shape, **_):
assert all(p.ndim == 2 for p in points)
assert all(p.shape == points[0].shape for p in points[1:])
assert source.shape[0] == points[0].shape[0]
if output_shape is None:

if output_shape is None: # Type 2
assert source.ndim == 2 + ndim
return ShapedArray(source.shape[:2] + (points[0].shape[-1],), source_dtype)
else:
return source.update(
shape=source.shape[:2] + (points[0].shape[-1],), dtype=source_dtype
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an unrelated change that I noticed. We're no longer supposed to manually construct a ShapedArray here because it should return the same type as the input (tracer vs concrete value).

else: # Type 1
assert source.ndim == 3
assert source.shape[2] == points[0].shape[1]
return ShapedArray(source.shape[:2] + tuple(output_shape), source_dtype)
return source.update(
shape=source.shape[:2] + tuple(output_shape), dtype=source_dtype
)
Loading