pytorch/torch/amp
Isalia20 49f6cce736 [MPS] grad scaler (#150255)
Fixes #142397

Basic implementation is done. What's left:
- [x] Different dtype/device tensors in the TensorList
- [x] fast path for grouping the foreach kernel
- [x] Tests

Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device.

By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put:
`instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)`

This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150255
Approved by: https://github.com/malfet

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
2025-04-06 17:06:55 +00:00
..
__init__.py generalize custom_fwd&custom_bwd to be device-agnostic (#126531) 2024-05-25 06:48:16 +00:00
autocast_mode.py [MAIA] [Autocast] Enable autocast on MAIA device (#148511) 2025-03-18 03:46:22 +00:00
grad_scaler.py [MPS] grad scaler (#150255) 2025-04-06 17:06:55 +00:00