mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Turn on capture_scalar_outputs when fullgraph=True (#163121)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163121 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
bb7c9a2d41
commit
7dcb568c8f
|
|
@ -91,29 +91,6 @@ from user code:
|
|||
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))""",
|
||||
)
|
||||
|
||||
def test_data_dependent_operator(self):
|
||||
def fn(x):
|
||||
return x.item()
|
||||
|
||||
self.assertExpectedInlineMunged(
|
||||
Unsupported,
|
||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
||||
torch.Tensor([1])
|
||||
),
|
||||
"""\
|
||||
Unsupported Tensor.item() call with capture_scalar_outputs=False
|
||||
Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.
|
||||
Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph.
|
||||
|
||||
Developer debug context: call_method TensorVariable() item () {}
|
||||
|
||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html
|
||||
|
||||
from user code:
|
||||
File "test_error_messages.py", line N, in fn
|
||||
return x.item()""",
|
||||
)
|
||||
|
||||
def test_data_dependent_operator2(self):
|
||||
def fn(x):
|
||||
return torch.equal(x, x)
|
||||
|
|
|
|||
|
|
@ -13354,6 +13354,14 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
|||
y = torch.tensor(5)
|
||||
f(x, y)
|
||||
|
||||
def test_full_graph_capture_scalar_outputs(self):
|
||||
@torch.compile(fullgraph=True)
|
||||
def foo(a):
|
||||
return torch.randn(5) * a.item()
|
||||
|
||||
# We expect to no longer raise here
|
||||
foo(torch.tensor(2.0))
|
||||
|
||||
def test_dynamic_float_scalar_tensor_coersion(self):
|
||||
# Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367
|
||||
class Foo:
|
||||
|
|
|
|||
|
|
@ -2007,7 +2007,6 @@ def forward(self, pred_1, x_1):
|
|||
# Fails with: AssertionError: scan is not an OpOverload
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@requires_cuda
|
||||
@unittest.expectedFailure
|
||||
def test_scan_associative_scan(self):
|
||||
combine_mode = "generic"
|
||||
compile_mode_scan = "compile"
|
||||
|
|
|
|||
|
|
@ -431,6 +431,7 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
f_code: CodeType,
|
||||
torch_function_mode_stack: list[torch.overrides.TorchFunctionMode],
|
||||
package: Optional["CompilePackage"],
|
||||
one_graph: bool = False,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
local_scope,
|
||||
|
|
@ -487,7 +488,8 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
# TrackedFake instances may have its metadata changed throughout
|
||||
# the program execution.
|
||||
tracked_fakes=self.tracked_fakes,
|
||||
allow_scalar_outputs=config.capture_scalar_outputs,
|
||||
# We want to allow scalar outputs when fullgraph=True
|
||||
allow_scalar_outputs=one_graph or config.capture_scalar_outputs,
|
||||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
|
||||
co_fields=self.co_fields,
|
||||
|
|
|
|||
|
|
@ -3847,6 +3847,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
global_scope=f_globals,
|
||||
f_code=f_code,
|
||||
torch_function_mode_stack=torch_function_mode_stack,
|
||||
one_graph=one_graph,
|
||||
package=package,
|
||||
),
|
||||
instructions=instructions,
|
||||
|
|
|
|||
|
|
@ -999,7 +999,11 @@ class TensorVariable(VariableTracker):
|
|||
return DataPtrVariable(self)
|
||||
|
||||
def method_item(self, *args, **kwargs):
|
||||
if not config.capture_scalar_outputs:
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
# We enable capture_scalar_outputs when full_graph=True by default.
|
||||
if not tx.one_graph and not config.capture_scalar_outputs:
|
||||
self._warn_capture_scalar_outputs()
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user