mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Introduce a device-agnostic runtime API design (#132204)
# Motivation According to [[RFC]A device-agnostic Python runtime API design for stream-based accelerators](https://github.com/pytorch/pytorch/issues/128403), this PR intends to introduce a device-agnostic runtime API design. I personally prefer the **Simple Version** APIs that no longer accept the device type as an input argument. It means we will leverage `getAccelerator` to fetch the current accelerator. And it is flexible to expand these APIs to handle multiple types of accelerator scenarios. The design does **NOT** break the previous design philosophies. I also believe that namespace torch.accelerator is better. It lets users know that the APIs they are calling are running on an accelerator rather than CPU. This is important. Meanwhile, we can follow a simple API design principle: 1. Device-agnostic APIs should be placed under the torch.accelerator namespace and not accept a device_type optional parameter. 2. Device-specific APIs should be placed under device-specific submodules. 3. APIS required by both CPU and accelerators should be placed under the torch namespace and accept a device_type optional parameter. Also, I list the pros and cons of **Simple Version** here: Pros: - `torch.accelerator.foo` will have the same input argument as `torch.xxx.foo`, bringing a better user experience; - more concise, facilitate the developer to write a device-agnostic code. Cons: - no obvious drawbacks. # Additional Context I list the new APIs here: ```python torch.accelerator.is_available() -> bool: torch.accelerator.current_accelerator() -> torch.device: torch.accelerator.device_count() -> int: torch.accelerator.current_device_idx() -> int: torch.accelerator.set_device_idx(device: Union[torch.device, str, int, None]) -> None: torch.accelerator.current_stream(device: Union[torch.device, str, int, None]) -> torch.Stream: torch.accelerator.set_stream(stream: torch.Stream) -> None: torch.accelerator.synchronize(device: Union[torch.device, str, int, None]) -> None: ``` According to the discussion with Alban, we decide to change the API name `set_device` to `set_device_idx` and `current_device` to `current_device_idx` for more explicit. And will submit other PR to support device and stream context manager. Pull Request resolved: https://github.com/pytorch/pytorch/pull/132204 Approved by: https://github.com/EikanWang, https://github.com/abhilash1910, https://github.com/gujinghui, https://github.com/albanD
This commit is contained in:
parent
1152726feb
commit
40c098f731
|
|
@ -216,6 +216,15 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI
|
|||
C10_HIP_CHECK(hipEventSynchronize(hip_event));
|
||||
}
|
||||
|
||||
// Note: synchronizeDevice can be safely called from any device
|
||||
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
|
||||
int orig_device{-1};
|
||||
C10_HIP_CHECK(hipGetDevice(&orig_device));
|
||||
C10_HIP_CHECK(hipSetDevice(device_index));
|
||||
C10_HIP_CHECK(hipDeviceSynchronize());
|
||||
C10_HIP_CHECK(hipSetDevice(orig_device));
|
||||
}
|
||||
|
||||
void recordDataPtrOnStream(
|
||||
const c10::DataPtr& data_ptr,
|
||||
const Stream& stream) const override {
|
||||
|
|
|
|||
|
|
@ -111,6 +111,8 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
|
|||
|
||||
bool queryEvent(void* event) const override;
|
||||
|
||||
void synchronizeDevice(const DeviceIndex device_index) const override;
|
||||
|
||||
};
|
||||
|
||||
/// A variant of OptionalDeviceGuard that is specialized for MPS.
|
||||
|
|
|
|||
|
|
@ -42,4 +42,8 @@ bool MPSGuardImpl::queryEvent(void* event) const {
|
|||
return mps_event->query();
|
||||
}
|
||||
|
||||
void MPSGuardImpl::synchronizeDevice(const DeviceIndex device_index) const {
|
||||
at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT);
|
||||
}
|
||||
|
||||
} // namespace at::mps
|
||||
|
|
|
|||
|
|
@ -795,6 +795,7 @@ libtorch_python_xpu_sources = [
|
|||
|
||||
libtorch_python_core_sources = [
|
||||
"torch/csrc/DataLoader.cpp",
|
||||
"torch/csrc/DeviceAccelerator.cpp",
|
||||
"torch/csrc/Device.cpp",
|
||||
"torch/csrc/Dtype.cpp",
|
||||
"torch/csrc/DynamicTypes.cpp",
|
||||
|
|
|
|||
|
|
@ -212,6 +212,15 @@ struct C10_API DeviceGuardImplInterface {
|
|||
TORCH_CHECK(false, "Backend doesn't support synchronizing events.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Wait (by blocking the calling thread) until all the work previously
|
||||
* enqueued on the device has been completed.
|
||||
*/
|
||||
virtual void synchronizeDevice(const DeviceIndex /*device_index*/) const {
|
||||
TORCH_CHECK(
|
||||
false, "Backend doesn't support synchronizing all streams on device.");
|
||||
}
|
||||
|
||||
/**
|
||||
* Ensure the caching allocator (if any) is aware that the given DataPtr is
|
||||
* being used on the given stream, and that it should thus avoid recycling the
|
||||
|
|
|
|||
|
|
@ -96,6 +96,10 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface {
|
|||
return impl_->synchronizeEvent(event);
|
||||
}
|
||||
|
||||
void synchronizeDevice(const DeviceIndex device_index) const override {
|
||||
return impl_->synchronizeDevice(device_index);
|
||||
}
|
||||
|
||||
private:
|
||||
const DeviceGuardImplInterface* impl_ = nullptr;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -219,6 +219,19 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
C10_CUDA_CHECK(cudaEventSynchronize(cuda_event));
|
||||
}
|
||||
|
||||
// Note: synchronizeDevice can be safely called from any device
|
||||
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
|
||||
DeviceIndex orig_device{-1};
|
||||
C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device));
|
||||
C10_CUDA_CHECK(c10::cuda::SetDevice(device_index));
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_device_synchronization(c10::kCUDA);
|
||||
}
|
||||
C10_CUDA_CHECK(cudaDeviceSynchronize());
|
||||
C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device));
|
||||
}
|
||||
|
||||
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
|
||||
const override {
|
||||
CUDAStream cuda_stream{stream};
|
||||
|
|
|
|||
|
|
@ -163,6 +163,14 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
xpu_event->wait_and_throw();
|
||||
}
|
||||
|
||||
void synchronizeDevice(const c10::DeviceIndex device_index) const override {
|
||||
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
|
||||
if (C10_UNLIKELY(interp)) {
|
||||
(*interp)->trace_gpu_device_synchronization(c10::kXPU);
|
||||
}
|
||||
c10::xpu::syncStreamsOnDevice(device_index);
|
||||
}
|
||||
|
||||
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
|
||||
const override {
|
||||
const XPUStream xpu_stream{stream};
|
||||
|
|
|
|||
17
docs/source/accelerator.rst
Normal file
17
docs/source/accelerator.rst
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
torch.accelerator
|
||||
===================================
|
||||
.. automodule:: torch.accelerator
|
||||
.. currentmodule:: torch.accelerator
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
device_count
|
||||
is_available
|
||||
current_accelerator
|
||||
set_device_idx
|
||||
current_device_idx
|
||||
set_stream
|
||||
current_stream
|
||||
synchronize
|
||||
|
|
@ -64,6 +64,7 @@ Features described in this documentation are classified by release status:
|
|||
torch.amp <amp>
|
||||
torch.autograd <autograd>
|
||||
torch.library <library>
|
||||
accelerator
|
||||
cpu
|
||||
cuda
|
||||
torch.cuda.memory <torch_cuda_memory>
|
||||
|
|
|
|||
|
|
@ -2183,6 +2183,15 @@ def _set_worker_pids(
|
|||
def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs
|
||||
def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails
|
||||
|
||||
# Defined in torch/csrc/DeviceAccelerator.cpp
|
||||
def _accelerator_getAccelerator() -> _device: ...
|
||||
def _accelerator_deviceCount() -> _int: ...
|
||||
def _accelerator_setDeviceIndex(device_index: _int) -> None: ...
|
||||
def _accelerator_getDeviceIndex() -> _int: ...
|
||||
def _accelerator_setStream(Stream) -> None: ...
|
||||
def _accelerator_getStream(device_index: _int) -> Stream: ...
|
||||
def _accelerator_synchronizeDevice(device_index: _int) -> None: ...
|
||||
|
||||
# Defined in torch/csrc/jit/python/python_tracer.cpp
|
||||
class TracingState:
|
||||
def push_scope(self, scope_name: str) -> None: ...
|
||||
|
|
|
|||
|
|
@ -2092,6 +2092,7 @@ from torch import (
|
|||
__config__ as __config__,
|
||||
__future__ as __future__,
|
||||
_awaits as _awaits,
|
||||
accelerator as accelerator,
|
||||
autograd as autograd,
|
||||
backends as backends,
|
||||
cpu as cpu,
|
||||
|
|
|
|||
145
torch/accelerator/__init__.py
Normal file
145
torch/accelerator/__init__.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
r"""
|
||||
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from ._utils import _device_t, _get_device_index
|
||||
|
||||
|
||||
def device_count() -> int:
|
||||
r"""Return the number of current :ref:`accelerator<accelerators>` available.
|
||||
|
||||
Returns:
|
||||
int: the number of the current :ref:`accelerator<accelerators>` available.
|
||||
If there is no available accelerators, return 0.
|
||||
"""
|
||||
return torch._C._accelerator_deviceCount()
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Check if there is an available :ref:`accelerator<accelerators>`.
|
||||
|
||||
Returns:
|
||||
bool: A boolean indicating if there is an available :ref:`accelerator<accelerators>`.
|
||||
|
||||
Example::
|
||||
|
||||
>>> assert torch.accelerator.is_available() "No available accelerators detected."
|
||||
"""
|
||||
return device_count() > 0
|
||||
|
||||
|
||||
def current_accelerator() -> torch.device:
|
||||
r"""Return the device of the current :ref:`accelerator<accelerators>`.
|
||||
|
||||
Returns:
|
||||
torch.device: return the current accelerator as :class:`torch.device`.
|
||||
|
||||
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
|
||||
:func:`torch.accelerator.current_device_idx` to know the current index being used.
|
||||
And ensure to use :func:`torch.accelerator.is_available` to check if there is an available
|
||||
accelerator. If there is no available accelerator, this function will raise an exception.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest:
|
||||
>>> if torch.accelerator.is_available():
|
||||
>>> current_device = torch.accelerator.current_accelerator()
|
||||
>>> else:
|
||||
>>> current_device = torch.device("cpu")
|
||||
>>> if current_device.type == 'cuda':
|
||||
>>> is_half_supported = torch.cuda.has_half
|
||||
>>> elif current_device.type == 'xpu':
|
||||
>>> is_half_supported = torch.xpu.get_device_properties().has_fp16
|
||||
>>> elif current_device.type == 'cpu':
|
||||
>>> is_half_supported = True
|
||||
"""
|
||||
return torch._C._accelerator_getAccelerator()
|
||||
|
||||
|
||||
def current_device_idx() -> int:
|
||||
r"""Return the index of a currently selected device for the current :ref:`accelerator<accelerators>`.
|
||||
|
||||
Returns:
|
||||
int: the index of a currently selected device.
|
||||
"""
|
||||
return torch._C._accelerator_getDeviceIndex()
|
||||
|
||||
|
||||
def set_device_idx(device: _device_t, /) -> None:
|
||||
r"""Set the current device index to a given device.
|
||||
|
||||
Args:
|
||||
device (:class:`torch.device`, str, int): a given device that must match the current
|
||||
:ref:`accelerator<accelerators>` device type.
|
||||
|
||||
.. note:: This function is a no-op if this device index is negative.
|
||||
"""
|
||||
device_index = _get_device_index(device)
|
||||
torch._C._accelerator_setDeviceIndex(device_index)
|
||||
|
||||
|
||||
def current_stream(device: _device_t = None, /) -> torch.Stream:
|
||||
r"""Return the currently selected stream for a given device.
|
||||
|
||||
Args:
|
||||
device (:class:`torch.device`, str, int, optional): a given device that must match the current
|
||||
:ref:`accelerator<accelerators>` device type. If not given,
|
||||
use :func:`torch.accelerator.current_device_idx` by default.
|
||||
|
||||
Returns:
|
||||
torch.Stream: the currently selected stream for a given device.
|
||||
"""
|
||||
device_index = _get_device_index(device, True)
|
||||
return torch._C._accelerator_getStream(device_index)
|
||||
|
||||
|
||||
def set_stream(stream: torch.Stream) -> None:
|
||||
r"""Set the current stream to a given stream.
|
||||
|
||||
Args:
|
||||
stream (torch.Stream): a given stream that must match the current :ref:`accelerator<accelerators>` device type.
|
||||
|
||||
.. note:: This function will set the current device index to the device index of the given stream.
|
||||
"""
|
||||
torch._C._accelerator_setStream(stream)
|
||||
|
||||
|
||||
def synchronize(device: _device_t = None, /) -> None:
|
||||
r"""Wait for all kernels in all streams on the given device to complete.
|
||||
|
||||
Args:
|
||||
device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match
|
||||
the current :ref:`accelerator<accelerators>` device type. If not given,
|
||||
use :func:`torch.accelerator.current_device_idx` by default.
|
||||
|
||||
.. note:: This function is a no-op if the current :ref:`accelerator<accelerators>` is not initialized.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> assert torch.accelerator.is_available() "No available accelerators detected."
|
||||
>>> start_event = torch.Event(enable_timing=True)
|
||||
>>> end_event = torch.Event(enable_timing=True)
|
||||
>>> start_event.record()
|
||||
>>> tensor = torch.randn(100, device=torch.accelerator.current_accelerator())
|
||||
>>> sum = torch.sum(tensor)
|
||||
>>> end_event.record()
|
||||
>>> torch.accelerator.synchronize()
|
||||
>>> elapsed_time_ms = start_event.elapsed_time(end_event)
|
||||
"""
|
||||
device_index = _get_device_index(device, True)
|
||||
torch._C._accelerator_synchronizeDevice(device_index)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"current_accelerator",
|
||||
"current_device_idx",
|
||||
"current_stream",
|
||||
"device_count",
|
||||
"is_available",
|
||||
"set_device_idx",
|
||||
"set_stream",
|
||||
"synchronize",
|
||||
]
|
||||
28
torch/accelerator/_utils.py
Normal file
28
torch/accelerator/_utils.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from torch import device as _device
|
||||
|
||||
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
|
||||
def _get_device_index(device: _device_t, optional: bool = False) -> int:
|
||||
if isinstance(device, int):
|
||||
return device
|
||||
if isinstance(device, str):
|
||||
device = torch.device(device)
|
||||
device_index: Optional[int] = None
|
||||
if isinstance(device, torch.device):
|
||||
if torch.accelerator.current_accelerator() != device.type:
|
||||
raise ValueError(
|
||||
f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}."
|
||||
)
|
||||
device_index = device.index
|
||||
if device_index is None:
|
||||
if not optional:
|
||||
raise ValueError(
|
||||
f"Expected a torch.device with a specified index or an integer, but got:{device}"
|
||||
)
|
||||
return torch.accelerator.current_device_idx()
|
||||
return device_index
|
||||
82
torch/csrc/DeviceAccelerator.cpp
Normal file
82
torch/csrc/DeviceAccelerator.cpp
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
#include <c10/core/DeviceGuard.h>
|
||||
#include <torch/csrc/DeviceAccelerator.h>
|
||||
#include <torch/csrc/utils/device_lazy_init.h>
|
||||
|
||||
namespace torch::accelerator {
|
||||
|
||||
void initModule(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
m.def("_accelerator_getAccelerator", []() {
|
||||
// If no accelerator is currently available, raise an exception.
|
||||
return c10::Device(at::getAccelerator(true).value());
|
||||
});
|
||||
|
||||
m.def("_accelerator_deviceCount", []() {
|
||||
const auto device_type = at::getAccelerator(false);
|
||||
if (!device_type.has_value()) {
|
||||
return static_cast<c10::DeviceIndex>(0);
|
||||
}
|
||||
torch::utils::maybe_initialize_device(device_type.value());
|
||||
c10::impl::VirtualGuardImpl impl(device_type.value());
|
||||
return static_cast<c10::DeviceIndex>(impl.deviceCount());
|
||||
});
|
||||
|
||||
m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::getAccelerator(true).value();
|
||||
// If device index is negative, no-op
|
||||
if (device_index < 0) {
|
||||
return;
|
||||
}
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
impl.setDevice({device_type, device_index});
|
||||
});
|
||||
|
||||
m.def("_accelerator_getDeviceIndex", []() {
|
||||
const auto device_type = at::getAccelerator(true).value();
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
return static_cast<c10::DeviceIndex>(impl.getDevice().index());
|
||||
});
|
||||
|
||||
m.def("_accelerator_setStream", [](c10::Stream stream) {
|
||||
const auto device_type = at::getAccelerator(true).value();
|
||||
TORCH_CHECK(
|
||||
device_type == stream.device_type(),
|
||||
"stream's device type ",
|
||||
c10::DeviceTypeName(stream.device_type()),
|
||||
" doesn't match the current accelerator ",
|
||||
c10::DeviceTypeName(device_type));
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
// Set the current device to the device of stream
|
||||
if (impl.getDevice().index() != stream.device_index()) {
|
||||
impl.setDevice(stream.device());
|
||||
}
|
||||
impl.exchangeStream(stream);
|
||||
});
|
||||
|
||||
m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::getAccelerator(true).value();
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
return impl.getStream({device_type, device_index});
|
||||
});
|
||||
|
||||
m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) {
|
||||
const auto device_type = at::getAccelerator(true).value();
|
||||
if (!torch::utils::is_device_initialized(device_type)) {
|
||||
return;
|
||||
}
|
||||
torch::utils::maybe_initialize_device(device_type);
|
||||
c10::impl::VirtualGuardImpl impl(device_type);
|
||||
// impl.synchronizeDevice should can be safely called from any device
|
||||
{
|
||||
py::gil_scoped_release no_gil;
|
||||
impl.synchronizeDevice(device_index);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace torch::accelerator
|
||||
8
torch/csrc/DeviceAccelerator.h
Normal file
8
torch/csrc/DeviceAccelerator.h
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
#include <ATen/DeviceAccelerator.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
|
||||
namespace torch::accelerator {
|
||||
|
||||
void initModule(PyObject* module);
|
||||
|
||||
} // namespace torch::accelerator
|
||||
|
|
@ -42,6 +42,7 @@
|
|||
#include <ATen/ThreadLocalPythonObjects.h>
|
||||
#include <torch/csrc/DataLoader.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/DeviceAccelerator.h>
|
||||
#include <torch/csrc/Dtype.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/Event.h>
|
||||
|
|
@ -1733,6 +1734,7 @@ PyObject* initModule() {
|
|||
#endif
|
||||
torch::mtia::initModule(module);
|
||||
torch::cpu::initModule(module);
|
||||
torch::accelerator::initModule(module);
|
||||
torch::instruction_counter::initModule(module);
|
||||
torch::initVerboseBindings(module);
|
||||
ASSERT_TRUE(THPStorage_init(module));
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user