This PR adds support for `enable_grad`/`no_grad`/`autocast` context managers getting properly traced in `pre_dispatch` tracing. The stuff in this PR includes:
- I added a torch function mode that runs during make_fx pre_dispatch tracing, `ProxyTorchFunctionMode`. It directly intercepts the torch ops that run during the above context managers, and adds them to the current graph instead of executing them
- `enable_grad` and `no_grad` currently desugar into `torch._C.set_grad_enabled(bool)`, but this API isn't currently overrideable by torch function so I added the ability to interpose there
- the `torch.amp` context managers don't currently have a nice equivalent, like `set_autocast_enabled(state)`, so I ended up adding two new API's: `torch.amp._set_autocast_enabled` and `torch.amp._set_autocast_disabled`. If you look at how the context manager is implemented, it ends up calling several different state-changing functions, some of which depend on the backend - so I figured that it would be cleaner just to add a new API (that should probably only be used by tracing) - but open to feedback
- I added a new dynamo backend, `compile(backend="pre_dispatch_eager")`. When pre_dispatch tracing becomes always-on in inductor, it will be another potential surface for bugs. I also added a test file for it (`test/dynamo/test_pre_dispatch.py`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103024
Approved by: https://github.com/ezyang
Fixes Meta internal user case.
Repro:
```
import torch
import torch._dynamo
def fn(x):
with torch.cuda.amp.autocast(False):
x = torch.sin(x + 1)
return x
x = torch.randn([2, 3])
ref = fn(x)
print(ref)
opt_fn = torch._dynamo.optimize(backend="inductor")(fn)
print(opt_fn(x))
```
Error:
```
Traceback (most recent call last):
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 425, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 410, in transform
tracer.run()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 2010, in run
super().run()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 703, in run
and self.step()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 663, in step
getattr(self, inst.opname)(inst)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 385, in wrapper
return inner_fn(self, inst)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1095, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 554, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/torch.py", line 381, in call_function
return AutocastModeVariable.create(target_values=args, kwargs=kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/ctx_manager.py", line 198, in create
bound_args = inspect.signature(torch.autocast).bind(*target_values, **kwargs)
File "/scratch/ybliang/work/env/lib/python3.9/inspect.py", line 3045, in bind
return self._bind(args, kwargs)
File "/scratch/ybliang/work/env/lib/python3.9/inspect.py", line 2984, in _bind
raise TypeError(
TypeError: multiple values for argument 'device_type'
from user code:
File "/scratch/ybliang/work/repos/debug/debug6.py", line 10, in fn
with torch.cuda.amp.autocast(False):
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101052
Approved by: https://github.com/anijain2305
This is a draft version of generic context manager, I believe there are some scenarios that I didn't anticipate. I posted this draft for discussion and check if this is the right direction.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98725
Approved by: https://github.com/jansel