diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c1b10b3cf8f58..3e3a2ae46cf60 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -2,6 +2,7 @@ import argparse import dataclasses import json +import pickle import random import time from functools import cache @@ -21,10 +22,14 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict +from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.utils import FlexibleArgumentParser, merge_async_iterators +SAMPLING_TEMPERATURE = 0.0 +SAMPLING_TOP_P = 1.0 + @dataclasses.dataclass class SampleRequest: @@ -165,7 +170,7 @@ def run_vllm( requests: List[SampleRequest], n: int, engine_args: EngineArgs, -) -> float: +) -> Tuple[float, Optional[List[RequestOutput]]]: from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) @@ -179,8 +184,8 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=1.0, - top_p=1.0, + temperature=SAMPLING_TEMPERATURE, + top_p=SAMPLING_TOP_P, ignore_eos=True, max_tokens=request.expected_output_len, )) @@ -190,12 +195,13 @@ def run_vllm( use_beam_search = False + outputs = None if not use_beam_search: start = time.perf_counter() - llm.generate(prompts, - sampling_params, - lora_request=lora_requests, - use_tqdm=True) + outputs = llm.generate(prompts, + sampling_params, + lora_request=lora_requests, + use_tqdm=True) end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" @@ -213,7 +219,7 @@ def run_vllm( ignore_eos=True, )) end = time.perf_counter() - return end - start + return end - start, outputs async def run_vllm_async( @@ -221,7 +227,7 @@ async def run_vllm_async( n: int, engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> float: +) -> Tuple[float, Optional[List[RequestOutput]]]: from vllm import SamplingParams async with build_async_engine_client_from_engine_args( @@ -238,8 +244,8 @@ async def run_vllm_async( sampling_params.append( SamplingParams( n=n, - temperature=1.0, - top_p=1.0, + temperature=SAMPLING_TEMPERATURE, + top_p=SAMPLING_TOP_P, ignore_eos=True, max_tokens=request.expected_output_len, )) @@ -255,10 +261,17 @@ async def run_vllm_async( request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) + outputs_dict = {} async for i, res in all_gens: - pass + outputs_dict[i] = res end = time.perf_counter() - return end - start + + num_prompts = len(prompts) + outputs = [] + for i in range(num_prompts): + outputs.append(outputs_dict[i]) + + return end - start, outputs def run_hf( @@ -391,7 +404,7 @@ def main(args: argparse.Namespace): for request in requests) if args.backend == "vllm": if args.async_engine: - elapsed_time = uvloop.run( + elapsed_time, outputs = uvloop.run( run_vllm_async( requests, args.n, @@ -399,8 +412,14 @@ def main(args: argparse.Namespace): args.disable_frontend_multiprocessing, )) else: - elapsed_time = run_vllm(requests, args.n, - EngineArgs.from_cli_args(args)) + elapsed_time, outputs = run_vllm(requests, args.n, + EngineArgs.from_cli_args(args)) + + if args.pickle_outputs: + print("Pickling request outputs : ") + with open("outputs.pkl", "wb+") as f: + pickle.dump(outputs, f) + elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -490,6 +509,11 @@ def main(args: argparse.Namespace): help="Path to the lora adapters to use. This can be an absolute path, " "a relative path, or a Hugging Face model identifier.") + parser.add_argument("--pickle-outputs", + action="store_true", + default=False, + help="Pickle outputs got from benchmark") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 8b247fb9b2388..acc0f04ea5332 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -1,5 +1,6 @@ import tempfile from collections import OrderedDict +from contextlib import contextmanager from typing import Dict, List, TypedDict from unittest.mock import MagicMock, patch @@ -76,6 +77,21 @@ def dist_init(): cleanup_dist_env_and_memory(shutdown_ray=True) +@contextmanager +def _dist_init(): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield + cleanup_dist_env_and_memory(shutdown_ray=True) + + @pytest.fixture def dist_init_torch_only(): if torch.distributed.is_initialized(): @@ -274,3 +290,20 @@ def get_model_patched(**kwargs): def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. model_runner.model) + + +@pytest.fixture(params=[True]) +def run_with_both_engines_lora(request): + # Automatically runs tests twice, once with V1 and once without + use_v1 = request.param + # Tests decorated with `@skip_v1` are only run without v1 + skip_v1 = request.node.get_closest_marker("skip_v1") + + if use_v1: + if skip_v1: + pytest.skip("Skipping test on vllm V1") + with patch('vllm.envs.VLLM_USE_V1', True): + yield + else: + with patch('vllm.envs.VLLM_USE_V1', False): + yield diff --git a/tests/lora/lora_torch_compile.py b/tests/lora/lora_torch_compile.py new file mode 100644 index 0000000000000..502da2807afed --- /dev/null +++ b/tests/lora/lora_torch_compile.py @@ -0,0 +1,241 @@ +import random +from typing import Dict, List, Optional, Tuple + +import torch +from conftest import _dist_init +from utils import DummyLoRAManager + +from vllm.config import LoRAConfig +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, + VocabParallelEmbeddingWithLoRA) +from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights +# yapf: enable +from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + id_to_index: List[Optional[int]], + layer: BaseLayerWithLoRA, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. + sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras: List[LoRALayerWeights] = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager( + layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, + device: torch.device = "cuda" +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs: List[torch.Tensor] = [] + index_mapping: List[int] = [] + prompt_mapping: List[int] = [] + + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device=device)) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device=device) * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +num_loras = 4 +vocab_size = 512 +is_prefill = True +max_loras = 8 +device = "cuda:0" + + +def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph: + print("Pre-pass:") + print(graph) + + return graph + + +def custom_backend(graph: torch.fx.GraphModule, + example_inputs: List[torch.Tensor]): + print("Graph entering custom_backend:") + print(graph.print_readable()) + from torch._inductor import config + current_config = config.shallow_copy_dict() + from torch._inductor.compile_fx import compile_fx + current_config['post_grad_custom_post_pass'] = custom_pass + return compile_fx(graph, example_inputs, config_patches=current_config) + + +@torch.inference_mode() +def test_embeddings() -> None: + + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment(1, 0) + ensure_model_parallel_initialized(1, 1) + + max_loras = 8 + punica_wrapper = get_punica_wrapper(8192, 256, device) + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[vocab_size:, :] = 0 + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + + lora_embedding.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=True) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + + lora_embedding_compiled = torch.compile(lora_embedding, + backend=custom_backend) + + embedding_compiled = torch.compile(embedding, backend=custom_backend) + + input = torch.cat(inputs) + torch._dynamo.mark_dynamic(input, 0) + + embedding_compiled(input) + lora_embedding_compiled(input) + + +if __name__ == '__main__': + with _dist_init(): + test_embeddings() diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 0ba2ce3617b67..3a88d1af57277 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -40,6 +40,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_baichuan_lora(baichuan_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 49a527b99ac16..85e4b8578ea11 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -1,5 +1,7 @@ from typing import List +import pytest + import vllm from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest @@ -45,6 +47,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @fork_new_process_for_each_test def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 5ae705e474ec6..93bc619069570 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -31,6 +31,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.xfail(current_platform.is_rocm(), reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index dfeac380951d8..035ff600bb410 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -1,5 +1,6 @@ from typing import List +import pytest import ray import vllm @@ -71,6 +72,14 @@ def generate_and_test(llm, sql_lora_files): print("removing lora") +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @fork_new_process_for_each_test def test_llama_lora(sql_lora_files): diff --git a/tests/lora/test_lora_bias_e2e.py b/tests/lora/test_lora_bias_e2e.py index c2520c847d873..2f7ab4128b553 100644 --- a/tests/lora/test_lora_bias_e2e.py +++ b/tests/lora/test_lora_bias_e2e.py @@ -28,6 +28,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("lora_bias", [True]) @pytest.mark.parametrize("fully_sharded", [True, False]) def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool): diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 78bf5a1617233..5f91f5710bc69 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -56,6 +56,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.xfail( current_platform.is_rocm(), reason="MiniCPM-V dependency xformers incompatible with ROCm") diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 5a3fcb8d690d9..8656a4b6b870b 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -1,5 +1,7 @@ from typing import List +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -46,6 +48,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_phi2_lora(phi2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 66b5f82bbb97d..bcb0438e00887 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -13,6 +13,10 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops.v1.lora_expand import lora_expand +from vllm.lora.ops.v1.lora_expand_slice import lora_expand_slice +from vllm.lora.ops.v1.lora_shrink import lora_shrink +from vllm.lora.punica_wrapper.v1_gpu import V1KernelMeta from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -378,3 +382,157 @@ def test_punica_expand_nslices( slice_offset += hidden_size assert_close(our_outputs, ref_outputs) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seq_length", [1, 128]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_v1_shrink_expand( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seq_length: int, + seed: int, + device: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + prompt_lora_mapping, + seq_len_tensor, + token_lora_mapping, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + + v1_meta: V1KernelMeta = V1KernelMeta.make(max_loras=num_loras, + max_num_tokens=seq_length * + batches, + device=device) + v1_meta.reset() + v1_meta.prepare_tensors(token_lora_mapping) + + if op_type == "shrink": + lora_shrink(inputs_tensor, lora_weights, our_out_tensor, + *v1_meta.meta_args(inputs_tensor.size(0)), scaling) + else: + lora_expand(inputs_tensor, + lora_weights, + our_out_tensor, + *v1_meta.meta_args(inputs_tensor.size(0)), + add_inputs=True) + + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + prompt_lora_mapping, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seq_length", [1, 128]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_v1_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + seq_length: int, + seed: int, + device: str, +): + + torch.set_default_device(device) + current_platform.seed_everything(seed) + + ( + inputs_tensor, + lora_weights_lst, + our_outputs, + ref_outputs, + b_seq_start_loc, + prompt_lora_mapping, + seq_len_tensor, + token_lora_mapping, + ) = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + + v1_meta: V1KernelMeta = V1KernelMeta.make(max_loras=num_loras, + max_num_tokens=seq_length * + batches, + device=device) + v1_meta.reset() + v1_meta.prepare_tensors(token_lora_mapping) + + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + lora_expand_slice(inputs_tensor, + lora_weights, + our_outputs, + *v1_meta.meta_args(inputs_tensor.size(0)), + slice_offset, + hidden_size, + add_inputs=True) + + ref_torch_groupgemm( + ref_outputs[:, slice_offset:slice_offset + hidden_size], + inputs_tensor, + lora_weights, + prompt_lora_mapping, + seq_len_tensor, + batches, + 1.0, + op_type="expand", + ) + + slice_offset += hidden_size + assert_close(our_outputs, ref_outputs) diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 026269667b473..20532400bb1ac 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -68,6 +68,14 @@ def format_prompt_tuples(prompt): return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, diff --git a/vllm/config.py b/vllm/config.py index ac767bbe14be4..4fe707992f8f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3149,12 +3149,6 @@ def __post_init__(self): " Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if self.lora_config is not None and self.compilation_config.level !=\ - CompilationLevel.NO_COMPILATION: - logger.warning("LoRA is not supported with `torch.compile` yet. " - "Disabling `torch.compile`.") - self.compilation_config.level = CompilationLevel.NO_COMPILATION - current_platform.check_and_update_config(self) if not self.instance_id: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 85164c2165a3c..bf0ad2e988ba9 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -231,8 +231,11 @@ def set_lora( self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - embeddings_indices = self.punica_wrapper.embeddings_indices + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, + 1, 0) + embeddings_indices = torch.narrow( + self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) + indices = embeddings_indices[1].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, @@ -245,11 +248,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_output_org = full_output if full_output.ndim == 3: full_output = full_output.view( - full_output.shape[0] * full_output.shape[1], -1) + full_output.size(0) * full_output.size(1), -1) if full_lora_a_embeddings.ndim == 3: full_lora_a_embeddings = full_lora_a_embeddings.view( - full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], + full_lora_a_embeddings.size(0) * + full_lora_a_embeddings.size(1), -1, ) self.punica_wrapper.add_lora_embedding(full_output, @@ -1028,7 +1031,19 @@ def _get_logits( logits = lm_head.linear_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias - logits = tensor_model_parallel_gather(logits) + + # TODO (varun) : Replace with base layer get_logits() + if self.use_gather: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) + else: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) + if logits is None: return None @@ -1065,19 +1080,19 @@ def _get_logits( lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], + lora_logits.size(0) * lora_logits.size(1), + lora_logits.size(2), ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), posinf=float("inf"), neginf=float("-inf"))) # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): - lora_logits = lora_logits[:logits.shape[0], :] + lora_logits = lora_logits[:logits.size(0), :] logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits + lora_logits.size(1)] = lora_logits # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 5c0e4e5cbc636..8002ce2694d6d 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -327,9 +327,11 @@ def __init__( self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None - self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, - max_batches=self.max_num_seqs, - device=self.device) + self.punica_wrapper = get_punica_wrapper( + max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device, + max_loras=lora_config.max_loras) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} diff --git a/vllm/lora/ops/v1/lora_expand.py b/vllm/lora/ops/v1/lora_expand.py new file mode 100644 index 0000000000000..8b143c2eead5b --- /dev/null +++ b/vllm/lora/ops/v1/lora_expand.py @@ -0,0 +1,242 @@ +import math + +import torch +import triton +import triton.language as tl + +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size * max rank + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + NUM_M_CTAS = tl.cdiv(M, BLOCK_M) + NUM_N_CTAS = tl.cdiv(N, BLOCK_N) + + pid = tl.program_id(0) + lora_idx = pid // (NUM_M_CTAS * NUM_N_CTAS) + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # early exit for the no-lora case. + return + + # lora m indices offsets + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = cta_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # early exit CTA + return + + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) + + offset_k = tl.arange(0, BLOCK_K) + + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + rm = tl.load(cta_lora_seq_indices + offset_rm) + a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride + + offset_n = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + b_ptr = (lora_ptr + l0_stride * lora_id + + offset_k[:, None] * lora_k_stride + rbn[None, :] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_k_stride + + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = tl.arange(0, BLOCK_M) + offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + c_ptr = out_ptr + rm[:, None] * cm_stride + offset_n[None, :] * cn_stride + + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < N) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def _lora_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + 1 + lora_token_start_loc: torch.Tensor, # max-loras + 2 + lora_ids: torch.Tensor, # max-loras + 1 + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix + grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens + that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. + lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], + along with num_tokens_per_lora[i] identifies the the region in + token_indices_sorted_by_lora_ids that LoRA lora_ids[i] + should process. + lora_ids: LoRA ids to process. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + assert token_indices_sorted_by_lora_ids.size(0) == inputs.size(0) + assert num_tokens_per_lora.size(0) == lora_ids.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + M = inputs.size(0) + N = lora_b_weights.size(-2) + K = lora_b_weights.size(-1) + BLOCK_M = 32 + BLOCK_N = 128 + BLOCK_K = 16 + + NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA + NUM_N_CTAS = math.ceil(N / BLOCK_N) + MAX_LORAS = lora_ids.size(0) + + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + xm_stride = inputs.stride(0) + xk_stride = inputs.stride(1) + l0_stride = lora_b_weights.stride(0) + lora_n_stride = lora_b_weights.stride(1) + lora_k_stride = lora_b_weights.stride(2) + cm_stride = output_tensor.stride(0) + cn_stride = output_tensor.stride(1) + + grid = (MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, ) + + _lora_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + xm_stride, + xk_stride, + l0_stride, + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ) + return + + +def lora_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + add_inputs: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_expand", + op_func=_lora_expand, + mutates_args=["output_tensor"], + fake_impl=lora_expand_fake, + ) + lora_expand = torch.ops.vllm.lora_expand + +except AttributeError: + lora_expand = _lora_expand diff --git a/vllm/lora/ops/v1/lora_expand_slice.py b/vllm/lora/ops/v1/lora_expand_slice.py new file mode 100644 index 0000000000000..b89d46512a4e7 --- /dev/null +++ b/vllm/lora/ops/v1/lora_expand_slice.py @@ -0,0 +1,251 @@ +import math + +import torch +import triton +import triton.language as tl + +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size * max rank + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + slice_offset, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + NUM_M_CTAS = tl.cdiv(M, BLOCK_M) + NUM_N_CTAS = tl.cdiv(N, BLOCK_N) + + pid = tl.program_id(0) + lora_idx = pid // (NUM_M_CTAS * NUM_N_CTAS) + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # early exit for the no-lora case. + return + + # lora m indices offsets + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = cta_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # early exit CTA + return + + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) + + offset_k = tl.arange(0, BLOCK_K) + + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + rm = tl.load(cta_lora_seq_indices + offset_rm) + a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride + + offset_n = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + b_ptr = (lora_ptr + l0_stride * lora_id + + offset_k[:, None] * lora_k_stride + rbn[None, :] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_k_stride + + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = tl.arange(0, BLOCK_M) + offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + slice_offset + c_ptr = out_ptr + rm[:, None] * cm_stride + offset_cn[None, :] * cn_stride + + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < + (N + slice_offset)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def _lora_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, # inputs.size(0) + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + lora_token_start_loc: torch.Tensor, # max-loras + lora_ids: torch.Tensor, # max-loras + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix + grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens + that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. + lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], + along with num_tokens_per_lora[i] identifies the the region in + token_indices_sorted_by_lora_ids that LoRA lora_ids[i] + should process. + lora_ids: LoRA ids to process. + slice_offset (int): output_tensor's offset + slice_size (int): current output_tensor's size + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + assert slice_size == lora_b_weights.size(-2) + assert token_indices_sorted_by_lora_ids.size(0) == inputs.size(0) + assert num_tokens_per_lora.size(0) == lora_ids.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + M = inputs.size(0) + N = lora_b_weights.size(-2) + K = lora_b_weights.size(-1) + BLOCK_M = 16 + BLOCK_N = 64 + BLOCK_K = 16 + + NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA + NUM_N_CTAS = math.ceil(N / BLOCK_N) + MAX_LORAS = lora_ids.size(0) + + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + xm_stride = inputs.stride(0) + xk_stride = inputs.stride(1) + l0_stride = lora_b_weights.stride(0) + lora_n_stride = lora_b_weights.stride(1) + lora_k_stride = lora_b_weights.stride(2) + cm_stride = output_tensor.stride(0) + cn_stride = output_tensor.stride(1) + + grid = (MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, ) + + _lora_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + xm_stride, + xk_stride, + l0_stride, + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + slice_offset, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ) + return + +def lora_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_expand_slice", + op_func=_lora_expand_slice, + mutates_args=["output_tensor"], + fake_impl=lora_expand_slice_fake, + ) + lora_expand_slice = torch.ops.vllm.lora_expand_slice + +except AttributeError: + lora_expand_slice = _lora_expand_slice diff --git a/vllm/lora/ops/v1/lora_expand_slices.py b/vllm/lora/ops/v1/lora_expand_slices.py new file mode 100644 index 0000000000000..e82c7ece71fd8 --- /dev/null +++ b/vllm/lora/ops/v1/lora_expand_slices.py @@ -0,0 +1,270 @@ +import math + +import torch +import triton +import triton.language as tl + +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_expand_slices_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_seq_indices, + lora_seq_counts, + lora_seq_start_loc, + lora_ids, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size * max_rank * num_loras + l1_stride, # hidden_size*max_rank + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + NUM_SLICES: tl.constexpr, + MAX_LORAS: tl.constexpr, + NUM_M_CTAS: tl.constexpr, + NUM_N_CTAS: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + pid = tl.program_id(0) + cta_s = pid // (MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS) + cta_l = (pid // (NUM_M_CTAS * NUM_N_CTAS)) % MAX_LORAS + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + + lora_id = tl.load(lora_ids + cta_l) + if lora_id == -1: + # early exit for the no-lora case. + return + + # lora m indices offsets + if cta_l == 0: + lora_m_indices_start = tl.cast(0, tl.int32) + else: + lora_m_indices_start = tl.load(lora_seq_start_loc + cta_l - 1) + lora_m_size = tl.load(lora_seq_counts + cta_l) + + cta_m_offset = cta_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # early exit CTA + return + + cta_lora_seq_indices = (lora_seq_indices + lora_m_indices_start + + cta_m_offset) + cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) + + offset_k = tl.arange(0, BLOCK_K) + + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + rm = tl.load(cta_lora_seq_indices + offset_rm) + a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride + + offset_n = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + b_ptr = (lora_ptr + l0_stride * cta_s + l1_stride * lora_id + + offset_k[:, None] * lora_k_stride + rbn[None, :] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_k_stride + + slice_offset = cta_s * N + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = tl.arange(0, BLOCK_M) + offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + slice_offset + c_ptr = out_ptr + rm[:, None] * cm_stride + offset_cn[None, :] * cn_stride + + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < + (slice_offset + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def _lora_expand_slices( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_seq_indices: torch.Tensor, + lora_seq_counts: torch.Tensor, + lora_seq_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + add_inputs: bool = False, +) -> None: + """_summary_ + + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + + token_lora_mapping_tensor: Each token's lora id as it appears in the + A matrix. + + lora_seq_indices: sorted lora-token mapping. Tokens of the same lora + appear next to each other. This is used so a thread block knows + what tokens to put next to each other when constructing a matrix + block. Essentially, + _, lora_seq_indices = torch.sort(token_lora_mapping, stable=True) + + lora_seq_counts: number of tokens per lora id. essentially, + lora_ids, lora_seq_counts = torch.unique(indices, + sorted=False, + return_counts=True) + + lora_seq_start_loc: start index of each lora id in lora_seq_indices. + essentially, + lora_seq_start_loc = torch.cumsum(lora_seq_counts, dim = 0) + + lora_ids : Set of lora ids in order according to lora_seq_counts, + and lora_seq_indices. + lora_ids, lora_seq_counts = torch.unique(indices, + sorted=False, + return_counts=True) + + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + num_slices = lora_b_weights.size(0) + assert output_tensor.size(1) == lora_b_weights.size(-2) * num_slices + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + assert lora_b_weights.ndim == 4 # (nslices, lora_num, hidden-size, rank) + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N = lora_b_weights.size(-2) + K = lora_b_weights.size(-1) + NUM_SLICES = lora_b_weights.size(0) + M = inputs.size(0) + MAX_LORAS = lora_ids.size(0) + + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + NUM_M_CTAS = math.ceil(M / BLOCK_M) + NUM_N_CTAS = math.ceil(N / BLOCK_N) + + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + grid = (NUM_SLICES * MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, ) + + xm_stride = inputs.stride(0) + xk_stride = inputs.stride(1) + l0_stride = lora_b_weights.stride(0) # slice stride + l1_stride = lora_b_weights.stride(1) # lora stride + lora_n_stride = lora_b_weights.stride(2) + lora_k_stride = lora_b_weights.stride(3) + cm_stride = output_tensor.stride(0) + cn_stride = output_tensor.stride(1) + + _lora_expand_slices_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_seq_indices, + lora_seq_counts, + lora_seq_start_loc, + lora_ids, + xm_stride, + xk_stride, + l0_stride, + l1_stride, + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + NUM_SLICES, + MAX_LORAS, + NUM_M_CTAS, + NUM_N_CTAS, + ADD_INPUTS, + CAST_TYPE, + ) + return + + +try: + lora_expand_slices = torch.library.custom_op( + "lora::v1::lora_expand_slices", + _lora_expand_slices, + mutates_args=["output_tensor"]) +except AttributeError: + lora_expand_slices = _lora_expand_slices + + +def lora_expand_slices_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_seq_indices: torch.Tensor, + lora_seq_counts: torch.Tensor, + lora_seq_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + add_inputs: bool = False, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_expand_slices", + op_func=_lora_expand_slices, + mutates_args=["output_tensor"], + fake_impl=lora_expand_slices_fake, + ) + lora_expand_slices = torch.ops.vllm.lora_expand_slices + +except AttributeError: + lora_expand = _lora_expand_slices diff --git a/vllm/lora/ops/v1/lora_shrink.py b/vllm/lora/ops/v1/lora_shrink.py new file mode 100644 index 0000000000000..bc3ecf036f9b9 --- /dev/null +++ b/vllm/lora/ops/v1/lora_shrink.py @@ -0,0 +1,254 @@ +import torch +import triton +import triton.language as tl + +from vllm.lora.ops.bgmv_shrink import bgmv_shrink +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SMALL_BLOCK_M: tl.constexpr, + NUM_M_CTAS: tl.constexpr, + NUM_N_CTAS: tl.constexpr, +): + lora_idx = tl.program_id(1) + lora_id = tl.load(lora_ids + lora_idx) + if lora_id == -1: + # early exit for the no-lora case. + return + + pid = tl.program_id(0) + cta_sk = pid // (NUM_M_CTAS * NUM_N_CTAS) + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + + # lora m indices offsets + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) + + cta_m_offset = cta_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # early exit CTA + return + + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) + cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) + + offset_k = tl.max_contiguous(BLOCK_K * cta_sk + tl.arange(0, BLOCK_K), + BLOCK_K) + + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + rm = tl.load(cta_lora_seq_indices + offset_rm) + a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride + + offset_n = tl.max_contiguous((cta_n * BLOCK_N) + tl.arange(0, BLOCK_N), + BLOCK_N) + rn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + b_ptr = lora_ptr + lora_id * l0_stride + rn[ + None, :] * lora_n_stride + offset_k[:, None] * lora_k_stride + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + max_k = tl.cdiv(K, BLOCK_K * SPLIT_K) + for k in range(0, max_k): + if EVEN_K: + b_tile = tl.load(b_ptr) + a_tile = tl.load(a_ptr) + else: + b_mask = offset_k[:, None] < K + b_tile = tl.load(b_ptr, mask=b_mask, other=0.0) + + a_mask = offset_k[None, :] < K + a_tile = tl.load(a_ptr, mask=a_mask, other=0.0) + + # TODO (varun) : When a_tile and b_tile are float16s the output is + # also float16. this can lead to infs in the output. + if SMALL_BLOCK_M: + #acc += tl.sum(a_tile * b_tile.T) + acc += tl.sum(a_tile * b_tile.T, 1) + else: + acc += tl.dot(a_tile, b_tile) + + a_ptr += BLOCK_K * SPLIT_K * xk_stride + b_ptr += BLOCK_K * SPLIT_K * lora_k_stride + offset_k += BLOCK_K * SPLIT_K + + acc *= scaling + + offset_cm = tl.arange(0, BLOCK_M) + c_ptr = out_ptr + rm[:, None] * cm_stride + offset_n[None, :] * cn_stride + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_n[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptr, acc, mask=c_mask) + else: + tl.atomic_add(c_ptr, acc, mask=c_mask) + + +@torch.inference_mode() +def _lora_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + lora_token_start_loc: torch.Tensor, # max-loras + lora_ids: torch.Tensor, # max-loras + scaling: float, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix + grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens + that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. + lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], + along with num_tokens_per_lora[i] identifies the the region in + token_indices_sorted_by_lora_ids that LoRA lora_ids[i] should + process. + lora_ids: LoRA ids to process. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + M = inputs.size(0) # num tokens + if M <= 16: + # GemmV is better for smaller batchsizes + return bgmv_shrink(inputs, lora_a_weights, output_tensor, + token_lora_mapping, scaling) + + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + assert token_indices_sorted_by_lora_ids.size(0) == inputs.size(0) + assert num_tokens_per_lora.size(0) == lora_ids.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + xm_stride = inputs.stride(0) + xk_stride = inputs.stride(1) + l0_stride = lora_a_weights.stride(0) + lora_k_stride = lora_a_weights.stride(2) + lora_n_stride = lora_a_weights.stride(1) + cm_stride = output_tensor.stride(0) + cn_stride = output_tensor.stride(1) + + # TODO tuning this config + N = lora_a_weights.size(-2) + K = lora_a_weights.size(-1) + MAX_LORAS = lora_ids.size(0) + + BLOCK_M = 32 + BLOCK_N = 16 + + if M < 128: + BLOCK_K = 256 + SPLIT_K = 64 + else: + BLOCK_K = 32 + SPLIT_K = 8 + + num_warps = 4 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + SMALL_BLOCK_M = BLOCK_M < 16 + NUM_M_CTAS = triton.cdiv(M, BLOCK_M) + NUM_N_CTAS = triton.cdiv(N, BLOCK_N) + + grid = ( + SPLIT_K * NUM_M_CTAS * NUM_N_CTAS, + MAX_LORAS, + ) + _lora_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + SMALL_BLOCK_M, + NUM_M_CTAS, + NUM_N_CTAS, + num_warps=num_warps, + ) + return + +def lora_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + scaling: float, +) -> None: + return + + +try: + direct_register_custom_op( + op_name="lora_shrink", + op_func=_lora_shrink, + mutates_args=["output_tensor"], + fake_impl=lora_shrink_fake, + ) + lora_shrink = torch.ops.vllm.lora_shrink + +except AttributeError: + lora_shrink = _lora_shrink diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index b9ec0c4bc6323..dd463dd68dbab 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -167,7 +167,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self.is_prefill = False self.no_lora = False - def _update_base_metadata( + def update_base_metadata( self, mapping: "LoRAMapping", lora_index_to_id: List[Optional[int]], @@ -192,16 +192,18 @@ def _update_base_metadata( self.device, long_lora_context, ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) + self._token_lora_indices[:base_indices.size(0)].copy_( + base_indices, non_blocking=True) + self._sampler_indices[:sampler_indices.size(0)].copy_( + sampler_indices, non_blocking=True) + self._sampler_indices_padded[:sampler_indices_padded.size(0)].copy_( + sampler_indices_padded, non_blocking=True) self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) + size(0), :embeddings_indices.size(1)].copy_( + embeddings_indices, non_blocking=True) if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) + self._long_lora_indices[:long_lora_offsets_tensor.size(0)].copy_( + long_lora_offsets_tensor, non_blocking=True) else: self._long_lora_indices.zero_() self.indices_len[:] = indices_len @@ -212,10 +214,10 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: batch_size, max_length, token_nums, no_lora) = compute_meta(token_lora_tensor) - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + self._seq_start_locs[:b_seq_start_tensor.size(0)].copy_( b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + self._seq_lengths[:seq_length_tensor.size(0)].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.size(0)].copy_( lora_indices_tensor) self.batch_size = batch_size self.max_length = max_length @@ -239,14 +241,14 @@ def _apply_bias( where n is number of slices """ org_output = output - output = output.view(-1, output.shape[-1]) + output = output.view(-1, output.size(-1)) indices = indices.view(-1) offset_left = 0 for slice_idx, slice in enumerate(output_slices): bias = lora_bias_stacked[slice_idx] if bias is not None: - bias = bias.view(-1, bias.shape[-1]) + bias = bias.view(-1, bias.size(-1)) bias = bias[indices] bias[indices == -1] = 0 output[:, offset_left:offset_left + slice] += bias @@ -328,9 +330,9 @@ def update_metadata( long_lora_context: Optional["LongContextLoRAContext"] = None, **kwargs): - self._update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) + self.update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) if mapping.is_prefill: # Update metadata required for prefill-related operators. self._update_prefill_metada(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index cd64878d95ae3..8c102a15b638d 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,3 +1,4 @@ +import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils import print_info_once @@ -7,9 +8,14 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: if current_platform.is_cuda_alike(): # Lazy import to avoid ImportError - from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU - print_info_once("Using PunicaWrapperGPU.") - return PunicaWrapperGPU(*args, **kwargs) + if envs.VLLM_USE_V1: + from vllm.lora.punica_wrapper.v1_gpu import V1LoRAGPU + print_info_once("Using V1LoRAGPU.") + return V1LoRAGPU(*args, **kwargs) + else: + from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU + print_info_once("Using PunicaWrapperGPU.") + return PunicaWrapperGPU(*args, **kwargs) elif current_platform.is_hpu(): # Lazy import to avoid ImportError from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU diff --git a/vllm/lora/punica_wrapper/v1_gpu.py b/vllm/lora/punica_wrapper/v1_gpu.py new file mode 100644 index 0000000000000..9e02afc61b3ac --- /dev/null +++ b/vllm/lora/punica_wrapper/v1_gpu.py @@ -0,0 +1,423 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final + +import torch + +from vllm.lora.layers import LoRAMapping +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.v1.lora_expand import lora_expand + from vllm.lora.ops.v1.lora_shrink import lora_shrink + from vllm.lora.ops.v1.lora_expand_slice import lora_expand_slice + +from .punica_base import PunicaWrapperBase + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.models import LongContextLoRAContext + + +@dataclass +class V1KernelMeta: + token_lora_mapping: torch.Tensor + token_indices_sorted_by_lora_ids: torch.Tensor + active_lora_ids: torch.Tensor + num_tokens_per_lora: torch.Tensor + lora_token_start_loc: torch.Tensor + + @staticmethod + def make(max_loras: int, max_num_tokens: int, + device: torch.device) -> "V1KernelMeta": + + token_lora_mapping = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + # +1 because "no-lora" is also a possibility + # example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1] + # is a possibility. + active_lora_ids = torch.empty(max_loras + 1, + dtype=torch.int32, + device=device) + + # using running example, [3, 10, 5, 2] is a possibility. + num_tokens_per_lora = torch.zeros(max_loras + 1, + dtype=torch.int32, + device=device) + + # +2 for this because, the first index is always 0 + # for example: let max loras be 3, then lora_token_start_loc, + # can be [0, 3, 13, 18, 20]. + lora_token_start_loc = torch.zeros(max_loras + 2, + dtype=torch.int32, + device=device) + return V1KernelMeta( + token_lora_mapping=token_lora_mapping, + token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, + active_lora_ids=active_lora_ids, + num_tokens_per_lora=num_tokens_per_lora, + lora_token_start_loc=lora_token_start_loc) + + def reset(self): + self.active_lora_ids.fill_(-1) + self.num_tokens_per_lora.fill_(0) + self.lora_token_start_loc.fill_(0) + + def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: + num_tokens = token_lora_mapping.size(0) + + # copy token lora mapping + self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, + non_blocking=True) + + # token_indices_sorted_by_lora_ids + _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, + stable=True) + # start gpu transfer + self.token_indices_sorted_by_lora_ids[:num_tokens].copy_( + token_indices_sorted_by_lora_ids, non_blocking=True) + + # active_lora_ids, num_tokens_per_lora + lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, + sorted=False, + return_counts=True) + self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, + non_blocking=True) + self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_( + num_tokens_per_lora, non_blocking=True) + + # lora_token_start_loc + lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) + self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( + lora_token_start_loc, non_blocking=True) + + def meta_args( + self, num_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor]: + return (self.token_lora_mapping[:num_tokens], + self.token_indices_sorted_by_lora_ids[:num_tokens], + self.num_tokens_per_lora, self.lora_token_start_loc, + self.active_lora_ids) + + +@final +class V1LoRAGPU(PunicaWrapperBase): + """ + TODO (varun) + _summary_ + + Args: + PunicaWrapperBase (_type_): _description_ + """ + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + self.max_loras = kwargs['max_loras'] + self.token_mapping_v1_meta = V1KernelMeta.make(self.max_loras, + max_num_batched_tokens, + device=device) + self.prompt_mapping_v1_meta = V1KernelMeta.make(self.max_loras, + max_batches, + device=device) + + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + # TODO (varun) : Make no_lora case work with torch compile + #self.no_lora = all([x == -1 for x in mapping.prompt_mapping]) + #if self.no_lora: + # # no update required + # return + #num_tokens: int = len(mapping.index_mapping) + + self.token_mapping_v1_meta.reset() + self.prompt_mapping_v1_meta.reset() + + self.update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + + self.token_mapping_v1_meta.prepare_tensors(self.token_lora_indices) + self.prompt_mapping_v1_meta.prepare_tensors(self.sampler_indices) + + def _shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + lora_shrink(x, w_t_all, y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), scale) + + def _expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + lora_expand(x, w_t_all, y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), + add_inputs) + + def _expand_slice( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + + lora_expand_slice(x, w_t_all, y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), + y_offset, y_slice_size, add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + self._expand_slice(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.size(-1)) + self._shrink(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.size(-1)) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.size(-1)) + offset_left = offset_start + if lora_bias_stacked is not None: + assert isinstance(x, torch.Tensor) + self._apply_bias(self._token_lora_indices[:x.size(0)], y, + output_slices, lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + self._expand(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self._token_lora_indices[:x.size(0)], y, + output_slices, lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.size(-1)) + x = x.view(-1, x.size(-1)) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + lora_shrink(x, lora_a_stacked, buffer, + *self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale) + lora_expand(buffer, + lora_b_stacked, + y, + *self.prompt_mapping_v1_meta.meta_args(x.size(0)), + add_inputs=True) + + y = y.view_as(y_org) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 34d65ed51ef3f..de615a5348e1f 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -60,7 +60,7 @@ class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" - d = x.shape[-1] // 2 + d = x.size(-1) // 2 return F.silu(x[..., :d]) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 43ea4eb5a4d1a..43c9424801222 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -46,7 +46,7 @@ def forward_native( x = x + residual.to(torch.float32) residual = x.to(orig_dtype) - hidden_size = x.shape[-1] + hidden_size = x.size(-1) if hidden_size != self.hidden_size: raise ValueError("Expected hidden_size to be " f"{self.hidden_size}, but found: {hidden_size}") diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 117fe086e5e87..eaf8e356ed416 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -129,7 +129,7 @@ def forward_native( if offsets is not None: positions = positions + offsets positions = positions.flatten() - num_tokens = positions.shape[0] + num_tokens = positions.size(0) cos_sin = self.cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9ddbff7c9a604..7ab435954c5d5 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -164,14 +164,13 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: return ret -def generate_block_hash_extra_keys( +def generate_block_hash_extra_keys_for_mm( request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: - """Generate extra keys for the block hash. The extra keys can come from - the multi-modal inputs and request specific metadata (e.g., LoRA ID). - For multi-modal inputs, the extra keys are (mm_hash, start_offset) that - indicate a mm input contained in the block and its starting offset in - the block tokens. + start_mm_idx: int) -> Tuple[Optional[List[Any]], int]: + """Generate extra keys related to MultiModal request for block hash + computation. For multi-modal inputs, the extra keys are + (mm_hash, start_offset) that indicate a mm input contained in the + block and its starting offset in the block tokens. Args: request: The request object. @@ -182,7 +181,6 @@ def generate_block_hash_extra_keys( Returns: A tuple of extra keys and the next multi-modal index. """ - mm_positions, mm_hashes = request.mm_positions, request.mm_hashes if not mm_positions: return None, start_mm_idx @@ -231,7 +229,56 @@ def generate_block_hash_extra_keys( else: # This block has not reached the current mm input. break - return tuple(extra_keys), curr_mm_idx + return extra_keys, curr_mm_idx + + +def generate_block_hash_extra_keys_for_lora( + request: Request) -> Optional[List[int]]: + """Generate extra keys related to LoRA for block hash computation. + + Args: + request: The request object. + + Returns: + Return LoRA id of the request if it is a LoRA request. Return None + otherwise. + """ + if not request.lora_request: + return None + return [request.lora_request.lora_int_id] + + +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + mm_extra_keys: Optional[List[Any]] + mm_extra_keys, new_start_mm_idx = generate_block_hash_extra_keys_for_mm( + request, start_token_idx, end_token_idx, start_mm_idx) + lora_extra_keys: Optional[ + List[int]] = generate_block_hash_extra_keys_for_lora(request) + + extra_keys: List[Any] = [] + if mm_extra_keys: + extra_keys.extend(mm_extra_keys) + if lora_extra_keys: + extra_keys.extend(lora_extra_keys) + + if not extra_keys: + return None, new_start_mm_idx + + return tuple(extra_keys), new_start_mm_idx def hash_block_tokens( @@ -280,7 +327,7 @@ def hash_request_tokens(block_size: int, "The number of multi-modal positions and hashes must match.") # TODO: Extend this to support other features such as LoRA. - need_extra_keys = bool(mm_positions) + need_extra_keys = bool(mm_positions) or (request.lora_request is not None) extra_keys = None curr_mm_idx = 0 diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 08e7c0fd4dc9b..140d6d00b705d 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -5,6 +5,7 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange from vllm.sampling_params import SamplingParams @@ -32,8 +33,6 @@ def __init__( self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config - # TODO: Support LoRA. - assert lora_config is None, "V1 does not support LoRA yet." # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -173,6 +172,14 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Record the LoRAs in scheduled_running_reqs + requested_loras: Set[int] = set() + if self.lora_config: + requested_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(requested_loras) <= self.lora_config.max_loras + # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting: @@ -184,6 +191,17 @@ def schedule(self) -> "SchedulerOutput": break request = self.waiting[0] + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request: + req_lora_id = request.lora_request.lora_int_id + if len(requested_loras) == self.lora_config.max_loras and ( + req_lora_id not in requested_loras): + # cannot schedule + break + requested_loras.add(req_lora_id) + # Get already-cached tokens. computed_blocks = self.kv_cache_manager.get_computed_blocks( request) @@ -521,6 +539,7 @@ class NewRequestData: sampling_params: SamplingParams block_ids: List[int] num_computed_tokens: int + lora_request: Optional[LoRARequest] @classmethod def from_request( @@ -539,6 +558,7 @@ def from_request( sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, ) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index cc0c7ea23469a..51a83847330d3 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -22,6 +22,8 @@ class DetokenizerRequest: stop: List[str] include_stop_str_in_output: bool + lora_request: Optional[LoRARequest] + @dataclass class EngineCoreRequest: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ba2b8377759d6..bbe1d8f4c2ac5 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -314,8 +314,7 @@ async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - assert lora_request is None - return self.detokenizer.tokenizer + return self.detokenizer.get_tokenizer(lora_request) async def is_tracing_enabled(self) -> bool: return False diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 02f34e2b54dd5..af8bc151500d4 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -3,11 +3,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.tokenizer import get_lora_tokenizer, get_tokenizer from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput logger = init_logger(__name__) @@ -197,16 +198,25 @@ def __init__(self, tokenizer_mode: str = "auto", trust_remote_code: bool = False, revision: Optional[str] = None): - # TODO: once we support LoRA, we should should pass the tokenizer - # here. We currently have two copies (this + in the LLMEngine). - self.tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode=tokenizer_mode, - trust_remote_code=trust_remote_code, - revision=revision) + # per-request tokenizers, like in LoRA, are created in + # add_request. All other requests use the base_tokenizer. + self._base_tokenizer = get_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + revision=revision) # Request id -> IncrementalDetokenizer self.request_states: Dict[str, IncrementalDetokenizer] = {} + def get_tokenizer(self, + lora_request: Optional[LoRARequest] = None + ) -> AnyTokenizer: + if lora_request: + return get_lora_tokenizer(lora_request) + else: + return self._base_tokenizer + def is_request_active(self, request_id: str): return request_id in self.request_states @@ -233,8 +243,9 @@ def add_request( assert (request.request_id not in self.request_states) + req_tokenizer = self.get_tokenizer(request.lora_request) request_state = IncrementalDetokenizer.from_new_request( - self.tokenizer, request) + req_tokenizer, request) self.request_states[request.request_id] = request_state def step( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6ee8732bc902c..d3be673044d28 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -50,7 +50,7 @@ def __init__( self.mm_hasher = MMHasher() # TODO: run in an ThreadpoolExecutor or BackgroundProcess. - # This ideally should releases the GIL, so we should not block the + # This ideally should release the GIL, so we should not block the # asyncio loop while this is running. def process_inputs( self, @@ -133,6 +133,7 @@ def process_inputs( sampling_params.output_kind, sampling_params.stop, sampling_params.include_stop_str_in_output, + lora_request, ) # Make Request for EngineCore. diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 6c4d300ec6efe..a7bb4d37e30da 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,11 +1,12 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import numpy as np import torch +from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata @@ -29,6 +30,8 @@ class CachedRequestState: num_computed_tokens: int output_token_ids: List[int] + lora_request: Optional[LoRARequest] + @property def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) @@ -157,6 +160,11 @@ def __init__( ] self.prompt_token_ids: Optional[torch.Tensor] = None + # lora related + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_requests: Set[LoRARequest] = set() + # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. @@ -231,6 +239,15 @@ def add_request( if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) + # Add request lora ID + if request.lora_request: + self.request_lora_mapping[ + req_index] = request.lora_request.lora_int_id + self.lora_requests.add(request.lora_request) + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: @@ -247,6 +264,12 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) + + # LoRA + # only update request_lora_mapping. Defer the updates + # to lora_requests to prepare_lora_inputs. + self.request_lora_mapping[req_index] = 0 + return req_index def clear(self) -> None: @@ -262,6 +285,9 @@ def clear(self) -> None: self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_requests.clear() def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: @@ -315,6 +341,9 @@ def condense(self, empty_req_indices: List[int]) -> None: if generator is not None: self.generators[empty_index] = generator + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -398,6 +427,34 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) + def make_lora_inputs( + self, num_scheduled_tokens: np.ndarray + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]: + """ + Given the num_scheduled_tokens for each request in the batch, return + datastructures used to activate the current LoRAs. + Returns: + 1. prompt_lora_mapping: A tuple of size self.num_reqs where, + prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) + where, token_lora_mapping[i] is the LoRA id to use for ith token. + 3. lora_requests: Set of relevant LoRA requests. + """ + + req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + prompt_lora_mapping = tuple(req_lora_mapping) + token_lora_mapping = tuple( + req_lora_mapping.repeat(num_scheduled_tokens)) + + active_lora_ids: Set[int] = set(np.unique(req_lora_mapping)) + active_lora_requests: Set[LoRARequest] = \ + set({lr for lr in self.lora_requests \ + if lr.lora_int_id in active_lora_ids}) + # Update lora requests + self.lora_requests = active_lora_requests + + return prompt_lora_mapping, token_lora_mapping, self.lora_requests + @property def num_reqs(self) -> int: return len(self.req_id_to_index) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 509771b7e2e5a..cb6a866c74906 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -23,6 +23,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -30,7 +31,7 @@ logger = init_logger(__name__) -class GPUModelRunner: +class GPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -117,6 +118,10 @@ def __init__( # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( reversed(self.vllm_config.compilation_config.capture_sizes)) + self.cudagraph_batch_sizes = [ + bs for bs in self.cudagraph_batch_sizes + if bs <= self.scheduler_config.max_num_seqs + ] # Persistent buffers for CUDA graphs. self.input_ids = torch.zeros(self.max_num_tokens, @@ -231,6 +236,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], + lora_request=new_req_data.lora_request, ) req_ids_to_add.append(req_id) @@ -274,15 +280,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] + num_scheduled_tokens_list = [] max_num_scheduled_tokens = 0 for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) + num_scheduled_tokens_list.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, + dtype=np.int32) assert max_num_scheduled_tokens > 0 # Get request indices. @@ -364,6 +371,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): block_table=self.input_batch.block_table[:num_reqs], slot_mapping=slot_mapping, ) + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled @@ -587,6 +599,12 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -704,15 +722,33 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) - logits = self.model.compute_logits(hidden_states, None) - logits = logits[:self.max_num_tokens] - # TODO(woosuk): Consider the memory usage of the sampler. - torch.cuda.synchronize() - del hidden_states, logits - self.encoder_cache.clear() + # TODO (varun): Reconcile text-only with multi-modal + # compute num tokens per request. For profile, have maximum num_reqs and + # that collectively have maximum num_tokens. + num_reqs = self.scheduler_config.max_num_seqs + num_tokens = self.max_num_tokens + min_tokens_per_req: int = num_tokens // num_reqs + + num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + + with self.maybe_profile_with_lora(self.lora_config, + num_scheduled_tokens): + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) + hidden_states = hidden_states[logit_indices] + logits = self.model.compute_logits(hidden_states, None) + # TODO(woosuk): Consider the memory usage of the sampler. + torch.cuda.synchronize() + del hidden_states, logits + self.encoder_cache.clear() gc.collect() def capture_model(self) -> None: @@ -730,10 +766,13 @@ def capture_model(self) -> None: # can reuse the memory pool allocated for the large shapes. with graph_capture(): for num_tokens in reversed(self.cudagraph_batch_sizes): - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): + print(f"running capture model for {num_tokens}...") + with self.maybe_capture_model_with_lora( + self.lora_config, num_tokens): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(self.model, num_tokens, self.kv_caches) self._dummy_run(self.model, num_tokens, self.kv_caches) - self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py new file mode 100644 index 0000000000000..d989019341043 --- /dev/null +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -0,0 +1,146 @@ +""" +Define LoRA adapter for model runner. +""" + +from contextlib import contextmanager +from typing import Set, Tuple + +import numpy as np +import torch.nn as nn + +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.v1.worker.gpu_input_batch import InputBatch + +logger = init_logger(__name__) + + +# Defined as a mixin for GPUModelRunner +class LoRAModelRunnerMixin: + + LORA_WARMUP_RANK = 8 + + def load_lora_model(self, model: nn.Module, model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: LoRAConfig, device: str) -> nn.Module: + + assert supports_lora( + model), f"{model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(model.config, "max_position_embeddings"): + max_pos_embeddings = model.config.max_position_embeddings + else: + max_pos_embeddings = ( + model.config.text_config.max_position_embeddings) + + # Add LoRA Manager to the Model Runner + self.lora_manager = LRUCacheWorkerLoRAManager( + scheduler_config.max_num_seqs, + scheduler_config.max_num_batched_tokens, + model_config.get_vocab_size(), + lora_config, + device, + model.embedding_modules, + model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + return self.lora_manager.create_lora_manager(model) + + def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...], + token_lora_mapping: Tuple[int, ...], + lora_requests: Set[LoRARequest]) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + + # In V1, we use the same kernels for both prefill and decode. + # To that effect, is_prefill is marked None. + lora_mapping = LoRAMapping(token_lora_mapping, + prompt_lora_mapping, + is_prefill=None) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def set_active_loras(self, input_batch: InputBatch, + num_scheduled_tokens: np.ndarray) -> None: + + prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs + token_lora_mapping: Tuple[int, + ...] # of size np.sum(num_scheduled_tokens) + lora_requests: Set[LoRARequest] + prompt_lora_mapping, token_lora_mapping, lora_requests = \ + input_batch.make_lora_inputs(num_scheduled_tokens) + return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) + + @contextmanager + def maybe_profile_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_reqs = len(num_scheduled_tokens) + num_loras = lora_config.max_loras + + # Make prompt lora mapping + # Assign LoRA IDs to requests arbitrarily + prompt_lora_mapping = np.random.randint(low=1, + high=num_loras + 1, + size=num_reqs, + dtype=np.int32) + # Make token lora mapping + token_lora_mapping = np.repeat(prompt_lora_mapping, + num_scheduled_tokens) + + # Make dummy lora requests + lora_requests: Set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + } + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), + lora_requests) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters() + + @contextmanager + def maybe_capture_model_with_lora(self, lora_config: LoRAConfig, + batch_size: int): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + prompt_lora_mapping = tuple([0] * batch_size) + token_lora_mapping = tuple([0] * batch_size) + + self._set_active_loras(prompt_lora_mapping, token_lora_mapping, + set()) + yield + + # __exit__ code