forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Grok-1 optimization #164
Closed
Closed
Grok-1 optimization #164
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
1a6a1e9
Support grok1 model
a3b34f6
fix config class not found issue
88ef830
add fp8 support (still debugging)
270beea
fix the weight name mapping, but got failed in triton part
d15a3c9
change the param_type from torch.float8_e4m3fn to torch.float8_e4m3fnuz
f7cca0c
workaround "fp8e4nv data type is not supported on CUDA"
4e4de7b
do vectorized load and store in scaled_fp8_quant_kernel, Add rpd prof…
c4e9206
1) Add TP8 fused_moe config 2) Add accuracy check script file 3) add …
d66a1d7
Add extra label in grok1 model to debug bubble issues
d904a2d
add ck group gemm support
060a89d
fix the accuracy problem
286210a
support LDS bypass feature for fused_moe
01072a9
Change padding size to 256 for fp8
d47382d
Revise benchmark_moe_rocm.py for more cases tunning.
036f294
Change tunning config for LDS bypass optimization for FP8
2b7a776
Optimize custom all reduce (#130)
iotamudelta fd8f821
Make CAR ROCm 6.1 compatible. (#137)
iotamudelta 86a5ef3
Add fused_moe configuration files
10684dd
rms_layernorm opt by Jacob
f384516
Add condition to adjust permute N dim size for better performance in …
6f24903
Fix CAR build error issue
1e4a3ec
Add some benchmark and unit-test files
2091e45
Add tuning file in the different hipblaslt version
974c168
Optimize rms_norm kernel with vec8 and fix fused_moe configuration va…
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -270,7 +270,10 @@ define_gpu_extension_target( | |
|
||
set(VLLM_MOE_EXT_SRC | ||
"csrc/moe/moe_ops.cpp" | ||
"csrc/moe/topk_softmax_kernels.cu") | ||
"csrc/moe/topk_softmax_kernels.cu" | ||
"csrc/moe/quant_gemm_kernels.cu") | ||
|
||
set(CK_LIBS "device_gemm_operations" "utility") | ||
|
||
define_gpu_extension_target( | ||
_moe_C | ||
|
@@ -279,6 +282,7 @@ define_gpu_extension_target( | |
SOURCES ${VLLM_MOE_EXT_SRC} | ||
COMPILE_FLAGS ${VLLM_GPU_FLAGS} | ||
ARCHITECTURES ${VLLM_GPU_ARCHES} | ||
LIBRARIES ${CK_LIBS} | ||
WITH_SOABI) | ||
|
||
# | ||
|
@@ -328,7 +332,7 @@ if (VLLM_PUNICA_GPU_ARCHES) | |
DESTINATION vllm | ||
LANGUAGE ${VLLM_GPU_LANG} | ||
SOURCES ${VLLM_PUNICA_EXT_SRC} | ||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS} | ||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_eLAGS} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. typo? |
||
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES} | ||
WITH_SOABI) | ||
else() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would you mind renaming this file to not use a '+' symbol? |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
"""Benchmark the latency of processing a single batch of requests.""" | ||
import argparse | ||
import json | ||
import time | ||
from pathlib import Path | ||
from typing import List, Optional | ||
|
||
import numpy as np | ||
import torch | ||
from tqdm import tqdm | ||
|
||
from vllm import LLM, SamplingParams | ||
from vllm.inputs import PromptStrictInputs | ||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS | ||
|
||
from data_loader import * | ||
from rpd_handler import * | ||
|
||
def main(args: argparse.Namespace): | ||
print(args) | ||
|
||
# NOTE(woosuk): If the request cannot be processed in a single batch, | ||
# the engine will automatically process the request in multiple batches. | ||
llm = LLM(model=args.model, | ||
speculative_model=args.speculative_model, | ||
num_speculative_tokens=args.num_speculative_tokens, | ||
tokenizer=args.tokenizer, | ||
quantization=args.quantization, | ||
quantized_weights_path=args.quantized_weights_path, | ||
tensor_parallel_size=args.tensor_parallel_size, | ||
trust_remote_code=args.trust_remote_code, | ||
dtype=args.dtype, | ||
enforce_eager=args.enforce_eager, | ||
kv_cache_dtype=args.kv_cache_dtype, | ||
quantization_param_path=args.quantization_param_path, | ||
device=args.device, | ||
ray_workers_use_nsight=args.ray_workers_use_nsight, | ||
worker_use_ray=args.worker_use_ray, | ||
use_v2_block_manager=args.use_v2_block_manager, | ||
enable_chunked_prefill=args.enable_chunked_prefill, | ||
download_dir=args.download_dir, | ||
block_size=args.block_size, | ||
disable_custom_all_reduce=args.disable_custom_all_reduce, | ||
gpu_memory_utilization=args.gpu_memory_utilization) | ||
|
||
sampling_params = SamplingParams( | ||
n=args.n, | ||
temperature=0.0 if args.use_beam_search else 1.0, | ||
top_p=1.0, | ||
use_beam_search=args.use_beam_search, | ||
ignore_eos=True, | ||
max_tokens=args.output_len, | ||
) | ||
print(sampling_params) | ||
#dummy_prompt_token_ids = np.random.randint(10000, | ||
# size=(args.batch_size, | ||
# args.input_len)) | ||
#dummy_inputs: List[PromptStrictInputs] = [{ | ||
# "prompt_token_ids": batch | ||
#} for batch in dummy_prompt_token_ids.tolist()] | ||
#from transformers import AutoTokenizer | ||
#tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/grok-1", trust_remote_code=True) | ||
#input_tokens = get_input_sentences(args.batch_size, args.input_len, "wikitext", 'wikitext-2-raw-v1', tokenizer) | ||
#prompts = [tokenizer.decode(sample) for sample in input_tokens] | ||
|
||
prompts = ["Chronicles III . Development work took approximately one year . After the release of Valkyria Chronicles II"] | ||
#prompts = ["Replace this with your text."] | ||
|
||
#dummy_inputs: List[PromptStrictInputs] = [{ | ||
# "prompt_token_ids": batch | ||
#} for batch in input_tokens] | ||
|
||
def run_to_completion(profile_dir: Optional[str] = None): | ||
if profile_dir: | ||
with torch.profiler.profile( | ||
activities=[ | ||
torch.profiler.ProfilerActivity.CPU, | ||
torch.profiler.ProfilerActivity.CUDA, | ||
], | ||
on_trace_ready=torch.profiler.tensorboard_trace_handler( | ||
str(profile_dir))) as p: | ||
llm.generate(prompts, | ||
sampling_params=sampling_params, | ||
use_tqdm=False) | ||
print(p.key_averages()) | ||
else: | ||
start_time = time.perf_counter() | ||
outputs = llm.generate(prompts, | ||
sampling_params=sampling_params, | ||
use_tqdm=False) | ||
end_time = time.perf_counter() | ||
for output in outputs: | ||
prompt = output.prompt | ||
generated_text = output.outputs[0].text | ||
print(f"Prompt: {prompt!r}") | ||
print(f"Generated text: {generated_text!r}") | ||
latency = end_time - start_time | ||
return latency | ||
|
||
print("Warming up...") | ||
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"): | ||
run_to_completion(profile_dir=None) | ||
|
||
if args.profile: | ||
profile_dir = args.profile_result_dir | ||
if not profile_dir: | ||
profile_dir = Path( | ||
"." | ||
) / "vllm_benchmark_result" / f"latency_result_{time.time()}" | ||
print(f"Profiling (results will be saved to '{profile_dir}')...") | ||
run_to_completion(profile_dir=profile_dir) | ||
return | ||
|
||
# Benchmark. | ||
latencies = [] | ||
if args.enable_prof: | ||
profiler = HipTx() | ||
profiler.start_profiling() | ||
|
||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"): | ||
latencies.append(run_to_completion(profile_dir=None)) | ||
|
||
if args.enable_prof: | ||
profiler.stop_profiling() | ||
|
||
latencies = np.array(latencies) | ||
percentages = [10, 25, 50, 75, 90] | ||
percentiles = np.percentile(latencies, percentages) | ||
print(f'Avg latency: {np.mean(latencies)} seconds') | ||
for percentage, percentile in zip(percentages, percentiles): | ||
print(f'{percentage}% percentile latency: {percentile} seconds') | ||
|
||
# Output JSON results if specified | ||
if args.output_json: | ||
results = { | ||
"avg_latency": np.mean(latencies), | ||
"latencies": latencies.tolist(), | ||
"percentiles": dict(zip(percentages, percentiles.tolist())), | ||
} | ||
with open(args.output_json, "w") as f: | ||
json.dump(results, f, indent=4) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser( | ||
description='Benchmark the latency of processing a single batch of ' | ||
'requests till completion.') | ||
parser.add_argument('--model', type=str, default='facebook/opt-125m') | ||
parser.add_argument('--speculative-model', type=str, default=None) | ||
parser.add_argument('--num-speculative-tokens', type=int, default=None) | ||
parser.add_argument('--tokenizer', type=str, default=None) | ||
parser.add_argument('--quantization', | ||
'-q', | ||
choices=[*QUANTIZATION_METHODS, None], | ||
default=None) | ||
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) | ||
parser.add_argument('--input-len', type=int, default=32) | ||
parser.add_argument('--output-len', type=int, default=128) | ||
parser.add_argument('--batch-size', type=int, default=8) | ||
parser.add_argument('--n', | ||
type=int, | ||
default=1, | ||
help='Number of generated sequences per prompt.') | ||
parser.add_argument('--use-beam-search', action='store_true') | ||
parser.add_argument('--num-iters-warmup', | ||
type=int, | ||
default=0, | ||
help='Number of iterations to run for warmup.') | ||
parser.add_argument('--num-iters', | ||
type=int, | ||
default=1, | ||
help='Number of iterations to run.') | ||
parser.add_argument('--trust-remote-code', | ||
action='store_true', | ||
help='trust remote code from huggingface') | ||
parser.add_argument( | ||
'--dtype', | ||
type=str, | ||
default='auto', | ||
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'], | ||
help='data type for model weights and activations. ' | ||
'The "auto" option will use FP16 precision ' | ||
'for FP32 and FP16 models, and BF16 precision ' | ||
'for BF16 models.') | ||
parser.add_argument('--enforce-eager', | ||
action='store_true', | ||
help='enforce eager mode and disable CUDA graph') | ||
parser.add_argument( | ||
'--kv-cache-dtype', | ||
type=str, | ||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'], | ||
default="auto", | ||
help='Data type for kv cache storage. If "auto", will use model ' | ||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ' | ||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)') | ||
parser.add_argument( | ||
'--quantization-param-path', | ||
type=str, | ||
default=None, | ||
help='Path to the JSON file containing the KV cache scaling factors. ' | ||
'This should generally be supplied, when KV cache dtype is FP8. ' | ||
'Otherwise, KV cache scaling factors default to 1.0, which may cause ' | ||
'accuracy issues. FP8_E5M2 (without scaling) is only supported on ' | ||
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is ' | ||
'instead supported for common inference criteria.') | ||
parser.add_argument( | ||
'--quantized-weights-path', | ||
type=str, | ||
default=None, | ||
help='Path to the safetensor file containing the quantized weights ' | ||
'and scaling factors. This should generally be supplied, when ' | ||
'quantization is FP8.') | ||
parser.add_argument( | ||
'--profile', | ||
action='store_true', | ||
help='profile the generation process of a single batch') | ||
parser.add_argument( | ||
'--profile-result-dir', | ||
type=str, | ||
default=None, | ||
help=('path to save the pytorch profiler output. Can be visualized ' | ||
'with ui.perfetto.dev or Tensorboard.')) | ||
parser.add_argument( | ||
"--device", | ||
type=str, | ||
default="cuda", | ||
choices=["cuda", "cpu"], | ||
help='device type for vLLM execution, supporting CUDA and CPU.') | ||
parser.add_argument('--block-size', | ||
type=int, | ||
default=16, | ||
help='block size of key/value cache') | ||
parser.add_argument( | ||
'--enable-chunked-prefill', | ||
action='store_true', | ||
help='If True, the prefill requests can be chunked based on the ' | ||
'max_num_batched_tokens') | ||
parser.add_argument('--use-v2-block-manager', action='store_true') | ||
parser.add_argument( | ||
"--ray-workers-use-nsight", | ||
action='store_true', | ||
help="If specified, use nsight to profile ray workers", | ||
) | ||
parser.add_argument('--worker-use-ray', | ||
action='store_true', | ||
help='use Ray for distributed serving, will be ' | ||
'automatically set when using more than 1 GPU ' | ||
'unless on ROCm where the default is torchrun') | ||
parser.add_argument('--download-dir', | ||
type=str, | ||
default=None, | ||
help='directory to download and load the weights, ' | ||
'default to the default cache dir of huggingface') | ||
parser.add_argument( | ||
'--output-json', | ||
type=str, | ||
default=None, | ||
help='Path to save the latency results in JSON format.') | ||
parser.add_argument('--disable_custom_all_reduce', action='store_true') | ||
parser.add_argument('--gpu-memory-utilization', | ||
type=float, | ||
default=0.9, | ||
help='the fraction of GPU memory to be used for ' | ||
'the model executor, which can range from 0 to 1.' | ||
'If unspecified, will use the default value of 0.9.') | ||
parser.add_argument( | ||
"--enable-prof", | ||
action='store_true', | ||
help="enable profiler.") | ||
args = parser.parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks to be a typo, probably does not build.