Turn on capture_scalar_outputs when fullgraph=True (#163121)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163121
Approved by: https://github.com/laithsakka
This commit is contained in:
bobrenjc93 2025-09-17 21:11:37 -07:00 committed by PyTorch MergeBot
parent bb7c9a2d41
commit 7dcb568c8f
6 changed files with 17 additions and 26 deletions

View File

@ -91,29 +91,6 @@ from user code:
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))""",
)
def test_data_dependent_operator(self):
def fn(x):
return x.item()
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
torch.Tensor([1])
),
"""\
Unsupported Tensor.item() call with capture_scalar_outputs=False
Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.
Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph.
Developer debug context: call_method TensorVariable() item () {}
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html
from user code:
File "test_error_messages.py", line N, in fn
return x.item()""",
)
def test_data_dependent_operator2(self):
def fn(x):
return torch.equal(x, x)

View File

@ -13354,6 +13354,14 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
y = torch.tensor(5)
f(x, y)
def test_full_graph_capture_scalar_outputs(self):
@torch.compile(fullgraph=True)
def foo(a):
return torch.randn(5) * a.item()
# We expect to no longer raise here
foo(torch.tensor(2.0))
def test_dynamic_float_scalar_tensor_coersion(self):
# Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367
class Foo:

View File

@ -2007,7 +2007,6 @@ def forward(self, pred_1, x_1):
# Fails with: AssertionError: scan is not an OpOverload
@unittest.skipIf(not SM70OrLater, "triton")
@requires_cuda
@unittest.expectedFailure
def test_scan_associative_scan(self):
combine_mode = "generic"
compile_mode_scan = "compile"

View File

@ -431,6 +431,7 @@ class OutputGraph(OutputGraphGuardsState):
f_code: CodeType,
torch_function_mode_stack: list[torch.overrides.TorchFunctionMode],
package: Optional["CompilePackage"],
one_graph: bool = False,
) -> None:
super().__init__(
local_scope,
@ -487,7 +488,8 @@ class OutputGraph(OutputGraphGuardsState):
# TrackedFake instances may have its metadata changed throughout
# the program execution.
tracked_fakes=self.tracked_fakes,
allow_scalar_outputs=config.capture_scalar_outputs,
# We want to allow scalar outputs when fullgraph=True
allow_scalar_outputs=one_graph or config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
co_fields=self.co_fields,

View File

@ -3847,6 +3847,7 @@ class InstructionTranslator(InstructionTranslatorBase):
global_scope=f_globals,
f_code=f_code,
torch_function_mode_stack=torch_function_mode_stack,
one_graph=one_graph,
package=package,
),
instructions=instructions,

View File

@ -999,7 +999,11 @@ class TensorVariable(VariableTracker):
return DataPtrVariable(self)
def method_item(self, *args, **kwargs):
if not config.capture_scalar_outputs:
from ..symbolic_convert import InstructionTranslator
tx = InstructionTranslator.current_tx()
# We enable capture_scalar_outputs when full_graph=True by default.
if not tx.one_graph and not config.capture_scalar_outputs:
self._warn_capture_scalar_outputs()
unimplemented_v2(
gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False",