GEMM API for efficient LLM inference with W8A16 #1788
Labels
enhancement
A feature or an optimization request
help wanted
platform:cpu-aarch64
Codeowner: @oneapi-src/onednn-cpu-aarch64
I want to perform inference on quantized LLAMA (W8A16) on ARM-v9 (with SVE) using oneDNN. The LLAMA weights are per-group quantized.
Based on my understanding, I need to prepack the weights to reduce the cost of repeated packing. However, packing will disrupt the arrangement of per-group quantization scales and shifts. I understand that dequantization needs to be fused with the kernel. If fused with packing, it's equivalent to storing another copy of the weights in FP16, essentially undoing the quantization.
I haven't figured out how to combine prepacking and per-group dequantization.
Which interface should I use for prepacking? SVE instructions can be 256-bit or 512-bit wide; how does oneDNN intelligently handle packing?
After prepacking and saving the weights again, how do I fuse dequantization with the kernel during computation?
The text was updated successfully, but these errors were encountered: