Skip to content

Commit

Permalink
Merge pull request #10 from WukLab/zijian-dev
Browse files Browse the repository at this point in the history
Zijian's development of resuming an API request after its call is completed by processing the API returned tokens using multiple single_query_cached_kv_attention. Currently, a few places are still written for testing only. Needs to change to run real models later.
  • Loading branch information
yiying-zhang authored Aug 15, 2023
2 parents 8c0efd1 + 5eb4f55 commit 8c7d6a4
Show file tree
Hide file tree
Showing 19 changed files with 567 additions and 101 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,6 @@ cython_debug/

# Sphinx documentation
_build/

ShareGPT_V3_unfiltered_cleaned_split.json
*.nsys-rep
21 changes: 21 additions & 0 deletions analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import pstats
from pstats import SortKey

p = pstats.Stats('prof_time.log')
# p.strip_dirs().sort_stats('cumtime').print_stats('opt.py', 30)
# p.strip_dirs().sort_stats('tottime').print_stats('opt.py:', 30)
p.strip_dirs().sort_stats('cumtime').print_stats(30)
# p.strip_dirs().sort_stats('tottime').print_stats(30)


# p.strip_dirs().sort_stats('cumtime').print_callees('step')
p.strip_dirs().sort_stats('cumtime').print_stats('llama.py')
# p.strip_dirs().sort_stats('cumtime').print_stats('layernorm.py')
# p.strip_dirs().sort_stats('cumtime').print_callees('forward')
# p.strip_dirs().sort_stats('cumtime').print_stats('attention.py')
p.strip_dirs().sort_stats('cumtime').print_stats('sampler.py')
# p.strip_dirs().sort_stats('tottime').print_stats('sampler.py')
# p.strip_dirs().sort_stats('cumtime').print_callees('_prune_hidden_states')
# p.strip_dirs().sort_stats('cumtime').print_callees('_sample_from_generation_tokens')
# p.strip_dirs().sort_stats('cumtime').print_callees('_sample_optimized')

3 changes: 2 additions & 1 deletion benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ def main(args: argparse.Namespace):
prompt_len + output_len
for _, prompt_len, output_len in requests
)
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
print(f"Elapsed time: {elapsed_time}s\n"
f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")


Expand Down
24 changes: 12 additions & 12 deletions examples/llm_engine_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@ def main(args: argparse.Namespace):

# Test the following prompts.
test_prompts = [
("A robot may not injure a human being", SamplingParams()),
("To be or not to be,",
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
("What is the meaning of life?",
SamplingParams(n=2,
best_of=5,
temperature=0.8,
top_p=0.95,
frequency_penalty=0.1)),
("It is only with the heart that one can see rightly",
SamplingParams(n=3, best_of=3, use_beam_search=True,
temperature=0.0)),
# ("A robot may not injure a human being", SamplingParams()),
("The president of the United States is",
SamplingParams(temperature=0.8, presence_penalty=0.2)),
# ("What is the meaning of life?",
# SamplingParams(n=2,
# best_of=5,
# temperature=0.8,
# top_p=0.95,
# frequency_penalty=0.1)),
# ("It is only with the heart that one can see rightly",
# SamplingParams(n=3, best_of=3, use_beam_search=True,
# temperature=0.0)),
]

# Run the engine by calling `engine.step()` manually.
Expand Down
9 changes: 8 additions & 1 deletion examples/offline_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from vllm import LLM, SamplingParams

API_START_TOKEN = "<TOOLFORMER_API_START>"
API_RESPONSE_TOKEN = "<TOOLFORMER_API_RESPONSE>"
API_END_TOKEN = "<TOOLFORMER_API_END>"
QUERY_DELIMITER = "("

# Sample prompts.
prompts = [
"Hello, my name is",
Expand All @@ -8,7 +13,9 @@
"The future of AI is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95)


# Create an LLM.
llm = LLM(model="facebook/opt-125m")
Expand Down
57 changes: 57 additions & 0 deletions examples/test_pause.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import argparse

from vllm import EngineArgs, LLMEngine, SamplingParams, utils

def api_call(input: str):
return " a "

def main(args: argparse.Namespace):
# Parse the CLI argument and initialize the engine.
engine_args = EngineArgs.from_cli_args(args)
engine = LLMEngine.from_engine_args(engine_args)
stop = [utils.get_api_stop_string()]
# Test the following prompts.
test_prompts = [
("The president of the United States is",
SamplingParams(temperature=0.0, presence_penalty=0.2,stop=stop)),
]

# Run the engine by calling `engine.step()` manually.
request_id = 0
# To test iteration-level scheduling, we add one request at each step.
for prompt, sampling_params in test_prompts:
engine.add_request(str(request_id), prompt, sampling_params)
request_id += 1

request_outputs = engine.step()

for request_output in request_outputs:
print(request_output)
if request_output.paused:
response = {}
for rs in request_output.paused:
rid, sid = rs
ret = api_call(request_output.outputs[rid].text)
response[sid] = [10]
engine.new_resume_request(request_output.request_id, response)
print(engine.scheduler.running[0].seqs[0].data)

# for _ in range(2):
# request_outputs = engine.step()
# for request_output in request_outputs:
# print(request_output)

while True:
request_outputs = engine.step()
for request_output in request_outputs:
if request_output.finished:
print(request_output)
if not engine.has_unfinished_requests():
break

if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Demo on using the LLMEngine class directly')
parser = EngineArgs.add_cli_args(parser)
args = parser.parse_args()
main(args)
5 changes: 5 additions & 0 deletions notice
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
If ray package raises this error:
AttributeError: 'NoneType' object has no attribute 'fs'
DO:
pip install ray==2.5.1

125 changes: 120 additions & 5 deletions tests/kernels/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from vllm import attention_ops

MAX_SEQ_LEN = 4096
MAX_SEQ_LEN = 16
TEST_SEED = 0


Expand Down Expand Up @@ -126,7 +126,8 @@ def ref_multi_query_cached_kv_attention(

# Create attention mask
attn_mask = torch.triu(torch.ones(query_len, context_len),
diagonal=context_len - query_len + 1) * -1e5
diagonal=context_len - query_len + 1)
attn_mask *= torch.finfo(dtype).min
attn_mask = attn_mask.to(dtype=dtype, device='cuda')

keys = []
Expand Down Expand Up @@ -245,7 +246,6 @@ def run_single_query_cached_kv_attention(
# We should use a relaxed tolerance for the test.
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


@torch.inference_mode()
def run_multi_query_kv_attention(
num_seqs: int,
Expand Down Expand Up @@ -291,13 +291,106 @@ def run_multi_query_kv_attention(
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)

# NOTE: This is the driver to test correctness of using single_query_att
# to complete multi-query-att. The function cannot be used as a kernel
# by itself.
@torch.inference_mode()
def run_multi_query_cached_kv_attention(
num_seqs: int,
num_heads: int,
head_size: int,
block_size: int,
num_blocks: int,
dtype: torch.dtype,
) -> None:
seq_lens = random.sample(range(1, MAX_SEQ_LEN), num_seqs)
num_tokens = sum(seq_lens)
scale = float(1.0 / (head_size**0.5))
qkv = torch.empty(num_tokens,
3,
num_heads,
head_size,
dtype=dtype,
device='cuda')
qkv.uniform_(-1e-3, 1e-3)
query, key, value = qkv.unbind(dim=1)
cu_seq_lens = [0]
for seq_len in seq_lens:
cu_seq_lens.append(cu_seq_lens[-1] + seq_len)

# prepare caches
x = 16 // torch.tensor([], dtype=dtype).element_size()
key_block_shape = (num_heads, head_size // x, block_size, x)
key_cache = torch.empty(size=(num_blocks, *key_block_shape),
dtype=dtype,
device='cuda')
key_cache.uniform_(-1e-3, 1e-3)
value_block_shape = (num_heads, head_size, block_size)
value_cache = torch.empty(size=(num_blocks, *value_block_shape),
dtype=dtype,
device='cuda')
value_cache.uniform_(-1e-3, 1e-3)

context_lens = [[i + 1 for i in range(l)] for l in seq_lens]
multi_ctx_lens = [l for l in seq_lens]
max_context_len = max(context_lens)
context_lens = torch.tensor(context_lens, dtype=torch.int, device='cuda')
multi_ctx_lens = torch.tensor(multi_ctx_lens, dtype=torch.int, device='cuda')
max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size

# NOTE: here each sequence's physical block is randomly generated
# This won't affect verifying correctness since tokens in the same seq use the
# same block table, and the reference calculation sees the same cache.
# Check https://github.com/WukLab/vLLM/blob/zijian-dev/vllm/worker/worker.py#L238-L240
# for the real flow
block_tables = []
for l in seq_lens:
block_table = [
random.randint(0, num_blocks - 1)
for _ in range(max_num_blocks_per_seq)
]
for _ in range(l):
block_tables.append(block_table)

block_tables = torch.tensor(block_tables, dtype=torch.int, device='cuda')
head_mapping = torch.arange(num_heads, dtype=torch.int32, device="cuda")
output = torch.empty(num_tokens,
num_heads,
head_size,
dtype=dtype,
device='cuda')
attention_ops.single_query_cached_kv_attention(
output,
query,
key_cache,
value_cache,
head_mapping,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
None, # ALiBi slopes.
)

ref_output = ref_multi_query_cached_kv_attention(
cu_seq_lens,
query,
key_cache,
value_cache,
block_tables,
multi_ctx_lens,
dtype
)
assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5)


def test_single_query_cached_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8, 16, 32]:
for head_size in [64, 80, 96, 112, 128, 256]:
for head_size in [64, 80, 96, 112, 128, 256]:
print(f'Testing single_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
Expand All @@ -310,7 +403,6 @@ def test_single_query_cached_kv_attention() -> None:
dtype=dtype,
)


def test_multi_query_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
Expand All @@ -324,3 +416,26 @@ def test_multi_query_kv_attention() -> None:
head_size=head_size,
dtype=dtype,
)

def test_multi_query_cached_kv_attention() -> None:
torch.random.manual_seed(TEST_SEED)
torch.cuda.manual_seed(TEST_SEED)
for dtype in [torch.half, torch.bfloat16, torch.float]:
for block_size in [8]:
for head_size in [64]:
print(f'Testing multi_query_cached_kv_attention with '
f'dtype={dtype}, block_size={block_size}, '
f'head_size={head_size}')
run_single_query_cached_kv_attention(
num_tokens=1,
num_heads=3,
head_size=head_size,
block_size=block_size,
num_blocks=1024,
dtype=dtype,
)

if __name__ == "__main__":
test_single_query_cached_kv_attention()
# test_multi_query_kv_attention()
test_multi_query_cached_kv_attention()
Loading

0 comments on commit 8c7d6a4

Please sign in to comment.