mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
This PR implements the gradient scaling API that mruberry, jjsjann123, ngimel, zdevito, gchanan and I have been discussing. Relevant issue/RFC: https://github.com/pytorch/pytorch/issues/25081.
Volume-wise, this PR is mostly documentation and tests. The Python API (found entirely in `torch/cuda/amp/amp_scaler.py`) is lightweight . The exposed functions are intended to make the implementation and control flow of gradient scaling convenient, intuitive, and performant.
The API is probably easiest to digest by looking at the documentation and examples. `docs/source/amp.rst` is the homepage for the Automatic Mixed Precision package. `docs/source/notes/amp_examples.rst` includes several examples demonstrating common but not-immediately-obvious use cases. Examples are backed by tests in `test_cuda.py` (and thankfully the tests pass :P).
Two small utility kernels have been added in `native/cuda/AmpKernels.cu` to improve performance and avoid host-device synchronizations wherever possible.
Existing optimizers, both in the wild and in Pytorch core, do not need to change to use the scaling API.
However, the API was also designed to establish a contract between user scripts and optimizers such that writers of _new_ custom optimizers have the control points they need to implement fast, optionally sync-free updates. User scripts that obey the scaling API can drop such custom optimizers in and reap performance benefits without having to change anything aside from the optimizer constructor itself. [I know what the contract with custom optimizers should be](
|
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| notes | ||
| org/pytorch | ||
| scripts | ||
| __config__.rst | ||
| amp.rst | ||
| autograd.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| conf.py | ||
| cpp_extension.rst | ||
| cuda_deterministic_backward.rst | ||
| cuda_deterministic.rst | ||
| cuda.rst | ||
| cudnn_deterministic.rst | ||
| cudnn_persistent_rnn.rst | ||
| data.rst | ||
| distributed.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit.rst | ||
| math-quantizer-equation.png | ||
| model_zoo.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx.rst | ||
| optim.rst | ||
| packages.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| sparse.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| torch.rst | ||
| type_info.rst | ||