mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit 347ace4c7a.
Reverted https://github.com/pytorch/pytorch/pull/149697 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it seems to fail on ROCm ([comment](https://github.com/pytorch/pytorch/pull/149697#issuecomment-3020006655))
194 lines
6.8 KiB
Python
194 lines
6.8 KiB
Python
# Owner(s): ["module: inductor"]
|
|
|
|
from sympy import Symbol, sympify
|
|
|
|
import torch
|
|
from torch._inductor.fx_utils import count_flops_fx, countable_fx
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch._inductor.utils import sympy_str, sympy_subs
|
|
from torch._inductor.virtualized import V
|
|
|
|
|
|
class TestUtils(TestCase):
|
|
def test_zip_schema(self):
|
|
def foo(x: torch.Tensor) -> None:
|
|
pass
|
|
|
|
result = torch.library.custom_op("mylib::foo", foo, mutates_args={"x"})
|
|
schema = result._opoverload._schema
|
|
g = torch.tensor([11, 2])
|
|
found = False
|
|
for arg, val in torch._library.utils.zip_schema(schema, [], {"x": g}):
|
|
if arg.name == "x":
|
|
found = True
|
|
|
|
self.assertTrue(found)
|
|
|
|
found = False
|
|
for arg, val in torch._library.utils.zip_schema(schema, [g], {}):
|
|
if arg.name == "x":
|
|
found = True
|
|
self.assertTrue(found)
|
|
|
|
def testSympySubs(self):
|
|
# integer and nonnegetaive attributes are preserved.
|
|
expr = Symbol("x")
|
|
result = sympy_subs(expr, {expr: "y"})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, None)
|
|
self.assertEqual(result.is_nonnegative, None)
|
|
|
|
expr = Symbol("x", integer=True, nonnegative=False)
|
|
result = sympy_subs(expr, {expr: "y"})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, True)
|
|
self.assertEqual(result.is_nonnegative, False)
|
|
|
|
# invalid replacement.
|
|
expr = Symbol("x", integer=True)
|
|
result = sympy_subs(expr, {Symbol("x"): Symbol("y")})
|
|
self.assertEqual(result.name, "x")
|
|
|
|
# valid replacement since properties match.
|
|
expr = Symbol("x", integer=True)
|
|
result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")})
|
|
self.assertEqual(result.name, "y")
|
|
|
|
# invalid replacement.
|
|
expr = Symbol("x", integer=None)
|
|
result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")})
|
|
self.assertEqual(result.name, "x")
|
|
|
|
# replaced cant be string
|
|
self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"})
|
|
|
|
# replaced can be an expression
|
|
expr = Symbol("x")
|
|
expr = abs(expr)
|
|
self.assertEqual(expr.is_integer, None)
|
|
self.assertEqual(expr.is_nonnegative, None)
|
|
# replace abs(x) with y
|
|
# propagte abs(x) sympy properties.
|
|
result = sympy_subs(expr, {expr: Symbol("y")})
|
|
self.assertEqual(result.name, "y")
|
|
self.assertEqual(result.is_integer, None)
|
|
self.assertEqual(result.is_nonnegative, None)
|
|
|
|
def test_sympy_str(self):
|
|
self.assertEqual(sympy_str(sympify("a+b+c")), "a + b + c")
|
|
self.assertEqual(sympy_str(sympify("a*b+c")), "c + a * b")
|
|
self.assertEqual(sympy_str(sympify("a+b*(c+d)")), "a + b * (c + d)")
|
|
self.assertEqual(sympy_str(sympify("(a+b)*(c+d)")), "(a + b) * (c + d)")
|
|
self.assertEqual(sympy_str(sympify("-a")), "-a")
|
|
self.assertEqual(sympy_str(sympify("a-b")), "a - b")
|
|
self.assertEqual(sympy_str(sympify("a+-b")), "a - b")
|
|
|
|
def test_flops_fx(self):
|
|
def create_fx_node(
|
|
aten: torch._ops.OpOverloadPacket, args, kwargs
|
|
) -> tuple[torch.fx.Node, torch.fx.Node]:
|
|
node1 = torch.fx.Node(
|
|
graph=torch.fx.Graph(),
|
|
name="",
|
|
op="call_function",
|
|
target=aten,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
)
|
|
name: str = aten.overloads()[0]
|
|
op_overload: torch._ops.OpOverload = getattr(aten, name)
|
|
node2 = torch.fx.Node(
|
|
graph=torch.fx.Graph(),
|
|
name="",
|
|
op="call_function",
|
|
target=op_overload,
|
|
args=args,
|
|
kwargs=kwargs,
|
|
)
|
|
return node1, node2
|
|
|
|
with V.set_fake_mode(
|
|
torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
|
):
|
|
trues = [
|
|
(
|
|
torch.ops.aten.addmm,
|
|
(torch.Tensor(4, 4), torch.Tensor(4, 5), torch.Tensor(5, 4)),
|
|
{},
|
|
),
|
|
(
|
|
torch.ops.aten.bmm,
|
|
(torch.Tensor(10, 4, 5), torch.Tensor(10, 5, 4)),
|
|
{},
|
|
),
|
|
(torch.ops.aten.mm, (torch.Tensor(2, 3), torch.Tensor(3, 2)), {}),
|
|
(
|
|
torch.ops.aten.convolution,
|
|
(
|
|
torch.Tensor(2, 3, 3),
|
|
torch.Tensor(2, 2, 2),
|
|
torch.Tensor(2),
|
|
(1, 1),
|
|
(0, 0),
|
|
(1, 1),
|
|
True,
|
|
(0, 0),
|
|
1,
|
|
),
|
|
{},
|
|
),
|
|
(
|
|
torch.ops.aten._convolution,
|
|
(
|
|
torch.Tensor(2, 2, 2),
|
|
torch.Tensor(2, 2, 2),
|
|
torch.Tensor(2),
|
|
(1,),
|
|
(0,),
|
|
(1,),
|
|
True,
|
|
(0,),
|
|
1,
|
|
False,
|
|
True,
|
|
False,
|
|
),
|
|
{},
|
|
),
|
|
]
|
|
# we don't support pointwise ops
|
|
falses = [
|
|
(
|
|
torch.ops.aten.add,
|
|
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
|
{},
|
|
),
|
|
(
|
|
torch.ops.aten.mul,
|
|
(torch.Tensor(1, 2, 3), torch.Tensor(1, 2, 3)),
|
|
{},
|
|
),
|
|
]
|
|
for t, args, kwargs in trues:
|
|
fx_node_1, fx_node_2 = create_fx_node(t, args, kwargs)
|
|
self.assertTrue(
|
|
countable_fx(fx_node_1), f"Expected true {t}: {fx_node_1}"
|
|
)
|
|
self.assertTrue(
|
|
countable_fx(fx_node_2), f"Expected true {t}: {fx_node_2}"
|
|
)
|
|
self.assertNotEqual(count_flops_fx(fx_node_1), None)
|
|
self.assertNotEqual(count_flops_fx(fx_node_2), None)
|
|
for f, args, kwargs in falses:
|
|
fx_node_1, fx_node_2 = create_fx_node(f, args, kwargs)
|
|
self.assertFalse(
|
|
countable_fx(fx_node_1), f"Expected false {f}: {fx_node_1}"
|
|
)
|
|
self.assertFalse(
|
|
countable_fx(fx_node_2), f"Expected false {f}: {fx_node_2}"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|