diff --git a/aten/src/ATen/DeviceAccelerator.cpp b/aten/src/ATen/DeviceAccelerator.cpp index 4c4e711885086..8d4410f963830 100644 --- a/aten/src/ATen/DeviceAccelerator.cpp +++ b/aten/src/ATen/DeviceAccelerator.cpp @@ -1,6 +1,8 @@ #include #include -namespace at { +#include + +namespace at::accelerator { std::optional getAccelerator(bool checked) { #define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \ @@ -37,8 +39,8 @@ std::optional getAccelerator(bool checked) { #undef DETECT_AND_ASSIGN_ACCELERATOR } -bool isAccelerator(c10::DeviceType d) { - switch (d) { +bool isAccelerator(c10::DeviceType device_type) { + switch (device_type) { case at::kCUDA: case at::kMTIA: case at::kXPU: @@ -52,4 +54,50 @@ bool isAccelerator(c10::DeviceType d) { } } -} // namespace at +c10::DeviceIndex deviceCount() { + const auto device_type = getAccelerator(false); + if (!device_type.has_value()) { + return static_cast(0); + } + c10::impl::VirtualGuardImpl impl(device_type.value()); + return static_cast(impl.deviceCount()); +} + +void setDeviceIndex(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + c10::impl::VirtualGuardImpl impl(device_type); + impl.setDevice({device_type, device_index}); +} + +c10::DeviceIndex getDeviceIndex() { + const auto device_type = getAccelerator(true).value(); + c10::impl::VirtualGuardImpl impl(device_type); + return static_cast(impl.getDevice().index()); +} + +void setCurrentStream(c10::Stream stream) { + const auto device_type = getAccelerator(true).value(); + TORCH_CHECK( + device_type == stream.device_type(), + "stream's device type ", + c10::DeviceTypeName(stream.device_type()), + " doesn't match the current accelerator ", + c10::DeviceTypeName(device_type)); + c10::impl::VirtualGuardImpl impl(device_type); + impl.exchangeStream(stream); +} + +c10::Stream getCurrentStream(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + c10::impl::VirtualGuardImpl impl(device_type); + return impl.getStream({device_type, device_index}); +} + +void synchronizeDevice(c10::DeviceIndex device_index) { + const auto device_type = getAccelerator(true).value(); + c10::impl::VirtualGuardImpl impl(device_type); + // impl.synchronizeDevice should can be safely called from any device + impl.synchronizeDevice(device_index); +} + +} // namespace at::accelerator diff --git a/aten/src/ATen/DeviceAccelerator.h b/aten/src/ATen/DeviceAccelerator.h index 7840911bd6b25..b9de0209c75f2 100644 --- a/aten/src/ATen/DeviceAccelerator.h +++ b/aten/src/ATen/DeviceAccelerator.h @@ -6,6 +6,8 @@ #include #include +namespace at::accelerator { + // This file defines the top level Accelerator concept for PyTorch. // A device is an accelerator per the definition here if: // - It is mutually exclusive with all other accelerators @@ -15,13 +17,39 @@ // As of today, accelerator devices are (in no particular order): // CUDA, MTIA, XPU, HIP, MPS, PrivateUse1 -namespace at { - // Ensures that only one accelerator is available (at // compile time if possible) and return it. // When checked is true, the returned optional always has a value. TORCH_API std::optional getAccelerator(bool checked = false); -TORCH_API bool isAccelerator(c10::DeviceType d); +// Check if the given device type is an accelerator. +TORCH_API bool isAccelerator(c10::DeviceType device_type); + +// Return the number of the device available. Note that this is *REQUIRED* to +// not raise any exception. +TORCH_API c10::DeviceIndex deviceCount(); + +// Set the current device index to the given device index. +TORCH_API void setDeviceIndex(c10::DeviceIndex device_index); + +// Get the current device index. +TORCH_API c10::DeviceIndex getDeviceIndex(); +// Set the current stream to a given stream. Note that this API doesn't change +// the current device index. +TORCH_API void setCurrentStream(c10::Stream stream); + +// Get the current stream of the given device index. +TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index); + +// Wait (by blocking the calling thread) until all the work previously enqueued +// on the given device index has been completed. +TORCH_API void synchronizeDevice(c10::DeviceIndex device_index); + +} // namespace at::accelerator + +namespace at { +// Keep BC only +using at::accelerator::getAccelerator; +using at::accelerator::isAccelerator; } // namespace at diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp index 9684c7e6ed263..641385fc529bc 100644 --- a/torch/csrc/DeviceAccelerator.cpp +++ b/torch/csrc/DeviceAccelerator.cpp @@ -1,4 +1,3 @@ -#include #include #include @@ -13,69 +12,53 @@ void initModule(PyObject* module) { }); m.def("_accelerator_deviceCount", []() { - const auto device_type = at::getAccelerator(false); - if (!device_type.has_value()) { - return static_cast(0); - } - torch::utils::maybe_initialize_device(device_type.value()); - c10::impl::VirtualGuardImpl impl(device_type.value()); - return static_cast(impl.deviceCount()); + auto device_type = at::accelerator::getAccelerator(false); + torch::utils::maybe_initialize_device(device_type); + return at::accelerator::deviceCount(); }); m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) { - const auto device_type = at::getAccelerator(true).value(); // If device index is negative, no-op if (device_index < 0) { return; } + const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); - c10::impl::VirtualGuardImpl impl(device_type); - impl.setDevice({device_type, device_index}); + at::accelerator::setDeviceIndex(device_index); }); m.def("_accelerator_getDeviceIndex", []() { - const auto device_type = at::getAccelerator(true).value(); + const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); - c10::impl::VirtualGuardImpl impl(device_type); - return static_cast(impl.getDevice().index()); + return at::accelerator::getDeviceIndex(); }); m.def("_accelerator_setStream", [](c10::Stream stream) { - const auto device_type = at::getAccelerator(true).value(); - TORCH_CHECK( - device_type == stream.device_type(), - "stream's device type ", - c10::DeviceTypeName(stream.device_type()), - " doesn't match the current accelerator ", - c10::DeviceTypeName(device_type)); + const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); - c10::impl::VirtualGuardImpl impl(device_type); // Set the current device to the device of stream - if (impl.getDevice().index() != stream.device_index()) { - impl.setDevice(stream.device()); + if (at::accelerator::getDeviceIndex() != stream.device_index()) { + at::accelerator::setDeviceIndex(stream.device_index()); } - impl.exchangeStream(stream); + at::accelerator::setCurrentStream(stream); }); m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) { - const auto device_type = at::getAccelerator(true).value(); + const auto device_type = at::accelerator::getAccelerator(true).value(); torch::utils::maybe_initialize_device(device_type); - c10::impl::VirtualGuardImpl impl(device_type); - return impl.getStream({device_type, device_index}); + return at::accelerator::getCurrentStream(device_index); }); m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) { - const auto device_type = at::getAccelerator(true).value(); + const auto device_type = at::accelerator::getAccelerator(true).value(); if (torch::utils::is_device_lazy_init_supported(device_type) && !torch::utils::is_device_initialized(device_type)) { return; } torch::utils::maybe_initialize_device(device_type); - c10::impl::VirtualGuardImpl impl(device_type); - // impl.synchronizeDevice should can be safely called from any device { py::gil_scoped_release no_gil; - impl.synchronizeDevice(device_index); + at::accelerator::synchronizeDevice(device_index); } }); } diff --git a/torch/csrc/utils/device_lazy_init.h b/torch/csrc/utils/device_lazy_init.h index 1dcbb1bc97a3c..c717d4ef1804e 100644 --- a/torch/csrc/utils/device_lazy_init.h +++ b/torch/csrc/utils/device_lazy_init.h @@ -55,6 +55,14 @@ inline void maybe_initialize_device(const at::TensorOptions& options) { maybe_initialize_device(device); } +inline void maybe_initialize_device( + std::optional& device_type) { + if (!device_type.has_value()) { + return; + } + maybe_initialize_device(device_type.value()); +} + bool is_device_initialized(at::DeviceType device_type); } // namespace torch::utils