Skip to content

Commit

Permalink
[RELAND] Add device-agnostic runtime Device/Stream C++ API (pytorch#1…
Browse files Browse the repository at this point in the history
…38677)

# Motivation
This PR intends to add C++ accelerator device-agnostic APIs.

# Additional Context
This PR is relanded. It is reverted because `torch.Event` doesn't support mps backend. We have fixed it in pytorch#142468. The previous commit is pytorch@f84e533

Pull Request resolved: pytorch#138677
Approved by: https://github.com/albanD, https://github.com/EikanWang
ghstack dependencies: pytorch#143171, pytorch#133572
  • Loading branch information
guangyey authored and pytorchmergebot committed Dec 16, 2024
1 parent 45ac4eb commit 9706ada
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 39 deletions.
56 changes: 52 additions & 4 deletions aten/src/ATen/DeviceAccelerator.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#include <ATen/Context.h>
#include <ATen/DeviceAccelerator.h>
namespace at {
#include <c10/core/impl/VirtualGuardImpl.h>

namespace at::accelerator {

std::optional<c10::DeviceType> getAccelerator(bool checked) {
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
Expand Down Expand Up @@ -37,8 +39,8 @@ std::optional<c10::DeviceType> getAccelerator(bool checked) {
#undef DETECT_AND_ASSIGN_ACCELERATOR
}

bool isAccelerator(c10::DeviceType d) {
switch (d) {
bool isAccelerator(c10::DeviceType device_type) {
switch (device_type) {
case at::kCUDA:
case at::kMTIA:
case at::kXPU:
Expand All @@ -52,4 +54,50 @@ bool isAccelerator(c10::DeviceType d) {
}
}

} // namespace at
c10::DeviceIndex deviceCount() {
const auto device_type = getAccelerator(false);
if (!device_type.has_value()) {
return static_cast<c10::DeviceIndex>(0);
}
c10::impl::VirtualGuardImpl impl(device_type.value());
return static_cast<c10::DeviceIndex>(impl.deviceCount());
}

void setDeviceIndex(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
impl.setDevice({device_type, device_index});
}

c10::DeviceIndex getDeviceIndex() {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
}

void setCurrentStream(c10::Stream stream) {
const auto device_type = getAccelerator(true).value();
TORCH_CHECK(
device_type == stream.device_type(),
"stream's device type ",
c10::DeviceTypeName(stream.device_type()),
" doesn't match the current accelerator ",
c10::DeviceTypeName(device_type));
c10::impl::VirtualGuardImpl impl(device_type);
impl.exchangeStream(stream);
}

c10::Stream getCurrentStream(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
return impl.getStream({device_type, device_index});
}

void synchronizeDevice(c10::DeviceIndex device_index) {
const auto device_type = getAccelerator(true).value();
c10::impl::VirtualGuardImpl impl(device_type);
// impl.synchronizeDevice should can be safely called from any device
impl.synchronizeDevice(device_index);
}

} // namespace at::accelerator
34 changes: 31 additions & 3 deletions aten/src/ATen/DeviceAccelerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#include <ATen/detail/MTIAHooksInterface.h>
#include <optional>

namespace at::accelerator {

// This file defines the top level Accelerator concept for PyTorch.
// A device is an accelerator per the definition here if:
// - It is mutually exclusive with all other accelerators
Expand All @@ -15,13 +17,39 @@
// As of today, accelerator devices are (in no particular order):
// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1

namespace at {

// Ensures that only one accelerator is available (at
// compile time if possible) and return it.
// When checked is true, the returned optional always has a value.
TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false);

TORCH_API bool isAccelerator(c10::DeviceType d);
// Check if the given device type is an accelerator.
TORCH_API bool isAccelerator(c10::DeviceType device_type);

// Return the number of the device available. Note that this is *REQUIRED* to
// not raise any exception.
TORCH_API c10::DeviceIndex deviceCount();

// Set the current device index to the given device index.
TORCH_API void setDeviceIndex(c10::DeviceIndex device_index);

// Get the current device index.
TORCH_API c10::DeviceIndex getDeviceIndex();

// Set the current stream to a given stream. Note that this API doesn't change
// the current device index.
TORCH_API void setCurrentStream(c10::Stream stream);

// Get the current stream of the given device index.
TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index);

// Wait (by blocking the calling thread) until all the work previously enqueued
// on the given device index has been completed.
TORCH_API void synchronizeDevice(c10::DeviceIndex device_index);

} // namespace at::accelerator

namespace at {
// Keep BC only
using at::accelerator::getAccelerator;
using at::accelerator::isAccelerator;
} // namespace at
47 changes: 15 additions & 32 deletions torch/csrc/DeviceAccelerator.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
#include <c10/core/DeviceGuard.h>
#include <torch/csrc/DeviceAccelerator.h>
#include <torch/csrc/utils/device_lazy_init.h>

Expand All @@ -13,69 +12,53 @@ void initModule(PyObject* module) {
});

m.def("_accelerator_deviceCount", []() {
const auto device_type = at::getAccelerator(false);
if (!device_type.has_value()) {
return static_cast<c10::DeviceIndex>(0);
}
torch::utils::maybe_initialize_device(device_type.value());
c10::impl::VirtualGuardImpl impl(device_type.value());
return static_cast<c10::DeviceIndex>(impl.deviceCount());
auto device_type = at::accelerator::getAccelerator(false);
torch::utils::maybe_initialize_device(device_type);
return at::accelerator::deviceCount();
});

m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
const auto device_type = at::getAccelerator(true).value();
// If device index is negative, no-op
if (device_index < 0) {
return;
}
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
impl.setDevice({device_type, device_index});
at::accelerator::setDeviceIndex(device_index);
});

m.def("_accelerator_getDeviceIndex", []() {
const auto device_type = at::getAccelerator(true).value();
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
return at::accelerator::getDeviceIndex();
});

m.def("_accelerator_setStream", [](c10::Stream stream) {
const auto device_type = at::getAccelerator(true).value();
TORCH_CHECK(
device_type == stream.device_type(),
"stream's device type ",
c10::DeviceTypeName(stream.device_type()),
" doesn't match the current accelerator ",
c10::DeviceTypeName(device_type));
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
// Set the current device to the device of stream
if (impl.getDevice().index() != stream.device_index()) {
impl.setDevice(stream.device());
if (at::accelerator::getDeviceIndex() != stream.device_index()) {
at::accelerator::setDeviceIndex(stream.device_index());
}
impl.exchangeStream(stream);
at::accelerator::setCurrentStream(stream);
});

m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
const auto device_type = at::getAccelerator(true).value();
const auto device_type = at::accelerator::getAccelerator(true).value();
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
return impl.getStream({device_type, device_index});
return at::accelerator::getCurrentStream(device_index);
});

m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
const auto device_type = at::getAccelerator(true).value();
const auto device_type = at::accelerator::getAccelerator(true).value();
if (torch::utils::is_device_lazy_init_supported(device_type) &&
!torch::utils::is_device_initialized(device_type)) {
return;
}
torch::utils::maybe_initialize_device(device_type);
c10::impl::VirtualGuardImpl impl(device_type);
// impl.synchronizeDevice should can be safely called from any device
{
py::gil_scoped_release no_gil;
impl.synchronizeDevice(device_index);
at::accelerator::synchronizeDevice(device_index);
}
});
}
Expand Down
8 changes: 8 additions & 0 deletions torch/csrc/utils/device_lazy_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ inline void maybe_initialize_device(const at::TensorOptions& options) {
maybe_initialize_device(device);
}

inline void maybe_initialize_device(
std::optional<at::DeviceType>& device_type) {
if (!device_type.has_value()) {
return;
}
maybe_initialize_device(device_type.value());
}

bool is_device_initialized(at::DeviceType device_type);

} // namespace torch::utils

0 comments on commit 9706ada

Please sign in to comment.