Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grok-1 optimization #164

Closed
wants to merge 24 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1a6a1e9
Support grok1 model
May 22, 2024
a3b34f6
fix config class not found issue
Jul 1, 2024
88ef830
add fp8 support (still debugging)
Jul 10, 2024
270beea
fix the weight name mapping, but got failed in triton part
Jul 11, 2024
d15a3c9
change the param_type from torch.float8_e4m3fn to torch.float8_e4m3fnuz
Jul 13, 2024
f7cca0c
workaround "fp8e4nv data type is not supported on CUDA"
Jul 15, 2024
4e4de7b
do vectorized load and store in scaled_fp8_quant_kernel, Add rpd prof…
Jul 25, 2024
c4e9206
1) Add TP8 fused_moe config 2) Add accuracy check script file 3) add …
Jul 31, 2024
d66a1d7
Add extra label in grok1 model to debug bubble issues
Aug 1, 2024
d904a2d
add ck group gemm support
Aug 12, 2024
060a89d
fix the accuracy problem
Aug 14, 2024
286210a
support LDS bypass feature for fused_moe
Aug 16, 2024
01072a9
Change padding size to 256 for fp8
Aug 16, 2024
d47382d
Revise benchmark_moe_rocm.py for more cases tunning.
Aug 16, 2024
036f294
Change tunning config for LDS bypass optimization for FP8
Aug 16, 2024
2b7a776
Optimize custom all reduce (#130)
iotamudelta Aug 14, 2024
fd8f821
Make CAR ROCm 6.1 compatible. (#137)
iotamudelta Aug 14, 2024
86a5ef3
Add fused_moe configuration files
Aug 23, 2024
10684dd
rms_layernorm opt by Jacob
Aug 23, 2024
f384516
Add condition to adjust permute N dim size for better performance in …
Aug 23, 2024
6f24903
Fix CAR build error issue
Aug 23, 2024
1e4a3ec
Add some benchmark and unit-test files
Aug 23, 2024
2091e45
Add tuning file in the different hipblaslt version
Aug 23, 2024
974c168
Optimize rms_norm kernel with vec8 and fix fused_moe configuration va…
Aug 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,10 @@ define_gpu_extension_target(

set(VLLM_MOE_EXT_SRC
"csrc/moe/moe_ops.cpp"
"csrc/moe/topk_softmax_kernels.cu")
"csrc/moe/topk_softmax_kernels.cu"
"csrc/moe/quant_gemm_kernels.cu")

set(CK_LIBS "device_gemm_operations" "utility")

define_gpu_extension_target(
_moe_C
Expand All @@ -279,6 +282,7 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
LIBRARIES ${CK_LIBS}
WITH_SOABI)

#
Expand Down Expand Up @@ -328,7 +332,7 @@ if (VLLM_PUNICA_GPU_ARCHES)
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${VLLM_PUNICA_EXT_SRC}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
COMPILE_FLAGS ${VLLM_PUNICA_GPU_eLAGS}
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks to be a typo, probably does not build.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo?

ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
WITH_SOABI)
else()
Expand Down
271 changes: 271 additions & 0 deletions benchmarks/benchmark_latency+accuracy_check.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you mind renaming this file to not use a '+' symbol?

Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""Benchmark the latency of processing a single batch of requests."""
import argparse
import json
import time
from pathlib import Path
from typing import List, Optional

import numpy as np
import torch
from tqdm import tqdm

from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS

from data_loader import *
from rpd_handler import *

def main(args: argparse.Namespace):
print(args)

# NOTE(woosuk): If the request cannot be processed in a single batch,
# the engine will automatically process the request in multiple batches.
llm = LLM(model=args.model,
speculative_model=args.speculative_model,
num_speculative_tokens=args.num_speculative_tokens,
tokenizer=args.tokenizer,
quantization=args.quantization,
quantized_weights_path=args.quantized_weights_path,
tensor_parallel_size=args.tensor_parallel_size,
trust_remote_code=args.trust_remote_code,
dtype=args.dtype,
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
quantization_param_path=args.quantization_param_path,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
worker_use_ray=args.worker_use_ray,
use_v2_block_manager=args.use_v2_block_manager,
enable_chunked_prefill=args.enable_chunked_prefill,
download_dir=args.download_dir,
block_size=args.block_size,
disable_custom_all_reduce=args.disable_custom_all_reduce,
gpu_memory_utilization=args.gpu_memory_utilization)

sampling_params = SamplingParams(
n=args.n,
temperature=0.0 if args.use_beam_search else 1.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=True,
max_tokens=args.output_len,
)
print(sampling_params)
#dummy_prompt_token_ids = np.random.randint(10000,
# size=(args.batch_size,
# args.input_len))
#dummy_inputs: List[PromptStrictInputs] = [{
# "prompt_token_ids": batch
#} for batch in dummy_prompt_token_ids.tolist()]
#from transformers import AutoTokenizer
#tokenizer = AutoTokenizer.from_pretrained("hpcai-tech/grok-1", trust_remote_code=True)
#input_tokens = get_input_sentences(args.batch_size, args.input_len, "wikitext", 'wikitext-2-raw-v1', tokenizer)
#prompts = [tokenizer.decode(sample) for sample in input_tokens]

prompts = ["Chronicles III . Development work took approximately one year . After the release of Valkyria Chronicles II"]
#prompts = ["Replace this with your text."]

#dummy_inputs: List[PromptStrictInputs] = [{
# "prompt_token_ids": batch
#} for batch in input_tokens]

def run_to_completion(profile_dir: Optional[str] = None):
if profile_dir:
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
str(profile_dir))) as p:
llm.generate(prompts,
sampling_params=sampling_params,
use_tqdm=False)
print(p.key_averages())
else:
start_time = time.perf_counter()
outputs = llm.generate(prompts,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.perf_counter()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}")
print(f"Generated text: {generated_text!r}")
latency = end_time - start_time
return latency

print("Warming up...")
for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
run_to_completion(profile_dir=None)

if args.profile:
profile_dir = args.profile_result_dir
if not profile_dir:
profile_dir = Path(
"."
) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
print(f"Profiling (results will be saved to '{profile_dir}')...")
run_to_completion(profile_dir=profile_dir)
return

# Benchmark.
latencies = []
if args.enable_prof:
profiler = HipTx()
profiler.start_profiling()

for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))

if args.enable_prof:
profiler.stop_profiling()

latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90]
percentiles = np.percentile(latencies, percentages)
print(f'Avg latency: {np.mean(latencies)} seconds')
for percentage, percentile in zip(percentages, percentiles):
print(f'{percentage}% percentile latency: {percentile} seconds')

# Output JSON results if specified
if args.output_json:
results = {
"avg_latency": np.mean(latencies),
"latencies": latencies.tolist(),
"percentiles": dict(zip(percentages, percentiles.tolist())),
}
with open(args.output_json, "w") as f:
json.dump(results, f, indent=4)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Benchmark the latency of processing a single batch of '
'requests till completion.')
parser.add_argument('--model', type=str, default='facebook/opt-125m')
parser.add_argument('--speculative-model', type=str, default=None)
parser.add_argument('--num-speculative-tokens', type=int, default=None)
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=[*QUANTIZATION_METHODS, None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
parser.add_argument('--output-len', type=int, default=128)
parser.add_argument('--batch-size', type=int, default=8)
parser.add_argument('--n',
type=int,
default=1,
help='Number of generated sequences per prompt.')
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters-warmup',
type=int,
default=0,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
type=int,
default=1,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='trust remote code from huggingface')
parser.add_argument(
'--dtype',
type=str,
default='auto',
choices=['auto', 'half', 'float16', 'bfloat16', 'float', 'float32'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.')
parser.add_argument('--enforce-eager',
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,
default=None,
help='Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
parser.add_argument(
'--quantized-weights-path',
type=str,
default=None,
help='Path to the safetensor file containing the quantized weights '
'and scaling factors. This should generally be supplied, when '
'quantization is FP8.')
parser.add_argument(
'--profile',
action='store_true',
help='profile the generation process of a single batch')
parser.add_argument(
'--profile-result-dir',
type=str,
default=None,
help=('path to save the pytorch profiler output. Can be visualized '
'with ui.perfetto.dev or Tensorboard.'))
parser.add_argument(
"--device",
type=str,
default="cuda",
choices=["cuda", "cpu"],
help='device type for vLLM execution, supporting CUDA and CPU.')
parser.add_argument('--block-size',
type=int,
default=16,
help='block size of key/value cache')
parser.add_argument(
'--enable-chunked-prefill',
action='store_true',
help='If True, the prefill requests can be chunked based on the '
'max_num_batched_tokens')
parser.add_argument('--use-v2-block-manager', action='store_true')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
help="If specified, use nsight to profile ray workers",
)
parser.add_argument('--worker-use-ray',
action='store_true',
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU '
'unless on ROCm where the default is torchrun')
parser.add_argument('--download-dir',
type=str,
default=None,
help='directory to download and load the weights, '
'default to the default cache dir of huggingface')
parser.add_argument(
'--output-json',
type=str,
default=None,
help='Path to save the latency results in JSON format.')
parser.add_argument('--disable_custom_all_reduce', action='store_true')
parser.add_argument('--gpu-memory-utilization',
type=float,
default=0.9,
help='the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.')
parser.add_argument(
"--enable-prof",
action='store_true',
help="enable profiler.")
args = parser.parse_args()
main(args)
20 changes: 18 additions & 2 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@
from vllm import LLM, SamplingParams
from vllm.inputs import PromptStrictInputs
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)

from rpd_handler import *

def main(args: argparse.Namespace):
print(args)
Expand Down Expand Up @@ -98,8 +101,17 @@ def run_to_completion(profile_dir: Optional[str] = None):

# Benchmark.
latencies = []
if args.enable_prof:
rank = get_tensor_model_parallel_rank()
profiler = HipTx(rank)
profiler.start_profiling()

for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
latencies.append(run_to_completion(profile_dir=None))

if args.enable_prof:
profiler.stop_profiling()

latencies = np.array(latencies)
percentages = [10, 25, 50, 75, 90]
percentiles = np.percentile(latencies, percentages)
Expand Down Expand Up @@ -141,11 +153,11 @@ def run_to_completion(profile_dir: Optional[str] = None):
parser.add_argument('--use-beam-search', action='store_true')
parser.add_argument('--num-iters-warmup',
type=int,
default=10,
default=0,
help='Number of iterations to run for warmup.')
parser.add_argument('--num-iters',
type=int,
default=30,
default=1,
help='Number of iterations to run.')
parser.add_argument('--trust-remote-code',
action='store_true',
Expand Down Expand Up @@ -248,5 +260,9 @@ def run_to_completion(profile_dir: Optional[str] = None):
'is used, on CUDA this will be automatically set to "ray" if '
'installed or "mp" (multiprocessing) otherwise. On ROCm, this is '
'instead set to torchrun by default.')
parser.add_argument(
"--enable-prof",
action='store_true',
help="enable profiler.")
args = parser.parse_args()
main(args)
Loading
Loading