# Owner(s): ["module: dynamo"] import functools import itertools import unittest import torch import torch._dynamo.test_case import torch._dynamo.testing import torch._functorch.config import torch.utils._pytree as pytree import torch.utils.checkpoint from torch._dynamo.testing import normalize_gm from torch._higher_order_ops.wrap import wrap from torch.fx.experimental.symbolic_shapes import ( DimDynamic, ShapeEnv, StatelessSymbolicContext, ) from torch.nested._internal.nested_tensor import ( jagged_from_list, jagged_from_tensor_and_lengths, nested_view_from_values_offsets, NestedTensor, ) from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, subtest, ) from torch.testing._internal.inductor_utils import HAS_CUDA from torch.testing._internal.two_tensor import TwoTensor def traceable_subclass(c): return torch._dynamo.config.patch("traceable_tensor_subclasses", {c}) requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda") compile_full_eager = torch.compile(backend="eager", fullgraph=True) class BaseTorchFunction(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) class MockSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) class AttrSubclass(torch.Tensor): x: int = 10 size: int = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) class DummyNDim(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func == torch.Tensor.ndim.__get__: return 10 return super().__torch_function__(func, types, args, kwargs) class WrapperSubclass: def __init__(self, tensor): self.tensor = tensor @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} args = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, args) kwargs = pytree.tree_map_only(WrapperSubclass, lambda x: x.tensor, kwargs) return func(*args, **kwargs) class SigmoidToExpSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func == torch.Tensor.sigmoid: return super().__torch_function__(torch.Tensor.exp, types, args, kwargs) return super().__torch_function__(func, types, args, kwargs) # Wrapper subclass with two inner tensors: data and scale # data has same shape as outer, and scale has single dim size class ScaledTensor(torch.Tensor): def __new__( cls, data: torch.Tensor, scale: torch.Tensor, *, constant: int = 0, ): return torch.Tensor._make_wrapper_subclass( cls, data.size(), strides=data.stride(), storage_offset=data.storage_offset(), dtype=data.dtype, layout=data.layout, requires_grad=data.requires_grad, device=data.device, ) def __init__(self, data: torch.Tensor, scale: torch.Tensor, constant: int = 0): self._data = data self._scale = scale self._constant = constant def __tensor_flatten__(self): ctx = {"_constant": self._constant} return ["_data", "_scale"], ctx @staticmethod def __tensor_unflatten__(inner_tensors, metadata, outer_size, outer_stride): assert len(inner_tensors) == 2 return ScaledTensor( inner_tensors["_data"], inner_tensors["_scale"], constant=metadata["_constant"], ) @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): scaled_tensor = args[0] out = func(scaled_tensor._data, *args[1:], **kwargs) return ScaledTensor(out, scaled_tensor._scale, constant=scaled_tensor._constant) def __repr__(self): return f"{self._data.__repr__()}\n{self._scale.__repr__()}" def func(a): return a.sin() class EagerRecordGraphAndInputs: def __init__(self): self.graphs = [] self.example_inputs = [] def __call__(self, gm: torch.fx.GraphModule, example_inputs): self.graphs.append(gm) self.example_inputs.append(example_inputs) return gm GLOBAL_TEST_SUBCLASSES = { MockSubclass, DummyNDim, SigmoidToExpSubclass, BaseTorchFunction, } # Returns True if the function recompiles between inputs1 and inputs2 with the # specified dynamic setting. def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True): compile_count = [0] def counter(gm, example_inputs): compile_count[0] += 1 return gm compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic) compiled_f(*inputs1) compiled_f(*inputs2) return compile_count[0] > 1 class SubclassTests(torch._dynamo.test_case.TestCase): @classmethod def setUpClass(cls): super().setUpClass() cls._exit_stack.enter_context( torch._dynamo.config.patch( "traceable_tensor_subclasses", GLOBAL_TEST_SUBCLASSES ) ) @classmethod def tearDownClass(cls): cls._exit_stack.close() def test_no_call_to_new(self): class BadNewTorchFunction(torch.Tensor): def __new__(cls, *args, **kwargs): raise RuntimeError("Oops!") @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {BadNewTorchFunction} ): @torch.compile(backend="eager", fullgraph=True) def fn(x): return torch.add(x, 1) input = torch.ones(2, 2).as_subclass(BadNewTorchFunction) res = fn(input) self.assertIsInstance(res, BadNewTorchFunction) def test_base_torch_function_tracing(self): def fn(x): return torch.add(x, 1) input = torch.ones(2, 2).as_subclass(BaseTorchFunction) out = fn(input) out_opt = compile_full_eager(fn)(input) self.assertIsInstance(out, BaseTorchFunction) self.assertEqual(out, out_opt) def test_torch_function_state_graph_break(self): @torch.compile(backend="eager") def fn(x): with torch._C.DisableTorchFunctionSubclass(): torch._dynamo.graph_break() return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) input = torch.ones(2, 2) res, _ = fn(input) self.assertFalse(res) def test_torch_function_state_nested(self): @torch.compile(backend="eager") def fn(x): with torch._C.DisableTorchFunctionSubclass(): with torch._C.DisableTorchFunctionSubclass(): x = x + 1 # Should reset to the outer state (disabled) after exiting ctx manager return torch._C._is_torch_function_enabled(), torch.add(x, 1.0) input = torch.ones(2, 2) res, _ = fn(input) self.assertFalse(res) def test_torch_function_state_tracing(self): @torch.compile(backend="eager", fullgraph=True) def fn(x): with torch._C.DisableTorchFunctionSubclass(): torch.add(x, 1.0) input = torch.ones(2, 2) res = fn(input) def test_torch_function_state_guards(self): cnt = torch._dynamo.testing.CompileCounter() @torch.compile(backend=cnt, fullgraph=True) def fn(x): torch.add(x, 1.0) input = torch.ones(2, 2) with torch._C.DisableTorchFunctionSubclass(): res = fn(input) res = fn(input) self.assertEqual(cnt.frame_count, 2) def test_return_subclass(self): @torch.compile(backend="eager", fullgraph=True) def fn(x): return MockSubclass(torch.add(x, 1.0)) input = torch.ones(2, 2) res = fn(input) self.assertIsInstance(res, MockSubclass) def test_return_as_subclass(self): @torch.compile(backend="eager", fullgraph=True) def fn(x): return torch.add(x, 1.0).as_subclass(MockSubclass) input = torch.ones(2, 2) res = fn(input) self.assertIsInstance(res, MockSubclass) def test_return_local_subclass(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) with torch._dynamo.config.patch("traceable_tensor_subclasses", {LocalSubclass}): @torch.compile(backend="eager", fullgraph=True) def fn(x): return LocalSubclass(torch.add(x, 1.0)) input = torch.ones(2, 2) res = fn(input) self.assertIsInstance(res, LocalSubclass) @parametrize( "comparison", [ subtest(isinstance, "isinstance"), subtest(lambda instance, type_: type(instance) == type_, "equality"), subtest(lambda instance, type_: type(instance) is type_, "identity"), ], ) @parametrize( "input_type", [ subtest(torch.Tensor, "tensor"), subtest(DummyNDim, "subclass"), ], ) def test_type_check(self, comparison, input_type): with torch._dynamo.config.patch("traceable_tensor_subclasses", {DummyNDim}): def fn(x): if comparison(x, DummyNDim): return torch.ones(1, 1) else: return torch.zeros(2, 2) input = torch.ones(2, 2).as_subclass(input_type) exp_res = fn(input) act_res = torch.compile(backend="eager", fullgraph=True)(fn)(input) self.assertEqual(exp_res, act_res) def test_torch_function_call_on_method(self): x = torch.ones(2, 2) y = torch.ones(2, 2) z = torch.ones(2, 2) wrapped = x.as_subclass(SigmoidToExpSubclass) wrapped2 = y.as_subclass(SigmoidToExpSubclass) def fn(w): return w.sigmoid() fn_opt = compile_full_eager(fn) res_exp = fn(wrapped) res_act = fn_opt(wrapped2) res_exp2 = z.exp() self.assertEqual(res_exp, res_act) self.assertEqual(res_exp, res_exp2) def test_user_overidden_method_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) def sigmoid(self): return None @torch.compile(backend="eager", fullgraph=True) def fn(x): x.sigmoid() msg = ( "Accessing overridden method/attribute sigmoid on a tensor" " subclass with a __torch_function__ override is not supported" ) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) def test_user_overidden_attr_unsupported(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) ndim = 10 @torch.compile(backend="eager", fullgraph=True) def fn(x): return x.ndim msg = ( "Accessing overridden method/attribute ndim on a tensor" " subclass with a __torch_function__ override is not supported" ) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) def test_user_overidden_property_unsupported(self): class LocalSubclass(torch.Tensor): def __init__(self): self._ndim = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) @property def ndim(self): return self._ndim @ndim.setter def ndim(self, value): self._ndim = value @torch.compile(backend="eager", fullgraph=True) def fn(x): return x.ndim msg = ( "Accessing overridden method/attribute ndim on a tensor" " subclass with a __torch_function__ override is not supported" ) with torch._dynamo.config.patch( "traceable_tensor_subclasses", {LocalSubclass} ), self.assertRaisesRegex(torch._dynamo.exc.Unsupported, msg): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) def test_overridden_method_guarding(self): class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) @torch.compile(backend="eager") def fn(x): return x.sigmoid() with torch._dynamo.config.patch( error_on_recompile=True, traceable_tensor_subclasses={LocalSubclass} ): x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) fn(x) x = torch.ones(2, 2).as_subclass(LocalSubclass) fn(x) with torch._dynamo.config.patch( traceable_tensor_subclasses={LocalSubclass} ), self.assertRaisesRegex( TypeError, "'bool' object is not callable", ): LocalSubclass.sigmoid = False fn(x) def test_torch_function_call_on_attr(self): x = torch.ones(2, 2) wrapped = x.as_subclass(DummyNDim) def fn(w): return w.ndim + torch.ones(2) fn_opt = compile_full_eager(fn) res_exp = fn(wrapped) res_act = fn_opt(wrapped) self.assertEqual(res_exp, res_act) self.assertEqual(res_exp, torch.ones(2) + 10) def test_torch_function_wrapper_class(self): x = torch.ones(2, 2) wrapped = WrapperSubclass(x) def fn(w): return torch.add(w, 1.0) fn_opt = compile_full_eager(fn) res_exp = fn(wrapped) res_act = fn_opt(wrapped) self.assertEqual(res_exp, res_act) def test_torch_function_wrapper_class_with_kwargs(self): x = torch.ones(2, 2) wrapped = WrapperSubclass(x) def fn(w): return torch.add(w, 1.0, alpha=2.0) fn_opt = compile_full_eager(fn) res_exp = fn(wrapped) res_act = fn_opt(wrapped) self.assertEqual(res_exp, res_act) def test_tensor_subclass_custom_attr(self): class AttrSubclass(torch.Tensor): x: int = 10 @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return super().__torch_function__(func, types, args, kwargs) @torch.compile(backend="eager", fullgraph=True) def fn(x): return x.x + torch.ones(2, 2) with traceable_subclass(AttrSubclass): input = torch.ones(2, 2).as_subclass(AttrSubclass) fn_opt = compile_full_eager(fn) res_exp = fn(input) res_act = fn_opt(input) self.assertEqual(res_exp, res_act) def test_compile_with_fake_tensor_dynamic_dim(self): x = torch.randn([3, 4]) def f(x): return torch.sin(x) def test_dynamic_dim(f, x, dim_dynamic, exp_frame_count, exp_op_count): torch._dynamo.reset() cnt = torch._dynamo.testing.CompileCounter() opt_f = torch.compile(f, backend=cnt, fullgraph=True) x1 = torch.rand_like(x) f(x) f(torch.randn([4, 3])) shape_env = ShapeEnv() with torch._subclasses.fake_tensor.FakeTensorMode( shape_env=shape_env ) as fake_mode: x_fake = fake_mode.from_tensor( x, symbolic_context=StatelessSymbolicContext( dynamic_sizes=[dim_dynamic for i in range(x.dim())] ), ) x1_fake = fake_mode.from_tensor( x1, symbolic_context=StatelessSymbolicContext( dynamic_sizes=[dim_dynamic for i in range(x.dim())] ), ) opt_f(x_fake) opt_f(x1_fake) self.assertEqual(cnt.frame_count, exp_frame_count) self.assertEqual(cnt.op_count, exp_op_count) test_dynamic_dim(f, x, DimDynamic.DYNAMIC, 1, 1) test_dynamic_dim(f, x, DimDynamic.DUCK, 1, 1) test_dynamic_dim(f, x, DimDynamic.STATIC, 1, 1) def test_compile_with_fake_tensor_automatic_dynamic(self): def f(x): return torch.sin(x) def test_automatic_dynamic(f, inps, dim_dynamic, exp_frame_count, exp_op_count): torch._dynamo.reset() cnt = torch._dynamo.testing.CompileCounter() opt_f = torch.compile(f, backend=cnt, fullgraph=True) shape_env = ShapeEnv() with torch._subclasses.fake_tensor.FakeTensorMode( shape_env=shape_env ) as fake_mode: for inp in inps: fake_inp = fake_mode.from_tensor( inp, symbolic_context=StatelessSymbolicContext( [dim_dynamic for i in range(x.dim())] ), ) opt_f(fake_inp) self.assertEqual(cnt.frame_count, exp_frame_count) self.assertEqual(cnt.op_count, exp_op_count) x = torch.randn([3, 4]) y = torch.randn([4, 5]) z = torch.randn([5, 6]) a = torch.randn([3, 5]) b = torch.randn([4, 4]) # When inputs' DimDynamic is DYNAMIC or DUCK, the inputs # to opt_f will be tensors with SymInt sizes. Dynamo will treat input # as dynamic automatically and will only compile once for dim_dynamic in [DimDynamic.DYNAMIC, DimDynamic.DUCK]: test_automatic_dynamic(f, [x, y, z], dim_dynamic, 1, 1) test_automatic_dynamic(f, [x, a, z], dim_dynamic, 1, 1) test_automatic_dynamic(f, [x, b, z], dim_dynamic, 1, 1) for dim_dynamic in [DimDynamic.STATIC]: # Recompile once, first with dim 0 and 1 become Dynamic test_automatic_dynamic(f, [x, y, z], dim_dynamic, 2, 2) # Recompile 2 times, first with dim 1 become Dynamic, second with dim 0 becomes Dynamic. test_automatic_dynamic(f, [x, a, z], dim_dynamic, 3, 3) # Recompile 2 times, first with dim 0 become Dynamic, second with dim 1 becomes Dynamic. test_automatic_dynamic(f, [x, b, z], dim_dynamic, 3, 3) def test_compile_with_functionalization(self): x = torch.randn([3, 4]) x_clone = x.clone() x_clone2 = x.clone() backend = EagerRecordGraphAndInputs() cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) @torch.compile(backend=cnt, fullgraph=True) def f(x): return x.add_(1.0) + torch.nn.functional.relu_(x) f_out = f(x) self.assertEqual(cnt.frame_count, 1) self.assertEqual(cnt.op_count, 3) self.assertEqual(len(backend.graphs), 1) self.assertEqual(len(backend.example_inputs), 1) expected = """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ add_ = l_x_.add_(1.0) relu_ = torch.relu_(l_x_); l_x_ = None add = add_ + relu_; add_ = relu_ = None return (add,) """ actual = normalize_gm(backend.graphs[0].print_readable(print_output=False)) self.assertEqual(actual, expected) ff = torch.func.functionalize(f) ff_out = ff(x_clone) self.assertEqual(cnt.frame_count, 2) self.assertEqual(cnt.op_count, 6) self.assertEqual(len(backend.graphs), 2) self.assertEqual(len(backend.example_inputs), 2) actual = normalize_gm(backend.graphs[1].print_readable(print_output=False)) self.assertEqual(actual, expected) self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) # Cannot re-use the version from AOTAutograd, since that uses python functional tensors. def to_fun(x): x_functional = torch._to_functional_tensor(x) torch._mirror_autograd_meta_to(x, x_functional) return x_functional def aot_f_wrapper(func): @functools.wraps(func) def wrapper(*args, **kwargs): torch._enable_functionalization(reapply_views=False) try: func_args = pytree.tree_map(to_fun, args) func_kwargs = pytree.tree_map(to_fun, kwargs) return func(*func_args, **func_kwargs) finally: torch._disable_functionalization() return wrapper aot_ff = aot_f_wrapper(f) aot_ff_out = aot_ff(x_clone2) self.assertEqual(cnt.frame_count, 3) self.assertEqual(cnt.op_count, 9) self.assertEqual(len(backend.graphs), 3) self.assertEqual(len(backend.example_inputs), 3) actual = normalize_gm(backend.graphs[2].print_readable(print_output=False)) self.assertEqual(actual, expected) self.assertTrue(torch._is_functional_tensor(backend.example_inputs[1][0])) self.assertEqual(f_out, ff_out) self.assertEqual(f_out, aot_ff_out) try: torch._enable_functionalization(reapply_views=False) xf = pytree.tree_map(to_fun, x) x_view = xf.t() with self.assertRaisesRegex(RuntimeError, "Cannot safely fakify a view"): f(x_view) finally: torch._disable_functionalization() def test_compile_higher_order_with_functionalization(self): backend = EagerRecordGraphAndInputs() cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) @torch.compile(backend=cnt, fullgraph=True) def f(x): return wrap(lambda x: x.add_(1.0), x) def check_count_and_graph( exp_frame_count, exp_op_count, exp_n_graph, exp_graph ): self.assertEqual(cnt.frame_count, exp_frame_count) self.assertEqual(cnt.op_count, exp_op_count) self.assertEqual(len(backend.graphs), exp_n_graph) actual = normalize_gm( backend.graphs[exp_n_graph - 1].print_readable(print_output=False) ) self.assertExpectedInline(actual, exp_graph) t = torch.randn([3, 4]) t_clone = t.clone() t_clone2 = t.clone() f(t) expected_graph = """\ class GraphModule(torch.nn.Module): def forward(self, L_x_ : torch.Tensor): l_x_ = L_x_ wrap_body_0 = self.wrap_body_0 wrap = torch._higher_order_ops.wrap.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None getitem = wrap[0]; wrap = None return (getitem,) class GraphModule(torch.nn.Module): def forward(self, l_x_): add_ = l_x_.add_(1.0); l_x_ = None return (add_,) """ check_count_and_graph(1, 2, 1, expected_graph) ff = torch.func.functionalize(f) ff_out = ff(t_clone) # frame count and op count are incremented due to re-compilation check_count_and_graph(2, 4, 2, expected_graph) try: x = torch._to_functional_tensor(t_clone2) torch._mirror_autograd_meta_to(t_clone2, x) torch._enable_functionalization(reapply_views=False) aot_f_out = f(x) finally: torch._disable_functionalization() # frame count and op count are incremented due to re-compilation check_count_and_graph(3, 6, 3, expected_graph) def test_has_torch_function(self): class MyTensor: @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} if func is torch.max: return torch.tensor(123) return func(*args, **kwargs) class LocalSubclass(torch.Tensor): @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} return func(*args, **kwargs) def fn(x): return torch.overrides.has_torch_function_unary( x ), torch.overrides.has_torch_function_variadic(x) for test_class in [MyTensor, LocalSubclass]: x = test_class() ref0 = fn(x) ref1 = fn(4) opt_fn = torch._dynamo.optimize("eager")(fn) res0 = opt_fn(x) res1 = opt_fn(4) self.assertEqual(ref0, res0) self.assertEqual(ref1, res1) def test_wrapper_subclass_guards_on_inner_tensor(self): # Holds an inner tensor, that has a distinct shape from the outer wrapper tensor. # Also adds additional guards on the inner tensor's sizes. # When the first input to an op has x.shape[0] > 5, we insert an extra add node. class DoubleSizeMaybeAddGeThreeTensor(torch.Tensor): @staticmethod def __new__(cls, inner): # Double the outer-most dimension outer_shape = (inner.shape[0] * 2,) + inner.shape[1:] return torch.Tensor._make_wrapper_subclass( # TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great. # Calling the overload that has kwargs causes us to go down the first overload path, # which will **always** specialize sizes. # We should probably eventually fix this so that the first overload can just handle dynamic shapes. cls, outer_shape, inner.stride(), None, None, inner.dtype, inner.layout, inner.device, False, inner.requires_grad, ) def __init__(self, inner): self.inner_elem = inner def __tensor_flatten__(self): return ["inner_elem"], None @staticmethod def __tensor_unflatten__(inner_tensors, _, outer_size, outer_stride): return DoubleSizeMaybeAddGeThreeTensor(inner_tensors["inner_elem"]) def __repr__(self): return f"DoubleSizeMayberAddGeThreeTensor({repr(self.inner_elem)})" @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} args_inner = torch.utils._pytree.tree_map_only( DoubleSizeMaybeAddGeThreeTensor, lambda x: x.inner_elem, args ) out_inner = func(*args_inner, **kwargs) # Add guards on the inner tensor's sizes if args_inner[0].shape[0] > 3: out_inner += 2 return DoubleSizeMaybeAddGeThreeTensor(out_inner) curr_var_to_val = None curr_var_to_sources = None guards = None def backend(gm, args): context = torch._guards.TracingContext.get() # Grab info on sources and guards from the shapeenv nonlocal curr_var_to_val nonlocal curr_var_to_sources nonlocal guards guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] curr_var_to_val = { str(k): v for k, v in context.fake_mode.shape_env.var_to_val.items() } curr_var_to_sources = { str(k): v[0].name() for k, v in context.fake_mode.shape_env.var_to_sources.items() } return gm @torch.compile(backend=backend) def fn(x): if x.shape[0] < 10: return torch.mul(x, x) else: return torch.div(x, x) inp = torch.ones(4, 4) x = DoubleSizeMaybeAddGeThreeTensor(inp) torch._dynamo.mark_dynamic(x, 0) res = fn(x) # During fakeifying, we end up allocating a separate symint # for the outer and inner tensor (in this test, s0 is unused). expected_var_to_val = { "s0": 8, "s1": 4, } expected_var_to_sources = { "s0": "L['x'].size()[0]", "s1": "L['x'].inner_elem.size()[0]", } self.assertEqual(curr_var_to_val, expected_var_to_val) self.assertEqual(curr_var_to_sources, expected_var_to_sources) self.assertExpectedInline( "\n".join(guards), """\ Eq(2*s1, s0) 2*s1 < 10 s1 > 3""", ) def test_wrapper_subclass_with_same_sized_inner_tensor(self): # shouldn't recompile for different sizes when dynamic=True sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(7)) self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=True)) # should recompile for different data size when dynamic=False sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) # avoid recompile using manual mark_dynamic() for different data size sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(6)) # NB: mark_dynamic() on outer tensor should translate to inner tensors of the same size torch._dynamo.mark_dynamic(sub1, 0) torch._dynamo.mark_dynamic(sub1, 1) sub2 = ScaledTensor(torch.randn(3, 5), torch.randn(6)) self.assertFalse(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) def test_wrapper_subclass_with_differently_sized_inner_tensor(self): # should recompile for different scale size when dynamic=False sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) # still recompiles using manual mark_dynamic() on outer for different scale size sub1 = ScaledTensor(torch.randn(2, 4), torch.randn(3)) # NB: mark_dynamic() on outer tensor doesn't translate to inner tensors of different size torch._dynamo.mark_dynamic(sub1, 0) torch._dynamo.mark_dynamic(sub1, 1) sub2 = ScaledTensor(torch.randn(2, 4), torch.randn(5)) self.assertTrue(_recompiles_for_inputs(func, (sub1,), (sub2,), dynamic=False)) def test_torch_dispatch_subclass_guard_recompile(self): x = torch.ones(2, 2) x_two = TwoTensor(x.clone(), x.clone()) def fn(w): return torch.add(w, 1.0) fn_opt = torch.compile(backend="eager")(fn) ref = fn(x_two) res = fn_opt(x_two) self.assertEqual(ref, res) # ensure no recompilation on same input type with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True): fn_opt(TwoTensor(x + 1, x + 2)) # recompile! ref = fn(x) res = fn_opt(x) self.assertEqual(ref, res) def test_torch_function_subclass_survives_into_aot_autograd(self): # If you have a tensor subclass that relies on dispatch into the same op # without unwrapping and calling torch._C.DisableTorchFunctionSubclass(), # the torch function-ness will survive into AOTAutograd. Today, NestedTensor # actually relies on this behavior! Because that torch function logic # runs during AOTAutograd, this test tests that there is no logic below # that relies torch function that gets unexpectedly disabled after we # redispatch from the subclass's torch function. class SubTensor(torch.Tensor): @staticmethod def __new__(cls, t): return torch.Tensor._make_wrapper_subclass( cls, t.shape, t.stride(), t.storage_offset(), torch.contiguous_format, t.dtype, torch.strided, t.device, False, t.requires_grad, "sizes", False, False, None, ) def __init__(self, t): super().__init__() self._t = t def __tensor_flatten__(self): return ["_t"], {} @staticmethod def __tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride): t = inner_tensors["_t"] return SubTensor(t) def __repr__(self): return f"SubTensor({self._t})" @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): if kwargs is None: kwargs = {} with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @classmethod def __torch_dispatch__(cls, func, types, args=(), kwargs=None): kwargs = {} if kwargs is None else kwargs new_args = pytree.tree_map_only(SubTensor, lambda s: s._t, args) output = func(*new_args, **kwargs) output = pytree.tree_map_only( torch.Tensor, lambda t: SubTensor(t), output ) return output @torch.compile(dynamic=True) def f(x): return x.unflatten(-1, [2, 5]) s = SubTensor(torch.randn(3, 10)) f(s) def test_recompile_with_symbool_inputs(self): def f(pred: bool): if pred: return torch.ones([3, 4]) else: return torch.ones([4, 3]) def test_recompilation( f, x, sizes, exp_graphs, exp_frame_count, exp_shape_env_guards ): torch._dynamo.reset() shape_env = ShapeEnv() backend = torch._dynamo.testing.EagerAndRecordGraphs() cnt = torch._dynamo.testing.CompileCounterWithBackend(backend) f_cond = torch.compile(f, backend=cnt, fullgraph=True) with torch._subclasses.fake_tensor.FakeTensorMode( shape_env=shape_env ) as fake_mode: fake_inp = fake_mode.from_tensor( x, symbolic_context=StatelessSymbolicContext( dynamic_sizes=[DimDynamic.DYNAMIC for i in range(x.dim())] ), ) for i, size in enumerate(sizes): pred = fake_inp.size(0) == size f_cond(pred) actual = normalize_gm( backend.graphs[exp_frame_count[i] - 1].print_readable( print_output=False ) ) actual_guard_str = [str(guard.expr) for guard in shape_env.guards] self.assertExpectedInline(actual, exp_graphs[i]) self.assertEqual(cnt.frame_count, exp_frame_count[i]) self.assertEqual(actual_guard_str, exp_shape_env_guards[i]) true_graph = """\ class GraphModule(torch.nn.Module): def forward(self): ones = torch.ones([3, 4]) return (ones,) """ false_graph = """\ class GraphModule(torch.nn.Module): def forward(self): ones = torch.ones([4, 3]) return (ones,) """ test_recompilation( f, torch.randn([3, 4]), [3, 3, 4, 5], exp_graphs=[true_graph, true_graph, false_graph, false_graph], exp_frame_count=[1, 1, 2, 2], exp_shape_env_guards=[ [], # s0 is specialized and guarded in outter shape_env when dynamo checks the guards ["Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)"], [ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", ], [ "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", "Ne(Piecewise((1, Eq(s0, 4)), (0, True)), 1)", "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", ], ], ) test_recompilation( f, torch.randn([3, 4]), [4, 5, 3, 3], exp_graphs=[false_graph, false_graph, true_graph, true_graph], exp_frame_count=[1, 1, 2, 2], exp_shape_env_guards=[ [], # s0 is specialized and guarded in outter shape_env when dynamo checks the guards ["Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)"], [ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", ], [ "Ne(Piecewise((1, Eq(s0, 5)), (0, True)), 1)", "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", "Eq(Piecewise((1, Eq(s0, 3)), (0, True)), 1)", ], ], ) def test_wrapper_subclass_dynamo_attribute_access_on_intermediate(self): def f(x_subclass): tmp_subclass = torch.add(x, 1) return torch.mul(tmp_subclass._scale, tmp_subclass._constant) x = ScaledTensor(torch.randn(2, 4), torch.randn(3), constant=2) out_ref = f(x) out_test = torch.compile(f, backend="aot_eager", fullgraph=True)(x) self.assertEqual(out_ref, out_test) def test_support_bases(self): import abc import torch.fx._symbolic_trace class Meta(abc.ABCMeta, torch.fx._symbolic_trace.ProxyableClassMeta): def __new__(cls, name, bases, dct): x = super().__new__(cls, name, bases, dct) x.attr = 100 return x class Multistreamable(abc.ABC): # noqa: B024 pass class Foo(Multistreamable, metaclass=Meta): pass @torch.compile(backend="eager", fullgraph=True) def f(x): typ = type(Foo()) typ.__bases__ return typ.__bases__ self.assertEqual(f(torch.randn(1)), (Multistreamable,)) @parametrize("dynamic", [False, True]) def test_subclass_views(self, dynamic): def _get_views(t): # Note that any closed-over SymInts will be symbolicized during fake-ification. yield t.narrow(dim=-1, start=3, length=8) yield t.split(5, -1) yield t.split_with_sizes([9, 6], -1) yield t.unsqueeze(-1).expand(4, 15, 10) yield t.select(-1, 6) yield t[2:3, 5:9] def f(x): return x * 2 compiled_f = torch.compile( f, backend="aot_eager", fullgraph=True, dynamic=dynamic ) # Take a view of a subclass to pass as input. t = TwoTensor(torch.randn(4, 15), torch.randn(4, 15)) for view in _get_views(t): out_ref = f(view) out_test = compiled_f(view) self.assertEqual(out_ref, out_test) instantiate_parametrized_tests(SubclassTests) class TestNestedTensor(torch._dynamo.test_case.TestCase): def _get_jagged_tensor(self, nested_size, offsets, requires_grad=True): # Makes a jagged tensor with N constituent tensors with size # as specified ((S0, S1, S2), D) D = nested_size[1] out = [] for s in nested_size[0]: out.append( torch.randn(s, D, requires_grad=requires_grad, dtype=torch.float64) ) return jagged_from_list(out, offsets) def _get_nc_jagged_tensor(self, inner_dim, starts, lengths, requires_grad=True): # Makes a jagged tensor with N constituent tensors with size # as specified ((S0, S1, S2), D) max_dim = (starts + lengths).max() values_tensor = torch.randn( starts.shape[0], max_dim.item(), inner_dim, requires_grad=requires_grad, dtype=torch.float64, ) return jagged_from_tensor_and_lengths(values_tensor, starts, lengths) def _check_recompiles(self, fn, inputs1, inputs2, expected_recompiles): actual_recompiles = _recompiles_for_inputs(fn, inputs1, inputs2) self.assertEqual(actual_recompiles, expected_recompiles) def test_unary_does_not_recompile(self): nt1, _ = self._get_jagged_tensor(((2, 3, 4), 3), None) nt2, _ = self._get_jagged_tensor(((3, 4, 5, 6), 4), None) self._check_recompiles(lambda nt1: nt1.sin(), (nt1,), (nt2,), False) def test_binary_does_not_recompile(self): def binary(nt1, nt2): if nt1.shape == nt2.shape: return nt1 + nt2 else: return nt1.sin() # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). # This causes a recompile later on when it realizes the batch and last dim # should not always be equal. To avoid that, we use (3, j0, 5) here. nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) nt3, offsets = self._get_jagged_tensor(((3, 4, 5), 4), None) nt4, _ = self._get_jagged_tensor(((3, 4, 5), 4), offsets) self._check_recompiles(binary, (nt1, nt2), (nt3, nt4), False) def test_binary_recompiles(self): def binary(nt1, nt2): if nt1.shape == nt2.shape: return nt1 + nt2 else: return nt1.sin() # Binary recompiles because singleton ints no longer match nt1, offsets = self._get_jagged_tensor(((2, 3, 4), 5), None) nt2, _ = self._get_jagged_tensor(((2, 3, 4), 5), offsets) nt3, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) self._check_recompiles(binary, (nt1, nt2), (nt1, nt3), True) # TODO: cannot parametrize this test class with device for some reason def _test_autograd(self, backend): a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64) b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64) c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64) nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged) # TODO: Switch to public API when it exists nt2, _ = jagged_from_list([a, b, c], nt.offsets()) def fn1(nt1, nt2): return (nt1 + nt2).sin().cos() compiled_f = torch.compile(fn1, fullgraph=True, backend=backend, dynamic=True) out = compiled_f(nt, nt2) out_buffer = out.values() ga, gb, gc = torch.autograd.grad(out_buffer.sum(), (a, b, c)) out_ref = fn1(nt, nt2) out_buffer_ref = out_ref.values() ga_ref, gb_ref, gc_ref = torch.autograd.grad(out_buffer_ref.sum(), (a, b, c)) self.assertTrue(torch.allclose(ga, ga_ref)) self.assertTrue(torch.allclose(gb, gb_ref)) self.assertTrue(torch.allclose(gc, gc_ref)) def test_basic_autograd(self): self._test_autograd("aot_eager") @requires_cuda def test_basic_autograd_inductor(self): self._test_autograd("inductor") def test_subclass_with_mutation_in_graph(self): # In this graph, we have an in-graph mutation, i.e. a mutation that is allowed # to remain in the graph. Normally this is allowed, but it's not allowed if # the graph handles subclasses at all. # Whether the mutation is allowed or not allowed in the graph alters the number # of outputs from the forward graph. Previously, a bug in this handling meant # that sometimes the expected number and actual number of outputs from the # joint graph did not match, causing assertion failures. def fn(x, y): z = x.sin() y.sin_() return z.cos(), y.cos() fn_c = torch.compile(fn, backend="inductor") values = [torch.rand((i, 8), requires_grad=True) for i in range(1, 6)] values_copy = [x.detach().clone().requires_grad_(True) for x in values] nt, offsets = jagged_from_list(values, None) nt_copy, offsets = jagged_from_list(values_copy, offsets) y = torch.rand((4, 8)) y_copy = y.clone() ret = fn_c(nt, y)[0] ref = fn(nt_copy, y_copy)[0] self.assertEqual(ret.values(), ref.values()) ret.values().sum().backward() ref.values().sum().backward() for ref_v, res_v in zip(values_copy, values): self.assertEqual(ref_v.grad, res_v.grad) def test_unbind(self): # NB: If we have shape e.g. (3, j0, 3), duck sizing will give us (s0, s1, s0). # This causes a recompile later on when it realizes the batch and last dim # should not always be equal. To avoid that, we use (3, j0, 5) here. nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) nt2, _ = self._get_jagged_tensor(((2, 3, 5), 2), None) nt3, _ = self._get_jagged_tensor(((2, 3, 4, 5), 3), None) def fn(x): return x.unbind() compiled_f = torch.compile(fn, fullgraph=True, backend="eager", dynamic=True) out = compiled_f(nt) out_ref = fn(nt) # correctness self.assertEqual(len(out), len(out_ref)) for x, x_ref in zip(out, out_ref): self.assertTrue(torch.allclose(x, x_ref)) # We specialize on the length of offsets, e.g. (1) we recompile if the # length of the offsets is different. (2) we don't recompile if the # length of the offsets is the same, even if the size of the constituent # tensors are different. self._check_recompiles(fn, (nt,), (nt2,), False) self._check_recompiles(fn, (nt,), (nt3,), True) def test_inline_nested_tensor_from_jagged(self): nt, _ = self._get_jagged_tensor(((2, 3, 4), 5), None) def fn(x): return torch.nested.nested_tensor_from_jagged(x.values() * 2, x.offsets()) torch.compile(fn, fullgraph=True, backend="aot_eager")(nt) def _get_views(self): # Test all cases with both an NT base and a dense base # Subclass -> Subclass # Dense -> Subclass for base_is_nt in [False, True]: # There are three cases to consider here based on the logic in # meta_utils.py # # (1) basic case: # view is not a leaf and has the same requires grad as its basic case x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True) x = x.clone() if base_is_nt else x self.assertEqual(x.is_leaf, False) yield x.unsqueeze(-1) # (2) leaf view case: # the view has to be a leaf (w/ requires_grad True or requires_grad False) # base w/ requires_grad True or requires_grad False for requires_grad_1, requires_grad_2 in itertools.product( [True, False], repeat=2 ): x, _ = self._get_jagged_tensor( ((2, 3, 4), 3), None, requires_grad=requires_grad_1 ) x = x.clone() if base_is_nt else x with torch.no_grad(): x_view = x.unsqueeze(-1) # The issue is this doesn't quite work x_view.requires_grad_(requires_grad_2) yield x_view # (3) obscure case: # view is not a leaf (implies requires_grad True) # base w/ requires_grad False) x, _ = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=False) x = x.clone() if base_is_nt else x # intermediate leaf view with torch.no_grad(): x_view = x.unsqueeze(-1) x_view.requires_grad_(True) x_view_view = x_view.unsqueeze(-1) yield x_view_view # Subclass -> Dense x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() yield x.values() # Dense -> Subclass -> Dense -> Subclass values = torch.randn(10, 5) offsets = torch.tensor([0, 3, 6, 10]) offsets2 = offsets.clone().detach() yield nested_view_from_values_offsets( nested_view_from_values_offsets(values, offsets).values(), offsets ) def _input_view_test(self, nt_view): def fn(x): return x.sin() out_ref = fn(nt_view) torch._dynamo.reset() compile_fn = torch.compile( fn, fullgraph=True, backend="aot_eager", dynamic=True ) out = compile_fn(nt_view) # Check metadata and values are correct self.assertTrue(out.size() == out_ref.size()) self.assertTrue(out.stride() == out_ref.stride()) if out.is_nested: self.assertTrue(torch.allclose(out.values(), out_ref.values())) else: self.assertTrue(torch.allclose(out, out_ref)) # Check that no upper/lower bound guards are incurred def backend(gm, args): context = torch._guards.TracingContext.get() guards = [str(g.expr) for g in context.fake_mode.shape_env.guards] # varies based on the type of view guard_str = "\n".join(guards) if isinstance(nt_view._base, NestedTensor): self.assertExpectedInline(guard_str, """Eq(s3 - 1, s0)""") else: self.assertExpectedInline(guard_str, """""") return gm torch._dynamo.reset() compile_fn = torch.compile(fn, fullgraph=True, backend=backend, dynamic=True) out = compile_fn(nt_view) def test_inputs_to_compiled_fn_are_views(self): for nt_view in self._get_views(): self._input_view_test(nt_view) def test_subclass_gives_static_shapes_when_dynamic_false(self): def check_graph(gm, *args): first_node_example_val = next(iter(gm.graph.nodes)).meta["example_value"] # We compiled with dynamic=False, expect no SymInt sizes on our placeholders self.assertTrue( all(isinstance(x, int) for x in first_node_example_val.shape) ) return gm @torch.compile(backend=check_graph, dynamic=False) def f(x): return x + 1 x_inner = torch.ones(4) x = TwoTensor(x_inner, x_inner) x_view = x.view(2, 2) out = f(x_view) # NJT1 -> Dense -> NJT2 -> Dense view # During view replay, the Dense -> NJT2 part will construct an intermediate, # symbolically-sized NJT that is immediately deconstructed to return the final dense # view. To construct this intermediate properly, we need the associated nested int # to be symbolic. This view is expected to fail compilation until symbolic nested ints # are cached onto fake offsets to solve this problem. @unittest.expectedFailure def test_subclass_dense_subclass_dense_view(self): x = self._get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone() offsets2 = x.offsets().clone().detach() nt_view = nested_view_from_values_offsets(x.values(), offsets2).values() self._input_view_test(nt_view) if __name__ == "__main__": from torch._dynamo.test_case import run_tests run_tests()