From 95ab1088efb0007098ee05bece1b3e9de21807ff Mon Sep 17 00:00:00 2001 From: Ziyue Yang Date: Thu, 24 Oct 2024 23:04:11 +0800 Subject: [PATCH] Fix in-place all-gather input buffer in executor_test (#372) --- python/test/executor_test.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/python/test/executor_test.py b/python/test/executor_test.py index 5fd59f2bb..5dd41a2c3 100644 --- a/python/test/executor_test.py +++ b/python/test/executor_test.py @@ -77,6 +77,13 @@ def dtype_to_mscclpp_dtype(dtype): raise ValueError(f"Unknown data type: {dtype}") +def determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name): + if "allgather" in execution_plan_name and in_place: + return recvbuf + else: + return sendbuf + + def determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name): if "allgather" in execution_plan_name: return recvbuf @@ -126,9 +133,9 @@ def main( executor_func = lambda stream: executor.execute( MPI.COMM_WORLD.rank, - sendbuf.data.ptr, + determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr, determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).data.ptr, - sendbuf.nbytes, + determine_input_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes, determine_result_buf(sendbuf, recvbuf, in_place, execution_plan_name).nbytes, dtype_to_mscclpp_dtype(dtype), execution_plan,