pytorch/test/dynamo/test_subclasses.py
PyTorch MergeBot 96c5be8bc4 Revert "Fakify leaf of FunctionalTensor (#107062)"
This reverts commit 3349725766.

Reverted https://github.com/pytorch/pytorch/pull/107062 on behalf of https://github.com/ydwu4 due to This appears to have broken the test TestDTensorCompile.test_dtensor_fullgraph. Probably a land race ([comment](https://github.com/pytorch/pytorch/pull/107062#issuecomment-1685447747))
2023-08-21 00:30:16 +00:00

136 lines
3.7 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
import torch.utils.checkpoint
class MockSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
@contextlib.contextmanager
def preserve_subclass_config():
old_subclass_set = set(torch._dynamo.config.traceable_tensor_subclasses)
try:
torch._dynamo.config.traceable_tensor_subclasses.add(MockSubclass)
yield
finally:
torch._dynamo.config.traceable_tensor_subclasses.clear()
torch._dynamo.config.traceable_tensor_subclasses.update(old_subclass_set)
class SubclassTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(preserve_subclass_config())
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
def test_torch_function_state_graph_break(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
torch._dynamo.graph_break()
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
input = torch.ones(2, 2)
res, _ = fn(input)
self.assertFalse(res)
def test_torch_function_state_tracing(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
torch.add(x, 1.0)
input = torch.ones(2, 2)
res = fn(input)
def test_torch_function_state_guards(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x):
torch.add(x, 1.0)
input = torch.ones(2, 2)
with torch._C.DisableTorchFunctionSubclass():
res = fn(input)
res = fn(input)
self.assertEqual(cnt.frame_count, 2)
def test_return_subclass(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return MockSubclass(torch.add(x, 1.0))
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, MockSubclass)
def test_return_local_subclass(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
torch._dynamo.config.traceable_tensor_subclasses.add(LocalSubclass)
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return LocalSubclass(torch.add(x, 1.0))
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, LocalSubclass)
def test_compile_with_fake_tensor(self):
x = torch.randn([3, 4])
x2 = torch.randn([4, 3])
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return torch.sin(x)
f(x)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
f(x2)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
with torch._subclasses.fake_tensor.FakeTensorMode() as fake_mode:
fake_tensor = fake_mode.from_tensor(x)
f(fake_tensor)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 3)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()