Skip to content

Commit

Permalink
add vllm to lm
Browse files Browse the repository at this point in the history
  • Loading branch information
ixaxaar committed Feb 24, 2024
1 parent 09f51b8 commit 7be87a7
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 5 deletions.
4 changes: 2 additions & 2 deletions geniusrise_text/base/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,11 @@ def listen(
)

def sequential_locker():
if not self.concurrent_queries:
if self.concurrent_queries:
sequential_lock.acquire()

def sequential_unlocker():
if not self.concurrent_queries:
if self.concurrent_queries:
sequential_lock.release()

def CORS():
Expand Down
12 changes: 10 additions & 2 deletions geniusrise_text/classification/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict

import cherrypy
import numpy as np
import torch
from geniusrise import BatchInput, BatchOutput, State
from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline
Expand Down Expand Up @@ -123,8 +124,15 @@ def classify(self) -> Dict[str, Any]:
logits = outputs.logits if hasattr(outputs, "logits") else outputs[0]
if next(self.model.parameters()).is_cuda:
logits = logits.cpu()
softmax = torch.nn.functional.softmax(logits, dim=-1)
scores = softmax.numpy().tolist() # Convert scores to list

# Handling a single number output
if logits.numel() == 1:
logits = outputs.logits.cpu().detach().numpy()
scores = 1 / (1 + np.exp(-logits)).flatten()
return {"input": text, "label_scores": scores.tolist()}
else:
softmax = torch.nn.functional.softmax(logits, dim=-1)
scores = softmax.numpy().tolist()

id_to_label = dict(enumerate(self.model.config.id2label.values())) # type: ignore
label_scores = {id_to_label[label_id]: score for label_id, score in enumerate(scores[0])}
Expand Down
192 changes: 191 additions & 1 deletion geniusrise_text/language_model/bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from geniusrise import BatchInput, BatchOutput, State
from pyarrow import feather
from pyarrow import parquet as pq
from vllm import LLM, SamplingParams

from geniusrise_text.base import TextBulk

Expand Down Expand Up @@ -164,7 +165,8 @@ def load_dataset(self, dataset_path: str, max_length: int = 512, **kwargs) -> Op

self.max_length = max_length

self.label_to_id = self.model.config.label2id if self.model and self.model.config.label2id else {} # type: ignore
if hasattr(self, "tokenizer") and self.tokenizer is not None:
self.label_to_id = self.model.config.label2id if self.model and self.model.config.label2id else {} # type: ignore

try:
self.log.info(f"Loading dataset from {dataset_path}")
Expand Down Expand Up @@ -349,6 +351,194 @@ def complete(
self._save_completions(completions, prompts, output_path)
self.done()

def complete_vllm(
self,
model_name: str,
use_cuda: bool = False,
precision: str = "float16",
quantization: int = 0,
device_map: str | Dict | None = "auto",
# VLLM params
vllm_tokenizer_mode: str = "auto",
vllm_download_dir: Optional[str] = None,
vllm_load_format: str = "auto",
vllm_seed: int = 42,
vllm_max_model_len: int = 1024,
vllm_enforce_eager: bool = False,
vllm_max_context_len_to_capture: int = 8192,
vllm_block_size: int = 16,
vllm_gpu_memory_utilization: float = 0.90,
vllm_swap_space: int = 4,
vllm_sliding_window: Optional[int] = None,
vllm_pipeline_parallel_size: int = 1,
vllm_tensor_parallel_size: int = 1,
vllm_worker_use_ray: bool = False,
vllm_max_parallel_loading_workers: Optional[int] = None,
vllm_disable_custom_all_reduce: bool = False,
vllm_max_num_batched_tokens: Optional[int] = None,
vllm_max_num_seqs: int = 64,
vllm_max_paddings: int = 512,
vllm_max_lora_rank: Optional[int] = None,
vllm_max_loras: Optional[int] = None,
vllm_max_cpu_loras: Optional[int] = None,
vllm_lora_extra_vocab_size: int = 0,
vllm_placement_group: Optional[dict] = None,
vllm_log_stats: bool = False,
# Generate params
notification_email: Optional[str] = None,
batch_size: int = 32,
**kwargs: Any,
) -> None:
"""
Performs bulk text generation using the Versatile Language Learning Model (VLLM) with specified parameters
for fine-tuning model behavior, including quantization and parallel processing settings. This method is designed
to process large datasets efficiently by leveraging VLLM capabilities for generating high-quality text completions
based on provided prompts.
Args:
model_name (str): The name or path of the VLLM model to use for text generation.
use_cuda (bool): Flag indicating whether to use CUDA for GPU acceleration.
precision (str): Precision of computations, can be "float16", "bfloat16", etc.
quantization (int): Level of quantization for model weights, 0 for none.
device_map (str | Dict | None): Specific device(s) to use for model inference.
vllm_tokenizer_mode (str): Mode of the tokenizer ("auto", "fast", or "slow").
vllm_download_dir (Optional[str]): Directory to download and load the model and tokenizer.
vllm_load_format (str): Format to load the model, e.g., "auto", "pt".
vllm_seed (int): Seed for random number generation.
vllm_max_model_len (int): Maximum sequence length the model can handle.
vllm_enforce_eager (bool): Enforce eager execution instead of using optimization techniques.
vllm_max_context_len_to_capture (int): Maximum context length for CUDA graph capture.
vllm_block_size (int): Block size for caching mechanism.
vllm_gpu_memory_utilization (float): Fraction of GPU memory to use.
vllm_swap_space (int): Amount of swap space to use in GiB.
vllm_sliding_window (Optional[int]): Size of the sliding window for processing.
vllm_pipeline_parallel_size (int): Number of pipeline parallel groups.
vllm_tensor_parallel_size (int): Number of tensor parallel groups.
vllm_worker_use_ray (bool): Whether to use Ray for model workers.
vllm_max_parallel_loading_workers (Optional[int]): Maximum number of workers for parallel loading.
vllm_disable_custom_all_reduce (bool): Disable custom all-reduce kernel and fall back to NCCL.
vllm_max_num_batched_tokens (Optional[int]): Maximum number of tokens to be processed in a single iteration.
vllm_max_num_seqs (int): Maximum number of sequences to be processed in a single iteration.
vllm_max_paddings (int): Maximum number of paddings to be added to a batch.
vllm_max_lora_rank (Optional[int]): Maximum rank for LoRA adjustments.
vllm_max_loras (Optional[int]): Maximum number of LoRA adjustments.
vllm_max_cpu_loras (Optional[int]): Maximum number of LoRA adjustments stored on CPU.
vllm_lora_extra_vocab_size (int): Additional vocabulary size for LoRA.
vllm_placement_group (Optional[dict]): Ray placement group for distributed execution.
vllm_log_stats (bool): Whether to log statistics during model operation.
notification_email (Optional[str]): Email to send notifications upon completion.
batch_size (int): Number of prompts to process in each batch for efficient memory usage.
**kwargs: Additional keyword arguments for generation settings like temperature, top_p, etc.
This method automates the loading of large datasets, generation of text completions, and saving results,
facilitating efficient and scalable text generation tasks.
"""
if ":" in model_name:
model_revision = model_name.split(":")[1]
tokenizer_revision = model_name.split(":")[1]
model_name = model_name.split(":")[0]
tokenizer_name = model_name
else:
model_revision = None
tokenizer_revision = None
tokenizer_name = model_name

self.model_name = model_name
self.tokenizer_name = tokenizer_name
self.model_revision = model_revision
self.tokenizer_revision = tokenizer_revision
self.use_cuda = use_cuda
self.precision = precision
self.quantization = quantization
self.device_map = device_map
self.notification_email = notification_email

self.model: LLM = self.load_models_vllm(
model=model_name,
tokenizer=tokenizer_name,
tokenizer_mode=vllm_tokenizer_mode,
trust_remote_code=True,
download_dir=vllm_download_dir,
load_format=vllm_load_format,
dtype=self._get_torch_dtype(precision),
seed=vllm_seed,
revision=model_revision,
tokenizer_revision=tokenizer_revision,
max_model_len=vllm_max_model_len,
quantization=(None if quantization == 0 else f"{quantization}-bit"),
enforce_eager=vllm_enforce_eager,
max_context_len_to_capture=vllm_max_context_len_to_capture,
block_size=vllm_block_size,
gpu_memory_utilization=vllm_gpu_memory_utilization,
swap_space=vllm_swap_space,
cache_dtype="auto",
sliding_window=vllm_sliding_window,
pipeline_parallel_size=vllm_pipeline_parallel_size,
tensor_parallel_size=vllm_tensor_parallel_size,
worker_use_ray=vllm_worker_use_ray,
max_parallel_loading_workers=vllm_max_parallel_loading_workers,
disable_custom_all_reduce=vllm_disable_custom_all_reduce,
max_num_batched_tokens=vllm_max_num_batched_tokens,
max_num_seqs=vllm_max_num_seqs,
max_paddings=vllm_max_paddings,
device="cuda" if use_cuda else "cpu",
max_lora_rank=vllm_max_lora_rank,
max_loras=vllm_max_loras,
max_cpu_loras=vllm_max_cpu_loras,
lora_dtype=self._get_torch_dtype(precision),
lora_extra_vocab_size=vllm_lora_extra_vocab_size,
placement_group=vllm_placement_group, # type: ignore
log_stats=vllm_log_stats,
batched_inference=True,
)

generation_args = {k.replace("generation_", ""): v for k, v in kwargs.items() if "generation_" in k}
self.generation_args = generation_args

dataset_path = self.input.input_folder
output_path = self.output.output_folder

# Load dataset
_dataset = self.load_dataset(dataset_path)
if _dataset is None:
self.log.error("Failed to load dataset.")
return
dataset = _dataset["text"]

for i in range(0, len(dataset), batch_size):
batch = dataset[i : i + batch_size]

outputs = self.model.generate(
prompts=batch,
sampling_params=SamplingParams(
n=generation_args.get("n", 1),
best_of=generation_args.get("best_of", None),
presence_penalty=generation_args.get("presence_penalty", 0.0),
frequency_penalty=generation_args.get("frequency_penalty", 0.0),
repetition_penalty=generation_args.get("repetition_penalty", 1.0),
temperature=generation_args.get("temperature", 1.0),
top_p=generation_args.get("top_p", 1.0),
top_k=generation_args.get("top_k", -1),
min_p=generation_args.get("min_p", 0.0),
use_beam_search=generation_args.get("use_beam_search", False),
length_penalty=generation_args.get("length_penalty", 1.0),
early_stopping=generation_args.get("early_stopping", False),
stop=generation_args.get("stop", None),
stop_token_ids=generation_args.get("stop_token_ids", None),
include_stop_str_in_output=generation_args.get("include_stop_str_in_output", False),
ignore_eos=generation_args.get("ignore_eos", False),
max_tokens=generation_args.get("max_tokens", 16),
logprobs=generation_args.get("logprobs", None),
prompt_logprobs=generation_args.get("prompt_logprobs", None),
skip_special_tokens=generation_args.get("skip_special_tokens", True),
spaces_between_special_tokens=generation_args.get("spaces_between_special_tokens", True),
logits_processors=generation_args.get("logits_processors", None),
),
)
completions = [" ".join(t.text for t in o.outputs) for o in outputs]
self._save_completions(completions, batch, output_path)
self.done()

def _save_completions(self, completions: List[str], prompts: List[str], output_path: str) -> None:
"""
Saves the generated completions to the specified output path.
Expand Down

0 comments on commit 7be87a7

Please sign in to comment.