mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[devx] Fix invalid symbol definition emitted in fx_graph_runnable.py (#166529)
Summary: When emitting symbolic variable definition in fx_graph_runnable.py, we need to check if a SymNode is actually an expression, so that we won't generate something like "s27*s53**2 = 36". Pull Request resolved: https://github.com/pytorch/pytorch/pull/166529 Approved by: https://github.com/mlazos ghstack dependencies: #166432
This commit is contained in:
parent
08b0a8f11a
commit
2df2c316e2
|
|
@ -382,6 +382,21 @@ class FxGraphRunnableTest(TestCase):
|
|||
torch.compile(f)(x)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
@torch._dynamo.config.patch(assume_static_by_default=False)
|
||||
def test_dynamic_expression(self):
|
||||
"""
|
||||
Test not emitting something like "s27*s53**2 = 36"
|
||||
"""
|
||||
|
||||
def f(x):
|
||||
return torch.ops.aten._adaptive_avg_pool2d(
|
||||
x, (6, 6)
|
||||
), torch.ops.aten._adaptive_avg_pool2d(x + 1, (2, 5))
|
||||
|
||||
x = torch.randn(2, 4, 16, 16)
|
||||
torch.compile(f)(x)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -35,6 +35,8 @@ from tempfile import TemporaryFile
|
|||
from typing import Any, IO, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Unpack
|
||||
|
||||
import sympy
|
||||
|
||||
|
||||
try:
|
||||
from triton.runtime.autotuner import Autotuner, Heuristics
|
||||
|
|
@ -441,22 +443,33 @@ isolate_fails_code_str = None
|
|||
|
||||
# Extract symbolic variables from the same arguments
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if isinstance(arg, torch.SymInt):
|
||||
sym_name = str(arg.node)
|
||||
if arg.node.hint is not None:
|
||||
used_syms[sym_name] = arg.node.hint
|
||||
if (
|
||||
isinstance(arg, torch.SymInt)
|
||||
# By checking sympy.Symbol, we are excluding any symbolic expressions.
|
||||
# TODO: we may need to solve expressions to extract symbol definitions.
|
||||
and isinstance(arg.node.expr, sympy.Symbol)
|
||||
and arg.node.hint is not None
|
||||
):
|
||||
used_syms[str(arg.node)] = arg.node.hint
|
||||
# pyrefly: ignore [unbound-name]
|
||||
elif isinstance(arg, torch.Tensor):
|
||||
# Extract symbolic variables from tensor shapes and strides
|
||||
for dim in arg.shape:
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if isinstance(dim, torch.SymInt) and dim.node.hint is not None:
|
||||
if (
|
||||
isinstance(dim, torch.SymInt)
|
||||
and isinstance(dim.node.expr, sympy.Symbol)
|
||||
and dim.node.hint is not None
|
||||
):
|
||||
used_syms[str(dim.node)] = dim.node.hint
|
||||
for stride in arg.stride():
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if isinstance(stride, torch.SymInt) and stride.node.hint is not None:
|
||||
if (
|
||||
isinstance(stride, torch.SymInt)
|
||||
and isinstance(stride.node.expr, sympy.Symbol)
|
||||
and stride.node.hint is not None
|
||||
):
|
||||
used_syms[str(stride.node)] = stride.node.hint
|
||||
|
||||
# Add symbolic variable definitions to the top of the generated code
|
||||
if used_syms:
|
||||
hint_lines = "\n".join(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user