-
Notifications
You must be signed in to change notification settings - Fork 308
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Quantize megablox #1062
base: main
Are you sure you want to change the base?
Quantize megablox #1062
Conversation
…load quantized weights from AqtEinsum to feed to gmm kernel later.
We shouldn't add |
By accepting pre-quantized weight, quantized gmm does not need to quantize weight for every iteration.
Discussed with JAX team, and they prefer we fork the megablox kernel to avoid dependency issue to use standalone install of JAX. So we will have a "copy" of gmm implementation in MaxText. Thanks Wonpyo! |
Megablox now supports: - int8 quantization - int8w quantization - int4w quantization
@RissyRan This branch now supports setting different precision for lhs and rhs. Now it support
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you Wonpyo!
Created a diff gmm file for future reference.
MaxText/layers/linears.py
Outdated
@@ -389,7 +395,14 @@ def unpermute(self, intermediate, sorted_selected_experts, weights): | |||
reshaped_weights.astype(jnp.float32), | |||
precision=matmul_precision, | |||
) | |||
return output.reshape(-1, self.config.max_target_length, self.config.emb_dim // tensor_parallelism).astype(self.dtype) | |||
updated_batch = int(self.config.per_device_batch_size * jax.device_count() // self.config.ici_fsdp_parallelism) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An assertion may be needed if the value is indivisible cc @mailvijayasingh @ZhiyuLi-goog
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Others LGTM! @ZhiyuLi-goog could you also help take a look once you are back?
Resolved some of comments @RissyRan would you do review one more time? |
Thank you @lenscloth for the awesome PR. LGTM! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! For code style check, you could workaround using this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Only request to get sizes based on inputs instead of config (ideally we can remove the original reference to tensor_parallelism as well)
MaxText/layers/linears.py
Outdated
@@ -404,7 +403,14 @@ def unpermute(self, intermediate, sorted_selected_experts, weights): | |||
reshaped_weights.astype(jnp.float32), | |||
precision=matmul_precision, | |||
) | |||
return output.reshape(-1, self.config.max_target_length, self.config.emb_dim // tensor_parallelism).astype(self.dtype) | |||
updated_batch = int(self.config.per_device_batch_size * jax.device_count() // self.config.ici_fsdp_parallelism) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this cover all cases - what if we use data_parallelism instead of FSDP (or a mix of them) or even fsdp_transpose?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get the relevant sizes from the input shape instead of based on the config?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ZhiyuLi-goog made change there. Would you take a look on this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the comment @gobbleturk
Does this cover all cases - what if we use data_parallelism instead of FSDP (or a mix of them) or even fsdp_transpose?
Updated to cover data and fsdp (or a mix of them). Not yet fsdp_transpose since it is simply not used in MoE models.
Can we get the relevant sizes from the input shape instead of based on the config?
It seems not easy to have a relevant sizes solution covering both prefill and decoding situations and the prefill length might be variant. And this is the trick working pretty well in our inferencing experiments.
@RissyRan @mailvijayasingh feel free to chime in if you have new ideas.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we get the relevant sizes from the input shape instead of based on the config?
Yeah, this is the difference between training and inference. The batch size will be different in prefill and generate stage (even we specify per_device_batch_size as a fixed value). In megablox, we need extra permute and unpermute operations, and those need to be calculated manually to map back the right shape.
@gobbleturk what's your suggestion on this? basically we need an indicator to see if current is prefill or generate stage in inference, and then reshape the right shape, based on different batch size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have updated this to inference output shape based on input tensor.
But I am not fully familiar with this code, and afraid I might made a wrong implementation here.
@ZhiyuLi-goog Can you take a look on this.
Also @gobbleturk Can you review this branch? It would be great to merge this branch before holiday.
cc. @RissyRan
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking something like sharded_batch_shape = inputs.shape[0]. I think prefill and generate are two separate jits (unsure) so this sharded_batch_shape just takes exactly the value it needs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thanks Wonpyo!
…ing them from how process indexes changed across restarts with some false assumptions. PiperOrigin-RevId: 700737164
…e Python layer overhead
This mode is the same as nightly except that after nightly is installed, any file in `maxtext/*.whl` is forcefully reinstalled.
PiperOrigin-RevId: 703063210
ae9a68f
to
742019f
Compare
Description
Support quantization on megablox
Tests
End to end tests on mixtral-8x22b model with max engine.
Checklist
Before submitting this PR, please make sure (put X in square brackets):