Revert "Move get accelerator to use build time flags when possible (#146098)"

This reverts commit 17302b4bc8.

Reverted https://github.com/pytorch/pytorch/pull/146098 on behalf of https://github.com/albanD due to Still fails with cuda build on a non-gpu machine ([comment](https://github.com/pytorch/pytorch/pull/146098#issuecomment-2707191770))
This commit is contained in:
PyTorch MergeBot 2025-03-07 18:59:58 +00:00
parent 1239176fe7
commit b246cd7b82
13 changed files with 54 additions and 144 deletions

View File

@ -5,53 +5,38 @@
namespace at::accelerator {
std::optional<c10::DeviceType> getAccelerator(bool checked) {
// 1. Check PrivateUse1 backends
// We explicitly allow PrivateUse1 and another device at the same time as we
// use this for testing. Whenever a PrivateUse1 device is registered, use it
// first.
// Note that this check is only for hook registration and thus is NOT initializing
// the device or poisoning fork.
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
if (at::has##device_name()) { \
device_type = k##device_name; \
TORCH_CHECK( \
!is_accelerator_detected, \
"Cannot have ", \
device_type.value(), \
" with other accelerators."); \
is_accelerator_detected = true; \
}
if (is_privateuse1_backend_registered()) {
// We explicitly allow PrivateUse1 and another device at the same time as we
// use this for testing. Whenever a PrivateUse1 device is registered, use it
// first.
return kPrivateUse1;
}
// 2. Check runtime backends
// This state is temporary, these runtime checks should be moved to compile-time
// once they provide the new isBuilt API and we are sure they're never in the
// same binary as another accelerator.
#define DETECT_RUNTIME_ACCELERATOR(device_name) \
if (at::has##device_name()) { \
return k##device_name; \
}
DETECT_RUNTIME_ACCELERATOR(MTIA)
#undef DETECT_RUNTIME_ACCELERATOR
// 2. Check compile-time backends
std::optional<c10::DeviceType> device_type = std::nullopt;
#define DETECT_AND_ASSIGN_ACCELERATOR_COMP(device_name) \
if (at::detail::get##device_name##Hooks().isBuilt()) { \
TORCH_CHECK( \
!device_type.has_value(), \
"Cannot have both " #device_name " and ", \
device_type.value(), "."); \
device_type = k##device_name; \
}
DETECT_AND_ASSIGN_ACCELERATOR_COMP(CUDA)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(XPU)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HIP)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(MPS)
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HPU)
bool is_accelerator_detected = false;
DETECT_AND_ASSIGN_ACCELERATOR(CUDA)
DETECT_AND_ASSIGN_ACCELERATOR(MTIA)
DETECT_AND_ASSIGN_ACCELERATOR(XPU)
DETECT_AND_ASSIGN_ACCELERATOR(HIP)
DETECT_AND_ASSIGN_ACCELERATOR(MPS)
DETECT_AND_ASSIGN_ACCELERATOR(HPU)
if (checked) {
TORCH_CHECK(
device_type, "Cannot access accelerator device when none is available.")
}
return device_type;
#undef DETECT_AND_ASSIGN_ACCELERATOR_COMP
#undef DETECT_AND_ASSIGN_ACCELERATOR
}
bool isAccelerator(c10::DeviceType device_type) {

View File

@ -33,8 +33,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
bool hasROCM() const override;
const at::cuda::NVRTC& nvrtc() const override;
DeviceIndex current_device() const override;
bool isBuilt() const override {return true;}
bool isAvailable() const override {return hasCUDA();}
bool hasPrimaryContext(DeviceIndex device_index) const override;
Allocator* getCUDADeviceAllocator() const override;
Allocator* getPinnedMemoryAllocator() const override;

View File

@ -20,23 +20,6 @@ struct TORCH_API AcceleratorHooksInterface {
// squelch -Werror=non-virtual-dtor
virtual ~AcceleratorHooksInterface() = default;
// Whether this backend was enabled at compilation time.
// This function should NEVER throw.
virtual bool isBuilt() const {
return false;
}
// Whether this backend can be used at runtime, meaning it was built,
// its runtime dependencies are available (driver) and at least one
// supported device can be used.
// This function should NEVER throw. This function should NOT initialize the context
// on any device (result of hasPrimaryContext below should not change).
// While it is acceptable for this function to poison fork, it is
// recommended to avoid doing so whenever possible.
virtual bool isAvailable() const {
return false;
}
// Whether the device at device_index is fully initialized or not.
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;

View File

@ -54,12 +54,7 @@ struct MPSHooks : public at::MPSHooksInterface {
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
const override;
bool isBuilt() const override {
return true;
}
bool isAvailable() const override {
return hasMPS();
}
// Compatibility with Accelerator API
bool hasPrimaryContext(DeviceIndex device_index) const override {
// When MPS is available, it is always in use for the one device.
return true;

View File

@ -84,14 +84,9 @@ bool XPUHooks::isPinnedPtr(const void* data) const {
sycl::get_pointer_type(data, c10::xpu::get_device_context());
}
bool XPUHooks::isAvailable() const {
return at::xpu::is_available();
}
bool XPUHooks::hasPrimaryContext(DeviceIndex device_index) const {
// The default context is utilized for each device.
// So it always returns true if a device is available.
return isAvailable();
// The default context is utilized for each device. So it always returns true.
return true;
}
DeviceIndex XPUHooks::deviceCount() const {

View File

@ -19,11 +19,6 @@ struct XPUHooks : public at::XPUHooksInterface {
DeviceIndex current_device() const override;
void deviceSynchronize(DeviceIndex device_index) const override;
Allocator* getPinnedMemoryAllocator() const override;
bool isBuilt() const override {
return true;
}
bool isAvailable() const override;
bool isPinnedPtr(const void* data) const override;
bool hasPrimaryContext(DeviceIndex device_index) const override;
DeviceIndex deviceCount() const override;

View File

@ -152,19 +152,7 @@ us to use the current accelerator as the default device for relevant concepts su
Stream device_type, FSDP, etc.
As of today, accelerator devices are (in no particular order) :doc:`"CUDA" <cuda>`, :doc:`"MTIA" <mtia>`,
:doc:`"XPU" <xpu>`, :doc:`"MPS" <mps>`, "HPU", and PrivateUse1 (many device not in the PyTorch repo itself).
Many tools in the PyTorch Ecosystem use fork to create subprocesses (for example dataloading
or intra-op parallelism), it is thus important to delay as much as possible any
operation that would prevent further forks. This is especially important here as most accelerator's initialization has such effect.
In practice, you should keep in mind that checking :func:`torch.accelerator.current_accelerator`
is a compile-time check by default, it is thus always fork-safe.
On the contrary, passing the ``check_available=True`` flag to this function or calling
:func:`torch.accelerator.is_available()` will usually prevent later fork.
Some backends provide an experimental opt-in option to make the runtime availability
check fork-safe. When using the CUDA device ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` can be
used for example.
:doc:`"XPU" <xpu>`, and PrivateUse1 (many device not in the PyTorch repo itself).
.. autosummary::
:toctree: generated

View File

@ -3597,9 +3597,8 @@ def fork_and_check_is_pinned():
def worker(conn):
try:
x = torch.randn(10)
x.is_pinned()
dev = torch.accelerator.current_accelerator()
x = torch.ones(10, device=dev)[0].item()
x.is_pinned(device="cuda")
x = torch.ones(10, device="cuda")[0].item()
conn.send(x)
except Exception as e:
conn.send(str(e))
@ -3619,7 +3618,7 @@ def fork_and_check_is_pinned():
x = torch.randn(10)
# check that is_pinned won't poison future fork
x.is_pinned()
x.is_pinned(device="cuda")
ret = fork_and_check_is_pinned()
print(ret)

View File

@ -2,7 +2,6 @@ r"""
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
"""
from typing import Optional
from typing_extensions import deprecated
import torch
@ -35,9 +34,7 @@ def device_count() -> int:
def is_available() -> bool:
r"""Check if the current accelerator is available at runtime: it was build, all the
required drivers are available and at least one device is visible.
See :ref:`accelerator<accelerators>` for details.
r"""Check if there is an available :ref:`accelerator<accelerators>`.
Returns:
bool: A boolean indicating if there is an available :ref:`accelerator<accelerators>`.
@ -46,47 +43,35 @@ def is_available() -> bool:
>>> assert torch.accelerator.is_available() "No available accelerators detected."
"""
# Why not just check "device_count() > 0" like other is_available call?
# Because device like CUDA have a python implementation of is_available that is
# non-poisoning and some features like Dataloader rely on it.
# So we are careful to delegate to the Python version of the accelerator here
acc = current_accelerator()
if acc is None:
return False
mod = torch.get_device_module(acc)
return mod.is_available()
return device_count() > 0
def current_accelerator(check_available: bool = False) -> Optional[torch.device]:
r"""Return the device of the accelerator available at compilation time.
If no accelerator were available at compilation time, returns None.
See :ref:`accelerator<accelerators>` for details.
Args:
check_available (bool, optional): if True, will also do a runtime check to see
if the device :func:`torch.accelerator.is_available` on top of the compile-time
check.
Default: ``False``
def current_accelerator() -> torch.device:
r"""Return the device of the current :ref:`accelerator<accelerators>`.
Returns:
torch.device: return the current accelerator as :class:`torch.device`.
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
:func:`torch.accelerator.current_device_index` to know the current index being used.
And ensure to use :func:`torch.accelerator.is_available` to check if there is an available
accelerator. If there is no available accelerator, this function will raise an exception.
Example::
>>> # xdoctest:
>>> # If an accelerator is available, sent the model to it
>>> model = torch.nn.Linear(2, 2)
>>> if (current_device := current_accelerator(check_available=True)) is not None:
>>> model.to(current_device)
>>> if torch.accelerator.is_available():
>>> current_device = torch.accelerator.current_accelerator()
>>> else:
>>> current_device = torch.device("cpu")
>>> if current_device.type == 'cuda':
>>> is_half_supported = torch.cuda.has_half
>>> elif current_device.type == 'xpu':
>>> is_half_supported = torch.xpu.get_device_properties().has_fp16
>>> elif current_device.type == 'cpu':
>>> is_half_supported = True
"""
if (acc := torch._C._accelerator_getAccelerator()) is not None:
if (not check_available) or (check_available and is_available()):
return acc
return None
return torch._C._accelerator_getAccelerator()
def current_device_index() -> int:

View File

@ -11,10 +11,7 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int:
device = torch.device(device)
device_index: Optional[int] = None
if isinstance(device, torch.device):
acc = torch.accelerator.current_accelerator()
if acc is None:
raise RuntimeError("Accelerator expected")
if acc.type != device.type:
if torch.accelerator.current_accelerator().type != device.type:
raise ValueError(
f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}."
)

View File

@ -6,14 +6,9 @@ namespace torch::accelerator {
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_accelerator_getAccelerator", []() -> std::optional<c10::Device> {
// If no accelerator was available at compile time, return None.
auto acc = at::getAccelerator(false);
if (acc.has_value()) {
return acc.value();
} else {
return std::nullopt;
}
m.def("_accelerator_getAccelerator", []() {
// If no accelerator is currently available, raise an exception.
return c10::Device(at::getAccelerator(true).value());
});
m.def("_accelerator_deviceCount", []() {

View File

@ -1246,13 +1246,9 @@ def _save(
if (
config.save.use_pinned_memory_for_d2h
and (
acc := torch.accelerator.current_accelerator(
check_available=True
)
)
is not None
and acc.type == storage.device.type
and torch.accelerator.is_available()
and torch.accelerator.current_accelerator().type
== storage.device.type
):
new_storage = torch.empty(
num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True

View File

@ -672,8 +672,7 @@ class _BaseDataLoaderIter:
# memory allocation for MPS is fixed.
if (
self._pin_memory
and (acc := torch.accelerator.current_accelerator()) is not None
and acc.type == "mps"
and torch.accelerator.current_accelerator().type == "mps"
):
self._pin_memory = False
warn_msg = (