Skip to content

Commit

Permalink
[Bugfix] Enable some fp8 and quantized fullgraph tests (vllm-project#…
Browse files Browse the repository at this point in the history
…10171)

Signed-off-by: Bill Nell <[email protected]>
Signed-off-by: Maxime Fournioux <[email protected]>
  • Loading branch information
bnellnm authored and mfournioux committed Nov 20, 2024
1 parent f6f2a4b commit d183cd9
Showing 1 changed file with 13 additions and 16 deletions.
29 changes: 13 additions & 16 deletions tests/compile/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,29 +9,26 @@

TEST_MODELS = [
("facebook/opt-125m", {}),
# TODO: add fake implementation for compressed-tensors
# ("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
# "dtype": torch.float16,
# "quantization": "compressed-tensors"
# }),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
"dtype": torch.float16,
"quantization": "compressed-tensors"
}),
("neuralmagic/Meta-Llama-3-8B-Instruct-FP8", {
"dtype": torch.float16,
"quantization": "fp8"
}),
# TODO: add fake implementation for compressed-tensors
# ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
# "quantization": "compressed-tensors"
# }),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]

# TODO: enable in pytorch 2.5
if False and is_quant_method_supported("aqlm"): # noqa: SIM223
if is_quant_method_supported("aqlm"):
TEST_MODELS.append(("ISTA-DASLab/Llama-2-7b-AQLM-2Bit-1x16-hf", {
"quantization": "aqlm"
}))

# TODO: enable in pytorch 2.5
# TODO: figure out why this fails.
if False and is_quant_method_supported("gguf"): # noqa: SIM223
TEST_MODELS.append(("TheBloke/TinyLlama-1.1B-Chat-v1.0-GGUF", {
"quantization": "gguf"
Expand Down Expand Up @@ -71,13 +68,13 @@ def check_full_graph_support(model,
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"

# Inductor doesn't support fp8 and the base meta llama uses too
# much memory.
quantization = model_kwargs.get("quantization")
if ((quantization == "fp8" or model == "meta-llama/Meta-Llama-3-8B")
# The base meta llama uses too much memory.
if (model == "meta-llama/Meta-Llama-3-8B"
and optimization_level >= CompilationLevel.PIECEWISE):
return

print(f"MODEL={model}")

prompts = [
"Hello, my name is",
"The president of the United States is",
Expand Down

0 comments on commit d183cd9

Please sign in to comment.