Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply quantization on megablox kernel; support both training and serving #1100

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions MaxText/kernels/megablox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megablox kernel"""

from kernels.megablox.ops import gmm
59 changes: 59 additions & 0 deletions MaxText/kernels/megablox/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common utilities for GMM kernels."""

import re

import jax
import jax.numpy as jnp


def is_tpu() -> bool:
return "TPU" in jax.devices()[0].device_kind


def tpu_kind() -> str:
"""Query identification string for the currently attached TPU."""
return jax.devices()[0].device_kind


_TPU_KIND_PATTERN = re.compile(r"TPU v(\d+)")


def tpu_generation() -> int:
"""Generation number of the currently attached TPU."""
if version := _TPU_KIND_PATTERN.match(tpu_kind()):
return int(version[1])
raise NotImplementedError("only TPU devices are supported")


def supports_bfloat16_matmul() -> bool:
"""Does the currently attached CPU support bfloat16 inputs?"""
return not is_tpu() or tpu_generation() >= 4


def assert_is_supported_dtype(dtype: jnp.dtype) -> None:
if dtype not in (jnp.bfloat16, jnp.float32):
raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.")


def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype:
"""A type to which both input should be adapted to before dot product."""
# bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed
# input precision, we need to convert bf16 argument to fp32 beforehand.
if supports_bfloat16_matmul() and lhs.dtype == jnp.bfloat16 and rhs.dtype == jnp.bfloat16:
return jnp.bfloat16
else:
return jnp.float32
Loading
Loading