Skip to content

Commit

Permalink
Merge branch 'vllm-project:main' into v0.2.4-rocm
Browse files Browse the repository at this point in the history
  • Loading branch information
iAmir97 authored Dec 14, 2023
2 parents 9d09d96 + 05bdf4e commit 24fc0b2
Show file tree
Hide file tree
Showing 16 changed files with 398 additions and 455 deletions.
14 changes: 1 addition & 13 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ ENV NVCC_THREADS=$nvcc_threads

RUN python3 setup.py build_ext --inplace

# Build the megablocks library as wheel because it doesn't publish pre-built wheels.
# https://github.com/stanford-futuredata/megablocks/commit/5897cd6f254b7b3edf7a708a3a3314ecb54b6f78
RUN apt-get install -y git && \
git clone https://github.com/stanford-futuredata/megablocks.git && \
cd megablocks && \
git checkout 5897cd6f254b7b3edf7a708a3a3314ecb54b6f78 && \
MAX_JOBS=8 NVCC_THREADS=8 python3 setup.py bdist_wheel

# image to run unit testing suite
FROM dev AS test

Expand Down Expand Up @@ -85,12 +77,8 @@ FROM vllm-base AS vllm-openai
RUN --mount=type=cache,target=/root/.cache/pip \
pip install accelerate

COPY vllm vllm
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
COPY --from=build /workspace/megablocks/dist/*.whl /tmp/
RUN --mount=type=cache,target=/root/.cache/pip \
pip install /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl && \
rm /tmp/megablocks-0.5.0-cp310-cp310-linux_x86_64.whl
COPY vllm vllm

ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]

4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/get
```bash
pip install vllm
```
**NOTE:** The Mixtral model additionally requires `megablocks` which can be installed with pip or [from source](https://github.com/stanford-futuredata/megablocks):
```bash
pip install megablocks
```

## Getting Started

Expand Down
7 changes: 4 additions & 3 deletions docs/source/getting_started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ You can install vLLM using pip:
.. code-block:: console
$ # (Optional) Create a new conda environment.
$ conda create -n myenv python=3.8 -y
$ conda create -n myenv python=3.9 -y
$ conda activate myenv
$ # Install vLLM with CUDA 12.1.
Expand All @@ -34,8 +34,9 @@ You can install vLLM using pip:
.. code-block:: console
$ # Install vLLM with CUDA 11.8.
$ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`).
$ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl
$ export VLLM_VERSION=0.2.4
$ export PYTHON_VERSION=39
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
$ # Re-install PyTorch with CUDA 11.8.
$ pip uninstall torch -y
Expand Down
12 changes: 9 additions & 3 deletions docs/source/models/supported_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ If your model uses one of the above model architectures, you can seamlessly run
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.

.. note::
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.

.. tip::
The easiest way to check if your model is supported is to run the program below:

Expand All @@ -84,18 +87,21 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
output = llm.generate("Hello, my name is")
print(output)
To use model from www.modelscope.cn
If vLLM successfully generates text, it indicates that your model is supported.

.. tip::
To use models from `ModelScope <www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:

.. code-block:: shell
$ export VLLM_USE_MODELSCOPE=True
And use with :code:`trust_remote_code=True`.

.. code-block:: python
from vllm import LLM
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
output = llm.generate("Hello, my name is")
print(output)
If vLLM successfully generates text, it indicates that your model is supported.
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# formatting
yapf==0.32.0
toml==0.10.2
ruff==0.1.5

# type checking
Expand Down
34 changes: 23 additions & 11 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from typing import List, Optional, Tuple

import pytest
Expand All @@ -7,21 +8,32 @@
from vllm import LLM, SamplingParams
from vllm.transformers_utils.tokenizer import get_tokenizer

_TEST_PROMPTS = [
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
"Describe the basic components of a neural network and how it can be trained.",
"Write a short story about a robot that dreams for the first time.",
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
]
_TEST_PROMPTS = ["prompts/example.txt"]
_LONG_PROMPTS = ["prompts/summary.txt"]


def _read_prompts(filename: str) -> str:
prompts = []
with open(filename, "r") as f:
prompt = f.readline()
prompts.append(prompt)
return prompts


@pytest.fixture
def example_prompts() -> List[str]:
return _TEST_PROMPTS
prompts = []
for filename in _TEST_PROMPTS:
prompts += _read_prompts(os.path.join("tests", filename))
return prompts


@pytest.fixture
def example_long_prompts() -> List[str]:
prompts = []
for filename in _LONG_PROMPTS:
prompts += _read_prompts(os.path.join("tests", filename))
return prompts


_STR_DTYPE_TO_TORCH_DTYPE = {
Expand Down
37 changes: 37 additions & 0 deletions tests/models/test_mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
Run `pytest tests/models/test_mistral.py --forked`.
"""
import pytest

MODELS = [
"mistralai/Mistral-7B-Instruct-v0.1",
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128])
def test_models(
hf_runner,
vllm_runner,
example_long_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
hf_model = hf_runner(model, dtype=dtype)
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
del hf_model

vllm_model = vllm_runner(model, dtype=dtype)
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
del vllm_model

for i in range(len(example_long_prompts)):
hf_output_ids, hf_output_str = hf_outputs[i]
vllm_output_ids, vllm_output_str = vllm_outputs[i]
assert hf_output_str == vllm_output_str, (
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
assert hf_output_ids == vllm_output_ids, (
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
8 changes: 8 additions & 0 deletions tests/prompts/example.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
Describe the basic components of a neural network and how it can be trained.
Write a short story about a robot that dreams for the first time.
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'
1 change: 1 addition & 0 deletions tests/prompts/summary.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.outputs import CompletionOutput, RequestOutput
from vllm.sampling_params import SamplingParams

__version__ = "0.2.4"
__version__ = "0.2.5"

__all__ = [
"LLM",
Expand Down
16 changes: 9 additions & 7 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,14 +120,16 @@ def _verify_load_format(self) -> None:
if load_format == "auto":
load_format = "pt"

# FIXME(woosuk): This is a temporary hack. Support safetensor weights.
# TODO: Remove this check once HF updates the pt weights of Mixtral.
architectures = getattr(self.hf_config, "architectures", [])
if "MixtralForCausalLM" in architectures and load_format != "pt":
logger.info(
"Currently, only 'pt' format is supported for Mixtral. "
"Changing the format to 'pt'. This may re-download the "
"weights if you have downloaded the safetensor weights.")
load_format = "pt"
if "MixtralForCausalLM" in architectures:
if load_format == "pt":
raise ValueError(
"Currently, the 'pt' format is not supported for Mixtral. "
"Please use the 'safetensors' format instead. ")
elif load_format == "auto":
# Do not fall back to pt weights.
load_format = "safetensors"

self.load_format = load_format

Expand Down
13 changes: 7 additions & 6 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import time
from functools import partial
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union)
Union, AsyncIterator)

from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -401,11 +401,12 @@ async def add_request(
return stream

async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None) -> RequestOutput:
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
prompt_token_ids: Optional[List[int]] = None
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
Generate outputs for a request. This method is a coroutine. It adds the
Expand Down
62 changes: 5 additions & 57 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,9 @@
from transformers import PretrainedConfig

from vllm.config import ModelConfig
from vllm.model_executor.models import *
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.weight_utils import (get_quant_config,
initialize_dummy_weights)
from vllm.utils import is_hip
from vllm.logger import init_logger

logger = init_logger(__name__)

# TODO(woosuk): Lazy-load the model classes.
_MODEL_REGISTRY = {
"AquilaModel": AquilaForCausalLM,
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
"BloomForCausalLM": BloomForCausalLM,
"ChatGLMModel": ChatGLMForCausalLM,
"ChatGLMForConditionalGeneration": ChatGLMForCausalLM,
"FalconForCausalLM": FalconForCausalLM,
"GPT2LMHeadModel": GPT2LMHeadModel,
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
"GPTJForCausalLM": GPTJForCausalLM,
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
"InternLMForCausalLM": InternLMForCausalLM,
"LlamaForCausalLM": LlamaForCausalLM,
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
"MistralForCausalLM": MistralForCausalLM,
"MixtralForCausalLM": MixtralForCausalLM,
# transformers's mpt class has lower case
"MptForCausalLM": MPTForCausalLM,
"MPTForCausalLM": MPTForCausalLM,
"OPTForCausalLM": OPTForCausalLM,
"PhiForCausalLM": PhiForCausalLM,
"QWenLMHeadModel": QWenLMHeadModel,
"RWForCausalLM": FalconForCausalLM,
"YiForCausalLM": YiForCausalLM,
}

# Models to be disabled in ROCm
_ROCM_UNSUPPORTED_MODELS = []
if is_hip():
for rocm_model in _ROCM_UNSUPPORTED_MODELS:
del _MODEL_REGISTRY[rocm_model]

# Models partially supported in ROCm
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
"MistralForCausalLM":
"Sliding window attention is not supported in ROCm's flash attention",
}


@contextlib.contextmanager
Expand All @@ -69,19 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype):
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _MODEL_REGISTRY:
if is_hip() and arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
f"{arch} is not fully supported in ROCm. Reason: "
f"{_ROCM_PARTIALLY_SUPPORTED_MODELS[arch]}")
return _MODEL_REGISTRY[arch]
elif arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {arch} is not supported by ROCm for now. \n"
f"Supported architectures {list(_MODEL_REGISTRY.keys())}")
model_cls = ModelRegistry.load_model_cls(arch)
if model_cls is not None:
return model_cls
raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
f"Supported architectures: {ModelRegistry.get_supported_archs()}")


def get_model(model_config: ModelConfig) -> nn.Module:
Expand Down
Loading

0 comments on commit 24fc0b2

Please sign in to comment.