Do not run CUDA lazy init if it is triggered with fake mode on. (#122636)

Partially fixes https://github.com/pytorch/pytorch/issues/122109

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122636
Approved by: https://github.com/zou3519
This commit is contained in:
Edward Z. Yang 2024-03-25 13:50:56 -07:00 committed by PyTorch MergeBot
parent dd3f2cb53a
commit 268b0cc714
3 changed files with 21 additions and 3 deletions

View File

@ -1162,6 +1162,14 @@ class FakeTensorOperatorInvariants(TestCase):
self.assertTrue("output[0]" not in str(e))
self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
# IMPORTANT!!! Always run even if CUDA is not available
def test_fake_cuda_no_init(self):
with FakeTensorMode():
torch.empty(10, device='cuda')
torch.ones(10, device='cuda')
torch.zeros(10, device='cuda')
torch.rand(10, device='cuda')
@skipIfRocm
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_conv_c1_backward(self):

View File

@ -454,9 +454,12 @@ class FakeTensor(torch.Tensor):
in ["cuda", "hpu", "xpu", torch._C._get_privateuse1_backend_name()]
and device.index is None
):
device = torch.device(
f"{device.type}:{getattr(torch, device.type).current_device()}"
)
if getattr(torch, device.type).is_initialized():
device = torch.device(
f"{device.type}:{getattr(torch, device.type).current_device()}"
)
else:
device = torch.device(f"{device.type}:0")
self.fake_device = device # type: ignore[attr-defined]
self.fake_mode = fake_mode # type: ignore[attr-defined]
self.constant = constant # type: ignore[attr-defined]

View File

@ -1,3 +1,4 @@
#include <c10/core/impl/TorchDispatchModeTLS.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/Exceptions.h>
@ -21,6 +22,12 @@ void device_lazy_init(at::DeviceType device_type) {
return;
}
auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(
c10::impl::TorchDispatchModeKey::FAKE);
if (maybe_mode) {
return;
}
std::string module_name = "torch." + at::DeviceTypeName(device_type, true);
auto module = THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
if (!module) {