mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fixes #142397 Basic implementation is done. What's left: - [x] Different dtype/device tensors in the TensorList - [x] fast path for grouping the foreach kernel - [x] Tests Regarding tests, I found some tests in `test/test_torch.py` for GradScaler but I couldn't figure out what is the best way to enable the test for MPS device. By removing `@onlyNativeDeviceTypes`, one enables the tests for MPS but also enables tests for all other devices which are not included in the native device types. If I put: `instantiate_device_type_tests(TestTorchDeviceType, globals(), allow_mps=True)` This enables lots of tests in that class for MPS which were not(?) being tested before? This part needs some clarification Pull Request resolved: https://github.com/pytorch/pytorch/pull/150255 Approved by: https://github.com/malfet Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com> |
||
|---|---|---|
| .. | ||
| _multi_tensor | ||
| __init__.py | ||
| _adafactor.py | ||
| _functional.py | ||
| adadelta.py | ||
| adagrad.py | ||
| adam.py | ||
| adamax.py | ||
| adamw.py | ||
| asgd.py | ||
| lbfgs.py | ||
| lr_scheduler.py | ||
| nadam.py | ||
| optimizer.py | ||
| radam.py | ||
| rmsprop.py | ||
| rprop.py | ||
| sgd.py | ||
| sparse_adam.py | ||
| swa_utils.py | ||