From 740374d456a638df98ffbc7d9dab328752330e62 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Wed, 24 Jul 2024 17:37:12 -0700 Subject: [PATCH] [core][distributed] fix zmq hang (#6759) --- vllm/connections.py | 4 +- .../device_communicators/shm_broadcast.py | 60 +++++++------------ 2 files changed, 23 insertions(+), 41 deletions(-) diff --git a/vllm/connections.py b/vllm/connections.py index 65d44176e2464..e785a0b3ebd74 100644 --- a/vllm/connections.py +++ b/vllm/connections.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Mapping, Optional +from typing import Mapping, MutableMapping, Optional from urllib.parse import urlparse import aiohttp @@ -40,7 +40,7 @@ def _validate_http_url(self, url: str): raise ValueError("Invalid HTTP URL: A valid HTTP URL " "must have scheme 'http' or 'https'.") - def _headers(self, **extras: str) -> Mapping[str, str]: + def _headers(self, **extras: str) -> MutableMapping[str, str]: return {"User-Agent": f"vLLM/{VLLM_VERSION}", **extras} def get_response( diff --git a/vllm/distributed/device_communicators/shm_broadcast.py b/vllm/distributed/device_communicators/shm_broadcast.py index 75d84c7a71bc3..d4847542688c0 100644 --- a/vllm/distributed/device_communicators/shm_broadcast.py +++ b/vllm/distributed/device_communicators/shm_broadcast.py @@ -9,7 +9,7 @@ import torch import torch.distributed as dist from torch.distributed import ProcessGroup -from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore +from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore import vllm.envs as envs from vllm.logger import init_logger @@ -153,9 +153,7 @@ class Handle: buffer: Optional[ShmRingBuffer] = None local_subscribe_port: Optional[int] = None - local_sync_port: Optional[int] = None remote_subscribe_port: Optional[int] = None - remote_sync_port: Optional[int] = None class MessageQueue: @@ -189,38 +187,36 @@ def __init__( self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks) - self.local_socket = context.socket(PUB) + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) local_subscribe_port = get_open_port() self.local_socket.bind(f"tcp://*:{local_subscribe_port}") - self.local_sync_socket = context.socket(REP) - local_sync_port = get_open_port() - self.local_sync_socket.bind(f"tcp://*:{local_sync_port}") self.current_idx = 0 else: self.buffer = None # type: ignore local_subscribe_port = None - local_sync_port = None self.local_socket = None - self.local_sync_socket = None self.current_idx = -1 if n_remote_reader > 0: # for remote readers, we will: # create a publish-subscribe socket to communicate large data - self.remote_socket = context.socket(PUB) + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) remote_subscribe_port = get_open_port() self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}") - self.remote_sync_socket = context.socket(REP) - remote_sync_port = get_open_port() - self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}") else: remote_subscribe_port = None - remote_sync_port = None self.remote_socket = None - self.remote_sync_socket = None self._is_writer = True self._is_local_reader = False @@ -233,9 +229,7 @@ def __init__( local_reader_ranks=local_reader_ranks, buffer=self.buffer, local_subscribe_port=local_subscribe_port, - local_sync_port=local_sync_port, remote_subscribe_port=remote_subscribe_port, - remote_sync_port=remote_sync_port, ) logger.info("vLLM message queue communication handle: %s", self.handle) @@ -264,12 +258,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self.local_socket.connect( f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}") - self.local_sync_socket = context.socket(REQ) - self.local_sync_socket.connect( - f"tcp://{handle.connect_ip}:{handle.local_sync_port}") - self.remote_socket = None - self.remote_sync_socket = None else: self.buffer = None # type: ignore self.current_idx = -1 @@ -278,17 +267,12 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue": self._is_remote_reader = True self.local_socket = None - self.local_sync_socket = None self.remote_socket = context.socket(SUB) self.remote_socket.setsockopt_string(SUBSCRIBE, "") self.remote_socket.connect( f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}") - self.remote_sync_socket = context.socket(REQ) - self.remote_sync_socket.connect( - f"tcp://{handle.connect_ip}:{handle.remote_sync_port}") - return self def wait_until_ready(self): @@ -300,29 +284,27 @@ def wait_until_ready(self): # local readers for i in range(self.n_local_reader): - recv = self.local_sync_socket.recv() - assert recv == b"READY" - self.local_sync_socket.send(b"READY") + # wait for subscription messages from all local readers + self.local_socket.recv() if self.n_local_reader > 0: + # send a message to all local readers + # to make sure the publish channel is working self.local_socket.send(b"READY") # remote readers for i in range(self.n_remote_reader): - recv = self.remote_sync_socket.recv() - assert recv == b"READY" - self.remote_sync_socket.send(b"READY") + # wait for subscription messages from all remote readers + self.remote_socket.recv() if self.n_remote_reader > 0: + # send a message to all remote readers + # to make sure the publish channel is working self.remote_socket.send(b"READY") elif self._is_local_reader: - self.local_sync_socket.send(b"READY") - recv = self.local_sync_socket.recv() - assert recv == b"READY" + # wait for the writer to send a message recv = self.local_socket.recv() assert recv == b"READY" elif self._is_remote_reader: - self.remote_sync_socket.send(b"READY") - recv = self.remote_sync_socket.recv() - assert recv == b"READY" + # wait for the writer to send a message recv = self.remote_socket.recv() assert recv == b"READY"