diff --git a/docs/source/nn.rst b/docs/source/nn.rst index c02b3204573..0f8c89c6d26 100644 --- a/docs/source/nn.rst +++ b/docs/source/nn.rst @@ -373,6 +373,8 @@ Utility functions to clip parameter gradients. clip_grad_norm_ clip_grad_norm clip_grad_value_ + get_total_norm + clip_grads_with_norm_ Utility functions to flatten and unflatten Module parameters to and from a single vector. diff --git a/test/test_nn.py b/test/test_nn.py index ae8a8fd9037..ec88ebe27c4 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -22,7 +22,7 @@ import torch.backends.cudnn as cudnn import torch.nn as nn import torch.nn.functional as F import torch.nn.utils.rnn as rnn_utils -from torch.nn.utils import clip_grad_norm_, clip_grad_value_ +from torch.nn.utils import clip_grad_norm_, clip_grad_value_, clip_grads_with_norm_, get_total_norm from torch.nn.utils import parameters_to_vector, vector_to_parameters from torch.nn.utils.fusion import fuse_conv_bn_weights from torch.nn.utils.fusion import fuse_linear_bn_weights @@ -12820,6 +12820,20 @@ if __name__ == '__main__': self.assertLessEqual(norm_after, norm_before) compare_scaling(grads) + # decomposed APIs should behave as expected + grads = torch.arange(1., 101, device=device).view(10, 10), torch.ones(10, device=device).div(1000) + for p, g in zip(l.parameters(), grads): + p._grad = g.clone().view_as(p) + norm_before = compute_norm(norm_type) + grads = [p.grad for p in l.parameters()] + total_norm = get_total_norm(grads, norm_type=norm_type, foreach=foreach) + clip_grads_with_norm_(l.parameters(), max_norm, total_norm, foreach=foreach) + norm_after = compute_norm(norm_type) + self.assertEqual(total_norm, norm_before) + self.assertEqual(norm_after, max_norm) + self.assertLessEqual(norm_after, norm_before) + compare_scaling(grads) + # Small gradients should be left unchanged grads = torch.rand(10, 10, device=device).div(10000), torch.ones(10, device=device).div(500) for p, g in zip(l.parameters(), grads): diff --git a/torch/nn/utils/__init__.py b/torch/nn/utils/__init__.py index e4dcc773691..5af9ed93e92 100644 --- a/torch/nn/utils/__init__.py +++ b/torch/nn/utils/__init__.py @@ -1,5 +1,11 @@ from . import parametrizations, rnn, stateless -from .clip_grad import clip_grad_norm, clip_grad_norm_, clip_grad_value_ +from .clip_grad import ( + _clip_grads_with_norm_ as clip_grads_with_norm_, + _get_total_norm as get_total_norm, + clip_grad_norm, + clip_grad_norm_, + clip_grad_value_, +) from .convert_parameters import parameters_to_vector, vector_to_parameters from .fusion import ( fuse_conv_bn_eval, @@ -19,6 +25,7 @@ from .weight_norm import remove_weight_norm, weight_norm __all__ = [ "clip_grad_norm", "clip_grad_norm_", + "clip_grads_with_norm_", "clip_grad_value_", "convert_conv2d_weight_memory_format", "convert_conv3d_weight_memory_format", @@ -26,6 +33,7 @@ __all__ = [ "fuse_conv_bn_weights", "fuse_linear_bn_eval", "fuse_linear_bn_weights", + "get_total_norm", "parameters_to_vector", "parametrizations", "remove_spectral_norm", diff --git a/torch/nn/utils/clip_grad.py b/torch/nn/utils/clip_grad.py index ea895b9c959..c51fb273bc1 100644 --- a/torch/nn/utils/clip_grad.py +++ b/torch/nn/utils/clip_grad.py @@ -13,7 +13,11 @@ from torch.utils._foreach_utils import ( ) -__all__ = ["clip_grad_norm_", "clip_grad_norm", "clip_grad_value_"] +__all__ = [ + "clip_grad_norm_", + "clip_grad_norm", + "clip_grad_value_", +] _tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] @@ -33,6 +37,141 @@ def _no_grad(func): return _no_grad_wrapper +@_no_grad +def _get_total_norm( + tensors: _tensor_or_tensors, + norm_type: float = 2.0, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, +) -> torch.Tensor: + r"""Compute the norm of an iterable of tensors. + + The norm is computed over the norms of the individual tensors, as if the norms of + the individual tensors were concatenated into a single vector. + + Args: + tensors (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will be normalized + norm_type (float): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of :attr:`tensors` is ``nan``, ``inf``, or ``-inf``. + Default: ``False`` + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + Total norm of the tensors (viewed as a single vector). + """ + if isinstance(tensors, torch.Tensor): + tensors = [tensors] + else: + tensors = list(tensors) + norm_type = float(norm_type) + if len(tensors) == 0: + return torch.tensor(0.0) + first_device = tensors[0].device + grouped_tensors: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [tensors] # type: ignore[list-item] + ) # type: ignore[assignment] + + norms: List[Tensor] = [] + for (device, _), ([device_tensors], _) in grouped_tensors.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_tensors, device)) or ( + foreach and _device_has_foreach_support(device) + ): + norms.extend(torch._foreach_norm(device_tensors, norm_type)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + norms.extend( + [torch.linalg.vector_norm(g, norm_type) for g in device_tensors] + ) + + total_norm = torch.linalg.vector_norm( + torch.stack([norm.to(first_device) for norm in norms]), norm_type + ) + + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f"The total norm of order {norm_type} for gradients from " + "`parameters` is non-finite, so it cannot be clipped. To disable " + "this error and scale the gradients by the non-finite norm anyway, " + "set `error_if_nonfinite=False`" + ) + return total_norm + + +@_no_grad +def _clip_grads_with_norm_( + parameters: _tensor_or_tensors, + max_norm: float, + total_norm: torch.Tensor, + foreach: Optional[bool] = None, +) -> None: + r"""Scale the gradients of an iterable of parameters given a pre-calculated total norm and desired max norm. + + The gradients will be scaled by the following calculation + + .. math:: + grad = grad * \frac{max\_norm}{total\_norm + 1e-6} + + Gradients are modified in-place. + + This function is equivalent to :func:`torch.nn.utils.clip_grad_norm_` with a pre-calculated + total norm. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float): max norm of the gradients + total_norm (Tensor): total norm of the gradients to use for clipping + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` + + Returns: + None + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + if len(grads) == 0: + return + grouped_grads: Dict[ + Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] + ] = _group_tensors_by_device_and_dtype( + [grads] + ) # type: ignore[assignment] + + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] + if (foreach is None and _has_foreach_support(device_grads, device)) or ( + foreach and _device_has_foreach_support(device) + ): + torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) + elif foreach: + raise RuntimeError( + f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" + ) + else: + clip_coef_clamped_device = clip_coef_clamped.to(device) + for g in device_grads: + g.mul_(clip_coef_clamped_device) + + @_no_grad def clip_grad_norm_( parameters: _tensor_or_tensors, @@ -47,6 +186,9 @@ def clip_grad_norm_( as if the norms of the individual gradients were concatenated into a single vector. Gradients are modified in-place. + This function is equivalent to :func:`torch.nn.utils.get_total_norm` followed by + :func:`torch.nn.utils.clip_grads_with_norm_` with the ``total_norm`` returned by ``get_total_norm``. + Args: parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a single Tensor that will have gradients normalized @@ -66,61 +208,12 @@ def clip_grad_norm_( """ if isinstance(parameters, torch.Tensor): parameters = [parameters] + else: + # prevent generators from being exhausted + parameters = list(parameters) grads = [p.grad for p in parameters if p.grad is not None] - max_norm = float(max_norm) - norm_type = float(norm_type) - if len(grads) == 0: - return torch.tensor(0.0) - first_device = grads[0].device - grouped_grads: Dict[ - Tuple[torch.device, torch.dtype], Tuple[List[List[Tensor]], List[int]] - ] = _group_tensors_by_device_and_dtype( - [grads] - ) # type: ignore[assignment] - - norms: List[Tensor] = [] - for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] - if (foreach is None and _has_foreach_support(device_grads, device)) or ( - foreach and _device_has_foreach_support(device) - ): - norms.extend(torch._foreach_norm(device_grads, norm_type)) - elif foreach: - raise RuntimeError( - f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" - ) - else: - norms.extend([torch.linalg.vector_norm(g, norm_type) for g in device_grads]) - - total_norm = torch.linalg.vector_norm( - torch.stack([norm.to(first_device) for norm in norms]), norm_type - ) - - if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): - raise RuntimeError( - f"The total norm of order {norm_type} for gradients from " - "`parameters` is non-finite, so it cannot be clipped. To disable " - "this error and scale the gradients by the non-finite norm anyway, " - "set `error_if_nonfinite=False`" - ) - clip_coef = max_norm / (total_norm + 1e-6) - # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so - # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization - # when the gradients do not reside in CPU memory. - clip_coef_clamped = torch.clamp(clip_coef, max=1.0) - for (device, _), ([device_grads], _) in grouped_grads.items(): # type: ignore[assignment] - if (foreach is None and _has_foreach_support(device_grads, device)) or ( - foreach and _device_has_foreach_support(device) - ): - torch._foreach_mul_(device_grads, clip_coef_clamped.to(device)) - elif foreach: - raise RuntimeError( - f"foreach=True was passed, but can't use the foreach API on {device.type} tensors" - ) - else: - clip_coef_clamped_device = clip_coef_clamped.to(device) - for g in device_grads: - g.mul_(clip_coef_clamped_device) - + total_norm = _get_total_norm(grads, norm_type, error_if_nonfinite, foreach) + _clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) return total_norm