`torch.autocast` with `xla` backend has been restricted to `torch.bfloat16`. This shouldn't be the case anymore.
This works with `xla::cast( ..., type=f16)`
```
IR {
%0 = f32[] prim::Constant(), xla_shape=f32[], value=1
%1 = f32[3,2]{1,0} aten::expand(%0), xla_shape=f32[3,2]{1,0}, size=(3, 2), dynamic_dims=(0, 0)
%2 = f16[3,2]{1,0} xla::cast(%1), xla_shape=f16[3,2]{1,0}, type=f16, dtype=Half, stype=Float
%3 = f32[] prim::Constant(), xla_shape=f32[], value=1
%4 = f32[2,3]{1,0} aten::expand(%3), xla_shape=f32[2,3]{1,0}, size=(2, 3), dynamic_dims=(0, 0)
%5 = f16[2,3]{1,0} xla::cast(%4), xla_shape=f16[2,3]{1,0}, type=f16, dtype=Half, stype=Float
%6 = f16[2,2]{1,0} aten::mm(%5, %2), xla_shape=f16[2,2]{1,0}, ROOT=0
}
```
This will allow PyTorch/XLA to extend its autocast implementation to use `xla` backend for `float16` type as well.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109554
Approved by: https://github.com/JackCaoG, https://github.com/bdhirsh
**Summary**
Fix the https://github.com/pytorch/pytorch/issues/100565 by allowing float32 data type when Autocast CPU is disabled. Current behavior is:
- When autocast is disabled and user passes in float data type, it works well.
- When autocast is enabled and user passes in float data type, a warn message throws `UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.` to disable autocast automatically
**TestPlan**
```
python -u -m pytest -s -v test_autocast.py -k test_autocast_disabled_with_fp32_dtype
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107348
Approved by: https://github.com/jgong5, https://github.com/Neilblaze, https://github.com/albanD
This PR adds support for `enable_grad`/`no_grad`/`autocast` context managers getting properly traced in `pre_dispatch` tracing. The stuff in this PR includes:
- I added a torch function mode that runs during make_fx pre_dispatch tracing, `ProxyTorchFunctionMode`. It directly intercepts the torch ops that run during the above context managers, and adds them to the current graph instead of executing them
- `enable_grad` and `no_grad` currently desugar into `torch._C.set_grad_enabled(bool)`, but this API isn't currently overrideable by torch function so I added the ability to interpose there
- the `torch.amp` context managers don't currently have a nice equivalent, like `set_autocast_enabled(state)`, so I ended up adding two new API's: `torch.amp._set_autocast_enabled` and `torch.amp._set_autocast_disabled`. If you look at how the context manager is implemented, it ends up calling several different state-changing functions, some of which depend on the backend - so I figured that it would be cleaner just to add a new API (that should probably only be used by tracing) - but open to feedback
- I added a new dynamo backend, `compile(backend="pre_dispatch_eager")`. When pre_dispatch tracing becomes always-on in inductor, it will be another potential surface for bugs. I also added a test file for it (`test/dynamo/test_pre_dispatch.py`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103024
Approved by: https://github.com/ezyang
As part of this, a new `AutocastIPU` dispatch key has been added.
There's an existing PR, #85043, to make `Autocast` a proper per-backend functionality key, but it ran into issues with layering with other functionality keys and went stale.
This has been tested in the out-of-tree IPU PyTorch backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103890
Approved by: https://github.com/albanD
Fixes #ISSUE_NUMBER
1、optimize the func name of AMP in custom device module,use `torch.foo.set_autocast_enable` instead of `torch.foo.set_autocast_foo_enable`.
2、In AMP with custom device,use `custom_device_mod.set_autocast_enable` instead of `getattr(custom_device_mod, "set_autocast_enable"`, because we have check that `custom_device_mod` hasattr `set_autocast_enable` before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98052
Approved by: https://github.com/bdhirsh
I am trying to use bfloat16 AMP on a range of devices, using the `enabled` argument to actually enable/disable AMP, like this:
```python
with torch.cuda.amp.autocast(enabled=use_amp, dtype=torch.bfloat16):
```
However, this raises a RuntimeError even if enabled=False.
```
File "/venv/lib/python3.8/site-packages/torch/amp/autocast_mode.py", line 221, in __init__
raise RuntimeError('Current CUDA Device does not support bfloat16. Please switch dtype to float16.')
RuntimeError: Current CUDA Device does not support bfloat16. Please switch dtype to float16.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96097
Approved by: https://github.com/ngimel, https://github.com/kit1980
Fixes #ISSUE_NUMBER
1、add amp support for custom backend
2、optimize the file `backend_registration.py`, and rename it with `custom_backend_registration.py`. And then we would register other funcs for custom backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96188
Approved by: https://github.com/bdhirsh
Fixes #ISSUE_NUMBER
1、add amp support for custom backend
2、optimize the file `backend_registration.py`, and rename it with `custom_backend_registration.py`. And then we would register other funcs for custom backend.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96188
Approved by: https://github.com/bdhirsh