mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Teach dynamo about grad Pull Request resolved: https://github.com/pytorch/pytorch/pull/102264 Approved by: https://github.com/zou3519
108 lines
4.6 KiB
Python
108 lines
4.6 KiB
Python
# NOTE: We allow Dynamo to see this file (via torch/_dynamo/skipfiles.py) so that it can
|
|
# trace through `grad`.
|
|
# Currently, we can't allow Dynamo to see `eager_transforms.py` as that break a lot of thing
|
|
# and there isn't a mechanism to selectively expose only some functions (eg. grad) from a file
|
|
# to Dynamo.
|
|
from torch._functorch.eager_transforms import grad_impl, exposed_in, Callable, argnums_t
|
|
import functools
|
|
|
|
@exposed_in("torch.func")
|
|
def grad(func: Callable, argnums: argnums_t = 0, has_aux: bool = False) -> Callable:
|
|
"""``grad`` operator helps computing gradients of ``func`` with respect to the
|
|
input(s) specified by ``argnums``. This operator can be nested to
|
|
compute higher-order gradients.
|
|
|
|
Args:
|
|
func (Callable): A Python function that takes one or more arguments.
|
|
Must return a single-element Tensor. If specified ``has_aux`` equals ``True``,
|
|
function can return a tuple of single-element Tensor and other auxiliary objects:
|
|
``(output, aux)``.
|
|
argnums (int or Tuple[int]): Specifies arguments to compute gradients with respect to.
|
|
``argnums`` can be single integer or tuple of integers. Default: 0.
|
|
has_aux (bool): Flag indicating that ``func`` returns a tensor and other
|
|
auxiliary objects: ``(output, aux)``. Default: False.
|
|
|
|
Returns:
|
|
Function to compute gradients with respect to its inputs. By default, the output of
|
|
the function is the gradient tensor(s) with respect to the first argument.
|
|
If specified ``has_aux`` equals ``True``, tuple of gradients and output auxiliary objects
|
|
is returned. If ``argnums`` is a tuple of integers, a tuple of output gradients with
|
|
respect to each ``argnums`` value is returned.
|
|
|
|
Example of using ``grad``:
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> from torch.func import grad
|
|
>>> x = torch.randn([])
|
|
>>> cos_x = grad(lambda x: torch.sin(x))(x)
|
|
>>> assert torch.allclose(cos_x, x.cos())
|
|
>>>
|
|
>>> # Second-order gradients
|
|
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
|
|
>>> assert torch.allclose(neg_sin_x, -x.sin())
|
|
|
|
When composed with ``vmap``, ``grad`` can be used to compute per-sample-gradients:
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> from torch.func import grad, vmap
|
|
>>> batch_size, feature_size = 3, 5
|
|
>>>
|
|
>>> def model(weights, feature_vec):
|
|
>>> # Very simple linear model with activation
|
|
>>> assert feature_vec.dim() == 1
|
|
>>> return feature_vec.dot(weights).relu()
|
|
>>>
|
|
>>> def compute_loss(weights, example, target):
|
|
>>> y = model(weights, example)
|
|
>>> return ((y - target) ** 2).mean() # MSELoss
|
|
>>>
|
|
>>> weights = torch.randn(feature_size, requires_grad=True)
|
|
>>> examples = torch.randn(batch_size, feature_size)
|
|
>>> targets = torch.randn(batch_size)
|
|
>>> inputs = (weights, examples, targets)
|
|
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)
|
|
|
|
Example of using ``grad`` with ``has_aux`` and ``argnums``:
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> from torch.func import grad
|
|
>>> def my_loss_func(y, y_pred):
|
|
>>> loss_per_sample = (0.5 * y_pred - y) ** 2
|
|
>>> loss = loss_per_sample.mean()
|
|
>>> return loss, (y_pred, loss_per_sample)
|
|
>>>
|
|
>>> fn = grad(my_loss_func, argnums=(0, 1), has_aux=True)
|
|
>>> y_true = torch.rand(4)
|
|
>>> y_preds = torch.rand(4, requires_grad=True)
|
|
>>> out = fn(y_true, y_preds)
|
|
>>> # > output is ((grads w.r.t y_true, grads w.r.t y_preds), (y_pred, loss_per_sample))
|
|
|
|
.. note::
|
|
Using PyTorch ``torch.no_grad`` together with ``grad``.
|
|
|
|
Case 1: Using ``torch.no_grad`` inside a function:
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> def f(x):
|
|
>>> with torch.no_grad():
|
|
>>> c = x ** 2
|
|
>>> return x - c
|
|
|
|
In this case, ``grad(f)(x)`` will respect the inner ``torch.no_grad``.
|
|
|
|
Case 2: Using ``grad`` inside ``torch.no_grad`` context manager:
|
|
|
|
>>> # xdoctest: +SKIP
|
|
>>> with torch.no_grad():
|
|
>>> grad(f)(x)
|
|
|
|
In this case, ``grad`` will respect the inner ``torch.no_grad``, but not the
|
|
outer one. This is because ``grad`` is a "function transform": its result
|
|
should not depend on the result of a context manager outside of ``f``.
|
|
|
|
"""
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return grad_impl(func, argnums, has_aux, args, kwargs)
|
|
return wrapper
|