mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
torch.mtia module for MTIA device backend (#123612)
MTIA device has its own Module in PyTorch now.
torch.mtia has following APIs similar to other backends. The lazy_init is also supported.
```
__all__ = [
"init",
"is_available",
"synchronize",
"device_count",
"current_device",
"current_stream",
"default_stream",
"set_stream",
"stream",
"device",
]
```
------------
For device management. We expand AccleratorHooksInterface to support generic device management and it can be used in both C++ and PyThon.
```
def _accelerator_hooks_device_count() -> _int: ...
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
def _accelerator_hooks_get_current_device() -> _int : ...
def _accelerator_hooks_exchange_device(device_index: _int) -> _int : ...
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int : ...
```
---------
Adding get_device_module API to retrieve device modules for different device types.
```
def get_device_module(device: Optional[Union[torch.device, str]] = None)
```
---------
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123612
Approved by: https://github.com/albanD
ghstack dependencies: #123611
This commit is contained in:
parent
36af9c0d7d
commit
73744a2c00
|
|
@ -69,6 +69,8 @@ class TORCH_API Context {
|
||||||
return at::detail::getMPSHooks();
|
return at::detail::getMPSHooks();
|
||||||
} else if (device_type == at::kPrivateUse1) {
|
} else if (device_type == at::kPrivateUse1) {
|
||||||
return at::detail::getPrivateUse1Hooks();
|
return at::detail::getPrivateUse1Hooks();
|
||||||
|
} else if (device_type == at::kMTIA) {
|
||||||
|
return at::detail::getMTIAHooks();
|
||||||
} else {
|
} else {
|
||||||
AT_ERROR(
|
AT_ERROR(
|
||||||
c10::DeviceTypeName(device_type), " device type not an accelerator.");
|
c10::DeviceTypeName(device_type), " device type not an accelerator.");
|
||||||
|
|
@ -156,6 +158,9 @@ class TORCH_API Context {
|
||||||
void lazyInitXPU() {
|
void lazyInitXPU() {
|
||||||
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
|
c10::call_once(thx_init, [&] { detail::getXPUHooks().initXPU(); });
|
||||||
}
|
}
|
||||||
|
void lazyInitMTIA() {
|
||||||
|
c10::call_once(th_mtia_init, [&] { detail::getMTIAHooks().initMTIA(); });
|
||||||
|
}
|
||||||
void lazyInitPrivateUse1() {
|
void lazyInitPrivateUse1() {
|
||||||
c10::call_once(thp_init, [&] {
|
c10::call_once(thp_init, [&] {
|
||||||
if (isPrivateUse1HooksRegistered()) {
|
if (isPrivateUse1HooksRegistered()) {
|
||||||
|
|
@ -349,6 +354,7 @@ class TORCH_API Context {
|
||||||
c10::once_flag thc_init;
|
c10::once_flag thc_init;
|
||||||
c10::once_flag thh_init;
|
c10::once_flag thh_init;
|
||||||
c10::once_flag thx_init;
|
c10::once_flag thx_init;
|
||||||
|
c10::once_flag th_mtia_init;
|
||||||
c10::once_flag thp_init;
|
c10::once_flag thp_init;
|
||||||
bool enabled_cudnn = true;
|
bool enabled_cudnn = true;
|
||||||
bool deterministic_cudnn = false;
|
bool deterministic_cudnn = false;
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,9 @@ C10_API std::optional<DeviceType> getAccelerator(bool checked) {
|
||||||
#define CHECK_NO_PU1 \
|
#define CHECK_NO_PU1 \
|
||||||
TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1");
|
TORCH_CHECK(!is_privateuse1_backend_registered(), "Cannot have both CUDA and PrivateUse1");
|
||||||
|
|
||||||
|
#define CHECK_NO_MTIA \
|
||||||
|
TORCH_CHECK(!at::hasMTIA(), "Cannot have MTIA with other devices");
|
||||||
|
|
||||||
if (is_privateuse1_backend_registered()) {
|
if (is_privateuse1_backend_registered()) {
|
||||||
// We explicitly allow PrivateUse1 and another device at the same time
|
// We explicitly allow PrivateUse1 and another device at the same time
|
||||||
// as we use this for testing.
|
// as we use this for testing.
|
||||||
|
|
@ -17,7 +20,12 @@ C10_API std::optional<DeviceType> getAccelerator(bool checked) {
|
||||||
return kPrivateUse1;
|
return kPrivateUse1;
|
||||||
} else if (at::hasCUDA()) {
|
} else if (at::hasCUDA()) {
|
||||||
CHECK_NO_PU1
|
CHECK_NO_PU1
|
||||||
|
CHECK_NO_MTIA
|
||||||
return kCUDA;
|
return kCUDA;
|
||||||
|
} else if (at::hasMTIA()) {
|
||||||
|
CHECK_NO_CUDA
|
||||||
|
CHECK_NO_PU1
|
||||||
|
return kMTIA;
|
||||||
} else {
|
} else {
|
||||||
TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.")
|
TORCH_CHECK(!checked, "Cannot access accelerator device when none is available.")
|
||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/Device.h>
|
#include <c10/core/Device.h>
|
||||||
|
#include <c10/core/Stream.h>
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
// AcceleratorHooksInterface is a shared interface provided by all
|
// AcceleratorHooksInterface is a shared interface provided by all
|
||||||
|
|
@ -16,6 +16,29 @@ struct TORCH_API AcceleratorHooksInterface {
|
||||||
|
|
||||||
// Whether the device at device_index is fully initialized or not.
|
// Whether the device at device_index is fully initialized or not.
|
||||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
|
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
|
||||||
|
|
||||||
|
virtual DeviceIndex deviceCount() const {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void setCurrentDevice(DeviceIndex device) const {
|
||||||
|
TORCH_CHECK(false, "Backend doesn't support setCurrentDevice()");
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual DeviceIndex getCurrentDevice() const {
|
||||||
|
TORCH_CHECK(false, "Backend doesn't support getCurrentDevice()");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual DeviceIndex exchangeDevice(DeviceIndex device) const {
|
||||||
|
TORCH_CHECK(false, "Backend doesn't support exchangeDevice()");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const {
|
||||||
|
TORCH_CHECK(false, "Backend doesn't support maybeExchangeDevice()");
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
|
|
@ -8,19 +8,22 @@
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
|
const MTIAHooksInterface& getMTIAHooks() {
|
||||||
const MTIAHooksInterface &getMTIAHooks() {
|
static std::unique_ptr<MTIAHooksInterface> mtia_hooks = nullptr;
|
||||||
static MTIAHooksInterface* MTIA_hooks = nullptr;
|
|
||||||
static c10::once_flag once;
|
static c10::once_flag once;
|
||||||
c10::call_once(once, [] {
|
c10::call_once(once, [] {
|
||||||
MTIA_hooks =
|
mtia_hooks = MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{});
|
||||||
MTIAHooksRegistry()->Create("MTIAHooks", MTIAHooksArgs{}).release();
|
if (!mtia_hooks) {
|
||||||
if (!MTIA_hooks) {
|
mtia_hooks = std::make_unique<MTIAHooksInterface>();
|
||||||
MTIA_hooks = new MTIAHooksInterface();
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
return *MTIA_hooks;
|
return *mtia_hooks;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool isMTIAHooksBuilt() {
|
||||||
|
return MTIAHooksRegistry()->Has("MTIAHooks");
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs)
|
C10_DEFINE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,9 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/core/Device.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
#include <c10/core/Stream.h>
|
||||||
#include <c10/util/Registry.h>
|
#include <c10/util/Registry.h>
|
||||||
|
|
||||||
#include <ATen/detail/AcceleratorHooksInterface.h>
|
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||||
|
|
@ -20,33 +22,72 @@ constexpr const char* MTIA_HELP =
|
||||||
"to use some MTIA's functionality without MTIA extension included.";
|
"to use some MTIA's functionality without MTIA extension included.";
|
||||||
|
|
||||||
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||||
|
// this fails the implementation if MTIAHooks functions are called, but
|
||||||
|
// MTIA backend is not present.
|
||||||
|
#define FAIL_MTIAHOOKS_FUNC(func) \
|
||||||
|
TORCH_CHECK(false, "Cannot execute ", func, "() without MTIA backend.");
|
||||||
|
|
||||||
virtual ~MTIAHooksInterface() override = default;
|
virtual ~MTIAHooksInterface() override = default;
|
||||||
|
|
||||||
virtual void initMTIA() const {
|
virtual void initMTIA() const {
|
||||||
TORCH_CHECK(
|
// Avoid logging here, since MTIA needs init devices first then it will know
|
||||||
false,
|
// how many devices are available. Make it as no-op if mtia extension is not
|
||||||
"Cannot initialize MTIA without MTIA Extension for PyTorch.",
|
// dynamically loaded.
|
||||||
MTIA_HELP);
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool hasMTIA() const {
|
virtual bool hasMTIA() const {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual DeviceIndex deviceCount() const override {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void deviceSynchronize(c10::DeviceIndex device_index) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
virtual std::string showConfig() const {
|
virtual std::string showConfig() const {
|
||||||
TORCH_CHECK(
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
false,
|
|
||||||
"Cannot query detailed MTIA version without MTIA Extension for PyTorch.",
|
|
||||||
MTIA_HELP);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
|
virtual bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||||
TORCH_CHECK(
|
return false;
|
||||||
false,
|
|
||||||
"Cannot check MTIA primary context without MTIA Extension for PyTorch.",
|
|
||||||
MTIA_HELP);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
virtual void setCurrentDevice(DeviceIndex device) const override {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual DeviceIndex getCurrentDevice() const override {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual DeviceIndex exchangeDevice(DeviceIndex device) const override {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual DeviceIndex maybeExchangeDevice(DeviceIndex device) const override {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual c10::Stream getCurrentStream(DeviceIndex device) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual c10::Stream getDefaultStream(DeviceIndex device) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void setCurrentStream(const c10::Stream& stream) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TORCH_API MTIAHooksArgs {};
|
struct TORCH_API MTIAHooksArgs {};
|
||||||
|
|
@ -57,5 +98,6 @@ C10_DECLARE_REGISTRY(MTIAHooksRegistry, MTIAHooksInterface, MTIAHooksArgs);
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
TORCH_API const MTIAHooksInterface& getMTIAHooks();
|
TORCH_API const MTIAHooksInterface& getMTIAHooks();
|
||||||
|
TORCH_API bool isMTIAHooksBuilt();
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
|
|
@ -822,6 +822,7 @@ libtorch_python_core_sources = [
|
||||||
"torch/csrc/dynamo/init.cpp",
|
"torch/csrc/dynamo/init.cpp",
|
||||||
"torch/csrc/functorch/init.cpp",
|
"torch/csrc/functorch/init.cpp",
|
||||||
"torch/csrc/mps/Module.cpp",
|
"torch/csrc/mps/Module.cpp",
|
||||||
|
"torch/csrc/mtia/Module.cpp",
|
||||||
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
"torch/csrc/inductor/aoti_runner/pybind.cpp",
|
||||||
"torch/csrc/jit/backends/backend_init.cpp",
|
"torch/csrc/jit/backends/backend_init.cpp",
|
||||||
"torch/csrc/jit/python/init.cpp",
|
"torch/csrc/jit/python/init.cpp",
|
||||||
|
|
|
||||||
|
|
@ -69,6 +69,7 @@ Features described in this documentation are classified by release status:
|
||||||
torch.cuda.memory <torch_cuda_memory>
|
torch.cuda.memory <torch_cuda_memory>
|
||||||
mps
|
mps
|
||||||
xpu
|
xpu
|
||||||
|
mtia
|
||||||
meta
|
meta
|
||||||
torch.backends <backends>
|
torch.backends <backends>
|
||||||
torch.export <export>
|
torch.export <export>
|
||||||
|
|
|
||||||
34
docs/source/mtia.rst
Normal file
34
docs/source/mtia.rst
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
torch.mtia
|
||||||
|
===================================
|
||||||
|
|
||||||
|
The MTIA backend is implemented out of the tree, only interfaces are be defined here.
|
||||||
|
|
||||||
|
.. automodule:: torch.mtia
|
||||||
|
.. currentmodule:: torch.mtia
|
||||||
|
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
StreamContext
|
||||||
|
current_device
|
||||||
|
current_stream
|
||||||
|
default_stream
|
||||||
|
device_count
|
||||||
|
init
|
||||||
|
is_available
|
||||||
|
is_initialized
|
||||||
|
set_stream
|
||||||
|
stream
|
||||||
|
synchronize
|
||||||
|
device
|
||||||
|
DeferredMtiaCallError
|
||||||
|
|
||||||
|
Streams and events
|
||||||
|
------------------
|
||||||
|
.. autosummary::
|
||||||
|
:toctree: generated
|
||||||
|
:nosignatures:
|
||||||
|
|
||||||
|
Event
|
||||||
|
Stream
|
||||||
|
|
@ -684,6 +684,7 @@ Utilities
|
||||||
set_float32_matmul_precision
|
set_float32_matmul_precision
|
||||||
get_float32_matmul_precision
|
get_float32_matmul_precision
|
||||||
set_warn_always
|
set_warn_always
|
||||||
|
get_device_module
|
||||||
is_warn_always_enabled
|
is_warn_always_enabled
|
||||||
vmap
|
vmap
|
||||||
_assert
|
_assert
|
||||||
|
|
|
||||||
|
|
@ -1719,6 +1719,24 @@ _TensorBase = TensorBase
|
||||||
# Defined in torch/csrc/multiprocessing/init.cpp
|
# Defined in torch/csrc/multiprocessing/init.cpp
|
||||||
def _multiprocessing_init() -> None: ...
|
def _multiprocessing_init() -> None: ...
|
||||||
|
|
||||||
|
# Defined in torch/csrc/Module.cpp
|
||||||
|
def _accelerator_hooks_device_count() -> _int: ...
|
||||||
|
def _accelerator_hooks_set_current_device(device_index: _int) -> None: ...
|
||||||
|
def _accelerator_hooks_get_current_device() -> _int: ...
|
||||||
|
def _accelerator_hooks_exchange_device(device_index: _int) -> _int: ...
|
||||||
|
def _accelerator_hooks_maybe_exchange_device(device_index: _int) -> _int: ...
|
||||||
|
def _get_accelerator(check: _bool = False) -> _device: ...
|
||||||
|
|
||||||
|
# Defined in torch/csrc/mtia/Module.cpp
|
||||||
|
def _mtia_init() -> None: ...
|
||||||
|
def _mtia_isBuilt() -> _bool: ...
|
||||||
|
def _mtia_isInBadFork() -> _bool: ...
|
||||||
|
def _mtia_deviceSynchronize() -> None: ...
|
||||||
|
def _mtia_getCurrentStream(device: _int) -> Stream: ...
|
||||||
|
def _mtia_setCurrentStream(stream: Stream) -> None: ...
|
||||||
|
def _mtia_getDefaultStream(device: _int) -> Stream: ...
|
||||||
|
|
||||||
|
|
||||||
# Defined in torch/csrc/mps/Module.cpp
|
# Defined in torch/csrc/mps/Module.cpp
|
||||||
def _mps_deviceSynchronize() -> None: ...
|
def _mps_deviceSynchronize() -> None: ...
|
||||||
def _mps_get_default_generator() -> Generator: ...
|
def _mps_get_default_generator() -> Generator: ...
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ class DeviceType(Enum):
|
||||||
FPGA = ...
|
FPGA = ...
|
||||||
MAIA = ...
|
MAIA = ...
|
||||||
XLA = ...
|
XLA = ...
|
||||||
|
MTIA = ...
|
||||||
MPS = ...
|
MPS = ...
|
||||||
HPU = ...
|
HPU = ...
|
||||||
Meta = ...
|
Meta = ...
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,7 @@ __all__ = [
|
||||||
'SymBool', 'sym_not', 'unravel_index',
|
'SymBool', 'sym_not', 'unravel_index',
|
||||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
|
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
|
||||||
'export', 'autocast', 'cond', 'GradScaler',
|
'export', 'autocast', 'cond', 'GradScaler',
|
||||||
|
'get_device_module',
|
||||||
]
|
]
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
|
|
@ -1579,6 +1580,7 @@ from torch import cuda as cuda
|
||||||
from torch import cpu as cpu
|
from torch import cpu as cpu
|
||||||
from torch import mps as mps
|
from torch import mps as mps
|
||||||
from torch import xpu as xpu
|
from torch import xpu as xpu
|
||||||
|
from torch import mtia as mtia
|
||||||
from torch import autograd as autograd
|
from torch import autograd as autograd
|
||||||
from torch.autograd import (
|
from torch.autograd import (
|
||||||
no_grad as no_grad,
|
no_grad as no_grad,
|
||||||
|
|
@ -2016,6 +2018,27 @@ else:
|
||||||
|
|
||||||
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
||||||
|
|
||||||
|
def get_device_module(device: Optional[Union[torch.device, str]] = None):
|
||||||
|
"""
|
||||||
|
Returns the module associated with a given device(e.g., torch.device('cuda'), "mtia:0", "xpu", ...).
|
||||||
|
If no device is given, return the module for the current accelerator or CPU if none is present.
|
||||||
|
"""
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
device_module_name = device.type
|
||||||
|
elif isinstance(device, str):
|
||||||
|
device_module_name = torch.device(device).type
|
||||||
|
elif device is None:
|
||||||
|
# Using default accelerator type. If no accelerator is available, it automatically returns CPU device.
|
||||||
|
device_module_name = torch._C._get_accelerator().type
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Invalid value of device '{device}', expect torch.device, str, or None")
|
||||||
|
device_module = getattr(torch, device_module_name, None)
|
||||||
|
if device_module is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Device '{device_module_name}' does not have a corresponding module registered as 'torch.{device_module_name}'."
|
||||||
|
)
|
||||||
|
return device_module
|
||||||
|
|
||||||
|
|
||||||
def _constrain_as_value(symbol, min: Optional[builtins.int] = None, max: Optional[builtins.int] = None):
|
def _constrain_as_value(symbol, min: Optional[builtins.int] = None, max: Optional[builtins.int] = None):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -713,6 +713,8 @@ def _get_available_device_type():
|
||||||
return "cuda"
|
return "cuda"
|
||||||
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
|
if hasattr(torch, "xpu") and torch.xpu.is_available(): # type: ignore[attr-defined]
|
||||||
return "xpu"
|
return "xpu"
|
||||||
|
if hasattr(torch, "mtia") and torch.mtia.is_available():
|
||||||
|
return "mtia"
|
||||||
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
custom_backend_name = torch._C._get_privateuse1_backend_name()
|
||||||
custom_device_mod = getattr(torch, custom_backend_name, None)
|
custom_device_mod = getattr(torch, custom_backend_name, None)
|
||||||
if custom_device_mod and custom_device_mod.is_available():
|
if custom_device_mod and custom_device_mod.is_available():
|
||||||
|
|
@ -727,6 +729,8 @@ def _get_device_attr(get_member):
|
||||||
return get_member(torch.cuda)
|
return get_member(torch.cuda)
|
||||||
if device_type and device_type.lower() == "xpu":
|
if device_type and device_type.lower() == "xpu":
|
||||||
return get_member(torch.xpu) # type: ignore[attr-defined]
|
return get_member(torch.xpu) # type: ignore[attr-defined]
|
||||||
|
if device_type and device_type.lower() == "mtia":
|
||||||
|
return get_member(torch.mtia)
|
||||||
if device_type == torch._C._get_privateuse1_backend_name():
|
if device_type == torch._C._get_privateuse1_backend_name():
|
||||||
return get_member(getattr(torch, device_type))
|
return get_member(getattr(torch, device_type))
|
||||||
# add more available device types here
|
# add more available device types here
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
#include <ATen/DeviceAccelerator.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
#include <fmt/core.h>
|
#include <fmt/core.h>
|
||||||
#include <sys/types.h>
|
#include <sys/types.h>
|
||||||
|
|
@ -16,10 +17,12 @@
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <ATen/Utils.h>
|
#include <ATen/Utils.h>
|
||||||
#include <ATen/core/Vitals.h>
|
#include <ATen/core/Vitals.h>
|
||||||
|
#include <ATen/detail/AcceleratorHooksInterface.h>
|
||||||
#include <ATen/dlpack.h>
|
#include <ATen/dlpack.h>
|
||||||
#include <ATen/native/ConvUtils.h>
|
#include <ATen/native/ConvUtils.h>
|
||||||
#include <ATen/native/ForeachUtils.h>
|
#include <ATen/native/ForeachUtils.h>
|
||||||
#include <ATen/native/Normalization.h>
|
#include <ATen/native/Normalization.h>
|
||||||
|
#include <c10/core/Device.h>
|
||||||
#include <c10/core/DispatchKeySet.h>
|
#include <c10/core/DispatchKeySet.h>
|
||||||
#include <c10/util/AbortHandler.h>
|
#include <c10/util/AbortHandler.h>
|
||||||
#include <c10/util/Backtrace.h>
|
#include <c10/util/Backtrace.h>
|
||||||
|
|
@ -72,6 +75,7 @@
|
||||||
#include <torch/csrc/lazy/python/init.h>
|
#include <torch/csrc/lazy/python/init.h>
|
||||||
#include <torch/csrc/monitor/python_init.h>
|
#include <torch/csrc/monitor/python_init.h>
|
||||||
#include <torch/csrc/mps/Module.h>
|
#include <torch/csrc/mps/Module.h>
|
||||||
|
#include <torch/csrc/mtia/Module.h>
|
||||||
#include <torch/csrc/multiprocessing/init.h>
|
#include <torch/csrc/multiprocessing/init.h>
|
||||||
#include <torch/csrc/onnx/init.h>
|
#include <torch/csrc/onnx/init.h>
|
||||||
#include <torch/csrc/profiler/python/init.h>
|
#include <torch/csrc/profiler/python/init.h>
|
||||||
|
|
@ -1641,6 +1645,7 @@ PyObject* initModule() {
|
||||||
#ifdef USE_XPU
|
#ifdef USE_XPU
|
||||||
torch::xpu::initModule(module);
|
torch::xpu::initModule(module);
|
||||||
#endif
|
#endif
|
||||||
|
torch::mtia::initModule(module);
|
||||||
torch::cpu::initModule(module);
|
torch::cpu::initModule(module);
|
||||||
torch::initVerboseBindings(module);
|
torch::initVerboseBindings(module);
|
||||||
ASSERT_TRUE(THPStorage_init(module));
|
ASSERT_TRUE(THPStorage_init(module));
|
||||||
|
|
@ -1975,6 +1980,70 @@ Call this whenever a new thread is created in order to propagate values from
|
||||||
return at::impl::ThreadLocalPythonObjects::get_state().contains(key);
|
return at::impl::ThreadLocalPythonObjects::get_state().contains(key);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
py_module.def("_accelerator_hooks_device_count", []() {
|
||||||
|
auto device_type = at::getAccelerator();
|
||||||
|
if (device_type.has_value()) {
|
||||||
|
return at::globalContext()
|
||||||
|
.getAcceleratorHooksInterface(device_type.value())
|
||||||
|
.deviceCount();
|
||||||
|
}
|
||||||
|
return c10::DeviceIndex(-1);
|
||||||
|
});
|
||||||
|
|
||||||
|
py_module.def(
|
||||||
|
"_accelerator_hooks_set_current_device",
|
||||||
|
[](c10::DeviceIndex device_index) {
|
||||||
|
auto device_type = at::getAccelerator();
|
||||||
|
if (device_type.has_value()) {
|
||||||
|
at::globalContext()
|
||||||
|
.getAcceleratorHooksInterface(device_type.value())
|
||||||
|
.setCurrentDevice(device_index);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
py_module.def("_accelerator_hooks_get_current_device", []() {
|
||||||
|
auto device_type = at::getAccelerator();
|
||||||
|
if (device_type.has_value()) {
|
||||||
|
return at::globalContext()
|
||||||
|
.getAcceleratorHooksInterface(device_type.value())
|
||||||
|
.getCurrentDevice();
|
||||||
|
}
|
||||||
|
return c10::DeviceIndex(-1);
|
||||||
|
});
|
||||||
|
|
||||||
|
py_module.def(
|
||||||
|
"_accelerator_hooks_exchange_device", [](c10::DeviceIndex device_index) {
|
||||||
|
auto device_type = at::getAccelerator();
|
||||||
|
if (device_type.has_value()) {
|
||||||
|
return at::globalContext()
|
||||||
|
.getAcceleratorHooksInterface(device_type.value())
|
||||||
|
.exchangeDevice(device_index);
|
||||||
|
}
|
||||||
|
return c10::DeviceIndex(-1);
|
||||||
|
});
|
||||||
|
|
||||||
|
py_module.def(
|
||||||
|
"_accelerator_hooks_maybe_exchange_device",
|
||||||
|
[](c10::DeviceIndex device_index) {
|
||||||
|
auto device_type = at::getAccelerator();
|
||||||
|
if (device_type.has_value()) {
|
||||||
|
return at::globalContext()
|
||||||
|
.getAcceleratorHooksInterface(device_type.value())
|
||||||
|
.maybeExchangeDevice(device_index);
|
||||||
|
}
|
||||||
|
return c10::DeviceIndex(-1);
|
||||||
|
});
|
||||||
|
|
||||||
|
py_module.def(
|
||||||
|
"_get_accelerator",
|
||||||
|
[](c10::optional<bool> check = c10::nullopt) {
|
||||||
|
return c10::Device(
|
||||||
|
at::getAccelerator(check.value_or(false))
|
||||||
|
.value_or(c10::DeviceType::CPU),
|
||||||
|
-1);
|
||||||
|
},
|
||||||
|
py::arg("check") = nullptr);
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
PyObject* has_cuda = Py_True;
|
PyObject* has_cuda = Py_True;
|
||||||
#else
|
#else
|
||||||
|
|
|
||||||
81
torch/csrc/mtia/Module.cpp
Normal file
81
torch/csrc/mtia/Module.cpp
Normal file
|
|
@ -0,0 +1,81 @@
|
||||||
|
#include <ATen/ATen.h>
|
||||||
|
#include <c10/util/CallOnce.h>
|
||||||
|
#include <torch/csrc/Generator.h>
|
||||||
|
#include <torch/csrc/Stream.h>
|
||||||
|
#include <torch/csrc/python_headers.h>
|
||||||
|
#include <torch/csrc/utils/device_lazy_init.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
|
#include <c10/core/DeviceType.h>
|
||||||
|
#include <c10/core/Stream.h>
|
||||||
|
#ifndef WIN32
|
||||||
|
#include <pthread.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace mtia {
|
||||||
|
|
||||||
|
static bool in_bad_fork = false; // True for children forked after mtia init
|
||||||
|
|
||||||
|
#ifndef WIN32
|
||||||
|
// Called in the forked child if mtia has already been initialized
|
||||||
|
static void forked_child() {
|
||||||
|
in_bad_fork = true;
|
||||||
|
torch::utils::set_requires_device_init(at::kMTIA, true);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Should be called before the first mtia call.
|
||||||
|
// Note: This is distinct from initExtension because a stub mtia implementation
|
||||||
|
// has some working functions (e.g. device_count) but cannot fully initialize.
|
||||||
|
static void poison_fork() {
|
||||||
|
#ifndef WIN32
|
||||||
|
static c10::once_flag flag;
|
||||||
|
c10::call_once(flag, [] { pthread_atfork(nullptr, nullptr, forked_child); });
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
void initModule(PyObject* module) {
|
||||||
|
auto m = py::handle(module).cast<py::module>();
|
||||||
|
|
||||||
|
m.def("_mtia_init", []() {
|
||||||
|
TORCH_INTERNAL_ASSERT(!in_bad_fork); // Handled at python level
|
||||||
|
poison_fork();
|
||||||
|
at::globalContext().lazyInitMTIA();
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("_mtia_isBuilt", []() {
|
||||||
|
// Check if the MTIAHooks class has been registered with the registry.
|
||||||
|
return at::detail::isMTIAHooksBuilt();
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("_mtia_isInBadFork", []() { return in_bad_fork; });
|
||||||
|
|
||||||
|
m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) {
|
||||||
|
torch::utils::device_lazy_init(at::kMTIA);
|
||||||
|
return at::detail::getMTIAHooks().getCurrentStream(device_index);
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("_mtia_deviceSynchronize", [](c10::DeviceIndex device_index) {
|
||||||
|
torch::utils::device_lazy_init(at::kMTIA);
|
||||||
|
at::detail::getMTIAHooks().deviceSynchronize(
|
||||||
|
at::detail::getMTIAHooks().getCurrentDevice());
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) {
|
||||||
|
torch::utils::device_lazy_init(at::kMTIA);
|
||||||
|
return at::detail::getMTIAHooks().getDefaultStream(device_index);
|
||||||
|
});
|
||||||
|
|
||||||
|
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
|
||||||
|
torch::utils::device_lazy_init(at::kMTIA);
|
||||||
|
auto device = at::detail::getMTIAHooks().getCurrentDevice();
|
||||||
|
if (device != stream.device_index()) {
|
||||||
|
at::detail::getMTIAHooks().setCurrentDevice(stream.device_index());
|
||||||
|
}
|
||||||
|
at::detail::getMTIAHooks().setCurrentStream(stream);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace mtia
|
||||||
|
} // namespace torch
|
||||||
12
torch/csrc/mtia/Module.h
Normal file
12
torch/csrc/mtia/Module.h
Normal file
|
|
@ -0,0 +1,12 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <torch/csrc/python_headers.h>
|
||||||
|
|
||||||
|
namespace torch {
|
||||||
|
namespace mtia {
|
||||||
|
|
||||||
|
// PyMethodDef* python_functions();
|
||||||
|
void initModule(PyObject* module);
|
||||||
|
|
||||||
|
} // namespace mtia
|
||||||
|
} // namespace torch
|
||||||
|
|
@ -194,6 +194,12 @@ struct type_caster<c10::Stream> {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
// NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
|
||||||
PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream"));
|
PYBIND11_TYPE_CASTER(c10::Stream, _("torch.Stream"));
|
||||||
|
|
||||||
|
// PYBIND11_TYPE_CASTER defines a member field called value. Since c10::Stream
|
||||||
|
// cannot be default-initialized, we provide this constructor to explicitly
|
||||||
|
// initialize that field. The value doesn't matter as it will be overwritten
|
||||||
|
// after a successful call to load.
|
||||||
|
type_caster() : value(c10::Stream::DEFAULT, c10::Device(c10::kCPU, 0)) {}
|
||||||
|
|
||||||
bool load(handle src, bool) {
|
bool load(handle src, bool) {
|
||||||
PyObject* obj = src.ptr();
|
PyObject* obj = src.ptr();
|
||||||
if (THPStream_Check(obj)) {
|
if (THPStream_Check(obj)) {
|
||||||
|
|
|
||||||
262
torch/mtia/__init__.py
Normal file
262
torch/mtia/__init__.py
Normal file
|
|
@ -0,0 +1,262 @@
|
||||||
|
r"""
|
||||||
|
This package enables an interface for accessing MTIA backend in python
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch.types import Device
|
||||||
|
|
||||||
|
from .. import device as _device, Tensor
|
||||||
|
from .._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||||
|
from ._utils import _get_device_index
|
||||||
|
|
||||||
|
_device_t = Union[_device, str, int, None]
|
||||||
|
|
||||||
|
# torch.mtia.Event/Stream is alias of torch.Event/Stream
|
||||||
|
Event = torch.Event
|
||||||
|
Stream = torch.Stream
|
||||||
|
|
||||||
|
_initialized = False
|
||||||
|
_queued_calls: List[
|
||||||
|
Tuple[Callable[[], None], List[str]]
|
||||||
|
] = [] # don't invoke these until initialization occurs
|
||||||
|
_tls = threading.local()
|
||||||
|
_initialization_lock = threading.Lock()
|
||||||
|
_lazy_seed_tracker = _LazySeedTracker()
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
_lazy_init()
|
||||||
|
|
||||||
|
|
||||||
|
def is_initialized():
|
||||||
|
r"""Return whether PyTorch's MTIA state has been initialized."""
|
||||||
|
return _initialized and not _is_in_bad_fork()
|
||||||
|
|
||||||
|
|
||||||
|
def _is_in_bad_fork() -> bool:
|
||||||
|
return torch._C._mtia_isInBadFork()
|
||||||
|
|
||||||
|
|
||||||
|
def _lazy_init() -> None:
|
||||||
|
global _initialized, _queued_calls
|
||||||
|
if is_initialized() or hasattr(_tls, "is_initializing"):
|
||||||
|
return
|
||||||
|
with _initialization_lock:
|
||||||
|
# We be double-checked locking, boys! This is OK because
|
||||||
|
# the above test was GIL protected anyway. The inner test
|
||||||
|
# is for when a thread blocked on some other thread which was
|
||||||
|
# doing the initialization; when they get the lock, they will
|
||||||
|
# find there is nothing left to do.
|
||||||
|
if is_initialized():
|
||||||
|
return
|
||||||
|
# It is important to prevent other threads from entering _lazy_init
|
||||||
|
# immediately, while we are still guaranteed to have the GIL, because some
|
||||||
|
# of the C calls we make below will release the GIL
|
||||||
|
if _is_in_bad_fork():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Cannot re-initialize MTIA in forked subprocess. To use MTIA with "
|
||||||
|
"multiprocessing, you must use the 'spawn' start method"
|
||||||
|
)
|
||||||
|
if not _is_compiled():
|
||||||
|
raise AssertionError("Torch not compiled with MTIA enabled")
|
||||||
|
|
||||||
|
torch._C._mtia_init()
|
||||||
|
# Some of the queued calls may reentrantly call _lazy_init();
|
||||||
|
# we need to just return without initializing in that case.
|
||||||
|
# However, we must not let any *other* threads in!
|
||||||
|
_tls.is_initializing = True
|
||||||
|
|
||||||
|
for calls in _lazy_seed_tracker.get_calls():
|
||||||
|
if calls:
|
||||||
|
_queued_calls.append(calls)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for queued_call, orig_traceback in _queued_calls:
|
||||||
|
try:
|
||||||
|
queued_call()
|
||||||
|
except Exception as e:
|
||||||
|
msg = (
|
||||||
|
f"MTIA call failed lazily at initialization with error: {str(e)}\n\n"
|
||||||
|
f"MTIA call was originally invoked at:\n\n{''.join(orig_traceback)}"
|
||||||
|
)
|
||||||
|
raise DeferredMtiaCallError(msg) from e
|
||||||
|
finally:
|
||||||
|
delattr(_tls, "is_initializing")
|
||||||
|
_initialized = True
|
||||||
|
|
||||||
|
|
||||||
|
class DeferredMtiaCallError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def _is_compiled() -> bool:
|
||||||
|
r"""Return true if compiled with MTIA support."""
|
||||||
|
return torch._C._mtia_isBuilt()
|
||||||
|
|
||||||
|
|
||||||
|
def is_available() -> bool:
|
||||||
|
r"""Return true if MTIA device is available"""
|
||||||
|
if not _is_compiled():
|
||||||
|
return False
|
||||||
|
# MTIA has to init devices first to know if there is any devices available.
|
||||||
|
return device_count() > 0
|
||||||
|
|
||||||
|
|
||||||
|
def synchronize() -> None:
|
||||||
|
r"""Waits for all jobs in all streams on a MTIA device to complete."""
|
||||||
|
return torch._C._mtia_deviceSynchronize()
|
||||||
|
|
||||||
|
|
||||||
|
def device_count() -> int:
|
||||||
|
r"""Return the number of MTIA devices available."""
|
||||||
|
return torch._C._accelerator_hooks_device_count()
|
||||||
|
|
||||||
|
|
||||||
|
def current_device() -> int:
|
||||||
|
r"""Return the index of a currently selected device."""
|
||||||
|
return torch._C._accelerator_hooks_get_current_device()
|
||||||
|
|
||||||
|
|
||||||
|
def current_stream(device: Optional[_device_t] = None) -> Stream:
|
||||||
|
r"""Return the currently selected :class:`Stream` for a given device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
the currently selected :class:`Stream` for the current device, given
|
||||||
|
by :func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
|
||||||
|
(default).
|
||||||
|
"""
|
||||||
|
return torch._C._mtia_getCurrentStream(_get_device_index(device, optional=True))
|
||||||
|
|
||||||
|
|
||||||
|
def default_stream(device: Optional[_device_t] = None) -> Stream:
|
||||||
|
r"""Return the default :class:`Stream` for a given device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int, optional): selected device. Returns
|
||||||
|
the default :class:`Stream` for the current device, given by
|
||||||
|
:func:`~torch.mtia.current_device`, if :attr:`device` is ``None``
|
||||||
|
(default).
|
||||||
|
"""
|
||||||
|
return torch._C._mtia_getDefaultStream(_get_device_index(device, optional=True))
|
||||||
|
|
||||||
|
|
||||||
|
def set_stream(stream: Stream):
|
||||||
|
r"""Set the current stream.This is a wrapper API to set the stream.
|
||||||
|
Usage of this function is discouraged in favor of the ``stream``
|
||||||
|
context manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stream (Stream): selected stream. This function is a no-op
|
||||||
|
if this argument is ``None``.
|
||||||
|
"""
|
||||||
|
if stream is None:
|
||||||
|
return
|
||||||
|
torch._C._mtia_setCurrentStream(stream)
|
||||||
|
|
||||||
|
|
||||||
|
class device:
|
||||||
|
r"""Context-manager that changes the selected device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device or int): device index to select. It's a no-op if
|
||||||
|
this argument is a negative integer or ``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, device: Any):
|
||||||
|
self.idx = _get_device_index(device, optional=True)
|
||||||
|
self.prev_idx = -1
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
self.prev_idx = torch._C._accelerator_hooks_maybe_exchange_device(self.idx)
|
||||||
|
|
||||||
|
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||||
|
self.idx = torch._C._accelerator_hooks_maybe_exchange_device(self.prev_idx)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
class StreamContext:
|
||||||
|
r"""Context-manager that selects a given stream.
|
||||||
|
|
||||||
|
All MTIA kernels queued within its context will be enqueued on a selected
|
||||||
|
stream.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
Stream (Stream): selected stream. This manager is a no-op if it's
|
||||||
|
``None``.
|
||||||
|
.. note:: Streams are per-device.
|
||||||
|
"""
|
||||||
|
|
||||||
|
cur_stream: Optional["torch.mtia.Stream"]
|
||||||
|
|
||||||
|
def __init__(self, stream: Optional["torch.mtia.Stream"]):
|
||||||
|
self.stream = stream
|
||||||
|
self.idx = _get_device_index(None, True)
|
||||||
|
if not torch.jit.is_scripting():
|
||||||
|
if self.idx is None:
|
||||||
|
self.idx = -1
|
||||||
|
|
||||||
|
self.src_prev_stream = (
|
||||||
|
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
||||||
|
)
|
||||||
|
self.dst_prev_stream = (
|
||||||
|
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
||||||
|
)
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
# Local cur_stream variable for type refinement
|
||||||
|
cur_stream = self.stream
|
||||||
|
# Return if stream is None or MTIA device not available
|
||||||
|
if cur_stream is None or self.idx == -1:
|
||||||
|
return
|
||||||
|
self.src_prev_stream = torch.mtia.current_stream(None)
|
||||||
|
|
||||||
|
# If the stream is not on the current device, then
|
||||||
|
# set the current stream on the device
|
||||||
|
if self.src_prev_stream.device != cur_stream.device:
|
||||||
|
with device(cur_stream.device):
|
||||||
|
self.dst_prev_stream = torch.mtia.current_stream(cur_stream.device)
|
||||||
|
torch.mtia.set_stream(cur_stream)
|
||||||
|
|
||||||
|
def __exit__(self, type: Any, value: Any, traceback: Any):
|
||||||
|
# Local cur_stream variable for type refinement
|
||||||
|
cur_stream = self.stream
|
||||||
|
# If stream is None or no MTIA device available, return
|
||||||
|
if cur_stream is None or self.idx == -1:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reset the stream on the original device
|
||||||
|
# and destination device
|
||||||
|
if self.src_prev_stream.device != cur_stream.device: # type: ignore[union-attr]
|
||||||
|
torch.mtia.set_stream(self.dst_prev_stream) # type: ignore[arg-type]
|
||||||
|
torch.mtia.set_stream(self.src_prev_stream) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def stream(stream: Optional["torch.mtia.Stream"]) -> StreamContext:
|
||||||
|
r"""Wrap around the Context-manager StreamContext that selects a given stream.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
stream (Stream): selected stream. This manager is a no-op if it's
|
||||||
|
``None``.
|
||||||
|
..Note:: In eager mode stream is of type Stream class while in JIT it doesn't support torch.mtia.stream
|
||||||
|
"""
|
||||||
|
return StreamContext(stream)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"init",
|
||||||
|
"is_available",
|
||||||
|
"is_initialized",
|
||||||
|
"synchronize",
|
||||||
|
"device_count",
|
||||||
|
"current_device",
|
||||||
|
"current_stream",
|
||||||
|
"default_stream",
|
||||||
|
"set_stream",
|
||||||
|
"stream",
|
||||||
|
"device",
|
||||||
|
]
|
||||||
38
torch/mtia/_utils.py
Normal file
38
torch/mtia/_utils.py
Normal file
|
|
@ -0,0 +1,38 @@
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# The _get_device_index has been moved to torch.utils._get_device_index
|
||||||
|
from torch._utils import _get_device_index as _torch_get_device_index
|
||||||
|
|
||||||
|
|
||||||
|
def _get_device_index(
|
||||||
|
device: Any, optional: bool = False, allow_cpu: bool = False
|
||||||
|
) -> int:
|
||||||
|
r"""Get the device index from :attr:`device`, which can be a torch.device object, a Python integer, or ``None``.
|
||||||
|
|
||||||
|
If :attr:`device` is a torch.device object, returns the device index if it
|
||||||
|
is a MTIA device. Note that for a MTIA device without a specified index,
|
||||||
|
i.e., ``torch.device('mtia')``, this will return the current default MTIA
|
||||||
|
device if :attr:`optional` is ``True``. If :attr:`allow_cpu` is ``True``,
|
||||||
|
CPU devices will be accepted and ``-1`` will be returned in this case.
|
||||||
|
|
||||||
|
If :attr:`device` is a Python integer, it is returned as is.
|
||||||
|
|
||||||
|
If :attr:`device` is ``None``, this will return the current default MTIA
|
||||||
|
device if :attr:`optional` is ``True``.
|
||||||
|
"""
|
||||||
|
if isinstance(device, int):
|
||||||
|
return device
|
||||||
|
if isinstance(device, str):
|
||||||
|
device = torch.device(device)
|
||||||
|
if isinstance(device, torch.device):
|
||||||
|
if allow_cpu:
|
||||||
|
if device.type not in ["mtia", "cpu"]:
|
||||||
|
raise ValueError(f"Expected a mtia or cpu device, but got: {device}")
|
||||||
|
elif device.type != "mtia":
|
||||||
|
raise ValueError(f"Expected a mtia device, but got: {device}")
|
||||||
|
if not torch.jit.is_scripting():
|
||||||
|
if isinstance(device, torch.mtia.device):
|
||||||
|
return device.idx
|
||||||
|
return _torch_get_device_index(device, optional, allow_cpu)
|
||||||
|
|
@ -283,6 +283,7 @@ def get_ignored_functions() -> Set[Callable]:
|
||||||
torch.use_deterministic_algorithms,
|
torch.use_deterministic_algorithms,
|
||||||
torch.is_deterministic_algorithms_warn_only_enabled,
|
torch.is_deterministic_algorithms_warn_only_enabled,
|
||||||
torch.set_deterministic_debug_mode,
|
torch.set_deterministic_debug_mode,
|
||||||
|
torch.get_device_module,
|
||||||
torch.get_deterministic_debug_mode,
|
torch.get_deterministic_debug_mode,
|
||||||
torch.set_float32_matmul_precision,
|
torch.set_float32_matmul_precision,
|
||||||
torch.get_float32_matmul_precision,
|
torch.get_float32_matmul_precision,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user