Skip to content

Commit

Permalink
[core][distributed] fix zmq hang (vllm-project#6759)
Browse files Browse the repository at this point in the history
(cherry picked from commit 740374d)
  • Loading branch information
youkaichao authored and russellb committed Sep 18, 2024
1 parent a1ba240 commit a7c521b
Showing 1 changed file with 21 additions and 39 deletions.
60 changes: 21 additions & 39 deletions vllm/distributed/device_communicators/shm_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context # type: ignore
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore

import vllm.envs as envs
from vllm.logger import init_logger
Expand Down Expand Up @@ -145,9 +145,7 @@ class Handle:

buffer: Optional[ShmRingBuffer] = None
local_subscribe_port: Optional[int] = None
local_sync_port: Optional[int] = None
remote_subscribe_port: Optional[int] = None
remote_sync_port: Optional[int] = None


class MessageQueue:
Expand Down Expand Up @@ -181,38 +179,36 @@ def __init__(
self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
max_chunks)

self.local_socket = context.socket(PUB)
# XPUB is very similar to PUB,
# except that it can receive subscription messages
# to confirm the number of subscribers
self.local_socket = context.socket(XPUB)
# set the verbose option so that we can receive every subscription
# message. otherwise, we will only receive the first subscription
# see http://api.zeromq.org/3-3:zmq-setsockopt for more details
self.local_socket.setsockopt(XPUB_VERBOSE, True)
local_subscribe_port = get_open_port()
self.local_socket.bind(f"tcp://*:{local_subscribe_port}")

self.local_sync_socket = context.socket(REP)
local_sync_port = get_open_port()
self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
self.current_idx = 0

else:
self.buffer = None # type: ignore
local_subscribe_port = None
local_sync_port = None
self.local_socket = None
self.local_sync_socket = None
self.current_idx = -1

if n_remote_reader > 0:
# for remote readers, we will:
# create a publish-subscribe socket to communicate large data
self.remote_socket = context.socket(PUB)
self.remote_socket = context.socket(XPUB)
self.remote_socket.setsockopt(XPUB_VERBOSE, True)
remote_subscribe_port = get_open_port()
self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")

self.remote_sync_socket = context.socket(REP)
remote_sync_port = get_open_port()
self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
else:
remote_subscribe_port = None
remote_sync_port = None
self.remote_socket = None
self.remote_sync_socket = None

self._is_writer = True
self._is_local_reader = False
Expand All @@ -225,9 +221,7 @@ def __init__(
local_reader_ranks=local_reader_ranks,
buffer=self.buffer,
local_subscribe_port=local_subscribe_port,
local_sync_port=local_sync_port,
remote_subscribe_port=remote_subscribe_port,
remote_sync_port=remote_sync_port,
)

def export_handle(self) -> Handle:
Expand All @@ -254,12 +248,7 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self.local_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")

self.local_sync_socket = context.socket(REQ)
self.local_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.local_sync_port}")

self.remote_socket = None
self.remote_sync_socket = None
else:
self.buffer = None # type: ignore
self.current_idx = -1
Expand All @@ -268,17 +257,12 @@ def create_from_handle(handle: Handle, rank) -> "MessageQueue":
self._is_remote_reader = True

self.local_socket = None
self.local_sync_socket = None

self.remote_socket = context.socket(SUB)
self.remote_socket.setsockopt_string(SUBSCRIBE, "")
self.remote_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")

self.remote_sync_socket = context.socket(REQ)
self.remote_sync_socket.connect(
f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")

return self

def wait_until_ready(self):
Expand All @@ -290,29 +274,27 @@ def wait_until_ready(self):

# local readers
for i in range(self.n_local_reader):
recv = self.local_sync_socket.recv()
assert recv == b"READY"
self.local_sync_socket.send(b"READY")
# wait for subscription messages from all local readers
self.local_socket.recv()
if self.n_local_reader > 0:
# send a message to all local readers
# to make sure the publish channel is working
self.local_socket.send(b"READY")

# remote readers
for i in range(self.n_remote_reader):
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
self.remote_sync_socket.send(b"READY")
# wait for subscription messages from all remote readers
self.remote_socket.recv()
if self.n_remote_reader > 0:
# send a message to all remote readers
# to make sure the publish channel is working
self.remote_socket.send(b"READY")
elif self._is_local_reader:
self.local_sync_socket.send(b"READY")
recv = self.local_sync_socket.recv()
assert recv == b"READY"
# wait for the writer to send a message
recv = self.local_socket.recv()
assert recv == b"READY"
elif self._is_remote_reader:
self.remote_sync_socket.send(b"READY")
recv = self.remote_sync_socket.recv()
assert recv == b"READY"
# wait for the writer to send a message
recv = self.remote_socket.recv()
assert recv == b"READY"

Expand Down

0 comments on commit a7c521b

Please sign in to comment.