Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow MSCCL++ CommGroup to take PyTorch tensors in args #255

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import mpi4py
import numpy as np

from mscclpp.utils import is_torch_tensor


class CommGroup:
def __init__(
Expand Down Expand Up @@ -108,8 +110,15 @@ def register_tensor_with_connections(
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data
local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags)
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
all_registered_memories = {}
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
Expand All @@ -136,20 +145,24 @@ def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor.data.ptr)
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
return channels

def make_sm_channels_with_scratch(
self, tensor: cp.ndarray, scratchTensor: cp.ndarray, connections: dict[int, Connection]
self,
tensor: cp.ndarray,
scratchTensor: cp.ndarray,
connections: dict[int, Connection],
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if is_torch_tensor(tensor) else tensor.data.ptr
scratch_data_ptr = scratchTensor.data_ptr() if is_torch_tensor(scratchTensor) else scratchTensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(
semaphores[rank], registered_memories[rank], tensor.data.ptr, scratchTensor.data.ptr
)
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr)
return channels

def make_proxy_channels(
Expand Down Expand Up @@ -180,8 +193,15 @@ def make_proxy_channels_with_scratch(
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data
local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags)
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if is_torch_tensor(tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size() if is_torch_tensor(tensor) else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)

semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
Expand Down
17 changes: 16 additions & 1 deletion python/mscclpp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,21 @@
import struct
import subprocess
import tempfile
from typing import Type
from typing import Any, Type

from cuda import cuda, nvrtc, cudart
import cupy as cp
import numpy as np

try:
import torch

_use_torch = True
torchTensor = torch.Tensor
except ImportError:
_use_torch = False
torchTensor = Type[Any]


def _check_cuda_errors(result):
if result[0].value:
Expand Down Expand Up @@ -145,6 +154,8 @@ def pack(*args):
res += struct.pack("P", arg.ctypes.data)
elif isinstance(arg, cp.ndarray):
res += struct.pack("P", arg.data.ptr)
elif is_torch_tensor(arg):
res += struct.pack("P", arg.data_ptr())
# use int to represent bool, which can avoid CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES error
elif isinstance(arg, bool):
res += struct.pack("i", arg)
Expand All @@ -153,3 +164,7 @@ def pack(*args):
else:
raise RuntimeError(f"Unsupported type: {type(arg)}")
return res


def is_torch_tensor(tensor: Any) -> bool:
return _use_torch and isinstance(tensor, torchTensor)
Loading