mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MTIA] Add set_device support (#128040)
Summary: Support set_device API in MTIA backend. Reviewed By: gnahzg Differential Revision: D58089498 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128040 Approved by: https://github.com/gnahzg
This commit is contained in:
parent
30875953a4
commit
f843ccbb1a
|
|
@ -18,6 +18,7 @@ The MTIA backend is implemented out of the tree, only interfaces are be defined
|
|||
init
|
||||
is_available
|
||||
is_initialized
|
||||
set_device
|
||||
set_stream
|
||||
stream
|
||||
synchronize
|
||||
|
|
|
|||
|
|
@ -160,6 +160,18 @@ def set_stream(stream: Stream):
|
|||
torch._C._mtia_setCurrentStream(stream)
|
||||
|
||||
|
||||
def set_device(device: _device_t) -> None:
|
||||
r"""Set the current device.
|
||||
|
||||
Args:
|
||||
device (torch.device or int): selected device. This function is a no-op
|
||||
if this argument is negative.
|
||||
"""
|
||||
device = _get_device_index(device)
|
||||
if device >= 0:
|
||||
torch._C._accelerator_hooks_set_current_device(device)
|
||||
|
||||
|
||||
class device:
|
||||
r"""Context-manager that changes the selected device.
|
||||
|
||||
|
|
@ -257,6 +269,7 @@ __all__ = [
|
|||
"current_device",
|
||||
"current_stream",
|
||||
"default_stream",
|
||||
"set_device",
|
||||
"set_stream",
|
||||
"stream",
|
||||
"device",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user