Autograd doc cleanup (#118500)

I don't think we'll realistically go though deprecation for these now since there are a couple use of each online. So document appropriately.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118500
Approved by: https://github.com/soulitzer
This commit is contained in:
albanD 2024-01-29 21:51:28 +00:00 committed by PyTorch MergeBot
parent fc5cde7579
commit a40be5f4dc
8 changed files with 134 additions and 84 deletions

View File

@ -33,6 +33,9 @@ for detailed steps on how to use this API.
forward_ad.dual_level
forward_ad.make_dual
forward_ad.unpack_dual
forward_ad.enter_dual_level
forward_ad.exit_dual_level
forward_ad.UnpackedDualTensor
.. _functional-api:
@ -209,6 +212,27 @@ When creating a new :class:`Function`, the following methods are available to `c
function.FunctionCtx.save_for_backward
function.FunctionCtx.set_materialize_grads
Custom Function utilities
^^^^^^^^^^^^^^^^^^^^^^^^^
Decorator for backward method.
.. autosummary::
:toctree: generated
:nosignatures:
function.once_differentiable
Base custom :class:`Function` used to build PyTorch utilities
.. autosummary::
:toctree: generated
:nosignatures:
function.BackwardCFunction
function.InplaceFunction
function.NestedIOFunction
.. _grad-check:
Numerical gradient checking
@ -224,6 +248,7 @@ Numerical gradient checking
gradcheck
gradgradcheck
GradcheckError
.. Just to reset the base path for the rest of this file
.. currentmodule:: torch.autograd
@ -249,6 +274,14 @@ and vtune profiler based using
profiler.profile.key_averages
profiler.profile.self_cpu_time_total
profiler.profile.total_average
profiler.parse_nvprof_trace
profiler.EnforceUnique
profiler.KinetoStepTracker
profiler.record_function
profiler_util.Interval
profiler_util.Kernel
profiler_util.MemRecordsAcc
profiler_util.StringTable
.. autoclass:: torch.autograd.profiler.emit_nvtx
.. autoclass:: torch.autograd.profiler.emit_itt
@ -260,13 +293,20 @@ and vtune profiler based using
profiler.load_nvprof
Anomaly detection
Debugging and anomaly detection
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: detect_anomaly
.. autoclass:: set_detect_anomaly
.. autosummary::
:toctree: generated
:nosignatures:
grad_mode.set_multithreading_enabled
Autograd graph
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@ -286,6 +326,7 @@ enabled and at least one of the inputs required gradients), or ``None`` otherwis
graph.Node.next_functions
graph.Node.register_hook
graph.Node.register_prehook
graph.increment_version
Some operations need intermediary results to be saved during the forward pass
in order to execute the backward pass.

View File

@ -91,9 +91,6 @@ templates_path = ["_templates"]
coverage_ignore_functions = [
# torch
"typename",
# torch.autograd
"register_py_tensor_class_for_device",
"variable",
# torch.cuda
"check_error",
"cudart",
@ -390,20 +387,6 @@ coverage_ignore_functions = [
"weight_dtype",
"weight_is_quantized",
"weight_is_statically_quantized",
# torch.autograd.forward_ad
"enter_dual_level",
"exit_dual_level",
# torch.autograd.function
"once_differentiable",
"traceable",
# torch.autograd.gradcheck
"get_analytical_jacobian",
"get_numerical_jacobian",
"get_numerical_jacobian_wrt_specific_input",
# torch.autograd.graph
"increment_version",
# torch.autograd.profiler
"parse_nvprof_trace",
# torch.backends.cudnn.rnn
"get_cudnn_mode",
"init_dropout_state",
@ -2530,40 +2513,6 @@ coverage_ignore_classes = [
"QuantWrapper",
# torch.ao.quantization.utils
"MatchAllNode",
# torch.autograd.forward_ad
"UnpackedDualTensor",
# torch.autograd.function
"BackwardCFunction",
"Function",
"FunctionCtx",
"FunctionMeta",
"InplaceFunction",
"NestedIOFunction",
# torch.autograd.grad_mode
"inference_mode",
"set_grad_enabled",
"set_multithreading_enabled",
# torch.autograd.gradcheck
"GradcheckError",
# torch.autograd.profiler
"EnforceUnique",
"KinetoStepTracker",
"profile",
"record_function",
# torch.autograd.profiler_legacy
"profile",
# torch.autograd.profiler_util
"EventList",
"FormattedTimesMixin",
"FunctionEvent",
"FunctionEventAvg",
"Interval",
"Kernel",
"MemRecordsAcc",
"StringTable",
# torch.autograd.variable
"Variable",
"VariableMeta",
# torch.backends.cudnn.rnn
"Unserializable",
# torch.cuda.amp.grad_scaler

View File

@ -274,9 +274,9 @@ Examples::
no_grad
enable_grad
set_grad_enabled
autograd.grad_mode.set_grad_enabled
is_grad_enabled
inference_mode
autograd.grad_mode.inference_mode
is_inference_mode_enabled
Math operations

View File

@ -35,20 +35,6 @@
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn"
},
"torch.autograd": [
"NestedIOFunction",
"detect_anomaly",
"enable_grad",
"grad",
"gradcheck",
"gradgradcheck",
"inference_mode",
"no_grad",
"set_detect_anomaly",
"set_grad_enabled",
"set_multithreading_enabled",
"variable"
],
"torch.backends": [
"contextmanager"
],

View File

@ -279,7 +279,14 @@ class _HookMixin:
class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
r"""
This class is used for internal autograd work. Do not use.
"""
def apply(self, *args):
r"""
Apply method used when executing this Node during the backward
"""
# _forward_cls is defined by derived class
# The user should define either backward or vjp but never both.
backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
@ -294,6 +301,9 @@ class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
return user_fn(self, *args)
def apply_jvp(self, *args):
r"""
Apply method used when executing forward mode AD during the forward
"""
# _forward_cls is defined by derived class
return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined]
@ -378,10 +388,10 @@ class _SingleLevelFunction(
Either:
1. Override forward with the signature forward(ctx, *args, **kwargs).
1. Override forward with the signature ``forward(ctx, *args, **kwargs)``.
``setup_context`` is not overridden. Setting up the ctx for backward
happens inside the ``forward``.
2. Override forward with the signature forward(*args, **kwargs) and
2. Override forward with the signature ``forward(*args, **kwargs)`` and
override ``setup_context``. Setting up the ctx for backward happens
inside ``setup_context`` (as opposed to inside the ``forward``)
@ -639,6 +649,11 @@ def traceable(fn_cls):
class InplaceFunction(Function):
r"""
This class is here only for backward compatibility reasons.
Use :class:`Function` instead of this for any new use case.
"""
def __init__(self, inplace=False):
super().__init__()
self.inplace = inplace
@ -754,6 +769,10 @@ _map_tensor_data = _nested_map(
class NestedIOFunction(Function):
r"""
This class is here only for backward compatibility reasons.
Use :class:`Function` instead of this for any new use case.
"""
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
# superclass (Function) but are instance methods here, which mypy reports as incompatible.
@ -774,6 +793,9 @@ class NestedIOFunction(Function):
return result
def backward(self, *gradients: Any) -> Any: # type: ignore[override]
r"""
Shared backward utility.
"""
nested_gradients = _unflatten(gradients, self._nested_output)
result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value]
return tuple(_iter_None_tensors(result))
@ -781,6 +803,9 @@ class NestedIOFunction(Function):
__call__ = _do_forward
def forward(self, *args: Any) -> Any: # type: ignore[override]
r"""
Shared forward utility.
"""
nested_tensors = _map_tensor_data(self._nested_input)
result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value]
del self._nested_input
@ -788,22 +813,40 @@ class NestedIOFunction(Function):
return tuple(_iter_tensors(result))
def save_for_backward(self, *args: Any) -> None:
r"""
See :meth:`Function.save_for_backward`.
"""
self.to_save = tuple(_iter_tensors(args))
self._to_save_nested = args
@property
def saved_tensors(self):
r"""
See :meth:`Function.saved_tensors`.
"""
flat_tensors = super().saved_tensors # type: ignore[misc]
return _unflatten(flat_tensors, self._to_save_nested)
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
r"""
See :meth:`Function.mark_dirty`.
"""
self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
r"""
See :meth:`Function.mark_non_differentiable`.
"""
self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
def forward_extended(self, *input: Any) -> None:
r"""
User defined forward.
"""
raise NotImplementedError
def backward_extended(self, *grad_output: Any) -> None:
r"""
User defined backward.
"""
raise NotImplementedError

View File

@ -196,6 +196,9 @@ class set_grad_enabled(_DecoratorContextManager):
torch._C._set_grad_enabled(self.prev)
def clone(self) -> "set_grad_enabled":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
@ -272,6 +275,9 @@ class inference_mode(_DecoratorContextManager):
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
def clone(self) -> "inference_mode":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)
@ -315,6 +321,9 @@ class set_multithreading_enabled(_DecoratorContextManager):
torch._C._set_multithreading_enabled(self.prev)
def clone(self) -> "set_multithreading_enabled":
r"""
Create a copy of this class
"""
return self.__class__(self.mode)

View File

@ -876,6 +876,9 @@ class EnforceUnique:
self.seen = set()
def see(self, *key):
r"""
Observe a key and raise an error if it is seen multiple times.
"""
if key in self.seen:
raise RuntimeError("duplicate key: " + str(key))
self.seen.add(key)
@ -962,23 +965,27 @@ class KinetoStepTracker:
We fix this by adding a layer of abstraction before calling step()
to the kineto library. The idea is to maintain steps per requester in a dict:
```
{
"ProfilerStep": 100, # triggered by profiler step() call
"Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
"Optimizer2Step": 100,
}
```
.. code-block::
{
"ProfilerStep": 100, # triggered by profiler step() call
"Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
"Optimizer2Step": 100,
}
To figure out the global step count just take the max of dict values (100).
If one of the count increments the max will go up.
```
{
"ProfilerStep": 100,
"Optimizer1Step": 101, # Optimizer1 got incremented first say
"Optimizer2Step": 100,
}
```
.. code-block::
{
"ProfilerStep": 100,
"Optimizer1Step": 101, # Optimizer1 got incremented first say
"Optimizer2Step": 100,
}
Then global step count is 101
We only call the kineto step() function when global count increments.
@ -991,10 +998,16 @@ class KinetoStepTracker:
@classmethod
def init_step_count(cls, requester: str):
r"""
Initialize for a given requester.
"""
cls._step_dict[requester] = cls._current_step
@classmethod
def erase_step_count(cls, requester: str) -> bool:
r"""
Remove a given requester.
"""
return cls._step_dict.pop(requester, None) is not None
@classmethod
@ -1023,4 +1036,7 @@ class KinetoStepTracker:
@classmethod
def current_step(cls) -> int:
r"""
Get the latest step for any requester
"""
return cls._current_step

View File

@ -423,6 +423,9 @@ class Interval:
self.end = end
def elapsed_us(self):
r"""
Returns the length of the interval
"""
return self.end - self.start
@ -781,6 +784,9 @@ class MemRecordsAcc:
self._start_uses, self._indices = zip(*tmp) # type: ignore[assignment]
def in_interval(self, start_us, end_us):
r"""
Return all records in the given interval
"""
start_idx = bisect.bisect_left(self._start_uses, start_us)
end_idx = bisect.bisect_right(self._start_uses, end_us)
for i in range(start_idx, end_idx):