Skip to content

Commit

Permalink
Add ray disaggregated serving support (#87)
Browse files Browse the repository at this point in the history
* add ray dissagregated serving support

* function fix

* fix lint error

* refactor parameter

* add ActiveRequest annotation in function
  • Loading branch information
FanhaiLu1 authored May 23, 2024
1 parent eaf0d6e commit e19a790
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 8 deletions.
1 change: 1 addition & 0 deletions jetstream/core/config_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class ServerConfig:
prefill_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
generate_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
interleaved_engine_create_fns: Tuple[CreateEngineFn, ...] = ()
is_ray_backend: bool = False


@dataclasses.dataclass
Expand Down
34 changes: 26 additions & 8 deletions jetstream/core/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def __init__(
interleaved_mode: bool = False,
jax_padding: bool = True,
metrics_collector: JetstreamMetricsCollector | None = None,
is_ray_backend: bool = False,
):
if prefill_engines is None:
prefill_engines = []
Expand Down Expand Up @@ -374,6 +375,7 @@ def __init__(
)
)
self.live = True
self._is_ray_backend = is_ray_backend
# Start all threads
for t in self._all_threads:
t.start()
Expand Down Expand Up @@ -508,6 +510,29 @@ def _prefill_thread(self, idx: int):
del prefill_result
del request

def _jax_transfer_prefill_result(
self, new_request: ActiveRequest, target_idx: int
):
new_request.prefill_result = jax.device_put(
new_request.prefill_result,
self._generate_engines[target_idx].get_prefix_destination_sharding(),
)
# Block here so we don't block on the generate thread that steps.
jax.block_until_ready(new_request.prefill_result)

def _ray_transfer_prefill_result(
self, new_request: ActiveRequest, target_idx: int
):
self._generate_engines[target_idx].transfer(new_request.prefill_result)

def _transfer_prefill_result(
self, new_request: ActiveRequest, target_idx: int
):
if self._is_ray_backend:
self._ray_transfer_prefill_result(new_request, target_idx)
else:
self._jax_transfer_prefill_result(new_request, target_idx)

def _transfer_thread(self, idx: int):
"""Transfers the kv cache on an active request to the least full
generate backlog."""
Expand All @@ -531,14 +556,7 @@ def _transfer_thread(self, idx: int):
target_idx,
)
# Transfer the info to the relevant generate slice.
new_request.prefill_result = jax.device_put(
new_request.prefill_result,
self._generate_engines[
target_idx
].get_prefix_destination_sharding(),
)
# Block here so we don't block on the generate thread that steps.
jax.block_until_ready(new_request.prefill_result)
self._transfer_prefill_result(new_request, target_idx)
# Place the request on the correct generate backlog and block if full.
self._generate_backlogs[target_idx].put(new_request, block=True)
logging.info(
Expand Down
1 change: 1 addition & 0 deletions jetstream/core/server_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def run(
interleaved_mode=interleaved_mode,
jax_padding=jax_padding,
metrics_collector=metrics_collector,
is_ray_backend=config.is_ray_backend,
)
# We default threads to the total number of concurrent allowed decodes,
# to make sure we can fully saturate the model. Set default minimum to 64.
Expand Down

0 comments on commit e19a790

Please sign in to comment.