mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[devx] Fix invalid symbol definition emitted in fx_graph_runnable.py (#166529)
Summary: When emitting symbolic variable definition in fx_graph_runnable.py, we need to check if a SymNode is actually an expression, so that we won't generate something like "s27*s53**2 = 36". Pull Request resolved: https://github.com/pytorch/pytorch/pull/166529 Approved by: https://github.com/mlazos ghstack dependencies: #166432
This commit is contained in:
parent
08b0a8f11a
commit
2df2c316e2
|
|
@ -382,6 +382,21 @@ class FxGraphRunnableTest(TestCase):
|
||||||
torch.compile(f)(x)
|
torch.compile(f)(x)
|
||||||
self._exec_and_verify_payload()
|
self._exec_and_verify_payload()
|
||||||
|
|
||||||
|
@torch._dynamo.config.patch(assume_static_by_default=False)
|
||||||
|
def test_dynamic_expression(self):
|
||||||
|
"""
|
||||||
|
Test not emitting something like "s27*s53**2 = 36"
|
||||||
|
"""
|
||||||
|
|
||||||
|
def f(x):
|
||||||
|
return torch.ops.aten._adaptive_avg_pool2d(
|
||||||
|
x, (6, 6)
|
||||||
|
), torch.ops.aten._adaptive_avg_pool2d(x + 1, (2, 5))
|
||||||
|
|
||||||
|
x = torch.randn(2, 4, 16, 16)
|
||||||
|
torch.compile(f)(x)
|
||||||
|
self._exec_and_verify_payload()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from torch._dynamo.test_case import run_tests
|
from torch._dynamo.test_case import run_tests
|
||||||
|
|
|
||||||
|
|
@ -35,6 +35,8 @@ from tempfile import TemporaryFile
|
||||||
from typing import Any, IO, Optional, TYPE_CHECKING, Union
|
from typing import Any, IO, Optional, TYPE_CHECKING, Union
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
|
import sympy
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from triton.runtime.autotuner import Autotuner, Heuristics
|
from triton.runtime.autotuner import Autotuner, Heuristics
|
||||||
|
|
@ -441,22 +443,33 @@ isolate_fails_code_str = None
|
||||||
|
|
||||||
# Extract symbolic variables from the same arguments
|
# Extract symbolic variables from the same arguments
|
||||||
# pyrefly: ignore [unbound-name]
|
# pyrefly: ignore [unbound-name]
|
||||||
if isinstance(arg, torch.SymInt):
|
if (
|
||||||
sym_name = str(arg.node)
|
isinstance(arg, torch.SymInt)
|
||||||
if arg.node.hint is not None:
|
# By checking sympy.Symbol, we are excluding any symbolic expressions.
|
||||||
used_syms[sym_name] = arg.node.hint
|
# TODO: we may need to solve expressions to extract symbol definitions.
|
||||||
|
and isinstance(arg.node.expr, sympy.Symbol)
|
||||||
|
and arg.node.hint is not None
|
||||||
|
):
|
||||||
|
used_syms[str(arg.node)] = arg.node.hint
|
||||||
# pyrefly: ignore [unbound-name]
|
# pyrefly: ignore [unbound-name]
|
||||||
elif isinstance(arg, torch.Tensor):
|
elif isinstance(arg, torch.Tensor):
|
||||||
# Extract symbolic variables from tensor shapes and strides
|
# Extract symbolic variables from tensor shapes and strides
|
||||||
for dim in arg.shape:
|
for dim in arg.shape:
|
||||||
# pyrefly: ignore [unbound-name]
|
# pyrefly: ignore [unbound-name]
|
||||||
if isinstance(dim, torch.SymInt) and dim.node.hint is not None:
|
if (
|
||||||
|
isinstance(dim, torch.SymInt)
|
||||||
|
and isinstance(dim.node.expr, sympy.Symbol)
|
||||||
|
and dim.node.hint is not None
|
||||||
|
):
|
||||||
used_syms[str(dim.node)] = dim.node.hint
|
used_syms[str(dim.node)] = dim.node.hint
|
||||||
for stride in arg.stride():
|
for stride in arg.stride():
|
||||||
# pyrefly: ignore [unbound-name]
|
# pyrefly: ignore [unbound-name]
|
||||||
if isinstance(stride, torch.SymInt) and stride.node.hint is not None:
|
if (
|
||||||
|
isinstance(stride, torch.SymInt)
|
||||||
|
and isinstance(stride.node.expr, sympy.Symbol)
|
||||||
|
and stride.node.hint is not None
|
||||||
|
):
|
||||||
used_syms[str(stride.node)] = stride.node.hint
|
used_syms[str(stride.node)] = stride.node.hint
|
||||||
|
|
||||||
# Add symbolic variable definitions to the top of the generated code
|
# Add symbolic variable definitions to the top of the generated code
|
||||||
if used_syms:
|
if used_syms:
|
||||||
hint_lines = "\n".join(
|
hint_lines = "\n".join(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user