mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Turn on capture_scalar_outputs when fullgraph=True (#163121)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163121 Approved by: https://github.com/laithsakka
This commit is contained in:
parent
bb7c9a2d41
commit
7dcb568c8f
|
|
@ -91,29 +91,6 @@ from user code:
|
||||||
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))""",
|
return torch.linalg.lstsq(torch.rand(10, 10), torch.rand(10, 10))""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_data_dependent_operator(self):
|
|
||||||
def fn(x):
|
|
||||||
return x.item()
|
|
||||||
|
|
||||||
self.assertExpectedInlineMunged(
|
|
||||||
Unsupported,
|
|
||||||
lambda: torch.compile(fn, backend="eager", fullgraph=True)(
|
|
||||||
torch.Tensor([1])
|
|
||||||
),
|
|
||||||
"""\
|
|
||||||
Unsupported Tensor.item() call with capture_scalar_outputs=False
|
|
||||||
Explanation: Dynamo does not support tracing `Tensor.item()` with config.capture_scalar_outputs=False.
|
|
||||||
Hint: Set `torch._dynamo.config.capture_scalar_outputs = True` or `export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1` to include these operations in the captured graph.
|
|
||||||
|
|
||||||
Developer debug context: call_method TensorVariable() item () {}
|
|
||||||
|
|
||||||
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0124.html
|
|
||||||
|
|
||||||
from user code:
|
|
||||||
File "test_error_messages.py", line N, in fn
|
|
||||||
return x.item()""",
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_data_dependent_operator2(self):
|
def test_data_dependent_operator2(self):
|
||||||
def fn(x):
|
def fn(x):
|
||||||
return torch.equal(x, x)
|
return torch.equal(x, x)
|
||||||
|
|
|
||||||
|
|
@ -13354,6 +13354,14 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
||||||
y = torch.tensor(5)
|
y = torch.tensor(5)
|
||||||
f(x, y)
|
f(x, y)
|
||||||
|
|
||||||
|
def test_full_graph_capture_scalar_outputs(self):
|
||||||
|
@torch.compile(fullgraph=True)
|
||||||
|
def foo(a):
|
||||||
|
return torch.randn(5) * a.item()
|
||||||
|
|
||||||
|
# We expect to no longer raise here
|
||||||
|
foo(torch.tensor(2.0))
|
||||||
|
|
||||||
def test_dynamic_float_scalar_tensor_coersion(self):
|
def test_dynamic_float_scalar_tensor_coersion(self):
|
||||||
# Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367
|
# Minified version of https://github.com/pytorch/pytorch/issues/158376#issuecomment-3079591367
|
||||||
class Foo:
|
class Foo:
|
||||||
|
|
|
||||||
|
|
@ -2007,7 +2007,6 @@ def forward(self, pred_1, x_1):
|
||||||
# Fails with: AssertionError: scan is not an OpOverload
|
# Fails with: AssertionError: scan is not an OpOverload
|
||||||
@unittest.skipIf(not SM70OrLater, "triton")
|
@unittest.skipIf(not SM70OrLater, "triton")
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
@unittest.expectedFailure
|
|
||||||
def test_scan_associative_scan(self):
|
def test_scan_associative_scan(self):
|
||||||
combine_mode = "generic"
|
combine_mode = "generic"
|
||||||
compile_mode_scan = "compile"
|
compile_mode_scan = "compile"
|
||||||
|
|
|
||||||
|
|
@ -431,6 +431,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||||
f_code: CodeType,
|
f_code: CodeType,
|
||||||
torch_function_mode_stack: list[torch.overrides.TorchFunctionMode],
|
torch_function_mode_stack: list[torch.overrides.TorchFunctionMode],
|
||||||
package: Optional["CompilePackage"],
|
package: Optional["CompilePackage"],
|
||||||
|
one_graph: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
local_scope,
|
local_scope,
|
||||||
|
|
@ -487,7 +488,8 @@ class OutputGraph(OutputGraphGuardsState):
|
||||||
# TrackedFake instances may have its metadata changed throughout
|
# TrackedFake instances may have its metadata changed throughout
|
||||||
# the program execution.
|
# the program execution.
|
||||||
tracked_fakes=self.tracked_fakes,
|
tracked_fakes=self.tracked_fakes,
|
||||||
allow_scalar_outputs=config.capture_scalar_outputs,
|
# We want to allow scalar outputs when fullgraph=True
|
||||||
|
allow_scalar_outputs=one_graph or config.capture_scalar_outputs,
|
||||||
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
|
||||||
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
|
prefer_deferred_runtime_asserts_over_guards=config.prefer_deferred_runtime_asserts_over_guards,
|
||||||
co_fields=self.co_fields,
|
co_fields=self.co_fields,
|
||||||
|
|
|
||||||
|
|
@ -3847,6 +3847,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||||
global_scope=f_globals,
|
global_scope=f_globals,
|
||||||
f_code=f_code,
|
f_code=f_code,
|
||||||
torch_function_mode_stack=torch_function_mode_stack,
|
torch_function_mode_stack=torch_function_mode_stack,
|
||||||
|
one_graph=one_graph,
|
||||||
package=package,
|
package=package,
|
||||||
),
|
),
|
||||||
instructions=instructions,
|
instructions=instructions,
|
||||||
|
|
|
||||||
|
|
@ -999,7 +999,11 @@ class TensorVariable(VariableTracker):
|
||||||
return DataPtrVariable(self)
|
return DataPtrVariable(self)
|
||||||
|
|
||||||
def method_item(self, *args, **kwargs):
|
def method_item(self, *args, **kwargs):
|
||||||
if not config.capture_scalar_outputs:
|
from ..symbolic_convert import InstructionTranslator
|
||||||
|
|
||||||
|
tx = InstructionTranslator.current_tx()
|
||||||
|
# We enable capture_scalar_outputs when full_graph=True by default.
|
||||||
|
if not tx.one_graph and not config.capture_scalar_outputs:
|
||||||
self._warn_capture_scalar_outputs()
|
self._warn_capture_scalar_outputs()
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False",
|
gb_type="Unsupported Tensor.item() call with capture_scalar_outputs=False",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user