Skip to content

Commit

Permalink
Allow MSCCL++ CommGroup to take PyTorch tensors in args
Browse files Browse the repository at this point in the history
Obtain data_ptr and tensor_size accordingly.
  • Loading branch information
aashaka committed Feb 1, 2024
1 parent 4eb0a08 commit d0249a3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 14 deletions.
55 changes: 41 additions & 14 deletions python/mscclpp/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
import mpi4py
import numpy as np
import torch


class CommGroup:
Expand Down Expand Up @@ -93,13 +94,22 @@ def make_connection(
return connections

def register_tensor_with_connections(
self, tensor: Type[cp.ndarray] or Type[np.ndarray], connections: dict[int, Connection]
self, tensor: Type[cp.ndarray] or Type[torch.Tensor] or Type[np.ndarray], connections: dict[int, Connection]
) -> dict[int, RegisteredMemory]:
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data
local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags)
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if isinstance(tensor, torch.Tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size()
if isinstance(tensor, torch.Tensor)
else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)
all_registered_memories = {}
all_registered_memories[self.my_rank] = local_reg_memory
future_memories = {}
Expand All @@ -122,28 +132,36 @@ def make_semaphore(
self.communicator.setup()
return semaphores

def make_sm_channels(self, tensor: cp.ndarray, connections: dict[int, Connection]) -> dict[int, SmChannel]:
def make_sm_channels(
self, tensor: torch.Tensor or cp.ndarray, connections: dict[int, Connection]
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if isinstance(tensor, torch.Tensor) else tensor.data.ptr
for rank in connections:
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor.data.ptr)
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr)
return channels

def make_sm_channels_with_scratch(
self, tensor: cp.ndarray, scratchTensor: cp.ndarray, connections: dict[int, Connection]
self,
tensor: torch.Tensor or cp.ndarray,
scratchTensor: torch.Tensor or cp.ndarray,
connections: dict[int, Connection],
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, SmDevice2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
channels = {}
tensor_data_ptr = tensor.data_ptr() if isinstance(tensor, torch.Tensor) else tensor.data.ptr
scratch_data_ptr = (
scratchTensor.data_ptr() if isinstance(scratchTensor, torch.Tensor) else scratchTensor.data.ptr
)
for rank in connections:
channels[rank] = SmChannel(
semaphores[rank], registered_memories[rank], tensor.data.ptr, scratchTensor.data.ptr
)
channels[rank] = SmChannel(semaphores[rank], registered_memories[rank], tensor_data_ptr, scratch_data_ptr)
return channels

def make_proxy_channels(
self, proxy_service: ProxyService, tensor: cp.ndarray, connections: dict[int, Connection]
self, proxy_service: ProxyService, tensor: torch.Tensor or cp.ndarray, connections: dict[int, Connection]
) -> dict[int, SmChannel]:
semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(tensor, connections)
Expand All @@ -163,15 +181,24 @@ def make_proxy_channels(
def make_proxy_channels_with_scratch(
self,
proxy_service: ProxyService,
tensor: cp.ndarray,
scratchTensor: cp.ndarray,
tensor: torch.Tensor or cp.ndarray,
scratchTensor: torch.Tensor or cp.ndarray,
connections: dict[int, Connection],
) -> dict[int, SmChannel]:
transport_flags = TransportFlags()
for rank in connections:
transport_flags |= connections[rank].transport()
data_ptr = tensor.data.ptr if isinstance(tensor, cp.ndarray) else tensor.ctypes.data
local_reg_memory = self.communicator.register_memory(data_ptr, tensor.size * tensor.itemsize, transport_flags)
data_ptr = (
tensor.data.ptr
if isinstance(tensor, cp.ndarray)
else tensor.data_ptr() if isinstance(tensor, torch.Tensor) else tensor.ctypes.data
)
tensor_size = (
tensor.numel() * tensor.element_size()
if isinstance(tensor, torch.Tensor)
else tensor.size * tensor.itemsize
)
local_reg_memory = self.communicator.register_memory(data_ptr, tensor_size, transport_flags)

semaphores = self.make_semaphore(connections, Host2DeviceSemaphore)
registered_memories = self.register_tensor_with_connections(scratchTensor, connections)
Expand Down
3 changes: 3 additions & 0 deletions python/mscclpp/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import struct
import subprocess
import tempfile
import torch
from typing import Type

from cuda import cuda, nvrtc, cudart
Expand Down Expand Up @@ -145,6 +146,8 @@ def pack(*args):
res += struct.pack("P", arg.ctypes.data)
elif isinstance(arg, cp.ndarray):
res += struct.pack("P", arg.data.ptr)
elif isinstance(arg, torch.Tensor):
res += struct.pack("P", arg.data_ptr())
# use int to represent bool, which can avoid CUDA_ERROR_LAUNCH_OUT_OF_RESOURCES error
elif isinstance(arg, bool):
res += struct.pack("i", arg)
Expand Down

0 comments on commit d0249a3

Please sign in to comment.