Skip to content

Commit

Permalink
cleanup startup logic to avoid using DEALER / ROUTER sockets
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic committed Dec 12, 2024
1 parent e16d63b commit ddbda1f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 78 deletions.
7 changes: 1 addition & 6 deletions vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,6 @@ class RPCAbortRequest:
request_id: str


class RPCStartupRequest(Enum):
IS_SERVER_READY = 1


@dataclass
class RPCStartupResponse:
tracing_enabled: bool
Expand All @@ -120,8 +116,7 @@ class RPCUProfileRequest(Enum):
STOP_PROFILE = 2


RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCStartupRequest,
RPCUProfileRequest]
RPC_REQUEST_T = Union[RPCProcessRequest, RPCAbortRequest, RPCUProfileRequest]

REQUEST_OUTPUTS_T = Union[List[RequestOutput], RPCError]

Expand Down
57 changes: 20 additions & 37 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.multiprocessing.ipc import (send_signed_async,
recv_signed_async)
Expand Down Expand Up @@ -137,7 +137,7 @@ def is_unsupported_config(engine_args: AsyncEngineArgs):

@contextmanager
def get_data_socket(self) -> Iterator[Socket]:
socket = self.context.socket(zmq.constants.DEALER)
socket = self.context.socket(zmq.constants.PULL)
try:
socket.connect(self.data_ipc_path)
yield socket
Expand Down Expand Up @@ -265,7 +265,7 @@ async def setup(self):

with self.get_data_socket() as socket:
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
response = await self._wait_for_server(socket)

self.tracing_flag = response.tracing_enabled

Expand All @@ -288,33 +288,7 @@ def _set_errored(self, e: BaseException):
logger.exception(repr(e))
if self._errored_with is None:
self._errored_with = e

@staticmethod
async def _send_get_data_rpc_request(request: RPCStartupRequest,
expected_type: Any,
error_message: str,
socket: Socket,
secret_key: bytes) -> Any:
"""Send an RPC request that is expecting data back."""

# Ping RPCServer with a request.
await send_signed_async(socket, secret_key, pickle.dumps(request))

# Make sure the server responds in time.
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("RPCServer didn't reply within "
f"{VLLM_RPC_TIMEOUT} ms")

# Await the data from the Server.
message = await recv_signed_async(socket, secret_key)
data = pickle.loads(message)

if isinstance(data, BaseException):
raise data
elif not isinstance(data, expected_type):
raise ValueError(error_message)

return data


@staticmethod
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
Expand Down Expand Up @@ -372,15 +346,24 @@ async def get_model_config(self) -> ModelConfig:
async def is_tracing_enabled(self) -> bool:
return self.tracing_flag

async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
async def _wait_for_server(self, socket: Socket) -> RPCStartupResponse:
"""Wait for the RPCServer to start up."""

return await self._send_get_data_rpc_request(
request=RPCStartupRequest.IS_SERVER_READY,
expected_type=RPCStartupResponse,
error_message="Unable to start RPC Server",
secret_key=self.secret_key,
socket=socket)
# Raise error if the server does not respond in time.
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("RPCServer didn't reply within "
f"{VLLM_RPC_TIMEOUT} ms")

# Await the data from the Server.
message = await recv_signed_async(socket, self.secret_key)
data = pickle.loads(message)

if isinstance(data, BaseException):
raise data
elif not isinstance(data, RPCStartupResponse):
raise ValueError("RPCServer failed to start.")

return data

async def abort(self, request_id: str):
"""Send an ABORT_REQUEST signal to the RPC Server"""
Expand Down
50 changes: 15 additions & 35 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.multiprocessing.ipc import (send_signed, recv_signed,
check_signed, sign)
from vllm.engine.multiprocessing.ipc import (send_signed, recv_signed)

# yapf: enable
from vllm.executor.gpu_executor import GPUExecutor
Expand Down Expand Up @@ -95,7 +94,7 @@ def __init__(self,
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")

# IPC path for the data socket.
# Send notification that we are ready.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"

# Error state.
Expand Down Expand Up @@ -135,8 +134,6 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,
def start(self):
try:
try:
logger.debug("Starting Startup Loop.")
self.run_startup_loop()
logger.debug("Starting Engine Loop.")
self.run_engine_loop()
except Exception as e:
Expand All @@ -154,42 +151,25 @@ def cleanup(self):
del self.engine

@contextmanager
def make_data_socket(
self) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
socket = self.ctx.socket(zmq.constants.ROUTER)
def make_push_socket(
self, path: str) -> Iterator[zmq.Socket]: # type: ignore[name-defined]
socket = self.ctx.socket(zmq.constants.PUSH)
try:
socket.bind(self.data_ipc_path)
socket.bind(path)
yield socket
finally:
socket.close(linger=0)

def run_startup_loop(self) -> None:
"""Startup loop for sending data from Engine -> Client."""

with self.make_data_socket() as socket:
response: Union[RPCStartupResponse, BaseException]
try:
identity, sig, message = socket.recv_multipart(copy=False)
if not check_signed(self.secret_key, sig, message.buffer):
raise ValueError("Message Signature is invalid.")
request: RPCStartupRequest = pickle.loads(message.buffer)

# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)

except Exception as e:
response = e

response_bytes = pickle.dumps(response)
sig = sign(self.secret_key, response_bytes)
socket.send_multipart((identity, sig, response_bytes), copy=False)
socket.close(linger=0)

def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""

# Alert that we are ready.
with self.make_push_socket(self.data_ipc_path) as socket:
response = RPCStartupResponse(
tracing_enabled=self.engine.is_tracing_enabled())
response_bytes = pickle.dumps(response)
send_signed(socket, self.secret_key, response_bytes)

while True:
if not self.engine.has_unfinished_requests():
# Poll until there is work to do.
Expand Down

0 comments on commit ddbda1f

Please sign in to comment.