pytorch/test/inductor/test_utils.py
Laith Sakka 7c1d93944e Proper handling of arguments passed by in kwargs inside zip_schema (#137311)
if the function is

```func(a, b, c) ```
and is called as
```func(a=1, b=.., c=..)```
before this change we do not iterate on the a, b, c, since those appear in kwargs this diff fix that issue.

This function is used in _inductor/ir.py to iterate over custom op arguments and when a custom pass does changes
and pass arguments as kwargs, we do not process them.
```
        for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
            handle_aliasing_and_mutation(info, arg)
```
Fix https://github.com/pytorch/pytorch/issues/137057

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137311
Approved by: https://github.com/zou3519
2024-10-04 21:50:31 +00:00

78 lines
2.6 KiB
Python

# Owner(s): ["module: inductor"]
from sympy import Symbol
import torch
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import sympy_subs
class TestUtils(TestCase):
def test_zip_schema(self):
def foo(x: torch.Tensor) -> None:
pass
result = torch.library.custom_op("mylib::foo", foo, mutates_args={"x"})
schema = result._opoverload._schema
g = torch.tensor([11, 2])
found = False
for arg, val in torch._library.utils.zip_schema(schema, [], {"x": g}):
if arg.name == "x":
found = True
self.assertTrue(found)
found = False
for arg, val in torch._library.utils.zip_schema(schema, [g], {}):
if arg.name == "x":
found = True
self.assertTrue(found)
def testSympySubs(self):
# integer and nonnegetaive attributes are preserved.
expr = Symbol("x")
result = sympy_subs(expr, {expr: "y"})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, None)
self.assertEqual(result.is_nonnegative, None)
expr = Symbol("x", integer=True, nonnegative=False)
result = sympy_subs(expr, {expr: "y"})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, True)
self.assertEqual(result.is_nonnegative, False)
# invalid replacement.
expr = Symbol("x", integer=True)
result = sympy_subs(expr, {Symbol("x"): Symbol("y")})
self.assertEqual(result.name, "x")
# valid replacement since properties match.
expr = Symbol("x", integer=True)
result = sympy_subs(expr, {Symbol("x", integer=True): Symbol("y")})
self.assertEqual(result.name, "y")
# invalid replacement.
expr = Symbol("x", integer=None)
result = sympy_subs(expr, {Symbol("x", integer=False): Symbol("y")})
self.assertEqual(result.name, "x")
# replaced cant be string
self.assertRaises(AssertionError, sympy_subs, expr, {"x": "y"})
# replaced can be an expression
expr = Symbol("x")
expr = abs(expr)
self.assertEqual(expr.is_integer, None)
self.assertEqual(expr.is_nonnegative, None)
# replace abs(x) with y
# propagte abs(x) sympy properties.
result = sympy_subs(expr, {expr: Symbol("y")})
self.assertEqual(result.name, "y")
self.assertEqual(result.is_integer, None)
self.assertEqual(result.is_nonnegative, None)
if __name__ == "__main__":
run_tests()