Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto-tune single-node AllReduce #219

Merged
merged 7 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/benchmark/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,9 @@ __forceinline__ __device__ void vectorSum(TYPE* dst, TYPE* src, size_t nElem) {
// AllReduce1
// -------------------------------------------

#ifndef READ_ONLY
#define READ_ONLY 0
#endif

extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks, size_t nelems) {
template <int READ_ONLY>
__device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, int rank, int nranks,
size_t nelems) {
const size_t chunkSize = nelems / nranks;
if (nranks == 1) return;
const int nPeer = nranks - 1;
Expand Down Expand Up @@ -211,13 +208,21 @@ extern "C" __global__ void __launch_bounds__(1024, 1)
}
}

extern "C" __global__ void __launch_bounds__(1024, 1) allreduce1(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff,
int rank, int nranks, size_t nelems, int read_only) {
if (read_only)
allreduce1_helper<1>(smChans, buff, rank, nranks, nelems);
else
allreduce1_helper<0>(smChans, buff, rank, nranks, nelems);
}

// -------------------------------------------
// AllReduce2
// -------------------------------------------

__device__ uint64_t globalFlag = 1;

extern "C" __global__ void __launch_bounds__(512, 1)
extern "C" __global__ void __launch_bounds__(1024, 1)
allreduce2(mscclpp::SmChannelDeviceHandle* smChans, TYPE* buff, TYPE* scratch, void* resultBuff, int rank,
int worldSize, size_t nelems) {
nelems = nelems / (sizeof(int) / sizeof(TYPE));
Expand Down
78 changes: 65 additions & 13 deletions python/benchmark/allreduce_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,46 @@
raise RuntimeError("Unknown data type")


def plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups):
import matplotlib.pyplot as plt
chhwang marked this conversation as resolved.
Show resolved Hide resolved

human_readable_sizes = [human_readable_size(size) for size in sizes]

fig, ax1 = plt.subplots(figsize=(10, 6))

# Plotting AlgBW for MSCCLPP and NCCL on the primary y-axis
(line1,) = ax1.plot(sizes, mscclpp_algbw, marker="o", color="blue", label="MSCCLPP AlgBW")
(line2,) = ax1.plot(sizes, nccl_algbw, marker="x", color="red", label="NCCL AlgBW")
ax1.set_ylabel("AlgBW (GB/s)")
ax1.set_xlabel("Data Size")

# Logarithmic x-axis
ax1.set_xscale("log", base=2)
ax1.set_xticks(sizes)
ax1.set_xticklabels(human_readable_sizes, rotation=45)

# Adding secondary y-axis for Speed Up
ax2 = ax1.twinx()
(line3,) = ax2.plot(sizes, speed_ups, marker="^", color="green", label="Speed Up")
ax2.set_ylabel("Speed Up (NCCL Time / MSCCLPP Time)", color="green")
ax2.tick_params(axis="y", labelcolor="green")

# Set the lower bound of the secondary y-axis to 0
ax2.set_ylim(bottom=0)

# Creating legends
lines = [line1, line2, line3]
labels = [line.get_label() for line in lines]
ax1.legend(lines, labels, loc="upper left")

# Setting title and grid
ax1.set_title("MSCCLPP vs NCCL -- " + str(MPI.COMM_WORLD.size // N_GPUS_PER_NODE) + " Nodes")
ax2.grid(True, which="both", ls="--")

# Saving the plot
plt.savefig("mscclpp_vs_nccl_comparison.pdf", format="pdf")


def human_readable_size(size, decimal_places=1):
for unit in ["B", "KiB", "MiB", "GiB", "TiB", "PiB"]:
if size < 1024.0 or unit == "PiB":
Expand Down Expand Up @@ -99,15 +139,12 @@ def run_benchmark(
memory_out = cp.zeros(nelem, dtype=data_type)
cp.cuda.runtime.deviceSynchronize()

proxy_service = None
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
if memory.nbytes < 2**20:
mscclpp_call = MscclppAllReduce2(mscclpp_group, memory, memory_out)
elif memory.nbytes < 2**29:
if memory.nbytes >= 2**20 and memory.nbytes <= 2**22:
read_only = 0
else:
read_only = 1
mscclpp_call = MscclppAllReduce1(mscclpp_group, memory, read_only=read_only)
mscclpp_call = MscclppAllReduce1(mscclpp_group, memory)
else:
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce3(mscclpp_group, memory, proxy_service)
Expand All @@ -117,14 +154,13 @@ def run_benchmark(
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce5(mscclpp_group, memory, memory_out, N_GPUS_PER_NODE, proxy_service)
proxy_service.start_proxy()
best_config = find_best_config(mscclpp_call, 100)
mscclpp_call.set_params(*best_config)
else:
proxy_service = ProxyService()
mscclpp_call = MscclppAllReduce4(mscclpp_group, memory, N_GPUS_PER_NODE, proxy_service)
proxy_service.start_proxy()
best_config = find_best_config(mscclpp_call, 20)
mscclpp_call.set_params(*best_config)

best_config = find_best_config(mscclpp_call, 20)
mscclpp_call.set_params(*best_config)

nccl_call = NcclAllReduce(nccl_op, memory)

Expand All @@ -145,6 +181,7 @@ def run_benchmark(
MPI.COMM_WORLD.barrier()
proxy_service.stop_proxy()

speed_up = nccl_time / mscclpp_time
if MPI.COMM_WORLD.rank == 0:
table.add_row(
[
Expand All @@ -155,12 +192,14 @@ def run_benchmark(
"{:.2f}".format(nccl_time),
"{:.2f}".format(nccl_algBw),
nccl_check,
"{:.2f}".format(nccl_time / mscclpp_time),
"{:.2f}".format(speed_up),
]
)
if MPI.COMM_WORLD.rank == 0:
print(".", end="", flush=True)

return memory.nbytes, mscclpp_algBw, nccl_algBw, speed_up


if __name__ == "__main__":
shm_comm = MPI.COMM_WORLD.Split_type(MPI.COMM_TYPE_SHARED, 0, MPI.INFO_NULL)
Expand Down Expand Up @@ -200,16 +239,29 @@ def run_benchmark(
"Speed Up",
]

for i in range(10, 28):
sizes = []
mscclpp_algbw = []
nccl_algbw = []
speed_ups = []
for i in range(10, 30):
if MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 1:
run_benchmark(mscclpp_group, nccl_comm, table, 100, 2**i)
nelems = 2**i
elif MPI.COMM_WORLD.size // N_GPUS_PER_NODE == 2:
run_benchmark(mscclpp_group, nccl_comm, table, 100, 3 * 2**i)
nelems = 3 * 2**i
else:
raise RuntimeError("Only support one node/two nodes communication")

size, mscclpp_algBw, nccl_algBw, speed_up = run_benchmark(mscclpp_group, nccl_comm, table, 100, nelems)
sizes.append(size)
mscclpp_algbw.append(mscclpp_algBw)
nccl_algbw.append(nccl_algBw)
speed_ups.append(speed_up)

if MPI.COMM_WORLD.rank == 0:
print()
print(table)

plot_graph(sizes, mscclpp_algbw, nccl_algbw, speed_ups)

mscclpp_group = None
nccl_comm = None
Loading
Loading