diff --git a/jetstream/core/orchestrator.py b/jetstream/core/orchestrator.py index 04faa285..720c71bb 100644 --- a/jetstream/core/orchestrator.py +++ b/jetstream/core/orchestrator.py @@ -201,6 +201,7 @@ class Driver: # This can be a list because we can pass it as an arg to generate and # detokenize threads. It is a list of tokens to be detokenized. _detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] + _prefill_detokenize_backlogs: list[queue.Queue[engine_api.ResultTokens]] = [] _generate_slots: list[queue.Queue[int]] = [] _active_requests: list[queue.Queue[tuple[int, ActiveRequest]]] = [] @@ -316,6 +317,12 @@ def __init__( queue.Queue(8) for _ in self._generate_engines ] + self._prefill_detokenize_backlogs = [ + # We don't let detokenization accumulate more than 8 steps to avoid + # synchronization issues. + queue.Queue(8) + for _ in self._prefill_engines + ] # A queue of integers representing available 'slots' in the decode # operation. I.e. potentially available rows in the batch and/or microbatch. @@ -376,12 +383,23 @@ def __init__( ) for idx in range(len(self._generate_engines)) ] + self.prefill_detokenize_threads = [ + JetThread( + target=functools.partial( + self._prefill_detokenize_thread, + idx, + ), + name=f"prefill-detokenize-{idx}", + ) + for idx in range(len(self._prefill_engines)) + ] self._all_threads = list( itertools.chain( self._prefill_threads, self._transfer_threads, self._generate_threads, self.detokenize_threads, + self.prefill_detokenize_threads, ) ) self.live = True @@ -514,12 +532,12 @@ def _prefill_thread(self, idx: int): padded_tokens=padded_tokens, true_length=true_length, ) - + first_token.copy_to_host_async() request.prefill_result = prefill_result # put first token to detokenize queue request.complete = np.zeros((prefill_engine.samples_per_slot,), np.bool_) - my_detokenize_backlog = self._detokenize_backlogs[idx] + my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] my_detokenize_backlog.put( (first_token, request, request_start_time), block=True ) @@ -736,34 +754,13 @@ def _detokenize_thread(self, idx: int): if data is None: break start_detokenize_time = time.time() - # prefill first token - if isinstance(data[0], engine_api.ResultTokens): - request_first_token, request, request_start_time = data - request_first_token = request_first_token.convert_to_numpy() - - results, complete = token_utils.process_result_tokens( - tokenizer=tokenizer, - slot=0, # always 0 as prefill only run 1 sample - slot_max_length=request.max_tokens, - result_tokens=request_first_token, - is_client_side_tokenization=request.is_client_side_tokenization, - complete=request.complete, - ) - request.complete = complete - # Return some output samples. - request.enqueue_samples(results) - - first_token_return_time = time.perf_counter() - logging.info( - "TTFT duration: %fms", - (first_token_return_time - request_start_time) * 1000, - ) # generate step tokens - elif isinstance(data[1], engine_api.ResultTokens): + if isinstance(data[1], engine_api.ResultTokens): # We want to detokenize them. generate_timestep_added, result_tokens = data # Disable attribute error because pytype doesn't know this # is a result tokens, and we can't annotate the tuple. + result_tokens = jax.block_until_ready(result_tokens) result_tokens = result_tokens.convert_to_numpy() for slot, request in my_live_requests.items(): @@ -795,6 +792,44 @@ def _detokenize_thread(self, idx: int): slot, active_request = data my_live_requests[slot] = active_request + def _prefill_detokenize_thread(self, idx: int): + """Detokenize the prefill token and returns it to the user.""" + # We don't need to keep a my_live_requests list since the request is + # passed to tranfer then to generate. We could directly detokenize the + # first token from prefill and return it. + my_detokenize_backlog = self._prefill_detokenize_backlogs[idx] + my_prefill_engine = self._prefill_engines[idx] + + metadata = my_prefill_engine.get_tokenizer() + tokenizer = my_prefill_engine.build_tokenizer(metadata) + while self.live: + data = my_detokenize_backlog.get(block=True) + if data is None: + break + # prefill first token + if isinstance(data[0], engine_api.ResultTokens): + request_first_token, request, request_start_time = data + request_first_token = jax.block_until_ready(request_first_token) + request_first_token = request_first_token.convert_to_numpy() + + results, complete = token_utils.process_result_tokens( + tokenizer=tokenizer, + slot=0, # always 0 as prefill only run 1 sample + slot_max_length=request.max_tokens, + result_tokens=request_first_token, + is_client_side_tokenization=request.is_client_side_tokenization, + complete=request.complete, + ) + request.complete = complete + # Return some output samples. + request.enqueue_samples(results) + + first_token_return_time = time.perf_counter() + logging.info( + "TTFT duration: %fms", + (first_token_return_time - request_start_time) * 1000, + ) + class LLMOrchestrator(jetstream_pb2_grpc.OrchestratorServicer): """Coordinates a set of prefill and generate slices for LLM decoding."""