Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Quantize megablox #1062

Open
wants to merge 43 commits into
base: main
Choose a base branch
from
Open

Quantize megablox #1062

wants to merge 43 commits into from

Conversation

lenscloth
Copy link
Collaborator

@lenscloth lenscloth commented Nov 25, 2024

Description

Support quantization on megablox

  • Current implementation can accelerate training but not serving, serving acceleration will be done in another PR.
  • Tested correctness after apply quantization on megablox.

Tests

End to end tests on mixtral-8x22b model with max engine.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

@RissyRan
Copy link
Collaborator

We shouldn't add megablox as a copy from JAX repo into MaxText. We are good to work in a dev branch, but in a long term, we need to merge related quantization changes into JAX repo. So MaxText could directly import it.

By accepting pre-quantized weight, quantized gmm does not need to
quantize weight for every iteration.
@RissyRan
Copy link
Collaborator

RissyRan commented Dec 4, 2024

We shouldn't add megablox as a copy from JAX repo into MaxText. We are good to work in a dev branch, but in a long term, we need to merge related quantization changes into JAX repo. So MaxText could directly import it.

Discussed with JAX team, and they prefer we fork the megablox kernel to avoid dependency issue to use standalone install of JAX. So we will have a "copy" of gmm implementation in MaxText. Thanks Wonpyo!

Megablox now supports:
- int8 quantization
- int8w quantization
- int4w quantization
@lenscloth
Copy link
Collaborator Author

@RissyRan This branch now supports setting different precision for lhs and rhs. Now it support

  • Int8
  • Int8w
  • Int4w

Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Wonpyo!

Created a diff gmm file for future reference.

MaxText/kernels/megablox/common.py Outdated Show resolved Hide resolved
MaxText/kernels/megablox/__init__.py Show resolved Hide resolved
MaxText/kernels/megablox/gmm.py Outdated Show resolved Hide resolved
MaxText/layers/linears.py Outdated Show resolved Hide resolved
@@ -389,7 +395,14 @@ def unpermute(self, intermediate, sorted_selected_experts, weights):
reshaped_weights.astype(jnp.float32),
precision=matmul_precision,
)
return output.reshape(-1, self.config.max_target_length, self.config.emb_dim // tensor_parallelism).astype(self.dtype)
updated_batch = int(self.config.per_device_batch_size * jax.device_count() // self.config.ici_fsdp_parallelism)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An assertion may be needed if the value is indivisible cc @mailvijayasingh @ZhiyuLi-goog

MaxText/layers/linears.py Outdated Show resolved Hide resolved
MaxText/layers/linears.py Outdated Show resolved Hide resolved
MaxText/kernels/megablox/gmm.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Others LGTM! @ZhiyuLi-goog could you also help take a look once you are back?

@lenscloth
Copy link
Collaborator Author

Resolved some of comments @RissyRan would you do review one more time?

@ZhiyuLi-goog
Copy link
Collaborator

Thank you @lenscloth for the awesome PR. LGTM!

MaxText/layers/linears.py Outdated Show resolved Hide resolved
MaxText/layers/linears.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! For code style check, you could workaround using this.

MaxText/kernels/megablox/gmm.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm!

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Only request to get sizes based on inputs instead of config (ideally we can remove the original reference to tensor_parallelism as well)

MaxText/kernels/__init__.py Outdated Show resolved Hide resolved
@@ -404,7 +403,14 @@ def unpermute(self, intermediate, sorted_selected_experts, weights):
reshaped_weights.astype(jnp.float32),
precision=matmul_precision,
)
return output.reshape(-1, self.config.max_target_length, self.config.emb_dim // tensor_parallelism).astype(self.dtype)
updated_batch = int(self.config.per_device_batch_size * jax.device_count() // self.config.ici_fsdp_parallelism)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this cover all cases - what if we use data_parallelism instead of FSDP (or a mix of them) or even fsdp_transpose?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get the relevant sizes from the input shape instead of based on the config?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ZhiyuLi-goog made change there. Would you take a look on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the comment @gobbleturk

Does this cover all cases - what if we use data_parallelism instead of FSDP (or a mix of them) or even fsdp_transpose?

Updated to cover data and fsdp (or a mix of them). Not yet fsdp_transpose since it is simply not used in MoE models.

Can we get the relevant sizes from the input shape instead of based on the config?

It seems not easy to have a relevant sizes solution covering both prefill and decoding situations and the prefill length might be variant. And this is the trick working pretty well in our inferencing experiments.

@RissyRan @mailvijayasingh feel free to chime in if you have new ideas.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get the relevant sizes from the input shape instead of based on the config?

Yeah, this is the difference between training and inference. The batch size will be different in prefill and generate stage (even we specify per_device_batch_size as a fixed value). In megablox, we need extra permute and unpermute operations, and those need to be calculated manually to map back the right shape.

@gobbleturk what's your suggestion on this? basically we need an indicator to see if current is prefill or generate stage in inference, and then reshape the right shape, based on different batch size.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated this to inference output shape based on input tensor.
But I am not fully familiar with this code, and afraid I might made a wrong implementation here.

@ZhiyuLi-goog Can you take a look on this.
Also @gobbleturk Can you review this branch? It would be great to merge this branch before holiday.

cc. @RissyRan

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking something like sharded_batch_shape = inputs.shape[0]. I think prefill and generate are two separate jits (unsure) so this sharded_batch_shape just takes exactly the value it needs

Copy link
Collaborator

@gobbleturk gobbleturk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, Thanks Wonpyo!

lenscloth and others added 22 commits December 14, 2024 01:09
…ing them from how process indexes changed across restarts with some false assumptions.

PiperOrigin-RevId: 700737164
This mode is the same as nightly except that after nightly is installed, any file in `maxtext/*.whl` is forcefully reinstalled.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.