diff --git a/MaxText/kernels/__init__.py b/MaxText/kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/MaxText/kernels/megablox/__init__.py b/MaxText/kernels/megablox/__init__.py new file mode 100644 index 000000000..5df55e489 --- /dev/null +++ b/MaxText/kernels/megablox/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Megablox kernel""" + +from kernels.megablox.ops import gmm diff --git a/MaxText/kernels/megablox/common.py b/MaxText/kernels/megablox/common.py new file mode 100644 index 000000000..d11c80387 --- /dev/null +++ b/MaxText/kernels/megablox/common.py @@ -0,0 +1,59 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common utilities for GMM kernels.""" + +import re + +import jax +import jax.numpy as jnp + + +def is_tpu() -> bool: + return "TPU" in jax.devices()[0].device_kind + + +def tpu_kind() -> str: + """Query identification string for the currently attached TPU.""" + return jax.devices()[0].device_kind + + +_TPU_KIND_PATTERN = re.compile(r"TPU v(\d+)") + + +def tpu_generation() -> int: + """Generation number of the currently attached TPU.""" + if version := _TPU_KIND_PATTERN.match(tpu_kind()): + return int(version[1]) + raise NotImplementedError("only TPU devices are supported") + + +def supports_bfloat16_matmul() -> bool: + """Does the currently attached CPU support bfloat16 inputs?""" + return not is_tpu() or tpu_generation() >= 4 + + +def assert_is_supported_dtype(dtype: jnp.dtype) -> None: + if dtype not in (jnp.bfloat16, jnp.float32): + raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.") + + +def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype: + """A type to which both input should be adapted to before dot product.""" + # bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed + # input precision, we need to convert bf16 argument to fp32 beforehand. + if supports_bfloat16_matmul() and lhs.dtype == jnp.bfloat16 and rhs.dtype == jnp.bfloat16: + return jnp.bfloat16 + else: + return jnp.float32 diff --git a/MaxText/kernels/megablox/gmm.py b/MaxText/kernels/megablox/gmm.py new file mode 100644 index 000000000..0c2a06fde --- /dev/null +++ b/MaxText/kernels/megablox/gmm.py @@ -0,0 +1,822 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grouped matrix multiplication kernels for TPU written in Pallas.""" + +# pylint: disable=too-many-positional-arguments, unnecessary-lambda-assignment + +from collections.abc import Callable +import dataclasses +import functools +from typing import Any, Optional, Literal + +import jax +import jax.numpy as jnp +from jax import lax +from jax.experimental import pallas as pl +from jax.experimental.pallas import tpu as pltpu +from kernels.megablox import common + +from aqt.jax.v2 import pallas as aqt_pl +from aqt.jax.v2 import aqt_tensor + +QTensor = aqt_tensor.QTensor +partial = functools.partial + + +def _validate_args( + *, + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + expected_rhs_dims: int = 3, +) -> tuple[jnp.ndarray, jnp.ndarray, jnp.dtype]: + """Validates the arguments for the gmm function.""" + # Validate 'lhs'. + if lhs.ndim != 2: + raise ValueError(f"Expected 2-tensor for 'lhs' but got {lhs.ndim}-tensor.") + common.assert_is_supported_dtype(lhs.dtype) + + # Validate 'rhs'. + if rhs.ndim != expected_rhs_dims: + raise ValueError(f"Expected {expected_rhs_dims}-tensor for 'rhs' but got" f" {rhs.ndim}-tensor.") + common.assert_is_supported_dtype(rhs.dtype) + + # Validate 'group_sizes'. + if group_sizes.dtype != jnp.int32: + raise ValueError(f"Expected 32-bit integer 'group_sizes' but got {group_sizes.dtype}.") + + return lhs, group_sizes, common.select_input_dtype(lhs, rhs) + + +def _calculate_num_tiles(x: int, tx: int) -> int: + tiles, rem = divmod(x, tx) + if rem: + raise ValueError(f"{x} must be divisible by x-dimension tile size ({tx}).") + return tiles + + +def _calculate_irregular_num_tiles(x: int, tx: int) -> tuple[int, int]: + tiles, rem = divmod(x, tx) + if rem: + tiles += 1 + return tiles, rem + + +GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple + + +def make_group_metadata( + *, + group_sizes: jnp.ndarray, + m: int, + tm: int, + start_group: jnp.ndarray, + num_nonzero_groups: int, + visit_empty_groups: bool = True, +) -> GroupMetadata: + """Create the metadata needed for grouped matmul computation. + + Args: + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + m: The number of rows in lhs. + tm: The m-dimension tile size being used. + start_group: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + num_nonzero_groups: Number of groups in group sizes to compute on. Useful in + combination with group_offset. + visit_empty_groups: If True, do not squeeze tiles for empty groups out of + the metadata. This is necessary for tgmm, where we at least need to zero + the output for each group. + + Returns: + tuple of: + group_offsets: A 1d, jnp.ndarray with shape [num_groups+1] and jnp.int32 + dtype. group_offsets[i] indicates the row at which group [i] starts in + the lhs matrix and group_offsets[i-1] = m. + group_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32 dtype. group_ids[i] indicates which group grid index 'i' will + work on. + m_tile_ids: A 1d, jnp.ndarray with shape [m_tiles + num_groups] and + jnp.int32. m_tile_ids[i] indicates which m-dimension tile grid index 'i' + will work on. + num_tiles: The number of m-dimension tiles to execute. + """ + num_groups = group_sizes.shape[0] + end_group = start_group + num_nonzero_groups - 1 + + # Calculate the offset of each group, starting at zero. This metadata is + # similar to row offsets in a CSR matrix. The following properties hold: + # + # group_offsets.shape = [num_groups + 1] + # group_offsets[0] = 0 + # group_offsets[num_groups] = m + # + # The row at which group 'i' starts is group_offsets[i]. + group_ends = jnp.cumsum(group_sizes) + group_offsets = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends]) + + # Assign a group id to each grid index. + # + # If a group starts somewhere other than the start of a tile or ends somewhere + # other than the end of a tile we need to compute that full tile. Calculate + # the number of tiles for each group by rounding their end up to the nearest + # 'tm' and their start down to the nearest 'tm'. + + # (1) Round the group_ends up to the nearest multiple of 'tm'. + # + # NOTE: This does not change group_offsets[num_groups], which is m + # (because we enforce m is divisible by tm). + rounded_group_ends = ((group_ends + tm - 1) // tm * tm).astype(jnp.int32) + + # (2) Round the group_starts down to the nearest multiple of 'tm'. + group_starts = jnp.concatenate([jnp.zeros(1, dtype=jnp.int32), group_ends[:-1]]) + rounded_group_starts = group_starts // tm * tm + + # (3) Calculate the number of rows in each group. + # + # NOTE: Handle zero-sized groups as a special case. If the start for a + # zero-sized group is not divisible by 'tm' its start will be rounded down and + # its end will be rounded up such that its size will become 1 tile here. + rounded_group_sizes = rounded_group_ends - rounded_group_starts + rounded_group_sizes = jnp.where(group_sizes == 0, 0, rounded_group_sizes) + + # (4) Convert the group sizes from units of rows to unit of 'tm' sized tiles. + # + # An m-dimension tile is 'owned' by group 'i' if the first row of the tile + # belongs to group 'i'. In addition to owned tiles, each group can have 0 or 1 + # initial partial tiles if it's first row does not occur in the first row of a + # tile. The '0-th' group never has a partial tile because it always starts at + # the 0-th row. + # + # If no group has a partial tile, the total number of tiles is equal to + # 'm // tm'. If every group has a partial except the 0-th group, the total + # number of tiles is equal to 'm // tm + num_groups - 1'. Thus we know that + # + # tiles_m <= group_tiles.sum() <= tiles_m + num_groups - 1 + # + # Where tiles_m = m // tm. + # + # NOTE: All group sizes are divisible by 'tm' because of the rounding in steps + # (1) and (2) so this division is exact. + group_tiles = rounded_group_sizes // tm + + if visit_empty_groups: + # Insert one tile for empty groups. + group_tiles = jnp.where(group_sizes == 0, 1, group_tiles) + + # Create the group ids for each grid index based on the tile counts for each + # group. + # + # NOTE: This repeat(...) will pad group_ids with the final group id if + # group_tiles.sum() < tiles_m + num_groups - 1. The kernel grid will be sized + # such that we only execute the necessary number of tiles. + tiles_m = _calculate_num_tiles(m, tm) + group_ids = jnp.repeat( + jnp.arange(num_groups, dtype=jnp.int32), + group_tiles, + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Assign an m-dimension tile id to each grid index. + # + # NOTE: Output tiles can only be re-visited consecutively. The following + # procedure guarantees that m-dimension tile indices respect this. + + # (1) Calculate how many times each m-dimension tile will be visited. + # + # Each tile is guaranteed to be visited once by the group that owns the tile. + # The remaining possible visits occur when a group starts inside of a tile at + # a position other than the first row. We can calculate which m-dimension tile + # each group starts in by floor-dividing its offset with `tm` and then count + # tile visits with a histogram. + # + # To avoid double counting tile visits from the group that owns the tile, + # filter these out by assigning their tile id to `tile_m` (one beyond the max) + # such that they're ignored by the subsequent histogram. Also filter out any + # group which is empty. + # + # TODO(tgale): Invert the 'partial_tile_mask' predicates to be more clear. + partial_tile_mask = jnp.logical_or((group_offsets[:-1] % tm) == 0, group_sizes == 0) + + # Explicitly enable tiles for zero sized groups, if specified. This covers + # zero sized groups that start on a tile-aligned row and those that do not. + if visit_empty_groups: + partial_tile_mask = jnp.where(group_sizes == 0, 0, partial_tile_mask) + + partial_tile_ids = jnp.where(partial_tile_mask, tiles_m, group_offsets[:-1] // tm) + + tile_visits = jnp.histogram(partial_tile_ids, bins=tiles_m, range=(0, tiles_m - 1))[0] + 1 + + # Create the m-dimension tile ids for each grid index based on the visit + # counts for each tile. + m_tile_ids = jnp.repeat( + jnp.arange(tiles_m, dtype=jnp.int32), + tile_visits.astype(jnp.int32), + total_repeat_length=tiles_m + num_groups - 1, + ) + + # Account for sharding. + # + # Find the start of the groups owned by our shard and shift the group_ids and + # m_tile_ids s.t. the metadata for our tiles are at the front of the arrays. + # + # TODO(tgale): Move this offset into the kernel to avoid these rolls. + first_tile_in_shard = (group_ids < start_group).sum() + group_ids = jnp.roll(group_ids, shift=-first_tile_in_shard, axis=0) + m_tile_ids = jnp.roll(m_tile_ids, shift=-first_tile_in_shard, axis=0) + + # Calculate the number of tiles we need to compute for our shard. + # + # Remove tile visits that belong to a group not in our shard. + iota = jnp.arange(num_groups, dtype=jnp.int32) + active_group_mask = jnp.logical_and(iota <= end_group, iota >= start_group) + group_tiles = jnp.where(active_group_mask, group_tiles, 0) + num_tiles = group_tiles.sum() + return (group_offsets, group_ids, m_tile_ids), num_tiles + + +def _get_group_size(*, grid_id: jnp.ndarray, group_metadata: GroupMetadata) -> jnp.ndarray: + """Calculate the number of rows in the current group.""" + group_offsets, group_ids = group_metadata[:2] + group_id = group_ids[grid_id] + group_start = group_offsets[group_id] + group_end = group_offsets[group_id + 1] + return group_end - group_start + + +def _get_store_mask( + *, + grid_id: jnp.ndarray, + group_metadata: GroupMetadata, + tm: int, + tn: int, +) -> jnp.ndarray: + """Mask for rows that belong to the current group in the current tile.""" + group_offsets, group_ids, m_tile_ids = group_metadata[:3] + group_id = group_ids[grid_id] + group_start = group_offsets[group_id] + group_end = group_offsets[group_id + 1] + m_id = m_tile_ids[grid_id] * tm + iota = jax.lax.broadcasted_iota(jnp.int32, (tm, tn), 0) + m_id + return jnp.logical_and(iota >= group_start, iota < group_end) + + +def _zero_uninitialized_memory( + out: jnp.ndarray, + *, + start_group: jnp.ndarray, + num_nonzero_groups: int, + group_metadata: GroupMetadata, +) -> jnp.ndarray: + """Zero out uninitialized memory from output.""" + group_offsets = group_metadata[0] + group_start = group_offsets[start_group] + group_end = group_offsets[start_group + num_nonzero_groups] + valid_mask = jax.lax.broadcasted_iota(jnp.int32, (out.shape[0],), 0) + valid_mask = (valid_mask >= group_start) & (valid_mask < group_end) + return jnp.where(valid_mask[:, None], out, 0) + + +LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] + + +@functools.partial( + jax.jit, + static_argnames=[ + "preferred_element_type", + "tiling", + "transpose_rhs", + "interpret", + "lhs_quantize_dtype", + "rhs_quantize_dtype", + ], +) +def gmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray | QTensor, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, + tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + existing_out: jnp.ndarray | None = None, + transpose_rhs: bool = False, + interpret: bool = False, + lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None, + rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None, +) -> jnp.ndarray: + """Compute lhs[sizes[i-1]:sizes[i], :] @ rhs for each group 'i'. + + Args: + lhs: A 2d, jnp.ndarray with shape [m, k]. + rhs: A 3d, jnp.ndarray with shape [num_groups, k, n]. + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + preferred_element_type: jnp.dtype, the element type for the output matrix. + tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. + group_offset: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + existing_out: Existing output to write to. + transpose_rhs: True if the rhs needs to be transposed. + interpret: Whether or not to run the kernel in interpret mode, helpful for + testing and debugging. + lhs_quantize_dtype: Bit precision of lhs after quantization. + rhs_quantirhs_quantize_dtypeze_bits: Bit precision of rhs after quantization. + If rhs (weight) is already quantized this must be samed with precision of the given weight. + + Returns: + A 2d, jnp.ndarray with shape [m, n]. + """ + + if rhs_quantize_dtype is None and isinstance(rhs, QTensor): + raise ValueError("rhs_quantize_dtype is None, but quantized rhs is given.") + + if rhs_quantize_dtype is not None and isinstance(rhs, QTensor): + # If weight is alreeady quantized check precision. + if rhs_quantize_dtype != rhs.qvalue.dtype: + raise ValueError( + f"{rhs_quantize_dtype=} and already given quantized {rhs.qvalue.dtype=} does not have the same precision" + ) + + if existing_out is not None: + assert isinstance(existing_out, jax.Array) + expected_dtype = existing_out.dtype + if expected_dtype != preferred_element_type: + raise ValueError("Existing output dtype must match preferred_element_type.") + if group_offset is None: + group_offset = jnp.array([0], dtype=jnp.int32) + else: + if group_offset.shape: + raise ValueError(f"group_offset must be a ()-shaped array. Got: {group_offset.shape}.") + group_offset = group_offset[None] + num_current_groups = rhs.shape[0] + num_total_groups = group_sizes.shape[0] + lhs, group_sizes, input_dtype = _validate_args(lhs=lhs, rhs=rhs, group_sizes=group_sizes) + + # Gather shape information. + m, k, n = (lhs.shape[0], lhs.shape[1], rhs.shape[2]) + if transpose_rhs: + n = rhs.shape[1] + + # If tiling is callable, look up the problem dimensions in the LUT. If no tuned + # tile dimensions are available throw an error. + if callable(tiling): + tiling = tiling(m, k, n) + + if tiling is None: + raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})") + + tm, tk, tn = tiling + tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk) + tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn) + del n_rem + + # Create the metadata we need for computation. + group_metadata, num_active_tiles = make_group_metadata( # pylint: disable=unbalanced-tuple-unpacking + group_sizes=group_sizes, + m=m, + tm=tm, + start_group=group_offset[0], + num_nonzero_groups=rhs.shape[0], + visit_empty_groups=False, + ) + + # We need to know contracting axis when we quantized lhs and rhs + # Thus move this code part outside of kernel. + if transpose_rhs: + dot_general_dims = (((1,), (1,)), ((), ())) + else: + dot_general_dims = (((1,), (0,)), ((), ())) + + def kernel( + group_metadata, + group_offset, + lhs: jax.Array | QTensor, + rhs: jax.Array | QTensor, + existing_out, + out, + acc_scratch, + ): + group_offsets, group_ids, m_tile_ids = group_metadata + del group_offsets, group_ids, group_offset + + grid_id = pl.program_id(1) + k_i = pl.program_id(2) + + @pl.when(k_i == 0) + def _zero_acc(): + acc_scratch[...] = jnp.zeros_like(acc_scratch) + + if existing_out is not None: + prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0) + is_first_processed_group = grid_id == 0 + m_tile_changed = m_tile_ids[grid_id] != m_tile_ids[prev_grid_id] + first_time_seeing_out = jnp.logical_or(is_first_processed_group, m_tile_changed) + + @pl.when(first_time_seeing_out) + def _init_out(): + out[...] = existing_out[...] + + def mask_k_rem(x, *, dim, quantize): + if k_rem == 0: + return x + + orig_dtype = x.dtype + iota = lax.broadcasted_iota(jnp.int32, x.shape, dim) + if quantize is None: + x = x.astype(jnp.float32) + else: + x = x.astype(jnp.int32) + return jnp.where(iota < k_rem, x, 0).astype(orig_dtype) + + def _store_accum(): + mask = _get_store_mask( + grid_id=grid_id, + group_metadata=group_metadata, + tm=tm, + tn=tn, + ) + to_store = acc_scratch[...] + out[...] = jax.lax.select(mask[...], to_store, out[...].astype(jnp.float32)).astype(preferred_element_type) + + def _accum(is_last_k_tile): + if is_last_k_tile: + mask_k_rem_lhs = partial(mask_k_rem, dim=1, quantize=lhs_quantize_dtype is not None) + mask_k_rem_rhs = partial(mask_k_rem, dim=int(transpose_rhs), quantize=rhs_quantize_dtype is not None) + else: + mask_k_rem_lhs = lambda x: x + mask_k_rem_rhs = lambda x: x + + if isinstance(lhs, QTensor): + # loaded_lhs = aqt_pl.load_qtensor(lhs) + # Let qx: QTensor, qx = quant(x, 8 , ...) + # qx.dequant() == qx.qvalue * qx.scale ~= x + # Thus, setting qvalue to zero is equivalent to setting original tensor + # to zero. + qvalue = mask_k_rem_lhs(lhs.qvalue[...]) + loaded_lhs = dataclasses.replace(lhs, qvalue=qvalue) + loaded_lhs = aqt_pl.load_qtensor(loaded_lhs) + else: + loaded_lhs = mask_k_rem_lhs(lhs[...]).astype(input_dtype) + + if isinstance(rhs, QTensor): + qvalue = mask_k_rem_rhs(rhs.qvalue[...]) + loaded_rhs = dataclasses.replace(rhs, qvalue=qvalue) + loaded_rhs = aqt_pl.load_qtensor(loaded_rhs) + else: + loaded_rhs = mask_k_rem_rhs(rhs[...]).astype(input_dtype) + + acc_scratch[...] += aqt_pl.dot_general( + loaded_lhs, + loaded_rhs, + preferred_element_type=jnp.float32, + dimension_numbers=dot_general_dims, + ) + if is_last_k_tile: + _store_accum() + + lax.cond( + k_i == tiles_k - 1, + partial(_accum, True), + partial(_accum, False), + ) + + def lhs_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): + # lhs is (m, k). Load the [tm, tk] matrix for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del n_i, group_offsets, group_ids, group_offset + return m_tile_ids[grid_id], k_i + + def rhs_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): + # rhs is (num_groups, k, n). Load the [tk, tn] matrix based on the group id + # for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del group_offsets, m_tile_ids + if transpose_rhs: + k_i, n_i = n_i, k_i + + # NOTE: If we're working on only a shard of the rhs we need to adjust the + # group index we load from to account for this. The group_ids are in the + # "unsharded" domain. + return group_ids[grid_id] - group_offset[0], k_i, n_i + + def out_transform_indices(n_i, grid_id, k_i, group_metadata, group_offset): + # out is (m, n). Load the [tm, tn] matrix for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del k_i, group_offsets, group_ids, group_offset + return m_tile_ids[grid_id], n_i + + out_block_spec = pl.BlockSpec((tm, tn), out_transform_indices) + if existing_out is None: + in_out_block_spec: Any = None + input_output_aliases = {} + else: + in_out_block_spec = out_block_spec + existing_out_arg_index = 6 + # adding one more input because of scale factor of quantized tensor. + if lhs_quantize_dtype is not None: + existing_out_arg_index += 1 + if rhs_quantize_dtype is not None: + existing_out_arg_index += 1 + input_output_aliases = {existing_out_arg_index: 0} + + lhs_block_spec = pl.BlockSpec((tm, tk), lhs_transform_indices) + if transpose_rhs: + rhs_block_spec = pl.BlockSpec((None, tn, tk), rhs_transform_indices) + else: + rhs_block_spec = pl.BlockSpec((None, tk, tn), rhs_transform_indices) + + lhs_bytes = lhs.size * lhs.itemsize + if isinstance(rhs, QTensor): + rhs_bytes = (k * n) * rhs.qvalue.itemsize # ignore scale factor as its size marginal. + else: + rhs_bytes = (k * n) * rhs.itemsize # We don't read all of rhs + + out_bytes = (m * n) * jnp.dtype(preferred_element_type).itemsize + max_active_tiles = group_metadata[1].size + bytes_accessed = (lhs_bytes * tiles_n) + (rhs_bytes * max_active_tiles) + out_bytes + flops = 2 * m * k * n + cost_estimate = pl.CostEstimate(flops=flops, bytes_accessed=bytes_accessed, transcendentals=0) + if lhs_quantize_dtype is not None or rhs_quantize_dtype is not None: + pallas_call_fn = aqt_pl.pallas_call + else: + pallas_call_fn = pl.pallas_call + call_gmm = pallas_call_fn( + kernel, + out_shape=jax.ShapeDtypeStruct((m, n), preferred_element_type), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + in_specs=[ + lhs_block_spec, + rhs_block_spec, + in_out_block_spec, + ], + out_specs=out_block_spec, + grid=(tiles_n, num_active_tiles, tiles_k), + scratch_shapes=[pltpu.VMEM((tm, tn), jnp.float32)], + ), + input_output_aliases=input_output_aliases, + compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel", "arbitrary", "arbitrary")), + interpret=interpret, + cost_estimate=cost_estimate, + ) + + lhs_contracting_axis, rhs_contracting_axis = dot_general_dims[0] + # Since block_spec.block_shape of rhs is None, the first axis is reduced + # inside kernel, e.g., if block_shape is (None, tn, tk) then a tensor of + # shape (tn, tk) will be feteched inside kernel instead of (1, tn, tk). + # Therefore, we need to add one to rhs_contracting_axis. + rhs_contracting_axis = map(lambda x: x + 1, rhs_contracting_axis) + + if lhs_quantize_dtype is not None: + lhs_quantize_bits = 4 if lhs_quantize_dtype == jnp.int4 else 8 + lhs = aqt_pl.quant(lhs, lhs_quantize_bits, lhs_contracting_axis) + + if not isinstance(rhs, QTensor) and rhs_quantize_dtype is not None: + rhs_quantize_bits = 4 if rhs_quantize_dtype == jnp.int4 else 8 + rhs = aqt_pl.quant(rhs, rhs_quantize_bits, list(rhs_contracting_axis)) + + out = call_gmm( + group_metadata, + group_offset, + lhs, + rhs, + existing_out, + ) + if existing_out is None and num_current_groups < num_total_groups: + out = _zero_uninitialized_memory( + out, + start_group=group_offset[0], + num_nonzero_groups=rhs.shape[0], + group_metadata=group_metadata, + ) + return out + + +@functools.partial( + jax.jit, + static_argnames=[ + "preferred_element_type", + "tiling", + "num_actual_groups", + "interpret", + ], +) +def tgmm( + lhs: jnp.ndarray, + rhs: jnp.ndarray, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, + tiling: tuple[int, int, int] | LutFn | None = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + num_actual_groups: int | None = None, + existing_out: jnp.ndarray | None = None, + interpret: bool = False, +) -> jnp.ndarray: + """Compute lhs[:, sizes[i-1]:sizes[i]] @ rhs[sizes[i-1]:sizes[i], :]. + + Args: + lhs: A 2d, jnp.ndarray with shape [k, m]. + rhs: A 2d, jnp.ndarray with shape [m, n]. + group_sizes: A 1d, jnp.ndarray with shape [num_groups] and jnp.int32 dtype. + preferred_element_type: jnp.dtype, the element type for the output matrix. + tiling: 3-tuple of ints. The m, k and n-dimension tile sizes. + group_offset: The group in group sizes to start computing from. This is + particularly useful for when rhs num_groups is sharded. + num_actual_groups: For when num_groups is sharded and we should only compute + the groups that are local, starting from group_offset. + existing_out: Existing output to write to. + interpret: Whether or not to run the kernel in interpret mode, helpful for + testing and debugging. + + Returns: + A 3d, jnp.ndarray with shape [num_groups, k, n]. + """ + if group_offset is None: + group_offset = jnp.array([0], dtype=jnp.int32) + else: + group_offset = group_offset[None] + lhs, group_sizes, input_dtype = _validate_args(lhs=lhs, rhs=rhs, group_sizes=group_sizes, expected_rhs_dims=2) + + # Gather shape information. + k, m, n = (lhs.shape[0], lhs.shape[1], rhs.shape[1]) + num_groups = group_sizes.shape[0] + num_actual_groups = num_actual_groups if num_actual_groups is not None else num_groups + + # If tiling is callable, look up the problem dimensions in the LUT. If no tuned + # tile dimensions are available throw an error. + if callable(tiling): + tiling = tiling(m, k, n) + + if tiling is None: + raise ValueError(f"No tuned tiling found for (m, k, n) = ({m}, {k}, {n})") + + tm, tk, tn = tiling + tiles_k, k_rem = _calculate_irregular_num_tiles(k, tk) + del k_rem + tiles_n, n_rem = _calculate_irregular_num_tiles(n, tn) + del n_rem + + # Create the metadata we need for computation. + group_metadata, num_active_tiles = make_group_metadata( + group_sizes=group_sizes, + m=m, + tm=tm, + start_group=group_offset[0], + num_nonzero_groups=num_actual_groups, + visit_empty_groups=True, + ) + + def kernel( + group_metadata, + group_offset, + lhs, + rhs, + existing_out, + out, + acc_scratch, + ): + grid_id = pl.program_id(2) + group_offsets, group_ids, m_tile_ids = group_metadata + del group_offsets, group_offset, m_tile_ids + + group = group_ids[grid_id] + prev_grid_id = jnp.where(grid_id > 0, grid_id - 1, 0) + prev_group = group_ids[prev_grid_id] + + group_has_changed = jnp.logical_or(grid_id == 0, prev_group != group) + + @pl.when(group_has_changed) + def _zero_acc(): + acc_scratch[...] = jnp.zeros_like(acc_scratch) + + # We'll only do computation if our group has a nonzero number of rows in it. + dont_skip = _get_group_size(grid_id=grid_id, group_metadata=group_metadata) > 0 + + @pl.when(dont_skip) + def _do(): + rhs_mask = _get_store_mask( + grid_id=grid_id, + group_metadata=group_metadata, + tm=tm, + tn=tn, + ) + lhs_mask = _get_store_mask( + grid_id=grid_id, + group_metadata=group_metadata, + tm=tm, + tn=tk, + ) + + loaded_lhs = lhs[...] + loaded_rhs = rhs[...] + loaded_lhs = lax.select( + lhs_mask[...], + loaded_lhs.astype(jnp.float32), + jnp.zeros_like(lhs, jnp.float32), + ).swapaxes(0, 1) + loaded_rhs = lax.select( + rhs_mask[...], + loaded_rhs.astype(jnp.float32), + jnp.zeros_like(rhs, jnp.float32), + ) + + acc_scratch[...] += lax.dot( + loaded_lhs.astype(input_dtype), + loaded_rhs.astype(input_dtype), + preferred_element_type=jnp.float32, + ) + + is_end_of_grid = grid_id == (pl.num_programs(2) - 1) + next_grid_id = jnp.where(is_end_of_grid, grid_id, grid_id + 1) + next_group = group_ids[next_grid_id] + + group_is_changing = jnp.logical_or(is_end_of_grid, group != next_group) + + @pl.when(group_is_changing) + def _store_accum(): + to_store = acc_scratch[...] + if existing_out is not None: + to_store += existing_out[...].astype(jnp.float32) + out[...] = to_store.astype(preferred_element_type) + + def lhs_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): + # lhs is (m, k). Load the [tm, tk] matrix for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del n_i, group_offsets, group_ids, group_offset + return m_tile_ids[grid_id], k_i + + def rhs_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): + # rhs is (m, n). Load the [tm, tn] matrix for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del k_i, group_offsets, group_ids, group_offset + return m_tile_ids[grid_id], n_i + + def out_transform_indices(n_i, k_i, grid_id, group_metadata, group_offset): + # out is (num_groups, k, n). Load the [tk, tn] matrix based on the group id + # for this m-tile. + group_offsets, group_ids, m_tile_ids = group_metadata + del group_offsets, m_tile_ids + + # NOTE: If we're working on only a shard of the output we need to adjust the + # group index we load from to account for this. The group_ids are in the + # "unsharded" domain. + return group_ids[grid_id] - group_offset[0], k_i, n_i + + out_block_spec = pl.BlockSpec((None, tk, tn), out_transform_indices) + if existing_out is None: + in_out_block_spec: Any = None + input_output_aliases = {} + else: + in_out_block_spec = out_block_spec + input_output_aliases = {6: 0} + + lhs_block_spec = pl.BlockSpec((tm, tk), lhs_transform_indices) + rhs_block_spec = pl.BlockSpec((tm, tn), rhs_transform_indices) + + lhs_bytes = lhs.size * lhs.itemsize + rhs_bytes = rhs.size * rhs.itemsize + out_bytewidth = jnp.dtype(preferred_element_type).itemsize + out_bytes = (num_actual_groups * k * n) * out_bytewidth + bytes_accessed = (lhs_bytes * tiles_n) + (rhs_bytes * tiles_k) + out_bytes + flops = 2 * m * k * n + cost_estimate = pl.CostEstimate(flops=flops, bytes_accessed=bytes_accessed, transcendentals=0) + lhs = lhs.swapaxes(0, 1) + call_gmm = pl.pallas_call( + kernel, + out_shape=jax.ShapeDtypeStruct((num_actual_groups, k, n), preferred_element_type), + grid_spec=pltpu.PrefetchScalarGridSpec( + num_scalar_prefetch=2, + in_specs=[ + lhs_block_spec, + rhs_block_spec, + in_out_block_spec, + ], + out_specs=out_block_spec, + grid=(tiles_n, tiles_k, num_active_tiles), + scratch_shapes=[pltpu.VMEM((tk, tn), jnp.float32)], + ), + input_output_aliases=input_output_aliases, + compiler_params=pltpu.TPUCompilerParams(dimension_semantics=("parallel", "arbitrary", "arbitrary")), + interpret=interpret, + cost_estimate=cost_estimate, + ) + + out = call_gmm( + group_metadata, + group_offset, + lhs, + rhs, + existing_out, + ) + return out diff --git a/MaxText/kernels/megablox/ops.py b/MaxText/kernels/megablox/ops.py new file mode 100644 index 000000000..1fce34b8a --- /dev/null +++ b/MaxText/kernels/megablox/ops.py @@ -0,0 +1,113 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Grouped matrix multiplication operations with custom VJPs.""" + +# pylint: disable=too-many-positional-arguments + +import jax +import jax.numpy as jnp +from kernels.megablox import gmm as backend +from aqt.jax.v2 import aqt_tensor +from typing import Literal + +gmm = jax.custom_vjp( + backend.gmm, + nondiff_argnums=(3, 4, 7, 8, 9, 10), +) + + +def _gmm_fwd( + lhs: jnp.ndarray, + rhs: jnp.ndarray | aqt_tensor.QTensor, + group_sizes: jnp.ndarray, + preferred_element_type: jnp.dtype = jnp.float32, + tiling: tuple[int, int, int] = (128, 128, 128), + group_offset: jnp.ndarray | None = None, + existing_out: jnp.ndarray | None = None, + transpose_rhs: bool = False, + interpret: bool = False, + lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None, + rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None = None, +) -> tuple[ + jnp.ndarray, + tuple[ + jnp.ndarray, + jnp.ndarray | aqt_tensor.QTensor, + jnp.ndarray, + jnp.ndarray | None, + int, + ], +]: + """Forward function for GMM VJP.""" + out = backend.gmm( + lhs, + rhs, + group_sizes, + preferred_element_type, + tiling, + group_offset, + existing_out, + transpose_rhs=transpose_rhs, + interpret=interpret, + lhs_quantize_dtype=lhs_quantize_dtype, + rhs_quantize_dtype=rhs_quantize_dtype, + ) + return out, (lhs, rhs, group_sizes, group_offset, rhs.shape[0]) + + +def _gmm_bwd( + preferred_element_type: jnp.dtype, + tiling: tuple[int, int, int], + transpose_rhs: bool, + interpret: bool, + lhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None, + rhs_quantize_dtype: Literal[jnp.int4, jnp.int8] | None, + residual: tuple[ + jnp.ndarray, + jnp.ndarray | aqt_tensor.QTensor, + jnp.ndarray, + jnp.ndarray | None, + int, + ], + grad: jnp.ndarray, +) -> tuple[jnp.ndarray, jnp.ndarray, None, None, jnp.ndarray]: + """Backward function for throughput GMM VJP.""" + del preferred_element_type + lhs, rhs, group_sizes, group_offset, num_actual_groups = residual + grad_lhs = backend.gmm( + grad, + rhs, + group_sizes, + lhs[0].dtype, + tiling, + group_offset, + transpose_rhs=not transpose_rhs, + interpret=interpret, + lhs_quantize_dtype=lhs_quantize_dtype, + rhs_quantize_dtype=rhs_quantize_dtype, + ) + grad_rhs = backend.tgmm( + lhs.swapaxes(0, 1), grad, group_sizes, rhs.dtype, tiling, group_offset, num_actual_groups, interpret=interpret + ) + + # NOTE: If the rhs transposition is fused into the forward pass we need to + # return the transpose of the rhs gradient that we calculated above. + # + # TODO(tgale, enriqueps, apaske): Fuse this transposition into the tgmm. + grad_rhs = grad_rhs.swapaxes(1, 2) if transpose_rhs else grad_rhs + return grad_lhs, grad_rhs, None, None, grad + + +gmm.defvjp(_gmm_fwd, _gmm_bwd) diff --git a/MaxText/kernels/ragged_attention.py b/MaxText/kernels/ragged_attention.py index 8ddeb7214..20cd292e8 100644 --- a/MaxText/kernels/ragged_attention.py +++ b/MaxText/kernels/ragged_attention.py @@ -20,14 +20,13 @@ import jax from jax import lax +from jax.experimental import shard_map from jax.experimental import pallas as pl from jax.experimental.pallas import tpu as pltpu import jax.numpy as jnp import numpy as np import common_types -from jax.experimental import shard_map - BATCH = common_types.BATCH DEFAULT_MASK_VALUE = common_types.DEFAULT_MASK_VALUE @@ -166,7 +165,7 @@ def reference_gqa( return o, logits_max, denominator -def ragged_flash_attention_kernel( +def ragged_flash_attention_kernel( # pylint: disable=too-many-positional-arguments lengths_ref, q_ref, k_ref, @@ -277,11 +276,11 @@ def compute_ragged_block_indices(b, i, lengths_ref): ], grid=(batch_size, seq_len // block_size), ), - compiler_params=dict( - mosaic=dict( - dimension_semantics=("parallel", "arbitrary"), - ) - ), + compiler_params={ + "mosaic": { + "dimension_semantics": ("parallel", "arbitrary"), + }, + }, out_shape=[ jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), jax.ShapeDtypeStruct((batch_size, num_heads, head_dim), jnp.float32), diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 7500b023f..7a622734d 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -18,6 +18,7 @@ import operator from typing import Any, Callable, Iterable, Sequence, Tuple, Union, Optional +import flax import flax.linen as nn import jax from jax import lax @@ -29,14 +30,12 @@ import numpy as np from jax.ad_checkpoint import checkpoint_name from jax.experimental import shard_map -import max_logging import math +import max_logging +import max_utils +from aqt.jax.v2 import aqt_tensor +from kernels import megablox as mblx -try: - from jax.experimental.pallas.ops.tpu import megablox as mblx -except ImportError: - max_logging.log("JAX megablox is available for TPU only.") - pass Array = common_types.Array Config = common_types.Config @@ -49,6 +48,18 @@ RMSNorm = normalizations.RMSNorm Quant = quantizations.AqtQuantization +QTensor = aqt_tensor.QTensor + +DISPATCH = "dispatch" +COMBINE = "combine" + + +def _get_model_call_mode(config): + if config.model_cal_mode == "inference": + return "inference" + else: + return None + DISPATCH = "dispatch" COMBINE = "combine" @@ -386,15 +397,14 @@ def permute(self, inputs, gate_logits): group_size = jnp.bincount(flatten_selected_experts, length=self.num_experts) return sorted_inputs, sorted_selected_experts, weights, group_size - def unpermute(self, intermediate, sorted_selected_experts, weights): + def unpermute(self, intermediate, sorted_selected_experts, weights, batch_size: int, sequence_length: int): """Unpermute tokens to original order and combine weights.""" unsort_intermediate = jnp.take(intermediate, indices=jnp.argsort(sorted_selected_experts), axis=0) reshaped_weights = jnp.reshape(weights, (-1, self.num_experts_per_tok)) - tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism reshaped_intermediate = jnp.reshape( unsort_intermediate, - (-1, self.num_experts_per_tok, self.config.emb_dim // tensor_parallelism), + (reshaped_weights.shape[0], self.num_experts_per_tok, -1), ) with jax.named_scope("weight_sum"): matmul_precision = lax.Precision(self.config.matmul_precision) @@ -404,7 +414,7 @@ def unpermute(self, intermediate, sorted_selected_experts, weights): reshaped_weights.astype(jnp.float32), precision=matmul_precision, ) - return output.reshape(-1, self.config.max_target_length, self.config.emb_dim // tensor_parallelism).astype(self.dtype) + return output.reshape(batch_size, sequence_length, -1).astype(self.dtype) def megablox(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): tile_size = (512, 1024, 1024) @@ -419,14 +429,22 @@ def gmm(inputs, kernel, group_sizes): inputs = inputs.astype(self.dtype) kernel = kernel.astype(self.dtype) + + lhs_quantize_dtype, rhs_quantize_dtype = None, None + if self.quant is not None: + quant_dg = self.quant.quant_dg + lhs_quantize_dtype = quant_dg.fwd.dg_quantizer.lhs.numerics.get_dtype() + rhs_quantize_dtype = quant_dg.fwd.dg_quantizer.rhs.numerics.get_dtype() + output = mblx.gmm( lhs=inputs, rhs=kernel, group_sizes=group_sizes, preferred_element_type=jnp.bfloat16, tiling=tile_size, + lhs_quantize_dtype=lhs_quantize_dtype, + rhs_quantize_dtype=rhs_quantize_dtype, ) - if hs_shape[0] % pad_length: output = output[: hs_shape[0]] return output @@ -434,25 +452,33 @@ def gmm(inputs, kernel, group_sizes): # Currently, we only support data and tensor parallelism with Megablox. # We all gather the input activations over tensor parallelism to follow strategy # in https://parsa.epfl.ch/course-info/cs723/papers/Megatron.pdf. + input_partition_spec = nn.logical_to_mesh_axes(("activation_batch", None, None)) + gate_logits_pspec = nn.logical_to_mesh_axes(("activation_batch", None, None)) + w0_pspec = nn.logical_to_mesh_axes((None, None, "mlp")) + w1_pspec = nn.logical_to_mesh_axes((None, None, "mlp")) + wo_pspec = nn.logical_to_mesh_axes((None, "mlp", None)) + + if isinstance(w0_kernel, QTensor): + w0_pspec = aqt_tensor.partition_spec(w0_pspec, (1,), w0_kernel.dtype, use_bias=False) + if isinstance(w1_kernel, QTensor): + w1_pspec = aqt_tensor.partition_spec(w1_pspec, (1,), w1_kernel.dtype, use_bias=False) + if isinstance(wo_kernel, QTensor): + wo_pspec = aqt_tensor.partition_spec(wo_pspec, (1,), wo_kernel.dtype, use_bias=False) + @functools.partial( shard_map.shard_map, mesh=self.mesh, - in_specs=( - (nn.logical_to_mesh_axes(("activation_batch", None, None))), - (nn.logical_to_mesh_axes(("activation_batch", None, None))), - (nn.logical_to_mesh_axes((None, None, "mlp"))), - (nn.logical_to_mesh_axes((None, None, "mlp"))), - (nn.logical_to_mesh_axes((None, "mlp", None))), - ), + in_specs=(input_partition_spec, gate_logits_pspec, w0_pspec, w1_pspec, wo_pspec), out_specs=(nn.logical_to_mesh_axes(("activation_batch", None, "activation_embed"))), check_rep=False, ) def wrapper(x, logits, w0, w1, wo): + batch_size, sequence_length, _ = x.shape x, sorted_selected_experts, weights, group_sizes = self.permute(x, logits) layer_w0 = gmm(x, w0, group_sizes) layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") layer_w1 = gmm(x, w1, group_sizes) - layer_w1 = checkpoint_name(layer_w0, "mlpwi_1") + layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") layer_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) intermediate_layer = jnp.multiply(layer_act, layer_w1) intermediate_output = gmm(intermediate_layer, wo, group_sizes) @@ -460,7 +486,9 @@ def wrapper(x, logits, w0, w1, wo): tensor_parallelism = self.config.ici_tensor_parallelism * self.config.dcn_tensor_parallelism if tensor_parallelism > 1: intermediate_output = jax.lax.psum_scatter(intermediate_output, "tensor", scatter_dimension=1, tiled=True) - output = self.unpermute(intermediate_output, sorted_selected_experts, weights) + output = self.unpermute( + intermediate_output, sorted_selected_experts, weights, batch_size=batch_size, sequence_length=sequence_length + ) return output, None return wrapper(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) @@ -678,6 +706,22 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): ).astype(self.dtype) return output, None + def retrieve_quantized_weight( + self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel + ) -> tuple[QTensor, QTensor, QTensor]: + # This is called only during tracing. This is to invoke creation of quantized tensor inside AqtEinsum. + # After jit, this will become no-op and will not affect performance. + _ = self.dense_matmul(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) + + w0_kernel = self.variables["aqt"]["AqtEinsum_0"]["AqtDotGeneral_0"]["qrhs"]["frozen"] + w1_kernel = self.variables["aqt"]["AqtEinsum_1"]["AqtDotGeneral_0"]["qrhs"]["frozen"] + wo_kernel = self.variables["aqt"]["AqtEinsum_2"]["AqtDotGeneral_0"]["qrhs"]["frozen"] + + w0_kernel = max_utils.unbox_logicallypartioned(w0_kernel) + w1_kernel = max_utils.unbox_logicallypartioned(w1_kernel) + wo_kernel = max_utils.unbox_logicallypartioned(wo_kernel) + return w0_kernel, w1_kernel, wo_kernel + @nn.compact def __call__(self, inputs): cfg = self.config @@ -692,11 +736,14 @@ def __call__(self, inputs): name="gate", matmul_precision=self.config.matmul_precision, )(inputs) - + cfg = self.config w0_kernel, w1_kernel, wo_kernel = self.generate_kernels(cfg.num_experts, cfg.emb_dim, cfg.mlp_dim) - if cfg.megablox: max_logging.log("Running MoE megablox implementation.") + if quantizations.in_serve_mode(self.quant): + w0_kernel, w1_kernel, wo_kernel = self.retrieve_quantized_weight( + inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel + ) return self.megablox(inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel) else: max_logging.log("Running MoE matmul implementation.") diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 9e48e27e2..c6af8b591 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -412,6 +412,7 @@ def user_init(raw_keys): raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) + raw_keys = create_parallelisms_list(raw_keys) raw_keys = set_and_validate_pipeline_config(raw_keys) @@ -520,6 +521,7 @@ def validate_multiple_slices(raw_keys): def set_and_validate_pipeline_config(raw_keys): if using_pipeline_parallelism(raw_keys): + raw_keys["using_pipeline_parallelism"] = True def modify_activation_embed_and_logits_batch(logical_axis_rules): for idx, logical_rule in enumerate(logical_axis_rules):