Skip to content

Commit

Permalink
Fix race condition in state.rpc_tasks
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 684534362
Change-Id: I706d8fe5bb66aaa176d2295ef84025e44c5a640a
  • Loading branch information
ukoxyz authored and copybara-github committed Oct 10, 2024
1 parent 3f085ae commit 48f6a8c
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions saxml/server/model_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2367,12 +2367,12 @@ def _run_post_prefill_async(self, state, slots, tokens, scores, request_rpcs):

n_tasks = len(slots)
assert n_tasks == len(request_rpcs)
for i, slot in enumerate(slots):
state.rpc_tasks[slot] = request_rpcs[i]

def _postprocess():
for i, slot in enumerate(slots):
state.rpc_tasks[slot] = request_rpcs[i]
if state.method.streamable_output:

if state.method.streamable_output:
def _postprocess():
nonlocal tokens
nonlocal scores

Expand Down Expand Up @@ -2401,7 +2401,7 @@ def _postprocess():
'Error occurred: %s, error: %s', state.model_key, e
)

state.post_process_pool.run(_postprocess)
state.post_process_pool.run(_postprocess)

def _run_generation_loop(
self,
Expand Down Expand Up @@ -2493,23 +2493,28 @@ def _run_post_generate_async(
"""Runs the post-generate processing on a dedicated thread."""
if state.method.streamable_output:
sequences = sequences[:, 1:]
rpc_tasks = list(state.rpc_tasks)
slots = np.flatnonzero(done)
for slot in slots:
rpc_task = rpc_tasks[slot]
assert rpc_task is not None
rpc_task.release_device_resource()
state.rpc_tasks[slot] = None

def _postprocess():
# If any of the sequences in the batch is done, return the response
# and reset the cache slot.

for idx, slot in enumerate(np.flatnonzero(done)):
rpc_task = state.rpc_tasks[slot]
assert isinstance(rpc_task, utils.RpcQueueTask)
rpc_task.release_device_resource()
for idx, slot in enumerate(slots):
rpc_task = rpc_tasks[slot]
assert rpc_task is not None
rpc_task.aux['finished_results'].append((sequences[idx], scores[idx]))
if rpc_task.aux['slot_count'] > len(rpc_task.aux['finished_results']):
assert not state.method.streamable_output
continue
if rpc_task.rpc and rpc_task.rpc.should_cancel():
logging.info('request cancelled.')
rpc_task.done(utils.cancelled())
state.rpc_tasks[slot] = None
continue
# [num_samples, ...]
seqs = np.stack(
Expand Down Expand Up @@ -2545,7 +2550,6 @@ def _postprocess():
self._log_exception(
'Error occurred: %s, error: %s', state.model_key, e
)
state.rpc_tasks[slot] = None
else:
# send response back to generate
self._model_services[state.service_id].FillRPCResponse(
Expand All @@ -2557,7 +2561,6 @@ def _postprocess():
self._log_exception(
'Error occurred: %s, error: %s', state.model_key, e
)
state.rpc_tasks[slot] = None

state.post_process_pool.run(_postprocess)

Expand Down

0 comments on commit 48f6a8c

Please sign in to comment.