Allow torch.load under FakeTensorMode to load FakeTensors with correct devices (for plain Tensors) (#147786)

This only fixes _rebuild_tensor_v2 and _rebuild_tensor_v3

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147786
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki 2025-03-06 08:50:55 +00:00 committed by PyTorch MergeBot
parent 79aa17489c
commit bdcc1b579b
4 changed files with 53 additions and 13 deletions

View File

@ -72,6 +72,7 @@ from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.inductor_utils import GPU_TYPE
from torch.testing._internal.jit_utils import RUN_CUDA
from torch.testing._internal.two_tensor import TwoTensor
from torch.utils._mode_utils import no_dispatch
from torch.utils._python_dispatch import TorchDispatchMode
@ -1609,24 +1610,41 @@ class FakeTensorPropTest(TestCase):
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_torch_load_with_fake_mode(self):
class TheModelClass(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.fc1 = torch.nn.Linear(5, 10)
model = torch.nn.Linear(5, 10)
sd = model.state_dict()
sd['tt'] = TwoTensor(torch.randn(2), torch.randn(2))
def forward(self, x):
return self.fc1(x)
with TemporaryFileName() as state_dict_file:
with TemporaryFileName() as state_dict_file, torch.serialization.safe_globals([TwoTensor]):
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file)
torch.save(sd, state_dict_file)
fake_mode = FakeTensorMode()
with fake_mode:
torch.load(state_dict_file) # scenario 1
torch.load(state_dict_file, map_location="cpu") # scenario 2
sd_loaded = torch.load(state_dict_file)
self.assertEqual(sd_loaded["weight"].device.type, "cpu")
self.assertEqual(sd_loaded["tt"].device.type, "cpu")
sd_loaded = torch.load(state_dict_file, map_location="cuda")
self.assertEqual(sd_loaded["weight"].device.type, "cuda")
self.assertEqual(sd_loaded["tt"].device.type, "cuda")
for k in sd.keys():
sd[k] = sd[k].to('cuda')
with TemporaryFileName() as state_dict_file, torch.serialization.safe_globals([TwoTensor]):
torch.save(sd, state_dict_file)
fake_mode = FakeTensorMode()
with fake_mode:
sd_loaded = torch.load(state_dict_file)
self.assertEqual(sd_loaded["weight"].device.type, "cuda")
self.assertEqual(sd_loaded["tt"].device.type, "cuda")
sd_loaded = torch.load(state_dict_file, map_location="cpu")
self.assertEqual(sd_loaded["weight"].device.type, "cpu")
self.assertEqual(sd_loaded["tt"].device.type, "cpu")
make_propagate_real_tensors_cls(FakeTensorPropTest)

View File

@ -203,6 +203,16 @@ def set_tensor_metadata(tensor, metadata):
torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined]
def _restore_device_fake_mode(tensor):
if torch._guards.detect_fake_mode(None) is not None:
if tensor.untyped_storage()._fake_device is not None:
device = _get_restore_location(tensor.untyped_storage()._fake_device)
if not isinstance(device, torch.device):
device = torch.device(device)
tensor.fake_device = torch.device(device)
return tensor
def _rebuild_tensor_v2(
storage,
storage_offset,
@ -221,6 +231,8 @@ def _rebuild_tensor_v2(
# general expectation is that backward_hooks is an empty
# OrderedDict. See Note [Don't serialize hooks]
tensor._backward_hooks = backward_hooks
tensor = _restore_device_fake_mode(tensor)
return tensor
@ -244,6 +256,7 @@ def _rebuild_tensor_v3(
if metadata:
set_tensor_metadata(t, metadata)
t._backward_hooks = backward_hooks
t = _restore_device_fake_mode(t)
return t

View File

@ -2000,8 +2000,15 @@ def _load(
# TODO: Once we decide to break serialization FC, we can
# stop wrapping with TypedStorage
if torch._guards.detect_fake_mode(None) is None:
wrap_storage = restore_location(storage, location)
else:
storage._fake_device = location
wrap_storage = storage
typed_storage = torch.storage.TypedStorage(
wrap_storage=restore_location(storage, location),
wrap_storage=wrap_storage,
dtype=dtype,
_internal=True,
)

View File

@ -43,7 +43,9 @@ class _StorageBase:
is_sparse: _bool = False
is_sparse_csr: _bool = False
device: torch.device
# Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
# Used when
# (1) stashing FakeTensor device onto storage in torch.serialization.skip_data
# (2) stashing device onto storage to propagate to FakeTensor when torch.load under FakeTensorMode
_fake_device: _Optional[torch.device] = None
def __init__(self, *args, **kwargs):