Skip to content

Commit

Permalink
[serve] refactor router (ray-project#48252)
Browse files Browse the repository at this point in the history
## Why are these changes needed?

depends on ray-project#48232

Pull out some router init parameters and router init logic into
`create_router`.

---------

Signed-off-by: Cindy Zhang <[email protected]>
  • Loading branch information
zcin authored Oct 29, 2024
1 parent d0c6c60 commit b71f6b7
Show file tree
Hide file tree
Showing 7 changed files with 179 additions and 139 deletions.
67 changes: 66 additions & 1 deletion python/ray/serve/_private/default_impl.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,34 @@
import asyncio
from typing import Callable, Optional

import ray
from ray._raylet import GcsClient
from ray.actor import ActorHandle
from ray.serve._private.cluster_node_info_cache import (
ClusterNodeInfoCache,
DefaultClusterNodeInfoCache,
)
from ray.serve._private.common import DeploymentID
from ray.serve._private.constants import (
RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE,
RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
)
from ray.serve._private.deployment_scheduler import (
DefaultDeploymentScheduler,
DeploymentScheduler,
)
from ray.serve._private.grpc_util import gRPCServer
from ray.serve._private.utils import get_head_node_id
from ray.serve._private.replica_scheduler import (
ActorReplicaWrapper,
PowerOfTwoChoicesReplicaScheduler,
)
from ray.serve._private.router import Router
from ray.serve._private.utils import (
get_head_node_id,
inside_ray_client_context,
resolve_request_args,
)

# NOTE: Please read carefully before changing!
#
Expand Down Expand Up @@ -50,6 +67,54 @@ def create_init_handle_options(**kwargs):
return _InitHandleOptions.create(**kwargs)


def create_router(
controller_handle: ActorHandle,
deployment_id: DeploymentID,
handle_id: str,
node_id: str,
actor_id: str,
availability_zone: Optional[str],
event_loop: asyncio.BaseEventLoop,
handle_options,
):
is_inside_ray_client_context = inside_ray_client_context()

replica_scheduler = PowerOfTwoChoicesReplicaScheduler(
event_loop,
deployment_id,
handle_options._source,
handle_options._prefer_local_routing,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
node_id,
actor_id,
ray.get_runtime_context().current_actor
if ray.get_runtime_context().get_actor_id()
else None,
availability_zone,
# Streaming ObjectRefGenerators are not supported in Ray Client
use_replica_queue_len_cache=(
not is_inside_ray_client_context and RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE
),
create_replica_wrapper_func=lambda r: ActorReplicaWrapper(r),
)

return Router(
controller_handle=controller_handle,
deployment_id=deployment_id,
handle_id=handle_id,
self_actor_id=actor_id,
handle_source=handle_options._source,
event_loop=event_loop,
replica_scheduler=replica_scheduler,
# Streaming ObjectRefGenerators are not supported in Ray Client
enable_strict_max_ongoing_requests=(
not is_inside_ray_client_context
and RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS
),
resolve_request_args_func=resolve_request_args,
)


def add_grpc_address(grpc_server: gRPCServer, server_address: str):
"""Helper function to add a address to gRPC server."""
grpc_server.add_insecure_port(server_address)
1 change: 1 addition & 0 deletions python/ray/serve/_private/replica_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
ReplicaScheduler,
)
from ray.serve._private.replica_scheduler.replica_wrapper import ( # noqa: F401
ActorReplicaWrapper,
ReplicaWrapper,
)
34 changes: 21 additions & 13 deletions python/ray/serve/_private/replica_scheduler/pow_2_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DeploymentHandleSource,
DeploymentID,
ReplicaID,
ReplicaQueueLengthInfo,
RequestMetadata,
RunningReplicaInfo,
)
Expand Down Expand Up @@ -245,6 +246,26 @@ def create_replica_wrapper(
) -> ReplicaWrapper:
return self._create_replica_wrapper_func(replica_info)

def on_replica_actor_died(self, replica_id: ReplicaID):
"""Drop replica from replica set so it's not considered for future requests."""
self._replicas.pop(replica_id, None)
self._replica_id_set.discard(replica_id)
for id_set in self._colocated_replica_ids.values():
id_set.discard(replica_id)

def on_replica_actor_unavailable(self, replica_id: ReplicaID):
"""Invalidate cache entry so active probing is required for the next request."""
self._replica_queue_len_cache.invalidate_key(replica_id)

def on_new_queue_len_info(
self, replica_id: ReplicaID, queue_len_info: ReplicaQueueLengthInfo
):
"""Update queue length cache with new info received from replica."""
if self._use_replica_queue_len_cache:
self._replica_queue_len_cache.update(
replica_id, queue_len_info.num_ongoing_requests
)

def update_replicas(self, replicas: List[ReplicaWrapper]):
"""Update the set of available replicas to be considered for scheduling.
Expand Down Expand Up @@ -473,19 +494,6 @@ async def choose_two_replicas_with_backoff(
self.num_scheduling_tasks_in_backoff
)

def on_replica_actor_died(self, replica_id: ReplicaID):
"""Drop replica from replica set so it's not considered for future requests."""

self._replicas.pop(replica_id, None)
self._replica_id_set.discard(replica_id)
for id_set in self._colocated_replica_ids.values():
id_set.discard(replica_id)

def on_replica_actor_unavailable(self, replica_id: ReplicaID):
"""Invalidate cache entry so active probing is required for the next request."""

self._replica_queue_len_cache.invalidate_key(replica_id)

async def _probe_queue_lens(
self,
replicas: List[ReplicaWrapper],
Expand Down
125 changes: 14 additions & 111 deletions python/ray/serve/_private/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
from collections import defaultdict
from contextlib import contextmanager
from functools import partial
from typing import Any, DefaultDict, Dict, List, Optional, Tuple, Union
from typing import Any, Coroutine, DefaultDict, List, Optional, Tuple, Union

import ray
from ray.actor import ActorHandle
from ray.exceptions import ActorDiedError, ActorUnavailableError, RayError
from ray.serve._private.common import (
Expand All @@ -21,22 +20,14 @@
from ray.serve._private.constants import (
HANDLE_METRIC_PUSH_INTERVAL_S,
RAY_SERVE_COLLECT_AUTOSCALING_METRICS_ON_HANDLE,
RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE,
RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS,
RAY_SERVE_HANDLE_AUTOSCALING_METRIC_RECORD_PERIOD_S,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
SERVE_LOGGER_NAME,
)
from ray.serve._private.long_poll import LongPollClient, LongPollNamespace
from ray.serve._private.metrics_utils import InMemoryMetricsStore, MetricsPusher
from ray.serve._private.replica_result import ReplicaResult
from ray.serve._private.replica_scheduler import (
PendingRequest,
PowerOfTwoChoicesReplicaScheduler,
ReplicaScheduler,
)
from ray.serve._private.replica_scheduler.replica_wrapper import ActorReplicaWrapper
from ray.serve._private.utils import inside_ray_client_context
from ray.serve._private.replica_scheduler import PendingRequest, ReplicaScheduler
from ray.serve._private.utils import resolve_request_args
from ray.serve.config import AutoscalingConfig
from ray.serve.exceptions import BackPressureError
from ray.util import metrics
Expand Down Expand Up @@ -69,7 +60,6 @@ def __init__(
self._self_actor_id = self_actor_id
self._handle_source = handle_source
self._controller_handle = controller_handle
self._self_actor_id = self_actor_id

# Exported metrics
self.num_router_requests = router_requests_counter
Expand Down Expand Up @@ -324,16 +314,12 @@ def __init__(
controller_handle: ActorHandle,
deployment_id: DeploymentID,
handle_id: str,
self_node_id: str,
self_actor_id: str,
self_availability_zone: Optional[str],
handle_source: DeploymentHandleSource,
event_loop: asyncio.BaseEventLoop = None,
_prefer_local_node_routing: bool = False,
enable_queue_len_cache: bool = RAY_SERVE_ENABLE_QUEUE_LENGTH_CACHE,
enable_strict_max_ongoing_requests: bool = RAY_SERVE_ENABLE_STRICT_MAX_ONGOING_REQUESTS, # noqa: E501
*,
replica_scheduler: Optional[ReplicaScheduler] = None,
event_loop: asyncio.BaseEventLoop,
replica_scheduler: Optional[ReplicaScheduler],
enable_strict_max_ongoing_requests: bool,
resolve_request_args_func: Coroutine = resolve_request_args,
):
"""Used to assign requests to downstream replicas for a deployment.
Expand All @@ -343,37 +329,11 @@ def __init__(

self._event_loop = event_loop
self.deployment_id = deployment_id

if inside_ray_client_context():
# Streaming ObjectRefGenerators are not supported in Ray Client, so we need
# to override the behavior.
self._enable_queue_len_cache = False
self._enable_strict_max_ongoing_requests = False
else:
self._enable_queue_len_cache = enable_queue_len_cache
self._enable_strict_max_ongoing_requests = (
enable_strict_max_ongoing_requests
)

replica_wrapper_cls = ActorReplicaWrapper
if replica_scheduler is None:
replica_scheduler = PowerOfTwoChoicesReplicaScheduler(
self._event_loop,
deployment_id,
handle_source,
_prefer_local_node_routing,
RAY_SERVE_PROXY_PREFER_LOCAL_AZ_ROUTING,
self_node_id,
self_actor_id,
ray.get_runtime_context().current_actor
if ray.get_runtime_context().get_actor_id()
else None,
self_availability_zone,
use_replica_queue_len_cache=enable_queue_len_cache,
create_replica_wrapper_func=lambda r: replica_wrapper_cls(r),
)
self._enable_strict_max_ongoing_requests = enable_strict_max_ongoing_requests

self._replica_scheduler: ReplicaScheduler = replica_scheduler
self._resolve_request_args = resolve_request_args_func

# Flipped to `True` once the router has received a non-empty
# replica set at least once.
self.running_replicas_populated: bool = False
Expand Down Expand Up @@ -443,62 +403,6 @@ def update_deployment_config(self, deployment_config: DeploymentConfig):
curr_num_replicas=len(self._replica_scheduler.curr_replicas),
)

async def _resolve_deployment_responses(
self, request_args: Tuple[Any], request_kwargs: Dict[str, Any]
) -> Tuple[Tuple[Any], Dict[str, Any]]:
"""Replaces top-level `DeploymentResponse` objects with resolved object refs.
This enables composition without explicitly calling `_to_object_ref`.
"""
from ray.serve.handle import DeploymentResponse, DeploymentResponseGenerator

generator_not_supported_message = (
"Streaming deployment handle results cannot be passed to "
"downstream handle calls. If you have a use case requiring "
"this feature, please file a feature request on GitHub."
)

new_args = [None for _ in range(len(request_args))]
new_kwargs = {}

arg_tasks = []
response_indices = []
for i, obj in enumerate(request_args):
if isinstance(obj, DeploymentResponseGenerator):
raise RuntimeError(generator_not_supported_message)
elif isinstance(obj, DeploymentResponse):
# Launch async task to convert DeploymentResponse to an object ref, and
# keep track of the argument index in the original `request_args`
response_indices.append(i)
arg_tasks.append(asyncio.create_task(obj._to_object_ref()))
else:
new_args[i] = obj

kwarg_tasks = []
response_keys = []
for k, obj in request_kwargs.items():
if isinstance(obj, DeploymentResponseGenerator):
raise RuntimeError(generator_not_supported_message)
elif isinstance(obj, DeploymentResponse):
# Launch async task to convert DeploymentResponse to an object ref, and
# keep track of the corresponding key in the original `request_kwargs`
response_keys.append(k)
kwarg_tasks.append(asyncio.create_task(obj._to_object_ref()))
else:
new_kwargs[k] = obj

# Gather `DeploymentResponse` object refs concurrently.
arg_obj_refs = await asyncio.gather(*arg_tasks)
kwarg_obj_refs = await asyncio.gather(*kwarg_tasks)

# Update new args and new kwargs with resolved object refs
for index, obj_ref in zip(response_indices, arg_obj_refs):
new_args[index] = obj_ref
new_kwargs.update((zip(response_keys, kwarg_obj_refs)))

# Return new args and new kwargs
return new_args, new_kwargs

def _process_finished_request(
self, replica_id: ReplicaID, result: Union[Any, RayError]
):
Expand Down Expand Up @@ -546,10 +450,9 @@ async def schedule_and_send_request(
replica_result,
queue_len_info,
) = await replica.send_request_with_rejection(pr)
if self._enable_queue_len_cache:
self._replica_scheduler.replica_queue_len_cache.update(
replica.replica_id, queue_len_info.num_ongoing_requests
)
self._replica_scheduler.on_new_queue_len_info(
replica.replica_id, queue_len_info
)
if queue_len_info.accepted:
return replica_result, replica.replica_id
except asyncio.CancelledError:
Expand Down Expand Up @@ -604,7 +507,7 @@ async def assign_request(

ref = None
try:
request_args, request_kwargs = await self._resolve_deployment_responses(
request_args, request_kwargs = await self._resolve_request_args(
request_args, request_kwargs
)
ref, replica_id = await self.schedule_and_send_request(
Expand Down
Loading

0 comments on commit b71f6b7

Please sign in to comment.