pytorch/test/dynamo/test_nops.py
Michael Voznesensky 4cdc96fb4f Add hooks structure for passing around user provided hooks, add a new guard_failure_fn (#90371)
This PR introduces a new function we can pass to torch._dynamo.optimize - guard_failure_fn. Usage is in the PR, and the one stacked on top of it, but the gist of it is that it emits failed guard reason strings alongside code. This is useful for tests and debugging, as it gives far finer grained assertions and control than the compile counter alone.

This is a resubmit of https://github.com/pytorch/pytorch/pull/90129

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90371
Approved by: https://github.com/ezyang
2022-12-07 17:51:53 +00:00

74 lines
1.5 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
from torch._dynamo import eval_frame
from torch._dynamo.hooks import Hooks
c = 10
def fn1(a, b):
return a + b - c
def fn2(a, b):
x = 0
y = 1
def modify():
nonlocal x
x += a + b + c
for _ in range(2):
modify()
return x + y
def fn3():
yield 1
yield 2
with_debug_nops = eval_frame._optimize_catch_errors(
torch._dynamo.testing.debug_insert_nops, Hooks(None, None)
)
class NopTests(torch._dynamo.test_case.TestCase):
@with_debug_nops
def test1(self):
self.assertEqual(fn1(1, 2), -7)
self.assertEqual(fn1(1, 2), -7)
@with_debug_nops
def test2(self):
self.assertEqual(fn2(1, 2), 27)
self.assertEqual(fn2(1, 2), 27)
@with_debug_nops
def test3(self):
t = fn3()
self.assertEqual(next(t), 1)
self.assertEqual(next(t), 2)
self.assertRaises(StopIteration, lambda: next(t))
def test_extended_args(self):
too_many_adds = "+".join(["a", "b"] * 256)
source = (
f"lambda a, b: ({too_many_adds}+a if a.sum() > 0 else {too_many_adds} - b)"
)
fn = eval(source)
a = torch.ones(1)
b = torch.ones(1)
fn = with_debug_nops(fn)
self.assertEqual(fn(a, b).sum(), 513)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()