mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix map_location for wrapper subclass and device tensors that go through numpy (#126728)
Fixes https://github.com/pytorch/pytorch/issues/124418 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126728 Approved by: https://github.com/albanD
This commit is contained in:
parent
4ff9113e3d
commit
87f79af24d
|
|
@ -7,12 +7,18 @@ import tempfile
|
||||||
import types
|
import types
|
||||||
import unittest
|
import unittest
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import torch.testing._internal.common_utils as common
|
import torch.testing._internal.common_utils as common
|
||||||
import torch.utils.cpp_extension
|
import torch.utils.cpp_extension
|
||||||
from torch.testing._internal.common_utils import IS_ARM64, skipIfTorchDynamo, TEST_CUDA
|
from torch.testing._internal.common_utils import (
|
||||||
|
IS_ARM64,
|
||||||
|
skipIfTorchDynamo,
|
||||||
|
TemporaryFileName,
|
||||||
|
TEST_CUDA,
|
||||||
|
)
|
||||||
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -572,6 +578,24 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||||
self.assertEqual(z_cpu, z[0])
|
self.assertEqual(z_cpu, z[0])
|
||||||
self.assertEqual(z_cpu, z[1])
|
self.assertEqual(z_cpu, z[1])
|
||||||
|
|
||||||
|
def test_open_device_numpy_serialization_map_location(self):
|
||||||
|
torch.utils.rename_privateuse1_backend("foo")
|
||||||
|
device = self.module.custom_device()
|
||||||
|
default_protocol = torch.serialization.DEFAULT_PROTOCOL
|
||||||
|
# This is a hack to test serialization through numpy
|
||||||
|
with patch.object(torch._C, "_has_storage", return_value=False):
|
||||||
|
x = torch.randn(2, 3)
|
||||||
|
x_foo = x.to(device)
|
||||||
|
sd = {"x": x_foo}
|
||||||
|
rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0]
|
||||||
|
self.assertTrue(
|
||||||
|
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
|
||||||
|
)
|
||||||
|
with TemporaryFileName() as f:
|
||||||
|
torch.save(sd, f)
|
||||||
|
sd_loaded = torch.load(f, map_location="cpu")
|
||||||
|
self.assertTrue(sd_loaded["x"].is_cpu)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
common.run_tests()
|
common.run_tests()
|
||||||
|
|
|
||||||
|
|
@ -4403,6 +4403,22 @@ class TestSubclassSerialization(TestCase):
|
||||||
with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
|
with self.assertRaisesRegex(TypeError, "'str' object is not callable"):
|
||||||
loaded = torch.load(f, weights_only=True)
|
loaded = torch.load(f, weights_only=True)
|
||||||
|
|
||||||
|
@unittest.skipIf(not torch.cuda.is_available(), "map_location loads to cuda")
|
||||||
|
def test_tensor_subclass_map_location(self):
|
||||||
|
t = TwoTensor(torch.randn(2, 3), torch.randn(2, 3))
|
||||||
|
sd = {'t': t}
|
||||||
|
|
||||||
|
with TemporaryFileName() as f:
|
||||||
|
torch.save(sd, f)
|
||||||
|
sd_loaded = torch.load(f, map_location=torch.device('cuda:0'))
|
||||||
|
self.assertTrue(sd_loaded['t'].device == torch.device('cuda:0'))
|
||||||
|
self.assertTrue(sd_loaded['t'].a.device == torch.device('cuda:0'))
|
||||||
|
self.assertTrue(sd_loaded['t'].b.device == torch.device('cuda:0'))
|
||||||
|
# make sure map_location is not propagated over multiple torch.load calls
|
||||||
|
sd_loaded = torch.load(f)
|
||||||
|
self.assertTrue(sd_loaded['t'].device == torch.device('cpu'))
|
||||||
|
self.assertTrue(sd_loaded['t'].a.device == torch.device('cpu'))
|
||||||
|
self.assertTrue(sd_loaded['t'].b.device == torch.device('cpu'))
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestBothSerialization, globals())
|
instantiate_device_type_tests(TestBothSerialization, globals())
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ import copyreg
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
@ -108,6 +109,31 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||||||
return kwargs["async"]
|
return kwargs["async"]
|
||||||
|
|
||||||
|
|
||||||
|
_thread_local_state = threading.local()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_restore_location(device):
|
||||||
|
"""Return the map_location location.
|
||||||
|
|
||||||
|
Used for rebuild functions where the tensor device is distinct from the storage
|
||||||
|
"""
|
||||||
|
|
||||||
|
map_location = getattr(_thread_local_state, "map_location", None)
|
||||||
|
if map_location is None:
|
||||||
|
return device
|
||||||
|
else:
|
||||||
|
if isinstance(map_location, dict):
|
||||||
|
return map_location.get(device, device)
|
||||||
|
elif isinstance(map_location, (str, torch.device)):
|
||||||
|
return map_location
|
||||||
|
else:
|
||||||
|
assert callable(map_location)
|
||||||
|
raise RuntimeError(
|
||||||
|
"Callable map_location not supported with _rebuild_wrapper_subclass "
|
||||||
|
"or _rebuild_device_tensor_from_numpy"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Note [Don't serialize hooks]
|
# Note [Don't serialize hooks]
|
||||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
# Since time immemorial, we have serialized the backward hooks associated with
|
# Since time immemorial, we have serialized the backward hooks associated with
|
||||||
|
|
@ -303,6 +329,7 @@ def _rebuild_nested_tensor(buffer, sizes, strides, storage_offsets):
|
||||||
|
|
||||||
|
|
||||||
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
|
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
|
||||||
|
device = _get_restore_location(device)
|
||||||
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
||||||
tensor.requires_grad = requires_grad
|
tensor.requires_grad = requires_grad
|
||||||
return tensor
|
return tensor
|
||||||
|
|
@ -321,6 +348,7 @@ def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
|
||||||
def _rebuild_wrapper_subclass(
|
def _rebuild_wrapper_subclass(
|
||||||
cls, dtype, size, stride, storage_offset, layout, device, requires_grad
|
cls, dtype, size, stride, storage_offset, layout, device, requires_grad
|
||||||
):
|
):
|
||||||
|
device = _get_restore_location(device)
|
||||||
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||||
cls,
|
cls,
|
||||||
size,
|
size,
|
||||||
|
|
|
||||||
|
|
@ -1453,7 +1453,11 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', overall
|
||||||
|
|
||||||
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
|
unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
|
||||||
unpickler.persistent_load = persistent_load
|
unpickler.persistent_load = persistent_load
|
||||||
|
# Needed for tensors where storage device and rebuild tensor device are
|
||||||
|
# not connected (wrapper subclasses and tensors rebuilt using numpy)
|
||||||
|
torch._utils._thread_local_state.map_location = map_location
|
||||||
result = unpickler.load()
|
result = unpickler.load()
|
||||||
|
del torch._utils._thread_local_state.map_location
|
||||||
|
|
||||||
torch._utils._validate_loaded_sparse_tensors()
|
torch._utils._validate_loaded_sparse_tensors()
|
||||||
torch._C._log_api_usage_metadata(
|
torch._C._log_api_usage_metadata(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user