Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into image_bench-on-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
lk-chen authored Nov 5, 2024
2 parents 06c36be + cd34029 commit 0dcd3c1
Show file tree
Hide file tree
Showing 23 changed files with 456 additions and 443 deletions.
7 changes: 0 additions & 7 deletions Dockerfile.tpu
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,6 @@ RUN apt-get update && apt-get install -y \
git \
ffmpeg libsm6 libxext6 libgl1

# Install the TPU and Pallas dependencies.
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

# Build vLLM.
COPY . .
ARG GIT_REPO_CHECK=0
Expand All @@ -25,7 +19,6 @@ ENV VLLM_TARGET_DEVICE="tpu"
RUN --mount=type=cache,target=/root/.cache/pip \
--mount=type=bind,source=.git,target=.git \
python3 -m pip install \
'cmake>=3.26' ninja packaging 'setuptools-scm>=8' wheel jinja2 \
-r requirements-tpu.txt
RUN python3 setup.py develop

Expand Down
57 changes: 5 additions & 52 deletions docs/source/getting_started/tpu-installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,27 +119,19 @@ Uninstall the existing `torch` and `torch_xla` packages:
pip uninstall torch torch-xla -y
Install `torch` and `torch_xla`
Install build dependencies:

.. code-block:: bash
pip install --pre torch==2.6.0.dev20241028+cpu torchvision==0.20.0.dev20241028+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.html
pip install -r requirements-tpu.txt
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Install JAX and Pallas:
Run the setup script:

.. code-block:: bash
pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
Install other build dependencies:
VLLM_TARGET_DEVICE="tpu" python setup.py develop
.. code-block:: bash
pip install -r requirements-tpu.txt
VLLM_TARGET_DEVICE="tpu" python setup.py develop
sudo apt-get install libopenblas-base libopenmpi-dev libomp-dev
Provision Cloud TPUs with GKE
-----------------------------
Expand Down Expand Up @@ -168,45 +160,6 @@ Run the Docker image with the following command:
$ # Make sure to add `--privileged --net host --shm-size=16G`.
$ docker run --privileged --net host --shm-size=16G -it vllm-tpu
.. _build_from_source_tpu:

Build from source
-----------------

You can also build and install the TPU backend from source.

First, install the dependencies:

.. code-block:: console
$ # (Recommended) Create a new conda environment.
$ conda create -n myenv python=3.10 -y
$ conda activate myenv
$ # Clean up the existing torch and torch-xla packages.
$ pip uninstall torch torch-xla -y
$ # Install PyTorch and PyTorch XLA.
$ export DATE="20241017"
$ export TORCH_VERSION="2.6.0"
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
$ pip install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-${TORCH_VERSION}.dev${DATE}-cp310-cp310-linux_x86_64.whl
$ # Install JAX and Pallas.
$ pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
$ pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
$ # Install other build dependencies.
$ pip install -r requirements-tpu.txt
Next, build vLLM from source. This will only take a few seconds:

.. code-block:: console
$ VLLM_TARGET_DEVICE="tpu" python setup.py develop
.. note::

Since TPU relies on XLA which requires static shapes, vLLM bucketizes the possible input shapes and compiles an XLA graph for each different shape.
Expand Down
20 changes: 18 additions & 2 deletions requirements-tpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@
-r requirements-common.txt

# Dependencies for TPU
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
cmake>=3.26
ninja
packaging
setuptools-scm>=8
wheel
jinja2
ray[default]

# Install torch_xla
--pre
--extra-index-url https://download.pytorch.org/whl/nightly/cpu
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.6.0.dev20241028+cpu
torchvision==0.20.0.dev20241028+cpu
torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl
jaxlib==0.4.32.dev20240829
jax==0.4.32.dev20240829
57 changes: 26 additions & 31 deletions tests/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Tuple

from vllm import SamplingParams
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob, Sequence, SequenceGroup

Expand All @@ -27,10 +28,7 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
prompt = Sequence(int(request_id),
inputs={
"prompt": prompt_str,
"prompt_token_ids": prompt_tokens,
},
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
block_size=block_size)
seq_group = SequenceGroup(request_id=request_id,
seqs=[prompt],
Expand Down Expand Up @@ -63,23 +61,21 @@ def create_dummy_prompt_encoder_decoder(
encoder_prompt_tokens = list(reversed(list(range(encoder_prompt_length))))
encoder_prompt_str = " ".join([str(t) for t in encoder_prompt_tokens])

inputs = {
"prompt": decoder_prompt_str,
"prompt_token_ids": decoder_prompt_tokens,
"encoder_prompt": encoder_prompt_str,
"encoder_prompt_token_ids": encoder_prompt_tokens,
"multi_modal_data": None,
inputs: EncoderDecoderInputs = {
"decoder": token_inputs(decoder_prompt_tokens,
prompt=decoder_prompt_str),
"encoder": token_inputs(encoder_prompt_tokens,
prompt=encoder_prompt_str),
}

decoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=True)
inputs=inputs["decoder"],
block_size=block_size)

encoder_prompt = Sequence(int(request_id),
inputs=inputs,
block_size=block_size,
from_decoder_prompt=False)
inputs=inputs["encoder"],
block_size=block_size)

seq_group = SequenceGroup(request_id=request_id,
seqs=[decoder_prompt],
sampling_params=SamplingParams(best_of=best_of),
Expand Down Expand Up @@ -108,7 +104,7 @@ def create_seq_group(
for seq_id_offset, output_len in enumerate(seq_output_lens):
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs={"prompt_token_ids": prompt_token_ids},
inputs=token_inputs(prompt_token_ids),
block_size=16,
)

Expand Down Expand Up @@ -143,21 +139,19 @@ def create_seq_group_encoder_decoder(

prompt_token_ids = [0] * seq_prompt_len

inputs = {
"prompt": "",
"prompt_token_ids": prompt_token_ids,
"encoder_prompt": "",
"encoder_prompt_token_ids": prompt_token_ids,
"multi_modal_data": None,
inputs: EncoderDecoderInputs = {
"decoder": token_inputs(prompt_token_ids),
"encoder": token_inputs(prompt_token_ids),
}

seqs = []
for seq_id_offset, output_len in enumerate(seq_output_lens):
# Construct decoder input sequences
seq = Sequence(seq_id=seq_id_start + seq_id_offset,
inputs=inputs,
block_size=16,
from_decoder_prompt=True)
seq = Sequence(
seq_id=seq_id_start + seq_id_offset,
inputs=inputs["decoder"],
block_size=16,
)

for i in range(output_len):
seq.append_token_id(
Expand All @@ -167,10 +161,11 @@ def create_seq_group_encoder_decoder(
seqs.append(seq)

# Encoder input sequence
encoder_seq = Sequence(seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs,
block_size=16,
from_decoder_prompt=False)
encoder_seq = Sequence(
seq_id=seq_id_start + len(seq_output_lens),
inputs=inputs["encoder"],
block_size=16,
)

return SequenceGroup(request_id=request_id,
seqs=seqs,
Expand Down
3 changes: 2 additions & 1 deletion tests/engine/output_processor/test_stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers import PreTrainedTokenizer

from vllm.engine.output_processor.stop_checker import StopChecker
from vllm.inputs import token_inputs
from vllm.sampling_params import SamplingParams
from vllm.sequence import Logprob, Sequence, SequenceStatus

Expand All @@ -15,7 +16,7 @@ def sequence_with_eos(text: str, eos_token: str,
"""
seq = Sequence(
seq_id=0,
inputs={"prompt_token_ids": []},
inputs=token_inputs([]),
block_size=16,
eos_token_id=eos_token_id,
)
Expand Down
7 changes: 3 additions & 4 deletions tests/test_cache_block_hashing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import pytest

from vllm.inputs import token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import Sequence
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
Expand Down Expand Up @@ -70,10 +71,8 @@ def test_auto_prefix_caching(model: str, block_size: int, max_num_seqs: int,
hashes[-1].append([])
prompt_token_ids = tokenizer.encode(prompt)
seq = Sequence(seq_id,
inputs={
"prompt": prompt,
"prompt_token_ids": prompt_token_ids,
},
inputs=token_inputs(prompt_token_ids,
prompt=prompt),
block_size=block_size,
eos_token_id=tokenizer.tokenizer.eos_token_id,
lora_request=lora_request)
Expand Down
6 changes: 2 additions & 4 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from transformers import AutoTokenizer

from vllm.inputs import token_inputs
from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup
from vllm.transformers_utils.detokenizer import (Detokenizer,
detokenize_incrementally)
Expand Down Expand Up @@ -169,10 +170,7 @@ def create_sequence(prompt_token_ids=None):
prompt_token_ids = prompt_token_ids or [1]
return Sequence(
seq_id=0,
inputs={
"prompt": "<s>",
"prompt_token_ids": prompt_token_ids,
},
inputs=token_inputs(prompt_token_ids, prompt="<s>"),
block_size=16,
)

Expand Down
15 changes: 5 additions & 10 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pickle
import time
from contextlib import contextmanager
Expand All @@ -18,12 +19,6 @@

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL = 1e-7

logger = init_logger(__name__)


Expand Down Expand Up @@ -333,8 +328,8 @@ def acquire_write(self):
# if this block is not ready to write,
# we need to wait until it is read by all readers

# wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
# Release the processor to other threads
os.sched_yield()

# if we wait for a long time, we should warn the user
if (time.monotonic() - start_time >
Expand Down Expand Up @@ -387,8 +382,8 @@ def acquire_read(self):
# if this block is not ready,
# we need to wait until it is written

# wait for a while
time.sleep(RINGBUFFER_SLEEP_INTERVAL)
# Release the processor to other threads
os.sched_yield()

# if we wait for a long time, we should warn the user
if (time.monotonic() - start_time >
Expand Down
Loading

0 comments on commit 0dcd3c1

Please sign in to comment.