diff --git a/docs/source/mtia.rst b/docs/source/mtia.rst index f2f5b5195dcbe6..b729c061b2a681 100644 --- a/docs/source/mtia.rst +++ b/docs/source/mtia.rst @@ -18,6 +18,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined init is_available is_initialized + set_device set_stream stream synchronize diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index b68a25bdb61b96..f9554a9bcb277f 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -160,6 +160,18 @@ def set_stream(stream: Stream): torch._C._mtia_setCurrentStream(stream) +def set_device(device: _device_t) -> None: + r"""Set the current device. + + Args: + device (torch.device or int): selected device. This function is a no-op + if this argument is negative. + """ + device = _get_device_index(device) + if device >= 0: + torch._C._accelerator_hooks_set_current_device(device) + + class device: r"""Context-manager that changes the selected device. @@ -257,6 +269,7 @@ def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext: "current_device", "current_stream", "default_stream", + "set_device", "set_stream", "stream", "device",