Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] remove compilation_context and simplify code #10838

Merged
merged 7 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions tests/compile/piecewise/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
Expand Down Expand Up @@ -81,6 +80,7 @@ def test_simple_piecewise_compile():
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_copy_inputs=True,
cudagraph_capture_sizes=[1, 2],
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
Expand All @@ -96,11 +96,10 @@ def test_simple_piecewise_compile():
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):

with set_compile_context([1, 2]):
model(inputs)
model(inputs)

model(torch.randn(2).cuda())
model(torch.randn(1).cuda())
model(torch.randn(2).cuda())
model(torch.randn(1).cuda())

input = torch.zeros(2).cuda()
global global_counter
Expand Down
33 changes: 17 additions & 16 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from torch import nn
from torch.library import Library

from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
Expand Down Expand Up @@ -256,6 +255,7 @@ def run_model(llama_config,
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
cudagraph_capture_sizes=[1, 2],
)
if split_attn:
compilation_config.splitting_ops = ["silly.attention"]
Expand All @@ -273,10 +273,9 @@ def run_model(llama_config,
input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda()
positions = torch.arange(B).cuda()

with set_compile_context([1, 2]):
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])
model(input_ids, positions)
model(input_ids[:2], positions[:2])
model(input_ids[:1], positions[:1])

input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
Expand Down Expand Up @@ -379,10 +378,13 @@ def benchmark():
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
splitting_ops=["silly.attention"],
cudagraph_capture_sizes=cudagraph_sizes,
)
else:
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )
level=CompilationLevel.PIECEWISE,
cudagraph_capture_sizes=cudagraph_sizes,
)

vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
Expand All @@ -396,17 +398,16 @@ def benchmark():

graphs = {}

with set_compile_context(cudagraph_sizes):
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (graph, output)
else:
model(input_ids, positions)
for b in cudagraph_sizes[::-1]:
if not piecewise:
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, pool=pool):
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
graphs[b] = (graph, output)
else:
output = model(input_ids[:b], positions[:b])
graphs[b] = (model, output)
for b in cudagraph_sizes:
if piecewise:
# noqa is for `Function definition does not bind loop variable`
Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/language/test_jamba.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import pytest

from tests.utils import multi_gpu_test
from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size

from ...utils import check_outputs_equal

Expand Down Expand Up @@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
while len(example_prompts) == VllmConfig.get_graph_batch_size(
len(example_prompts)):
example_prompts.append(example_prompts[0])

try:
Expand Down
5 changes: 3 additions & 2 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer

from vllm.config import VllmConfig
from vllm.sampling_params import SamplingParams
from vllm.worker.model_runner import _get_graph_batch_size

from ...utils import check_outputs_equal

Expand Down Expand Up @@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding(
# This test is for verifying that mamba cache is padded to CG captured
# batch size. If it's not, a torch RuntimeError will be raised because
# tensor dimensions aren't compatible
while len(example_prompts) == _get_graph_batch_size(len(example_prompts)):
while len(example_prompts) == VllmConfig.get_graph_batch_size(
len(example_prompts)):
example_prompts.append(example_prompts[0])

try:
Expand Down
4 changes: 2 additions & 2 deletions tests/worker/test_encoder_decoder_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import make_tensor_with_pad
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import _get_graph_batch_size

BATCH_SIZES = [1, 4, 16, 64, 256]

Expand Down Expand Up @@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group):
# With CUDA Graph capture and replay enabled, the decoder and encoder
# input sequences will be padded. Create the expected padded tensors
# accordingly.
graph_batch_size = _get_graph_batch_size(expanded_batch_size)
graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size)
cuda_graph_pad_size = graph_batch_size - expanded_batch_size
padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size))
padded_encoder_seq_lens = encoder_seq_lens + list(
Expand Down
5 changes: 3 additions & 2 deletions tests/worker/test_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
import pytest
import torch

from vllm.config import VllmConfig
from vllm.distributed.parallel_state import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.engine.arg_utils import EngineArgs
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
from vllm.utils import get_open_port
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
from vllm.worker.model_runner import ModelRunner


def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
Expand Down Expand Up @@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
assert len(slot_mapping) == len(input_tokens)

expected_bs = _get_graph_batch_size(len(seq_group_metadata_list))
expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list))
# Verify input metadata is correct for prompts.
device = model_runner.device
assert attn_metadata.num_prefills == 0
Expand Down
4 changes: 0 additions & 4 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,10 +242,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
assert not self._called, "VllmBackend can only be called once"

self.graph = graph
# config is updated now, because only here can
# we get the sizes to capture for cudagraph
# from compilation context
self.compilation_configs.init_during_runtime()
self.configure_post_pass()

self.split_gm, self.piecewise_graphs = split_graph(
Expand Down
23 changes: 0 additions & 23 deletions vllm/compilation/compile_context.py

This file was deleted.

83 changes: 75 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2357,15 +2357,10 @@ def init_backend(self) -> Union[str, Callable]:
from vllm.compilation.backends import VllmBackend
return VllmBackend(self)

def init_during_runtime(self):
def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
from vllm.compilation.compile_context import get_compile_context
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
we need to know the cudagraph sizes."""

if self.cudagraph_capture_sizes is None:
self.capture_sizes = sizes_to_specialize
else:
Expand All @@ -2386,6 +2381,21 @@ def init_during_runtime(self):
self.inductor_compile_sizes = []
self.compile_sizes = self.inductor_compile_sizes

# sort to make sure cudagraph capture sizes are in descending order
self.capture_sizes.sort(reverse=True)


_BATCH_SIZE_ALIGNMENT = 8
# all the token sizes that **can** be captured by cudagraph.
# they can be arbitrarily large.
# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192.
# the actual sizes to capture will be determined by the model,
# depending on the model's max_num_seqs.
# NOTE: get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025)
]
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class VllmConfig:
Expand Down Expand Up @@ -2413,6 +2423,41 @@ class VllmConfig:
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore

@staticmethod
def get_graph_batch_size(batch_size: int) -> int:
"""Returns the padded batch size given actual batch size.

Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT,
2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT...
"""
if batch_size <= 2:
return batch_size
elif batch_size <= 4:
return 4
else:
return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) //
_BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT)

@staticmethod
def get_max_graph_batch_size(max_num_seqs: int) -> int:
"""
max_num_seqs: Maximum number of sequences in a batch.
_BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture.

pad the max_num_seqs if necessary by calling get_graph_batch_size,
which will deal with some edge cases like 1, 2, 4.

if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded
size. if not, it means the padded size is larger than the largest size
in _BATCH_SIZES_TO_CAPTURE, return the largest size in
_BATCH_SIZES_TO_CAPTURE.
"""
padded_size = VllmConfig.get_graph_batch_size(max_num_seqs)
if padded_size in _BATCH_SIZES_TO_CAPTURE:
return padded_size
assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1]
return _BATCH_SIZES_TO_CAPTURE[-1]

@staticmethod
def _get_quantization_config(
model_config: ModelConfig,
Expand Down Expand Up @@ -2496,6 +2541,28 @@ def __post_init__(self):
self.compilation_config.pass_config.enable_reshape = False
self.compilation_config.level = CompilationLevel.PIECEWISE

if not envs.VLLM_USE_V1:
max_batchsize_to_capture = 0
if self.scheduler_config is not None and \
self.model_config is not None and \
not self.model_config.enforce_eager:
max_batchsize_to_capture = \
self.get_max_graph_batch_size(
self.scheduler_config.max_num_seqs)
batch_size_capture_list = [
size for size in _BATCH_SIZES_TO_CAPTURE
if size <= max_batchsize_to_capture
]
else:
batch_size_capture_list = []
if self.model_config is not None and \
not self.model_config.enforce_eager:
batch_size_capture_list = [1, 2, 4
] + [i for i in range(8, 513, 8)]

self.compilation_config.init_with_cudagraph_sizes(
batch_size_capture_list)
WoosukKwon marked this conversation as resolved.
Show resolved Hide resolved

if self.cache_config is not None and \
self.cache_config.cpu_offload_gb > 0 and \
self.compilation_config.level != CompilationLevel.NO_COMPILATION:
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.layer import Attention
from vllm.config import CacheConfig, VllmConfig
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
Expand All @@ -25,8 +25,6 @@
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)

from .interfaces import HasInnerState, SupportsLoRA
from .utils import maybe_prefix
Expand Down Expand Up @@ -404,7 +402,7 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)

Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/models/mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from transformers import MambaConfig

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.logits_processor import LogitsProcessor
Expand All @@ -23,8 +23,6 @@
MambaCacheParams)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE,
_get_graph_batch_size)

from .utils import maybe_prefix

Expand Down Expand Up @@ -187,7 +185,7 @@ def forward(self,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs):
if self.mamba_cache is None:
max_batch_size = (_get_graph_batch_size(
max_batch_size = (VllmConfig.get_graph_batch_size(
self.scheduler_config.max_num_seqs) if self.scheduler_config
else max(_BATCH_SIZES_TO_CAPTURE) + 2)
self.mamba_cache = MambaCacheManager(
Expand Down
Loading
Loading