Skip to content

Commit

Permalink
[Feature] Launch Long Live Servers and Multiple Client Groups (dmlc#3688
Browse files Browse the repository at this point in the history
)

* enable to launch multiple client groups sequentially

* launch simultaneously is enabled

* refine docstring

* revert unnecessary change

* [DOC] add doc for long live server

* refine

* refine doc

* refine doc
  • Loading branch information
Rhett-Ying authored Feb 9, 2022
1 parent 738e831 commit fcd8ed9
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 24 deletions.
47 changes: 46 additions & 1 deletion docs/source/guide/distributed-tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Below shows an example of launching a distributed training job in a cluster.
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1 --num_workers 4"
"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1"
The configuration file *ip_config.txt* contains the IP addresses of the machines in a cluster.
A typical example of *ip_config.txt* is as follows:
Expand All @@ -75,3 +75,48 @@ The launch script creates a specified number of training jobs (``--num_trainers`
In addition, a user needs to specify the number of sampler processes for each trainer
(``--num_samplers``). The number of sampler processes has to match with the number of worker processes
specified in :func:`~dgl.distributed.initialize`.

It is common that users may want to try different models or training configurations
against the same graph data. To avoid repetitively loading the same graph data, DGL
allows users to launch a persistent graph server to be shared across multiple training
jobs. A persistent graph server will stay alive even all training workers have
finished and exited. Below shows an example of launching a persistent graph server:

We first launch the graph server together with the first group of training workers.

.. code:: none
python3 tools/launch.py \
--workspace ~graphsage/ \
--num_trainers 2 \
--num_samplers 4 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--keep_alive \
--server_name long_live \
"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1"
Pay attention to the ``--keep_alive`` option, which indicates the server should
stay alive after workers have finished. ``--server_name`` is the given name of
the server which will be referred when launching new training jobs.

Launch another group of distributed training job and connect to the existing persistent server.

.. code:: none
python3 tools/launch.py \
--workspace ~graphsage/ \
--num_trainers 2 \
--num_samplers 4 \
--num_servers 1 \
--part_config data/ogb-product.json \
--ip_config ip_config.txt \
--server_name long_live \
"python3 code/train_dist.py --graph-name ogb-product --ip_config ip_config.txt --num-epochs 5 --batch-size 1000 --lr 0.1"
.. note::
All the arguments for ``launch.py`` should be kept same as previous launch. And below
arguments for specific training script should be kept same as well: ``--graph-name``,
``--ip_config``. The rest arguments such as ``--num-epochs``, ``--batch-size`` and so
on are free to change.
4 changes: 3 additions & 1 deletion python/dgl/distributed/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def initialize(ip_config, num_servers=1, num_workers=0,
formats = os.environ.get('DGL_GRAPH_FORMAT', 'csc').split(',')
formats = [f.strip() for f in formats]
rpc.reset()
keep_alive = os.environ.get('DGL_KEEP_ALIVE') is not None
keep_alive = bool(int(os.environ.get('DGL_KEEP_ALIVE', 0)))
serv = DistGraphServer(int(os.environ.get('DGL_SERVER_ID')),
os.environ.get('DGL_IP_CONFIG'),
int(os.environ.get('DGL_NUM_SERVER')),
Expand Down Expand Up @@ -322,6 +322,8 @@ def exit_client():
needs to call `exit_client` before calling `initialize` again.
"""
# Only client with rank_0 will send shutdown request to servers.
print("Client[{}] in group[{}] is exiting...".format(
rpc.get_rank(), rpc.get_group_id()))
finalize_worker() # finalize workers should be earilier than barrier, and non-blocking
# collect data such as DistTensor before exit
gc.collect()
Expand Down
2 changes: 1 addition & 1 deletion python/dgl/distributed/rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

__all__ = ['set_rank', 'get_rank', 'Request', 'Response', 'register_service', \
'create_sender', 'create_receiver', 'finalize_sender', 'finalize_receiver', \
'receiver_wait', 'connect_receiver', 'read_ip_config', \
'receiver_wait', 'connect_receiver', 'read_ip_config', 'get_group_id', \
'get_num_machines', 'set_num_machines', 'get_machine_id', 'set_machine_id', \
'send_request', 'recv_request', 'send_response', 'recv_response', 'remote_call', \
'send_request_to_machine', 'remote_call_to_machine', 'fast_pull', \
Expand Down
128 changes: 107 additions & 21 deletions tools/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from threading import Thread
from typing import Optional

DEFAULT_PORT = 30050

def cleanup_proc(get_all_remote_pids, conn):
'''This process tries to clean up the remote training tasks.
'''
Expand Down Expand Up @@ -271,6 +269,7 @@ def construct_dgl_server_env_vars(
ip_config: str,
num_servers: int,
graph_format: str,
keep_alive: bool,
pythonpath: Optional[str] = "",
) -> str:
"""Constructs the DGL server-specific env vars string that are required for DGL code to behave in the correct
Expand All @@ -287,6 +286,8 @@ def construct_dgl_server_env_vars(
Relative path to workspace.
num_servers:
graph_format:
keep_alive:
Whether to keep server alive when clients exit
pythonpath: Optional. If given, this will pass this as PYTHONPATH.
Returns:
Expand All @@ -302,6 +303,7 @@ def construct_dgl_server_env_vars(
"DGL_IP_CONFIG={DGL_IP_CONFIG} "
"DGL_NUM_SERVER={DGL_NUM_SERVER} "
"DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} "
"DGL_KEEP_ALIVE={DGL_KEEP_ALIVE} "
"{suffix_optional_envvars}"
)
suffix_optional_envvars = ""
Expand All @@ -316,6 +318,7 @@ def construct_dgl_server_env_vars(
DGL_IP_CONFIG=ip_config,
DGL_NUM_SERVER=num_servers,
DGL_GRAPH_FORMAT=graph_format,
DGL_KEEP_ALIVE=int(keep_alive),
suffix_optional_envvars=suffix_optional_envvars,
)

Expand All @@ -328,6 +331,7 @@ def construct_dgl_client_env_vars(
num_servers: int,
graph_format: str,
num_omp_threads: int,
group_id: int,
pythonpath: Optional[str] = "",
) -> str:
"""Constructs the DGL client-specific env vars string that are required for DGL code to behave in the correct
Expand All @@ -344,6 +348,8 @@ def construct_dgl_client_env_vars(
num_servers:
graph_format:
num_omp_threads:
group_id:
Used in client processes to indicate which group it belongs to.
pythonpath: Optional. If given, this will pass this as PYTHONPATH.
Returns:
Expand All @@ -360,6 +366,7 @@ def construct_dgl_client_env_vars(
"DGL_NUM_SERVER={DGL_NUM_SERVER} "
"DGL_GRAPH_FORMAT={DGL_GRAPH_FORMAT} "
"OMP_NUM_THREADS={OMP_NUM_THREADS} "
"DGL_GROUP_ID={DGL_GROUP_ID} "
"{suffix_optional_envvars}"
)
# append optional additional env-vars
Expand All @@ -376,6 +383,7 @@ def construct_dgl_client_env_vars(
DGL_NUM_SERVER=num_servers,
DGL_GRAPH_FORMAT=graph_format,
OMP_NUM_THREADS=num_omp_threads,
DGL_GROUP_ID=group_id,
suffix_optional_envvars=suffix_optional_envvars,
)

Expand Down Expand Up @@ -424,6 +432,72 @@ def wrap_cmd_with_extra_envvars(cmd: str, env_vars: list) -> str:
env_vars = " ".join(env_vars)
return wrap_cmd_with_local_envvars(cmd, env_vars)


g_monitor_file = None
g_group_id = 0

def has_alive_servers(args):
"""Check whether there exists alive servers.
For each group of long live servers, a monitor file named
'dgl_dist_monitor_{args.server_name}' is created under '/tmp/' directory.
We check the existence of this monitor file to determine whether to
launch new servers or utilize the existing alive ones. If there
exist alive servers, we obtain availale group ID from the monitor
file which could be used in current client groups.
Returns
-------
bool
indicates whether there exists alive servers.
"""
if args.server_name is None:
return False
global g_monitor_file
global g_group_id
monitor_file = '/tmp/dgl_dist_monitor_' + args.server_name
from filelock import FileLock
lock = FileLock(monitor_file + '.lock')
with lock:
next_group_id = None
ret = os.path.exists(monitor_file)
if ret:
print("Monitor file for alive servers already exist: {}.".format(monitor_file))
lines = [line.rstrip('\n') for line in open(monitor_file)]
g_group_id = int(lines[0])
next_group_id = g_group_id + 1
if not ret and args.keep_alive:
next_group_id = 1
print("Monitor file for alive servers is created: {}.".format(monitor_file))
g_monitor_file = monitor_file
if next_group_id is not None:
with open(monitor_file, 'w') as f:
f.write(str(next_group_id))
return ret


def clean_alive_servers():
"""Remove keep alive related files"""
global g_monitor_file
try:
if g_monitor_file is not None:
os.remove(g_monitor_file)
os.remove(g_monitor_file + '.lock')
print("Monitor file for alive servers is removed: {}.".format(g_monitor_file))
except:
print("Failed to delete monitor file for alive servers: {}.".format(g_monitor_file))

def get_available_port(ip):
"""Get available port with specified ip."""
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
for port in range(1234, 65535):
try:
sock.connect((ip, port))
except:
return port
raise RuntimeError("Failed to get available port for ip~{}".format(ip))

def submit_jobs(args, udf_command):
"""Submit distributed jobs (server and client processes) via ssh"""
hosts = []
Expand All @@ -441,7 +515,7 @@ def submit_jobs(args, udf_command):
hosts.append((ip, port))
elif len(result) == 1:
ip = result[0]
port = DEFAULT_PORT
port = get_available_port(ip)
hosts.append((ip, port))
else:
raise RuntimeError("Format error of ip_config.")
Expand All @@ -457,23 +531,27 @@ def submit_jobs(args, udf_command):

tot_num_clients = args.num_trainers * (1 + args.num_samplers) * len(hosts)
# launch server tasks
server_env_vars = construct_dgl_server_env_vars(
num_samplers=args.num_samplers,
num_server_threads=args.num_server_threads,
tot_num_clients=tot_num_clients,
part_config=args.part_config,
ip_config=args.ip_config,
num_servers=args.num_servers,
graph_format=args.graph_format,
pythonpath=os.environ.get("PYTHONPATH", ""),
)
for i in range(len(hosts) * server_count_per_machine):
ip, _ = hosts[int(i / server_count_per_machine)]
server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}"
cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
if not has_alive_servers(args):
server_env_vars = construct_dgl_server_env_vars(
num_samplers=args.num_samplers,
num_server_threads=args.num_server_threads,
tot_num_clients=tot_num_clients,
part_config=args.part_config,
ip_config=args.ip_config,
num_servers=args.num_servers,
graph_format=args.graph_format,
keep_alive=args.keep_alive,
pythonpath=os.environ.get("PYTHONPATH", ""),
)
for i in range(len(hosts) * server_count_per_machine):
ip, _ = hosts[int(i / server_count_per_machine)]
server_env_vars_cur = f"{server_env_vars} DGL_SERVER_ID={i}"
cmd = wrap_cmd_with_local_envvars(udf_command, server_env_vars_cur)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
cmd = 'cd ' + str(args.workspace) + '; ' + cmd
thread_list.append(execute_remote(cmd, ip, args.ssh_port, username=args.ssh_username))
else:
print(f"Use running server {args.server_name}.")

# launch client tasks
client_env_vars = construct_dgl_client_env_vars(
Expand All @@ -484,6 +562,7 @@ def submit_jobs(args, udf_command):
num_servers=args.num_servers,
graph_format=args.graph_format,
num_omp_threads=os.environ.get("OMP_NUM_THREADS", str(args.num_omp_threads)),
group_id=g_group_id,
pythonpath=os.environ.get("PYTHONPATH", ""),
)

Expand All @@ -496,7 +575,7 @@ def submit_jobs(args, udf_command):
num_nodes=len(hosts),
node_rank=node_id,
master_addr=hosts[0][0],
master_port=1234,
master_port=get_available_port(hosts[0][0]),
)
cmd = wrap_cmd_with_local_envvars(torch_dist_udf_command, client_env_vars)
cmd = wrap_cmd_with_extra_envvars(cmd, args.extra_envs) if len(args.extra_envs) > 0 else cmd
Expand All @@ -513,6 +592,7 @@ def signal_handler(signal, frame):
logging.info('Stop launcher')
# We need to tell the cleanup process to kill remote training jobs.
conn2.send('cleanup')
clean_alive_servers()
sys.exit(0)
signal.signal(signal.SIGINT, signal_handler)

Expand Down Expand Up @@ -560,7 +640,13 @@ def main():
help='Extra environment parameters need to be set. For example, \
you can set the LD_LIBRARY_PATH and NCCL_DEBUG by adding: \
--extra_envs LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH NCCL_DEBUG=INFO ')
parser.add_argument('--keep_alive', action='store_true', help='Servers keep alive when clients exit')
parser.add_argument('--server_name', type=str,
help='Used to check whether there exist alive servers')
args, udf_command = parser.parse_known_args()
if args.keep_alive:
assert args.server_name is not None, "Server name is required if '--keep_alive' is enabled."
print("Servers will keep alive even clients exit...")
assert len(udf_command) == 1, 'Please provide user command line.'
assert args.num_trainers is not None and args.num_trainers > 0, \
'--num_trainers must be a positive number.'
Expand Down

0 comments on commit fcd8ed9

Please sign in to comment.