Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] remove compilation_context and simplify code #10838

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
Expand Down Expand Up @@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
Expand All @@ -96,11 +96,10 @@ def test_simple_piecewise_compile():
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):

with set_compile_context([1, 2]):
model(inputs)
model(inputs)

model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())

input = torch.zeros(2).cuda()
global global_counter
Expand Down
33 changes: 17 additions & 16 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
Expand Down Expand Up @@ -256,6 +255,7 @@ def run_model(llama_config,
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
Expand All @@ -273,10 +273,9 @@ def run_model(llama_config,
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()

with set_compile_context([1, 2]):
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])

input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
Expand Down Expand Up @@ -379,10 +378,13 @@ def benchmark():
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)
else:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )
level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=cudagraph_sizes,
)

vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
Expand All @@ -396,17 +398,16 @@ def benchmark():

graphs = {}

with set_compile_context(cudagraph_sizes):
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
else:
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
graphs[b] = (graph, output)
else:
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
for b in cudagraph_sizes:
if piecewise:
# noqa is for `Function definition does not bind loop variable`
Expand Down
4 changes: 0 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
assert not self._called, "VllmBackend can only be called once"

self.graph = graph
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs.init_during_runtime()
self.configure_post_pass()

self.split_gm, self.piecewise_graphs = split_graph(
Expand Down
23 changes: 0 additions & 23 deletions vllm/compilation/compile_context.py

This file was deleted.

83 changes: 75 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,15 +2357,10 @@ def init_backend(self) -> Union[str, Callable]:
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)

def init_during_runtime(self):
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
from vllm.compilation.compile_context import get_compile_context
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
we need to know the cudagraph sizes."""

if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
Expand All @@ -2385,6 +2380,56 @@ def init_during_runtime(self):
if self.inductor_compile_sizes is None:
self.inductor_compile_sizes = []
self.compile_sizes = self.inductor_compile_sizes

# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)

@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.

Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)

@staticmethod
def get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.

pad the max_num_seqs if necessary by calling _get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.

if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size = CompilationConfig.get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]


_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]


@dataclass
Expand Down Expand Up @@ -2496,6 +2541,28 @@ def __post_init__(self):
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE

if not envs.VLLM_USE_V1:
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
max_batchsize_to_capture = \
CompilationConfig.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
size for size in _BATCH_SIZES_TO_CAPTURE
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)

if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
Expand Down
8 changes: 3 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import torch.distributed
import torch.nn as nn

from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import graph_capture
from vllm.forward_context import set_forward_context
Expand Down Expand Up @@ -548,10 +547,9 @@ def profile_run(self) -> None:
torch.tensor([], dtype=torch.float32, device=self.device)
for _ in range(self.num_attn_layers)
]
with set_compile_context(self.cudagraph_batch_sizes):
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
# Trigger compilation for general shape.
hidden_states = self._dummy_run(self.model, self.max_num_tokens,
dummy_kv_caches)
logits = self.model.compute_logits(hidden_states, None)
logits = logits[:self.max_num_tokens]
# TODO(woosuk): Consider the memory usage of the sampler.
Expand Down
7 changes: 4 additions & 3 deletions vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad
from vllm.worker.model_runner import (GPUModelRunnerBase,
ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata,
_get_graph_batch_size)
ModelInputForGPUWithSamplingMetadata)
from vllm.worker.model_runner_base import (
_add_attn_metadata_broadcastable_dict,
_add_sampling_metadata_broadcastable_dict)
Expand Down Expand Up @@ -465,7 +464,9 @@ def _prepare_encoder_model_input_tensors(
# We will be using CUDA graph replay for this decode.
max_len_of_block_table = self.get_max_block_per_batch()
batch_size = len(encoder_seq_lens)
graph_batch_size = _get_graph_batch_size(batch_size)
graph_batch_size = \
self.vllm_config.compilation_config.get_graph_batch_size(
batch_size)
assert graph_batch_size >= batch_size
cuda_graph_pad_size = graph_batch_size - batch_size
# extend the cross_block_tables and encoder_seq_lens to match
Expand Down
Loading
Loading