From 268b0cc71422573b48adb8fb9bc3fe52d1855bf5 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Mon, 25 Mar 2024 13:50:56 -0700 Subject: [PATCH] Do not run CUDA lazy init if it is triggered with fake mode on. (#122636) Partially fixes https://github.com/pytorch/pytorch/issues/122109 Signed-off-by: Edward Z. Yang Pull Request resolved: https://github.com/pytorch/pytorch/pull/122636 Approved by: https://github.com/zou3519 --- test/test_fake_tensor.py | 8 ++++++++ torch/_subclasses/fake_tensor.py | 9 ++++++--- torch/csrc/utils/device_lazy_init.cpp | 7 +++++++ 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index ed91effb704..851e94668c3 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -1162,6 +1162,14 @@ class FakeTensorOperatorInvariants(TestCase): self.assertTrue("output[0]" not in str(e)) self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e)) + # IMPORTANT!!! Always run even if CUDA is not available + def test_fake_cuda_no_init(self): + with FakeTensorMode(): + torch.empty(10, device='cuda') + torch.ones(10, device='cuda') + torch.zeros(10, device='cuda') + torch.rand(10, device='cuda') + @skipIfRocm @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_conv_c1_backward(self): diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index a49320313ff..a6a9addccfb 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -454,9 +454,12 @@ class FakeTensor(torch.Tensor): in ["cuda", "hpu", "xpu", torch._C._get_privateuse1_backend_name()] and device.index is None ): - device = torch.device( - f"{device.type}:{getattr(torch, device.type).current_device()}" - ) + if getattr(torch, device.type).is_initialized(): + device = torch.device( + f"{device.type}:{getattr(torch, device.type).current_device()}" + ) + else: + device = torch.device(f"{device.type}:0") self.fake_device = device # type: ignore[attr-defined] self.fake_mode = fake_mode # type: ignore[attr-defined] self.constant = constant # type: ignore[attr-defined] diff --git a/torch/csrc/utils/device_lazy_init.cpp b/torch/csrc/utils/device_lazy_init.cpp index f3b93e1da50..355fd426b14 100644 --- a/torch/csrc/utils/device_lazy_init.cpp +++ b/torch/csrc/utils/device_lazy_init.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -21,6 +22,12 @@ void device_lazy_init(at::DeviceType device_type) { return; } + auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode( + c10::impl::TorchDispatchModeKey::FAKE); + if (maybe_mode) { + return; + } + std::string module_name = "torch." + at::DeviceTypeName(device_type, true); auto module = THPObjectPtr(PyImport_ImportModule(module_name.c_str())); if (!module) {