pytorch/torch/amp
Shangdi Yu a47bb4a393 Fix autocast for non-strict export (#137495)
Summary:

add testing for autocast and set_grad nodes for export_for_training. In export_for_training, we do not wrap the autocast and set_grad node in to HOP, but we should still have the set_grad_enabled/autocast nodes.

add support for autocast in non-strict export. Previously, `_enter_autocast` and `_exit_autocast` nodes don't show up in the export graph when we use `strict=False`.

- In autocast's enter and exit function, we dispatch to `PreDispatchTorchFunctionMode.__torch_function__`.
 if we have PreDispatchTorchFunctionMode in our function_mode_stack, the call stack looks like below. This is mostly the same call stack as strict mode, except strict mode enters [here](https://www.internalfb.com/code/fbsource/[0d4f1135cacdb26c6e01d5dce1ce52a15d61ee48]/xplat/caffe2/torch/_dynamo/variables/ctx_manager.py?lines=806).
```
- torch.amp.autocast.__enter__()'s torch.overrides.handle_torch_function
- torch.fx.experimental.proxy_tensor.TorchFunctionMetadataMode.__torch_function__
- torch.amp._enter_autocast()'s torch.overrides.handle_torch_function
- PreDispatchTorchFunctionMode.__torch_function__
```
- in `PreDispatchTorchFunctionMode.__torch_function__`, we create the autocast nodes.
- to match the strict mode behavior, we let the input node to the `_exist_autocast` node be the corresponding `_enter_autocast` node. This requires us to maintain a stack in `PreDispatchTorchFunctionMode`.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_with_autocast
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test:test_export  -- -r  test_export_with_set_grad
```

Differential Revision: D64016023

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137495
Approved by: https://github.com/bdhirsh
2024-10-16 17:39:00 +00:00
..
__init__.py generalize custom_fwd&custom_bwd to be device-agnostic (#126531) 2024-05-25 06:48:16 +00:00
autocast_mode.py Fix autocast for non-strict export (#137495) 2024-10-16 17:39:00 +00:00
grad_scaler.py Add support to GradScaler for respecting an already set grad_scale value (#123429) 2024-06-27 22:40:54 +00:00