-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ingest FP8 attn scales and use them in ROCm FlashAttention #338
Conversation
4e42946
to
9ba2fab
Compare
@@ -428,7 +434,9 @@ def load_weights(self, weights: Iterable[Tuple[str, | |||
param = params_dict[scale_name] | |||
weight_loader = getattr(param, "weight_loader", | |||
default_weight_loader) | |||
loaded_weight = loaded_weight[0] | |||
if loaded_weight.shape: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same for mllama.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I targeted only the models which unconditionally do this logic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without explicitly disabling VLLM_USE_ROCM_FP8_ATTN now quark quantized models (amd/Meta-Llama-3.1-70B-Instruct-FP8-KV) fail with a triton exception:
python: /root/.triton/llvm/llvm-c08c6a71-ubuntu-x64/include/llvm/Support/Casting.h:566: decltype(auto) llvm::cast(const From&) [with To = mlir::detail::TypedValue<mlir::RankedTensorType>; From = mlir::OpResult]: Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.
…ntion for dynamic quantization
EDIT: @ilia-cher identified the issue and provided a simple fix that works on older Triton. Still recommended to upgrade to latest. |
9639307
to
1ed1389
Compare
Thanks to the work of @ilia-cher in #301, Triton FA supports per-tensor quantized FP8 almost-everything (quantized first and second GEMMs and attention output <- needs per-tensor quantized Q, K, V, softmax(QK^T) and corresponding scales).
This PR enables the aforementioned quantization routines in Triton FA and ROCm PA if a quantized (text-only) Llama model contains attention output scales and an appropriate environment variable is set (
VLLM_USE_ROCM_FP8_FLASH_ATTN={True/1}
, off by default). Extending this to other model architectures is straightforward but not done for now. Accuracy might dip for Triton FA if not all scales are present in the model.