diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 54f8862499a..6aedc7d281f 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -7,12 +7,18 @@ import tempfile import types import unittest from typing import Union +from unittest.mock import patch import torch import torch.testing._internal.common_utils as common import torch.utils.cpp_extension -from torch.testing._internal.common_utils import IS_ARM64, skipIfTorchDynamo, TEST_CUDA +from torch.testing._internal.common_utils import ( + IS_ARM64, + skipIfTorchDynamo, + TemporaryFileName, + TEST_CUDA, +) from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME @@ -572,6 +578,24 @@ class TestCppExtensionOpenRgistration(common.TestCase): self.assertEqual(z_cpu, z[0]) self.assertEqual(z_cpu, z[1]) + def test_open_device_numpy_serialization_map_location(self): + torch.utils.rename_privateuse1_backend("foo") + device = self.module.custom_device() + default_protocol = torch.serialization.DEFAULT_PROTOCOL + # This is a hack to test serialization through numpy + with patch.object(torch._C, "_has_storage", return_value=False): + x = torch.randn(2, 3) + x_foo = x.to(device) + sd = {"x": x_foo} + rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0] + self.assertTrue( + rebuild_func is torch._utils._rebuild_device_tensor_from_numpy + ) + with TemporaryFileName() as f: + torch.save(sd, f) + sd_loaded = torch.load(f, map_location="cpu") + self.assertTrue(sd_loaded["x"].is_cpu) + if __name__ == "__main__": common.run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index 1be1b06ab78..e83cafd3f3d 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -4403,6 +4403,22 @@ class TestSubclassSerialization(TestCase): with self.assertRaisesRegex(TypeError, "'str' object is not callable"): loaded = torch.load(f, weights_only=True) + @unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda") + def test_tensor_subclass_map_location(self): + t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3)) + sd = {'t': t} + + with TemporaryFileName() as f: + torch.save(sd, f) + sd_loaded = torch.load(f, map_location=torch.device('cuda:0')) + self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0')) + self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0')) + self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0')) + # make sure map_location is not propagated over multiple torch.load calls + sd_loaded = torch.load(f) + self.assertTrue(sd_loaded['t'].device == torch.device('cpu')) + self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu')) + self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu')) instantiate_device_type_tests(TestBothSerialization, globals()) diff --git a/torch/_utils.py b/torch/_utils.py index 1bb726252de..8f866f9cf80 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -2,6 +2,7 @@ import copyreg import functools import logging import sys +import threading import traceback import warnings from collections import defaultdict @@ -108,6 +109,31 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] +_thread_local_state = threading.local() + + +def _get_restore_location(device): + """Return the map_location location. + + Used for rebuild functions where the tensor device is distinct from the storage + """ + + map_location = getattr(_thread_local_state, "map_location", None) + if map_location is None: + return device + else: + if isinstance(map_location, dict): + return map_location.get(device, device) + elif isinstance(map_location, (str, torch.device)): + return map_location + else: + assert callable(map_location) + raise RuntimeError( + "Callable map_location not supported with _rebuild_wrapper_subclass " + "or _rebuild_device_tensor_from_numpy" + ) + + # Note [Don't serialize hooks] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # Since time immemorial, we have serialized the backward hooks associated with @@ -303,6 +329,7 @@ def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets): def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): + device = _get_restore_location(device) tensor = torch.from_numpy(data).to(dtype=dtype, device=device) tensor.requires_grad = requires_grad return tensor @@ -321,6 +348,7 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): def _rebuild_wrapper_subclass( cls, dtype, size, stride, storage_offset, layout, device, requires_grad ): + device = _get_restore_location(device) return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] cls, size, diff --git a/torch/serialization.py b/torch/serialization.py index a7703b9964d..e4ad1f7e9c6 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -1453,7 +1453,11 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall unpickler = UnpicklerWrapper(data_file, **pickle_load_args) unpickler.persistent_load = persistent_load + # Needed for tensors where storage device and rebuild tensor device are + # not connected (wrapper subclasses and tensors rebuilt using numpy) + torch._utils._thread_local_state.map_location = map_location result = unpickler.load() + del torch._utils._thread_local_state.map_location torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata(