Skip to content

Commit

Permalink
[jax:pallas] Use 64-bit indexing when necessary when lowering to Triton.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 670530776
  • Loading branch information
chr1sj0nes authored and jax authors committed Sep 3, 2024
1 parent 4c3111b commit 7b161fb
Showing 1 changed file with 32 additions and 24 deletions.
56 changes: 32 additions & 24 deletions jax/_src/pallas/triton/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def _atomic_lowering_rule(
raise NotImplementedError("Only single indexer is supported.")
idx = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
val = _ensure_ir_value(val, value_aval)
if mask is not None:
Expand Down Expand Up @@ -1674,12 +1674,12 @@ def _compute_pointers_from_indices(
root_ptr: ir.Value,
block_info: BlockInfo | None,
nd_indexer: NDIndexer,
array_shape: tuple[int, ...],
array_shape_dtype: Any,
) -> ir.Value:
if block_info is None: # TODO(necula): is this branch dead?
full_shape = array_shape
full_shape = array_shape_dtype.shape
num_mapped_dims = 0
block_shape = array_shape
block_shape = array_shape_dtype.shape
else:
full_shape = block_info.full_shape_dtype.shape
num_mapped_dims = sum(
Expand All @@ -1692,20 +1692,29 @@ def _compute_pointers_from_indices(
_check_tensor_size(indexer_shape)
indices = nd_indexer.indices
other_shape = indexer_shape[len(int_indexer_shape) :]
bcast_indices = []
other_shape_idx = 0
if block_info is None:
start_index_offsets = [None] * len(indices)
else:
start_index_offsets = block_info.start_indices
assert len(indices) + num_mapped_dims == len(full_shape)
assert len(start_index_offsets) == len(full_shape)

array_dtype = jnp.dtype(array_shape_dtype.dtype)
full_size = math.prod(full_shape) * array_dtype.itemsize
# Use 64-bit indexing when offset might be >= 2**32 bytes.
offset_eltype = ir.IntegerType.get_signless(64 if full_size > 2**32 else 32)
if indexer_shape:
offsets = _full(ir.RankedTensorType.get(indexer_shape, offset_eltype), 0)
else:
offsets = _ir_constant(0, offset_eltype)

indexer_iter = iter(indices)
for dim_stride, dim_block_size, start_offset in zip(
strides, block_shape, start_index_offsets
):
if dim_block_size is pallas_core.mapped:
index = _i32_constant(0)
index = _ir_constant(0, offset_eltype)
else:
index = next(indexer_iter)

Expand All @@ -1716,22 +1725,27 @@ def _compute_pointers_from_indices(
if index.is_dynamic_start or (index.stride != 1):
start = index.start
if not index.is_dynamic_start:
start = _i32_constant(start)
start = _ir_constant(start, offset_eltype)
start = _ir_cast(start, offset_eltype, signed=False)

iota = _ir_cast(_make_range(0, index.size), start.type, signed=False)
iota = _ir_cast(_make_range(0, index.size), offset_eltype, signed=False)
if index.stride != 1:
iota = _mul(iota, _full(iota.type, index.stride))
dim_offsets = _add(_bcast_to(start, [index.size]), iota)
else:
dim_offsets = _make_range(index.start, index.start + index.size)
iota = _make_range(index.start, index.start + index.size)
dim_offsets = _ir_cast(iota, offset_eltype, signed=False)

other_shape_idx += 1
for _ in other_shape[other_shape_idx:]:
rank = ir.RankedTensorType(dim_offsets.type).rank
dim_offsets = _expand_dims(dim_offsets, rank)
else:
# indexer is either a *scalar* or an array of size `int_indexer_shape`
dim_offsets = _ensure_ir_value(index, jax_core.ShapedArray((), jnp.int32))
dim_offsets = index
if not isinstance(dim_offsets, ir.Value):
dim_offsets = _ir_constant(dim_offsets, offset_eltype)
dim_offsets = _ir_cast(dim_offsets, offset_eltype, signed=False)

if ir.RankedTensorType.isinstance(dim_offsets.type):
for _ in other_shape:
Expand All @@ -1742,19 +1756,16 @@ def _compute_pointers_from_indices(
rank = ir.RankedTensorType(dim_offsets.type).rank
for _ in range(len(indexer_shape) - rank):
dim_offsets = _expand_dims(dim_offsets, 0)
dim_offsets = _bcast_to(dim_offsets, indexer_shape)
elif indexer_shape:
dim_offsets = _splat(dim_offsets, indexer_shape)
dim_offsets = _bcast_to(dim_offsets, indexer_shape)

if start_offset is not None:
offset_type = _element_type(dim_offsets.type)
start_offset = _ir_cast(start_offset, offset_type, signed=False)
start_offset = _ir_cast(start_offset, offset_eltype, signed=False)
dim_offsets = _add(dim_offsets, _bcast_to(start_offset, indexer_shape))

bcast_indices.append(_mul(dim_offsets, _full(dim_offsets.type, dim_stride)))
dim_offsets = _mul(dim_offsets, _full(dim_offsets.type, dim_stride))
offsets = _add(offsets, dim_offsets)

ptrs = _bcast_to(root_ptr, indexer_shape)
return functools.reduce(_add, bcast_indices, ptrs)
return _add(_bcast_to(root_ptr, indexer_shape), offsets)


@register_lowering(sp.get_p)
Expand Down Expand Up @@ -1869,7 +1880,7 @@ def _masked_load_lowering_rule(
assert len(ctx.avals_in) == 1
return ptr
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
if mask is not None:
mask = _bcast_to(_ensure_ir_value(mask, mask_aval), idx.get_indexer_shape())
Expand Down Expand Up @@ -1966,7 +1977,7 @@ def _masked_swap_lowering_rule(
raise NotImplementedError("No support for multiple indexers yet.")
idx = indexers[0]
ptr = _compute_pointers_from_indices(
ptr, ctx.block_infos[0], idx, ctx.avals_in[0].shape
ptr, ctx.block_infos[0], idx, ctx.avals_in[0]
)
other = None
if value is not None:
Expand All @@ -1991,10 +2002,7 @@ def _addupdate_lowering_rule(ctx: LoweringRuleContext, ptr, value, *idx, tree):
raise NotImplementedError("No support for multiple indexers yet.")
indexer = indexers[0]
ptr = _compute_pointers_from_indices(
ptr,
ctx.block_infos[0],
indexer,
ctx.avals_in[0].shape,
ptr, ctx.block_infos[0], indexer, ctx.avals_in[0]
)
op = tt_dialect.RMWOp.FADD
if isinstance(_element_type(value.type), ir.IntegerType):
Expand Down

0 comments on commit 7b161fb

Please sign in to comment.