diff --git a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h index 6be1aed915e..4ec944034be 100644 --- a/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h +++ b/aten/src/ATen/hip/impl/HIPGuardImplMasqueradingAsCUDA.h @@ -216,6 +216,15 @@ struct HIPGuardImplMasqueradingAsCUDA final : public c10::impl::DeviceGuardImplI C10_HIP_CHECK(hipEventSynchronize(hip_event)); } + // Note: synchronizeDevice can be safely called from any device + void synchronizeDevice(const c10::DeviceIndex device_index) const override { + int orig_device{-1}; + C10_HIP_CHECK(hipGetDevice(&orig_device)); + C10_HIP_CHECK(hipSetDevice(device_index)); + C10_HIP_CHECK(hipDeviceSynchronize()); + C10_HIP_CHECK(hipSetDevice(orig_device)); + } + void recordDataPtrOnStream( const c10::DataPtr& data_ptr, const Stream& stream) const override { diff --git a/aten/src/ATen/mps/MPSGuardImpl.h b/aten/src/ATen/mps/MPSGuardImpl.h index cb50df2faea..6132cd8055e 100644 --- a/aten/src/ATen/mps/MPSGuardImpl.h +++ b/aten/src/ATen/mps/MPSGuardImpl.h @@ -111,6 +111,8 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface bool queryEvent(void* event) const override; + void synchronizeDevice(const DeviceIndex device_index) const override; + }; /// A variant of OptionalDeviceGuard that is specialized for MPS. diff --git a/aten/src/ATen/mps/MPSGuardImpl.mm b/aten/src/ATen/mps/MPSGuardImpl.mm index f832516c5da..a3dea4cd7c4 100644 --- a/aten/src/ATen/mps/MPSGuardImpl.mm +++ b/aten/src/ATen/mps/MPSGuardImpl.mm @@ -42,4 +42,8 @@ bool MPSGuardImpl::queryEvent(void* event) const { return mps_event->query(); } +void MPSGuardImpl::synchronizeDevice(const DeviceIndex device_index) const { + at::mps::getDefaultMPSStream()->synchronize(SyncType::COMMIT_AND_WAIT); +} + } // namespace at::mps diff --git a/build_variables.bzl b/build_variables.bzl index 56f7bb6cf5a..36154e93e50 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -795,6 +795,7 @@ libtorch_python_xpu_sources = [ libtorch_python_core_sources = [ "torch/csrc/DataLoader.cpp", + "torch/csrc/DeviceAccelerator.cpp", "torch/csrc/Device.cpp", "torch/csrc/Dtype.cpp", "torch/csrc/DynamicTypes.cpp", diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index a9b9b1219df..f145db0d234 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -212,6 +212,15 @@ struct C10_API DeviceGuardImplInterface { TORCH_CHECK(false, "Backend doesn't support synchronizing events."); } + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the device has been completed. + */ + virtual void synchronizeDevice(const DeviceIndex /*device_index*/) const { + TORCH_CHECK( + false, "Backend doesn't support synchronizing all streams on device."); + } + /** * Ensure the caching allocator (if any) is aware that the given DataPtr is * being used on the given stream, and that it should thus avoid recycling the diff --git a/c10/core/impl/VirtualGuardImpl.h b/c10/core/impl/VirtualGuardImpl.h index 1d26eef0c9e..cdee2aa1a64 100644 --- a/c10/core/impl/VirtualGuardImpl.h +++ b/c10/core/impl/VirtualGuardImpl.h @@ -96,6 +96,10 @@ class VirtualGuardImpl final : public DeviceGuardImplInterface { return impl_->synchronizeEvent(event); } + void synchronizeDevice(const DeviceIndex device_index) const override { + return impl_->synchronizeDevice(device_index); + } + private: const DeviceGuardImplInterface* impl_ = nullptr; }; diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index 1ef2fcb2c08..dd81dcf51fd 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -219,6 +219,19 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { C10_CUDA_CHECK(cudaEventSynchronize(cuda_event)); } + // Note: synchronizeDevice can be safely called from any device + void synchronizeDevice(const c10::DeviceIndex device_index) const override { + DeviceIndex orig_device{-1}; + C10_CUDA_CHECK(c10::cuda::GetDevice(&orig_device)); + C10_CUDA_CHECK(c10::cuda::SetDevice(device_index)); + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_device_synchronization(c10::kCUDA); + } + C10_CUDA_CHECK(cudaDeviceSynchronize()); + C10_CUDA_CHECK(c10::cuda::SetDevice(orig_device)); + } + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override { CUDAStream cuda_stream{stream}; diff --git a/c10/xpu/impl/XPUGuardImpl.h b/c10/xpu/impl/XPUGuardImpl.h index 6213eccd2b2..5cb60a6a850 100644 --- a/c10/xpu/impl/XPUGuardImpl.h +++ b/c10/xpu/impl/XPUGuardImpl.h @@ -163,6 +163,14 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { xpu_event->wait_and_throw(); } + void synchronizeDevice(const c10::DeviceIndex device_index) const override { + const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); + if (C10_UNLIKELY(interp)) { + (*interp)->trace_gpu_device_synchronization(c10::kXPU); + } + c10::xpu::syncStreamsOnDevice(device_index); + } + void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) const override { const XPUStream xpu_stream{stream}; diff --git a/docs/source/accelerator.rst b/docs/source/accelerator.rst new file mode 100644 index 00000000000..6e4d7a541ee --- /dev/null +++ b/docs/source/accelerator.rst @@ -0,0 +1,17 @@ +torch.accelerator +=================================== +.. automodule:: torch.accelerator +.. currentmodule:: torch.accelerator + +.. autosummary:: + :toctree: generated + :nosignatures: + + device_count + is_available + current_accelerator + set_device_idx + current_device_idx + set_stream + current_stream + synchronize diff --git a/docs/source/index.rst b/docs/source/index.rst index 773e6420429..61325ff0ba8 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -64,6 +64,7 @@ Features described in this documentation are classified by release status: torch.amp torch.autograd torch.library + accelerator cpu cuda torch.cuda.memory diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 4fec36e8e65..930e4be2420 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -2183,6 +2183,15 @@ def _set_worker_pids( def _remove_worker_pids(loader_id: _int) -> None: ... # THPModule_removeWorkerPIDs def _error_if_any_worker_fails() -> None: ... # THPModule_errorIfAnyWorkerFails +# Defined in torch/csrc/DeviceAccelerator.cpp +def _accelerator_getAccelerator() -> _device: ... +def _accelerator_deviceCount() -> _int: ... +def _accelerator_setDeviceIndex(device_index: _int) -> None: ... +def _accelerator_getDeviceIndex() -> _int: ... +def _accelerator_setStream(Stream) -> None: ... +def _accelerator_getStream(device_index: _int) -> Stream: ... +def _accelerator_synchronizeDevice(device_index: _int) -> None: ... + # Defined in torch/csrc/jit/python/python_tracer.cpp class TracingState: def push_scope(self, scope_name: str) -> None: ... diff --git a/torch/__init__.py b/torch/__init__.py index 5ff3c610abf..144af1f508e 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2092,6 +2092,7 @@ from torch import ( __config__ as __config__, __future__ as __future__, _awaits as _awaits, + accelerator as accelerator, autograd as autograd, backends as backends, cpu as cpu, diff --git a/torch/accelerator/__init__.py b/torch/accelerator/__init__.py new file mode 100644 index 00000000000..f4d7593175b --- /dev/null +++ b/torch/accelerator/__init__.py @@ -0,0 +1,145 @@ +r""" +This package introduces support for the current :ref:`accelerator` in python. +""" + +import torch + +from ._utils import _device_t, _get_device_index + + +def device_count() -> int: + r"""Return the number of current :ref:`accelerator` available. + + Returns: + int: the number of the current :ref:`accelerator` available. + If there is no available accelerators, return 0. + """ + return torch._C._accelerator_deviceCount() + + +def is_available() -> bool: + r"""Check if there is an available :ref:`accelerator`. + + Returns: + bool: A boolean indicating if there is an available :ref:`accelerator`. + + Example:: + + >>> assert torch.accelerator.is_available() "No available accelerators detected." + """ + return device_count() > 0 + + +def current_accelerator() -> torch.device: + r"""Return the device of the current :ref:`accelerator`. + + Returns: + torch.device: return the current accelerator as :class:`torch.device`. + + .. note:: The index of the returned :class:`torch.device` will be ``None``, please use + :func:`torch.accelerator.current_device_idx` to know the current index being used. + And ensure to use :func:`torch.accelerator.is_available` to check if there is an available + accelerator. If there is no available accelerator, this function will raise an exception. + + Example:: + + >>> # xdoctest: + >>> if torch.accelerator.is_available(): + >>> current_device = torch.accelerator.current_accelerator() + >>> else: + >>> current_device = torch.device("cpu") + >>> if current_device.type == 'cuda': + >>> is_half_supported = torch.cuda.has_half + >>> elif current_device.type == 'xpu': + >>> is_half_supported = torch.xpu.get_device_properties().has_fp16 + >>> elif current_device.type == 'cpu': + >>> is_half_supported = True + """ + return torch._C._accelerator_getAccelerator() + + +def current_device_idx() -> int: + r"""Return the index of a currently selected device for the current :ref:`accelerator`. + + Returns: + int: the index of a currently selected device. + """ + return torch._C._accelerator_getDeviceIndex() + + +def set_device_idx(device: _device_t, /) -> None: + r"""Set the current device index to a given device. + + Args: + device (:class:`torch.device`, str, int): a given device that must match the current + :ref:`accelerator` device type. + + .. note:: This function is a no-op if this device index is negative. + """ + device_index = _get_device_index(device) + torch._C._accelerator_setDeviceIndex(device_index) + + +def current_stream(device: _device_t = None, /) -> torch.Stream: + r"""Return the currently selected stream for a given device. + + Args: + device (:class:`torch.device`, str, int, optional): a given device that must match the current + :ref:`accelerator` device type. If not given, + use :func:`torch.accelerator.current_device_idx` by default. + + Returns: + torch.Stream: the currently selected stream for a given device. + """ + device_index = _get_device_index(device, True) + return torch._C._accelerator_getStream(device_index) + + +def set_stream(stream: torch.Stream) -> None: + r"""Set the current stream to a given stream. + + Args: + stream (torch.Stream): a given stream that must match the current :ref:`accelerator` device type. + + .. note:: This function will set the current device index to the device index of the given stream. + """ + torch._C._accelerator_setStream(stream) + + +def synchronize(device: _device_t = None, /) -> None: + r"""Wait for all kernels in all streams on the given device to complete. + + Args: + device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match + the current :ref:`accelerator` device type. If not given, + use :func:`torch.accelerator.current_device_idx` by default. + + .. note:: This function is a no-op if the current :ref:`accelerator` is not initialized. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA) + >>> assert torch.accelerator.is_available() "No available accelerators detected." + >>> start_event = torch.Event(enable_timing=True) + >>> end_event = torch.Event(enable_timing=True) + >>> start_event.record() + >>> tensor = torch.randn(100, device=torch.accelerator.current_accelerator()) + >>> sum = torch.sum(tensor) + >>> end_event.record() + >>> torch.accelerator.synchronize() + >>> elapsed_time_ms = start_event.elapsed_time(end_event) + """ + device_index = _get_device_index(device, True) + torch._C._accelerator_synchronizeDevice(device_index) + + +__all__ = [ + "current_accelerator", + "current_device_idx", + "current_stream", + "device_count", + "is_available", + "set_device_idx", + "set_stream", + "synchronize", +] diff --git a/torch/accelerator/_utils.py b/torch/accelerator/_utils.py new file mode 100644 index 00000000000..abaa00c44b5 --- /dev/null +++ b/torch/accelerator/_utils.py @@ -0,0 +1,28 @@ +from typing import Optional, Union + +import torch +from torch import device as _device + + +_device_t = Union[_device, str, int, None] + + +def _get_device_index(device: _device_t, optional: bool = False) -> int: + if isinstance(device, int): + return device + if isinstance(device, str): + device = torch.device(device) + device_index: Optional[int] = None + if isinstance(device, torch.device): + if torch.accelerator.current_accelerator() != device.type: + raise ValueError( + f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}." + ) + device_index = device.index + if device_index is None: + if not optional: + raise ValueError( + f"Expected a torch.device with a specified index or an integer, but got:{device}" + ) + return torch.accelerator.current_device_idx() + return device_index diff --git a/torch/csrc/DeviceAccelerator.cpp b/torch/csrc/DeviceAccelerator.cpp new file mode 100644 index 00000000000..67bd30acbf4 --- /dev/null +++ b/torch/csrc/DeviceAccelerator.cpp @@ -0,0 +1,82 @@ +#include +#include +#include + +namespace torch::accelerator { + +void initModule(PyObject* module) { + auto m = py::handle(module).cast(); + + m.def("_accelerator_getAccelerator", []() { + // If no accelerator is currently available, raise an exception. + return c10::Device(at::getAccelerator(true).value()); + }); + + m.def("_accelerator_deviceCount", []() { + const auto device_type = at::getAccelerator(false); + if (!device_type.has_value()) { + return static_cast(0); + } + torch::utils::maybe_initialize_device(device_type.value()); + c10::impl::VirtualGuardImpl impl(device_type.value()); + return static_cast(impl.deviceCount()); + }); + + m.def("_accelerator_setDeviceIndex", [](c10::DeviceIndex device_index) { + const auto device_type = at::getAccelerator(true).value(); + // If device index is negative, no-op + if (device_index < 0) { + return; + } + torch::utils::maybe_initialize_device(device_type); + c10::impl::VirtualGuardImpl impl(device_type); + impl.setDevice({device_type, device_index}); + }); + + m.def("_accelerator_getDeviceIndex", []() { + const auto device_type = at::getAccelerator(true).value(); + torch::utils::maybe_initialize_device(device_type); + c10::impl::VirtualGuardImpl impl(device_type); + return static_cast(impl.getDevice().index()); + }); + + m.def("_accelerator_setStream", [](c10::Stream stream) { + const auto device_type = at::getAccelerator(true).value(); + TORCH_CHECK( + device_type == stream.device_type(), + "stream's device type ", + c10::DeviceTypeName(stream.device_type()), + " doesn't match the current accelerator ", + c10::DeviceTypeName(device_type)); + torch::utils::maybe_initialize_device(device_type); + c10::impl::VirtualGuardImpl impl(device_type); + // Set the current device to the device of stream + if (impl.getDevice().index() != stream.device_index()) { + impl.setDevice(stream.device()); + } + impl.exchangeStream(stream); + }); + + m.def("_accelerator_getStream", [](c10::DeviceIndex device_index) { + const auto device_type = at::getAccelerator(true).value(); + torch::utils::maybe_initialize_device(device_type); + c10::impl::VirtualGuardImpl impl(device_type); + return impl.getStream({device_type, device_index}); + }); + + m.def("_accelerator_synchronizeDevice", [](c10::DeviceIndex device_index) { + const auto device_type = at::getAccelerator(true).value(); + if (!torch::utils::is_device_initialized(device_type)) { + return; + } + torch::utils::maybe_initialize_device(device_type); + c10::impl::VirtualGuardImpl impl(device_type); + // impl.synchronizeDevice should can be safely called from any device + { + py::gil_scoped_release no_gil; + impl.synchronizeDevice(device_index); + } + }); +} + +} // namespace torch::accelerator diff --git a/torch/csrc/DeviceAccelerator.h b/torch/csrc/DeviceAccelerator.h new file mode 100644 index 00000000000..87b20e4576f --- /dev/null +++ b/torch/csrc/DeviceAccelerator.h @@ -0,0 +1,8 @@ +#include +#include + +namespace torch::accelerator { + +void initModule(PyObject* module); + +} // namespace torch::accelerator diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index 416e5b5d72b..e11294418a3 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -1733,6 +1734,7 @@ PyObject* initModule() { #endif torch::mtia::initModule(module); torch::cpu::initModule(module); + torch::accelerator::initModule(module); torch::instruction_counter::initModule(module); torch::initVerboseBindings(module); ASSERT_TRUE(THPStorage_init(module));