diff --git a/test/dynamo/test_error_messages.py b/test/dynamo/test_error_messages.py index 26e8176473e..172ced2a58a 100644 --- a/test/dynamo/test_error_messages.py +++ b/test/dynamo/test_error_messages.py @@ -48,27 +48,6 @@ class GenericCtxMgr: class ErrorMessagesTest(LoggingTestCase): - def test_dynamic_shape_operator(self): - def fn(): - return torch.nonzero(torch.rand([10, 10])) - - self.assertExpectedInlineMunged( - Unsupported, - lambda: torch.compile(fn, backend="eager", fullgraph=True)(), - """\ -Dynamic shape operator - Explanation: Operator `aten.nonzero.default`'s output shape depends on input Tensor data. - Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True` - - Developer debug context: aten.nonzero.default - - For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html - -from user code: - File "test_error_messages.py", line N, in fn - return torch.nonzero(torch.rand([10, 10]))""", - ) - def test_dynamic_shape_operator_no_meta_kernel(self): def fn(): return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10)) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index dac76803798..648848420a1 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -13362,6 +13362,18 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase): # We expect to no longer raise here foo(torch.tensor(2.0)) + def test_full_graph_capture_dynamic_output_shape_ops(self): + def fn(x): + nz = torch.nonzero(x) + squared = nz * nz + sliced = torch.ops.aten.slice.Tensor(squared, dim=1, start=-2, end=None) + view = sliced.unsqueeze(dim=0) + return view.squeeze(dim=0) + + example_inputs = (torch.randn(1, 1, 1, 1),) + # we expect to no longer raise here + torch.compile(fn, fullgraph=True)(*example_inputs) + def test_dynamic_float_scalar_tensor_coersion(self): # Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367 class Foo: diff --git a/torch/__init__.py b/torch/__init__.py index dee45cfe422..08dee062435 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2507,7 +2507,8 @@ def compile( fullgraph (bool): If False (default), torch.compile attempts to discover compilable regions in the function that it will optimize. If True, then we require that the entire function be capturable into a single graph. If this is not possible (that is, if there are graph breaks), - then this will raise an error. + then this will raise an error. This also opts into unbacked semantics, notably it will turn on + capture_scalar_outputs and capture_dynamic_output_shape_ops on by default. dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt to generate a kernel that is as dynamic as possible to avoid recompilations when sizes change. This may not always work as some operations/optimizations will diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 12b88c0d6b1..8a162942fe7 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -488,9 +488,10 @@ class OutputGraph(OutputGraphGuardsState): # TrackedFake instances may have its metadata changed throughout # the program execution. tracked_fakes=self.tracked_fakes, - # We want to allow scalar outputs when fullgraph=True + # We want to allow capture scalar outputs and allow_dynamic_output_shape_ops when fullgraph=True allow_scalar_outputs=one_graph or config.capture_scalar_outputs, - allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops, + allow_dynamic_output_shape_ops=one_graph + or config.capture_dynamic_output_shape_ops, prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards, co_fields=self.co_fields, )