mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49097 RFC: https://github.com/pytorch/rfcs/pull/11 This PR add the basic logic to handle forward grad as dual Tensors. It contains the following: - Mechanism to save dual state on a Tensor and clear it up when the dual level ends - C++ and python user facing API - Updated view system that is able to track both forward and backward views The current PR has the following limitations: - Extensive tests are in the next PR in the stack as formulas are needed to write full tests. - Only the manual formulas have been audited and no other formula is actually implemented here (they are in the next PR in the stack) - Only level 0 is allowed for now. This was discussed and agreed that it is not needed for the first version of this PR. - We can save one ViewInfo creation when both the forward and backward views have the same base. This can be done by adding a boolean flag to the DifferentiableViewMeta and extra logic in the `as_view` method. This is left out to keep this PR concise. - We can skip tracking forward views if the base has a forward grad. This can be done by adding extra logic in the `as_view` method. This is left out to keep this PR concise. Reading guide: - Updated view handling in [gen_variable_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-f6553cec68caeaea36f6c8b14ff76a6d39dfd774e0ea9ef2f76e8d81fd9af5df), [VariableTypeUtils.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-ec71cfa45954dece1236c661d170e6341879c5be637f4abf52e826d61b40695a), [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285) (skip code below "[Forward Grad View]" for now), [variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-1604bcd0e4350ed99ec45e437cee7ac9ebe337392c9ea16a236247aeeb35b02bR266-R542) and [custom_function.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-dd85f452082b5bb6612bbc12adb496f8827defa228509f7b493de1d517522d5d). This introduces the new ViewInfo to hold view informations shared for forward and backward. It also updates the differentiable view meta to use this. And it updates the as_view function to handle both forward and backward view. - New forward grad class that handle storing gradients and tracking at each level [forward_grad.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c6c5b9ab2d7e5dde4102495faa1b6bbbfc23aa3e47deb7359c0bfe1eb004c0cb), [forward_grad.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-de2ab54ade7312701850d71a119a4f4ee4b9fc5a9c42a467cdd4e73c033531dd) and [build_variables.bzl](https://github.com/pytorch/pytorch/pull/49097/files#diff-dfdfa2efb17beddfd9094524f95351fd197db6c8857e96b436fb599870359325). EDIT: These files also contain the new flag to globally disable forward AD that allows us to reduce performance issues while this is in development. - Lowest level API and binding between Tensor and AutogradMeta in [TensorBody.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-7554853205392fa743357bf845ecc350a974ec049383248c12daaf2f4de04911), [TensorImpl.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-052bd9150ef8e09289ddf644b5a6830ede49207201cd41728f6d7cc6d9cead94), [TensorImpl.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-a15aae4cf23da44970db7cece62ff981265575c798c62f7b52d87c8809dfe2e1) and the rest of [variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-60e3bfe444e89efc7149f25b38e472710525984789934ab83f1bd5671b8ff285R557-R677) - API to access the forward primal that needs to be a differentiable function (and so in native_functions.yaml) [native_functions.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-2f3dbd85efb9b5172f2264eedd3be47dd765e6ab7cc8bf3ade5e62c28ae35991) [NamedRegistrations.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-69bd3bea510c9b64e1633fa18c3ea63d4b8348dbad3a78ad9de844ab3e43dc1d), [VariableMethodsStub.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-23f5fcb737a2b289811fe0f4b65aef775e7c824b2e629ecd343df51405cd434f), [derivatives.yaml](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_python_functions.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-e4c2f99a2404e98c3586e07425da73008f36b1bada790648a7297af141d37f8c), [gen_trace_type.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-54e0b976027bf8debefb959ff360b89ae93466970c843365b1b3a03806d868ce), [TraceTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-f34636741ad4a23d018e0c289bc750c3bad887b45660e1d6eaf440d234a78fbf) and [part of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R198-R243) - c++ API [autograd.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-349028fbe8291a965a7a263c323b208fe071c35c66179ee997ef84fa81aa4b1e), [autograd.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-a3fe908d67dfec16a1fcde300de68b0701bf68b88db7451f29f2bee255cf30c9) - python binding [init.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-c58a67c85191c22c9b3bb439117d8053edfd9dea839fa010cf967d404c3c630d) - python API [forward_ad.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a4efad4ba18fffdfb264c21e5475997a24a743089a899f8ec1a5ff962c6738d9), [autograd/__init__.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-743abcafd32ad0e69f39ac5a91df4197b7e1921c135cacee7ef6dc829a8a7af8) - c++ and python printing [Formatting.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-881dba501e71662e2e4818b4b016f739b344c8aed2f5edc6b871eda47a2aced0), [_tensor_str.py](https://github.com/pytorch/pytorch/pull/49097/files#diff-a7911f8d5e73adbff914d99fd7818ace2a7030b6a3748abe06ec6fc6e3df9cc3) - Utility for formulas and updated manual functions to respect new view system as well as forward grad [FunctionsManual.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-6378bb6dc81a64dab676d61731341fa5d1088418f32a1473a33a0ccfc2357dc1), [FunctionsManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-4adbd88239afcd60e8198aab65d4f5e43b62314e34b80551e997a1ea503adea5) [rest of VariableTypeManual.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-6e19a1bce8cbdba8714b6e2c794a76bc0864b64a49cfa757cb0b5afdc937d1a4R264-R433) - Ensure SavedVariable save forward grad properly [saved_variable.h](https://github.com/pytorch/pytorch/pull/49097/files#diff-c1b8039d776241abe177d5aa99b79dd9489a9b3e529da8ab24c2e386c1238ae2), [saved_variable.cpp](https://github.com/pytorch/pytorch/pull/49097/files#diff-cc9fba479b5beae06b2eea2e390d17796e0341c5b037a20b5bcaccbb0c341030) Test Plan: Imported from OSS Reviewed By: mrshenli Differential Revision: D25607503 Pulled By: albanD fbshipit-source-id: f1396290de1d75760f3d380c43cdd56e86fa6099
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
import torch
|
|
from .grad_mode import _DecoratorContextManager
|
|
|
|
from typing import Any
|
|
|
|
# TODO(alband): Once most of the formulas are implemented, these functions need to be added
|
|
# to the main doc to make them fully "public".
|
|
|
|
# Global variable used to make the python API simpler to use
|
|
_current_level = -1
|
|
|
|
def enter_dual_level():
|
|
r"""Function that can be used to enter a new forward grad level.
|
|
This level can be used to make and unpack dual Tensors to compute
|
|
forward gradients.
|
|
|
|
This function also updates the current level that is used by default
|
|
by the other functions in this API.
|
|
"""
|
|
global _current_level
|
|
new_level = torch._C._enter_dual_level()
|
|
if new_level != _current_level + 1:
|
|
raise RuntimeError("Entering a new forward AD level but the current level "
|
|
"is not valid. Make sure you did not modified it directly.")
|
|
_current_level = new_level
|
|
return new_level
|
|
|
|
def exit_dual_level(*, level=None):
|
|
r"""Function that can be used to exit a forward grad level.
|
|
This function deletes all the gradients associated with this
|
|
level. Only deleting the latest entered level is allowed.
|
|
|
|
This function also updates the current level that is used by default
|
|
by the other functions in this API.
|
|
"""
|
|
global _current_level
|
|
if level is None:
|
|
level = _current_level
|
|
if level != _current_level:
|
|
raise RuntimeError("Trying to exit a forward AD level that was not the last one "
|
|
"that was created. This is not supported.")
|
|
torch._C._exit_dual_level(level=level)
|
|
_current_level = level - 1
|
|
|
|
def make_dual(tensor, tangent, *, level=None):
|
|
r"""Function that creates a "dual object" that can be used to compute forward AD gradients
|
|
based on the given Tensor and its tangent. It returns a new Tensor that shares memory with
|
|
:attr:`tensor` and the :attr:`tangent` is used as-is.
|
|
|
|
This function is backward differentiable.
|
|
|
|
Given a function `f` whose jacobian is `J`, it allows to compute the jacobian vector product,
|
|
named `jvp`, between `J` and a given vector `v` as follows.
|
|
|
|
Example::
|
|
>>> inp = make_dual(x, v)
|
|
>>> out = f(inp)
|
|
>>> y, jvp = unpack_dual(out)
|
|
|
|
"""
|
|
if level is None:
|
|
level = _current_level
|
|
|
|
if level < 0:
|
|
raise RuntimeError("Trying to create a dual Tensor for forward AD but no level "
|
|
"exists, make sure to enter_dual_level() first.")
|
|
|
|
return torch.make_dual(tensor, tangent, level=level)
|
|
|
|
def unpack_dual(tensor, *, level=None):
|
|
r"""Function that unpacks a "dual object" to recover two plain tensors, one representing
|
|
the primal and the other the tangent (both are views of :attr:`tensor`. Neither of these
|
|
tensors can be dual tensor of level :attr:`level`.
|
|
|
|
This function is backward differentiable.
|
|
"""
|
|
if level is None:
|
|
level = _current_level
|
|
|
|
if level < 0:
|
|
return tensor, None
|
|
|
|
return torch.unpack_dual(tensor, level=level)
|
|
|
|
class dual_level(_DecoratorContextManager):
|
|
r"""Context-manager that controls the current forward ad level. It
|
|
appropriately enters and exit the dual level.
|
|
|
|
This function also updates the current level that is used by default
|
|
by the other functions in this API.
|
|
|
|
Example::
|
|
|
|
>>> x = torch.tensor([1])
|
|
>>> x_t = torch.tensor([1])
|
|
>>> with dual_level():
|
|
... inp = make_dual(x, x_t)
|
|
... # Do computations with inp
|
|
... out = your_fn(inp)
|
|
... _, grad = unpack_dual(out)
|
|
>>> grad is None
|
|
False
|
|
>>> # After exiting the level, the grad is deleted
|
|
>>> _, grad_after = unpack_dual(out)
|
|
>>> grad is None
|
|
True
|
|
|
|
"""
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def __enter__(self):
|
|
return enter_dual_level()
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
exit_dual_level()
|