pytorch/test/dynamo/test_pre_dispatch.py
Brian Hirsh 875f60399e pre_dispatch tracing: support autocast and no_grad/enable_grad ctx managers, add a pre_dispatch_eager dynamo backend (#103024)
This PR adds support for `enable_grad`/`no_grad`/`autocast` context managers getting properly traced in `pre_dispatch` tracing. The stuff in this PR includes:
- I added a torch function mode that runs during make_fx pre_dispatch tracing, `ProxyTorchFunctionMode`. It directly intercepts the torch ops that run during the above context managers, and adds them to the current graph instead of executing them
- `enable_grad` and `no_grad` currently desugar into `torch._C.set_grad_enabled(bool)`, but this API isn't currently overrideable by torch function so I added the ability to interpose there
- the `torch.amp` context managers don't currently have a nice equivalent, like `set_autocast_enabled(state)`, so I ended up adding two new API's: `torch.amp._set_autocast_enabled` and `torch.amp._set_autocast_disabled`. If you look at how the context manager is implemented, it ends up calling several different state-changing functions, some of which depend on the backend - so I figured that it would be cleaner just to add a new API (that should probably only be used by tracing) - but open to feedback
- I added a new dynamo backend, `compile(backend="pre_dispatch_eager")`. When pre_dispatch tracing becomes always-on in inductor, it will be another potential surface for bugs. I also added a test file for it (`test/dynamo/test_pre_dispatch.py`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103024
Approved by: https://github.com/ezyang
2023-06-29 14:17:42 +00:00

77 lines
2.1 KiB
Python

# Owner(s): ["module: dynamo"]
import torch
import torch._dynamo
import torch._dynamo.test_case
class PreDispatchTests(torch._dynamo.test_case.TestCase):
def test_no_grad_simple(self):
def f(a):
b = a.sin()
with torch.no_grad():
c = b.cos()
return b * c.sin()
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
a_ref = torch.randn(4, requires_grad=True)
a_test = a_ref.clone().detach().requires_grad_(True)
out_ref = f(a_ref)
out_test = f_compiled(a_test)
self.assertEqual(out_ref, out_test)
out_ref.sum().backward()
out_test.sum().backward()
self.assertEqual(a_ref.grad, a_test.grad)
def test_enable_grad_and_no_grad(self):
def f(a):
b = a * 2
with torch.no_grad():
c = b * 3
with torch.enable_grad():
d = c * 4
e = d * 5
return b + c + d + e
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
a_ref = torch.randn(4, requires_grad=True)
a_test = a_ref.clone().detach().requires_grad_(True)
out_ref = f(a_ref)
out_test = f_compiled(a_test)
self.assertEqual(out_ref, out_test)
out_ref.sum().backward()
out_test.sum().backward()
self.assertEqual(a_ref.grad, a_test.grad)
def test_autocast_simple(self):
def f(a):
b = a * 2
with torch.amp.autocast(device_type="cpu"):
c = torch.matmul(b, b)
return b + c
f_compiled = torch.compile(f, backend="pre_dispatch_eager")
a_ref = torch.randn(4, device="cpu", requires_grad=True)
a_test = a_ref.clone().detach().requires_grad_(True)
out_ref = f(a_ref)
out_test = f_compiled(a_test)
self.assertEqual(out_ref, out_test)
out_ref.sum().backward()
out_test.sum().backward()
self.assertEqual(a_ref.grad, a_test.grad)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()