Fix map_location for wrapper subclass and device tensors that go through numpy (#126728)

Fixes https://github.com/pytorch/pytorch/issues/124418

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126728
Approved by: https://github.com/albanD
This commit is contained in:
Mikayla Gawarecki 2024-05-23 22:32:29 +00:00 committed by PyTorch MergeBot
parent 4ff9113e3d
commit 87f79af24d
4 changed files with 73 additions and 1 deletions

View File

@ -7,12 +7,18 @@ import tempfile
import types import types
import unittest import unittest
from typing import Union from typing import Union
from unittest.mock import patch
import torch import torch
import torch.testing._internal.common_utils as common import torch.testing._internal.common_utils as common
import torch.utils.cpp_extension import torch.utils.cpp_extension
from torch.testing._internal.common_utils import IS_ARM64, skipIfTorchDynamo, TEST_CUDA from torch.testing._internal.common_utils import (
IS_ARM64,
skipIfTorchDynamo,
TemporaryFileName,
TEST_CUDA,
)
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
@ -572,6 +578,24 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.assertEqual(z_cpu, z[0]) self.assertEqual(z_cpu, z[0])
self.assertEqual(z_cpu, z[1]) self.assertEqual(z_cpu, z[1])
def test_open_device_numpy_serialization_map_location(self):
torch.utils.rename_privateuse1_backend("foo")
device = self.module.custom_device()
default_protocol = torch.serialization.DEFAULT_PROTOCOL
# This is a hack to test serialization through numpy
with patch.object(torch._C, "_has_storage", return_value=False):
x = torch.randn(2, 3)
x_foo = x.to(device)
sd = {"x": x_foo}
rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0]
self.assertTrue(
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
)
with TemporaryFileName() as f:
torch.save(sd, f)
sd_loaded = torch.load(f, map_location="cpu")
self.assertTrue(sd_loaded["x"].is_cpu)
if __name__ == "__main__": if __name__ == "__main__":
common.run_tests() common.run_tests()

View File

@ -4403,6 +4403,22 @@ class TestSubclassSerialization(TestCase):
with self.assertRaisesRegex(TypeError, "'str' object is not callable"): with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
loaded = torch.load(f, weights_only=True) loaded = torch.load(f, weights_only=True)
@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
def test_tensor_subclass_map_location(self):
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
sd = {'t': t}
with TemporaryFileName() as f:
torch.save(sd, f)
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0'))
# make sure map_location is not propagated over multiple torch.load calls
sd_loaded = torch.load(f)
self.assertTrue(sd_loaded['t'].device == torch.device('cpu'))
self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu'))
self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu'))
instantiate_device_type_tests(TestBothSerialization, globals()) instantiate_device_type_tests(TestBothSerialization, globals())

View File

@ -2,6 +2,7 @@ import copyreg
import functools import functools
import logging import logging
import sys import sys
import threading
import traceback import traceback
import warnings import warnings
from collections import defaultdict from collections import defaultdict
@ -108,6 +109,31 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
return kwargs["async"] return kwargs["async"]
_thread_local_state = threading.local()
def _get_restore_location(device):
"""Return the map_location location.
Used for rebuild functions where the tensor device is distinct from the storage
"""
map_location = getattr(_thread_local_state, "map_location", None)
if map_location is None:
return device
else:
if isinstance(map_location, dict):
return map_location.get(device, device)
elif isinstance(map_location, (str, torch.device)):
return map_location
else:
assert callable(map_location)
raise RuntimeError(
"Callable map_location not supported with _rebuild_wrapper_subclass "
"or _rebuild_device_tensor_from_numpy"
)
# Note [Don't serialize hooks] # Note [Don't serialize hooks]
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Since time immemorial, we have serialized the backward hooks associated with # Since time immemorial, we have serialized the backward hooks associated with
@ -303,6 +329,7 @@ def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad): def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
device = _get_restore_location(device)
tensor = torch.from_numpy(data).to(dtype=dtype, device=device) tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad tensor.requires_grad = requires_grad
return tensor return tensor
@ -321,6 +348,7 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
def _rebuild_wrapper_subclass( def _rebuild_wrapper_subclass(
cls, dtype, size, stride, storage_offset, layout, device, requires_grad cls, dtype, size, stride, storage_offset, layout, device, requires_grad
): ):
device = _get_restore_location(device)
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined] return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls, cls,
size, size,

View File

@ -1453,7 +1453,11 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall
unpickler = UnpicklerWrapper(data_file, **pickle_load_args) unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load unpickler.persistent_load = persistent_load
# Needed for tensors where storage device and rebuild tensor device are
# not connected (wrapper subclasses and tensors rebuilt using numpy)
torch._utils._thread_local_state.map_location = map_location
result = unpickler.load() result = unpickler.load()
del torch._utils._thread_local_state.map_location
torch._utils._validate_loaded_sparse_tensors() torch._utils._validate_loaded_sparse_tensors()
torch._C._log_api_usage_metadata( torch._C._log_api_usage_metadata(