mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Move get accelerator to use build time flags when possible (#146098)"
This reverts commit 17302b4bc8.
Reverted https://github.com/pytorch/pytorch/pull/146098 on behalf of https://github.com/albanD due to Still fails with cuda build on a non-gpu machine ([comment](https://github.com/pytorch/pytorch/pull/146098#issuecomment-2707191770))
This commit is contained in:
parent
1239176fe7
commit
b246cd7b82
|
|
@ -5,53 +5,38 @@
|
|||
namespace at::accelerator {
|
||||
|
||||
std::optional<c10::DeviceType> getAccelerator(bool checked) {
|
||||
// 1. Check PrivateUse1 backends
|
||||
// We explicitly allow PrivateUse1 and another device at the same time as we
|
||||
// use this for testing. Whenever a PrivateUse1 device is registered, use it
|
||||
// first.
|
||||
// Note that this check is only for hook registration and thus is NOT initializing
|
||||
// the device or poisoning fork.
|
||||
#define DETECT_AND_ASSIGN_ACCELERATOR(device_name) \
|
||||
if (at::has##device_name()) { \
|
||||
device_type = k##device_name; \
|
||||
TORCH_CHECK( \
|
||||
!is_accelerator_detected, \
|
||||
"Cannot have ", \
|
||||
device_type.value(), \
|
||||
" with other accelerators."); \
|
||||
is_accelerator_detected = true; \
|
||||
}
|
||||
|
||||
if (is_privateuse1_backend_registered()) {
|
||||
// We explicitly allow PrivateUse1 and another device at the same time as we
|
||||
// use this for testing. Whenever a PrivateUse1 device is registered, use it
|
||||
// first.
|
||||
return kPrivateUse1;
|
||||
}
|
||||
|
||||
// 2. Check runtime backends
|
||||
// This state is temporary, these runtime checks should be moved to compile-time
|
||||
// once they provide the new isBuilt API and we are sure they're never in the
|
||||
// same binary as another accelerator.
|
||||
#define DETECT_RUNTIME_ACCELERATOR(device_name) \
|
||||
if (at::has##device_name()) { \
|
||||
return k##device_name; \
|
||||
}
|
||||
|
||||
DETECT_RUNTIME_ACCELERATOR(MTIA)
|
||||
|
||||
#undef DETECT_RUNTIME_ACCELERATOR
|
||||
|
||||
// 2. Check compile-time backends
|
||||
std::optional<c10::DeviceType> device_type = std::nullopt;
|
||||
|
||||
#define DETECT_AND_ASSIGN_ACCELERATOR_COMP(device_name) \
|
||||
if (at::detail::get##device_name##Hooks().isBuilt()) { \
|
||||
TORCH_CHECK( \
|
||||
!device_type.has_value(), \
|
||||
"Cannot have both " #device_name " and ", \
|
||||
device_type.value(), "."); \
|
||||
device_type = k##device_name; \
|
||||
}
|
||||
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(CUDA)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(XPU)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HIP)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(MPS)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR_COMP(HPU)
|
||||
bool is_accelerator_detected = false;
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(CUDA)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(MTIA)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(XPU)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(HIP)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(MPS)
|
||||
DETECT_AND_ASSIGN_ACCELERATOR(HPU)
|
||||
if (checked) {
|
||||
TORCH_CHECK(
|
||||
device_type, "Cannot access accelerator device when none is available.")
|
||||
}
|
||||
return device_type;
|
||||
|
||||
#undef DETECT_AND_ASSIGN_ACCELERATOR_COMP
|
||||
#undef DETECT_AND_ASSIGN_ACCELERATOR
|
||||
}
|
||||
|
||||
bool isAccelerator(c10::DeviceType device_type) {
|
||||
|
|
|
|||
|
|
@ -33,8 +33,6 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
|||
bool hasROCM() const override;
|
||||
const at::cuda::NVRTC& nvrtc() const override;
|
||||
DeviceIndex current_device() const override;
|
||||
bool isBuilt() const override {return true;}
|
||||
bool isAvailable() const override {return hasCUDA();}
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override;
|
||||
Allocator* getCUDADeviceAllocator() const override;
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
|
|
|
|||
|
|
@ -20,23 +20,6 @@ struct TORCH_API AcceleratorHooksInterface {
|
|||
// squelch -Werror=non-virtual-dtor
|
||||
virtual ~AcceleratorHooksInterface() = default;
|
||||
|
||||
// Whether this backend was enabled at compilation time.
|
||||
// This function should NEVER throw.
|
||||
virtual bool isBuilt() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Whether this backend can be used at runtime, meaning it was built,
|
||||
// its runtime dependencies are available (driver) and at least one
|
||||
// supported device can be used.
|
||||
// This function should NEVER throw. This function should NOT initialize the context
|
||||
// on any device (result of hasPrimaryContext below should not change).
|
||||
// While it is acceptable for this function to poison fork, it is
|
||||
// recommended to avoid doing so whenever possible.
|
||||
virtual bool isAvailable() const {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Whether the device at device_index is fully initialized or not.
|
||||
virtual bool hasPrimaryContext(DeviceIndex device_index) const = 0;
|
||||
|
||||
|
|
|
|||
|
|
@ -54,12 +54,7 @@ struct MPSHooks : public at::MPSHooksInterface {
|
|||
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
|
||||
const override;
|
||||
|
||||
bool isBuilt() const override {
|
||||
return true;
|
||||
}
|
||||
bool isAvailable() const override {
|
||||
return hasMPS();
|
||||
}
|
||||
// Compatibility with Accelerator API
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
// When MPS is available, it is always in use for the one device.
|
||||
return true;
|
||||
|
|
|
|||
|
|
@ -84,14 +84,9 @@ bool XPUHooks::isPinnedPtr(const void* data) const {
|
|||
sycl::get_pointer_type(data, c10::xpu::get_device_context());
|
||||
}
|
||||
|
||||
bool XPUHooks::isAvailable() const {
|
||||
return at::xpu::is_available();
|
||||
}
|
||||
|
||||
bool XPUHooks::hasPrimaryContext(DeviceIndex device_index) const {
|
||||
// The default context is utilized for each device.
|
||||
// So it always returns true if a device is available.
|
||||
return isAvailable();
|
||||
// The default context is utilized for each device. So it always returns true.
|
||||
return true;
|
||||
}
|
||||
|
||||
DeviceIndex XPUHooks::deviceCount() const {
|
||||
|
|
|
|||
|
|
@ -19,11 +19,6 @@ struct XPUHooks : public at::XPUHooksInterface {
|
|||
DeviceIndex current_device() const override;
|
||||
void deviceSynchronize(DeviceIndex device_index) const override;
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
|
||||
bool isBuilt() const override {
|
||||
return true;
|
||||
}
|
||||
bool isAvailable() const override;
|
||||
bool isPinnedPtr(const void* data) const override;
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override;
|
||||
DeviceIndex deviceCount() const override;
|
||||
|
|
|
|||
|
|
@ -152,19 +152,7 @@ us to use the current accelerator as the default device for relevant concepts su
|
|||
Stream device_type, FSDP, etc.
|
||||
|
||||
As of today, accelerator devices are (in no particular order) :doc:`"CUDA" <cuda>`, :doc:`"MTIA" <mtia>`,
|
||||
:doc:`"XPU" <xpu>`, :doc:`"MPS" <mps>`, "HPU", and PrivateUse1 (many device not in the PyTorch repo itself).
|
||||
|
||||
Many tools in the PyTorch Ecosystem use fork to create subprocesses (for example dataloading
|
||||
or intra-op parallelism), it is thus important to delay as much as possible any
|
||||
operation that would prevent further forks. This is especially important here as most accelerator's initialization has such effect.
|
||||
In practice, you should keep in mind that checking :func:`torch.accelerator.current_accelerator`
|
||||
is a compile-time check by default, it is thus always fork-safe.
|
||||
On the contrary, passing the ``check_available=True`` flag to this function or calling
|
||||
:func:`torch.accelerator.is_available()` will usually prevent later fork.
|
||||
|
||||
Some backends provide an experimental opt-in option to make the runtime availability
|
||||
check fork-safe. When using the CUDA device ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` can be
|
||||
used for example.
|
||||
:doc:`"XPU" <xpu>`, and PrivateUse1 (many device not in the PyTorch repo itself).
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
|
|
|
|||
|
|
@ -3597,9 +3597,8 @@ def fork_and_check_is_pinned():
|
|||
def worker(conn):
|
||||
try:
|
||||
x = torch.randn(10)
|
||||
x.is_pinned()
|
||||
dev = torch.accelerator.current_accelerator()
|
||||
x = torch.ones(10, device=dev)[0].item()
|
||||
x.is_pinned(device="cuda")
|
||||
x = torch.ones(10, device="cuda")[0].item()
|
||||
conn.send(x)
|
||||
except Exception as e:
|
||||
conn.send(str(e))
|
||||
|
|
@ -3619,7 +3618,7 @@ def fork_and_check_is_pinned():
|
|||
|
||||
x = torch.randn(10)
|
||||
# check that is_pinned won't poison future fork
|
||||
x.is_pinned()
|
||||
x.is_pinned(device="cuda")
|
||||
ret = fork_and_check_is_pinned()
|
||||
print(ret)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ r"""
|
|||
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import torch
|
||||
|
|
@ -35,9 +34,7 @@ def device_count() -> int:
|
|||
|
||||
|
||||
def is_available() -> bool:
|
||||
r"""Check if the current accelerator is available at runtime: it was build, all the
|
||||
required drivers are available and at least one device is visible.
|
||||
See :ref:`accelerator<accelerators>` for details.
|
||||
r"""Check if there is an available :ref:`accelerator<accelerators>`.
|
||||
|
||||
Returns:
|
||||
bool: A boolean indicating if there is an available :ref:`accelerator<accelerators>`.
|
||||
|
|
@ -46,47 +43,35 @@ def is_available() -> bool:
|
|||
|
||||
>>> assert torch.accelerator.is_available() "No available accelerators detected."
|
||||
"""
|
||||
# Why not just check "device_count() > 0" like other is_available call?
|
||||
# Because device like CUDA have a python implementation of is_available that is
|
||||
# non-poisoning and some features like Dataloader rely on it.
|
||||
# So we are careful to delegate to the Python version of the accelerator here
|
||||
acc = current_accelerator()
|
||||
if acc is None:
|
||||
return False
|
||||
|
||||
mod = torch.get_device_module(acc)
|
||||
return mod.is_available()
|
||||
return device_count() > 0
|
||||
|
||||
|
||||
def current_accelerator(check_available: bool = False) -> Optional[torch.device]:
|
||||
r"""Return the device of the accelerator available at compilation time.
|
||||
If no accelerator were available at compilation time, returns None.
|
||||
See :ref:`accelerator<accelerators>` for details.
|
||||
|
||||
Args:
|
||||
check_available (bool, optional): if True, will also do a runtime check to see
|
||||
if the device :func:`torch.accelerator.is_available` on top of the compile-time
|
||||
check.
|
||||
Default: ``False``
|
||||
def current_accelerator() -> torch.device:
|
||||
r"""Return the device of the current :ref:`accelerator<accelerators>`.
|
||||
|
||||
Returns:
|
||||
torch.device: return the current accelerator as :class:`torch.device`.
|
||||
|
||||
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
|
||||
:func:`torch.accelerator.current_device_index` to know the current index being used.
|
||||
And ensure to use :func:`torch.accelerator.is_available` to check if there is an available
|
||||
accelerator. If there is no available accelerator, this function will raise an exception.
|
||||
|
||||
Example::
|
||||
|
||||
>>> # xdoctest:
|
||||
>>> # If an accelerator is available, sent the model to it
|
||||
>>> model = torch.nn.Linear(2, 2)
|
||||
>>> if (current_device := current_accelerator(check_available=True)) is not None:
|
||||
>>> model.to(current_device)
|
||||
>>> if torch.accelerator.is_available():
|
||||
>>> current_device = torch.accelerator.current_accelerator()
|
||||
>>> else:
|
||||
>>> current_device = torch.device("cpu")
|
||||
>>> if current_device.type == 'cuda':
|
||||
>>> is_half_supported = torch.cuda.has_half
|
||||
>>> elif current_device.type == 'xpu':
|
||||
>>> is_half_supported = torch.xpu.get_device_properties().has_fp16
|
||||
>>> elif current_device.type == 'cpu':
|
||||
>>> is_half_supported = True
|
||||
"""
|
||||
if (acc := torch._C._accelerator_getAccelerator()) is not None:
|
||||
if (not check_available) or (check_available and is_available()):
|
||||
return acc
|
||||
return None
|
||||
return torch._C._accelerator_getAccelerator()
|
||||
|
||||
|
||||
def current_device_index() -> int:
|
||||
|
|
|
|||
|
|
@ -11,10 +11,7 @@ def _get_device_index(device: _device_t, optional: bool = False) -> int:
|
|||
device = torch.device(device)
|
||||
device_index: Optional[int] = None
|
||||
if isinstance(device, torch.device):
|
||||
acc = torch.accelerator.current_accelerator()
|
||||
if acc is None:
|
||||
raise RuntimeError("Accelerator expected")
|
||||
if acc.type != device.type:
|
||||
if torch.accelerator.current_accelerator().type != device.type:
|
||||
raise ValueError(
|
||||
f"{device.type} doesn't match the current accelerator {torch.accelerator.current_accelerator()}."
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,14 +6,9 @@ namespace torch::accelerator {
|
|||
void initModule(PyObject* module) {
|
||||
auto m = py::handle(module).cast<py::module>();
|
||||
|
||||
m.def("_accelerator_getAccelerator", []() -> std::optional<c10::Device> {
|
||||
// If no accelerator was available at compile time, return None.
|
||||
auto acc = at::getAccelerator(false);
|
||||
if (acc.has_value()) {
|
||||
return acc.value();
|
||||
} else {
|
||||
return std::nullopt;
|
||||
}
|
||||
m.def("_accelerator_getAccelerator", []() {
|
||||
// If no accelerator is currently available, raise an exception.
|
||||
return c10::Device(at::getAccelerator(true).value());
|
||||
});
|
||||
|
||||
m.def("_accelerator_deviceCount", []() {
|
||||
|
|
|
|||
|
|
@ -1246,13 +1246,9 @@ def _save(
|
|||
|
||||
if (
|
||||
config.save.use_pinned_memory_for_d2h
|
||||
and (
|
||||
acc := torch.accelerator.current_accelerator(
|
||||
check_available=True
|
||||
)
|
||||
)
|
||||
is not None
|
||||
and acc.type == storage.device.type
|
||||
and torch.accelerator.is_available()
|
||||
and torch.accelerator.current_accelerator().type
|
||||
== storage.device.type
|
||||
):
|
||||
new_storage = torch.empty(
|
||||
num_bytes, dtype=torch.uint8, device="cpu", pin_memory=True
|
||||
|
|
|
|||
|
|
@ -672,8 +672,7 @@ class _BaseDataLoaderIter:
|
|||
# memory allocation for MPS is fixed.
|
||||
if (
|
||||
self._pin_memory
|
||||
and (acc := torch.accelerator.current_accelerator()) is not None
|
||||
and acc.type == "mps"
|
||||
and torch.accelerator.current_accelerator().type == "mps"
|
||||
):
|
||||
self._pin_memory = False
|
||||
warn_msg = (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user