[MTIA] Add set_device support (#128040)

Summary: Support set_device API in MTIA backend.

Reviewed By: gnahzg

Differential Revision: D58089498

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128040
Approved by: https://github.com/gnahzg
This commit is contained in:
Jun Luo 2024-06-10 23:42:50 +00:00 committed by PyTorch MergeBot
parent 30875953a4
commit f843ccbb1a
2 changed files with 14 additions and 0 deletions

View File

@ -18,6 +18,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
init
is_available
is_initialized
set_device
set_stream
stream
synchronize

View File

@ -160,6 +160,18 @@ def set_stream(stream: Stream):
torch._C._mtia_setCurrentStream(stream)
def set_device(device: _device_t) -> None:
r"""Set the current device.
Args:
device (torch.device or int): selected device. This function is a no-op
if this argument is negative.
"""
device = _get_device_index(device)
if device >= 0:
torch._C._accelerator_hooks_set_current_device(device)
class device:
r"""Context-manager that changes the selected device.
@ -257,6 +269,7 @@ __all__ = [
"current_device",
"current_stream",
"default_stream",
"set_device",
"set_stream",
"stream",
"device",