mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Before this PR, for the added [test](https://github.com/pytorch/pytorch/pull/107604/files#diff-c618f2274b6b5ccc533c580549d2e552edbd9fc5ac0da1aa4b00338525c8f78dR224), which feeds FunctionTensorWrapper inputs to higherOrderOperator, we have an assertion error in this line [code](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1R1284). The key difference of this PR is this [line ](https://github.com/pytorch/pytorch/pull/107604/files#diff-9f0663783bcd93e948e0491ef61b48123bdc9977bcc632fd707da578df13bfa1L1263)of check: ```python elif ( isinstance(example_value, FakeTensor) and example_value.fake_mode is tx.fake_mode ): ``` The original intention of it seems to be dealing with case where we want to wrap an fx proxy for an intermediate fake tensor that's produced by some tensor ops and an example value is provided (as is the case for higherOrderOps [here](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/variables/higher_order_ops.py#L85)). A fakified FunctionalTensorWrapper(FakeTensor) always fails this check. This PR changes it to checking whether it's already fakified by tx.fake_mode. Pull Request resolved: https://github.com/pytorch/pytorch/pull/107604 Approved by: https://github.com/zou3519 ghstack dependencies: #107569
284 lines
8.9 KiB
Python
284 lines
8.9 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import contextlib
|
|
import functools
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._functorch.config
|
|
import torch.utils._pytree as pytree
|
|
import torch.utils.checkpoint
|
|
from torch._dynamo.testing import normalize_gm
|
|
from torch._functorch.aot_autograd import to_fun
|
|
from torch._higher_order_ops.wrap import wrap
|
|
|
|
|
|
class MockSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
class EagerRecordGraphAndInputs:
|
|
def __init__(self):
|
|
self.graphs = []
|
|
self.example_inputs = []
|
|
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
|
self.graphs.append(gm)
|
|
self.example_inputs.append(example_inputs)
|
|
return gm
|
|
|
|
|
|
@contextlib.contextmanager
|
|
def preserve_subclass_config():
|
|
old_subclass_set = set(torch._dynamo.config.traceable_tensor_subclasses)
|
|
try:
|
|
torch._dynamo.config.traceable_tensor_subclasses.add(MockSubclass)
|
|
yield
|
|
finally:
|
|
torch._dynamo.config.traceable_tensor_subclasses.clear()
|
|
torch._dynamo.config.traceable_tensor_subclasses.update(old_subclass_set)
|
|
|
|
|
|
class SubclassTests(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
cls._exit_stack.enter_context(preserve_subclass_config())
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
cls._exit_stack.close()
|
|
|
|
def test_torch_function_state_graph_break(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
torch._dynamo.graph_break()
|
|
return torch._C._is_torch_function_enabled(), torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
res, _ = fn(input)
|
|
self.assertFalse(res)
|
|
|
|
def test_torch_function_state_tracing(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
|
|
def test_torch_function_state_guards(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(x):
|
|
torch.add(x, 1.0)
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
with torch._C.DisableTorchFunctionSubclass():
|
|
res = fn(input)
|
|
|
|
res = fn(input)
|
|
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_return_subclass(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return MockSubclass(torch.add(x, 1.0))
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
self.assertIsInstance(res, MockSubclass)
|
|
|
|
def test_return_local_subclass(self):
|
|
class LocalSubclass(torch.Tensor):
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
return func(*args, **kwargs)
|
|
|
|
torch._dynamo.config.traceable_tensor_subclasses.add(LocalSubclass)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return LocalSubclass(torch.add(x, 1.0))
|
|
|
|
input = torch.ones(2, 2)
|
|
|
|
res = fn(input)
|
|
self.assertIsInstance(res, LocalSubclass)
|
|
|
|
def test_compile_with_fake_tensor(self):
|
|
x = torch.randn([3, 4])
|
|
x2 = torch.randn([4, 3])
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return torch.sin(x)
|
|
|
|
f(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
f(x2)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
with torch._subclasses.fake_tensor.FakeTensorMode() as fake_mode:
|
|
fake_tensor = fake_mode.from_tensor(x)
|
|
f(fake_tensor)
|
|
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_compile_with_functionalization(self):
|
|
x = torch.randn([3, 4])
|
|
x_clone = x.clone()
|
|
x_clone2 = x.clone()
|
|
backend = EagerRecordGraphAndInputs()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return x.add_(1.0) + torch.nn.functional.relu_(x)
|
|
|
|
f_out = f(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
self.assertEqual(len(backend.graphs), 1)
|
|
self.assertEqual(len(backend.example_inputs), 1)
|
|
|
|
expected = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
|
|
add_ = l_x_.add_(1.0)
|
|
relu_ = torch.relu_(l_x_); l_x_ = None
|
|
add = add_ + relu_; add_ = relu_ = None
|
|
return (add,)
|
|
"""
|
|
actual = normalize_gm(backend.graphs[0].print_readable(print_output=False))
|
|
self.assertEqual(actual, expected)
|
|
|
|
ff = torch.func.functionalize(f)
|
|
ff_out = ff(x_clone)
|
|
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 6)
|
|
self.assertEqual(len(backend.graphs), 2)
|
|
self.assertEqual(len(backend.example_inputs), 2)
|
|
actual = normalize_gm(backend.graphs[1].print_readable(print_output=False))
|
|
self.assertEqual(actual, expected)
|
|
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
|
|
|
def aot_f_wrapper(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
torch._enable_functionalization(reapply_views=False)
|
|
try:
|
|
func_args = pytree.tree_map(to_fun, args)
|
|
func_kwargs = pytree.tree_map(to_fun, kwargs)
|
|
return func(*func_args, **func_kwargs)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
return wrapper
|
|
|
|
aot_ff = aot_f_wrapper(f)
|
|
aot_ff_out = aot_ff(x_clone2)
|
|
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(cnt.op_count, 9)
|
|
self.assertEqual(len(backend.graphs), 3)
|
|
self.assertEqual(len(backend.example_inputs), 3)
|
|
actual = normalize_gm(backend.graphs[2].print_readable(print_output=False))
|
|
self.assertEqual(actual, expected)
|
|
self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0]))
|
|
|
|
self.assertEqual(f_out, ff_out)
|
|
self.assertEqual(f_out, aot_ff_out)
|
|
|
|
try:
|
|
torch._enable_functionalization(reapply_views=False)
|
|
xf = pytree.tree_map(to_fun, x)
|
|
x_view = xf.t()
|
|
with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"):
|
|
f(x_view)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
def test_compile_higher_order_with_functionalization(self):
|
|
backend = EagerRecordGraphAndInputs()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f(x):
|
|
return wrap(lambda x: x.add_(1.0), x)
|
|
|
|
def check_count_and_graph(
|
|
exp_frame_count, exp_op_count, exp_n_graph, exp_graph
|
|
):
|
|
self.assertEqual(cnt.frame_count, exp_frame_count)
|
|
self.assertEqual(cnt.op_count, exp_op_count)
|
|
self.assertEqual(len(backend.graphs), exp_n_graph)
|
|
actual = normalize_gm(
|
|
backend.graphs[exp_n_graph - 1].print_readable(print_output=False)
|
|
)
|
|
self.assertExpectedInline(actual, exp_graph)
|
|
|
|
t = torch.randn([3, 4])
|
|
t_clone = t.clone()
|
|
t_clone2 = t.clone()
|
|
f(t)
|
|
|
|
expected_graph = """\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
|
|
wrap_body_0 = self.wrap_body_0
|
|
wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
|
return (wrap,)
|
|
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, l_x_):
|
|
add_ = l_x_.add_(1.0); l_x_ = None
|
|
return add_
|
|
"""
|
|
check_count_and_graph(1, 1, 1, expected_graph)
|
|
|
|
ff = torch.func.functionalize(f)
|
|
ff_out = ff(t_clone)
|
|
# frame count and op count are incremented due to re-compilation
|
|
check_count_and_graph(2, 2, 2, expected_graph)
|
|
|
|
try:
|
|
x = torch._to_functional_tensor(t_clone2, mirror_autograd_meta=True)
|
|
torch._enable_functionalization(reapply_views=False)
|
|
aot_f_out = f(x)
|
|
finally:
|
|
torch._disable_functionalization()
|
|
|
|
# frame count and op count are incremented due to re-compilation
|
|
check_count_and_graph(3, 3, 3, expected_graph)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|