Skip to content

Commit

Permalink
Merge branch 'main' into chhwang/cq-warning
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Nov 14, 2023
2 parents 6d0dda0 + 4cdb100 commit 8abfb76
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 7 deletions.
10 changes: 5 additions & 5 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ void register_core(nb::module_& m) {
nb::arg("nRanks"))
.def("create_unique_id", &TcpBootstrap::createUniqueId)
.def("get_unique_id", &TcpBootstrap::getUniqueId)
.def("initialize", (void (TcpBootstrap::*)(UniqueId, int64_t)) & TcpBootstrap::initialize, nb::arg("uniqueId"),
nb::arg("timeoutSec") = 30)
.def("initialize", (void (TcpBootstrap::*)(const std::string&, int64_t)) & TcpBootstrap::initialize,
nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);
.def("initialize", static_cast<void (TcpBootstrap::*)(UniqueId, int64_t)>(&TcpBootstrap::initialize),
nb::call_guard<nb::gil_scoped_release>(), nb::arg("uniqueId"), nb::arg("timeoutSec") = 30)
.def("initialize", static_cast<void (TcpBootstrap::*)(const std::string&, int64_t)>(&TcpBootstrap::initialize),
nb::call_guard<nb::gil_scoped_release>(), nb::arg("ifIpPortTrio"), nb::arg("timeoutSec") = 30);

nb::enum_<Transport>(m, "Transport")
.value("Unknown", Transport::Unknown)
Expand Down Expand Up @@ -120,7 +120,7 @@ void register_core(nb::module_& m) {
self->updateAndSync(dst, dstOffset, (uint64_t*)src, newValue);
},
nb::arg("dst"), nb::arg("dstOffset"), nb::arg("src"), nb::arg("newValue"))
.def("flush", &Connection::flush, nb::arg("timeoutUsec") = (int64_t)3e7)
.def("flush", &Connection::flush, nb::call_guard<nb::gil_scoped_release>(), nb::arg("timeoutUsec") = (int64_t)3e7)
.def("transport", &Connection::transport)
.def("remote_transport", &Connection::remoteTransport);

Expand Down
3 changes: 2 additions & 1 deletion python/mscclpp/semaphore_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ void register_semaphore(nb::module_& m) {
.def("connection", &Host2HostSemaphore::connection)
.def("signal", &Host2HostSemaphore::signal)
.def("poll", &Host2HostSemaphore::poll)
.def("wait", &Host2HostSemaphore::wait, nb::arg("max_spin_count") = 10000000);
.def("wait", &Host2HostSemaphore::wait, nb::call_guard<nb::gil_scoped_release>(),
nb::arg("max_spin_count") = 10000000);

nb::class_<SmDevice2DeviceSemaphore> smDevice2DeviceSemaphore(m, "SmDevice2DeviceSemaphore");
smDevice2DeviceSemaphore
Expand Down
82 changes: 81 additions & 1 deletion python/test/test_mscclpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,22 @@

from concurrent.futures import ThreadPoolExecutor
import time
import threading

import cupy as cp
import numpy as np
import netifaces as ni
import pytest

from mscclpp import Fifo, Host2DeviceSemaphore, Host2HostSemaphore, ProxyService, SmDevice2DeviceSemaphore, Transport
from mscclpp import (
TcpBootstrap,
Fifo,
Host2DeviceSemaphore,
Host2HostSemaphore,
ProxyService,
SmDevice2DeviceSemaphore,
Transport,
)
from ._cpp import _ext
from .mscclpp_group import MscclppGroup
from .mscclpp_mpi import MpiGroup, parametrize_mpi_groups, mpi_group
Expand Down Expand Up @@ -63,6 +72,50 @@ def test_group_with_ip(mpi_group: MpiGroup, ifIpPortTrio: str):
assert np.array_equal(memory, memory_expected)


@parametrize_mpi_groups(2, 4, 8, 16)
def test_bootstrap_init_gil_release(mpi_group: MpiGroup):
bootstrap = TcpBootstrap.create(mpi_group.comm.rank, mpi_group.comm.size)
uniq_id = None
if mpi_group.comm.rank == 0:
# similar to NCCL's unique id
uniq_id = bootstrap.create_unique_id()
uniq_id_global = mpi_group.comm.bcast(uniq_id, 0)

if mpi_group.comm.rank == 0:
# rank 0 never initializes the bootstrap, making other ranks block
pass
else:
check_list = []

def check_target():
check_list.append("this thread could run.")

def init_target():
try:
# expected to raise a timeout after 3 seconds
bootstrap.initialize(uniq_id_global, 3)
except:
pass

init_thread = threading.Thread(target=init_target)
check_thread = threading.Thread(target=check_target)
init_thread.start()

time.sleep(0.1)

# check that the check thread is not blocked
s = time.time()
check_thread.start()
check_thread.join()
e = time.time()
assert e - s < 0.1
assert len(check_list) == 1

init_thread.join()

mpi_group.comm.barrier()


def create_and_connect(mpi_group: MpiGroup, transport: str):
if transport == "NVLink" and all_ranks_on_the_same_node(mpi_group) is False:
pytest.skip("cannot use nvlink for cross node")
Expand Down Expand Up @@ -186,6 +239,33 @@ def test_h2h_semaphores(mpi_group: MpiGroup):
group.barrier()


@parametrize_mpi_groups(2, 4, 8, 16)
def test_h2h_semaphores_gil_release(mpi_group: MpiGroup):
group, connections = create_and_connect(mpi_group, "IB")

semaphores = group.make_semaphore(connections, Host2HostSemaphore)

def target_wait(sems, conns):
for rank in conns:
sems[rank].wait(-1)

def target_signal(sems, conns):
# sleep 1 sec to let target_wait() starts a bit earlier
time.sleep(1)
# if wait() doesn't release GIL, this will block forever
for rank in conns:
sems[rank].signal()

wait_thread = threading.Thread(target=target_wait, args=(semaphores, connections))
signal_thread = threading.Thread(target=target_signal, args=(semaphores, connections))
wait_thread.start()
signal_thread.start()
signal_thread.join()
wait_thread.join()

group.barrier()


class MscclppKernel:
def __init__(
self,
Expand Down

0 comments on commit 8abfb76

Please sign in to comment.