Skip to content

Commit

Permalink
commit
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao committed Dec 2, 2024
1 parent 9b14d97 commit b67c1f6
Show file tree
Hide file tree
Showing 8 changed files with 108 additions and 126 deletions.
9 changes: 4 additions & 5 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
Expand Down Expand Up @@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
Expand All @@ -96,11 +96,10 @@ def test_simple_piecewise_compile():
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):

with set_compile_context([1, 2]):
model(inputs)
model(inputs)

model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())

input = torch.zeros(2).cuda()
global global_counter
Expand Down
33 changes: 17 additions & 16 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
Expand Down Expand Up @@ -256,6 +255,7 @@ def run_model(llama_config,
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
Expand All @@ -273,10 +273,9 @@ def run_model(llama_config,
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()

with set_compile_context([1, 2]):
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])

input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
Expand Down Expand Up @@ -379,10 +378,13 @@ def benchmark():
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)
else:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )
level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=cudagraph_sizes,
)

vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
Expand All @@ -396,17 +398,16 @@ def benchmark():

graphs = {}

with set_compile_context(cudagraph_sizes):
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
else:
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
graphs[b] = (graph, output)
else:
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
for b in cudagraph_sizes:
if piecewise:
# noqa is for `Function definition does not bind loop variable`
Expand Down
4 changes: 0 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
assert not self._called, "VllmBackend can only be called once"

self.graph = graph
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs.init_during_runtime()
self.configure_post_pass()

self.split_gm, self.piecewise_graphs = split_graph(
Expand Down
23 changes: 0 additions & 23 deletions vllm/compilation/compile_context.py

This file was deleted.

80 changes: 72 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,15 +2357,10 @@ def init_backend(self) -> Union[str, Callable]:
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)

def init_during_runtime(self):
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
from vllm.compilation.compile_context import get_compile_context
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
we need to know the cudagraph sizes."""

if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
Expand All @@ -2386,6 +2381,53 @@ def init_during_runtime(self):
self.inductor_compile_sizes = []
self.compile_sizes = self.inductor_compile_sizes

@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.
Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)

@staticmethod
def get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.
pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.
if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size = CompilationConfig.get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]


_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]


@dataclass
class VllmConfig:
Expand Down Expand Up @@ -2496,6 +2538,28 @@ def __post_init__(self):
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE

if not envs.VLLM_USE_V1:
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
max_batchsize_to_capture = \
CompilationConfig.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
size for size in _BATCH_SIZES_TO_CAPTURE
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)

if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
Expand Down
8 changes: 3 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.distributed
import torch.nn as nn

from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
Expand Down Expand Up @@ -548,10 +547,9 @@ def profile_run(self) -> None:
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
Expand Down
7 changes: 4 additions & 3 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata,
_get_graph_batch_size)
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
Expand Down Expand Up @@ -465,7 +464,9 @@ def _prepare_encoder_model_input_tensors(
# We will be using CUDA graph replay for this decode.
max_len_of_block_table = self.get_max_block_per_batch()
batch_size = len(encoder_seq_lens)
graph_batch_size = _get_graph_batch_size(batch_size)
graph_batch_size = \
self.vllm_config.compilation_config.get_graph_batch_size(
batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
# extend the cross_block_tables and encoder_seq_lens to match
Expand Down
Loading

0 comments on commit b67c1f6

Please sign in to comment.