Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-redhat committed Nov 17, 2024
1 parent 25fff99 commit 5a87b99
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 91 deletions.
2 changes: 1 addition & 1 deletion vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch_xla

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionType)


class PallasAttentionBackend(AttentionBackend):
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer import AnyTokenizer
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import RequestOutput
from vllm.platforms import current_platform
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.platforms import current_platform
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.usage.usage_lib import UsageContext
Expand Down
163 changes: 79 additions & 84 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from unittest.mock import patch

import numpy as np
import torch
Expand All @@ -15,8 +14,7 @@
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalDataDict
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cdiv,
is_pin_memory_available)
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_pin_memory_available
from vllm.v1.attention.backends.pallas import (PallasAttentionBackend,
PallasAttentionMetadata)
from vllm.v1.outputs import ModelRunnerOutput
Expand All @@ -33,7 +31,8 @@


@dataclass
class PrefillData:
class PrefillInputData:

request_ids: List
prompt_lens: List
token_ids: List
Expand All @@ -46,11 +45,12 @@ def zipped(self):


@dataclass
class DecodeData:
class DecodeInputData:

num_decodes: int
token_ids: torch.Tensor
position_ids: torch.Tensor
attn_metadata: PallasAttentionMetadata
token_ids: Optional[torch.Tensor] = None
position_ids: Optional[torch.Tensor] = None
attn_metadata: PallasAttentionMetadata = None


class TPUModelRunner:
Expand Down Expand Up @@ -200,33 +200,13 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
req_state = self.requests[req_id]
self.input_batch.add_request(req_state, None)

def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0

num_reqs = self.input_batch.num_reqs
num_decodes = self.input_batch.num_decodes
num_prefills = self.input_batch.num_prefills

assert num_decodes + num_prefills > 0

# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)

# Assert Decodes Are Decodes.
if idx < num_decodes:
assert num_tokens == 1

######################### PREFILLS #########################
# Prefills run separately, each with shape [1, padded_prompt_len],
# due to lack of variable length flashattention.
#
# Due to static shapes, prefills are padded to the nearest power
# of two, such that we can avoid recompilation.
def _prepare_prefill_inputs(
self,
num_scheduled_tokens: List[int],
) -> PrefillInputData:
# Prefills run separately, each with shape [1, prompt_len],
# due to lack of variable length flashattention, so we
# create a list that will be used in execute_model()

prefill_request_ids = []
prefill_prompt_lens = []
Expand All @@ -236,6 +216,8 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):

# DECODES are the first num_decodes REQUESTS.
# PREFILLS are the next num_reqs - num_decodes REQUESTS.
num_reqs = self.input_batch.num_reqs
num_decodes = self.input_batch.num_decodes
for idx in range(num_decodes, num_reqs):
prefill_request_ids.append(self.input_batch.req_ids[idx])

Expand All @@ -246,19 +228,17 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
assert padded_prompt_len <= self.max_model_len

# TOKEN_IDS.
prefill_token_ids.append(
torch.from_numpy(
self.input_batch.token_ids_cpu[idx:idx +
1, :padded_prompt_len]).to(
self.device))
token_ids = torch.from_numpy(self.input_batch.token_ids_cpu[
idx, :padded_prompt_len].reshape(-1, 1))
prefill_token_ids.append(token_ids.to(self.device))

# POSITIONS.
positions = self.prefill_positions[:, :padded_prompt_len]
prefill_position_ids.append(positions.to(self.device))

# SLOT_MAPPING.
# The "slot" is the "physical index" of a token in the KV cache.
# We look up the block_idx in the block table (logical <> physical map)
# Look up the block_idx in the block table (logical<>physical map)
# to compute this.
block_numbers = self.input_batch.block_table_cpu_tensor[
idx, positions // self.block_size].reshape(1, -1)
Expand All @@ -278,25 +258,25 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
context_lens=None,
))

prefill_data = PrefillData(
return PrefillInputData(
request_ids=prefill_request_ids,
prompt_lens=prefill_prompt_lens,
token_ids=prefill_token_ids,
position_ids=prefill_position_ids,
attn_metadata=prefill_attn_metadata,
)

if num_decodes == 0:
return prefill_data, None

######################### DECODES #########################
def _prepare_decode_inputs(self, num_decodes: int) -> DecodeInputData:
# Decodes run as one single padded batch with shape [batch, 1]
#
# We need to set _PAD_SLOT_ID for the padding tokens in the
# slot_mapping, such that the attention KV cache insertion
# logic knows to ignore those indicies. Otherwise, the
# padding data can be dummy since we have a causal mask.

if num_decodes == 0:
return DecodeInputData(num_decodes=0)

# PAD FOR STATIC SHAPES.
padded_batch_size = _get_padded_batch_size(num_decodes)

Expand All @@ -316,7 +296,7 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):

# SLOT_MAPPING [batch, 1]
# The "slot" is the "physical index" of a token in the KV cache.
# We look up the block_idx in the block table (logical <> physical map)
# Look up the block_idx in the block table (logical<>physical map)
# to compute this.
block_number = torch.gather(
input=self.input_batch.block_table_cpu_tensor,
Expand All @@ -337,17 +317,41 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
context_lens = (positions.reshape(-1) + 1)

# CPU<>TPU sync happens here.
decode_data = DecodeData(num_decodes=num_decodes,
token_ids=token_ids.to(self.device),
position_ids=positions.to(self.device),
attn_metadata=PallasAttentionMetadata(
is_prompt=False,
slot_mapping=slot_mapping.to(self.device),
block_tables=block_table.to(self.device),
context_lens=context_lens.to(self.device),
))

return prefill_data, decode_data
return DecodeInputData(num_decodes=num_decodes,
token_ids=token_ids.to(self.device),
position_ids=positions.to(self.device),
attn_metadata=PallasAttentionMetadata(
is_prompt=False,
slot_mapping=slot_mapping.to(self.device),
block_tables=block_table.to(self.device),
context_lens=context_lens.to(self.device),
))

def _prepare_inputs(
self, scheduler_output: "SchedulerOutput"
) -> Tuple[PrefillInputData, Optional[DecodeInputData]]:

total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
assert total_num_scheduled_tokens > 0

num_reqs = self.input_batch.num_reqs
num_decodes = self.input_batch.num_decodes

# Get the number of scheduled tokens for each request.
# TODO: The Python loop can be slow. Optimize.
num_scheduled_tokens = []
for idx, req_id in enumerate(self.input_batch.req_ids[:num_reqs]):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
num_scheduled_tokens.append(num_tokens)

# Assert Decodes Are Decodes.
if idx < num_decodes:
assert num_tokens == 1

return (
self._prepare_prefill_inputs(num_scheduled_tokens),
self._prepare_decode_inputs(num_decodes),
)

def _prepare_sampling(
self,
Expand All @@ -373,28 +377,29 @@ def execute_model(
num_reqs = self.input_batch.num_reqs
sampled_token_ids = torch.empty(num_reqs, dtype=torch.int32)

########## DECODES ##########
num_decodes = 0
if decode_data:
num_decodes = decode_data.num_decodes
######################### DECODES #########################
# Decodes run as one single padded batch with shape [batch, 1]
if decode_data.num_decodes > 0:

# FORWARD.
selected_token_ids = self.model(decode_data.token_ids,
decode_data.position_ids,
decode_data.attn_metadata,
self.kv_caches,
is_prompt=False)

# NOTE: TPU<>CPU sync happens here.
# It is important to call .cpu() first to avoid compilation on hotpath.
token_ids = selected_token_ids.cpu()[:num_decodes]
# We need to call .cpu() first to avoid recompilation.
token_ids = selected_token_ids.cpu()[:decode_data.num_decodes]
sampled_token_ids_list = token_ids.tolist()
sampled_token_ids[:num_decodes] = token_ids
sampled_token_ids[:decode_data.num_decodes] = token_ids

# UPDATE REQUEST STATE.
for i, req_id in enumerate(
self.input_batch.req_ids[:decode_data.num_decodes]):
req_state = self.requests[req_id]

# NO CHUNKED PREFILL
# TODO: ASSERT NO CHUNKED PREFILL.
assert scheduler_output.num_scheduled_tokens[req_id] == 1
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])
Expand All @@ -404,34 +409,33 @@ def execute_model(
self.input_batch.token_ids_cpu[i, seq_len] = token_id
req_state.output_token_ids.append(token_id)

########## PREFILLS ##########
######################### PREFILLS #########################
for idx, (req_id, prompt_len, token_ids, position_ids,
attn_metadata) in enumerate(prefill_data.zipped()):

# [padded_prompt_len]
# FORWARD.
selected_token_ids = self.model(token_ids,
position_ids,
attn_metadata,
self.kv_caches,
is_prompt=True)

# NOTE: TPU<>CPU sync happens here.
# It is important to call .cpu() first to avoid compilation on hotpath.
# We need to call .cpu() first to avoid recompilation.
token_id = selected_token_ids.cpu()[prompt_len - 1].item()
sampled_token_ids[num_decodes + idx] = token_id
sampled_token_ids[decode_data.num_decodes + idx] = token_id
req_state = self.requests[req_id]

# TODO: prefix caching.
if req_state.num_computed_tokens > 0:
breakpoint()
# TODO: ASSERT NO PREFIX CACHING.
assert req_state.num_computed_tokens == 0
seq_len = (req_state.num_computed_tokens +
scheduler_output.num_scheduled_tokens[req_id])

# TODO: chunked prefill.
# TODO: ASSERT NO CHUNKED PREFILL.
assert seq_len == req_state.num_tokens
assert prompt_len == seq_len

# Append the sampled token to the output token ids.
# UPDATE REQUEST STATE.
req_idx = self.input_batch.req_id_to_index[req_id]
self.input_batch.token_ids_cpu[req_idx, seq_len] = token_id
req_state.output_token_ids.append(token_id)
Expand Down Expand Up @@ -606,15 +610,6 @@ def initialize_kv_cache(self, num_blocks: int) -> None:
device=self.device),
))

def _get_padded_batch_size(self, batch_size: int) -> Optional[int]:
# The GMM Pallas kernel requires num_tokens * topk to be a multiple of 16.
# To meet this requirement in the simplest way, we set the minimal batch
# size to 8 (== MIN_BATCH_SIZE).
if batch_size <= 8:
return 8
else:
return ((batch_size + 15) // 16) * 16


@dataclass
class CachedRequestState:
Expand Down
8 changes: 4 additions & 4 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
import torch_xla.runtime as xr

import vllm.envs as envs
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment)
from vllm.logger import init_logger
from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.model_executor import set_random_seed
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size
from vllm.v1.outputs import ModelRunnerOutput
Expand Down Expand Up @@ -114,9 +114,9 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
# intermediate activations.
m = xm.get_memory_info(self.device)
total_tpu_memory = m["bytes_limit"]
peak_memory = m["peak_bytes_used"] # Weights + intermediate activations.
logger.debug("Peak Used: %sGB",
peak_memory // 1024 // 1024 // 1024)
peak_memory = m[
"peak_bytes_used"] # Weights + intermediate activations.
logger.debug("Peak Used: %sGB", peak_memory // 1024 // 1024 // 1024)
logger.debug("Total Memory: %sGB",
total_tpu_memory // 1024 // 1024 // 1024)

Expand Down

0 comments on commit 5a87b99

Please sign in to comment.