Turn on capture_dynamic_output_shape_ops when fullgraph=True (#163123)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163123
Approved by: https://github.com/laithsakka
ghstack dependencies: #163121
This commit is contained in:
bobrenjc93 2025-09-18 11:47:04 -07:00 committed by PyTorch MergeBot
parent 7dcb568c8f
commit ed3438ff13
4 changed files with 17 additions and 24 deletions

View File

@ -48,27 +48,6 @@ class GenericCtxMgr:
class ErrorMessagesTest(LoggingTestCase):
def test_dynamic_shape_operator(self):
def fn():
return torch.nonzero(torch.rand([10, 10]))
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
"""\
Dynamic shape operator
Explanation: Operator `aten.nonzero.default`'s output shape depends on input Tensor data.
Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`
Developer debug context: aten.nonzero.default
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html
from user code:
File "test_error_messages.py", line N, in fn
return torch.nonzero(torch.rand([10, 10]))""",
)
def test_dynamic_shape_operator_no_meta_kernel(self):
def fn():
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))

View File

@ -13362,6 +13362,18 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
# We expect to no longer raise here
foo(torch.tensor(2.0))
def test_full_graph_capture_dynamic_output_shape_ops(self):
def fn(x):
nz = torch.nonzero(x)
squared = nz * nz
sliced = torch.ops.aten.slice.Tensor(squared, dim=1, start=-2, end=None)
view = sliced.unsqueeze(dim=0)
return view.squeeze(dim=0)
example_inputs = (torch.randn(1, 1, 1, 1),)
# we expect to no longer raise here
torch.compile(fn, fullgraph=True)(*example_inputs)
def test_dynamic_float_scalar_tensor_coersion(self):
# Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367
class Foo:

View File

@ -2507,7 +2507,8 @@ def compile(
fullgraph (bool): If False (default), torch.compile attempts to discover compilable regions
in the function that it will optimize. If True, then we require that the entire function be
capturable into a single graph. If this is not possible (that is, if there are graph breaks),
then this will raise an error.
then this will raise an error. This also opts into unbacked semantics, notably it will turn on
capture_scalar_outputs and capture_dynamic_output_shape_ops on by default.
dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
to generate a kernel that is as dynamic as possible to avoid recompilations when
sizes change. This may not always work as some operations/optimizations will

View File

@ -488,9 +488,10 @@ class OutputGraph(OutputGraphGuardsState):
# TrackedFake instances may have its metadata changed throughout
# the program execution.
tracked_fakes=self.tracked_fakes,
# We want to allow scalar outputs when fullgraph=True
# We want to allow capture scalar outputs and allow_dynamic_output_shape_ops when fullgraph=True
allow_scalar_outputs=one_graph or config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
allow_dynamic_output_shape_ops=one_graph
or config.capture_dynamic_output_shape_ops,
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
co_fields=self.co_fields,
)