Skip to content

Commit

Permalink
Merge pull request #3 from geniusrise/feat/vllm
Browse files Browse the repository at this point in the history
Add VLLM support for LM and CHAT
  • Loading branch information
ixaxaar authored Feb 24, 2024
2 parents a965a86 + 7be87a7 commit 573049b
Show file tree
Hide file tree
Showing 8 changed files with 1,005 additions and 40 deletions.
142 changes: 112 additions & 30 deletions geniusrise_text/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@
sequential_lock = threading.Lock()


def sequential_tool():
with sequential_lock:
# Yield to signal that the request can proceed
yield


# Register the custom tool
cherrypy.tools.sequential = cherrypy.Tool("before_handler", sequential_tool)


class TextAPI(TextBulk):
"""
A class representing a Hugging Face API for generating text using a pre-trained language model.
Expand Down Expand Up @@ -148,6 +138,7 @@ def validate_password(self, realm, username, password):
def listen(
self,
model_name: str,
# Huggingface params
model_class: str = "AutoModelForCausalLM",
tokenizer_class: str = "AutoTokenizer",
use_cuda: bool = False,
Expand All @@ -159,6 +150,35 @@ def listen(
compile: bool = False,
awq_enabled: bool = False,
flash_attention: bool = False,
concurrent_queries: bool = False,
use_vllm: bool = False,
# VLLM params
vllm_tokenizer_mode: str = "auto",
vllm_download_dir: Optional[str] = None,
vllm_load_format: str = "auto",
vllm_seed: int = 42,
vllm_max_model_len: int = 1024,
vllm_enforce_eager: bool = False,
vllm_max_context_len_to_capture: int = 8192,
vllm_block_size: int = 16,
vllm_gpu_memory_utilization: float = 0.90,
vllm_swap_space: int = 4,
vllm_sliding_window: Optional[int] = None,
vllm_pipeline_parallel_size: int = 1,
vllm_tensor_parallel_size: int = 1,
vllm_worker_use_ray: bool = False,
vllm_max_parallel_loading_workers: Optional[int] = None,
vllm_disable_custom_all_reduce: bool = False,
vllm_max_num_batched_tokens: Optional[int] = None,
vllm_max_num_seqs: int = 64,
vllm_max_paddings: int = 512,
vllm_max_lora_rank: Optional[int] = None,
vllm_max_loras: Optional[int] = None,
vllm_max_cpu_loras: Optional[int] = None,
vllm_lora_extra_vocab_size: int = 0,
vllm_placement_group: Optional[dict] = None,
vllm_log_stats: bool = False,
# Server params
endpoint: str = "*",
port: int = 3000,
cors_domain: str = "http://localhost:3000",
Expand All @@ -182,6 +202,7 @@ def listen(
compile (bool, optional): Whether to compile the model before fine-tuning. Defaults to True.
awq_enabled (bool): Whether to use AWQ for model optimization. Default is False.
flash_attention (bool): Whether to use flash attention 2. Default is False.
concurrent_queries: (bool): Whether the API uses a single thread for inference (usually true for a single GPU system)
endpoint (str, optional): The endpoint to listen on. Defaults to "*".
port (int, optional): The port to listen on. Defaults to 3000.
cors_domain (str, optional): The domain to allow CORS requests from. Defaults to "http://localhost:3000".
Expand All @@ -198,8 +219,11 @@ def listen(
self.device_map = device_map
self.max_memory = max_memory
self.torchscript = torchscript
self.flash_attention = flash_attention
self.awq_enabled = awq_enabled
self.flash_attention = flash_attention
self.use_vllm = use_vllm
self.concurrent_queries = concurrent_queries

self.model_args = model_args
self.username = username
self.password = password
Expand All @@ -219,24 +243,72 @@ def listen(
self.tokenizer_name = tokenizer_name
self.tokenizer_revision = tokenizer_revision

self.model, self.tokenizer = self.load_models(
model_name=self.model_name,
tokenizer_name=self.tokenizer_name,
model_revision=self.model_revision,
tokenizer_revision=self.tokenizer_revision,
model_class=self.model_class,
tokenizer_class=self.tokenizer_class,
use_cuda=self.use_cuda,
precision=self.precision,
quantization=self.quantization,
device_map=self.device_map,
max_memory=self.max_memory,
torchscript=self.torchscript,
awq_enabled=self.awq_enabled,
flash_attention=self.flash_attention,
compile=compile,
**self.model_args,
)
if use_vllm:
self.model = self.load_models_vllm(
model=model_name,
tokenizer=tokenizer_name,
tokenizer_mode=vllm_tokenizer_mode,
trust_remote_code=True,
download_dir=vllm_download_dir,
load_format=vllm_load_format,
dtype=self._get_torch_dtype(precision),
seed=vllm_seed,
revision=model_revision,
tokenizer_revision=tokenizer_revision,
max_model_len=vllm_max_model_len,
quantization=(None if quantization == 0 else f"{quantization}-bit"),
enforce_eager=vllm_enforce_eager,
max_context_len_to_capture=vllm_max_context_len_to_capture,
block_size=vllm_block_size,
gpu_memory_utilization=vllm_gpu_memory_utilization,
swap_space=vllm_swap_space,
cache_dtype="auto",
sliding_window=vllm_sliding_window,
pipeline_parallel_size=vllm_pipeline_parallel_size,
tensor_parallel_size=vllm_tensor_parallel_size,
worker_use_ray=vllm_worker_use_ray,
max_parallel_loading_workers=vllm_max_parallel_loading_workers,
disable_custom_all_reduce=vllm_disable_custom_all_reduce,
max_num_batched_tokens=vllm_max_num_batched_tokens,
max_num_seqs=vllm_max_num_seqs,
max_paddings=vllm_max_paddings,
device="cuda" if use_cuda else "cpu",
max_lora_rank=vllm_max_lora_rank,
max_loras=vllm_max_loras,
max_cpu_loras=vllm_max_cpu_loras,
lora_dtype=self._get_torch_dtype(precision),
lora_extra_vocab_size=vllm_lora_extra_vocab_size,
placement_group=vllm_placement_group, # type: ignore
log_stats=vllm_log_stats,
batched_inference=False,
)
else:
self.model, self.tokenizer = self.load_models(
model_name=self.model_name,
tokenizer_name=self.tokenizer_name,
model_revision=self.model_revision,
tokenizer_revision=self.tokenizer_revision,
model_class=self.model_class,
tokenizer_class=self.tokenizer_class,
use_cuda=self.use_cuda,
precision=self.precision,
quantization=self.quantization,
device_map=self.device_map,
max_memory=self.max_memory,
torchscript=self.torchscript,
awq_enabled=self.awq_enabled,
flash_attention=self.flash_attention,
compile=compile,
**self.model_args,
)

def sequential_locker():
if self.concurrent_queries:
sequential_lock.acquire()

def sequential_unlocker():
if self.concurrent_queries:
sequential_lock.release()

def CORS():
cherrypy.response.headers["Access-Control-Allow-Origin"] = cors_domain
Expand Down Expand Up @@ -277,6 +349,8 @@ def CORS():
# Configure basic authentication
conf = {
"/": {
"tools.sequential_locker.on": True,
"tools.sequential_unlocker.on": True,
"tools.auth_basic.on": True,
"tools.auth_basic.realm": "geniusrise",
"tools.auth_basic.checkpassword": self.validate_password,
Expand All @@ -285,11 +359,19 @@ def CORS():
}
else:
# Configuration without authentication
conf = {"/": {"tools.CORS.on": True}}
conf = {
"/": {
"tools.sequential_locker.on": True,
"tools.sequential_unlocker.on": True,
"tools.CORS.on": True,
}
}

cherrypy.tools.sequential_locker = cherrypy.Tool("before_handler", sequential_locker)
cherrypy.tools.CORS = cherrypy.Tool("before_handler", CORS)
cherrypy.tree.mount(self, "/api/v1/", conf)
cherrypy.tools.CORS = cherrypy.Tool("before_finalize", CORS)
cherrypy.tools.sequential_unlocker = cherrypy.Tool("before_finalize", sequential_unlocker)
cherrypy.engine.start()
cherrypy.engine.block()

Expand Down
Loading

0 comments on commit 573049b

Please sign in to comment.