pytorch/test/dynamo/test_subclasses.py
ydwu4 cbcd551045 Fix torch.compile FunctionalTensor inputs for higherOrderOps (#107604)
Before this PR, for the added [test](https://github.com/pytorch/pytorch/pull/107604/files#diff-c618f2274b6b5ccc533c580549d2e552edbd9fc5ac0da1aa4b00338525c8f78dR224), which feeds FunctionTensorWrapper inputs to higherOrderOperator, we have an assertion error in this line [code](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1R1284).

The key difference of this PR is this [line ](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1L1263)of check:
```python
        elif (
            isinstance(example_value, FakeTensor)
            and example_value.fake_mode is tx.fake_mode
        ):
```
The original intention of it seems to be dealing with case where we want to wrap an fx proxy for an intermediate fake tensor that's produced by some tensor ops and an example value is provided (as is the case for higherOrderOps [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/higher_order_ops.py#L85)). A fakified FunctionalTensorWrapper(FakeTensor) always fails this check. This PR changes it to checking whether it's already fakified by tx.fake_mode.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107604
Approved by: https://github.com/zou3519
ghstack dependencies: #107569
2023-08-23 02:42:18 +00:00

284 lines
8.9 KiB
Python

# Owner(s): ["module: dynamo"]
import contextlib
import functools
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch._functorch.config
import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch._dynamo.testing import normalize_gm
from torch._functorch.aot_autograd import to_fun
from torch._higher_order_ops.wrap import wrap
class MockSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
class EagerRecordGraphAndInputs:
def __init__(self):
self.graphs = []
self.example_inputs = []
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
self.graphs.append(gm)
self.example_inputs.append(example_inputs)
return gm
@contextlib.contextmanager
def preserve_subclass_config():
old_subclass_set = set(torch._dynamo.config.traceable_tensor_subclasses)
try:
torch._dynamo.config.traceable_tensor_subclasses.add(MockSubclass)
yield
finally:
torch._dynamo.config.traceable_tensor_subclasses.clear()
torch._dynamo.config.traceable_tensor_subclasses.update(old_subclass_set)
class SubclassTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls._exit_stack.enter_context(preserve_subclass_config())
@classmethod
def tearDownClass(cls):
cls._exit_stack.close()
def test_torch_function_state_graph_break(self):
@torch.compile(backend="eager")
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
torch._dynamo.graph_break()
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
input = torch.ones(2, 2)
res, _ = fn(input)
self.assertFalse(res)
def test_torch_function_state_tracing(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
with torch._C.DisableTorchFunctionSubclass():
torch.add(x, 1.0)
input = torch.ones(2, 2)
res = fn(input)
def test_torch_function_state_guards(self):
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x):
torch.add(x, 1.0)
input = torch.ones(2, 2)
with torch._C.DisableTorchFunctionSubclass():
res = fn(input)
res = fn(input)
self.assertEqual(cnt.frame_count, 2)
def test_return_subclass(self):
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return MockSubclass(torch.add(x, 1.0))
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, MockSubclass)
def test_return_local_subclass(self):
class LocalSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
torch._dynamo.config.traceable_tensor_subclasses.add(LocalSubclass)
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
return LocalSubclass(torch.add(x, 1.0))
input = torch.ones(2, 2)
res = fn(input)
self.assertIsInstance(res, LocalSubclass)
def test_compile_with_fake_tensor(self):
x = torch.randn([3, 4])
x2 = torch.randn([4, 3])
cnt = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return torch.sin(x)
f(x)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 1)
f(x2)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 2)
with torch._subclasses.fake_tensor.FakeTensorMode() as fake_mode:
fake_tensor = fake_mode.from_tensor(x)
f(fake_tensor)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 3)
def test_compile_with_functionalization(self):
x = torch.randn([3, 4])
x_clone = x.clone()
x_clone2 = x.clone()
backend = EagerRecordGraphAndInputs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return x.add_(1.0) + torch.nn.functional.relu_(x)
f_out = f(x)
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 3)
self.assertEqual(len(backend.graphs), 1)
self.assertEqual(len(backend.example_inputs), 1)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
add_ = l_x_.add_(1.0)
relu_ = torch.relu_(l_x_); l_x_ = None
add = add_ + relu_; add_ = relu_ = None
return (add,)
"""
actual = normalize_gm(backend.graphs[0].print_readable(print_output=False))
self.assertEqual(actual, expected)
ff = torch.func.functionalize(f)
ff_out = ff(x_clone)
self.assertEqual(cnt.frame_count, 2)
self.assertEqual(cnt.op_count, 6)
self.assertEqual(len(backend.graphs), 2)
self.assertEqual(len(backend.example_inputs), 2)
actual = normalize_gm(backend.graphs[1].print_readable(print_output=False))
self.assertEqual(actual, expected)
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
def aot_f_wrapper(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
torch._enable_functionalization(reapply_views=False)
try:
func_args = pytree.tree_map(to_fun, args)
func_kwargs = pytree.tree_map(to_fun, kwargs)
return func(*func_args, **func_kwargs)
finally:
torch._disable_functionalization()
return wrapper
aot_ff = aot_f_wrapper(f)
aot_ff_out = aot_ff(x_clone2)
self.assertEqual(cnt.frame_count, 3)
self.assertEqual(cnt.op_count, 9)
self.assertEqual(len(backend.graphs), 3)
self.assertEqual(len(backend.example_inputs), 3)
actual = normalize_gm(backend.graphs[2].print_readable(print_output=False))
self.assertEqual(actual, expected)
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
self.assertEqual(f_out, ff_out)
self.assertEqual(f_out, aot_ff_out)
try:
torch._enable_functionalization(reapply_views=False)
xf = pytree.tree_map(to_fun, x)
x_view = xf.t()
with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"):
f(x_view)
finally:
torch._disable_functionalization()
def test_compile_higher_order_with_functionalization(self):
backend = EagerRecordGraphAndInputs()
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
@torch.compile(backend=cnt, fullgraph=True)
def f(x):
return wrap(lambda x: x.add_(1.0), x)
def check_count_and_graph(
exp_frame_count, exp_op_count, exp_n_graph, exp_graph
):
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(cnt.op_count, exp_op_count)
self.assertEqual(len(backend.graphs), exp_n_graph)
actual = normalize_gm(
backend.graphs[exp_n_graph - 1].print_readable(print_output=False)
)
self.assertExpectedInline(actual, exp_graph)
t = torch.randn([3, 4])
t_clone = t.clone()
t_clone2 = t.clone()
f(t)
expected_graph = """\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
wrap_body_0 = self.wrap_body_0
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
return (wrap,)
class GraphModule(torch.nn.Module):
def forward(self, l_x_):
add_ = l_x_.add_(1.0); l_x_ = None
return add_
"""
check_count_and_graph(1, 1, 1, expected_graph)
ff = torch.func.functionalize(f)
ff_out = ff(t_clone)
# frame count and op count are incremented due to re-compilation
check_count_and_graph(2, 2, 2, expected_graph)
try:
x = torch._to_functional_tensor(t_clone2, mirror_autograd_meta=True)
torch._enable_functionalization(reapply_views=False)
aot_f_out = f(x)
finally:
torch._disable_functionalization()
# frame count and op count are incremented due to re-compilation
check_count_and_graph(3, 3, 3, expected_graph)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()