mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
## Semantic
The semantic is
(1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint).
```python
import torch
import torch.nn as nn
sd = nn.Linear(3, 5).state_dict()
with torch.serialization.skip_data():
torch.save(sd, 'foo.pt')
print(torch.load('foo.pt', weights_only=True))
```
(2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)
```python
import torch
import torch.nn as nn
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode():
m = nn.Linear(3, 5, dtype=torch.float16, device='cuda')
sd = m.state_dict()
with torch.serialization.skip_data(materialize_fake_tensors=True):
torch.save(sd, 'bla.pt')
print(torch.load('bla.pt', weights_only=True))
# OrderedDict([('weight', tensor([[0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.],
# [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))])
```
## Follow Ups
- [ ] `torch.load` semantic for skip_data context manager
- [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass)
Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134504
Approved by: https://github.com/albanD
585 lines
22 KiB
Python
585 lines
22 KiB
Python
# Owner(s): ["module: cpp-extensions"]
|
|
|
|
import _codecs
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import types
|
|
import unittest
|
|
from typing import Union
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch.testing._internal.common_utils as common
|
|
import torch.utils.cpp_extension
|
|
from torch.serialization import safe_globals
|
|
from torch.testing._internal.common_utils import (
|
|
IS_ARM64,
|
|
skipIfTorchDynamo,
|
|
TemporaryFileName,
|
|
TEST_CUDA,
|
|
TEST_XPU,
|
|
)
|
|
from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME
|
|
|
|
|
|
TEST_CUDA = TEST_CUDA and CUDA_HOME is not None
|
|
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
|
|
|
|
|
|
def remove_build_path():
|
|
if sys.platform == "win32":
|
|
# Not wiping extensions build folder because Windows
|
|
return
|
|
default_build_root = torch.utils.cpp_extension.get_default_build_root()
|
|
if os.path.exists(default_build_root):
|
|
shutil.rmtree(default_build_root, ignore_errors=True)
|
|
|
|
|
|
def generate_faked_module():
|
|
def device_count() -> int:
|
|
return 1
|
|
|
|
def get_rng_state(device: Union[int, str, torch.device] = "foo") -> torch.Tensor:
|
|
# create a tensor using our custom device object.
|
|
return torch.empty(4, 4, device="foo")
|
|
|
|
def set_rng_state(
|
|
new_state: torch.Tensor, device: Union[int, str, torch.device] = "foo"
|
|
) -> None:
|
|
pass
|
|
|
|
def is_available():
|
|
return True
|
|
|
|
def current_device():
|
|
return 0
|
|
|
|
# create a new module to fake torch.foo dynamicaly
|
|
foo = types.ModuleType("foo")
|
|
|
|
foo.device_count = device_count
|
|
foo.get_rng_state = get_rng_state
|
|
foo.set_rng_state = set_rng_state
|
|
foo.is_available = is_available
|
|
foo.current_device = current_device
|
|
foo._lazy_init = lambda: None
|
|
foo.is_initialized = lambda: True
|
|
|
|
return foo
|
|
|
|
|
|
@unittest.skipIf(IS_ARM64, "Does not work on arm")
|
|
@unittest.skipIf(TEST_XPU, "XPU does not support cppextension currently")
|
|
@torch.testing._internal.common_utils.markDynamoStrictTest
|
|
class TestCppExtensionOpenRgistration(common.TestCase):
|
|
"""Tests Open Device Registration with C++ extensions."""
|
|
|
|
module = None
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
|
|
# cpp extensions use relative paths. Those paths are relative to
|
|
# this file, so we'll change the working directory temporarily
|
|
self.old_working_dir = os.getcwd()
|
|
os.chdir(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
assert self.module is not None
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
# return the working directory (see setUp)
|
|
os.chdir(self.old_working_dir)
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
remove_build_path()
|
|
|
|
cls.module = torch.utils.cpp_extension.load(
|
|
name="custom_device_extension",
|
|
sources=[
|
|
"cpp_extensions/open_registration_extension.cpp",
|
|
],
|
|
extra_include_paths=["cpp_extensions"],
|
|
extra_cflags=["-g"],
|
|
verbose=True,
|
|
)
|
|
|
|
# register torch.foo module and foo device to torch
|
|
torch.utils.rename_privateuse1_backend("foo")
|
|
torch.utils.generate_methods_for_privateuse1_backend(for_storage=True)
|
|
torch._register_device_module("foo", generate_faked_module())
|
|
|
|
def test_base_device_registration(self):
|
|
self.assertFalse(self.module.custom_add_called())
|
|
# create a tensor using our custom device object
|
|
device = self.module.custom_device()
|
|
x = torch.empty(4, 4, device=device)
|
|
y = torch.empty(4, 4, device=device)
|
|
# Check that our device is correct.
|
|
self.assertTrue(x.device == device)
|
|
self.assertFalse(x.is_cpu)
|
|
self.assertFalse(self.module.custom_add_called())
|
|
# calls out custom add kernel, registered to the dispatcher
|
|
z = x + y
|
|
# check that it was called
|
|
self.assertTrue(self.module.custom_add_called())
|
|
z_cpu = z.to(device="cpu")
|
|
# Check that our cross-device copy correctly copied the data to cpu
|
|
self.assertTrue(z_cpu.is_cpu)
|
|
self.assertFalse(z.is_cpu)
|
|
self.assertTrue(z.device == device)
|
|
self.assertEqual(z, z_cpu)
|
|
|
|
def test_common_registration(self):
|
|
# check unsupported device and duplicated registration
|
|
with self.assertRaisesRegex(RuntimeError, "Expected one of cpu"):
|
|
torch._register_device_module("dev", generate_faked_module())
|
|
with self.assertRaisesRegex(RuntimeError, "The runtime module of"):
|
|
torch._register_device_module("foo", generate_faked_module())
|
|
|
|
# backend name can be renamed to the same name multiple times
|
|
torch.utils.rename_privateuse1_backend("foo")
|
|
|
|
# backend name can't be renamed multiple times to different names.
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "torch.register_privateuse1_backend()"
|
|
):
|
|
torch.utils.rename_privateuse1_backend("dev")
|
|
|
|
# generator tensor and module can be registered only once
|
|
with self.assertRaisesRegex(RuntimeError, "The custom device module of"):
|
|
torch.utils.generate_methods_for_privateuse1_backend()
|
|
|
|
# check whether torch.foo have been registered correctly
|
|
self.assertTrue(
|
|
torch.utils.backend_registration._get_custom_mod_func("device_count")() == 1
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "Try to call torch.foo"):
|
|
torch.utils.backend_registration._get_custom_mod_func("func_name_")
|
|
|
|
# check attributes after registered
|
|
self.assertTrue(hasattr(torch.Tensor, "is_foo"))
|
|
self.assertTrue(hasattr(torch.Tensor, "foo"))
|
|
self.assertTrue(hasattr(torch.TypedStorage, "is_foo"))
|
|
self.assertTrue(hasattr(torch.TypedStorage, "foo"))
|
|
self.assertTrue(hasattr(torch.UntypedStorage, "is_foo"))
|
|
self.assertTrue(hasattr(torch.UntypedStorage, "foo"))
|
|
self.assertTrue(hasattr(torch.nn.Module, "foo"))
|
|
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "is_foo"))
|
|
self.assertTrue(hasattr(torch.nn.utils.rnn.PackedSequence, "foo"))
|
|
|
|
def test_open_device_generator_registration_and_hooks(self):
|
|
device = self.module.custom_device()
|
|
# None of our CPU operations should call the custom add function.
|
|
self.assertFalse(self.module.custom_add_called())
|
|
|
|
# check generator registered before using
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Please register a generator to the PrivateUse1 dispatch key",
|
|
):
|
|
torch.Generator(device=device)
|
|
|
|
self.module.register_generator_first()
|
|
gen = torch.Generator(device=device)
|
|
self.assertTrue(gen.device == device)
|
|
|
|
# generator can be registered only once
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Only can register a generator to the PrivateUse1 dispatch key once",
|
|
):
|
|
self.module.register_generator_second()
|
|
|
|
if self.module.is_register_hook() is False:
|
|
self.module.register_hook()
|
|
default_gen = self.module.default_generator(0)
|
|
self.assertTrue(
|
|
default_gen.device.type == torch._C._get_privateuse1_backend_name()
|
|
)
|
|
|
|
def test_open_device_dispatchstub(self):
|
|
# test kernels could be reused by privateuse1 backend through dispatchstub
|
|
input_data = torch.randn(2, 2, 3, dtype=torch.float32, device="cpu")
|
|
foo_input_data = input_data.to("foo")
|
|
output_data = torch.abs(input_data)
|
|
foo_output_data = torch.abs(foo_input_data)
|
|
self.assertEqual(output_data, foo_output_data.cpu())
|
|
|
|
output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
|
|
# output operand will resize flag is True in TensorIterator.
|
|
foo_input_data = input_data.to("foo")
|
|
foo_output_data = output_data.to("foo")
|
|
# output operand will resize flag is False in TensorIterator.
|
|
torch.abs(input_data, out=output_data[:, :, 0:6:2])
|
|
torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:2])
|
|
self.assertEqual(output_data, foo_output_data.cpu())
|
|
|
|
# output operand will resize flag is True in TensorIterator.
|
|
# and convert output to contiguous tensor in TensorIterator.
|
|
output_data = torch.randn(2, 2, 6, dtype=torch.float32, device="cpu")
|
|
foo_input_data = input_data.to("foo")
|
|
foo_output_data = output_data.to("foo")
|
|
torch.abs(input_data, out=output_data[:, :, 0:6:3])
|
|
torch.abs(foo_input_data, out=foo_output_data[:, :, 0:6:3])
|
|
self.assertEqual(output_data, foo_output_data.cpu())
|
|
|
|
def test_open_device_quantized(self):
|
|
input_data = torch.randn(3, 4, 5, dtype=torch.float32, device="cpu").to("foo")
|
|
quantized_tensor = torch.quantize_per_tensor(input_data, 0.1, 10, torch.qint8)
|
|
self.assertEqual(quantized_tensor.device, torch.device("foo:0"))
|
|
self.assertEqual(quantized_tensor.dtype, torch.qint8)
|
|
|
|
def test_open_device_random(self):
|
|
# check if torch.foo have implemented get_rng_state
|
|
with torch.random.fork_rng(device_type="foo"):
|
|
pass
|
|
|
|
def test_open_device_tensor(self):
|
|
device = self.module.custom_device()
|
|
|
|
# check whether print tensor.type() meets the expectation
|
|
dtypes = {
|
|
torch.bool: "torch.foo.BoolTensor",
|
|
torch.double: "torch.foo.DoubleTensor",
|
|
torch.float32: "torch.foo.FloatTensor",
|
|
torch.half: "torch.foo.HalfTensor",
|
|
torch.int32: "torch.foo.IntTensor",
|
|
torch.int64: "torch.foo.LongTensor",
|
|
torch.int8: "torch.foo.CharTensor",
|
|
torch.short: "torch.foo.ShortTensor",
|
|
torch.uint8: "torch.foo.ByteTensor",
|
|
}
|
|
for tt, dt in dtypes.items():
|
|
test_tensor = torch.empty(4, 4, dtype=tt, device=device)
|
|
self.assertTrue(test_tensor.type() == dt)
|
|
|
|
# check whether the attributes and methods of the corresponding custom backend are generated correctly
|
|
x = torch.empty(4, 4)
|
|
self.assertFalse(x.is_foo)
|
|
|
|
x = x.foo(torch.device("foo"))
|
|
self.assertFalse(self.module.custom_add_called())
|
|
self.assertTrue(x.is_foo)
|
|
|
|
# test different device type input
|
|
y = torch.empty(4, 4)
|
|
self.assertFalse(y.is_foo)
|
|
|
|
y = y.foo(torch.device("foo:0"))
|
|
self.assertFalse(self.module.custom_add_called())
|
|
self.assertTrue(y.is_foo)
|
|
|
|
# test different device type input
|
|
z = torch.empty(4, 4)
|
|
self.assertFalse(z.is_foo)
|
|
|
|
z = z.foo(0)
|
|
self.assertFalse(self.module.custom_add_called())
|
|
self.assertTrue(z.is_foo)
|
|
|
|
def test_open_device_packed_sequence(self):
|
|
device = self.module.custom_device()
|
|
a = torch.rand(5, 3)
|
|
b = torch.tensor([1, 1, 1, 1, 1])
|
|
input = torch.nn.utils.rnn.PackedSequence(a, b)
|
|
self.assertFalse(input.is_foo)
|
|
input_foo = input.foo()
|
|
self.assertTrue(input_foo.is_foo)
|
|
|
|
def test_open_device_storage(self):
|
|
# check whether the attributes and methods for storage of the corresponding custom backend are generated correctly
|
|
x = torch.empty(4, 4)
|
|
z1 = x.storage()
|
|
self.assertFalse(z1.is_foo)
|
|
|
|
z1 = z1.foo()
|
|
self.assertFalse(self.module.custom_add_called())
|
|
self.assertTrue(z1.is_foo)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
|
|
z1.foo(torch.device("cpu"))
|
|
|
|
z1 = z1.cpu()
|
|
self.assertFalse(self.module.custom_add_called())
|
|
self.assertFalse(z1.is_foo)
|
|
|
|
z1 = z1.foo(device="foo:0", non_blocking=False)
|
|
self.assertFalse(self.module.custom_add_called())
|
|
self.assertTrue(z1.is_foo)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Invalid device"):
|
|
z1.foo(device="cuda:0", non_blocking=False)
|
|
|
|
# check UntypedStorage
|
|
y = torch.empty(4, 4)
|
|
z2 = y.untyped_storage()
|
|
self.assertFalse(z2.is_foo)
|
|
|
|
z2 = z2.foo()
|
|
self.assertFalse(self.module.custom_add_called())
|
|
self.assertTrue(z2.is_foo)
|
|
|
|
# check custom StorageImpl create
|
|
self.module.custom_storage_registry()
|
|
|
|
z3 = y.untyped_storage()
|
|
self.assertFalse(self.module.custom_storageImpl_called())
|
|
|
|
z3 = z3.foo()
|
|
self.assertTrue(self.module.custom_storageImpl_called())
|
|
self.assertFalse(self.module.custom_storageImpl_called())
|
|
|
|
z3 = z3[0:3]
|
|
self.assertTrue(self.module.custom_storageImpl_called())
|
|
|
|
@skipIfTorchDynamo("unsupported aten.is_pinned.default")
|
|
def test_open_device_storage_pin_memory(self):
|
|
# Check if the pin_memory is functioning properly on custom device
|
|
cpu_tensor = torch.empty(3)
|
|
self.assertFalse(cpu_tensor.is_foo)
|
|
self.assertFalse(cpu_tensor.is_pinned("foo"))
|
|
|
|
cpu_tensor_pin = cpu_tensor.pin_memory("foo")
|
|
self.assertTrue(cpu_tensor_pin.is_pinned("foo"))
|
|
|
|
# Test storage pin_memory and is_pin
|
|
cpu_storage = cpu_tensor.storage()
|
|
# We implement a dummy pin_memory of no practical significance
|
|
# for custom device. Once tensor.pin_memory() has been called,
|
|
# then tensor.is_pinned() will always return true no matter
|
|
# what tensor it's called on.
|
|
self.assertTrue(cpu_storage.is_pinned("foo"))
|
|
|
|
cpu_storage_pinned = cpu_storage.pin_memory("foo")
|
|
self.assertTrue(cpu_storage_pinned.is_pinned("foo"))
|
|
|
|
# Test untyped storage pin_memory and is_pin
|
|
cpu_tensor = torch.randn([3, 2, 1, 4])
|
|
cpu_untyped_storage = cpu_tensor.untyped_storage()
|
|
self.assertTrue(cpu_untyped_storage.is_pinned("foo"))
|
|
|
|
cpu_untyped_storage_pinned = cpu_untyped_storage.pin_memory("foo")
|
|
self.assertTrue(cpu_untyped_storage_pinned.is_pinned("foo"))
|
|
|
|
@unittest.skip(
|
|
"Temporarily disable due to the tiny differences between clang++ and g++ in defining static variable in inline function"
|
|
)
|
|
def test_open_device_serialization(self):
|
|
self.module.set_custom_device_index(-1)
|
|
storage = torch.UntypedStorage(4, device=torch.device("foo"))
|
|
self.assertEqual(torch.serialization.location_tag(storage), "foo")
|
|
|
|
self.module.set_custom_device_index(0)
|
|
storage = torch.UntypedStorage(4, device=torch.device("foo"))
|
|
self.assertEqual(torch.serialization.location_tag(storage), "foo:0")
|
|
|
|
cpu_storage = torch.empty(4, 4).storage()
|
|
foo_storage = torch.serialization.default_restore_location(cpu_storage, "foo:0")
|
|
self.assertTrue(foo_storage.is_foo)
|
|
|
|
# test tensor MetaData serialization
|
|
x = torch.empty(4, 4).long()
|
|
y = x.foo()
|
|
self.assertFalse(self.module.check_backend_meta(y))
|
|
self.module.custom_set_backend_meta(y)
|
|
self.assertTrue(self.module.check_backend_meta(y))
|
|
|
|
self.module.custom_serialization_registry()
|
|
with tempfile.TemporaryDirectory() as tmpdir:
|
|
path = os.path.join(tmpdir, "data.pt")
|
|
torch.save(y, path)
|
|
z1 = torch.load(path)
|
|
# loads correctly onto the foo backend device
|
|
self.assertTrue(z1.is_foo)
|
|
# loads BackendMeta data correctly
|
|
self.assertTrue(self.module.check_backend_meta(z1))
|
|
|
|
# cross-backend
|
|
z2 = torch.load(path, map_location="cpu")
|
|
# loads correctly onto the cpu backend device
|
|
self.assertFalse(z2.is_foo)
|
|
# loads BackendMeta data correctly
|
|
self.assertFalse(self.module.check_backend_meta(z2))
|
|
|
|
def test_open_device_storage_resize(self):
|
|
cpu_tensor = torch.randn([8])
|
|
foo_tensor = cpu_tensor.foo()
|
|
foo_storage = foo_tensor.storage()
|
|
self.assertTrue(foo_storage.size() == 8)
|
|
|
|
# Only register tensor resize_ function.
|
|
foo_tensor.resize_(8)
|
|
self.assertTrue(foo_storage.size() == 8)
|
|
|
|
with self.assertRaisesRegex(TypeError, "Overflow"):
|
|
foo_tensor.resize_(8**29)
|
|
|
|
def test_open_device_storage_type(self):
|
|
# test cpu float storage
|
|
cpu_tensor = torch.randn([8]).float()
|
|
cpu_storage = cpu_tensor.storage()
|
|
self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
|
|
|
|
# test custom float storage before defining FloatStorage
|
|
foo_tensor = cpu_tensor.foo()
|
|
foo_storage = foo_tensor.storage()
|
|
self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage")
|
|
|
|
class CustomFloatStorage:
|
|
@property
|
|
def __module__(self):
|
|
return "torch." + torch._C._get_privateuse1_backend_name()
|
|
|
|
@property
|
|
def __name__(self):
|
|
return "FloatStorage"
|
|
|
|
# test custom float storage after defining FloatStorage
|
|
try:
|
|
torch.foo.FloatStorage = CustomFloatStorage()
|
|
self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage")
|
|
|
|
# test custom int storage after defining FloatStorage
|
|
foo_tensor2 = torch.randn([8]).int().foo()
|
|
foo_storage2 = foo_tensor2.storage()
|
|
self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage")
|
|
finally:
|
|
torch.foo.FloatStorage = None
|
|
|
|
def test_open_device_faketensor(self):
|
|
with torch._subclasses.fake_tensor.FakeTensorMode.push():
|
|
a = torch.empty(1, device="foo")
|
|
b = torch.empty(1, device="foo:0")
|
|
result = a + b
|
|
|
|
def test_open_device_named_tensor(self):
|
|
torch.empty([2, 3, 4, 5], device="foo", names=["N", "C", "H", "W"])
|
|
|
|
# Not an open registration test - this file is just very convenient
|
|
# for testing torch.compile on custom C++ operators
|
|
def test_compile_autograd_function_returns_self(self):
|
|
x_ref = torch.randn(4, requires_grad=True)
|
|
out_ref = self.module.custom_autograd_fn_returns_self(x_ref)
|
|
out_ref.sum().backward()
|
|
|
|
x_test = x_ref.clone().detach().requires_grad_(True)
|
|
f_compiled = torch.compile(self.module.custom_autograd_fn_returns_self)
|
|
out_test = f_compiled(x_test)
|
|
out_test.sum().backward()
|
|
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(x_ref.grad, x_test.grad)
|
|
|
|
# Not an open registration test - this file is just very convenient
|
|
# for testing torch.compile on custom C++ operators
|
|
@skipIfTorchDynamo("Temporary disabled due to torch._ops.OpOverloadPacket")
|
|
def test_compile_autograd_function_aliasing(self):
|
|
x_ref = torch.randn(4, requires_grad=True)
|
|
out_ref = torch.ops._test_funcs.custom_autograd_fn_aliasing(x_ref)
|
|
out_ref.sum().backward()
|
|
|
|
x_test = x_ref.clone().detach().requires_grad_(True)
|
|
f_compiled = torch.compile(torch.ops._test_funcs.custom_autograd_fn_aliasing)
|
|
out_test = f_compiled(x_test)
|
|
out_test.sum().backward()
|
|
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(x_ref.grad, x_test.grad)
|
|
|
|
def test_open_device_scalar_type_fallback(self):
|
|
z_cpu = torch.Tensor([[0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2]]).to(torch.int64)
|
|
z = torch.triu_indices(3, 3, device="foo")
|
|
self.assertEqual(z_cpu, z)
|
|
|
|
def test_open_device_tensor_type_fallback(self):
|
|
# create tensors located in custom device
|
|
x = torch.Tensor([[1, 2, 3], [2, 3, 4]]).to("foo")
|
|
y = torch.Tensor([1, 0, 2]).to("foo")
|
|
# create result tensor located in cpu
|
|
z_cpu = torch.Tensor([[0, 2, 1], [1, 3, 2]])
|
|
# Check that our device is correct.
|
|
device = self.module.custom_device()
|
|
self.assertTrue(x.device == device)
|
|
self.assertFalse(x.is_cpu)
|
|
|
|
# call sub op, which will fallback to cpu
|
|
z = torch.sub(x, y)
|
|
self.assertEqual(z_cpu, z)
|
|
|
|
# call index op, which will fallback to cpu
|
|
z_cpu = torch.Tensor([3, 1])
|
|
y = torch.Tensor([1, 0]).long().to("foo")
|
|
z = x[y, y]
|
|
self.assertEqual(z_cpu, z)
|
|
|
|
def test_open_device_tensorlist_type_fallback(self):
|
|
# create tensors located in custom device
|
|
v_foo = torch.Tensor([1, 2, 3]).to("foo")
|
|
# create result tensor located in cpu
|
|
z_cpu = torch.Tensor([2, 4, 6])
|
|
# create tensorlist for foreach_add op
|
|
x = (v_foo, v_foo)
|
|
y = (v_foo, v_foo)
|
|
# Check that our device is correct.
|
|
device = self.module.custom_device()
|
|
self.assertTrue(v_foo.device == device)
|
|
self.assertFalse(v_foo.is_cpu)
|
|
|
|
# call _foreach_add op, which will fallback to cpu
|
|
z = torch._foreach_add(x, y)
|
|
self.assertEqual(z_cpu, z[0])
|
|
self.assertEqual(z_cpu, z[1])
|
|
|
|
# call _fused_adamw_ with undefined tensor.
|
|
self.module.fallback_with_undefined_tensor()
|
|
|
|
def test_open_device_numpy_serialization(self):
|
|
torch.utils.rename_privateuse1_backend("foo")
|
|
device = self.module.custom_device()
|
|
default_protocol = torch.serialization.DEFAULT_PROTOCOL
|
|
# This is a hack to test serialization through numpy
|
|
with patch.object(torch._C, "_has_storage", return_value=False):
|
|
x = torch.randn(2, 3)
|
|
x_foo = x.to(device)
|
|
sd = {"x": x_foo}
|
|
rebuild_func = x_foo._reduce_ex_internal(default_protocol)[0]
|
|
self.assertTrue(
|
|
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
|
|
)
|
|
# Test map_location
|
|
with TemporaryFileName() as f:
|
|
torch.save(sd, f)
|
|
with safe_globals(
|
|
[
|
|
np.core.multiarray._reconstruct,
|
|
np.ndarray,
|
|
np.dtype,
|
|
_codecs.encode,
|
|
type(np.dtype(np.float32))
|
|
if np.__version__ < "1.25.0"
|
|
else np.dtypes.Float32DType,
|
|
]
|
|
):
|
|
sd_loaded = torch.load(f, map_location="cpu")
|
|
self.assertTrue(sd_loaded["x"].is_cpu)
|
|
|
|
# Test metadata_only
|
|
with TemporaryFileName() as f:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Cannot serialize tensors on backends with no storage under skip_data context manager",
|
|
):
|
|
with torch.serialization.skip_data():
|
|
torch.save(sd, f)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
common.run_tests()
|