mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Turn on capture_dynamic_output_shape_ops when fullgraph=True (#163123)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163123 Approved by: https://github.com/laithsakka ghstack dependencies: #163121
This commit is contained in:
parent
7dcb568c8f
commit
ed3438ff13
|
|
@ -48,27 +48,6 @@ class GenericCtxMgr:
|
|||
|
||||
|
||||
class ErrorMessagesTest(LoggingTestCase):
|
||||
def test_dynamic_shape_operator(self):
|
||||
def fn():
|
||||
return torch.nonzero(torch.rand([10, 10]))
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(),
|
||||
"""\
|
||||
Dynamic shape operator
|
||||
Explanation: Operator `aten.nonzero.default`'s output shape depends on input Tensor data.
|
||||
Hint: Enable tracing of dynamic shape operators with `torch._dynamo.config.capture_dynamic_output_shape_ops = True`
|
||||
|
||||
Developer debug context: aten.nonzero.default
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0036.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
return torch.nonzero(torch.rand([10, 10]))""",
|
||||
)
|
||||
|
||||
def test_dynamic_shape_operator_no_meta_kernel(self):
|
||||
def fn():
|
||||
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))
|
||||
|
|
|
|||
|
|
@ -13362,6 +13362,18 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
|||
# We expect to no longer raise here
|
||||
foo(torch.tensor(2.0))
|
||||
|
||||
def test_full_graph_capture_dynamic_output_shape_ops(self):
|
||||
def fn(x):
|
||||
nz = torch.nonzero(x)
|
||||
squared = nz * nz
|
||||
sliced = torch.ops.aten.slice.Tensor(squared, dim=1, start=-2, end=None)
|
||||
view = sliced.unsqueeze(dim=0)
|
||||
return view.squeeze(dim=0)
|
||||
|
||||
example_inputs = (torch.randn(1, 1, 1, 1),)
|
||||
# we expect to no longer raise here
|
||||
torch.compile(fn, fullgraph=True)(*example_inputs)
|
||||
|
||||
def test_dynamic_float_scalar_tensor_coersion(self):
|
||||
# Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367
|
||||
class Foo:
|
||||
|
|
|
|||
|
|
@ -2507,7 +2507,8 @@ def compile(
|
|||
fullgraph (bool): If False (default), torch.compile attempts to discover compilable regions
|
||||
in the function that it will optimize. If True, then we require that the entire function be
|
||||
capturable into a single graph. If this is not possible (that is, if there are graph breaks),
|
||||
then this will raise an error.
|
||||
then this will raise an error. This also opts into unbacked semantics, notably it will turn on
|
||||
capture_scalar_outputs and capture_dynamic_output_shape_ops on by default.
|
||||
dynamic (bool or None): Use dynamic shape tracing. When this is True, we will up-front attempt
|
||||
to generate a kernel that is as dynamic as possible to avoid recompilations when
|
||||
sizes change. This may not always work as some operations/optimizations will
|
||||
|
|
|
|||
|
|
@ -488,9 +488,10 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
# TrackedFake instances may have its metadata changed throughout
|
||||
# the program execution.
|
||||
tracked_fakes=self.tracked_fakes,
|
||||
# We want to allow scalar outputs when fullgraph=True
|
||||
# We want to allow capture scalar outputs and allow_dynamic_output_shape_ops when fullgraph=True
|
||||
allow_scalar_outputs=one_graph or config.capture_scalar_outputs,
|
||||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||
allow_dynamic_output_shape_ops=one_graph
|
||||
or config.capture_dynamic_output_shape_ops,
|
||||
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
|
||||
co_fields=self.co_fields,
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user