From 78737aa03685ce7f3a131e797177b433e0f0aa39 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Fri, 8 Nov 2024 17:17:26 +0000 Subject: [PATCH 01/12] Add lora support Signed-off-by: Varun Sundar Rabindranath --- benchmarks/benchmark_throughput.py | 47 +- tests/lora/conftest.py | 32 ++ tests/lora/lora_torch_compile.py | 243 +++++++++++ tests/lora/test_baichuan.py | 8 + tests/lora/test_chatglm3_tp.py | 10 + tests/lora/test_gemma.py | 8 + tests/lora/test_llama_tp.py | 9 + tests/lora/test_lora_bias_e2e.py | 8 + tests/lora/test_minicpmv.py | 8 + tests/lora/test_phi.py | 10 + tests/lora/test_punica_sizes.py | 156 +++++++ tests/lora/test_quant_model.py | 8 + vllm/config.py | 10 +- vllm/lora/layers.py | 33 +- vllm/lora/models.py | 3 +- vllm/lora/ops/v1/lora_expand.py | 234 ++++++++++ vllm/lora/ops/v1/lora_expand_slice.py | 244 +++++++++++ vllm/lora/ops/v1/lora_expand_slices.py | 276 ++++++++++++ vllm/lora/ops/v1/lora_shrink.py | 225 ++++++++++ vllm/lora/punica_wrapper/punica_base.py | 30 +- vllm/lora/punica_wrapper/punica_selector.py | 12 +- vllm/lora/punica_wrapper/v1_gpu.py | 403 ++++++++++++++++++ vllm/model_executor/layers/activation.py | 2 +- vllm/model_executor/layers/layernorm.py | 2 +- .../model_executor/layers/rotary_embedding.py | 2 +- vllm/v1/core/scheduler.py | 24 +- vllm/v1/engine/__init__.py | 2 + vllm/v1/engine/async_llm.py | 3 +- vllm/v1/engine/detokenizer.py | 27 +- vllm/v1/engine/processor.py | 3 +- vllm/v1/worker/gpu_input_batch.py | 59 ++- vllm/v1/worker/gpu_model_runner.py | 68 ++- vllm/v1/worker/lora_model_runner_mixin.py | 147 +++++++ 33 files changed, 2277 insertions(+), 79 deletions(-) create mode 100644 tests/lora/lora_torch_compile.py create mode 100644 vllm/lora/ops/v1/lora_expand.py create mode 100644 vllm/lora/ops/v1/lora_expand_slice.py create mode 100644 vllm/lora/ops/v1/lora_expand_slices.py create mode 100644 vllm/lora/ops/v1/lora_shrink.py create mode 100644 vllm/lora/punica_wrapper/v1_gpu.py create mode 100644 vllm/v1/worker/lora_model_runner_mixin.py diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c1b10b3cf8f58..cc2020e2a3332 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -4,6 +4,7 @@ import json import random import time +import pickle from functools import cache from typing import Dict, List, Optional, Tuple @@ -24,7 +25,10 @@ from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.utils import FlexibleArgumentParser, merge_async_iterators +from vllm.outputs import RequestOutput +SAMPLING_TEMPERATURE=0.0 +SAMPLING_TOP_P=1.0 @dataclasses.dataclass class SampleRequest: @@ -165,7 +169,7 @@ def run_vllm( requests: List[SampleRequest], n: int, engine_args: EngineArgs, -) -> float: +) -> Tuple[float, Optional[List[RequestOutput]]]: from vllm import LLM, SamplingParams llm = LLM(**dataclasses.asdict(engine_args)) @@ -179,8 +183,8 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=1.0, - top_p=1.0, + temperature=SAMPLING_TEMPERATURE, + top_p=SAMPLING_TOP_P, ignore_eos=True, max_tokens=request.expected_output_len, )) @@ -190,9 +194,10 @@ def run_vllm( use_beam_search = False + outputs = None if not use_beam_search: start = time.perf_counter() - llm.generate(prompts, + outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests, use_tqdm=True) @@ -213,7 +218,7 @@ def run_vllm( ignore_eos=True, )) end = time.perf_counter() - return end - start + return end - start, outputs async def run_vllm_async( @@ -221,7 +226,7 @@ async def run_vllm_async( n: int, engine_args: AsyncEngineArgs, disable_frontend_multiprocessing: bool = False, -) -> float: +) -> Tuple[float, Optional[List[RequestOutput]]]: from vllm import SamplingParams async with build_async_engine_client_from_engine_args( @@ -238,8 +243,8 @@ async def run_vllm_async( sampling_params.append( SamplingParams( n=n, - temperature=1.0, - top_p=1.0, + temperature=SAMPLING_TEMPERATURE, + top_p=SAMPLING_TOP_P, ignore_eos=True, max_tokens=request.expected_output_len, )) @@ -255,10 +260,17 @@ async def run_vllm_async( request_id=f"test{i}") generators.append(generator) all_gens = merge_async_iterators(*generators) + outputs_dict = {} async for i, res in all_gens: - pass + outputs_dict[i] = res end = time.perf_counter() - return end - start + + num_prompts = len(prompts) + outputs = [] + for i in range(num_prompts): + outputs.append(outputs_dict[i]) + + return end - start, outputs def run_hf( @@ -391,7 +403,7 @@ def main(args: argparse.Namespace): for request in requests) if args.backend == "vllm": if args.async_engine: - elapsed_time = uvloop.run( + elapsed_time, outputs = uvloop.run( run_vllm_async( requests, args.n, @@ -399,8 +411,14 @@ def main(args: argparse.Namespace): args.disable_frontend_multiprocessing, )) else: - elapsed_time = run_vllm(requests, args.n, + elapsed_time, outputs = run_vllm(requests, args.n, EngineArgs.from_cli_args(args)) + + if args.pickle_outputs: + print("Pickling request outputs : ") + with open("outputs.pkl", "wb+") as f: + pickle.dump(outputs, f) + elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -490,6 +508,11 @@ def main(args: argparse.Namespace): help="Path to the lora adapters to use. This can be an absolute path, " "a relative path, or a Hugging Face model identifier.") + parser.add_argument("--pickle-outputs", + action="store_true", + default=False, + help="Pickle outputs got from benchmark") + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() if args.tokenizer is None: diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 8b247fb9b2388..40f5316f9d031 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -20,6 +20,7 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model +from contextlib import contextmanager class ContextIDInfo(TypedDict): @@ -75,6 +76,20 @@ def dist_init(): yield cleanup_dist_env_and_memory(shutdown_ray=True) +@contextmanager +def _dist_init(): + temp_file = tempfile.mkstemp()[1] + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"file://{temp_file}", + local_rank=0, + backend="nccl", + ) + initialize_model_parallel(1, 1) + yield + cleanup_dist_env_and_memory(shutdown_ray=True) + @pytest.fixture def dist_init_torch_only(): @@ -274,3 +289,20 @@ def get_model_patched(**kwargs): def llama_2_7b_model_extra_embeddings(llama_2_7b_engine_extra_embeddings): yield (llama_2_7b_engine_extra_embeddings.model_executor.driver_worker. model_runner.model) + + +@pytest.fixture(params=[True]) +def run_with_both_engines_lora(request): + # Automatically runs tests twice, once with V1 and once without + use_v1 = request.param + # Tests decorated with `@skip_v1` are only run without v1 + skip_v1 = request.node.get_closest_marker("skip_v1") + + if use_v1: + if skip_v1: + pytest.skip("Skipping test on vllm V1") + with patch('vllm.envs.VLLM_USE_V1', True): + yield + else: + with patch('vllm.envs.VLLM_USE_V1', False): + yield diff --git a/tests/lora/lora_torch_compile.py b/tests/lora/lora_torch_compile.py new file mode 100644 index 0000000000000..fef535df698d6 --- /dev/null +++ b/tests/lora/lora_torch_compile.py @@ -0,0 +1,243 @@ +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +import torch +import torch.nn.functional as F + +from vllm.config import LoRAConfig +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (LoRAMapping, + BaseLayerWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.lora.punica_wrapper import get_punica_wrapper +from vllm.model_executor.utils import set_random_seed + +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) + +from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, + PackedLoRALayerWeights) + +from utils import DummyLoRAManager +from vllm.distributed.parallel_state import ensure_model_parallel_initialized, init_distributed_environment +from conftest import _dist_init + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + +def populate_loras( + id_to_index: List[Optional[int]], + layer: BaseLayerWithLoRA, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. + sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras: List[LoRALayerWeights] = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager( + layer_weights.device).init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, + device: torch.device = "cuda" +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs: List[torch.Tensor] = [] + index_mapping: List[int] = [] + prompt_mapping: List[int] = [] + + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device=device)) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device=device) * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +num_loras = 4 +vocab_size = 512 +is_prefill = True +max_loras = 8 +device="cuda:0" + +def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph: + print("Pre-pass:") + print(graph) + + return graph + + +def custom_backend(graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): + print("Graph entering custom_backend:") + print(graph.print_readable()) + from torch._inductor import config + current_config = config.shallow_copy_dict() + from torch._inductor.compile_fx import compile_fx + current_config['post_grad_custom_post_pass'] = custom_pass + return compile_fx(graph, example_inputs, config_patches=current_config) + +@torch.inference_mode() +def test_embeddings() -> None: + + torch.cuda.set_device(device) + torch.set_default_device(device) + + init_distributed_environment(1, 0) + ensure_model_parallel_initialized(1,1) + + max_loras = 8 + punica_wrapper = get_punica_wrapper(8192, 256, device) + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(vocab_size, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[vocab_size:, :] = 0 + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + + lora_embedding.set_mapping(punica_wrapper) + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, vocab_size), + device=device) + lora_mapping = LoRAMapping(index_mapping, + prompt_mapping, + is_prefill=True) + punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, + vocab_size, + lora_config.lora_extra_vocab_size) + + lora_embedding_compiled = torch.compile(lora_embedding, backend=custom_backend) + + embedding_compiled = torch.compile(embedding, backend=custom_backend) + + input = torch.cat(inputs) + torch._dynamo.mark_dynamic(input, 0) + + lr = embedding_compiled(input) + lora_result = lora_embedding_compiled(input) + +if __name__ == '__main__': + with _dist_init(): + test_embeddings() \ No newline at end of file diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 0ba2ce3617b67..3a88d1af57277 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -40,6 +40,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_baichuan_lora(baichuan_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, diff --git a/tests/lora/test_chatglm3_tp.py b/tests/lora/test_chatglm3_tp.py index 49a527b99ac16..85e4b8578ea11 100644 --- a/tests/lora/test_chatglm3_tp.py +++ b/tests/lora/test_chatglm3_tp.py @@ -1,5 +1,7 @@ from typing import List +import pytest + import vllm from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest @@ -45,6 +47,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @fork_new_process_for_each_test def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py index 5ae705e474ec6..93bc619069570 100644 --- a/tests/lora/test_gemma.py +++ b/tests/lora/test_gemma.py @@ -31,6 +31,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.xfail(current_platform.is_rocm(), reason="There can be output mismatch on ROCm") def test_gemma_lora(gemma_lora_files): diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index dfeac380951d8..035ff600bb410 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -1,5 +1,6 @@ from typing import List +import pytest import ray import vllm @@ -71,6 +72,14 @@ def generate_and_test(llm, sql_lora_files): print("removing lora") +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @fork_new_process_for_each_test def test_llama_lora(sql_lora_files): diff --git a/tests/lora/test_lora_bias_e2e.py b/tests/lora/test_lora_bias_e2e.py index c2520c847d873..2f7ab4128b553 100644 --- a/tests/lora/test_lora_bias_e2e.py +++ b/tests/lora/test_lora_bias_e2e.py @@ -28,6 +28,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("lora_bias", [True]) @pytest.mark.parametrize("fully_sharded", [True, False]) def test_lora_bias(lora_bias_files: str, lora_bias: bool, fully_sharded: bool): diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py index 78bf5a1617233..5f91f5710bc69 100644 --- a/tests/lora/test_minicpmv.py +++ b/tests/lora/test_minicpmv.py @@ -56,6 +56,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.xfail( current_platform.is_rocm(), reason="MiniCPM-V dependency xformers incompatible with ROCm") diff --git a/tests/lora/test_phi.py b/tests/lora/test_phi.py index 5a3fcb8d690d9..8656a4b6b870b 100644 --- a/tests/lora/test_phi.py +++ b/tests/lora/test_phi.py @@ -1,5 +1,7 @@ from typing import List +import pytest + import vllm from vllm.lora.request import LoRARequest @@ -46,6 +48,14 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + def test_phi2_lora(phi2_lora_files): # We enable enforce_eager=True here to reduce VRAM usage for lora-test CI, # Otherwise, the lora-test will fail due to CUDA OOM. diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 66b5f82bbb97d..c90baa5307116 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -13,6 +13,12 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink + +from vllm.lora.ops.v1.lora_expand import lora_expand +from vllm.lora.ops.v1.lora_shrink import lora_shrink +from vllm.lora.ops.v1.lora_expand_slice import lora_expand_slice +from vllm.lora.punica_wrapper.v1_gpu import V1KernelMeta + from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -378,3 +384,153 @@ def test_punica_expand_nslices( slice_offset += hidden_size assert_close(our_outputs, ref_outputs) + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("op_type", ["shrink", "expand"]) +@pytest.mark.parametrize("seq_length", [1, 128]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_v1_shrink_expand( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + scaling: float, + dtype: torch.dtype, + op_type: str, + seq_length: int, + seed: int, + device: str, +): + torch.set_default_device(device) + current_platform.seed_everything(seed) + + ( + inputs_tensor, + lora_weights, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + prompt_lora_mapping, + seq_len_tensor, + token_lora_mapping, + ) = generate_data( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + op_type, + device, + ) + + v1_meta: V1KernelMeta = V1KernelMeta.make(max_loras = num_loras, max_num_tokens = seq_length * batches, device=device) + v1_meta.reset() + v1_meta.prepare_tensors(token_lora_mapping) + + if op_type == "shrink": + lora_shrink(inputs_tensor, + lora_weights, + our_out_tensor, + *v1_meta.meta_args, + scaling ) + else: + lora_expand(inputs_tensor, + lora_weights, + our_out_tensor, + *v1_meta.meta_args, + add_inputs=True) + + ref_torch_groupgemm( + ref_out_tensor, + inputs_tensor, + lora_weights, + prompt_lora_mapping, + seq_len_tensor, + batches, + scaling if op_type == "shrink" else 1.0, + op_type, + ) + if op_type == "shrink": + ref_out_tensor = ref_out_tensor.to(torch.float32) + assert_close(our_out_tensor, ref_out_tensor) + + +@pytest.mark.parametrize("batches", BATCHES) +@pytest.mark.parametrize("num_loras", NUM_LORA) +@pytest.mark.parametrize("rank", MAX_RANKS) +@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) +@pytest.mark.parametrize("nslices", [2, 3]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seq_length", [1, 128]) +@pytest.mark.parametrize("seed", SEED) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_v1_expand_nslices( + batches: int, + num_loras: int, + rank: int, + hidden_size: int, + nslices: int, + dtype: torch.dtype, + seq_length: int, + seed: int, + device: str, +): + + torch.set_default_device(device) + current_platform.seed_everything(seed) + + ( + inputs_tensor, + lora_weights_lst, + our_outputs, + ref_outputs, + b_seq_start_loc, + prompt_lora_mapping, + seq_len_tensor, + token_lora_mapping, + ) = generate_data_for_expand_nslices( + batches, + hidden_size, + num_loras, + rank, + seq_length, + dtype, + nslices, + device, + ) + + v1_meta: V1KernelMeta = V1KernelMeta.make(max_loras = num_loras, max_num_tokens = seq_length * batches, device=device) + v1_meta.reset() + v1_meta.prepare_tensors(token_lora_mapping) + + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + lora_expand_slice(inputs_tensor, + lora_weights, + our_outputs, + *v1_meta.meta_args, + slice_offset, + hidden_size, + add_inputs=True) + + ref_torch_groupgemm( + ref_outputs[:, slice_offset:slice_offset + hidden_size], + inputs_tensor, + lora_weights, + prompt_lora_mapping, + seq_len_tensor, + batches, + 1.0, + op_type="expand", + ) + + slice_offset += hidden_size + assert_close(our_outputs, ref_outputs) \ No newline at end of file diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 026269667b473..20532400bb1ac 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -68,6 +68,14 @@ def format_prompt_tuples(prompt): return generated_texts +@pytest.fixture(autouse=True) +def v1(run_with_both_engines_lora): + # Simple autouse wrapper to run both engines for each test + # This can be promoted up to conftest.py to run for every + # test in a package + pass + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tp_size", [1]) def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, diff --git a/vllm/config.py b/vllm/config.py index ac767bbe14be4..d488a95f2e171 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3149,11 +3149,11 @@ def __post_init__(self): " Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - if self.lora_config is not None and self.compilation_config.level !=\ - CompilationLevel.NO_COMPILATION: - logger.warning("LoRA is not supported with `torch.compile` yet. " - "Disabling `torch.compile`.") - self.compilation_config.level = CompilationLevel.NO_COMPILATION + #if self.lora_config is not None and self.compilation_config.level !=\ + # CompilationLevel.NO_COMPILATION: + # logger.warning("LoRA is not supported with `torch.compile` yet. " + # "Disabling `torch.compile`.") + # self.compilation_config.level = CompilationLevel.NO_COMPILATION current_platform.check_and_update_config(self) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 85164c2165a3c..5f04d2f7adb10 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -231,8 +231,9 @@ def set_lora( self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = x > self.base_layer.org_vocab_size - 1 - embeddings_indices = self.punica_wrapper.embeddings_indices + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) + embeddings_indices = torch.narrow(self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) + indices = embeddings_indices[1].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, @@ -245,11 +246,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: full_output_org = full_output if full_output.ndim == 3: full_output = full_output.view( - full_output.shape[0] * full_output.shape[1], -1) + full_output.size(0) * full_output.size(1), -1) if full_lora_a_embeddings.ndim == 3: full_lora_a_embeddings = full_lora_a_embeddings.view( - full_lora_a_embeddings.shape[0] * - full_lora_a_embeddings.shape[1], + full_lora_a_embeddings.size(0) * + full_lora_a_embeddings.size(1), -1, ) self.punica_wrapper.add_lora_embedding(full_output, @@ -1028,7 +1029,19 @@ def _get_logits( logits = lm_head.linear_method.apply(lm_head, hidden_states) if embedding_bias is not None: logits += embedding_bias - logits = tensor_model_parallel_gather(logits) + + # TODO (varun) : Replace with base layer get_logits() + if self.use_gather: + # None may be returned for rank > 0 + logits = tensor_model_parallel_gather(logits) + else: + # Gather is not supported for some devices such as TPUs. + # Use all-gather instead. + # NOTE(woosuk): Here, the outputs of every device should not be None + # because XLA requires strict SPMD among all devices. Every device + # should execute the same operations after gathering the logits. + logits = tensor_model_parallel_all_gather(logits) + if logits is None: return None @@ -1065,19 +1078,19 @@ def _get_logits( lora_logits = lora_logits.mT indices_padded = self.punica_wrapper.sampler_indices_padded lora_logits = (lora_logits.reshape( - lora_logits.shape[0] * lora_logits.shape[1], - lora_logits.shape[2], + lora_logits.size(0) * lora_logits.size(1), + lora_logits.size(2), ).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"), posinf=float("inf"), neginf=float("-inf"))) # HPU needs special handling to prune out dummy samples. if current_platform.is_hpu(): - lora_logits = lora_logits[:logits.shape[0], :] + lora_logits = lora_logits[:logits.size(0), :] logits[:, self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + - lora_logits.shape[1]] = lora_logits + lora_logits.size(1)] = lora_logits # LogitsProcessorWithLoRA always using bgmv self.punica_wrapper.add_lora_logits(logits, hidden_states, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 5c0e4e5cbc636..74a7ba8ff9cb9 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -329,7 +329,8 @@ def __init__( self.long_lora_context: Optional[LongContextLoRAContext] = None self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, max_batches=self.max_num_seqs, - device=self.device) + device=self.device, + max_loras=lora_config.max_loras) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} diff --git a/vllm/lora/ops/v1/lora_expand.py b/vllm/lora/ops/v1/lora_expand.py new file mode 100644 index 0000000000000..a33e62a001ec7 --- /dev/null +++ b/vllm/lora/ops/v1/lora_expand.py @@ -0,0 +1,234 @@ +import torch +import triton +import triton.language as tl +import math +from vllm.utils import direct_register_custom_op + +@triton.jit +def _lora_expand_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size * max rank + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + NUM_M_CTAS = tl.cdiv(M, BLOCK_M) + NUM_N_CTAS = tl.cdiv(N, BLOCK_N) + + pid = tl.program_id(0) + l = pid // (NUM_M_CTAS * NUM_N_CTAS) + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + + lora_id = tl.load(lora_ids + l) + if lora_id == -1: + # early exit for the no-lora case. + return + + # lora m indices offsets + lora_m_indices_start = tl.load(lora_token_start_loc + l) + lora_m_size = tl.load(num_tokens_per_lora + l) + + cta_m_offset = cta_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # early exit CTA + return + + cta_lora_seq_indices = token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) + + offset_k = tl.arange(0, BLOCK_K) + + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + rm = tl.load(cta_lora_seq_indices + offset_rm) + a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride + + offset_n = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + b_ptr = (lora_ptr + l0_stride * lora_id + + offset_k[:, None] * lora_k_stride + rbn[None, :] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_k_stride + + + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = tl.arange(0, BLOCK_M) + offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + c_ptr = out_ptr + rm[:, None] * cm_stride + offset_n[None, :] * cn_stride + + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < N) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + +@torch.inference_mode() +def _lora_expand( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + 1 + lora_token_start_loc: torch.Tensor, # max-loras + 2 + lora_ids: torch.Tensor, # max-loras + 1 + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens that are to be + processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] + is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] + should process. + lora_ids: LoRA ids to process. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + assert token_indices_sorted_by_lora_ids.size(0) == inputs.size(0) + assert num_tokens_per_lora.size(0) == lora_ids.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + M = inputs.size(0) + N = lora_b_weights.size(-2) + K = lora_b_weights.size(-1) + BLOCK_M = 16 + BLOCK_N = 64 + BLOCK_K = 16 + + NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA + NUM_N_CTAS = math.ceil(N / BLOCK_N) + MAX_LORAS = lora_ids.size(0) + + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + xm_stride = inputs.stride(0) + xk_stride = inputs.stride(1) + l0_stride = lora_b_weights.stride(0) + lora_n_stride = lora_b_weights.stride(1) + lora_k_stride = lora_b_weights.stride(2) + cm_stride = output_tensor.stride(0) + cn_stride = output_tensor.stride(1) + + grid = ( + MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, + ) + + _lora_expand_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + xm_stride, + xk_stride, + l0_stride, + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ) + return + +def lora_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + add_inputs: bool = False, +) -> None: + return + +try: + direct_register_custom_op( + op_name="lora_expand", + op_func=_lora_expand, + mutates_args=["output_tensor"], + fake_impl=lora_expand_fake, + ) + lora_expand = torch.ops.vllm.lora_expand + +except AttributeError: + lora_expand = _lora_expand \ No newline at end of file diff --git a/vllm/lora/ops/v1/lora_expand_slice.py b/vllm/lora/ops/v1/lora_expand_slice.py new file mode 100644 index 0000000000000..4690b6f59dfed --- /dev/null +++ b/vllm/lora/ops/v1/lora_expand_slice.py @@ -0,0 +1,244 @@ +import torch +import triton +import triton.language as tl +import math +from vllm.utils import direct_register_custom_op + +@triton.jit +def _lora_expand_slice_kernel( + input_ptr, + lora_ptr, + out_ptr, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size * max rank + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + slice_offset, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, +): + + NUM_M_CTAS = tl.cdiv(M, BLOCK_M) + NUM_N_CTAS = tl.cdiv(N, BLOCK_N) + + pid = tl.program_id(0) + l = pid // (NUM_M_CTAS * NUM_N_CTAS) + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + + lora_id = tl.load(lora_ids + l) + if lora_id == -1: + # early exit for the no-lora case. + return + + # lora m indices offsets + lora_m_indices_start = tl.load(lora_token_start_loc + l) + lora_m_size = tl.load(num_tokens_per_lora + l) + + cta_m_offset = cta_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # early exit CTA + return + + cta_lora_seq_indices = token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) + + offset_k = tl.arange(0, BLOCK_K) + + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + rm = tl.load(cta_lora_seq_indices + offset_rm) + a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride + + offset_n = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + b_ptr = (lora_ptr + l0_stride * lora_id + + offset_k[:, None] * lora_k_stride + rbn[None, :] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_k_stride + + + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = tl.arange(0, BLOCK_M) + offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + slice_offset + c_ptr = out_ptr + rm[:, None] * cm_stride + offset_cn[None, :] * cn_stride + + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < (N + slice_offset)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + + +@torch.inference_mode() +def _lora_expand_slice( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + lora_token_start_loc: torch.Tensor, # max-loras + lora_ids: torch.Tensor, # max-loras + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'b weight + output_tensor (torch.Tensor): output tensor + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens that are to be + processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] + is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] + should process. + lora_ids: LoRA ids to process. + slice_offset (int): output_tensor's offset + slice_size (int): current output_tensor's size + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + assert slice_size == lora_b_weights.size(-2) + assert token_indices_sorted_by_lora_ids.size(0) == inputs.size(0) + assert num_tokens_per_lora.size(0) == lora_ids.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weights.size(1) == 1 + lora_b_weights = lora_b_weights.squeeze(dim=1) + else: + assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) + + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + + M = inputs.size(0) + N = lora_b_weights.size(-2) + K = lora_b_weights.size(-1) + BLOCK_M = 16 + BLOCK_N = 64 + BLOCK_K = 16 + + NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA + NUM_N_CTAS = math.ceil(N / BLOCK_N) + MAX_LORAS = lora_ids.size(0) + + EVEN_K = K % BLOCK_K == 0 + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + xm_stride = inputs.stride(0) + xk_stride = inputs.stride(1) + l0_stride = lora_b_weights.stride(0) + lora_n_stride = lora_b_weights.stride(1) + lora_k_stride = lora_b_weights.stride(2) + cm_stride = output_tensor.stride(0) + cn_stride = output_tensor.stride(1) + + grid = ( + MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, + ) + + _lora_expand_slice_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + M, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + xm_stride, + xk_stride, + l0_stride, + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + slice_offset, + EVEN_K, + ADD_INPUTS, + CAST_TYPE, + BLOCK_M, + BLOCK_N, + BLOCK_K, + ) + return + +def lora_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + return + +try: + direct_register_custom_op( + op_name="lora_expand_slice", + op_func=_lora_expand_slice, + mutates_args=["output_tensor"], + fake_impl=lora_expand_slice_fake, + ) + lora_expand_slice = torch.ops.vllm.lora_expand_slice + +except AttributeError: + lora_expand_slice = _lora_expand_slice \ No newline at end of file diff --git a/vllm/lora/ops/v1/lora_expand_slices.py b/vllm/lora/ops/v1/lora_expand_slices.py new file mode 100644 index 0000000000000..1bd20865a3aca --- /dev/null +++ b/vllm/lora/ops/v1/lora_expand_slices.py @@ -0,0 +1,276 @@ +import torch +import triton +import triton.language as tl +import math + +from vllm.utils import direct_register_custom_op + +@triton.jit +def _lora_expand_slices_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + lora_seq_indices, + lora_seq_counts, + lora_seq_start_loc, + lora_ids, + xm_stride, + xk_stride, # 1 + l0_stride, # hidden_size * max_rank * num_loras + l1_stride, # hidden_size*max_rank + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + NUM_SLICES: tl.constexpr, + MAX_LORAS: tl.constexpr, + NUM_M_CTAS: tl.constexpr, + NUM_N_CTAS: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, +): + pid = tl.program_id(0) + cta_s = pid // (MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS) + cta_l = (pid // (NUM_M_CTAS * NUM_N_CTAS)) % MAX_LORAS + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + + lora_id = tl.load(lora_ids + cta_l) + if lora_id == -1: + # early exit for the no-lora case. + return + + # lora m indices offsets + if cta_l == 0: + lora_m_indices_start = tl.cast(0, tl.int32) + else: + lora_m_indices_start = tl.load(lora_seq_start_loc + cta_l - 1) + lora_m_size = tl.load(lora_seq_counts + cta_l) + + cta_m_offset = cta_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # early exit CTA + return + + cta_lora_seq_indices = lora_seq_indices + lora_m_indices_start + cta_m_offset + cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) + + offset_k = tl.arange(0, BLOCK_K) + + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + rm = tl.load(cta_lora_seq_indices + offset_rm) + a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride + + offset_n = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + b_ptr = (lora_ptr + l0_stride * cta_s + l1_stride * lora_id + + offset_k[:, None] * lora_k_stride + rbn[None, :] * lora_n_stride) + + accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + for k in range(tl.cdiv(K, BLOCK_K)): + if EVEN_K: + tiled_a = tl.load(a_ptr) + tiled_b = tl.load(b_ptr) + else: + tiled_a = tl.load(a_ptr, + mask=offset_k[None, :] < K - k * BLOCK_K, + other=0) + tiled_b = tl.load(b_ptr, + mask=offset_k[:, None] < K - k * BLOCK_K, + other=0) + if CAST_TYPE: + tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + accumulator += tl.dot( + tiled_a, + tiled_b, + ) + a_ptr += BLOCK_K * xk_stride + b_ptr += BLOCK_K * lora_k_stride + + + slice_offset = cta_s * N + tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + offset_cm = tl.arange(0, BLOCK_M) + offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + slice_offset + c_ptr = out_ptr + rm[:, None] * cm_stride + offset_cn[None, :] * cn_stride + + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < (slice_offset + N)) + if ADD_INPUTS: + tiled_out = tl.load(c_ptr, mask=c_mask) + tiled_c += tiled_out + tl.store(c_ptr, tiled_c, mask=c_mask) + +@torch.inference_mode() +def _lora_expand_slices( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_seq_indices: torch.Tensor, + lora_seq_counts: torch.Tensor, + lora_seq_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + add_inputs: bool = False, +) -> None: + """_summary_ + + Args: + inputs (torch.Tensor): input tensor + lora_b_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + + token_lora_mapping_tensor: Each token's lora id as it appears in the A matrix. + + lora_seq_indices: sorted lora-token mapping. Tokens of the same lora appear next to eachother. + This is used so a thread block knows what tokens to put next to eachother when constructing a matrix block. + Essentially, + _, lora_seq_indices = torch.sort(token_lora_mapping, stable=True) + + lora_seq_counts: number of tokens per lora id. essentially, + lora_ids, lora_seq_counts = torch.unique(indices, + sorted=False, + return_counts=True) + + lora_seq_start_loc: start index of each lora id in lora_seq_indices. essentially, + lora_seq_start_loc = torch.cumsum(lora_seq_counts, dim = 0) + + lora_ids : Set of lora ids in order according to lora_seq_counts, and lora_seq_indices + lora_ids, lora_seq_counts = torch.unique(indices, + sorted=False, + return_counts=True) + + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] + assert lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_b_weights.size(-1) + num_slices = lora_b_weights.size(0) + assert output_tensor.size(1) == lora_b_weights.size(-2) * num_slices + assert inputs.is_contiguous() + assert output_tensor.is_contiguous() + + assert lora_b_weights.ndim == 4 # (nslices, lora_num, hidden-size, rank) + assert lora_b_weights.is_contiguous() + + # TODO tuning this config + N = lora_b_weights.size(-2) + K = lora_b_weights.size(-1) + NUM_SLICES = lora_b_weights.size(0) + M = inputs.size(0) + MAX_LORAS = lora_ids.size(0) + + BLOCK_M = 32 + BLOCK_N = 32 + BLOCK_K = 16 + EVEN_K = K % BLOCK_K == 0 + NUM_M_CTAS = math.ceil(M / BLOCK_M) + NUM_N_CTAS = math.ceil(N / BLOCK_N) + + ADD_INPUTS = add_inputs + CAST_TYPE = False + if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + torch.float16, + torch.bfloat16, + ]: + CAST_TYPE = True + + grid = ( + NUM_SLICES * MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, + ) + + xm_stride = inputs.stride(0) + xk_stride = inputs.stride(1) + l0_stride = lora_b_weights.stride(0) # slice stride + l1_stride = lora_b_weights.stride(1) # lora stride + lora_n_stride = lora_b_weights.stride(2) + lora_k_stride = lora_b_weights.stride(3) + cm_stride = output_tensor.stride(0) + cn_stride = output_tensor.stride(1) + + #print (f"lora seq indices : {lora_seq_indices.dtype} {lora_seq_indices}") + #print (f"lora seq counts : {lora_seq_counts.dtype} {lora_seq_counts}") + #print (f"lora seq start loc : {lora_seq_start_loc.dtype} {lora_seq_start_loc}") + #print (f"lora ids : {lora_ids.dtype} {lora_ids}") + #print (f"num loras : {NUM_LORAS}") + #print (f"num slices : {NUM_SLICES}") + #print (f"N : {N}") + #print (f"K : {K}") + #print (f"A : {inputs.dtype} {inputs.shape}") + #print (f"B : {lora_b_weights.dtype} {lora_b_weights.shape}") + #print (f"C : {output_tensor.dtype} {output_tensor.shape}") + #print (f"A m k strides : {xm_stride} {xk_stride}") + #print (f"B k n strides : {lora_k_stride} {lora_n_stride}") + #print (f"C m n strides : {cm_stride} {cn_stride}") + + _lora_expand_slices_kernel[grid]( + inputs, + lora_b_weights, + output_tensor, + N, + K, + lora_seq_indices, + lora_seq_counts, + lora_seq_start_loc, + lora_ids, + xm_stride, + xk_stride, + l0_stride, + l1_stride, + lora_n_stride, + lora_k_stride, + cm_stride, + cn_stride, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + NUM_SLICES, + MAX_LORAS, + NUM_M_CTAS, + NUM_N_CTAS, + ADD_INPUTS, + CAST_TYPE, + ) + return + + +try: + lora_expand_slices = torch.library.custom_op("lora::v1::lora_expand_slices", + _lora_expand_slices, + mutates_args=["output_tensor"]) +except AttributeError: + lora_expand_slices = _lora_expand_slices + +def lora_expand_slices_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_seq_indices: torch.Tensor, + lora_seq_counts: torch.Tensor, + lora_seq_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + add_inputs: bool = False, +) -> None: + return + +try: + direct_register_custom_op( + op_name="lora_expand_slices", + op_func= _lora_expand_slices, + mutates_args=["output_tensor"], + fake_impl=lora_expand_slices_fake, + ) + lora_expand_slices = torch.ops.vllm.lora_expand_slices + +except AttributeError: + lora_expand = _lora_expand_slices \ No newline at end of file diff --git a/vllm/lora/ops/v1/lora_shrink.py b/vllm/lora/ops/v1/lora_shrink.py new file mode 100644 index 0000000000000..0057b066a58ae --- /dev/null +++ b/vllm/lora/ops/v1/lora_shrink.py @@ -0,0 +1,225 @@ +import torch +import triton +import triton.language as tl +import math +from vllm.utils import direct_register_custom_op + + +@triton.jit +def _lora_shrink_kernel( + input_ptr, + lora_ptr, + out_ptr, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M : tl.constexpr, + BLOCK_N : tl.constexpr, + BLOCK_K : tl.constexpr, + EVEN_K : tl.constexpr, + SPLIT_K : tl.constexpr, + NUM_M_CTAS : tl.constexpr, + NUM_N_CTAS : tl.constexpr, + ): + + pid = tl.program_id(0) + l = pid // (NUM_M_CTAS * NUM_N_CTAS) + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + cta_sk = tl.program_id(1) + + lora_id = tl.load(lora_ids + l) + if lora_id == -1: + # early exit for the no-lora case. + return + + # lora m indices offsets + lora_m_indices_start = tl.load(lora_token_start_loc + l) + lora_m_size = tl.load(num_tokens_per_lora + l) + + cta_m_offset = cta_m * BLOCK_M + if cta_m_offset >= lora_m_size: + # early exit CTA + return + + cta_lora_seq_indices = token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) + + offset_k = tl.max_contiguous(BLOCK_K * cta_sk + tl.arange(0, BLOCK_K), BLOCK_K) + + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + rm = tl.load(cta_lora_seq_indices + offset_rm) + a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride + + offset_n = tl.max_contiguous((cta_n * BLOCK_N)+ tl.arange(0, BLOCK_N), BLOCK_N) + rn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + b_ptr = lora_ptr + lora_id * l0_stride + rn[None, :] * lora_n_stride + offset_k[:, None] * lora_k_stride + + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) + max_k = tl.cdiv(K, BLOCK_K * SPLIT_K) + for k in range(0, max_k): + if EVEN_K: + b_tile = tl.load(b_ptr) + a_tile = tl.load(a_ptr) + else: + b_mask = offset_k[:, None] < K + b_tile = tl.load(b_ptr, mask=b_mask, other=0.0) + + a_mask = offset_k[None, :] < K + a_tile = tl.load(a_ptr, mask=a_mask, other=0.0) + + # TODO (varun) : When a_tile and b_tile are float16s the output is also float16. this can + # lead to infs in the output. + acc += tl.dot(a_tile, b_tile) + + + a_ptr += BLOCK_K * SPLIT_K * xk_stride + b_ptr += BLOCK_K * SPLIT_K * lora_k_stride + offset_k += BLOCK_K * SPLIT_K + + acc *= scaling + + offset_cm = tl.arange(0, BLOCK_M) + c_ptr = out_ptr + rm[:, None] * cm_stride + offset_n[None, :] * cn_stride + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_n[None, :] < N) + if SPLIT_K == 1: + tl.store(c_ptr, acc, mask=c_mask) + else: + tl.atomic_add(c_ptr, acc, mask=c_mask) + +@torch.inference_mode() +def _lora_shrink( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + lora_token_start_loc: torch.Tensor, # max-loras + lora_ids: torch.Tensor, # max-loras + scaling: float, +) -> None: + """ + Args: + inputs (torch.Tensor): input tensor + lora_a_weights (torch.Tensor): lora'a weight + output_tensor (torch.Tensor): output tensor + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens that are to be + processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] + is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] + identifies the the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] + should process. + lora_ids: LoRA ids to process. + add_inputs (bool, optional): Defaults to False, adds the final lora + results to the output. + """ + + assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype in [torch.float16, torch.bfloat16] + assert lora_a_weights.dtype in [ + torch.float16, + torch.bfloat16, + ] + assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.is_contiguous() + + if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) + assert lora_a_weights.size(1) == 1 + lora_a_weights = lora_a_weights.squeeze(dim=1) + else: + assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) + assert lora_a_weights.is_contiguous() + assert output_tensor.is_contiguous() + assert token_indices_sorted_by_lora_ids.size(0) == inputs.size(0) + assert num_tokens_per_lora.size(0) == lora_ids.size(0) + assert lora_token_start_loc.size(0) == lora_ids.size(0) + 1 + + xm_stride = inputs.stride(0) + xk_stride = inputs.stride(1) + l0_stride = lora_a_weights.stride(0) + lora_k_stride = lora_a_weights.stride(2) + lora_n_stride = lora_a_weights.stride(1) + cm_stride = output_tensor.stride(0) + cn_stride = output_tensor.stride(1) + + # TODO tuning this config + M = inputs.size(0) # num tokens + N = lora_a_weights.size(-2) + K = lora_a_weights.size(-1) + MAX_LORAS = lora_ids.size(0) + BLOCK_M = 32 + BLOCK_N = 16 + BLOCK_K = 16 + SPLIT_K = 64 + EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 + NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA + NUM_N_CTAS = math.ceil(N / BLOCK_N) + + grid = ( + MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, + SPLIT_K, + ) + + _lora_shrink_kernel[grid]( + inputs, + lora_a_weights, + output_tensor, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M, + BLOCK_N, + BLOCK_K, + EVEN_K, + SPLIT_K, + NUM_M_CTAS, + NUM_N_CTAS, + ) + return + +def lora_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, + num_tokens_per_lora: torch.Tensor, + lora_token_start_loc: torch.Tensor, + lora_ids: torch.Tensor, + scaling: float, +)-> None: + return + +try: + direct_register_custom_op( + op_name="lora_shrink", + op_func=_lora_shrink, + mutates_args=["output_tensor"], + fake_impl=lora_shrink_fake, + ) + lora_shrink = torch.ops.vllm.lora_shrink + +except AttributeError: + lora_shrink = _lora_shrink \ No newline at end of file diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index b9ec0c4bc6323..e86201909896b 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -167,7 +167,7 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, self.is_prefill = False self.no_lora = False - def _update_base_metadata( + def update_base_metadata( self, mapping: "LoRAMapping", lora_index_to_id: List[Optional[int]], @@ -192,16 +192,16 @@ def _update_base_metadata( self.device, long_lora_context, ) - self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices) - self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) - self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( - sampler_indices_padded) + self._token_lora_indices[:base_indices.size(0)].copy_(base_indices, non_blocking=True) + self._sampler_indices[:sampler_indices.size(0)].copy_(sampler_indices, non_blocking=True) + self._sampler_indices_padded[:sampler_indices_padded.size(0)].copy_( + sampler_indices_padded, non_blocking=True) self._embeddings_indices[:embeddings_indices. - shape[0], :embeddings_indices.shape[1]].copy_( - embeddings_indices) + size(0), :embeddings_indices.size(1)].copy_( + embeddings_indices, non_blocking=True) if long_lora_offsets_tensor is not None: - self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_( - long_lora_offsets_tensor) + self._long_lora_indices[:long_lora_offsets_tensor.size(0)].copy_( + long_lora_offsets_tensor, non_blocking=True) else: self._long_lora_indices.zero_() self.indices_len[:] = indices_len @@ -212,10 +212,10 @@ def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None: batch_size, max_length, token_nums, no_lora) = compute_meta(token_lora_tensor) - self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_( + self._seq_start_locs[:b_seq_start_tensor.size(0)].copy_( b_seq_start_tensor) - self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor) - self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_( + self._seq_lengths[:seq_length_tensor.size(0)].copy_(seq_length_tensor) + self._lora_indices_per_batch[:lora_indices_tensor.size(0)].copy_( lora_indices_tensor) self.batch_size = batch_size self.max_length = max_length @@ -239,14 +239,14 @@ def _apply_bias( where n is number of slices """ org_output = output - output = output.view(-1, output.shape[-1]) + output = output.view(-1, output.size(-1)) indices = indices.view(-1) offset_left = 0 for slice_idx, slice in enumerate(output_slices): bias = lora_bias_stacked[slice_idx] if bias is not None: - bias = bias.view(-1, bias.shape[-1]) + bias = bias.view(-1, bias.size(-1)) bias = bias[indices] bias[indices == -1] = 0 output[:, offset_left:offset_left + slice] += bias @@ -328,7 +328,7 @@ def update_metadata( long_lora_context: Optional["LongContextLoRAContext"] = None, **kwargs): - self._update_base_metadata(mapping, lora_index_to_id, max_loras, + self.update_base_metadata(mapping, lora_index_to_id, max_loras, vocab_size, extra_vocab_size, long_lora_context) if mapping.is_prefill: diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index cd64878d95ae3..d8ca03dff8982 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,5 +1,6 @@ from vllm.platforms import current_platform from vllm.utils import print_info_once +import vllm.envs as envs from .punica_base import PunicaWrapperBase @@ -7,9 +8,14 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: if current_platform.is_cuda_alike(): # Lazy import to avoid ImportError - from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU - print_info_once("Using PunicaWrapperGPU.") - return PunicaWrapperGPU(*args, **kwargs) + if envs.VLLM_USE_V1: + from vllm.lora.punica_wrapper.v1_gpu import V1LoRAGPU + print_info_once("Using V1LoRAGPU.") + return V1LoRAGPU(*args, **kwargs) + else: + from vllm.lora.punica_wrapper.punica_gpu import PunicaWrapperGPU + print_info_once("Using PunicaWrapperGPU.") + return PunicaWrapperGPU(*args, **kwargs) elif current_platform.is_hpu(): # Lazy import to avoid ImportError from vllm.lora.punica_wrapper.punica_hpu import PunicaWrapperHPU diff --git a/vllm/lora/punica_wrapper/v1_gpu.py b/vllm/lora/punica_wrapper/v1_gpu.py new file mode 100644 index 0000000000000..eb4f55b2a93a3 --- /dev/null +++ b/vllm/lora/punica_wrapper/v1_gpu.py @@ -0,0 +1,403 @@ +from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union, final, List +from dataclasses import dataclass + +import torch + +from vllm.lora.layers import LoRAMapping +from vllm.triton_utils import HAS_TRITON + +if HAS_TRITON: + from vllm.lora.ops.v1.lora_expand import lora_expand + from vllm.lora.ops.v1.lora_shrink import lora_shrink + from vllm.lora.ops.v1.lora_expand_slice import lora_expand_slice + #from vllm.lora.ops.v1.lora_expand_slices import lora_expand_slices + +from .punica_base import PunicaWrapperBase + +if TYPE_CHECKING: + # avoid circuit import + from vllm.lora.models import LongContextLoRAContext + +@dataclass +class V1KernelMeta: + token_indices_sorted_by_lora_ids: torch.Tensor + active_lora_ids: torch.Tensor + num_tokens_per_lora: torch.Tensor + lora_token_start_loc: torch.Tensor + + @staticmethod + def make(max_loras: int, max_num_tokens: int, device: torch.device) -> "V1KernelMeta": + token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + + # +1 because "no-lora" is also a possibility + # example: let max_loras be 3, active_lora_ids of [-1, 0, 1, 2] + # is a possibility. + active_lora_ids = torch.empty(max_loras + 1, + dtype=torch.int32, + device=device) + + # using running example, [3, 10, 5, 2] is a possibility. + num_tokens_per_lora = torch.zeros(max_loras + 1, + dtype=torch.int32, + device=device) + + # +2 for this because, the first index is always 0 + # for example: let max loras be 3, then lora_token_start_loc, + # can be [0, 3, 13, 18, 20]. + lora_token_start_loc = torch.zeros(max_loras + 2, + dtype=torch.int32, + device=device) + return V1KernelMeta(token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, + active_lora_ids = active_lora_ids, + num_tokens_per_lora=num_tokens_per_lora, + lora_token_start_loc = lora_token_start_loc) + + def reset(self): + self.active_lora_ids.fill_(-1) + self.num_tokens_per_lora.fill_(0) + self.lora_token_start_loc.fill_(0) + + def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: + num_tokens = token_lora_mapping.size(0) + # token_indices_sorted_by_lora_ids + _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, stable=True) + # start gpu transfer + self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(token_indices_sorted_by_lora_ids, + non_blocking=True) + + # active_lora_ids, num_tokens_per_lora + lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, + sorted=False, + return_counts=True) + self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, non_blocking=True) + self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_(num_tokens_per_lora, non_blocking=True) + + # lora_token_start_loc + lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) + self.lora_token_start_loc[1: 1 + lora_token_start_loc.size(0)].copy_(lora_token_start_loc, + non_blocking=True) + + def meta_args(self, num_tokens: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return (self.token_indices_sorted_by_lora_ids[:num_tokens], + self.num_tokens_per_lora, + self.lora_token_start_loc, + self.active_lora_ids) + +@final +class V1LoRAGPU(PunicaWrapperBase): + """ + TODO (varun) + _summary_ + + Args: + PunicaWrapperBase (_type_): _description_ + """ + + + def __init__(self, max_num_batched_tokens: int, max_batches: int, + device: Union[torch.device, str], **kwargs): + PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, + device) + self.max_loras = kwargs['max_loras'] + self.token_mapping_v1_meta = V1KernelMeta.make(self.max_loras, max_num_batched_tokens, device=device) + self.prompt_mapping_v1_meta = V1KernelMeta.make(self.max_loras, max_batches, device=device) + + def update_metadata( + self, + mapping: LoRAMapping, + lora_index_to_id: List[Optional[int]], + max_loras: int, + vocab_size: int, + extra_vocab_size: int, + long_lora_context: Optional["LongContextLoRAContext"] = None, + **kwargs): + + # TODO (varun) : Make no_lora case work with torch compile + #self.no_lora = all([x == -1 for x in mapping.prompt_mapping]) + #if self.no_lora: + # # no update required + # return + + print (f"lora update metadata ...") + + num_tokens: int = len(mapping.index_mapping) + + self.token_mapping_v1_meta.reset() + self.prompt_mapping_v1_meta.reset() + + self.update_base_metadata(mapping, lora_index_to_id, max_loras, + vocab_size, extra_vocab_size, + long_lora_context) + + self.token_mapping_v1_meta.prepare_tensors(self.token_lora_indices) + self.prompt_mapping_v1_meta.prepare_tensors(self.sampler_indices) + + def _shrink( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + scale: float, + ): + lora_shrink(x, + w_t_all, + y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), + scale) + + def _expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + add_inputs: bool, + ): + lora_expand(x, w_t_all, y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), + add_inputs) + + def _expand_slice( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_inputs: bool, + ): + #No LoRA request, so return directly + if self.no_lora: + return + + lora_expand_slice(x, w_t_all, y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), + y_offset, + y_slice_size, + add_inputs) + + def _apply_expand( + self, + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + y_offset: Optional[int], + y_slice_size: Optional[int], + add_inputs: bool = True, + ): + """ + Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` + computation, which is suitable for the + GEMM of lora'b. + """ + self._expand_slice(y, x, w_t_all, y_offset, y_slice_size, add_inputs) + + def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, scale: float): + """ + Perform the ` y+=x@w_t_all` computation, which is suitable for the + GEMM of lora'a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + """ + y_org = y + y = y.view(-1, y.size(-1)) + self._shrink(y, x, w_t_all, scale) + y = y.view_as(y_org) + + def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], + x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], + scale: float, **kwargs): + """ + Performs GEMM for multiple slices of lora_a. + When `is_prefill is` true, it indicates that it is currently the + prefill stage, and the `_shrink_prefill` function should be called. + Otherwise, it is the decode stage, and the _shrink_decode function + should be called. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += (x @ lora_a_stacked[i]) * scale + + Args: + y (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Output tensors + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weights + scale (float): Scaling factor for the operation + """ + + x = x.view(-1, x.size(-1)) + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], + scale) + + def add_expand(self, + y: torch.Tensor, + x: Union[Tuple[torch.Tensor, ...], torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + output_slices: Tuple[int, ...], + offset_start: int = 0, + add_inputs=True, + **kwargs) -> None: + """ + Performs GEMM and bias addition for multiple slices of lora_b. + + Semantics: + for i in range(len(lora_b_stacked)): + slice = output_slices[i] + y[:, offset:offset+slice] += x[i] @ lora_b_stacked[i] + + lora_bias_stacked[i] + offset += slice + + Args: + y (torch.Tensor): Output tensor. + x (Union[Tuple[torch.Tensor, ...], torch.Tensor]): Input tensors + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): + bias's weight + output_slices (Tuple[int, ...]): Every slice's size + add_inputs (bool): Defaults to True. + """ + y_org = y + y = y.view(-1, y.size(-1)) + offset_left = offset_start + if lora_bias_stacked is not None: + self._apply_bias(self._token_lora_indices[:x.size(0)], y, output_slices, + lora_bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_left, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_left += output_slices[slice_idx] + y = y.view_as(y_org) + + def add_lora_embedding(self, + y: torch.Tensor, + x: torch.Tensor, + lora_b_stacked: torch.Tensor, + add_inputs: bool = True, + **kwargs) -> None: + """ + Applies lora specifically for VocabParallelEmbeddingWithLoRA. + + Semantics: + y += x @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_b_stacked (torch.Tensor): lora_b's weights. + add_inputs (bool): Default to True. + """ + + # Embedding layer only need expand op + self._expand(y, x, lora_b_stacked, add_inputs) + + def add_lora_linear(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + lora_bias_stacked: Optional[Tuple[torch.Tensor, ...]], + scale: float, + output_slices: Tuple[int, ...], + *, + buffer: Optional[Tuple[torch.Tensor, ...]] = None, + **kwargs) -> None: + """ + Applicable to linear-related lora. + + Semantics: + for i in range(len(lora_a_stacked)): + y[i] += ( + x[i].unsqueeze(0) + @ lora_a_stacked[indices[i], layer_idx, :, :] + @ lora_b_stacked[indices[i], layer_idx, :, :] + * scale + ).squeeze(0)+lora_bias_stacked[i] + + Args: + y (torch.Tensor): Output tensor. Will be changed in-place. + x (torch.Tensor): Input tensor + lora_a_stacked (Tuple[torch.Tensor, ...]): lora_a's weight. + lora_b_stacked (Tuple[torch.Tensor, ...]): lora_b's weight. + lora_bias_stacked (Optional[Tuple[torch.Tensor, ...]]): lora's bias. + scale (float): Scaling factor. + output_slices (Tuple[int, ...]): Every slice's size. + buffer (Optional[Tuple[torch.Tensor, ...]]): Defaults to None. + """ + + assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) + if lora_bias_stacked is not None: + assert len(lora_bias_stacked) == len(output_slices) + y = self._apply_bias(self._token_lora_indices[:x.size(0)], y, output_slices, + lora_bias_stacked) + + if buffer is None: + r = lora_b_stacked[0].size(-1) + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = tuple( + torch.zeros( + (x.size(0), r), dtype=torch.float32, device=x.device) + for _ in range(len(output_slices))) + self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) + self.add_expand(y, + buffer, + lora_b_stacked, + None, + output_slices, + add_inputs=True, + **kwargs) + + def add_lora_logits(self, + y: torch.Tensor, + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + scale, + *, + buffer: Optional[torch.Tensor] = None, + **kwargs) -> None: + """ + Applies lora specifically for LogitsProcessorWithLoRA. + + Semantics: + buffer = (x @ lora_a_stacked) * scale + y += buffer @ lora_b_stacked + + Args: + y (torch.Tensor): Output tensor. + x (torch.Tensor): Input tensor. + lora_a_stacked (torch.Tensor): lora_a's weights. + lora_b_stacked (torch.Tensor):lora_b's weights. + scale (float): Scaling factor. + buffer (Optional[torch.Tensor]):Default to None. + """ + y_org = y + y = y.view(-1, y.size(-1)) + x = x.view(-1, x.size(-1)) + r = lora_b_stacked.size(-1) + if buffer is None: + # We set the buffer to be float32 by default ,refer to: + # https://github.com/triton-lang/triton/issues/1387 + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + + lora_shrink(x, lora_a_stacked, buffer, *self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale) + lora_expand(buffer, lora_b_stacked, y, + *self.prompt_mapping_v1_meta.meta_args(x.size(0)), + add_inputs=True) + + y = y.view_as(y_org) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index 34d65ed51ef3f..de615a5348e1f 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -60,7 +60,7 @@ class SiluAndMul(CustomOp): def forward_native(self, x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" - d = x.shape[-1] // 2 + d = x.size(-1) // 2 return F.silu(x[..., :d]) * x[..., d:] def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index 43ea4eb5a4d1a..43c9424801222 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -46,7 +46,7 @@ def forward_native( x = x + residual.to(torch.float32) residual = x.to(orig_dtype) - hidden_size = x.shape[-1] + hidden_size = x.size(-1) if hidden_size != self.hidden_size: raise ValueError("Expected hidden_size to be " f"{self.hidden_size}, but found: {hidden_size}") diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index 117fe086e5e87..eaf8e356ed416 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -129,7 +129,7 @@ def forward_native( if offsets is not None: positions = positions + offsets positions = positions.flatten() - num_tokens = positions.shape[0] + num_tokens = positions.size(0) cos_sin = self.cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index 08e7c0fd4dc9b..140d6d00b705d 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -5,6 +5,7 @@ from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.multimodal.base import PlaceholderRange from vllm.sampling_params import SamplingParams @@ -32,8 +33,6 @@ def __init__( self.scheduler_config = scheduler_config self.cache_config = cache_config self.lora_config = lora_config - # TODO: Support LoRA. - assert lora_config is None, "V1 does not support LoRA yet." # Scheduling constraints. self.max_num_running_reqs = self.scheduler_config.max_num_seqs @@ -173,6 +172,14 @@ def schedule(self) -> "SchedulerOutput": self.encoder_cache_manager.allocate(request, i) encoder_budget = new_encoder_budget + # Record the LoRAs in scheduled_running_reqs + requested_loras: Set[int] = set() + if self.lora_config: + requested_loras = set( + req.lora_request.lora_int_id for req in scheduled_running_reqs + if req.lora_request and req.lora_request.lora_int_id > 0) + assert len(requested_loras) <= self.lora_config.max_loras + # Next, schedule the WAITING requests. if not preempted_reqs: while self.waiting: @@ -184,6 +191,17 @@ def schedule(self) -> "SchedulerOutput": break request = self.waiting[0] + + # Check that adding the request still respects the max_loras + # constraint. + if self.lora_config and request.lora_request: + req_lora_id = request.lora_request.lora_int_id + if len(requested_loras) == self.lora_config.max_loras and ( + req_lora_id not in requested_loras): + # cannot schedule + break + requested_loras.add(req_lora_id) + # Get already-cached tokens. computed_blocks = self.kv_cache_manager.get_computed_blocks( request) @@ -521,6 +539,7 @@ class NewRequestData: sampling_params: SamplingParams block_ids: List[int] num_computed_tokens: int + lora_request: Optional[LoRARequest] @classmethod def from_request( @@ -539,6 +558,7 @@ def from_request( sampling_params=request.sampling_params, block_ids=block_ids, num_computed_tokens=num_computed_tokens, + lora_request=request.lora_request, ) diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index cc0c7ea23469a..51a83847330d3 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -22,6 +22,8 @@ class DetokenizerRequest: stop: List[str] include_stop_str_in_output: bool + lora_request: Optional[LoRARequest] + @dataclass class EngineCoreRequest: diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index ba2b8377759d6..bbe1d8f4c2ac5 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -314,8 +314,7 @@ async def get_tokenizer( self, lora_request: Optional[LoRARequest] = None, ) -> AnyTokenizer: - assert lora_request is None - return self.detokenizer.tokenizer + return self.detokenizer.get_tokenizer(lora_request) async def is_tracing_enabled(self) -> bool: return False diff --git a/vllm/v1/engine/detokenizer.py b/vllm/v1/engine/detokenizer.py index 02f34e2b54dd5..af8bc151500d4 100644 --- a/vllm/v1/engine/detokenizer.py +++ b/vllm/v1/engine/detokenizer.py @@ -3,11 +3,12 @@ from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger +from vllm.lora.request import LoRARequest from vllm.outputs import RequestOutput from vllm.sampling_params import RequestOutputKind from vllm.transformers_utils.detokenizer_utils import ( AnyTokenizer, convert_prompt_ids_to_tokens, detokenize_incrementally) -from vllm.transformers_utils.tokenizer import get_tokenizer +from vllm.transformers_utils.tokenizer import get_lora_tokenizer, get_tokenizer from vllm.v1.engine import DetokenizerRequest, EngineCoreOutput logger = init_logger(__name__) @@ -197,16 +198,25 @@ def __init__(self, tokenizer_mode: str = "auto", trust_remote_code: bool = False, revision: Optional[str] = None): - # TODO: once we support LoRA, we should should pass the tokenizer - # here. We currently have two copies (this + in the LLMEngine). - self.tokenizer = get_tokenizer(tokenizer_name=tokenizer_name, - tokenizer_mode=tokenizer_mode, - trust_remote_code=trust_remote_code, - revision=revision) + # per-request tokenizers, like in LoRA, are created in + # add_request. All other requests use the base_tokenizer. + self._base_tokenizer = get_tokenizer( + tokenizer_name=tokenizer_name, + tokenizer_mode=tokenizer_mode, + trust_remote_code=trust_remote_code, + revision=revision) # Request id -> IncrementalDetokenizer self.request_states: Dict[str, IncrementalDetokenizer] = {} + def get_tokenizer(self, + lora_request: Optional[LoRARequest] = None + ) -> AnyTokenizer: + if lora_request: + return get_lora_tokenizer(lora_request) + else: + return self._base_tokenizer + def is_request_active(self, request_id: str): return request_id in self.request_states @@ -233,8 +243,9 @@ def add_request( assert (request.request_id not in self.request_states) + req_tokenizer = self.get_tokenizer(request.lora_request) request_state = IncrementalDetokenizer.from_new_request( - self.tokenizer, request) + req_tokenizer, request) self.request_states[request.request_id] = request_state def step( diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 6ee8732bc902c..d3be673044d28 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -50,7 +50,7 @@ def __init__( self.mm_hasher = MMHasher() # TODO: run in an ThreadpoolExecutor or BackgroundProcess. - # This ideally should releases the GIL, so we should not block the + # This ideally should release the GIL, so we should not block the # asyncio loop while this is running. def process_inputs( self, @@ -133,6 +133,7 @@ def process_inputs( sampling_params.output_kind, sampling_params.stop, sampling_params.include_stop_str_in_output, + lora_request, ) # Make Request for EngineCore. diff --git a/vllm/v1/worker/gpu_input_batch.py b/vllm/v1/worker/gpu_input_batch.py index 6c4d300ec6efe..a7bb4d37e30da 100644 --- a/vllm/v1/worker/gpu_input_batch.py +++ b/vllm/v1/worker/gpu_input_batch.py @@ -1,11 +1,12 @@ # Datastructures defining an input batch from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Set +from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple import numpy as np import torch +from vllm.lora.request import LoRARequest from vllm.multimodal import MultiModalKwargs from vllm.sampling_params import SamplingParams, SamplingType from vllm.v1.sample.metadata import SamplingMetadata @@ -29,6 +30,8 @@ class CachedRequestState: num_computed_tokens: int output_token_ids: List[int] + lora_request: Optional[LoRARequest] + @property def num_tokens(self) -> int: return len(self.prompt_token_ids) + len(self.output_token_ids) @@ -157,6 +160,11 @@ def __init__( ] self.prompt_token_ids: Optional[torch.Tensor] = None + # lora related + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_requests: Set[LoRARequest] = set() + # req_index -> generator # NOTE(woosuk): The indices of the requests that do not have their own # generator should not be included in the dictionary. @@ -231,6 +239,15 @@ def add_request( if sampling_params.prompt_logprobs: self.prompt_logprob_reqs.add(req_id) + # Add request lora ID + if request.lora_request: + self.request_lora_mapping[ + req_index] = request.lora_request.lora_int_id + self.lora_requests.add(request.lora_request) + else: + # No LoRA + self.request_lora_mapping[req_index] = 0 + def remove_request(self, req_id: str) -> Optional[int]: req_index = self.req_id_to_index.pop(req_id, None) if req_index is None: @@ -247,6 +264,12 @@ def remove_request(self, req_id: str) -> Optional[int]: self.generators.pop(req_index, None) self.num_logprobs.pop(req_id, None) self.prompt_logprob_reqs.discard(req_id) + + # LoRA + # only update request_lora_mapping. Defer the updates + # to lora_requests to prepare_lora_inputs. + self.request_lora_mapping[req_index] = 0 + return req_index def clear(self) -> None: @@ -262,6 +285,9 @@ def clear(self) -> None: self.generators.clear() self.num_logprobs.clear() self.prompt_logprob_reqs.clear() + self.request_lora_mapping = np.zeros((self.max_num_reqs, ), + dtype=np.int32) + self.lora_requests.clear() def condense(self, empty_req_indices: List[int]) -> None: if self.num_reqs == 0: @@ -315,6 +341,9 @@ def condense(self, empty_req_indices: List[int]) -> None: if generator is not None: self.generators[empty_index] = generator + self.request_lora_mapping[empty_index] = self.request_lora_mapping[ + last_req_index] + # Decrement last_req_index since it is now empty. last_req_index -= 1 @@ -398,6 +427,34 @@ def _make_prompt_token_ids_tensor(self) -> torch.Tensor: return prompt_token_ids_cpu_tensor.to(device=self.device, non_blocking=True) + def make_lora_inputs( + self, num_scheduled_tokens: np.ndarray + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], Set[LoRARequest]]: + """ + Given the num_scheduled_tokens for each request in the batch, return + datastructures used to activate the current LoRAs. + Returns: + 1. prompt_lora_mapping: A tuple of size self.num_reqs where, + prompt_lora_mapping[i] is the LoRA id to use for the ith prompt. + 2. token_lora_mapping: A tuple of size np.sum(num_scheduled_tokens) + where, token_lora_mapping[i] is the LoRA id to use for ith token. + 3. lora_requests: Set of relevant LoRA requests. + """ + + req_lora_mapping = self.request_lora_mapping[:self.num_reqs] + prompt_lora_mapping = tuple(req_lora_mapping) + token_lora_mapping = tuple( + req_lora_mapping.repeat(num_scheduled_tokens)) + + active_lora_ids: Set[int] = set(np.unique(req_lora_mapping)) + active_lora_requests: Set[LoRARequest] = \ + set({lr for lr in self.lora_requests \ + if lr.lora_int_id in active_lora_ids}) + # Update lora requests + self.lora_requests = active_lora_requests + + return prompt_lora_mapping, token_lora_mapping, self.lora_requests + @property def num_reqs(self) -> int: return len(self.req_id_to_index) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 509771b7e2e5a..56f872265336d 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -23,6 +23,7 @@ from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch +from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin if TYPE_CHECKING: from vllm.v1.core.scheduler import SchedulerOutput @@ -30,7 +31,7 @@ logger = init_logger(__name__) -class GPUModelRunner: +class GPUModelRunner(LoRAModelRunnerMixin): def __init__( self, @@ -231,6 +232,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None: block_ids=new_req_data.block_ids, num_computed_tokens=new_req_data.num_computed_tokens, output_token_ids=[], + lora_request=new_req_data.lora_request, ) req_ids_to_add.append(req_id) @@ -274,15 +276,16 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): # Get the number of scheduled tokens for each request. # TODO: The Python loop can be slow. Optimize. - num_scheduled_tokens = [] + num_scheduled_tokens_list = [] max_num_scheduled_tokens = 0 for req_id in self.input_batch.req_ids[:num_reqs]: assert req_id is not None num_tokens = scheduler_output.num_scheduled_tokens[req_id] - num_scheduled_tokens.append(num_tokens) + num_scheduled_tokens_list.append(num_tokens) max_num_scheduled_tokens = max(max_num_scheduled_tokens, num_tokens) - num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, + dtype=np.int32) assert max_num_scheduled_tokens > 0 # Get request indices. @@ -364,6 +367,11 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): block_table=self.input_batch.block_table[:num_reqs], slot_mapping=slot_mapping, ) + + # Hot-Swap lora model + if self.lora_config: + self.set_active_loras(self.input_batch, num_scheduled_tokens) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial # request in the batch. While we should not sample any token from this # partial request, we do so for simplicity. We will ignore the sampled @@ -587,6 +595,12 @@ def load_model(self) -> None: logger.info("Starting to load model %s...", self.model_config.model) with DeviceMemoryProfiler() as m: # noqa: SIM117 self.model = get_model(vllm_config=self.vllm_config) + if self.lora_config: + self.model = self.load_lora_model(self.model, + self.model_config, + self.scheduler_config, + self.lora_config, + self.device) self.model_memory_usage = m.consumed_memory logger.info("Loading model weights took %.4f GB", @@ -704,15 +718,34 @@ def profile_run(self) -> None: # Cache the dummy encoder outputs. self.encoder_cache["tmp"] = dict(enumerate(dummy_encoder_outputs)) - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) - logits = self.model.compute_logits(hidden_states, None) - logits = logits[:self.max_num_tokens] - # TODO(woosuk): Consider the memory usage of the sampler. - torch.cuda.synchronize() - del hidden_states, logits - self.encoder_cache.clear() + # TODO (varun): Reconcile text-only with multi-modal + # compute num tokens per request. For profile, have maximum num_reqs and + # that collectively have maximum num_tokens. + num_reqs = self.scheduler_config.max_num_seqs + num_tokens = self.max_num_tokens + min_tokens_per_req: int = num_tokens // num_reqs + + num_scheduled_tokens_list: List[int] = [min_tokens_per_req] * num_reqs + num_scheduled_tokens_list[-1] += num_tokens % num_reqs + assert sum(num_scheduled_tokens_list) == num_tokens + assert len(num_scheduled_tokens_list) == num_reqs + + num_scheduled_tokens: np.ndarray = np.array(num_scheduled_tokens_list, + dtype=np.int32) + logit_indices = np.cumsum(num_scheduled_tokens) - 1 + + with self.maybe_profile_with_lora(self.lora_config, + num_scheduled_tokens): + print(f"running profile code ...") + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) + hidden_states = hidden_states[logit_indices] + logits = self.model.compute_logits(hidden_states, None) + # TODO(woosuk): Consider the memory usage of the sampler. + torch.cuda.synchronize() + del hidden_states, logits + self.encoder_cache.clear() gc.collect() def capture_model(self) -> None: @@ -730,10 +763,13 @@ def capture_model(self) -> None: # can reuse the memory pool allocated for the large shapes. with graph_capture(): for num_tokens in reversed(self.cudagraph_batch_sizes): - for _ in range(self.vllm_config.compilation_config. - cudagraph_num_of_warmups): + print(f"running capture model for {num_tokens}...") + with self.maybe_capture_model_with_lora(self.lora_config, + num_tokens): + for _ in range(self.vllm_config.compilation_config. + cudagraph_num_of_warmups): + self._dummy_run(self.model, num_tokens, self.kv_caches) self._dummy_run(self.model, num_tokens, self.kv_caches) - self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py new file mode 100644 index 0000000000000..d7b379c85976f --- /dev/null +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -0,0 +1,147 @@ +""" +Define LoRA adapter for model runner. +""" + +from contextlib import contextmanager +from typing import Set, Tuple, Optional + +import numpy as np +import torch.nn as nn + +from vllm.config import LoRAConfig, ModelConfig, SchedulerConfig +from vllm.logger import init_logger +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.model_executor.models import supports_lora, supports_multimodal +from vllm.v1.worker.gpu_input_batch import InputBatch + +logger = init_logger(__name__) + + +# Defined as a mixin for GPUModelRunner +class LoRAModelRunnerMixin: + + LORA_WARMUP_RANK = 8 + + def load_lora_model(self, model: nn.Module, model_config: ModelConfig, + scheduler_config: SchedulerConfig, + lora_config: LoRAConfig, device: str) -> nn.Module: + + assert supports_lora( + model), f"{model.__class__.__name__} does not support LoRA yet." + + if supports_multimodal(model): + logger.warning("Regarding multimodal models, vLLM currently " + "only supports adding LoRA to language model.") + + # It's necessary to distinguish between the max_position_embeddings + # of VLMs and LLMs. + if hasattr(model.config, "max_position_embeddings"): + max_pos_embeddings = model.config.max_position_embeddings + else: + max_pos_embeddings = ( + model.config.text_config.max_position_embeddings) + + # Add LoRA Manager to the Model Runner + self.lora_manager = LRUCacheWorkerLoRAManager( + scheduler_config.max_num_seqs, + scheduler_config.max_num_batched_tokens, + model_config.get_vocab_size(), + lora_config, + device, + model.embedding_modules, + model.embedding_padding_modules, + max_position_embeddings=max_pos_embeddings, + ) + return self.lora_manager.create_lora_manager(model) + + def _set_active_loras(self, prompt_lora_mapping: Tuple[int, ...], + token_lora_mapping: Tuple[int, ...], + lora_requests: Set[LoRARequest]) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + + # In V1, we use the same kernels for both prefill and decode. + # To that effect, is_prefill is marked None. + lora_mapping = LoRAMapping(token_lora_mapping, + prompt_lora_mapping, + is_prefill=None) + self.lora_manager.set_active_adapters(lora_requests, lora_mapping) + + def set_active_loras(self, input_batch: InputBatch, + num_scheduled_tokens: np.ndarray) -> None: + + prompt_lora_mapping: Tuple[int, ...] # of size input_batch.num_reqs + token_lora_mapping: Tuple[int, + ...] # of size np.sum(num_scheduled_tokens) + lora_requests: Set[LoRARequest] + prompt_lora_mapping, token_lora_mapping, lora_requests = \ + input_batch.make_lora_inputs(num_scheduled_tokens) + return self._set_active_loras(prompt_lora_mapping, token_lora_mapping, + lora_requests) + + @contextmanager + def maybe_profile_with_lora(self, lora_config: LoRAConfig, + num_scheduled_tokens: np.ndarray): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + num_reqs = len(num_scheduled_tokens) + num_loras = lora_config.max_loras + + # Make prompt lora mapping + # Assign LoRA IDs to requests arbitrarily + prompt_lora_mapping = np.random.randint(low=1, + high=num_loras + 1, + size=num_reqs, + dtype=np.int32) + # Make token lora mapping + token_lora_mapping = np.repeat(prompt_lora_mapping, + num_scheduled_tokens) + + # Make dummy lora requests + lora_requests: Set[LoRARequest] = { + LoRARequest(lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_path="/not/a/real/path") + for lora_id in range(1, num_loras + 1) + } + + with self.lora_manager.dummy_lora_cache(): + # Add the dummy LoRAs here so _set_active_loras doesn't try to + # load from disk. + for lr in lora_requests: + self.lora_manager.add_dummy_lora( + lr, rank=self.LORA_WARMUP_RANK) + + self._set_active_loras(tuple(prompt_lora_mapping), + tuple(token_lora_mapping), + lora_requests) + + yield + + # __exit__ code + self.lora_manager.remove_all_adapters() + + + @contextmanager + def maybe_capture_model_with_lora(self, lora_config: LoRAConfig, batch_size: int): + if lora_config is None: + yield + else: + # __enter__ code + assert self.lora_manager is not None, "LoRA is not enabled" + + prompt_lora_mapping = tuple([0] * batch_size) + token_lora_mapping = tuple([0] * batch_size) + + self._set_active_loras(prompt_lora_mapping, + token_lora_mapping, + set()) + yield + + # __exit__ code From e5b4087dee617cda4f0e395f30b7a4234463a1ec Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 16:06:29 -0500 Subject: [PATCH 02/12] lora id for prefix caching --- vllm/v1/core/kv_cache_utils.py | 56 +++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 7 deletions(-) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9ddbff7c9a604..9f0f0a47168cb 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -164,11 +164,10 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: return ret -def generate_block_hash_extra_keys( +def generate_block_hash_extra_keys_for_mm( request: Request, start_token_idx: int, end_token_idx: int, - start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: - """Generate extra keys for the block hash. The extra keys can come from - the multi-modal inputs and request specific metadata (e.g., LoRA ID). + start_mm_idx: int) -> Tuple[Optional[List[Any]], int]: + """Generate extra keys related to MultiModal request for block hash computation. For multi-modal inputs, the extra keys are (mm_hash, start_offset) that indicate a mm input contained in the block and its starting offset in the block tokens. @@ -182,7 +181,6 @@ def generate_block_hash_extra_keys( Returns: A tuple of extra keys and the next multi-modal index. """ - mm_positions, mm_hashes = request.mm_positions, request.mm_hashes if not mm_positions: return None, start_mm_idx @@ -231,8 +229,52 @@ def generate_block_hash_extra_keys( else: # This block has not reached the current mm input. break - return tuple(extra_keys), curr_mm_idx + return extra_keys, curr_mm_idx + +def generate_block_hash_extra_keys_for_lora( + request: Request) -> Optional[List[int]]: + """Generate extra keys related to LoRA for block hash computation. + + Args: + request: The request object. + + Returns: + Return LoRA id of the request if it is a LoRA request. Return None + otherwise. + """ + if not request.lora_request: + return None + return [request.lora_request.lora_int_id] + +def generate_block_hash_extra_keys( + request: Request, start_token_idx: int, end_token_idx: int, + start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: + """Generate extra keys for the block hash. The extra keys can come from + the multi-modal inputs and request specific metadata (e.g., LoRA ID). + + Args: + request: The request object. + start_token_idx: The start token index of the block. + end_token_idx: The end token index of the block. + start_mm_idx: The start multi-modal index of the block. + + Returns: + A tuple of extra keys and the next multi-modal index. + """ + mm_extra_keys: Optional[List[Any]] + mm_extra_keys, new_start_mm_idx = generate_block_hash_extra_keys_for_mm(request, start_token_idx, end_token_idx, start_mm_idx) + lora_extra_keys: Optional[List[int]] = generate_block_hash_extra_keys_for_lora(request) + + extra_keys: List[Any] = [] + if mm_extra_keys: + extra_keys.extend(mm_extra_keys) + if lora_extra_keys: + extra_keys.extend(lora_extra_keys) + + if not extra_keys: + return None, new_start_mm_idx + return tuple(extra_keys), new_start_mm_idx def hash_block_tokens( parent_block_hash: Optional[int], @@ -280,7 +322,7 @@ def hash_request_tokens(block_size: int, "The number of multi-modal positions and hashes must match.") # TODO: Extend this to support other features such as LoRA. - need_extra_keys = bool(mm_positions) + need_extra_keys = bool(mm_positions) or (request.lora_request is not None) extra_keys = None curr_mm_idx = 0 From 4a5b550e107ed8358fd23655a959d7efe082ffdc Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 16:06:47 -0500 Subject: [PATCH 03/12] remove comment --- vllm/lora/punica_wrapper/v1_gpu.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/lora/punica_wrapper/v1_gpu.py b/vllm/lora/punica_wrapper/v1_gpu.py index eb4f55b2a93a3..5214d905379d4 100644 --- a/vllm/lora/punica_wrapper/v1_gpu.py +++ b/vllm/lora/punica_wrapper/v1_gpu.py @@ -119,9 +119,6 @@ def update_metadata( #if self.no_lora: # # no update required # return - - print (f"lora update metadata ...") - num_tokens: int = len(mapping.index_mapping) self.token_mapping_v1_meta.reset() From 80ff344b9cec00bb448314fcc4ccbedc13103342 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 16:33:47 -0500 Subject: [PATCH 04/12] limit cudagraph capture size to max_num_seqs --- vllm/v1/worker/gpu_model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 56f872265336d..22490bf5c550f 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -118,6 +118,7 @@ def __init__( # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( reversed(self.vllm_config.compilation_config.capture_sizes)) + self.cudagraph_batch_sizes = [bs for bs in self.cudagraph_batch_sizes if bs <= self.scheduler_config.max_num_seqs] # Persistent buffers for CUDA graphs. self.input_ids = torch.zeros(self.max_num_tokens, From dee4001294b16f76adda40a5e526bd8bda1c7ea6 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 17:02:43 -0500 Subject: [PATCH 05/12] remove torch compile comment --- vllm/config.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index d488a95f2e171..4fe707992f8f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -3149,12 +3149,6 @@ def __post_init__(self): " Disabling `torch.compile`.") self.compilation_config.level = CompilationLevel.NO_COMPILATION - #if self.lora_config is not None and self.compilation_config.level !=\ - # CompilationLevel.NO_COMPILATION: - # logger.warning("LoRA is not supported with `torch.compile` yet. " - # "Disabling `torch.compile`.") - # self.compilation_config.level = CompilationLevel.NO_COMPILATION - current_platform.check_and_update_config(self) if not self.instance_id: From dc11242c11a2ec091fa0cca9d698bf13de2fac14 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 17:25:19 -0500 Subject: [PATCH 06/12] format --- benchmarks/benchmark_throughput.py | 15 +-- tests/lora/conftest.py | 1 + tests/lora/lora_torch_compile.py | 36 ++++--- tests/lora/test_punica_sizes.py | 32 +++--- vllm/lora/layers.py | 8 +- vllm/lora/models.py | 9 +- vllm/lora/ops/v1/lora_expand.py | 56 +++++----- vllm/lora/ops/v1/lora_expand_slice.py | 58 +++++----- vllm/lora/ops/v1/lora_expand_slices.py | 75 ++++++------- vllm/lora/ops/v1/lora_shrink.py | 112 +++++++++++--------- vllm/lora/punica_wrapper/punica_base.py | 10 +- vllm/lora/punica_wrapper/punica_selector.py | 2 +- vllm/lora/punica_wrapper/v1_gpu.py | 95 +++++++++-------- vllm/v1/core/kv_cache_utils.py | 17 +-- vllm/v1/worker/gpu_model_runner.py | 10 +- vllm/v1/worker/lora_model_runner_mixin.py | 9 +- 16 files changed, 291 insertions(+), 254 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index cc2020e2a3332..10863749778fd 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -27,8 +27,9 @@ from vllm.utils import FlexibleArgumentParser, merge_async_iterators from vllm.outputs import RequestOutput -SAMPLING_TEMPERATURE=0.0 -SAMPLING_TOP_P=1.0 +SAMPLING_TEMPERATURE = 0.0 +SAMPLING_TOP_P = 1.0 + @dataclasses.dataclass class SampleRequest: @@ -198,9 +199,9 @@ def run_vllm( if not use_beam_search: start = time.perf_counter() outputs = llm.generate(prompts, - sampling_params, - lora_request=lora_requests, - use_tqdm=True) + sampling_params, + lora_request=lora_requests, + use_tqdm=True) end = time.perf_counter() else: assert lora_requests is None, "BeamSearch API does not support LoRA" @@ -412,8 +413,8 @@ def main(args: argparse.Namespace): )) else: elapsed_time, outputs = run_vllm(requests, args.n, - EngineArgs.from_cli_args(args)) - + EngineArgs.from_cli_args(args)) + if args.pickle_outputs: print("Pickling request outputs : ") with open("outputs.pkl", "wb+") as f: diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 40f5316f9d031..1b58db7fa9f88 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -76,6 +76,7 @@ def dist_init(): yield cleanup_dist_env_and_memory(shutdown_ray=True) + @contextmanager def _dist_init(): temp_file = tempfile.mkstemp()[1] diff --git a/tests/lora/lora_torch_compile.py b/tests/lora/lora_torch_compile.py index fef535df698d6..2afd44e54a6ee 100644 --- a/tests/lora/lora_torch_compile.py +++ b/tests/lora/lora_torch_compile.py @@ -1,10 +1,7 @@ import random -from copy import deepcopy -from dataclasses import dataclass from typing import Dict, List, Optional, Tuple import torch -import torch.nn.functional as F from vllm.config import LoRAConfig # yapf conflicts with isort for this block @@ -14,18 +11,18 @@ VocabParallelEmbeddingWithLoRA) # yapf: enable from vllm.lora.punica_wrapper import get_punica_wrapper -from vllm.model_executor.utils import set_random_seed from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights, - PackedLoRALayerWeights) +from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights) from utils import DummyLoRAManager -from vllm.distributed.parallel_state import ensure_model_parallel_initialized, init_distributed_environment +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) from conftest import _dist_init + def get_random_id_to_index(num_loras: int, num_slots: int, log: bool = True) -> List[Optional[int]]: @@ -53,6 +50,7 @@ def get_random_id_to_index(num_loras: int, return slots + def populate_loras( id_to_index: List[Optional[int]], layer: BaseLayerWithLoRA, @@ -115,6 +113,7 @@ def populate_loras( return lora_dict, sublora_dict + def create_random_inputs( active_lora_ids: List[int], num_inputs: int, @@ -163,7 +162,8 @@ def create_random_inputs( vocab_size = 512 is_prefill = True max_loras = 8 -device="cuda:0" +device = "cuda:0" + def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph: print("Pre-pass:") @@ -172,7 +172,8 @@ def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph: return graph -def custom_backend(graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): +def custom_backend(graph: torch.fx.GraphModule, + example_inputs: List[torch.Tensor]): print("Graph entering custom_backend:") print(graph.print_readable()) from torch._inductor import config @@ -181,6 +182,7 @@ def custom_backend(graph: torch.fx.GraphModule, example_inputs: List[torch.Tenso current_config['post_grad_custom_post_pass'] = custom_pass return compile_fx(graph, example_inputs, config_patches=current_config) + @torch.inference_mode() def test_embeddings() -> None: @@ -188,7 +190,7 @@ def test_embeddings() -> None: torch.set_default_device(device) init_distributed_environment(1, 0) - ensure_model_parallel_initialized(1,1) + ensure_model_parallel_initialized(1, 1) max_loras = 8 punica_wrapper = get_punica_wrapper(8192, 256, device) @@ -221,23 +223,23 @@ def create_random_embedding_layer(): input_size=(200, ), input_range=(1, vocab_size), device=device) - lora_mapping = LoRAMapping(index_mapping, - prompt_mapping, - is_prefill=True) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping, is_prefill=True) punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras, vocab_size, lora_config.lora_extra_vocab_size) - lora_embedding_compiled = torch.compile(lora_embedding, backend=custom_backend) + lora_embedding_compiled = torch.compile(lora_embedding, + backend=custom_backend) embedding_compiled = torch.compile(embedding, backend=custom_backend) input = torch.cat(inputs) torch._dynamo.mark_dynamic(input, 0) - lr = embedding_compiled(input) - lora_result = lora_embedding_compiled(input) + embedding_compiled(input) + lora_embedding_compiled(input) + if __name__ == '__main__': with _dist_init(): - test_embeddings() \ No newline at end of file + test_embeddings() diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index c90baa5307116..12d5f2a857b66 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -385,6 +385,7 @@ def test_punica_expand_nslices( slice_offset += hidden_size assert_close(our_outputs, ref_outputs) + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @@ -430,16 +431,16 @@ def test_v1_shrink_expand( device, ) - v1_meta: V1KernelMeta = V1KernelMeta.make(max_loras = num_loras, max_num_tokens = seq_length * batches, device=device) + v1_meta: V1KernelMeta = V1KernelMeta.make(max_loras=num_loras, + max_num_tokens=seq_length * + batches, + device=device) v1_meta.reset() v1_meta.prepare_tensors(token_lora_mapping) if op_type == "shrink": - lora_shrink(inputs_tensor, - lora_weights, - our_out_tensor, - *v1_meta.meta_args, - scaling ) + lora_shrink(inputs_tensor, lora_weights, our_out_tensor, + *v1_meta.meta_args, scaling) else: lora_expand(inputs_tensor, lora_weights, @@ -506,7 +507,10 @@ def test_v1_expand_nslices( device, ) - v1_meta: V1KernelMeta = V1KernelMeta.make(max_loras = num_loras, max_num_tokens = seq_length * batches, device=device) + v1_meta: V1KernelMeta = V1KernelMeta.make(max_loras=num_loras, + max_num_tokens=seq_length * + batches, + device=device) v1_meta.reset() v1_meta.prepare_tensors(token_lora_mapping) @@ -514,12 +518,12 @@ def test_v1_expand_nslices( for index in range(nslices): lora_weights = lora_weights_lst[index] lora_expand_slice(inputs_tensor, - lora_weights, - our_outputs, - *v1_meta.meta_args, - slice_offset, - hidden_size, - add_inputs=True) + lora_weights, + our_outputs, + *v1_meta.meta_args, + slice_offset, + hidden_size, + add_inputs=True) ref_torch_groupgemm( ref_outputs[:, slice_offset:slice_offset + hidden_size], @@ -533,4 +537,4 @@ def test_v1_expand_nslices( ) slice_offset += hidden_size - assert_close(our_outputs, ref_outputs) \ No newline at end of file + assert_close(our_outputs, ref_outputs) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 5f04d2f7adb10..bf0ad2e988ba9 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -231,9 +231,11 @@ def set_lora( self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def forward(self, x: torch.Tensor) -> torch.Tensor: - added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, 1, 0) - embeddings_indices = torch.narrow(self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) - + added_tokens_mask = torch.where(x > self.base_layer.org_vocab_size - 1, + 1, 0) + embeddings_indices = torch.narrow( + self.punica_wrapper._embeddings_indices, 1, 0, x.size(0)) + indices = embeddings_indices[1].view_as(x) full_lora_a_embeddings = F.embedding( x + indices, diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 74a7ba8ff9cb9..8002ce2694d6d 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -327,10 +327,11 @@ def __init__( self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.long_lora_context: Optional[LongContextLoRAContext] = None - self.punica_wrapper = get_punica_wrapper(max_num_batched_tokens, - max_batches=self.max_num_seqs, - device=self.device, - max_loras=lora_config.max_loras) + self.punica_wrapper = get_punica_wrapper( + max_num_batched_tokens, + max_batches=self.max_num_seqs, + device=self.device, + max_loras=lora_config.max_loras) # Scaling factor -> offset to the sin_cos_cache to it. # Used for long context lora. self.scaling_factor_to_offset: Dict[float, int] = {} diff --git a/vllm/lora/ops/v1/lora_expand.py b/vllm/lora/ops/v1/lora_expand.py index a33e62a001ec7..88db011e13182 100644 --- a/vllm/lora/ops/v1/lora_expand.py +++ b/vllm/lora/ops/v1/lora_expand.py @@ -4,6 +4,7 @@ import math from vllm.utils import direct_register_custom_op + @triton.jit def _lora_expand_kernel( input_ptr, @@ -18,7 +19,7 @@ def _lora_expand_kernel( lora_ids, xm_stride, xk_stride, # 1 - l0_stride, # hidden_size * max rank + l0_stride, # hidden_size * max rank lora_n_stride, lora_k_stride, cm_stride, @@ -31,34 +32,35 @@ def _lora_expand_kernel( BLOCK_K: tl.constexpr, ): - NUM_M_CTAS = tl.cdiv(M, BLOCK_M) + NUM_M_CTAS = tl.cdiv(M, BLOCK_M) NUM_N_CTAS = tl.cdiv(N, BLOCK_N) pid = tl.program_id(0) - l = pid // (NUM_M_CTAS * NUM_N_CTAS) + lora_idx = pid // (NUM_M_CTAS * NUM_N_CTAS) cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS cta_m = pid % NUM_M_CTAS - lora_id = tl.load(lora_ids + l) + lora_id = tl.load(lora_ids + lora_idx) if lora_id == -1: # early exit for the no-lora case. return # lora m indices offsets - lora_m_indices_start = tl.load(lora_token_start_loc + l) - lora_m_size = tl.load(num_tokens_per_lora + l) + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) - cta_m_offset = cta_m * BLOCK_M + cta_m_offset = cta_m * BLOCK_M if cta_m_offset >= lora_m_size: # early exit CTA return - cta_lora_seq_indices = token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) offset_k = tl.arange(0, BLOCK_K) - offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size rm = tl.load(cta_lora_seq_indices + offset_rm) a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride @@ -88,7 +90,6 @@ def _lora_expand_kernel( a_ptr += BLOCK_K * xk_stride b_ptr += BLOCK_K * lora_k_stride - tiled_c = accumulator.to(lora_ptr.dtype.element_ty) offset_cm = tl.arange(0, BLOCK_M) offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N @@ -100,15 +101,16 @@ def _lora_expand_kernel( tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) + @torch.inference_mode() def _lora_expand( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, - token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) - num_tokens_per_lora: torch.Tensor, # max-loras + 1 - lora_token_start_loc: torch.Tensor, # max-loras + 2 - lora_ids: torch.Tensor, # max-loras + 1 + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + 1 + lora_token_start_loc: torch.Tensor, # max-loras + 2 + lora_ids: torch.Tensor, # max-loras + 1 add_inputs: bool = False, ) -> None: """ @@ -116,12 +118,14 @@ def _lora_expand( inputs (torch.Tensor): input tensor lora_b_weights (torch.Tensor): lora'b weight output_tensor (torch.Tensor): output tensor - token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix grouped by LoRA IDs. - num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens that are to be - processed by LoRA ID lora_ids[i] - lora_token_start_loc: A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] - is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] - identifies the the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix + grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens + that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. + lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], + along with num_tokens_per_lora[i] identifies the the region in + token_indices_sorted_by_lora_ids that LoRA lora_ids[i] should process. lora_ids: LoRA ids to process. add_inputs (bool, optional): Defaults to False, adds the final lora @@ -159,7 +163,7 @@ def _lora_expand( NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA NUM_N_CTAS = math.ceil(N / BLOCK_N) - MAX_LORAS = lora_ids.size(0) + MAX_LORAS = lora_ids.size(0) EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs @@ -175,12 +179,10 @@ def _lora_expand( l0_stride = lora_b_weights.stride(0) lora_n_stride = lora_b_weights.stride(1) lora_k_stride = lora_b_weights.stride(2) - cm_stride = output_tensor.stride(0) + cm_stride = output_tensor.stride(0) cn_stride = output_tensor.stride(1) - grid = ( - MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, - ) + grid = (MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, ) _lora_expand_kernel[grid]( inputs, @@ -209,6 +211,7 @@ def _lora_expand( ) return + def lora_expand_fake( inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -221,6 +224,7 @@ def lora_expand_fake( ) -> None: return + try: direct_register_custom_op( op_name="lora_expand", @@ -231,4 +235,4 @@ def lora_expand_fake( lora_expand = torch.ops.vllm.lora_expand except AttributeError: - lora_expand = _lora_expand \ No newline at end of file + lora_expand = _lora_expand diff --git a/vllm/lora/ops/v1/lora_expand_slice.py b/vllm/lora/ops/v1/lora_expand_slice.py index 4690b6f59dfed..c64f95e514445 100644 --- a/vllm/lora/ops/v1/lora_expand_slice.py +++ b/vllm/lora/ops/v1/lora_expand_slice.py @@ -4,6 +4,7 @@ import math from vllm.utils import direct_register_custom_op + @triton.jit def _lora_expand_slice_kernel( input_ptr, @@ -18,7 +19,7 @@ def _lora_expand_slice_kernel( lora_ids, xm_stride, xk_stride, # 1 - l0_stride, # hidden_size * max rank + l0_stride, # hidden_size * max rank lora_n_stride, lora_k_stride, cm_stride, @@ -32,34 +33,35 @@ def _lora_expand_slice_kernel( BLOCK_K: tl.constexpr, ): - NUM_M_CTAS = tl.cdiv(M, BLOCK_M) + NUM_M_CTAS = tl.cdiv(M, BLOCK_M) NUM_N_CTAS = tl.cdiv(N, BLOCK_N) pid = tl.program_id(0) - l = pid // (NUM_M_CTAS * NUM_N_CTAS) + lora_idx = pid // (NUM_M_CTAS * NUM_N_CTAS) cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS cta_m = pid % NUM_M_CTAS - lora_id = tl.load(lora_ids + l) + lora_id = tl.load(lora_ids + lora_idx) if lora_id == -1: # early exit for the no-lora case. return # lora m indices offsets - lora_m_indices_start = tl.load(lora_token_start_loc + l) - lora_m_size = tl.load(num_tokens_per_lora + l) + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) - cta_m_offset = cta_m * BLOCK_M + cta_m_offset = cta_m * BLOCK_M if cta_m_offset >= lora_m_size: # early exit CTA return - cta_lora_seq_indices = token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) offset_k = tl.arange(0, BLOCK_K) - offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size rm = tl.load(cta_lora_seq_indices + offset_rm) a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride @@ -89,13 +91,13 @@ def _lora_expand_slice_kernel( a_ptr += BLOCK_K * xk_stride b_ptr += BLOCK_K * lora_k_stride - tiled_c = accumulator.to(lora_ptr.dtype.element_ty) offset_cm = tl.arange(0, BLOCK_M) offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + slice_offset c_ptr = out_ptr + rm[:, None] * cm_stride + offset_cn[None, :] * cn_stride - c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < (N + slice_offset)) + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < + (N + slice_offset)) if ADD_INPUTS: tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out @@ -107,10 +109,10 @@ def _lora_expand_slice( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, - token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) - num_tokens_per_lora: torch.Tensor, # max-loras - lora_token_start_loc: torch.Tensor, # max-loras - lora_ids: torch.Tensor, # max-loras + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + lora_token_start_loc: torch.Tensor, # max-loras + lora_ids: torch.Tensor, # max-loras slice_offset: int, slice_size: int, add_inputs: bool = False, @@ -120,12 +122,14 @@ def _lora_expand_slice( inputs (torch.Tensor): input tensor lora_b_weights (torch.Tensor): lora'b weight output_tensor (torch.Tensor): output tensor - token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix grouped by LoRA IDs. - num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens that are to be - processed by LoRA ID lora_ids[i] - lora_token_start_loc: A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] - is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] - identifies the the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix + grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens + that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. + lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], + along with num_tokens_per_lora[i] identifies the the region in + token_indices_sorted_by_lora_ids that LoRA lora_ids[i] should process. lora_ids: LoRA ids to process. slice_offset (int): output_tensor's offset @@ -166,7 +170,7 @@ def _lora_expand_slice( NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA NUM_N_CTAS = math.ceil(N / BLOCK_N) - MAX_LORAS = lora_ids.size(0) + MAX_LORAS = lora_ids.size(0) EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs @@ -182,12 +186,10 @@ def _lora_expand_slice( l0_stride = lora_b_weights.stride(0) lora_n_stride = lora_b_weights.stride(1) lora_k_stride = lora_b_weights.stride(2) - cm_stride = output_tensor.stride(0) + cm_stride = output_tensor.stride(0) cn_stride = output_tensor.stride(1) - grid = ( - MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, - ) + grid = (MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, ) _lora_expand_slice_kernel[grid]( inputs, @@ -217,6 +219,7 @@ def _lora_expand_slice( ) return + def lora_expand_slice_fake( inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -231,6 +234,7 @@ def lora_expand_slice_fake( ) -> None: return + try: direct_register_custom_op( op_name="lora_expand_slice", @@ -241,4 +245,4 @@ def lora_expand_slice_fake( lora_expand_slice = torch.ops.vllm.lora_expand_slice except AttributeError: - lora_expand_slice = _lora_expand_slice \ No newline at end of file + lora_expand_slice = _lora_expand_slice diff --git a/vllm/lora/ops/v1/lora_expand_slices.py b/vllm/lora/ops/v1/lora_expand_slices.py index 1bd20865a3aca..e38ed5e8e2390 100644 --- a/vllm/lora/ops/v1/lora_expand_slices.py +++ b/vllm/lora/ops/v1/lora_expand_slices.py @@ -5,6 +5,7 @@ from vllm.utils import direct_register_custom_op + @triton.jit def _lora_expand_slices_kernel( input_ptr, @@ -37,7 +38,7 @@ def _lora_expand_slices_kernel( ): pid = tl.program_id(0) cta_s = pid // (MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS) - cta_l = (pid // (NUM_M_CTAS * NUM_N_CTAS)) % MAX_LORAS + cta_l = (pid // (NUM_M_CTAS * NUM_N_CTAS)) % MAX_LORAS cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS cta_m = pid % NUM_M_CTAS @@ -51,19 +52,20 @@ def _lora_expand_slices_kernel( lora_m_indices_start = tl.cast(0, tl.int32) else: lora_m_indices_start = tl.load(lora_seq_start_loc + cta_l - 1) - lora_m_size = tl.load(lora_seq_counts + cta_l) + lora_m_size = tl.load(lora_seq_counts + cta_l) - cta_m_offset = cta_m * BLOCK_M + cta_m_offset = cta_m * BLOCK_M if cta_m_offset >= lora_m_size: # early exit CTA return - cta_lora_seq_indices = lora_seq_indices + lora_m_indices_start + cta_m_offset + cta_lora_seq_indices = (lora_seq_indices + lora_m_indices_start + + cta_m_offset) cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) offset_k = tl.arange(0, BLOCK_K) - offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size rm = tl.load(cta_lora_seq_indices + offset_rm) a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride @@ -93,19 +95,20 @@ def _lora_expand_slices_kernel( a_ptr += BLOCK_K * xk_stride b_ptr += BLOCK_K * lora_k_stride - slice_offset = cta_s * N tiled_c = accumulator.to(lora_ptr.dtype.element_ty) offset_cm = tl.arange(0, BLOCK_M) - offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + slice_offset + offset_cn = tl.arange(0, BLOCK_N) + cta_n * BLOCK_N + slice_offset c_ptr = out_ptr + rm[:, None] * cm_stride + offset_cn[None, :] * cn_stride - c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < (slice_offset + N)) + c_mask = (offset_cm[:, None] < cta_m_size) & (offset_cn[None, :] < + (slice_offset + N)) if ADD_INPUTS: tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) + @torch.inference_mode() def _lora_expand_slices( inputs: torch.Tensor, @@ -124,11 +127,13 @@ def _lora_expand_slices( lora_b_weights (torch.Tensor): lora'a weight output_tensor (torch.Tensor): output tensor - token_lora_mapping_tensor: Each token's lora id as it appears in the A matrix. + token_lora_mapping_tensor: Each token's lora id as it appears in the + A matrix. - lora_seq_indices: sorted lora-token mapping. Tokens of the same lora appear next to eachother. - This is used so a thread block knows what tokens to put next to eachother when constructing a matrix block. - Essentially, + lora_seq_indices: sorted lora-token mapping. Tokens of the same lora + appear next to each other. This is used so a thread block knows + what tokens to put next to each other when constructing a matrix + block. Essentially, _, lora_seq_indices = torch.sort(token_lora_mapping, stable=True) lora_seq_counts: number of tokens per lora id. essentially, @@ -136,10 +141,12 @@ def _lora_expand_slices( sorted=False, return_counts=True) - lora_seq_start_loc: start index of each lora id in lora_seq_indices. essentially, + lora_seq_start_loc: start index of each lora id in lora_seq_indices. + essentially, lora_seq_start_loc = torch.cumsum(lora_seq_counts, dim = 0) - lora_ids : Set of lora ids in order according to lora_seq_counts, and lora_seq_indices + lora_ids : Set of lora ids in order according to lora_seq_counts, + and lora_seq_indices. lora_ids, lora_seq_counts = torch.unique(indices, sorted=False, return_counts=True) @@ -159,7 +166,7 @@ def _lora_expand_slices( assert inputs.is_contiguous() assert output_tensor.is_contiguous() - assert lora_b_weights.ndim == 4 # (nslices, lora_num, hidden-size, rank) + assert lora_b_weights.ndim == 4 # (nslices, lora_num, hidden-size, rank) assert lora_b_weights.is_contiguous() # TODO tuning this config @@ -184,34 +191,17 @@ def _lora_expand_slices( ]: CAST_TYPE = True - grid = ( - NUM_SLICES * MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, - ) + grid = (NUM_SLICES * MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, ) xm_stride = inputs.stride(0) xk_stride = inputs.stride(1) - l0_stride = lora_b_weights.stride(0) # slice stride - l1_stride = lora_b_weights.stride(1) # lora stride + l0_stride = lora_b_weights.stride(0) # slice stride + l1_stride = lora_b_weights.stride(1) # lora stride lora_n_stride = lora_b_weights.stride(2) lora_k_stride = lora_b_weights.stride(3) - cm_stride = output_tensor.stride(0) + cm_stride = output_tensor.stride(0) cn_stride = output_tensor.stride(1) - #print (f"lora seq indices : {lora_seq_indices.dtype} {lora_seq_indices}") - #print (f"lora seq counts : {lora_seq_counts.dtype} {lora_seq_counts}") - #print (f"lora seq start loc : {lora_seq_start_loc.dtype} {lora_seq_start_loc}") - #print (f"lora ids : {lora_ids.dtype} {lora_ids}") - #print (f"num loras : {NUM_LORAS}") - #print (f"num slices : {NUM_SLICES}") - #print (f"N : {N}") - #print (f"K : {K}") - #print (f"A : {inputs.dtype} {inputs.shape}") - #print (f"B : {lora_b_weights.dtype} {lora_b_weights.shape}") - #print (f"C : {output_tensor.dtype} {output_tensor.shape}") - #print (f"A m k strides : {xm_stride} {xk_stride}") - #print (f"B k n strides : {lora_k_stride} {lora_n_stride}") - #print (f"C m n strides : {cm_stride} {cn_stride}") - _lora_expand_slices_kernel[grid]( inputs, lora_b_weights, @@ -245,12 +235,14 @@ def _lora_expand_slices( try: - lora_expand_slices = torch.library.custom_op("lora::v1::lora_expand_slices", - _lora_expand_slices, - mutates_args=["output_tensor"]) + lora_expand_slices = torch.library.custom_op( + "lora::v1::lora_expand_slices", + _lora_expand_slices, + mutates_args=["output_tensor"]) except AttributeError: lora_expand_slices = _lora_expand_slices + def lora_expand_slices_fake( inputs: torch.Tensor, lora_b_weights: torch.Tensor, @@ -263,14 +255,15 @@ def lora_expand_slices_fake( ) -> None: return + try: direct_register_custom_op( op_name="lora_expand_slices", - op_func= _lora_expand_slices, + op_func=_lora_expand_slices, mutates_args=["output_tensor"], fake_impl=lora_expand_slices_fake, ) lora_expand_slices = torch.ops.vllm.lora_expand_slices except AttributeError: - lora_expand = _lora_expand_slices \ No newline at end of file + lora_expand = _lora_expand_slices diff --git a/vllm/lora/ops/v1/lora_shrink.py b/vllm/lora/ops/v1/lora_shrink.py index 0057b066a58ae..7423abf7c7334 100644 --- a/vllm/lora/ops/v1/lora_shrink.py +++ b/vllm/lora/ops/v1/lora_shrink.py @@ -7,64 +7,68 @@ @triton.jit def _lora_shrink_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - token_indices_sorted_by_lora_ids, - num_tokens_per_lora, - lora_token_start_loc, - lora_ids, - scaling, - xm_stride, - xk_stride, - l0_stride, - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - BLOCK_M : tl.constexpr, - BLOCK_N : tl.constexpr, - BLOCK_K : tl.constexpr, - EVEN_K : tl.constexpr, - SPLIT_K : tl.constexpr, - NUM_M_CTAS : tl.constexpr, - NUM_N_CTAS : tl.constexpr, - ): + input_ptr, + lora_ptr, + out_ptr, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + NUM_M_CTAS: tl.constexpr, + NUM_N_CTAS: tl.constexpr, +): pid = tl.program_id(0) - l = pid // (NUM_M_CTAS * NUM_N_CTAS) + lora_idx = pid // (NUM_M_CTAS * NUM_N_CTAS) cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS cta_m = pid % NUM_M_CTAS cta_sk = tl.program_id(1) - lora_id = tl.load(lora_ids + l) + lora_id = tl.load(lora_ids + lora_idx) if lora_id == -1: # early exit for the no-lora case. return # lora m indices offsets - lora_m_indices_start = tl.load(lora_token_start_loc + l) - lora_m_size = tl.load(num_tokens_per_lora + l) + lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) + lora_m_size = tl.load(num_tokens_per_lora + lora_idx) - cta_m_offset = cta_m * BLOCK_M + cta_m_offset = cta_m * BLOCK_M if cta_m_offset >= lora_m_size: # early exit CTA return - cta_lora_seq_indices = token_indices_sorted_by_lora_ids + lora_m_indices_start + cta_m_offset + cta_lora_seq_indices = (token_indices_sorted_by_lora_ids + + lora_m_indices_start + cta_m_offset) cta_m_size = min(BLOCK_M, lora_m_size - cta_m_offset) - offset_k = tl.max_contiguous(BLOCK_K * cta_sk + tl.arange(0, BLOCK_K), BLOCK_K) + offset_k = tl.max_contiguous(BLOCK_K * cta_sk + tl.arange(0, BLOCK_K), + BLOCK_K) - offset_rm = tl.arange(0, BLOCK_M) % cta_m_size + offset_rm = tl.arange(0, BLOCK_M) % cta_m_size rm = tl.load(cta_lora_seq_indices + offset_rm) a_ptr = input_ptr + rm[:, None] * xm_stride + offset_k[None, :] * xk_stride - offset_n = tl.max_contiguous((cta_n * BLOCK_N)+ tl.arange(0, BLOCK_N), BLOCK_N) + offset_n = tl.max_contiguous((cta_n * BLOCK_N) + tl.arange(0, BLOCK_N), + BLOCK_N) rn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - b_ptr = lora_ptr + lora_id * l0_stride + rn[None, :] * lora_n_stride + offset_k[:, None] * lora_k_stride + b_ptr = lora_ptr + lora_id * l0_stride + rn[ + None, :] * lora_n_stride + offset_k[:, None] * lora_k_stride acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) max_k = tl.cdiv(K, BLOCK_K * SPLIT_K) @@ -79,11 +83,10 @@ def _lora_shrink_kernel( a_mask = offset_k[None, :] < K a_tile = tl.load(a_ptr, mask=a_mask, other=0.0) - # TODO (varun) : When a_tile and b_tile are float16s the output is also float16. this can - # lead to infs in the output. + # TODO (varun) : When a_tile and b_tile are float16s the output is also + # float16. this can lead to infs in the output. acc += tl.dot(a_tile, b_tile) - a_ptr += BLOCK_K * SPLIT_K * xk_stride b_ptr += BLOCK_K * SPLIT_K * lora_k_stride offset_k += BLOCK_K * SPLIT_K @@ -98,15 +101,16 @@ def _lora_shrink_kernel( else: tl.atomic_add(c_ptr, acc, mask=c_mask) + @torch.inference_mode() def _lora_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, - token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) - num_tokens_per_lora: torch.Tensor, # max-loras - lora_token_start_loc: torch.Tensor, # max-loras - lora_ids: torch.Tensor, # max-loras + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + lora_token_start_loc: torch.Tensor, # max-loras + lora_ids: torch.Tensor, # max-loras scaling: float, ) -> None: """ @@ -114,13 +118,15 @@ def _lora_shrink( inputs (torch.Tensor): input tensor lora_a_weights (torch.Tensor): lora'a weight output_tensor (torch.Tensor): output tensor - token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix grouped by LoRA IDs. - num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens that are to be - processed by LoRA ID lora_ids[i] - lora_token_start_loc: A cumulative sum of num_tokens_per_lora. lora_token_start_loc[0] - is always 0 so that lora_token_start_loc[i], along with num_tokens_per_lora[i] - identifies the the region in token_indices_sorted_by_lora_ids that LoRA lora_ids[i] - should process. + token_indices_sorted_by_lora_ids: Row/Token indices from the A matrix + grouped by LoRA IDs. + num_tokens_per_lora: num_tokens_per_lora[i] is the number of tokens + that are to be processed by LoRA ID lora_ids[i] + lora_token_start_loc: A cumulative sum of num_tokens_per_lora. + lora_token_start_loc[0] is always 0 so that lora_token_start_loc[i], + along with num_tokens_per_lora[i] identifies the the region in + token_indices_sorted_by_lora_ids that LoRA lora_ids[i] should + process. lora_ids: LoRA ids to process. add_inputs (bool, optional): Defaults to False, adds the final lora results to the output. @@ -155,7 +161,7 @@ def _lora_shrink( cn_stride = output_tensor.stride(1) # TODO tuning this config - M = inputs.size(0) # num tokens + M = inputs.size(0) # num tokens N = lora_a_weights.size(-2) K = lora_a_weights.size(-1) MAX_LORAS = lora_ids.size(0) @@ -200,6 +206,7 @@ def _lora_shrink( ) return + def lora_shrink_fake( inputs: torch.Tensor, lora_a_weights: torch.Tensor, @@ -209,9 +216,10 @@ def lora_shrink_fake( lora_token_start_loc: torch.Tensor, lora_ids: torch.Tensor, scaling: float, -)-> None: +) -> None: return + try: direct_register_custom_op( op_name="lora_shrink", @@ -222,4 +230,4 @@ def lora_shrink_fake( lora_shrink = torch.ops.vllm.lora_shrink except AttributeError: - lora_shrink = _lora_shrink \ No newline at end of file + lora_shrink = _lora_shrink diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index e86201909896b..dd463dd68dbab 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -192,8 +192,10 @@ def update_base_metadata( self.device, long_lora_context, ) - self._token_lora_indices[:base_indices.size(0)].copy_(base_indices, non_blocking=True) - self._sampler_indices[:sampler_indices.size(0)].copy_(sampler_indices, non_blocking=True) + self._token_lora_indices[:base_indices.size(0)].copy_( + base_indices, non_blocking=True) + self._sampler_indices[:sampler_indices.size(0)].copy_( + sampler_indices, non_blocking=True) self._sampler_indices_padded[:sampler_indices_padded.size(0)].copy_( sampler_indices_padded, non_blocking=True) self._embeddings_indices[:embeddings_indices. @@ -329,8 +331,8 @@ def update_metadata( **kwargs): self.update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) + vocab_size, extra_vocab_size, + long_lora_context) if mapping.is_prefill: # Update metadata required for prefill-related operators. self._update_prefill_metada(self.token_lora_indices) diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index d8ca03dff8982..7f4d092179b1b 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -9,7 +9,7 @@ def get_punica_wrapper(*args, **kwargs) -> PunicaWrapperBase: if current_platform.is_cuda_alike(): # Lazy import to avoid ImportError if envs.VLLM_USE_V1: - from vllm.lora.punica_wrapper.v1_gpu import V1LoRAGPU + from vllm.lora.punica_wrapper.v1_gpu import V1LoRAGPU print_info_once("Using V1LoRAGPU.") return V1LoRAGPU(*args, **kwargs) else: diff --git a/vllm/lora/punica_wrapper/v1_gpu.py b/vllm/lora/punica_wrapper/v1_gpu.py index 5214d905379d4..34a8dcee93ffa 100644 --- a/vllm/lora/punica_wrapper/v1_gpu.py +++ b/vllm/lora/punica_wrapper/v1_gpu.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union, final, List +from typing import TYPE_CHECKING, Optional, Tuple, Union, final, List from dataclasses import dataclass import torch @@ -10,7 +10,6 @@ from vllm.lora.ops.v1.lora_expand import lora_expand from vllm.lora.ops.v1.lora_shrink import lora_shrink from vllm.lora.ops.v1.lora_expand_slice import lora_expand_slice - #from vllm.lora.ops.v1.lora_expand_slices import lora_expand_slices from .punica_base import PunicaWrapperBase @@ -18,6 +17,7 @@ # avoid circuit import from vllm.lora.models import LongContextLoRAContext + @dataclass class V1KernelMeta: token_indices_sorted_by_lora_ids: torch.Tensor @@ -26,7 +26,8 @@ class V1KernelMeta: lora_token_start_loc: torch.Tensor @staticmethod - def make(max_loras: int, max_num_tokens: int, device: torch.device) -> "V1KernelMeta": + def make(max_loras: int, max_num_tokens: int, + device: torch.device) -> "V1KernelMeta": token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, dtype=torch.int32, device=device) @@ -36,7 +37,7 @@ def make(max_loras: int, max_num_tokens: int, device: torch.device) -> "V1Kernel # is a possibility. active_lora_ids = torch.empty(max_loras + 1, dtype=torch.int32, - device=device) + device=device) # using running example, [3, 10, 5, 2] is a possibility. num_tokens_per_lora = torch.zeros(max_loras + 1, @@ -48,11 +49,12 @@ def make(max_loras: int, max_num_tokens: int, device: torch.device) -> "V1Kernel # can be [0, 3, 13, 18, 20]. lora_token_start_loc = torch.zeros(max_loras + 2, dtype=torch.int32, - device=device) - return V1KernelMeta(token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, - active_lora_ids = active_lora_ids, - num_tokens_per_lora=num_tokens_per_lora, - lora_token_start_loc = lora_token_start_loc) + device=device) + return V1KernelMeta( + token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, + active_lora_ids=active_lora_ids, + num_tokens_per_lora=num_tokens_per_lora, + lora_token_start_loc=lora_token_start_loc) def reset(self): self.active_lora_ids.fill_(-1) @@ -62,29 +64,34 @@ def reset(self): def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: num_tokens = token_lora_mapping.size(0) # token_indices_sorted_by_lora_ids - _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, stable=True) + _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, + stable=True) # start gpu transfer - self.token_indices_sorted_by_lora_ids[:num_tokens].copy_(token_indices_sorted_by_lora_ids, - non_blocking=True) - + self.token_indices_sorted_by_lora_ids[:num_tokens].copy_( + token_indices_sorted_by_lora_ids, non_blocking=True) + # active_lora_ids, num_tokens_per_lora lora_ids, num_tokens_per_lora = torch.unique(token_lora_mapping, - sorted=False, - return_counts=True) - self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, non_blocking=True) - self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_(num_tokens_per_lora, non_blocking=True) + sorted=False, + return_counts=True) + self.active_lora_ids[:lora_ids.size(0)].copy_(lora_ids, + non_blocking=True) + self.num_tokens_per_lora[:num_tokens_per_lora.size(0)].copy_( + num_tokens_per_lora, non_blocking=True) # lora_token_start_loc lora_token_start_loc = torch.cumsum(num_tokens_per_lora, dim=0) - self.lora_token_start_loc[1: 1 + lora_token_start_loc.size(0)].copy_(lora_token_start_loc, - non_blocking=True) + self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( + lora_token_start_loc, non_blocking=True) - def meta_args(self, num_tokens: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def meta_args( + self, num_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return (self.token_indices_sorted_by_lora_ids[:num_tokens], - self.num_tokens_per_lora, - self.lora_token_start_loc, + self.num_tokens_per_lora, self.lora_token_start_loc, self.active_lora_ids) + @final class V1LoRAGPU(PunicaWrapperBase): """ @@ -95,14 +102,17 @@ class V1LoRAGPU(PunicaWrapperBase): PunicaWrapperBase (_type_): _description_ """ - def __init__(self, max_num_batched_tokens: int, max_batches: int, device: Union[torch.device, str], **kwargs): PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) self.max_loras = kwargs['max_loras'] - self.token_mapping_v1_meta = V1KernelMeta.make(self.max_loras, max_num_batched_tokens, device=device) - self.prompt_mapping_v1_meta = V1KernelMeta.make(self.max_loras, max_batches, device=device) + self.token_mapping_v1_meta = V1KernelMeta.make(self.max_loras, + max_num_batched_tokens, + device=device) + self.prompt_mapping_v1_meta = V1KernelMeta.make(self.max_loras, + max_batches, + device=device) def update_metadata( self, @@ -119,14 +129,14 @@ def update_metadata( #if self.no_lora: # # no update required # return - num_tokens: int = len(mapping.index_mapping) + #num_tokens: int = len(mapping.index_mapping) self.token_mapping_v1_meta.reset() self.prompt_mapping_v1_meta.reset() self.update_base_metadata(mapping, lora_index_to_id, max_loras, - vocab_size, extra_vocab_size, - long_lora_context) + vocab_size, extra_vocab_size, + long_lora_context) self.token_mapping_v1_meta.prepare_tensors(self.token_lora_indices) self.prompt_mapping_v1_meta.prepare_tensors(self.sampler_indices) @@ -138,11 +148,8 @@ def _shrink( w_t_all: torch.Tensor, scale: float, ): - lora_shrink(x, - w_t_all, - y, - *self.token_mapping_v1_meta.meta_args(x.size(0)), - scale) + lora_shrink(x, w_t_all, y, + *self.token_mapping_v1_meta.meta_args(x.size(0)), scale) def _expand( self, @@ -169,10 +176,8 @@ def _expand_slice( return lora_expand_slice(x, w_t_all, y, - *self.token_mapping_v1_meta.meta_args(x.size(0)), - y_offset, - y_slice_size, - add_inputs) + *self.token_mapping_v1_meta.meta_args(x.size(0)), + y_offset, y_slice_size, add_inputs) def _apply_expand( self, @@ -264,8 +269,9 @@ def add_expand(self, y = y.view(-1, y.size(-1)) offset_left = offset_start if lora_bias_stacked is not None: - self._apply_bias(self._token_lora_indices[:x.size(0)], y, output_slices, - lora_bias_stacked) + assert isinstance(x, torch.Tensor) + self._apply_bias(self._token_lora_indices[:x.size(0)], y, + output_slices, lora_bias_stacked) for slice_idx in range(len(lora_b_stacked)): self._apply_expand( y, @@ -337,8 +343,8 @@ def add_lora_linear(self, assert len(lora_a_stacked) == len(lora_b_stacked) == len(output_slices) if lora_bias_stacked is not None: assert len(lora_bias_stacked) == len(output_slices) - y = self._apply_bias(self._token_lora_indices[:x.size(0)], y, output_slices, - lora_bias_stacked) + y = self._apply_bias(self._token_lora_indices[:x.size(0)], y, + output_slices, lora_bias_stacked) if buffer is None: r = lora_b_stacked[0].size(-1) @@ -392,8 +398,11 @@ def add_lora_logits(self, dtype=torch.float32, device=x.device) - lora_shrink(x, lora_a_stacked, buffer, *self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale) - lora_expand(buffer, lora_b_stacked, y, + lora_shrink(x, lora_a_stacked, buffer, + *self.prompt_mapping_v1_meta.meta_args(x.size(0)), scale) + lora_expand(buffer, + lora_b_stacked, + y, *self.prompt_mapping_v1_meta.meta_args(x.size(0)), add_inputs=True) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 9f0f0a47168cb..7ab435954c5d5 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -167,10 +167,10 @@ def get_all_free_blocks(self) -> List[KVCacheBlock]: def generate_block_hash_extra_keys_for_mm( request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int) -> Tuple[Optional[List[Any]], int]: - """Generate extra keys related to MultiModal request for block hash computation. - For multi-modal inputs, the extra keys are (mm_hash, start_offset) that - indicate a mm input contained in the block and its starting offset in - the block tokens. + """Generate extra keys related to MultiModal request for block hash + computation. For multi-modal inputs, the extra keys are + (mm_hash, start_offset) that indicate a mm input contained in the + block and its starting offset in the block tokens. Args: request: The request object. @@ -231,6 +231,7 @@ def generate_block_hash_extra_keys_for_mm( break return extra_keys, curr_mm_idx + def generate_block_hash_extra_keys_for_lora( request: Request) -> Optional[List[int]]: """Generate extra keys related to LoRA for block hash computation. @@ -246,6 +247,7 @@ def generate_block_hash_extra_keys_for_lora( return None return [request.lora_request.lora_int_id] + def generate_block_hash_extra_keys( request: Request, start_token_idx: int, end_token_idx: int, start_mm_idx: int) -> Tuple[Optional[Tuple[Any, ...]], int]: @@ -262,8 +264,10 @@ def generate_block_hash_extra_keys( A tuple of extra keys and the next multi-modal index. """ mm_extra_keys: Optional[List[Any]] - mm_extra_keys, new_start_mm_idx = generate_block_hash_extra_keys_for_mm(request, start_token_idx, end_token_idx, start_mm_idx) - lora_extra_keys: Optional[List[int]] = generate_block_hash_extra_keys_for_lora(request) + mm_extra_keys, new_start_mm_idx = generate_block_hash_extra_keys_for_mm( + request, start_token_idx, end_token_idx, start_mm_idx) + lora_extra_keys: Optional[ + List[int]] = generate_block_hash_extra_keys_for_lora(request) extra_keys: List[Any] = [] if mm_extra_keys: @@ -276,6 +280,7 @@ def generate_block_hash_extra_keys( return tuple(extra_keys), new_start_mm_idx + def hash_block_tokens( parent_block_hash: Optional[int], curr_block_token_ids: Sequence[int], diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 22490bf5c550f..cb6a866c74906 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -118,7 +118,10 @@ def __init__( # The batch sizes in the config are in descending order. self.cudagraph_batch_sizes = list( reversed(self.vllm_config.compilation_config.capture_sizes)) - self.cudagraph_batch_sizes = [bs for bs in self.cudagraph_batch_sizes if bs <= self.scheduler_config.max_num_seqs] + self.cudagraph_batch_sizes = [ + bs for bs in self.cudagraph_batch_sizes + if bs <= self.scheduler_config.max_num_seqs + ] # Persistent buffers for CUDA graphs. self.input_ids = torch.zeros(self.max_num_tokens, @@ -737,7 +740,6 @@ def profile_run(self) -> None: with self.maybe_profile_with_lora(self.lora_config, num_scheduled_tokens): - print(f"running profile code ...") # Trigger compilation for general shape. hidden_states = self._dummy_run(self.model, self.max_num_tokens, dummy_kv_caches) @@ -765,8 +767,8 @@ def capture_model(self) -> None: with graph_capture(): for num_tokens in reversed(self.cudagraph_batch_sizes): print(f"running capture model for {num_tokens}...") - with self.maybe_capture_model_with_lora(self.lora_config, - num_tokens): + with self.maybe_capture_model_with_lora( + self.lora_config, num_tokens): for _ in range(self.vllm_config.compilation_config. cudagraph_num_of_warmups): self._dummy_run(self.model, num_tokens, self.kv_caches) diff --git a/vllm/v1/worker/lora_model_runner_mixin.py b/vllm/v1/worker/lora_model_runner_mixin.py index d7b379c85976f..d989019341043 100644 --- a/vllm/v1/worker/lora_model_runner_mixin.py +++ b/vllm/v1/worker/lora_model_runner_mixin.py @@ -3,7 +3,7 @@ """ from contextlib import contextmanager -from typing import Set, Tuple, Optional +from typing import Set, Tuple import numpy as np import torch.nn as nn @@ -127,9 +127,9 @@ def maybe_profile_with_lora(self, lora_config: LoRAConfig, # __exit__ code self.lora_manager.remove_all_adapters() - @contextmanager - def maybe_capture_model_with_lora(self, lora_config: LoRAConfig, batch_size: int): + def maybe_capture_model_with_lora(self, lora_config: LoRAConfig, + batch_size: int): if lora_config is None: yield else: @@ -139,8 +139,7 @@ def maybe_capture_model_with_lora(self, lora_config: LoRAConfig, batch_size: int prompt_lora_mapping = tuple([0] * batch_size) token_lora_mapping = tuple([0] * batch_size) - self._set_active_loras(prompt_lora_mapping, - token_lora_mapping, + self._set_active_loras(prompt_lora_mapping, token_lora_mapping, set()) yield From 83339bd875c8858d0f43d2f912d0c1682f855839 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 17:31:50 -0500 Subject: [PATCH 07/12] lora_expand opt changes --- vllm/lora/ops/v1/lora_expand.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/vllm/lora/ops/v1/lora_expand.py b/vllm/lora/ops/v1/lora_expand.py index 88db011e13182..61220b592be3d 100644 --- a/vllm/lora/ops/v1/lora_expand.py +++ b/vllm/lora/ops/v1/lora_expand.py @@ -107,10 +107,11 @@ def _lora_expand( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, - token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) - num_tokens_per_lora: torch.Tensor, # max-loras + 1 - lora_token_start_loc: torch.Tensor, # max-loras + 2 - lora_ids: torch.Tensor, # max-loras + 1 + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + 1 + lora_token_start_loc: torch.Tensor, # max-loras + 2 + lora_ids: torch.Tensor, # max-loras + 1 add_inputs: bool = False, ) -> None: """ @@ -157,8 +158,8 @@ def _lora_expand( M = inputs.size(0) N = lora_b_weights.size(-2) K = lora_b_weights.size(-1) - BLOCK_M = 16 - BLOCK_N = 64 + BLOCK_M = 32 + BLOCK_N = 128 BLOCK_K = 16 NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA From bfc51e6720b48b02465b6a309d404315e5ec743e Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 17:35:10 -0500 Subject: [PATCH 08/12] lora_shrink opt changes --- vllm/lora/ops/v1/lora_shrink.py | 117 +++++++++++++++++++------------- 1 file changed, 69 insertions(+), 48 deletions(-) diff --git a/vllm/lora/ops/v1/lora_shrink.py b/vllm/lora/ops/v1/lora_shrink.py index 7423abf7c7334..b1d623be5e054 100644 --- a/vllm/lora/ops/v1/lora_shrink.py +++ b/vllm/lora/ops/v1/lora_shrink.py @@ -3,47 +3,48 @@ import triton.language as tl import math from vllm.utils import direct_register_custom_op +from vllm.lora.ops.bgmv_shrink import bgmv_shrink @triton.jit def _lora_shrink_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - token_indices_sorted_by_lora_ids, - num_tokens_per_lora, - lora_token_start_loc, - lora_ids, - scaling, - xm_stride, - xk_stride, - l0_stride, - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - SPLIT_K: tl.constexpr, - NUM_M_CTAS: tl.constexpr, - NUM_N_CTAS: tl.constexpr, -): - - pid = tl.program_id(0) - lora_idx = pid // (NUM_M_CTAS * NUM_N_CTAS) - cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS - cta_m = pid % NUM_M_CTAS - cta_sk = tl.program_id(1) - + input_ptr, + lora_ptr, + out_ptr, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M : tl.constexpr, + BLOCK_N : tl.constexpr, + BLOCK_K : tl.constexpr, + EVEN_K : tl.constexpr, + SPLIT_K : tl.constexpr, + SMALL_BLOCK_M: tl.constexpr, + NUM_M_CTAS : tl.constexpr, + NUM_N_CTAS : tl.constexpr, + ): + lora_idx = tl.program_id(1) lora_id = tl.load(lora_ids + lora_idx) if lora_id == -1: # early exit for the no-lora case. return + pid = tl.program_id(0) + cta_sk = pid // (NUM_M_CTAS * NUM_N_CTAS) + cta_n = (pid // NUM_M_CTAS) % NUM_N_CTAS + cta_m = pid % NUM_M_CTAS + # lora m indices offsets lora_m_indices_start = tl.load(lora_token_start_loc + lora_idx) lora_m_size = tl.load(num_tokens_per_lora + lora_idx) @@ -83,9 +84,13 @@ def _lora_shrink_kernel( a_mask = offset_k[None, :] < K a_tile = tl.load(a_ptr, mask=a_mask, other=0.0) - # TODO (varun) : When a_tile and b_tile are float16s the output is also - # float16. this can lead to infs in the output. - acc += tl.dot(a_tile, b_tile) + # TODO (varun) : When a_tile and b_tile are float16s the output is also float16. this can + # lead to infs in the output. + if SMALL_BLOCK_M: + #acc += tl.sum(a_tile * b_tile.T) + acc += tl.sum(a_tile * b_tile.T, 1) + else: + acc += tl.dot(a_tile, b_tile) a_ptr += BLOCK_K * SPLIT_K * xk_stride b_ptr += BLOCK_K * SPLIT_K * lora_k_stride @@ -107,10 +112,11 @@ def _lora_shrink( inputs: torch.Tensor, lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, - token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) - num_tokens_per_lora: torch.Tensor, # max-loras - lora_token_start_loc: torch.Tensor, # max-loras - lora_ids: torch.Tensor, # max-loras + token_lora_mapping: torch.Tensor, + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + lora_token_start_loc: torch.Tensor, # max-loras + lora_ids: torch.Tensor, # max-loras scaling: float, ) -> None: """ @@ -132,6 +138,11 @@ def _lora_shrink( results to the output. """ + M = inputs.size(0) # num tokens + if M <= 16: + # GemmV is better for smaller batchsizes + return bgmv_shrink(inputs, lora_a_weights, output_tensor, token_lora_mapping, scaling) + assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] assert lora_a_weights.dtype in [ @@ -161,23 +172,31 @@ def _lora_shrink( cn_stride = output_tensor.stride(1) # TODO tuning this config - M = inputs.size(0) # num tokens N = lora_a_weights.size(-2) K = lora_a_weights.size(-1) MAX_LORAS = lora_ids.size(0) + + BLOCK_M = 32 BLOCK_N = 16 - BLOCK_K = 16 - SPLIT_K = 64 + + if M < 128: + BLOCK_K = 256 + SPLIT_K = 64 + else: + BLOCK_K = 32 + SPLIT_K = 8 + + num_warps = 4 EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 - NUM_M_CTAS = math.ceil(M / BLOCK_M) # Each BLOCK_M is its own CTA - NUM_N_CTAS = math.ceil(N / BLOCK_N) + SMALL_BLOCK_M = BLOCK_M < 16 + NUM_M_CTAS = triton.cdiv(M, BLOCK_M) + NUM_N_CTAS = triton.cdiv(N, BLOCK_N) grid = ( - MAX_LORAS * NUM_M_CTAS * NUM_N_CTAS, - SPLIT_K, - ) - + SPLIT_K * NUM_M_CTAS * NUM_N_CTAS, + MAX_LORAS, + ) _lora_shrink_kernel[grid]( inputs, lora_a_weights, @@ -201,8 +220,10 @@ def _lora_shrink( BLOCK_K, EVEN_K, SPLIT_K, + SMALL_BLOCK_M, NUM_M_CTAS, NUM_N_CTAS, + num_warps=num_warps, ) return From 9b643c6018b4bc22623ad1bc0ab6e22be5f54e58 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 17:36:50 -0500 Subject: [PATCH 09/12] v1_gpu changes to pass in lora ids --- vllm/lora/punica_wrapper/v1_gpu.py | 37 +++++++++++++++++++----------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/vllm/lora/punica_wrapper/v1_gpu.py b/vllm/lora/punica_wrapper/v1_gpu.py index 34a8dcee93ffa..bee4e7925df6e 100644 --- a/vllm/lora/punica_wrapper/v1_gpu.py +++ b/vllm/lora/punica_wrapper/v1_gpu.py @@ -20,20 +20,25 @@ @dataclass class V1KernelMeta: + token_lora_mapping: torch.Tensor token_indices_sorted_by_lora_ids: torch.Tensor active_lora_ids: torch.Tensor num_tokens_per_lora: torch.Tensor lora_token_start_loc: torch.Tensor @staticmethod - def make(max_loras: int, max_num_tokens: int, - device: torch.device) -> "V1KernelMeta": + def make(max_loras: int, max_num_tokens: int, device: torch.device) -> "V1KernelMeta": + + token_lora_mapping = torch.empty(max_num_tokens, + dtype=torch.int32, + device=device) + token_indices_sorted_by_lora_ids = torch.empty(max_num_tokens, dtype=torch.int32, device=device) # +1 because "no-lora" is also a possibility - # example: let max_loras be 3, active_lora_ids of [-1, 0, 1, 2] + # example: let max_loras be 3, active_lora_ids of [-1, 0, 2, 1] # is a possibility. active_lora_ids = torch.empty(max_loras + 1, dtype=torch.int32, @@ -49,12 +54,12 @@ def make(max_loras: int, max_num_tokens: int, # can be [0, 3, 13, 18, 20]. lora_token_start_loc = torch.zeros(max_loras + 2, dtype=torch.int32, - device=device) - return V1KernelMeta( - token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, - active_lora_ids=active_lora_ids, - num_tokens_per_lora=num_tokens_per_lora, - lora_token_start_loc=lora_token_start_loc) + device=device) + return V1KernelMeta(token_lora_mapping=token_lora_mapping, + token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, + active_lora_ids = active_lora_ids, + num_tokens_per_lora=num_tokens_per_lora, + lora_token_start_loc = lora_token_start_loc) def reset(self): self.active_lora_ids.fill_(-1) @@ -63,6 +68,10 @@ def reset(self): def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: num_tokens = token_lora_mapping.size(0) + + # copy token lora mapping + self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, non_blocking=True) + # token_indices_sorted_by_lora_ids _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, stable=True) @@ -84,11 +93,11 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( lora_token_start_loc, non_blocking=True) - def meta_args( - self, num_tokens: int - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return (self.token_indices_sorted_by_lora_ids[:num_tokens], - self.num_tokens_per_lora, self.lora_token_start_loc, + def meta_args(self, num_tokens: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return (self.token_lora_mapping[:num_tokens], + self.token_indices_sorted_by_lora_ids[:num_tokens], + self.num_tokens_per_lora, + self.lora_token_start_loc, self.active_lora_ids) From ac6e926611e44210d2ac688a71a7ba9fd48d4459 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 18:03:19 -0500 Subject: [PATCH 10/12] fix tests --- tests/lora/test_punica_sizes.py | 6 +++--- vllm/lora/ops/v1/lora_expand_slice.py | 1 + 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 12d5f2a857b66..716fe0a2eec32 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -440,12 +440,12 @@ def test_v1_shrink_expand( if op_type == "shrink": lora_shrink(inputs_tensor, lora_weights, our_out_tensor, - *v1_meta.meta_args, scaling) + *v1_meta.meta_args(inputs_tensor.size(0)), scaling) else: lora_expand(inputs_tensor, lora_weights, our_out_tensor, - *v1_meta.meta_args, + *v1_meta.meta_args(inputs_tensor.size(0)), add_inputs=True) ref_torch_groupgemm( @@ -520,7 +520,7 @@ def test_v1_expand_nslices( lora_expand_slice(inputs_tensor, lora_weights, our_outputs, - *v1_meta.meta_args, + *v1_meta.meta_args(inputs_tensor.size(0)), slice_offset, hidden_size, add_inputs=True) diff --git a/vllm/lora/ops/v1/lora_expand_slice.py b/vllm/lora/ops/v1/lora_expand_slice.py index c64f95e514445..53eb8672f8eab 100644 --- a/vllm/lora/ops/v1/lora_expand_slice.py +++ b/vllm/lora/ops/v1/lora_expand_slice.py @@ -109,6 +109,7 @@ def _lora_expand_slice( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, # inputs.size(0) token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) num_tokens_per_lora: torch.Tensor, # max-loras lora_token_start_loc: torch.Tensor, # max-loras From 97f21345c7e3f9ab29179c84bd6247d689dc1a1a Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 23:10:18 -0500 Subject: [PATCH 11/12] format --- benchmarks/benchmark_throughput.py | 4 +- tests/lora/conftest.py | 2 +- tests/lora/lora_torch_compile.py | 16 ++--- tests/lora/test_punica_sizes.py | 4 +- vllm/lora/ops/v1/lora_expand.py | 12 ++-- vllm/lora/ops/v1/lora_expand_slice.py | 6 +- vllm/lora/ops/v1/lora_expand_slices.py | 3 +- vllm/lora/ops/v1/lora_shrink.py | 80 ++++++++++----------- vllm/lora/punica_wrapper/punica_selector.py | 2 +- vllm/lora/punica_wrapper/v1_gpu.py | 29 ++++---- 10 files changed, 81 insertions(+), 77 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 10863749778fd..3e3a2ae46cf60 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -2,9 +2,9 @@ import argparse import dataclasses import json +import pickle import random import time -import pickle from functools import cache from typing import Dict, List, Optional, Tuple @@ -22,10 +22,10 @@ from vllm.lora.request import LoRARequest from vllm.lora.utils import get_adapter_absolute_path from vllm.multimodal import MultiModalDataDict +from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.transformers_utils.tokenizer import AnyTokenizer, get_lora_tokenizer from vllm.utils import FlexibleArgumentParser, merge_async_iterators -from vllm.outputs import RequestOutput SAMPLING_TEMPERATURE = 0.0 SAMPLING_TOP_P = 1.0 diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 1b58db7fa9f88..acc0f04ea5332 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -1,5 +1,6 @@ import tempfile from collections import OrderedDict +from contextlib import contextmanager from typing import Dict, List, TypedDict from unittest.mock import MagicMock, patch @@ -20,7 +21,6 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader import get_model -from contextlib import contextmanager class ContextIDInfo(TypedDict): diff --git a/tests/lora/lora_torch_compile.py b/tests/lora/lora_torch_compile.py index 2afd44e54a6ee..502da2807afed 100644 --- a/tests/lora/lora_torch_compile.py +++ b/tests/lora/lora_torch_compile.py @@ -2,26 +2,22 @@ from typing import Dict, List, Optional, Tuple import torch +from conftest import _dist_init +from utils import DummyLoRAManager from vllm.config import LoRAConfig +from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment) # yapf conflicts with isort for this block # yapf: disable -from vllm.lora.layers import (LoRAMapping, - BaseLayerWithLoRA, +from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, VocabParallelEmbeddingWithLoRA) +from vllm.lora.models import LoRALayerWeights, PackedLoRALayerWeights # yapf: enable from vllm.lora.punica_wrapper import get_punica_wrapper - from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.lora.models import (LoRALayerWeights, PackedLoRALayerWeights) - -from utils import DummyLoRAManager -from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, - init_distributed_environment) -from conftest import _dist_init - def get_random_id_to_index(num_loras: int, num_slots: int, diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 716fe0a2eec32..bcb0438e00887 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -13,12 +13,10 @@ from vllm.lora.ops.sgmv_expand import sgmv_expand from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink - from vllm.lora.ops.v1.lora_expand import lora_expand -from vllm.lora.ops.v1.lora_shrink import lora_shrink from vllm.lora.ops.v1.lora_expand_slice import lora_expand_slice +from vllm.lora.ops.v1.lora_shrink import lora_shrink from vllm.lora.punica_wrapper.v1_gpu import V1KernelMeta - from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, diff --git a/vllm/lora/ops/v1/lora_expand.py b/vllm/lora/ops/v1/lora_expand.py index 61220b592be3d..9b7e0c955967e 100644 --- a/vllm/lora/ops/v1/lora_expand.py +++ b/vllm/lora/ops/v1/lora_expand.py @@ -1,7 +1,9 @@ +import math + import torch import triton import triton.language as tl -import math + from vllm.utils import direct_register_custom_op @@ -108,10 +110,10 @@ def _lora_expand( lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, token_lora_mapping: torch.Tensor, - token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) - num_tokens_per_lora: torch.Tensor, # max-loras + 1 - lora_token_start_loc: torch.Tensor, # max-loras + 2 - lora_ids: torch.Tensor, # max-loras + 1 + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + 1 + lora_token_start_loc: torch.Tensor, # max-loras + 2 + lora_ids: torch.Tensor, # max-loras + 1 add_inputs: bool = False, ) -> None: """ diff --git a/vllm/lora/ops/v1/lora_expand_slice.py b/vllm/lora/ops/v1/lora_expand_slice.py index 53eb8672f8eab..3618de5da4c7d 100644 --- a/vllm/lora/ops/v1/lora_expand_slice.py +++ b/vllm/lora/ops/v1/lora_expand_slice.py @@ -1,7 +1,9 @@ +import math + import torch import triton import triton.language as tl -import math + from vllm.utils import direct_register_custom_op @@ -109,7 +111,7 @@ def _lora_expand_slice( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, - token_lora_mapping: torch.Tensor, # inputs.size(0) + token_lora_mapping: torch.Tensor, # inputs.size(0) token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) num_tokens_per_lora: torch.Tensor, # max-loras lora_token_start_loc: torch.Tensor, # max-loras diff --git a/vllm/lora/ops/v1/lora_expand_slices.py b/vllm/lora/ops/v1/lora_expand_slices.py index e38ed5e8e2390..e82c7ece71fd8 100644 --- a/vllm/lora/ops/v1/lora_expand_slices.py +++ b/vllm/lora/ops/v1/lora_expand_slices.py @@ -1,7 +1,8 @@ +import math + import torch import triton import triton.language as tl -import math from vllm.utils import direct_register_custom_op diff --git a/vllm/lora/ops/v1/lora_shrink.py b/vllm/lora/ops/v1/lora_shrink.py index b1d623be5e054..af52ce74bea09 100644 --- a/vllm/lora/ops/v1/lora_shrink.py +++ b/vllm/lora/ops/v1/lora_shrink.py @@ -1,39 +1,39 @@ import torch import triton import triton.language as tl -import math -from vllm.utils import direct_register_custom_op + from vllm.lora.ops.bgmv_shrink import bgmv_shrink +from vllm.utils import direct_register_custom_op @triton.jit def _lora_shrink_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - token_indices_sorted_by_lora_ids, - num_tokens_per_lora, - lora_token_start_loc, - lora_ids, - scaling, - xm_stride, - xk_stride, - l0_stride, - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - BLOCK_M : tl.constexpr, - BLOCK_N : tl.constexpr, - BLOCK_K : tl.constexpr, - EVEN_K : tl.constexpr, - SPLIT_K : tl.constexpr, - SMALL_BLOCK_M: tl.constexpr, - NUM_M_CTAS : tl.constexpr, - NUM_N_CTAS : tl.constexpr, - ): + input_ptr, + lora_ptr, + out_ptr, + N, + K, + token_indices_sorted_by_lora_ids, + num_tokens_per_lora, + lora_token_start_loc, + lora_ids, + scaling, + xm_stride, + xk_stride, + l0_stride, + lora_k_stride, + lora_n_stride, + cm_stride, + cn_stride, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SMALL_BLOCK_M: tl.constexpr, + NUM_M_CTAS: tl.constexpr, + NUM_N_CTAS: tl.constexpr, +): lora_idx = tl.program_id(1) lora_id = tl.load(lora_ids + lora_idx) if lora_id == -1: @@ -84,8 +84,8 @@ def _lora_shrink_kernel( a_mask = offset_k[None, :] < K a_tile = tl.load(a_ptr, mask=a_mask, other=0.0) - # TODO (varun) : When a_tile and b_tile are float16s the output is also float16. this can - # lead to infs in the output. + # TODO (varun) : When a_tile and b_tile are float16s the output is + # also float16. this can lead to infs in the output. if SMALL_BLOCK_M: #acc += tl.sum(a_tile * b_tile.T) acc += tl.sum(a_tile * b_tile.T, 1) @@ -113,10 +113,10 @@ def _lora_shrink( lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, token_lora_mapping: torch.Tensor, - token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) - num_tokens_per_lora: torch.Tensor, # max-loras - lora_token_start_loc: torch.Tensor, # max-loras - lora_ids: torch.Tensor, # max-loras + token_indices_sorted_by_lora_ids: torch.Tensor, # inputs.size(0) + num_tokens_per_lora: torch.Tensor, # max-loras + lora_token_start_loc: torch.Tensor, # max-loras + lora_ids: torch.Tensor, # max-loras scaling: float, ) -> None: """ @@ -138,10 +138,11 @@ def _lora_shrink( results to the output. """ - M = inputs.size(0) # num tokens + M = inputs.size(0) # num tokens if M <= 16: # GemmV is better for smaller batchsizes - return bgmv_shrink(inputs, lora_a_weights, output_tensor, token_lora_mapping, scaling) + return bgmv_shrink(inputs, lora_a_weights, output_tensor, + token_lora_mapping, scaling) assert inputs.dtype == lora_a_weights.dtype assert inputs.dtype in [torch.float16, torch.bfloat16] @@ -176,7 +177,6 @@ def _lora_shrink( K = lora_a_weights.size(-1) MAX_LORAS = lora_ids.size(0) - BLOCK_M = 32 BLOCK_N = 16 @@ -194,9 +194,9 @@ def _lora_shrink( NUM_N_CTAS = triton.cdiv(N, BLOCK_N) grid = ( - SPLIT_K * NUM_M_CTAS * NUM_N_CTAS, - MAX_LORAS, - ) + SPLIT_K * NUM_M_CTAS * NUM_N_CTAS, + MAX_LORAS, + ) _lora_shrink_kernel[grid]( inputs, lora_a_weights, diff --git a/vllm/lora/punica_wrapper/punica_selector.py b/vllm/lora/punica_wrapper/punica_selector.py index 7f4d092179b1b..8c102a15b638d 100644 --- a/vllm/lora/punica_wrapper/punica_selector.py +++ b/vllm/lora/punica_wrapper/punica_selector.py @@ -1,6 +1,6 @@ +import vllm.envs as envs from vllm.platforms import current_platform from vllm.utils import print_info_once -import vllm.envs as envs from .punica_base import PunicaWrapperBase diff --git a/vllm/lora/punica_wrapper/v1_gpu.py b/vllm/lora/punica_wrapper/v1_gpu.py index bee4e7925df6e..9e02afc61b3ac 100644 --- a/vllm/lora/punica_wrapper/v1_gpu.py +++ b/vllm/lora/punica_wrapper/v1_gpu.py @@ -1,5 +1,5 @@ -from typing import TYPE_CHECKING, Optional, Tuple, Union, final, List from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple, Union, final import torch @@ -27,7 +27,8 @@ class V1KernelMeta: lora_token_start_loc: torch.Tensor @staticmethod - def make(max_loras: int, max_num_tokens: int, device: torch.device) -> "V1KernelMeta": + def make(max_loras: int, max_num_tokens: int, + device: torch.device) -> "V1KernelMeta": token_lora_mapping = torch.empty(max_num_tokens, dtype=torch.int32, @@ -54,12 +55,13 @@ def make(max_loras: int, max_num_tokens: int, device: torch.device) -> "V1Kernel # can be [0, 3, 13, 18, 20]. lora_token_start_loc = torch.zeros(max_loras + 2, dtype=torch.int32, - device=device) - return V1KernelMeta(token_lora_mapping=token_lora_mapping, - token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, - active_lora_ids = active_lora_ids, - num_tokens_per_lora=num_tokens_per_lora, - lora_token_start_loc = lora_token_start_loc) + device=device) + return V1KernelMeta( + token_lora_mapping=token_lora_mapping, + token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, + active_lora_ids=active_lora_ids, + num_tokens_per_lora=num_tokens_per_lora, + lora_token_start_loc=lora_token_start_loc) def reset(self): self.active_lora_ids.fill_(-1) @@ -70,7 +72,8 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: num_tokens = token_lora_mapping.size(0) # copy token lora mapping - self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, non_blocking=True) + self.token_lora_mapping[:num_tokens].copy_(token_lora_mapping, + non_blocking=True) # token_indices_sorted_by_lora_ids _, token_indices_sorted_by_lora_ids = torch.sort(token_lora_mapping, @@ -93,11 +96,13 @@ def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: self.lora_token_start_loc[1:1 + lora_token_start_loc.size(0)].copy_( lora_token_start_loc, non_blocking=True) - def meta_args(self, num_tokens: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + def meta_args( + self, num_tokens: int + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, + torch.Tensor]: return (self.token_lora_mapping[:num_tokens], self.token_indices_sorted_by_lora_ids[:num_tokens], - self.num_tokens_per_lora, - self.lora_token_start_loc, + self.num_tokens_per_lora, self.lora_token_start_loc, self.active_lora_ids) From 8ff67c5a19dc51e9bcb0b4d1bf21899233b926df Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sun, 29 Dec 2024 23:24:08 -0500 Subject: [PATCH 12/12] fix fake functions --- vllm/lora/ops/v1/lora_expand.py | 1 + vllm/lora/ops/v1/lora_expand_slice.py | 2 +- vllm/lora/ops/v1/lora_shrink.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/lora/ops/v1/lora_expand.py b/vllm/lora/ops/v1/lora_expand.py index 9b7e0c955967e..8b143c2eead5b 100644 --- a/vllm/lora/ops/v1/lora_expand.py +++ b/vllm/lora/ops/v1/lora_expand.py @@ -219,6 +219,7 @@ def lora_expand_fake( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor, num_tokens_per_lora: torch.Tensor, lora_token_start_loc: torch.Tensor, diff --git a/vllm/lora/ops/v1/lora_expand_slice.py b/vllm/lora/ops/v1/lora_expand_slice.py index 3618de5da4c7d..b89d46512a4e7 100644 --- a/vllm/lora/ops/v1/lora_expand_slice.py +++ b/vllm/lora/ops/v1/lora_expand_slice.py @@ -222,11 +222,11 @@ def _lora_expand_slice( ) return - def lora_expand_slice_fake( inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor, num_tokens_per_lora: torch.Tensor, lora_token_start_loc: torch.Tensor, diff --git a/vllm/lora/ops/v1/lora_shrink.py b/vllm/lora/ops/v1/lora_shrink.py index af52ce74bea09..bc3ecf036f9b9 100644 --- a/vllm/lora/ops/v1/lora_shrink.py +++ b/vllm/lora/ops/v1/lora_shrink.py @@ -227,11 +227,11 @@ def _lora_shrink( ) return - def lora_shrink_fake( inputs: torch.Tensor, lora_a_weights: torch.Tensor, output_tensor: torch.Tensor, + token_lora_mapping: torch.Tensor, token_indices_sorted_by_lora_ids: torch.Tensor, num_tokens_per_lora: torch.Tensor, lora_token_start_loc: torch.Tensor,