Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Align Tokenizer in JetStream #40

Merged
merged 5 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ jobs:
pip install pylint
pip install pyink
pip install -r requirements.txt
pip install -r benchmarks/requirements.in
- name: Typecheck the code with pytype
run: |
pytype --jobs auto --disable import-error --disable module-attr jetstream/
pytype --jobs auto --disable import-error --disable module-attr jetstream/ benchmarks/
- name: Analysing the code with pylint
run: |
pylint jetstream/ benchmarks/
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -337,15 +337,15 @@ python -m jetstream.tools.load_tester
### Test core modules
```
# Test JetStream core orchestrator
python -m jetstream.core.orchestrator_test
python -m jetstream.tests.core.test_orchestrator

# Test JetStream core server library
python -m jetstream.core.server_test
python -m jetstream.tests.core.test_server

# Test mock JetStream engine implementation
python -m jetstream.engine.mock_engine_test
python -m jetstream.tests.engine.test_mock_engine

# Test mock JetStream token utils
python -m jetstream.engine.utils_test
python -m jetstream.tests.engine.test_utils

```
65 changes: 33 additions & 32 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,15 +63,14 @@
import json
import random
import time
from typing import Any, AsyncGenerator, List, Optional
from typing import Any, AsyncGenerator, Optional

import grpc
from jetstream.core.proto import jetstream_pb2
from jetstream.core.proto import jetstream_pb2_grpc
from jetstream.engine.token_utils import load_vocab
import numpy as np
import tensorflow as tf
import tensorflow_text as tftxt
from tqdm.asyncio import tqdm
from tqdm.asyncio import tqdm # pytype: disable=pyi-error


@dataclass
Expand Down Expand Up @@ -103,9 +102,9 @@ class InputRequest:

@dataclass
class RequestFuncOutput:
input_request: InputRequest = None
generated_token_list: list[str] = None
generated_text: str = None
input_request: Optional[InputRequest] = None
generated_token_list: list[str] = []
generated_text: str = ""
success: bool = False
latency: float = 0
ttft: float = 0
Expand All @@ -129,18 +128,16 @@ def get_tokenizer(tokenizer_name: str) -> Any:
if tokenizer_name == "test":
return "test"
else:
with tf.io.gfile.GFile(tokenizer_name, "rb") as model_fp:
sp_model = model_fp.read()
sp_tokenizer = tftxt.SentencepieceTokenizer(
model=sp_model, add_bos=True, add_eos=False, reverse=False
)
return sp_tokenizer
# Use JetStream tokenizer util. It's using the sentencepiece wrapper in
# seqio library.
vocab = load_vocab(tokenizer_name)
return vocab.tokenizer


def load_sharegpt_dataset(
dataset_path: str,
conversation_starter: str,
) -> List[tuple[str]]:
) -> list[tuple[Any, Any]]:
# Load the dataset.
with open(dataset_path, "r", encoding="utf-8") as f:
dataset = json.load(f)
Expand All @@ -163,7 +160,7 @@ def load_sharegpt_dataset(
return dataset


def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:
def load_openorca_dataset(dataset_path: str) -> list[tuple[Any, Any]]:
# Load the dataset.
with open(dataset_path, "r", encoding="utf-8") as f:
dataset = json.load(f)
Expand All @@ -176,9 +173,9 @@ def load_openorca_dataset(dataset_path: str) -> List[tuple[str]]:


def tokenize_dataset(
dataset: List[tuple[str]],
dataset: list[tuple[Any, Any, Any]],
tokenizer: Any,
) -> List[tuple[Any]]:
) -> list[tuple[str, Any, str, int, int, int]]:

n = len(dataset)

Expand All @@ -191,10 +188,10 @@ def tokenize_dataset(
outputs.append(output)
indices.append(idx)

prompt_token_ids = tokenizer.tokenize(
prompt_token_ids = tokenizer.encode(
prompts
) # adjust this code based on tokenizer method
outputs_token_ids = tokenizer.tokenize(
outputs_token_ids = tokenizer.encode(
outputs
) # adjust this code based on tokenizer method

Expand All @@ -215,8 +212,9 @@ def tokenize_dataset(


def filter_dataset(
tokenized_dataset: List[tuple[Any]], max_output_length: Optional[int] = None
) -> List[InputRequest]:
tokenized_dataset: list[tuple[str, Any, str, int, int, int]],
max_output_length: Optional[int] = None,
) -> list[InputRequest]:
if max_output_length is None:
print("In InputRequest, pass in actual output_length for each sample")
else:
Expand All @@ -226,7 +224,7 @@ def filter_dataset(
)

# Filter out too long sequences.
filtered_dataset: List[InputRequest] = []
filtered_dataset: list[InputRequest] = []
for (
prompt,
_,
Expand Down Expand Up @@ -255,12 +253,12 @@ def filter_dataset(


def sample_requests(
dataset: List[tuple[str]],
dataset: list[tuple[Any, Any]],
tokenizer: Any,
num_requests: int,
max_output_length: Optional[int] = None,
oversample_multiplier: float = 1.2,
) -> List[InputRequest]:
) -> list[InputRequest]:

# Original dataset size
n = len(dataset)
Expand Down Expand Up @@ -301,7 +299,7 @@ def sample_requests(


async def get_request(
input_requests: List[InputRequest],
input_requests: list[InputRequest],
request_rate: float,
) -> AsyncGenerator[InputRequest, None]:
input_requests = iter(input_requests)
Expand All @@ -318,8 +316,8 @@ async def get_request(


def calculate_metrics(
input_requests: List[InputRequest],
outputs: List[RequestFuncOutput],
input_requests: list[InputRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer: Any,
) -> BenchmarkMetrics:
Expand Down Expand Up @@ -371,16 +369,17 @@ async def grpc_async_request(
token_list = []
request_start_time = time.perf_counter()
response = stub.Decode(request)
async for token in response:
async for sample_list in response:
if ttft == 0:
ttft = time.perf_counter() - request_start_time
token_list.append(token.response[0])
token_list.extend(sample_list.response[0].token_ids)
latency = time.perf_counter() - request_start_time
return token_list, ttft, latency


async def send_request(
api_url: str,
tokenizer: Any,
input_request: InputRequest,
pbar: tqdm,
session_cache: str,
Expand All @@ -402,7 +401,8 @@ async def send_request(
output.ttft = ttft
output.latency = latency
output.generated_token_list = generated_token_list
output.generated_text = "".join(generated_token_list)
# generated_token_list is a list of token ids, decode it to generated_text.
output.generated_text = tokenizer.decode(generated_token_list)
output.success = True
if pbar:
pbar.update(1)
Expand All @@ -412,7 +412,7 @@ async def send_request(
async def benchmark(
api_url: str,
tokenizer: Any,
input_requests: List[InputRequest],
input_requests: list[InputRequest],
request_rate: float,
disable_tqdm: bool,
session_cache: str,
Expand All @@ -430,6 +430,7 @@ async def benchmark(
asyncio.create_task(
send_request(
api_url=api_url,
tokenizer=tokenizer,
input_request=request,
pbar=pbar,
session_cache=session_cache,
Expand All @@ -439,7 +440,7 @@ async def benchmark(
)
outputs = await asyncio.gather(*tasks)

if not disable_tqdm:
if not disable_tqdm and pbar:
pbar.close()

benchmark_duration = time.perf_counter() - benchmark_start_time
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/requirements.in
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
nltk
evaluate
rouge-score
rouge-score
tqdm
11 changes: 8 additions & 3 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class ActiveRequest:
# We keep prefill and decode information together in the same object so that
# there is less indirection about where this return channel is.
# The return channel returns a list of strings, one per sample for that query.
return_channel: async_multifuture.AsyncMultifuture[list[str]]
return_channel: async_multifuture.AsyncMultifuture[list[list[int]]]
# [num_samples,] which corresponds to whether each sample is complete for the
# requests.
complete: Optional[np.ndarray] = None
Expand All @@ -139,7 +139,7 @@ class ActiveRequest:
# Which generate step this was added at.
generate_timestep_added: Optional[int] = None

def enqueue_tokens(self, generated_tokens: list[str]):
def enqueue_tokens(self, generated_tokens: list[list[int]]):
"""Records information about the step.

Args:
Expand Down Expand Up @@ -662,4 +662,9 @@ async def Decode( # pylint: disable=invalid-overridden-method
# The DecodeResponse stream should consume all generated tokens in
# return_channel when complete signal is received. It should check if
# return_channel is empty to decide if it should exit the while loop.
yield jetstream_pb2.DecodeResponse(response=response)
repeated_token_ids = []
for token_ids in response:
repeated_token_ids.append(
jetstream_pb2.RepeatedTokenIds(token_ids=token_ids)
)
yield jetstream_pb2.DecodeResponse(response=repeated_token_ids)
8 changes: 6 additions & 2 deletions jetstream/core/proto/jetstream.proto
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ message DecodeRequest {
int32 max_tokens = 4;
}
message DecodeResponse {
// List of responses, one per sample.
repeated string response = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we still keep a str as option? The internal keep both text and token id.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we don't want to decode it to str (or piece) in jetstream, since it would have some off in the final result.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Thanks for making the changes!

// List of responses, one per sample. The list size depends on text generation strategy the engine used.
repeated RepeatedTokenIds response = 1;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add your deleted field number to the reserved list, it may messes up deserialization. Here is the reference: https://protobuf.dev/programming-guides/dos-donts/#reserve-tag-numbers

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not keep both and let the user choose whether she wants ids or string?

}
message RepeatedTokenIds {
// List of token ids, one list per sample. When speculative decoding is disabled, the list size should be 1; When speculative decoding is enabled, the list size should be >= 1.
repeated int32 token_ids = 1;
}
10 changes: 6 additions & 4 deletions jetstream/core/proto/jetstream_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"e\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x17\n\x0f\x61\x64\x64itional_text\x18\x02 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05""\n\x0e\x44\x65\x63odeResponse\x12\x10\n\x08response\x18\x01 \x03(\t2]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
b'\n$jetstream/core/proto/jetstream.proto\x12\x0fjetstream_proto"e\n\rDecodeRequest\x12\x15\n\rsession_cache\x18\x01 \x01(\t\x12\x17\n\x0f\x61\x64\x64itional_text\x18\x02 \x01(\t\x12\x10\n\x08priority\x18\x03 \x01(\x05\x12\x12\n\nmax_tokens\x18\x04 \x01(\x05"E\n\x0e\x44\x65\x63odeResponse\x12\x33\n\x08response\x18\x01 \x03(\x0b\x32!.jetstream_proto.RepeatedTokenIds"%\n\x10RepeatedTokenIds\x12\x11\n\ttoken_ids\x18\x01 \x03(\x05\x32]\n\x0cOrchestrator\x12M\n\x06\x44\x65\x63ode\x12\x1e.jetstream_proto.DecodeRequest\x1a\x1f.jetstream_proto.DecodeResponse"\x00\x30\x01\x62\x06proto3'
)

_globals = globals()
Expand All @@ -41,7 +41,9 @@
_globals["_DECODEREQUEST"]._serialized_start = 57
_globals["_DECODEREQUEST"]._serialized_end = 158
_globals["_DECODERESPONSE"]._serialized_start = 160
_globals["_DECODERESPONSE"]._serialized_end = 194
_globals["_ORCHESTRATOR"]._serialized_start = 196
_globals["_ORCHESTRATOR"]._serialized_end = 289
_globals["_DECODERESPONSE"]._serialized_end = 229
_globals["_REPEATEDTOKENIDS"]._serialized_start = 231
_globals["_REPEATEDTOKENIDS"]._serialized_end = 268
_globals["_ORCHESTRATOR"]._serialized_start = 270
_globals["_ORCHESTRATOR"]._serialized_end = 363
# @@protoc_insertion_point(module_scope)
8 changes: 5 additions & 3 deletions jetstream/engine/mock_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ def _encode(self, s: str) -> Sequence[int]:

def _decode(self, ids: np.ndarray):
"""Converts a numpy array into a string."""
# 'We use array methods, not python iterables so we don't
# implement this method in the mock vocab.
raise NotImplementedError
return "".join([chr(r) for r in list(ids)])

def _encode_tf(self, s: str) -> np.ndarray:
"""Converts a string into a numpy array."""
Expand All @@ -78,6 +76,10 @@ def _decode_tf(self, ids: np.ndarray) -> List[str]:
results = np.split(ids, ids.shape[0])
return ["".join([chr(r) for r in list(line[0])]) for line in results]

def decode(self, ids: np.ndarray):
"""Converts a numpy array into a string."""
return self._decode(ids)

def encode_tf(self, s: str) -> np.ndarray:
"""Converts a string into a numpy array."""
return self._encode_tf(s)
Expand Down
33 changes: 5 additions & 28 deletions jetstream/engine/token_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,6 @@
from jetstream.engine import mock_utils


def mix_decode(vocab: Vocabulary, tok_id: int):
"""
The IdToPiece and decode results differ for 344 tokens in Llama2.
Use the decode function to generate the correct strings for these 344 tokens.
If IdToPiece returns a hex string (e.g., '<0x0A>') for a token within these
344, utilize IdToPiece to convert it into a string, likely with a space
placeholder (' ') for the corresponding tokens.
"""
p_token = vocab.tokenizer.IdToPiece(tok_id)
# SentencePiece escapes the whitespace with a meta symbol "▁" (U+2581)
p_token = p_token.replace("▁", " ")
d_token = vocab.tokenizer.decode([tok_id])
return p_token if p_token.lstrip() == d_token else d_token


def take_nearest_length(lengths: list[int], length: int) -> int:
"""Gets the nearest length to the right in a set of lengths."""
pos = bisect_left(lengths, length)
Expand Down Expand Up @@ -127,7 +112,7 @@ def process_result_tokens(
vocab: Vocabulary,
complete: np.ndarray,
debug: bool = False,
) -> Tuple[List[str], np.ndarray]:
) -> Tuple[List[List[int]], np.ndarray]:
"""Processes a result tokens into a list of strings, handling multiple
samples.

Expand All @@ -141,7 +126,7 @@ def process_result_tokens(
debug: Whether to log step by step detokenisation.

Returns:
sample_return: List of strings, one per sample.
sample_return: List of tok_id list, one list per sample.
complete: Updated complete.
"""
# tokens: [samples, speculations]
Expand All @@ -162,7 +147,7 @@ def process_result_tokens(
)
sample_return = []
for idx in range(samples):
string_so_far = ""
tok_id_so_far = []
if not complete[idx].item():
for spec_idx in range(speculations):
tok_id = slot_tokens[idx, spec_idx].item()
Expand All @@ -178,16 +163,8 @@ def process_result_tokens(
complete[idx] = True
break
else:
try:
token = mix_decode(vocab, tok_id) # pytype: disable=attribute-error
except ValueError:
# This error only occurs when using tests where the vocab range is
# computed via addition and int->char is computed using chr(). Real
# models have vocab logits which are at max the size of the vocab.
logging.warning("%d exceeded vocab range", tok_id)
token = "<sampled_outside_vocab>"
string_so_far += token
sample_return.append(string_so_far)
tok_id_so_far.append(tok_id)
sample_return.append(tok_id_so_far)
if debug:
logging.info("Sampled return %s", str(sample_return))
return sample_return, complete
Expand Down
Loading
Loading