mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Do not run CUDA lazy init if it is triggered with fake mode on. (#122636)
Partially fixes https://github.com/pytorch/pytorch/issues/122109 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/122636 Approved by: https://github.com/zou3519
This commit is contained in:
parent
dd3f2cb53a
commit
268b0cc714
|
|
@ -1162,6 +1162,14 @@ class FakeTensorOperatorInvariants(TestCase):
|
|||
self.assertTrue("output[0]" not in str(e))
|
||||
self.assertTrue("found mismatched tensor metadata for output[6]: Devices cpu and cuda:0 are not equal!" in str(e))
|
||||
|
||||
# IMPORTANT!!! Always run even if CUDA is not available
|
||||
def test_fake_cuda_no_init(self):
|
||||
with FakeTensorMode():
|
||||
torch.empty(10, device='cuda')
|
||||
torch.ones(10, device='cuda')
|
||||
torch.zeros(10, device='cuda')
|
||||
torch.rand(10, device='cuda')
|
||||
|
||||
@skipIfRocm
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_conv_c1_backward(self):
|
||||
|
|
|
|||
|
|
@ -454,9 +454,12 @@ class FakeTensor(torch.Tensor):
|
|||
in ["cuda", "hpu", "xpu", torch._C._get_privateuse1_backend_name()]
|
||||
and device.index is None
|
||||
):
|
||||
device = torch.device(
|
||||
f"{device.type}:{getattr(torch, device.type).current_device()}"
|
||||
)
|
||||
if getattr(torch, device.type).is_initialized():
|
||||
device = torch.device(
|
||||
f"{device.type}:{getattr(torch, device.type).current_device()}"
|
||||
)
|
||||
else:
|
||||
device = torch.device(f"{device.type}:0")
|
||||
self.fake_device = device # type: ignore[attr-defined]
|
||||
self.fake_mode = fake_mode # type: ignore[attr-defined]
|
||||
self.constant = constant # type: ignore[attr-defined]
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
#include <c10/core/impl/TorchDispatchModeTLS.h>
|
||||
#include <torch/csrc/utils/device_lazy_init.h>
|
||||
|
||||
#include <torch/csrc/Exceptions.h>
|
||||
|
|
@ -21,6 +22,12 @@ void device_lazy_init(at::DeviceType device_type) {
|
|||
return;
|
||||
}
|
||||
|
||||
auto maybe_mode = c10::impl::TorchDispatchModeTLS::get_mode(
|
||||
c10::impl::TorchDispatchModeKey::FAKE);
|
||||
if (maybe_mode) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::string module_name = "torch." + at::DeviceTypeName(device_type, true);
|
||||
auto module = THPObjectPtr(PyImport_ImportModule(module_name.c_str()));
|
||||
if (!module) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user