mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Allow torch.load under FakeTensorMode to load FakeTensors with correct devices (for plain Tensors) (#147786)
This only fixes _rebuild_tensor_v2 and _rebuild_tensor_v3 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147786 Approved by: https://github.com/albanD
This commit is contained in:
parent
79aa17489c
commit
bdcc1b579b
|
|
@ -72,6 +72,7 @@ from torch.testing._internal.custom_op_db import custom_op_db
|
|||
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||
from torch.testing._internal.jit_utils import RUN_CUDA
|
||||
from torch.testing._internal.two_tensor import TwoTensor
|
||||
from torch.utils._mode_utils import no_dispatch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
|
|
@ -1609,24 +1610,41 @@ class FakeTensorPropTest(TestCase):
|
|||
|
||||
self.assertEqual(fake_r.T.is_contiguous(), r.T.is_contiguous())
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_torch_load_with_fake_mode(self):
|
||||
class TheModelClass(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(5, 10)
|
||||
model = torch.nn.Linear(5, 10)
|
||||
sd = model.state_dict()
|
||||
sd['tt'] = TwoTensor(torch.randn(2), torch.randn(2))
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc1(x)
|
||||
|
||||
with TemporaryFileName() as state_dict_file:
|
||||
with TemporaryFileName() as state_dict_file, torch.serialization.safe_globals([TwoTensor]):
|
||||
# Create state_dict to be loaded later
|
||||
model = TheModelClass()
|
||||
torch.save(model.state_dict(), state_dict_file)
|
||||
torch.save(sd, state_dict_file)
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
torch.load(state_dict_file) # scenario 1
|
||||
torch.load(state_dict_file, map_location="cpu") # scenario 2
|
||||
sd_loaded = torch.load(state_dict_file)
|
||||
self.assertEqual(sd_loaded["weight"].device.type, "cpu")
|
||||
self.assertEqual(sd_loaded["tt"].device.type, "cpu")
|
||||
sd_loaded = torch.load(state_dict_file, map_location="cuda")
|
||||
self.assertEqual(sd_loaded["weight"].device.type, "cuda")
|
||||
self.assertEqual(sd_loaded["tt"].device.type, "cuda")
|
||||
|
||||
|
||||
for k in sd.keys():
|
||||
sd[k] = sd[k].to('cuda')
|
||||
|
||||
with TemporaryFileName() as state_dict_file, torch.serialization.safe_globals([TwoTensor]):
|
||||
torch.save(sd, state_dict_file)
|
||||
|
||||
fake_mode = FakeTensorMode()
|
||||
with fake_mode:
|
||||
sd_loaded = torch.load(state_dict_file)
|
||||
self.assertEqual(sd_loaded["weight"].device.type, "cuda")
|
||||
self.assertEqual(sd_loaded["tt"].device.type, "cuda")
|
||||
sd_loaded = torch.load(state_dict_file, map_location="cpu")
|
||||
self.assertEqual(sd_loaded["weight"].device.type, "cpu")
|
||||
self.assertEqual(sd_loaded["tt"].device.type, "cpu")
|
||||
|
||||
|
||||
make_propagate_real_tensors_cls(FakeTensorPropTest)
|
||||
|
|
|
|||
|
|
@ -203,6 +203,16 @@ def set_tensor_metadata(tensor, metadata):
|
|||
torch._C._set_tensor_metadata(tensor, metadata) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def _restore_device_fake_mode(tensor):
|
||||
if torch._guards.detect_fake_mode(None) is not None:
|
||||
if tensor.untyped_storage()._fake_device is not None:
|
||||
device = _get_restore_location(tensor.untyped_storage()._fake_device)
|
||||
if not isinstance(device, torch.device):
|
||||
device = torch.device(device)
|
||||
tensor.fake_device = torch.device(device)
|
||||
return tensor
|
||||
|
||||
|
||||
def _rebuild_tensor_v2(
|
||||
storage,
|
||||
storage_offset,
|
||||
|
|
@ -221,6 +231,8 @@ def _rebuild_tensor_v2(
|
|||
# general expectation is that backward_hooks is an empty
|
||||
# OrderedDict. See Note [Don't serialize hooks]
|
||||
tensor._backward_hooks = backward_hooks
|
||||
|
||||
tensor = _restore_device_fake_mode(tensor)
|
||||
return tensor
|
||||
|
||||
|
||||
|
|
@ -244,6 +256,7 @@ def _rebuild_tensor_v3(
|
|||
if metadata:
|
||||
set_tensor_metadata(t, metadata)
|
||||
t._backward_hooks = backward_hooks
|
||||
t = _restore_device_fake_mode(t)
|
||||
return t
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2000,8 +2000,15 @@ def _load(
|
|||
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# stop wrapping with TypedStorage
|
||||
|
||||
if torch._guards.detect_fake_mode(None) is None:
|
||||
wrap_storage = restore_location(storage, location)
|
||||
else:
|
||||
storage._fake_device = location
|
||||
wrap_storage = storage
|
||||
|
||||
typed_storage = torch.storage.TypedStorage(
|
||||
wrap_storage=restore_location(storage, location),
|
||||
wrap_storage=wrap_storage,
|
||||
dtype=dtype,
|
||||
_internal=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,9 @@ class _StorageBase:
|
|||
is_sparse: _bool = False
|
||||
is_sparse_csr: _bool = False
|
||||
device: torch.device
|
||||
# Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True)
|
||||
# Used when
|
||||
# (1) stashing FakeTensor device onto storage in torch.serialization.skip_data
|
||||
# (2) stashing device onto storage to propagate to FakeTensor when torch.load under FakeTensorMode
|
||||
_fake_device: _Optional[torch.device] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user