[devx] Fix invalid symbol definition emitted in fx_graph_runnable.py (#166529)

Summary: When emitting symbolic variable definition in fx_graph_runnable.py, we need to check if a SymNode is actually an expression, so that we won't generate something like "s27*s53**2 = 36".

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166529
Approved by: https://github.com/mlazos
ghstack dependencies: #166432
This commit is contained in:
Bin Bao 2025-10-29 08:15:04 -07:00 committed by PyTorch MergeBot
parent 08b0a8f11a
commit 2df2c316e2
2 changed files with 35 additions and 7 deletions

View File

@ -382,6 +382,21 @@ class FxGraphRunnableTest(TestCase):
torch.compile(f)(x)
self._exec_and_verify_payload()
@torch._dynamo.config.patch(assume_static_by_default=False)
def test_dynamic_expression(self):
"""
Test not emitting something like "s27*s53**2 = 36"
"""
def f(x):
return torch.ops.aten._adaptive_avg_pool2d(
x, (6, 6)
), torch.ops.aten._adaptive_avg_pool2d(x + 1, (2, 5))
x = torch.randn(2, 4, 16, 16)
torch.compile(f)(x)
self._exec_and_verify_payload()
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -35,6 +35,8 @@ from tempfile import TemporaryFile
from typing import Any, IO, Optional, TYPE_CHECKING, Union
from typing_extensions import Unpack
import sympy
try:
from triton.runtime.autotuner import Autotuner, Heuristics
@ -441,22 +443,33 @@ isolate_fails_code_str = None
# Extract symbolic variables from the same arguments
# pyrefly: ignore [unbound-name]
if isinstance(arg, torch.SymInt):
sym_name = str(arg.node)
if arg.node.hint is not None:
used_syms[sym_name] = arg.node.hint
if (
isinstance(arg, torch.SymInt)
# By checking sympy.Symbol, we are excluding any symbolic expressions.
# TODO: we may need to solve expressions to extract symbol definitions.
and isinstance(arg.node.expr, sympy.Symbol)
and arg.node.hint is not None
):
used_syms[str(arg.node)] = arg.node.hint
# pyrefly: ignore [unbound-name]
elif isinstance(arg, torch.Tensor):
# Extract symbolic variables from tensor shapes and strides
for dim in arg.shape:
# pyrefly: ignore [unbound-name]
if isinstance(dim, torch.SymInt) and dim.node.hint is not None:
if (
isinstance(dim, torch.SymInt)
and isinstance(dim.node.expr, sympy.Symbol)
and dim.node.hint is not None
):
used_syms[str(dim.node)] = dim.node.hint
for stride in arg.stride():
# pyrefly: ignore [unbound-name]
if isinstance(stride, torch.SymInt) and stride.node.hint is not None:
if (
isinstance(stride, torch.SymInt)
and isinstance(stride.node.expr, sympy.Symbol)
and stride.node.hint is not None
):
used_syms[str(stride.node)] = stride.node.hint
# Add symbolic variable definitions to the top of the generated code
if used_syms:
hint_lines = "\n".join(