Skip to content

Commit

Permalink
Put continuous batching RPC operations to async.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 610560319
Change-Id: I860c00cc62baa9a20fa644c3651033e8ec1aeebc
  • Loading branch information
bignamehyp authored and copybara-github committed Feb 27, 2024
1 parent 8f452e6 commit a2d3c42
Showing 1 changed file with 51 additions and 19 deletions.
70 changes: 51 additions & 19 deletions saxml/server/model_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ class ContinuousBatchingState:
prefill_queue: queue.SimpleQueue[PrefillRequest]
prefill_thread: threading.Thread
generate_thread: threading.Thread
post_process_pool: utils.ThreadPool
# When a new request is prefilled, prefill_thread uses this conditional
# variable to signal the availability to generate_thread.
generate_cv: threading.Condition
Expand Down Expand Up @@ -336,6 +337,14 @@ def __init__(
self.generate_thread = threading.Thread(
target=generate_fn, args=(self,), daemon=True
)

# By constraining num_thread=1, we can ensure the operations on
# state.rpc_tasks are also serializable because utils.ThreadPool
# runs in FIFO order.
self.post_process_pool = utils.ThreadPool(
num_threads=1,
thread_name_prefix='model_service_runner_post_processer',
)
self.generate_cv = threading.Condition()
self.available_slots = queue.SimpleQueue()
for i in range(num_cache_slots):
Expand Down Expand Up @@ -2106,7 +2115,7 @@ def _run_prefill_insert_loop(

state.decoded_tokens[slot][0] = np.array(tokens.addressable_data(0))
state.scores[slot] = np.array(scores.addressable_data(0))
state.rpc_tasks[slot] = request.rpc_task
self._post_prefill_async(state, slot, request.rpc_task)
state.steps[slot] = 1

# Must set slots_in_use in the end.
Expand All @@ -2118,6 +2127,12 @@ def _run_prefill_insert_loop(
state.pending_insert = False
state.generate_cv.notify()

def _post_prefill_async(self, state, slot, request_rpc):
def _postprocess():
state.rpc_tasks[slot] = request_rpc

state.post_process_pool.run(_postprocess)

def _run_generation_loop(
self,
state: ContinuousBatchingState,
Expand Down Expand Up @@ -2164,34 +2179,51 @@ def _run_generation_loop(
if not np.any(done):
continue

# If any of the sequences in the batch is done, return the response
# and reset the cache slot.
sequences = state.decoded_tokens[done]
output_strings = method_obj.detokenize(sequences)
done_slots = np.flatnonzero(done)

for idx, slot in enumerate(done_slots):
outputs = [output_strings[idx]], [state.scores[slot]]
self._model_services[state.service_id].FillRPCResponse(
state.model_method, outputs, state.rpc_tasks[slot].response
)
try:
state.rpc_tasks[slot].done(utils.ok())
except Exception as e: # pylint: disable=broad-except
self._log_exception('Error occurred: %s, error: %s', model_key, e)
state.rpc_tasks[slot] = None

self._run_post_generate_async(
state,
copy.deepcopy(state.decoded_tokens[done]),
copy.deepcopy(state.scores[done]),
done,
)
# Reset the cache slot state.
state.decoded_tokens[done] = 0
state.scores[done] = 0.0
state.steps[done] = 0
state.slots_in_use[done] = 0

# Release the slots.
for slot in done_slots:
for slot in np.flatnonzero(done):
logging.info('Releasing slot %d.', slot)
state.available_slots.put(slot)

def _run_post_generate_async(
self,
state: ContinuousBatchingState,
sequences,
scores,
done,
):
"""Runs the post-generate processing on a dedicated thread."""
def _postprocess():
# If any of the sequences in the batch is done, return the response
# and reset the cache slot.
output_strings = state.method.detokenize(sequences)

for idx, slot in enumerate(np.flatnonzero(done)):
outputs = [output_strings[idx]], [scores[idx]]
self._model_services[state.service_id].FillRPCResponse(
state.model_method, outputs, state.rpc_tasks[slot].response
)
try:
state.rpc_tasks[slot].done(utils.ok())
except Exception as e: # pylint: disable=broad-except
self._log_exception(
'Error occurred: %s, error: %s', state.model_key, e
)
state.rpc_tasks[slot] = None

state.post_process_pool.run(_postprocess)

def _run_secondary_worker_loop(self):
"""Runs the processing loop in secondary hosts in a multi-host setup."""
while True:
Expand Down

0 comments on commit a2d3c42

Please sign in to comment.