pytorch/test/test_content_store.py
Edward Z. Yang ce1ad1c143 Add load_storage (#100519)
This adds a new operator debugprims::load_storage which does the unusual thing of loading a tensor from disk (via ContentStoreReader). This will be used in a later PR to implement delta debugging in the minifier, even when the repro is too big to fit into memory. The way it works is that you specify a name of the tensor you want to load, as well as enough metadata to reconstruct the tensor, if the store isn't available. If there is an active content store, we read and return the tensor from that store; otherwise we use `rand_strided` to create it.

I needed some infra improvements to do this:

* `custom_op` now supports factory functions. Factory functions have to be registered specially via `impl_factory`
* I modified `clone_input` to also support dtype conversion, which I use to change the dtype of a loaded tensor if necessary.
* ContentStore needs to work with a device argument, so we torch.load directly to the correct device. This is for fake tensor support.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100519
Approved by: https://github.com/zou3519, https://github.com/anijain2305
2023-05-05 05:25:03 +00:00

122 lines
4.3 KiB
Python

# Owner(s): ["oncall: pt2"]
import tempfile
import torch
from torch._prims.debug_prims import load_tensor_reader
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
from torch.multiprocessing.reductions import StorageWeakRef
from torch.testing._internal.common_device_type import instantiate_device_type_tests
from torch.testing._internal.common_utils import run_tests, skipIfRocm, TestCase
from torch.utils._content_store import (
ContentStoreReader,
ContentStoreWriter,
hash_storage,
)
class TestContentStore(TestCase):
def test_basic(self, device):
# setup test data
x = torch.randn(4, device=device)
y = torch.randn(6, device=device)
z = x.view(2, 2)
# start writing
with tempfile.TemporaryDirectory() as loc:
writer = ContentStoreWriter(loc)
writer.write_tensor("x", x)
writer.write_tensor("y", y)
writer.write_tensor("z", z)
# do some mutation that is VC UNTRACKED
x.data.add_(1)
writer.write_tensor("x2", x)
writer.write_tensor("y2", y)
writer.write_tensor("z2", z)
del writer
reader = ContentStoreReader(loc)
n_x = reader.read_tensor("x")
n_y = reader.read_tensor("y")
n_z = reader.read_tensor("z")
self.assertEqual(n_x + 1, x)
self.assertEqual(n_y, y)
self.assertEqual(n_z + 1, z)
self.assertEqual(
StorageWeakRef(n_x.untyped_storage()),
StorageWeakRef(n_z.untyped_storage()),
)
n_x2 = reader.read_tensor("x2")
n_y2 = reader.read_tensor("y2")
n_z2 = reader.read_tensor("z2")
self.assertEqual(n_x2, x)
self.assertEqual(n_y2, y)
self.assertEqual(n_z2, z)
self.assertEqual(
StorageWeakRef(n_y2.untyped_storage()),
StorageWeakRef(n_y.untyped_storage()),
)
def test_scalar(self, device):
# Should not raise an error
hash_storage(torch.tensor(2, device=device).untyped_storage())
@skipIfRocm
def test_load_tensor(self, device):
with tempfile.TemporaryDirectory() as loc:
writer = ContentStoreWriter(loc)
x = torch.randn(4, device=device)
def same_meta_as_x(t):
self.assertEqual(t.size(), x.size())
self.assertEqual(t.stride(), x.stride())
self.assertEqual(t.dtype, x.dtype)
self.assertEqual(t.device, x.device)
writer.write_tensor("x", x)
with load_tensor_reader(loc):
x2 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x2)
x3 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertEqual(x, x3)
# Must not alias!
self.assertNotEqual(
StorageWeakRef(x.untyped_storage()),
StorageWeakRef(x2.untyped_storage()),
)
self.assertNotEqual(
StorageWeakRef(x2.untyped_storage()),
StorageWeakRef(x3.untyped_storage()),
)
# Check fake tensor mode works too
with FakeTensorMode():
x4 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
self.assertIsInstance(x4, FakeTensor)
same_meta_as_x(x4)
# Check fp64 works
x5 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float64, device=device
)
self.assertEqual(x5.float(), x)
self.assertEqual(x5.dtype, torch.float64)
x6 = torch.ops.debugprims.load_tensor.default(
"x", (4,), (1,), dtype=torch.float32, device=device
)
same_meta_as_x(x6)
instantiate_device_type_tests(TestContentStore, globals())
if __name__ == "__main__":
run_tests()