pytorch/test/inductor/test_utils.py
thenumberouscode c106ee8515 [FakeTensor] Supplement the relevant logic for converting conv1d to conv2d in meta_conv (#160408)
## Fixes https://github.com/pytorch/pytorch/issues/159462 also fixes #163569 , #163604

## summary
the issue is caused by the wrong stride of conv1d's result generated by meta_conv:
4d5b3f2d5a/torch/_meta_registrations.py (L2453-L2471)

and the wrong stride will be used to codegen size assert in inductor:
4d5b3f2d5a/torch/_inductor/ir.py (L6152-L6163)

## reason
So why the computed stride is wrong in the meta_conv function? because the corresponding backend will convert conv1d to conv2d and change the input tensor' size and memory_format(channel last). but the meta_conv do not do this transformation, so a mismatch happend.
4d5b3f2d5a/aten/src/ATen/native/Convolution.cpp (L1502-L1510)
 just add corresponding logic in meta_conv.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160408
Approved by: https://github.com/eellison, https://github.com/jansel, https://github.com/mlazos
2025-09-26 15:45:02 +00:00

208 lines
7.3 KiB
Python

# Owner(s): ["module: inductor"]
import unittest
from sympy import Symbol, sympify
import torch
from torch._inductor.fx_utils import count_flops_fx, countable_fx
from torch._inductor.utils import get_device_tflops, sympy_str, sympy_subs
from torch._inductor.virtualized import V
from torch.testing._internal.common_device_type import (
dtypes,
instantiate_device_type_tests,
)
from torch.testing._internal.common_utils import run_tests, TestCase
class TestUtils(TestCase):
def test_zip_schema(self):
def foo(x: torch.Tensor) -> None:
pass
result = torch.library.custom_op("mylib::foo", foo, mutates_args={"x"})
schema = result._opoverload._schema
g = torch.tensor([11, 2])
found = False
for arg, val in torch._library.utils.zip_schema(schema, [], {"x": g}):
if arg.name == "x":
found = True
self.assertTrue(found)
found = False
for arg, val in torch._library.utils.zip_schema(schema, [g], {}):
if arg.name == "x":
found = True
self.assertTrue(found)
def testSympySubs(self):
# integer and nonnegetaive attributes are preserved.
expr = Symbol("x")
result = sympy_subs(expr, {expr: "y"})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, None)
self.assertEqual(result.is_nonnegative, None)
expr = Symbol("x", integer=True, nonnegative=False)
result = sympy_subs(expr, {expr: "y"})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, True)
self.assertEqual(result.is_nonnegative, False)
# invalid replacement.
expr = Symbol("x", integer=True)
result = sympy_subs(expr, {Symbol("x"): Symbol("y")})
self.assertEqual(result.name, "x")
# valid replacement since properties match.
expr = Symbol("x", integer=True)
result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")})
self.assertEqual(result.name, "y")
# invalid replacement.
expr = Symbol("x", integer=None)
result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")})
self.assertEqual(result.name, "x")
# replaced can't be string
self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"})
# replaced can be an expression
expr = Symbol("x")
expr = abs(expr)
self.assertEqual(expr.is_integer, None)
self.assertEqual(expr.is_nonnegative, None)
# replace abs(x) with y
# propagte abs(x) sympy properties.
result = sympy_subs(expr, {expr: Symbol("y")})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, None)
self.assertEqual(result.is_nonnegative, None)
def test_sympy_str(self):
self.assertEqual(sympy_str(sympify("a+b+c")), "a + b + c")
self.assertEqual(sympy_str(sympify("a*b+c")), "c + a * b")
self.assertEqual(sympy_str(sympify("a+b*(c+d)")), "a + b * (c + d)")
self.assertEqual(sympy_str(sympify("(a+b)*(c+d)")), "(a + b) * (c + d)")
self.assertEqual(sympy_str(sympify("-a")), "-a")
self.assertEqual(sympy_str(sympify("a-b")), "a - b")
self.assertEqual(sympy_str(sympify("a+-b")), "a - b")
def test_flops_fx(self):
def create_fx_node(
aten: torch._ops.OpOverloadPacket, args, kwargs
) -> tuple[torch.fx.Node, torch.fx.Node]:
node1 = torch.fx.Node(
graph=torch.fx.Graph(),
name="",
op="call_function",
target=aten,
args=args,
kwargs=kwargs,
)
name: str = aten.overloads()[0]
op_overload: torch._ops.OpOverload = getattr(aten, name)
node2 = torch.fx.Node(
graph=torch.fx.Graph(),
name="",
op="call_function",
target=op_overload,
args=args,
kwargs=kwargs,
)
return node1, node2
with V.set_fake_mode(
torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
):
trues = [
(
torch.ops.aten.addmm,
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
{},
),
(
torch.ops.aten.bmm,
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
{},
),
(torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}),
(
torch.ops.aten.convolution,
(
torch.Tensor(2, 2, 3),
torch.Tensor(2, 2, 2),
torch.Tensor(2),
(1,),
(0,),
(1,),
True,
(0,),
1,
),
{},
),
(
torch.ops.aten._convolution,
(
torch.Tensor(2, 2, 2),
torch.Tensor(2, 2, 2),
torch.Tensor(2),
(1,),
(0,),
(1,),
True,
(0,),
1,
False,
True,
False,
),
{},
),
]
# we don't support pointwise ops
falses = [
(
torch.ops.aten.add,
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
{},
),
(
torch.ops.aten.mul,
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
{},
),
]
for t, args, kwargs in trues:
fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs)
self.assertTrue(
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
)
self.assertTrue(
countable_fx(fx_node_2), f"Expected true {t}: {fx_node_2}"
)
self.assertNotEqual(count_flops_fx(fx_node_1), None)
self.assertNotEqual(count_flops_fx(fx_node_2), None)
for f, args, kwargs in falses:
fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs)
self.assertFalse(
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
)
self.assertFalse(
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
)
@unittest.skipIf(not torch.cuda.is_available(), "skip if no device")
@dtypes(torch.float16, torch.bfloat16, torch.float32)
def test_get_device_tflops(self, dtype):
ret = get_device_tflops(dtype)
self.assertTrue(type(ret) == float)
instantiate_device_type_tests(TestUtils, globals())
if __name__ == "__main__":
run_tests()