Reference amp tutorial (recipe) from core amp docs (#44725)

Summary:
https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html is live.  Core amp docs should reference it.

Also i fixed some typos in the `zero_grad` docs we ignored when git was behaving weirdly during ngimel 's merge of https://github.com/pytorch/pytorch/pull/44423.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/44725

Reviewed By: mruberry

Differential Revision: D23723807

Pulled By: ngimel

fbshipit-source-id: ca0b76365f8ca908bd978e3b38bf81857fa6c2a3
This commit is contained in:
Michael Carilli 2020-09-16 11:29:55 -07:00 committed by Facebook GitHub Bot
parent a011b86115
commit 3e6bb5233f
4 changed files with 16 additions and 11 deletions

View File

@ -14,7 +14,8 @@ are much faster in ``float16``. Other ops, like reductions, often require the dy
range of ``float32``. Mixed precision tries to match each op to its appropriate datatype.
Ordinarily, "automatic mixed precision training" uses :class:`torch.cuda.amp.autocast` and
:class:`torch.cuda.amp.GradScaler` together, as shown in the :ref:`Automatic Mixed Precision examples<amp-examples>`.
:class:`torch.cuda.amp.GradScaler` together, as shown in the :ref:`Automatic Mixed Precision examples<amp-examples>`
and `Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_.
However, :class:`autocast` and :class:`GradScaler` are modular, and may be used separately if desired.
.. contents:: :local:

View File

@ -19,6 +19,10 @@ gradients by minimizing gradient underflow, as explained :ref:`here<gradient-sca
:class:`torch.cuda.amp.autocast` and :class:`torch.cuda.amp.GradScaler` are modular.
In the samples below, each is used as its individual documentation suggests.
(Samples here are illustrative. See the
`Automatic Mixed Precision recipe <https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html>`_
for a runnable walkthrough.)
.. contents:: :local:
Typical Mixed Precision Training

View File

@ -1315,11 +1315,11 @@ class Module:
def zero_grad(self, set_to_none: bool = False) -> None:
r"""Sets gradients of all model parameters to zero. See similar function
under `torch.optimizer` for more contexts.
under :class:`torch.optim.Optimizer` for more context.
Arguments:
set_to_none (bool): instead of setting to zero, set the grad to None.
See :meth:`torch.optim.optimizer.zero_grad` for details.
set_to_none (bool): instead of setting to zero, set the grads to None.
See :meth:`torch.optim.Optimizer.zero_grad` for details.
"""
if getattr(self, '_is_replica', False):
warnings.warn(

View File

@ -165,18 +165,18 @@ class Optimizer(object):
self.__setstate__({'state': state, 'param_groups': param_groups})
def zero_grad(self, set_to_none: bool = False):
r"""Set the gradients of all optimized :class:`torch.Tensor` s to zero.
r"""Sets the gradients of all optimized :class:`torch.Tensor` s to zero.
Arguments:
set_to_none (bool): instead of setting to zero, set the grad to None.
set_to_none (bool): instead of setting to zero, set the grads to None.
This is will in general have lower memory footprint, and can modestly improve performance.
However, it changes certain behaviors. For example:
1. When user tries to access the gradient value and perform manual ops on it.
A None attribute or a Tensor full of 0s will be different.
2. If the user requests `zero_grad(set_to_none=True)` followed by a backward pass, `.grad` s
1. When the user tries to access a gradient and perform manual ops on it,
a None attribute or a Tensor full of 0s will behave differently.
2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s
are guaranteed to be None for params that did not receive a gradient.
3. `torch.optim` optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other it skip
3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None
(in one case it does the step with a gradient of 0 and in the other it skips
the step altogether).
"""
for group in self.param_groups: