Skip to content

Commit

Permalink
Refactor the busy loop for the continuous batching
Browse files Browse the repository at this point in the history
Previous implementation of continuous batching on the host side is a busy loop that polls for incoming requests, which could starve other host side operations. We refactor this into the implementation with two threads (one for prefill/insert, the other for generate) per model method.

PiperOrigin-RevId: 608734099
Change-Id: I6c07e51a1db23243d2b48c3ad9323a27aef2e453
  • Loading branch information
changlan authored and copybara-github committed Feb 20, 2024
1 parent f817103 commit ab36fc0
Showing 1 changed file with 142 additions and 64 deletions.
206 changes: 142 additions & 64 deletions saxml/server/model_service_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class PrefillRequest:

@dataclasses.dataclass
class ContinuousBatchingState:
"""Continuous batching state."""
"""Continuous batching state of a model method."""

method: servable_model.ServableMethod
model_key: str
Expand All @@ -275,8 +275,18 @@ class ContinuousBatchingState:

# Prefill requests are enqueued here after getting successfully pre-processed
# on the primary host. When cache slots become available, dequeue requests
# from this queue in the generation loop for further prefill and generation.
# from this queue in the prefill loop.
prefill_queue: queue.SimpleQueue[PrefillRequest]
prefill_thread: threading.Thread
generate_thread: threading.Thread
# When a new request is prefilled, prefill_thread uses this conditional
# variable to signal the availability to generate_thread.
generate_cv: threading.Condition
# prefill_thread gets blocked when available_slots reaches empty, until
# generate_thread releases a cache slot when it finishes a request.
available_slots: queue.SimpleQueue[int]
# Set by prefill_thread when there are pending requests and slots to insert.
pending_insert: bool
# RPC tasks for requests being actively processed. One entry per cache slot.
# Entries are filled in when a request leaves prefill_queue and freed when
# the lm.generate RPC is fulfilled.
Expand Down Expand Up @@ -306,6 +316,8 @@ def __init__(
service_id: str,
num_cache_slots: int,
max_decode_step: int,
prefill_fn: Callable[..., None],
generate_fn: Callable[..., None],
):
self.method = method
self.model_key = model_key
Expand All @@ -317,6 +329,19 @@ def __init__(
self.prefill_queue: queue.SimpleQueue[PrefillRequest] = queue.SimpleQueue()
self.rpc_tasks = [None] * num_cache_slots

# TODO(changlan): Properly handle shutdown
self.prefill_thread = threading.Thread(
target=prefill_fn, args=(self,), daemon=True
)
self.generate_thread = threading.Thread(
target=generate_fn, args=(self,), daemon=True
)
self.generate_cv = threading.Condition()
self.available_slots = queue.SimpleQueue()
for i in range(num_cache_slots):
self.available_slots.put(i)
self.pending_insert = False

self.decoded_tokens = np.zeros(
(num_cache_slots, max_decode_step), dtype=np.int32
)
Expand Down Expand Up @@ -1347,12 +1372,12 @@ def __init__(
)

# For continuous batching
self._token_batching_state: Dict[str, ContinuousBatchingState] = {}
self._generate_thread = threading.Thread(
target=self._run_generation_loop,
daemon=True, # TODO(changlan): Handle shutdown properly
name='model_service_runner_generation',
)
# Indexed by (model_key, method_name)
self._continuous_batching_state: Dict[
tuple[str, str], ContinuousBatchingState
] = dict()
# Global lock for device computation
self._device_compute_mutex: threading.Lock = threading.Lock()
else:
primary_id_str = self._spmd_backend.receive_via_device()
primary_host = int(primary_id_str)
Expand Down Expand Up @@ -1471,7 +1496,6 @@ def start(self) -> None:
self._aio_thread.start()
if self._is_primary:
self._keep_warm_thread.start()
self._generate_thread.start()

def on_initial_models_load_completion(self) -> None:
"""Callback to invoke after all models are loaded."""
Expand Down Expand Up @@ -1603,16 +1627,7 @@ def preprocess_rpc_tasks(
max_live_batches=method.num_cache_slots * 2,
)

self._token_batching_state[model_key] = ContinuousBatchingState(
method,
model_key=model_key,
model_method=method_name,
service_id=service_id,
num_cache_slots=method.num_cache_slots,
max_decode_step=method.max_decode_steps,
)

self._loaded_models.load(
model = self._loaded_models.load(
model_key,
model_path,
checkpoint_path,
Expand All @@ -1622,6 +1637,22 @@ def preprocess_rpc_tasks(
register_methods,
)

for method_name, method_obj in model.methods.items():
if method_obj.continuous_batching:
state = ContinuousBatchingState(
method_obj,
model_key=model_key,
model_method=method_name,
service_id=method_obj.service_id(),
num_cache_slots=method_obj.num_cache_slots,
max_decode_step=method_obj.max_decode_steps,
prefill_fn=self._run_prefill_insert_loop,
generate_fn=self._run_generation_loop,
)
self._continuous_batching_state[(model_key, method_name)] = state
state.prefill_thread.start()
state.generate_thread.start()

def _save_model(self, model_key: str, checkpoint_path: str):
"""Saves a model checkpoint."""
if not self._loaded_models.contains(model_key):
Expand Down Expand Up @@ -1997,7 +2028,14 @@ def _run_primary_worker_loop(self):
try:
batch.wait_for_ready()
assert len(batch.rpc_tasks) == 1
state = self._token_batching_state[batch.method.model_key]

key = (batch.method.model_key, batch.method.model_method)
assert key in self._continuous_batching_state, (
f'Model method {key} not found in continuous batching state.'
f' Available keys: {self._continuous_batching_state.keys()}'
)

state = self._continuous_batching_state[key]

state.prefill_queue.put(
PrefillRequest(
Expand All @@ -2015,27 +2053,33 @@ def _run_primary_worker_loop(self):
batch.finish()
self._log_exception('Unknown method: %s', batch.method.name)

def _run_generation_loop(self):
"""Host loop for running continuous generation on the primary host."""
logging.info('Running generation loop')
def _run_prefill_insert_loop(
self,
state: ContinuousBatchingState,
):
"""Host loop for running prefills on the primary host."""
model_key = state.model_key
model_method = state.model_method
logging.info(
'Running prefill insert loop for model %s method %s',
model_key,
model_method,
)
method_obj = state.method
while True:
for model_key in self._loaded_models.get_models():
if model_key not in self._token_batching_state:
continue
state = self._token_batching_state[model_key]
assert model_key == state.model_key
model = self._loaded_models.get_model(state.model_key)
method_obj = model.method(state.model_method)

# If there are prefill requests and empty cache slots, run prefills
# and insert them to the cache first
while not (state.slots_in_use.all() or state.prefill_queue.empty()):
request = state.prefill_queue.get()
prefill_dequeue_time = time.time()

slot = np.argmin(state.slots_in_use)
state.slots_in_use[slot] = 1

# Block if there is no prefill requests.
request = state.prefill_queue.get()
# Block if there is no available cache slot.
slot = state.available_slots.get()
prefill_dequeue_time = time.time()

# Take an unused slot.
logging.info('Taking slot %d', slot)

# Atomic mutation. Safe to be outside of generate_cv guard.
state.pending_insert = True
with state.generate_cv:
with self._device_compute_mutex:
self._inform_secondary_hosts(
MethodName.PREFILL_INSERT,
state.model_key,
Expand All @@ -2046,46 +2090,74 @@ def _run_generation_loop(self):

assert request.preprocessed_inputs is not None
assert request.rpc_task is not None

scores, tokens, prefix_cache = method_obj.prefill(
inputs=request.preprocessed_inputs,
)
state.update_stats(
prefill_wait_time=prefill_dequeue_time - request.enqueue_time,
insert_wait_time=time.time() - prefill_dequeue_time,
)

method_obj.insert(prefix_cache, slot)
state.decoded_tokens[slot][0] = np.array(tokens.addressable_data(0))
state.scores[slot] = np.array(scores.addressable_data(0))
state.rpc_tasks[slot] = request.rpc_task
state.steps[slot] = 1

# Skip generation if none of the cache slots are in use.
if not state.slots_in_use.any():
continue
state.decoded_tokens[slot][0] = np.array(tokens.addressable_data(0))
state.scores[slot] = np.array(scores.addressable_data(0))
state.rpc_tasks[slot] = request.rpc_task
state.steps[slot] = 1

# Must set slots_in_use in the end.
state.slots_in_use[slot] = 1

# Don't wake up the generate thread if there are more pending requests
# and empty slots.
if state.prefill_queue.empty() or state.available_slots.empty():
state.pending_insert = False
state.generate_cv.notify()

def _run_generation_loop(
self,
state: ContinuousBatchingState,
):
"""Host loop for running continuous generation on the primary host."""
model_key = state.model_key
model_method = state.model_method
logging.info(
'Running generation loop for model %s method %s',
model_key,
model_method,
)
method_obj = state.method
while True:
with state.generate_cv:
while not state.slots_in_use.any() or state.pending_insert:
# Release the lock and wait.
state.generate_cv.wait()

# Perform generation.
start_ts = time.time()
self._inform_secondary_hosts(
MethodName.GENERATE,
state.model_key,
state.model_method,
skip_host_sync=True,
)
token_batch = (
state.decoded_tokens[
np.arange(state.num_cache_slots), state.steps - 1
]
* state.slots_in_use
)
token_batch = method_obj.input_to_device_for_continuous_batching(
token_batch,
InputShapeInfo(batch_size=state.num_cache_slots),
)
res = method_obj.generate(token_batch)
scores, tokens, done = method_obj.output_to_host(
res, unpadded_batch_size=state.num_cache_slots
)
state.generate_step_time.add(time.time() - start_ts)
with self._device_compute_mutex:
start_ts = time.time()
self._inform_secondary_hosts(
MethodName.GENERATE,
state.model_key,
state.model_method,
skip_host_sync=True,
)
token_batch = method_obj.input_to_device_for_continuous_batching(
token_batch,
InputShapeInfo(batch_size=state.num_cache_slots),
)
res = method_obj.generate(token_batch)
scores, tokens, done = method_obj.output_to_host(
res, unpadded_batch_size=state.num_cache_slots
)
state.generate_step_time.add(time.time() - start_ts)

state.decoded_tokens[np.arange(state.num_cache_slots), state.steps] = (
tokens
Expand All @@ -2099,7 +2171,8 @@ def _run_generation_loop(self):
if not np.any(done):
continue

# If any of the sequences in the batch is done
# If any of the sequences in the batch is done, return the response
# and reset the cache slot.
sequences = state.decoded_tokens[done]
output_strings = method_obj.detokenize(sequences)
done_slots = np.flatnonzero(done)
Expand All @@ -2121,6 +2194,11 @@ def _run_generation_loop(self):
state.steps[done] = 0
state.slots_in_use[done] = 0

# Release the slots.
for slot in done_slots:
logging.info('Releasing slot %d.', slot)
state.available_slots.put(slot)

def _run_secondary_worker_loop(self):
"""Runs the processing loop in secondary hosts in a multi-host setup."""
while True:
Expand Down

0 comments on commit ab36fc0

Please sign in to comment.