mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Return NoOpDeviceGuardImpl in replace of CudaDeviceGuard when device is not available, or cpu-only build (#160532)
Summary:
To support exporting a cuda model on a CPU-only machine under fake tensor mode.
User commonly need to move sample inputs to the cuda device with .to("cuda:0") or .to("cuda") call.
This diff supports this.
I expect the following pattern to work
```
with FakeTensorMode(allow_non_fake_inputs=True):
cuda_module = module.to("cuda:0")
cuda_sample_inputs = tuple([x.to("cuda:0") for x in sample_inputs])
with torch.no_grad():
ep = torch.export.export(cuda_module, cuda_sample_inputs)
```
Test Plan:
CI
Rollback Plan:
Differential Revision: D80181887
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160532
Approved by: https://github.com/henryoier, https://github.com/ezyang
This commit is contained in:
parent
0925c644ed
commit
a956c4ab1c
|
|
@ -1,4 +1,5 @@
|
||||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||||
|
#include <c10/core/impl/FakeGuardImpl.h>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
namespace c10::impl {
|
namespace c10::impl {
|
||||||
|
|
@ -14,4 +15,27 @@ DeviceGuardImplRegistrar::DeviceGuardImplRegistrar(
|
||||||
device_guard_impl_registry[static_cast<size_t>(type)].store(impl);
|
device_guard_impl_registry[static_cast<size_t>(type)].store(impl);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
thread_local std::unique_ptr<DeviceGuardImplInterface> tls_fake_device_guard =
|
||||||
|
nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ensureCUDADeviceGuardSet() {
|
||||||
|
constexpr auto cuda_idx = static_cast<std::size_t>(DeviceType::CUDA);
|
||||||
|
|
||||||
|
const DeviceGuardImplInterface* p =
|
||||||
|
device_guard_impl_registry[cuda_idx].load();
|
||||||
|
|
||||||
|
// A non-null `ptr` indicates that CUDA is already available.
|
||||||
|
if (p == nullptr || (p && p->deviceCount() == 0)) {
|
||||||
|
// In following cases, we override CUDA guard interface with a no-op
|
||||||
|
// device guard.
|
||||||
|
// 1. p == nullptr; Trying to get a cuda device guard on a cpu-only build.
|
||||||
|
// 2. p->deviceCount() == 0; cuda build enabled, but no cuda devices
|
||||||
|
// available.
|
||||||
|
tls_fake_device_guard = std::make_unique<FakeGuardImpl<DeviceType::CUDA>>();
|
||||||
|
device_guard_impl_registry[cuda_idx].store(tls_fake_device_guard.get());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace c10::impl
|
} // namespace c10::impl
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
// Just for C10_ANONYMOUS_VARIABLE
|
// Just for C10_ANONYMOUS_VARIABLE
|
||||||
|
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
||||||
#include <c10/util/Registry.h>
|
#include <c10/util/Registry.h>
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
@ -251,7 +252,7 @@ struct C10_API DeviceGuardImplInterface {
|
||||||
// for devices that don't actually have a concept of device index. Prominent
|
// for devices that don't actually have a concept of device index. Prominent
|
||||||
// examples are CPU and Meta.
|
// examples are CPU and Meta.
|
||||||
template <DeviceType D>
|
template <DeviceType D>
|
||||||
struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface {
|
struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface {
|
||||||
NoOpDeviceGuardImpl() = default;
|
NoOpDeviceGuardImpl() = default;
|
||||||
DeviceType type() const override {
|
DeviceType type() const override {
|
||||||
return D;
|
return D;
|
||||||
|
|
@ -371,5 +372,7 @@ inline bool hasDeviceGuardImpl(DeviceType type) {
|
||||||
return device_guard_impl_registry[static_cast<size_t>(type)].load();
|
return device_guard_impl_registry[static_cast<size_t>(type)].load();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void C10_API ensureCUDADeviceGuardSet();
|
||||||
|
|
||||||
} // namespace impl
|
} // namespace impl
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@
|
||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
|
import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
||||||
|
|
@ -50,17 +51,11 @@ fake_export_failures = {
|
||||||
xfail("masked.std"),
|
xfail("masked.std"),
|
||||||
xfail("masked.sum"),
|
xfail("masked.sum"),
|
||||||
xfail("masked.var"),
|
xfail("masked.var"),
|
||||||
xfail("nn.functional.grid_sample"),
|
|
||||||
xfail("to_sparse"),
|
xfail("to_sparse"),
|
||||||
# cannot xfail as it is passing for cpu-only build
|
# cannot xfail as it is passing for cpu-only build
|
||||||
|
skip("nn.functional.grid_sample"),
|
||||||
skip("nn.functional.conv2d"),
|
skip("nn.functional.conv2d"),
|
||||||
skip("nn.functional.scaled_dot_product_attention"),
|
skip("nn.functional.scaled_dot_product_attention"),
|
||||||
# following are failing due to OptionalDeviceGuard
|
|
||||||
xfail("__getitem__"),
|
|
||||||
xfail("nn.functional.batch_norm"),
|
|
||||||
xfail("nn.functional.instance_norm"),
|
|
||||||
xfail("nn.functional.multi_margin_loss"),
|
|
||||||
xfail("nonzero"),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fake_decomposition_failures = {
|
fake_decomposition_failures = {
|
||||||
|
|
@ -128,9 +123,52 @@ class TestExportOpInfo(TestCase):
|
||||||
def test_fake_export(self, device, dtype, op):
|
def test_fake_export(self, device, dtype, op):
|
||||||
_test_export_helper(self, dtype, op)
|
_test_export_helper(self, dtype, op)
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.backends.cuda.is_built(), "requires CUDA build")
|
||||||
|
def test_preserve_original_behavior(self):
|
||||||
|
def cuda_calls_behavior_unchanged():
|
||||||
|
cpu_x = torch.randn(2)
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Found no NVIDIA driver on your system."
|
||||||
|
):
|
||||||
|
cuda_x = cpu_x.to("cuda")
|
||||||
|
|
||||||
only_for = "cpu"
|
with self.assertRaisesRegex(
|
||||||
instantiate_device_type_tests(TestExportOpInfo, globals(), only_for=only_for)
|
RuntimeError, "Found no NVIDIA driver on your system."
|
||||||
|
):
|
||||||
|
torch.randn(2, device="cuda")
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Found no NVIDIA driver on your system."
|
||||||
|
):
|
||||||
|
torch.cuda.get_device_capability()
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Found no NVIDIA driver on your system."
|
||||||
|
):
|
||||||
|
torch.cuda.set_device(1)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError, "Found no NVIDIA driver on your system."
|
||||||
|
):
|
||||||
|
torch.cuda.current_device()
|
||||||
|
|
||||||
|
self.assertEqual(torch.cuda.is_available(), False)
|
||||||
|
self.assertEqual(torch.cuda.device_count(), 0)
|
||||||
|
|
||||||
|
cuda_calls_behavior_unchanged()
|
||||||
|
|
||||||
|
cpu_x = torch.randn(2)
|
||||||
|
with FakeTensorMode(allow_non_fake_inputs=True) as mode:
|
||||||
|
cuda_x = mode.from_tensor(cpu_x)
|
||||||
|
cuda_x.fake_device = torch.device("cuda")
|
||||||
|
cuda_y = cuda_x + cuda_x
|
||||||
|
self.assertEqual(cuda_y.device.type, "cuda")
|
||||||
|
|
||||||
|
# should fail again after exiting the fake mode, with the identical error message
|
||||||
|
cuda_calls_behavior_unchanged()
|
||||||
|
|
||||||
|
|
||||||
|
instantiate_device_type_tests(TestExportOpInfo, globals(), only_for="cpu")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1379,6 +1379,7 @@ def _get_linalg_preferred_backend() -> _LinalgBackend: ...
|
||||||
def _set_linalg_preferred_backend(arg: _LinalgBackend): ...
|
def _set_linalg_preferred_backend(arg: _LinalgBackend): ...
|
||||||
def _get_fp32_precision_getter(backend: str, op: str) -> str: ...
|
def _get_fp32_precision_getter(backend: str, op: str) -> str: ...
|
||||||
def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ...
|
def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ...
|
||||||
|
def _ensureCUDADeviceGuardSet() -> None: ...
|
||||||
|
|
||||||
class _LinalgBackend:
|
class _LinalgBackend:
|
||||||
Default: _LinalgBackend
|
Default: _LinalgBackend
|
||||||
|
|
|
||||||
|
|
@ -1387,6 +1387,12 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
# See NOTE: [torch.tensor, lift_fresh, and device movement]
|
# See NOTE: [torch.tensor, lift_fresh, and device movement]
|
||||||
prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors()
|
prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors()
|
||||||
torch._C._set_only_lift_cpu_tensors(True)
|
torch._C._set_only_lift_cpu_tensors(True)
|
||||||
|
|
||||||
|
# In the case of CPU-only build or cuda device unavailable,
|
||||||
|
# we patch the cuda device guard to use NoOpDeviceGuardImpl.
|
||||||
|
# This enables us to trace over cuda kernels under FakeTensorMode.
|
||||||
|
torch._C._ensureCUDADeviceGuardSet()
|
||||||
|
|
||||||
maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key)
|
maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key)
|
||||||
if self is not maybe_prev_fake_mode:
|
if self is not maybe_prev_fake_mode:
|
||||||
self.enter_stack.append(
|
self.enter_stack.append(
|
||||||
|
|
@ -1397,6 +1403,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
# no-op (still need to re-set the fake mode though since we unset it)
|
# no-op (still need to re-set the fake mode though since we unset it)
|
||||||
torch._C._set_dispatch_mode(self)
|
torch._C._set_dispatch_mode(self)
|
||||||
self.enter_stack.append((False, None, prev_only_lift_cpu_tensors))
|
self.enter_stack.append((False, None, prev_only_lift_cpu_tensors))
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __exit__(
|
def __exit__(
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@
|
||||||
#include <ATen/native/Normalization.h>
|
#include <ATen/native/Normalization.h>
|
||||||
#include <c10/core/Device.h>
|
#include <c10/core/Device.h>
|
||||||
#include <c10/core/DispatchKeySet.h>
|
#include <c10/core/DispatchKeySet.h>
|
||||||
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||||
#include <c10/util/AbortHandler.h>
|
#include <c10/util/AbortHandler.h>
|
||||||
#include <c10/util/Backtrace.h>
|
#include <c10/util/Backtrace.h>
|
||||||
#include <c10/util/Logging.h>
|
#include <c10/util/Logging.h>
|
||||||
|
|
@ -1550,6 +1551,15 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled(
|
||||||
END_HANDLE_TH_ERRORS
|
END_HANDLE_TH_ERRORS
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static PyObject* THCPModule_ensureCUDADeviceGuardSet(
|
||||||
|
PyObject* self,
|
||||||
|
PyObject* noargs) {
|
||||||
|
HANDLE_TH_ERRORS
|
||||||
|
c10::impl::ensureCUDADeviceGuardSet();
|
||||||
|
Py_RETURN_NONE;
|
||||||
|
END_HANDLE_TH_ERRORS
|
||||||
|
}
|
||||||
|
|
||||||
static std::initializer_list<PyMethodDef> TorchMethods = {
|
static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||||
{"_initExtension", THPModule_initExtension, METH_O, nullptr},
|
{"_initExtension", THPModule_initExtension, METH_O, nullptr},
|
||||||
{"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
|
{"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr},
|
||||||
|
|
@ -1845,7 +1855,13 @@ static std::initializer_list<PyMethodDef> TorchMethods = {
|
||||||
(PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
|
(PyCFunction)(void (*)())THPModule_has_torch_function_variadic,
|
||||||
METH_FASTCALL,
|
METH_FASTCALL,
|
||||||
nullptr},
|
nullptr},
|
||||||
{nullptr, nullptr, 0, nullptr}};
|
{"_ensureCUDADeviceGuardSet",
|
||||||
|
THCPModule_ensureCUDADeviceGuardSet,
|
||||||
|
METH_NOARGS,
|
||||||
|
nullptr},
|
||||||
|
{nullptr, nullptr, 0, nullptr}
|
||||||
|
|
||||||
|
};
|
||||||
|
|
||||||
#ifdef USE_CUDA
|
#ifdef USE_CUDA
|
||||||
// NOLINTBEGIN(misc-use-internal-linkage)
|
// NOLINTBEGIN(misc-use-internal-linkage)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user