Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantize megablox #1062

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
61d944d
base
ZhiyuLi-goog Nov 8, 2024
3d25eb3
fix import error in pallas megablox kernel
ZhiyuLi-goog Nov 9, 2024
4a3d9ba
Quantize GMM kernel
ZhiyuLi-goog Nov 9, 2024
2ebe80e
[MoE] fix typo
ZhiyuLi-goog Nov 12, 2024
0e59567
walk around quantize_params
ZhiyuLi-goog Nov 12, 2024
bd208ed
inferencing shape hack
ZhiyuLi-goog Nov 12, 2024
57e2383
Enable (1) checkpoint conversion for MoEBlock when megablox=True (2) …
lenscloth Nov 23, 2024
9117c40
Fix lint error
lenscloth Nov 23, 2024
35e243a
[rollback] quantize param when ckpt is not quantized.
lenscloth Nov 25, 2024
94bdc94
Merge branch 'quantize_megablox'
lenscloth Nov 25, 2024
fa5d5f3
Let `gmm` accept pre-quantized rhs (weight).
lenscloth Nov 27, 2024
11f1f43
Support different quantization precision for lhs / rhs of megablox.
lenscloth Dec 4, 2024
1844936
Fix license
lenscloth Dec 5, 2024
0d9db17
Refactoring MoEBlock
lenscloth Dec 5, 2024
cd8b2d1
Rename in_out_block_spec for better readability
lenscloth Dec 5, 2024
cfd3b15
Fix lint error & refactor retrieving quantized weight logic
lenscloth Dec 6, 2024
5d51030
fix lint error
lenscloth Dec 6, 2024
3971306
Merge branch 'quantize_megablox'
lenscloth Dec 9, 2024
325c4c3
Read lhs and rhs quantization dtype directly from DotGeneral
lenscloth Dec 9, 2024
95bb7c1
Fix MoE related tests
RissyRan Nov 26, 2024
7addc0f
Add checkpoint topology discovery for the Replicator Service
xuefgu Nov 26, 2024
f6c7c5a
Fix local restore by re-mapping device ids directly instead of inferr…
cpgaffney1 Nov 27, 2024
dca3ee5
Compact the number of variables for the prefill result cache to reduc…
Nov 27, 2024
c993cf2
add more MoE tests
RissyRan Nov 27, 2024
86d0b54
update setup_gcsfuse for better perf
aireenmei Nov 27, 2024
346218a
Assert multiple slices available when requesting DCN parallelisms
gobbleturk Dec 2, 2024
f6c38f1
Add llama 3.1 70b config
raymondzouu Nov 5, 2024
7a4a44e
Update replicator.yaml to include framework and num_slices information
xuefgu Dec 3, 2024
e3823c0
Support JAX_VERSION for nightly mode on GPU
bvandermoon Dec 4, 2024
e17fcc6
Added a custom_wheel mode for building the dependency image.
lukebaumann Dec 3, 2024
4cdc15d
clean up pipeline config setting in its own method
gobbleturk Dec 4, 2024
7390f56
Fixes for dropping
mailvijayasingh Dec 4, 2024
5831547
Fix setup_training_state
khatwanimohit Dec 4, 2024
90430e1
point to new jax github location in documentation
jakeharmon8 Dec 5, 2024
2d1c51a
Change awk regex command to capture the coordinate address properly
michelle-yooh Dec 3, 2024
53a6abe
Fixes non-hashable error in ragged attn.
patemotter Dec 6, 2024
4546d68
Add new remat policy for save_dot_except_mlp with context
bvandermoon Dec 4, 2024
945ee3d
Fix moe logging to differentiate dense and megablox runs
lenscloth Dec 9, 2024
2b62fe8
update sharding
ZhiyuLi-goog Dec 10, 2024
742019f
Reshape based on the original input shape
lenscloth Dec 12, 2024
91b84f6
Resolve math import error
lenscloth Dec 14, 2024
cd8e6bf
Merge branch 'main' into quantize_megablox
lenscloth Dec 14, 2024
b0f21f9
fix lint error
lenscloth Dec 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added MaxText/kernels/__init__.py
lenscloth marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
15 changes: 15 additions & 0 deletions MaxText/kernels/megablox/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from kernels.megablox.ops import gmm
lenscloth marked this conversation as resolved.
Show resolved Hide resolved
59 changes: 59 additions & 0 deletions MaxText/kernels/megablox/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Common utilities for GMM kernels."""

import re

import jax
import jax.numpy as jnp


def is_tpu() -> bool:
return "TPU" in jax.devices()[0].device_kind


def tpu_kind() -> str:
"""Query identification string for the currently attached TPU."""
return jax.devices()[0].device_kind


_TPU_KIND_PATTERN = re.compile(r"TPU v(\d+)")


def tpu_generation() -> int:
"""Generation number of the currently attached TPU."""
if version := _TPU_KIND_PATTERN.match(tpu_kind()):
return int(version[1])
raise NotImplementedError("only TPU devices are supported")


def supports_bfloat16_matmul() -> bool:
"""Does the currently attached CPU support bfloat16 inputs?"""
return not is_tpu() or tpu_generation() >= 4


def assert_is_supported_dtype(dtype: jnp.dtype) -> None:
if dtype != jnp.bfloat16 and dtype != jnp.float32:
raise ValueError(f"Expected bfloat16 or float32 array but got {dtype}.")


def select_input_dtype(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.dtype:
"""A type to which both input should be adapted to before dot product."""
# bf16xbf16 matmul is only supported since TPUv4 generation. In case of mixed
# input precision, we need to convert bf16 argument to fp32 beforehand.
if supports_bfloat16_matmul() and lhs.dtype == jnp.bfloat16 and rhs.dtype == jnp.bfloat16:
return jnp.bfloat16
else:
return jnp.float32
Loading
Loading