diff --git a/c10/core/impl/DeviceGuardImplInterface.cpp b/c10/core/impl/DeviceGuardImplInterface.cpp index 015bcd3e64f..1fb78aa443e 100644 --- a/c10/core/impl/DeviceGuardImplInterface.cpp +++ b/c10/core/impl/DeviceGuardImplInterface.cpp @@ -1,4 +1,5 @@ #include +#include #include namespace c10::impl { @@ -14,4 +15,27 @@ DeviceGuardImplRegistrar::DeviceGuardImplRegistrar( device_guard_impl_registry[static_cast(type)].store(impl); } +namespace { +thread_local std::unique_ptr tls_fake_device_guard = + nullptr; +} + +void ensureCUDADeviceGuardSet() { + constexpr auto cuda_idx = static_cast(DeviceType::CUDA); + + const DeviceGuardImplInterface* p = + device_guard_impl_registry[cuda_idx].load(); + + // A non-null `ptr` indicates that CUDA is already available. + if (p == nullptr || (p && p->deviceCount() == 0)) { + // In following cases, we override CUDA guard interface with a no-op + // device guard. + // 1. p == nullptr; Trying to get a cuda device guard on a cpu-only build. + // 2. p->deviceCount() == 0; cuda build enabled, but no cuda devices + // available. + tls_fake_device_guard = std::make_unique>(); + device_guard_impl_registry[cuda_idx].store(tls_fake_device_guard.get()); + } +} + } // namespace c10::impl diff --git a/c10/core/impl/DeviceGuardImplInterface.h b/c10/core/impl/DeviceGuardImplInterface.h index 523e9ad9f45..fc8c367f75e 100644 --- a/c10/core/impl/DeviceGuardImplInterface.h +++ b/c10/core/impl/DeviceGuardImplInterface.h @@ -6,6 +6,7 @@ #include // Just for C10_ANONYMOUS_VARIABLE +#include #include #include @@ -251,7 +252,7 @@ struct C10_API DeviceGuardImplInterface { // for devices that don't actually have a concept of device index. Prominent // examples are CPU and Meta. template -struct NoOpDeviceGuardImpl final : public DeviceGuardImplInterface { +struct NoOpDeviceGuardImpl : public DeviceGuardImplInterface { NoOpDeviceGuardImpl() = default; DeviceType type() const override { return D; @@ -371,5 +372,7 @@ inline bool hasDeviceGuardImpl(DeviceType type) { return device_guard_impl_registry[static_cast(type)].load(); } +void C10_API ensureCUDADeviceGuardSet(); + } // namespace impl } // namespace c10 diff --git a/test/export/test_export_opinfo.py b/test/export/test_export_opinfo.py index 35d8b2895bd..24e2f71ff43 100644 --- a/test/export/test_export_opinfo.py +++ b/test/export/test_export_opinfo.py @@ -3,6 +3,7 @@ # flake8: noqa import itertools +import unittest import torch from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode @@ -50,17 +51,11 @@ fake_export_failures = { xfail("masked.std"), xfail("masked.sum"), xfail("masked.var"), - xfail("nn.functional.grid_sample"), xfail("to_sparse"), # cannot xfail as it is passing for cpu-only build + skip("nn.functional.grid_sample"), skip("nn.functional.conv2d"), skip("nn.functional.scaled_dot_product_attention"), - # following are failing due to OptionalDeviceGuard - xfail("__getitem__"), - xfail("nn.functional.batch_norm"), - xfail("nn.functional.instance_norm"), - xfail("nn.functional.multi_margin_loss"), - xfail("nonzero"), } fake_decomposition_failures = { @@ -128,9 +123,52 @@ class TestExportOpInfo(TestCase): def test_fake_export(self, device, dtype, op): _test_export_helper(self, dtype, op) + @unittest.skipIf(not torch.backends.cuda.is_built(), "requires CUDA build") + def test_preserve_original_behavior(self): + def cuda_calls_behavior_unchanged(): + cpu_x = torch.randn(2) + with self.assertRaisesRegex( + RuntimeError, "Found no NVIDIA driver on your system." + ): + cuda_x = cpu_x.to("cuda") -only_for = "cpu" -instantiate_device_type_tests(TestExportOpInfo, globals(), only_for=only_for) + with self.assertRaisesRegex( + RuntimeError, "Found no NVIDIA driver on your system." + ): + torch.randn(2, device="cuda") + + with self.assertRaisesRegex( + RuntimeError, "Found no NVIDIA driver on your system." + ): + torch.cuda.get_device_capability() + + with self.assertRaisesRegex( + RuntimeError, "Found no NVIDIA driver on your system." + ): + torch.cuda.set_device(1) + + with self.assertRaisesRegex( + RuntimeError, "Found no NVIDIA driver on your system." + ): + torch.cuda.current_device() + + self.assertEqual(torch.cuda.is_available(), False) + self.assertEqual(torch.cuda.device_count(), 0) + + cuda_calls_behavior_unchanged() + + cpu_x = torch.randn(2) + with FakeTensorMode(allow_non_fake_inputs=True) as mode: + cuda_x = mode.from_tensor(cpu_x) + cuda_x.fake_device = torch.device("cuda") + cuda_y = cuda_x + cuda_x + self.assertEqual(cuda_y.device.type, "cuda") + + # should fail again after exiting the fake mode, with the identical error message + cuda_calls_behavior_unchanged() + + +instantiate_device_type_tests(TestExportOpInfo, globals(), only_for="cpu") if __name__ == "__main__": diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index e55137c3d2b..83cacaf69de 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1379,6 +1379,7 @@ def _get_linalg_preferred_backend() -> _LinalgBackend: ... def _set_linalg_preferred_backend(arg: _LinalgBackend): ... def _get_fp32_precision_getter(backend: str, op: str) -> str: ... def _set_fp32_precision_setter(backend: str, op: str, value: str) -> str: ... +def _ensureCUDADeviceGuardSet() -> None: ... class _LinalgBackend: Default: _LinalgBackend diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 5767f6a1d0c..6b55abcef00 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1387,6 +1387,12 @@ class FakeTensorMode(TorchDispatchMode): # See NOTE: [torch.tensor, lift_fresh, and device movement] prev_only_lift_cpu_tensors = torch._C._only_lift_cpu_tensors() torch._C._set_only_lift_cpu_tensors(True) + + # In the case of CPU-only build or cuda device unavailable, + # we patch the cuda device guard to use NoOpDeviceGuardImpl. + # This enables us to trace over cuda kernels under FakeTensorMode. + torch._C._ensureCUDADeviceGuardSet() + maybe_prev_fake_mode = torch._C._unset_dispatch_mode(self._mode_key) if self is not maybe_prev_fake_mode: self.enter_stack.append( @@ -1397,6 +1403,7 @@ class FakeTensorMode(TorchDispatchMode): # no-op (still need to re-set the fake mode though since we unset it) torch._C._set_dispatch_mode(self) self.enter_stack.append((False, None, prev_only_lift_cpu_tensors)) + return self def __exit__( diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index ac2b03d2651..d040e16ba52 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -1550,6 +1551,15 @@ static PyObject* THPModule_are_vmap_fallback_warnings_enabled( END_HANDLE_TH_ERRORS } +static PyObject* THCPModule_ensureCUDADeviceGuardSet( + PyObject* self, + PyObject* noargs) { + HANDLE_TH_ERRORS + c10::impl::ensureCUDADeviceGuardSet(); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + static std::initializer_list TorchMethods = { {"_initExtension", THPModule_initExtension, METH_O, nullptr}, {"_autograd_init", THPAutograd_initExtension, METH_NOARGS, nullptr}, @@ -1845,7 +1855,13 @@ static std::initializer_list TorchMethods = { (PyCFunction)(void (*)())THPModule_has_torch_function_variadic, METH_FASTCALL, nullptr}, - {nullptr, nullptr, 0, nullptr}}; + {"_ensureCUDADeviceGuardSet", + THCPModule_ensureCUDADeviceGuardSet, + METH_NOARGS, + nullptr}, + {nullptr, nullptr, 0, nullptr} + +}; #ifdef USE_CUDA // NOLINTBEGIN(misc-use-internal-linkage)