mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Autograd doc cleanup (#118500)
I don't think we'll realistically go though deprecation for these now since there are a couple use of each online. So document appropriately. Pull Request resolved: https://github.com/pytorch/pytorch/pull/118500 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
fc5cde7579
commit
a40be5f4dc
|
|
@ -33,6 +33,9 @@ for detailed steps on how to use this API.
|
|||
forward_ad.dual_level
|
||||
forward_ad.make_dual
|
||||
forward_ad.unpack_dual
|
||||
forward_ad.enter_dual_level
|
||||
forward_ad.exit_dual_level
|
||||
forward_ad.UnpackedDualTensor
|
||||
|
||||
.. _functional-api:
|
||||
|
||||
|
|
@ -209,6 +212,27 @@ When creating a new :class:`Function`, the following methods are available to `c
|
|||
function.FunctionCtx.save_for_backward
|
||||
function.FunctionCtx.set_materialize_grads
|
||||
|
||||
Custom Function utilities
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
Decorator for backward method.
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
function.once_differentiable
|
||||
|
||||
Base custom :class:`Function` used to build PyTorch utilities
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
function.BackwardCFunction
|
||||
function.InplaceFunction
|
||||
function.NestedIOFunction
|
||||
|
||||
|
||||
.. _grad-check:
|
||||
|
||||
Numerical gradient checking
|
||||
|
|
@ -224,6 +248,7 @@ Numerical gradient checking
|
|||
|
||||
gradcheck
|
||||
gradgradcheck
|
||||
GradcheckError
|
||||
|
||||
.. Just to reset the base path for the rest of this file
|
||||
.. currentmodule:: torch.autograd
|
||||
|
|
@ -249,6 +274,14 @@ and vtune profiler based using
|
|||
profiler.profile.key_averages
|
||||
profiler.profile.self_cpu_time_total
|
||||
profiler.profile.total_average
|
||||
profiler.parse_nvprof_trace
|
||||
profiler.EnforceUnique
|
||||
profiler.KinetoStepTracker
|
||||
profiler.record_function
|
||||
profiler_util.Interval
|
||||
profiler_util.Kernel
|
||||
profiler_util.MemRecordsAcc
|
||||
profiler_util.StringTable
|
||||
|
||||
.. autoclass:: torch.autograd.profiler.emit_nvtx
|
||||
.. autoclass:: torch.autograd.profiler.emit_itt
|
||||
|
|
@ -260,13 +293,20 @@ and vtune profiler based using
|
|||
|
||||
profiler.load_nvprof
|
||||
|
||||
Anomaly detection
|
||||
Debugging and anomaly detection
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
.. autoclass:: detect_anomaly
|
||||
|
||||
.. autoclass:: set_detect_anomaly
|
||||
|
||||
.. autosummary::
|
||||
:toctree: generated
|
||||
:nosignatures:
|
||||
|
||||
grad_mode.set_multithreading_enabled
|
||||
|
||||
|
||||
|
||||
Autograd graph
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
|
@ -286,6 +326,7 @@ enabled and at least one of the inputs required gradients), or ``None`` otherwis
|
|||
graph.Node.next_functions
|
||||
graph.Node.register_hook
|
||||
graph.Node.register_prehook
|
||||
graph.increment_version
|
||||
|
||||
Some operations need intermediary results to be saved during the forward pass
|
||||
in order to execute the backward pass.
|
||||
|
|
|
|||
|
|
@ -91,9 +91,6 @@ templates_path = ["_templates"]
|
|||
coverage_ignore_functions = [
|
||||
# torch
|
||||
"typename",
|
||||
# torch.autograd
|
||||
"register_py_tensor_class_for_device",
|
||||
"variable",
|
||||
# torch.cuda
|
||||
"check_error",
|
||||
"cudart",
|
||||
|
|
@ -390,20 +387,6 @@ coverage_ignore_functions = [
|
|||
"weight_dtype",
|
||||
"weight_is_quantized",
|
||||
"weight_is_statically_quantized",
|
||||
# torch.autograd.forward_ad
|
||||
"enter_dual_level",
|
||||
"exit_dual_level",
|
||||
# torch.autograd.function
|
||||
"once_differentiable",
|
||||
"traceable",
|
||||
# torch.autograd.gradcheck
|
||||
"get_analytical_jacobian",
|
||||
"get_numerical_jacobian",
|
||||
"get_numerical_jacobian_wrt_specific_input",
|
||||
# torch.autograd.graph
|
||||
"increment_version",
|
||||
# torch.autograd.profiler
|
||||
"parse_nvprof_trace",
|
||||
# torch.backends.cudnn.rnn
|
||||
"get_cudnn_mode",
|
||||
"init_dropout_state",
|
||||
|
|
@ -2530,40 +2513,6 @@ coverage_ignore_classes = [
|
|||
"QuantWrapper",
|
||||
# torch.ao.quantization.utils
|
||||
"MatchAllNode",
|
||||
# torch.autograd.forward_ad
|
||||
"UnpackedDualTensor",
|
||||
# torch.autograd.function
|
||||
"BackwardCFunction",
|
||||
"Function",
|
||||
"FunctionCtx",
|
||||
"FunctionMeta",
|
||||
"InplaceFunction",
|
||||
"NestedIOFunction",
|
||||
# torch.autograd.grad_mode
|
||||
"inference_mode",
|
||||
"set_grad_enabled",
|
||||
"set_multithreading_enabled",
|
||||
# torch.autograd.gradcheck
|
||||
"GradcheckError",
|
||||
# torch.autograd.profiler
|
||||
"EnforceUnique",
|
||||
"KinetoStepTracker",
|
||||
"profile",
|
||||
"record_function",
|
||||
# torch.autograd.profiler_legacy
|
||||
"profile",
|
||||
# torch.autograd.profiler_util
|
||||
"EventList",
|
||||
"FormattedTimesMixin",
|
||||
"FunctionEvent",
|
||||
"FunctionEventAvg",
|
||||
"Interval",
|
||||
"Kernel",
|
||||
"MemRecordsAcc",
|
||||
"StringTable",
|
||||
# torch.autograd.variable
|
||||
"Variable",
|
||||
"VariableMeta",
|
||||
# torch.backends.cudnn.rnn
|
||||
"Unserializable",
|
||||
# torch.cuda.amp.grad_scaler
|
||||
|
|
|
|||
|
|
@ -274,9 +274,9 @@ Examples::
|
|||
|
||||
no_grad
|
||||
enable_grad
|
||||
set_grad_enabled
|
||||
autograd.grad_mode.set_grad_enabled
|
||||
is_grad_enabled
|
||||
inference_mode
|
||||
autograd.grad_mode.inference_mode
|
||||
is_inference_mode_enabled
|
||||
|
||||
Math operations
|
||||
|
|
|
|||
|
|
@ -35,20 +35,6 @@
|
|||
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
|
||||
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn"
|
||||
},
|
||||
"torch.autograd": [
|
||||
"NestedIOFunction",
|
||||
"detect_anomaly",
|
||||
"enable_grad",
|
||||
"grad",
|
||||
"gradcheck",
|
||||
"gradgradcheck",
|
||||
"inference_mode",
|
||||
"no_grad",
|
||||
"set_detect_anomaly",
|
||||
"set_grad_enabled",
|
||||
"set_multithreading_enabled",
|
||||
"variable"
|
||||
],
|
||||
"torch.backends": [
|
||||
"contextmanager"
|
||||
],
|
||||
|
|
|
|||
|
|
@ -279,7 +279,14 @@ class _HookMixin:
|
|||
|
||||
|
||||
class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
|
||||
r"""
|
||||
This class is used for internal autograd work. Do not use.
|
||||
"""
|
||||
|
||||
def apply(self, *args):
|
||||
r"""
|
||||
Apply method used when executing this Node during the backward
|
||||
"""
|
||||
# _forward_cls is defined by derived class
|
||||
# The user should define either backward or vjp but never both.
|
||||
backward_fn = self._forward_cls.backward # type: ignore[attr-defined]
|
||||
|
|
@ -294,6 +301,9 @@ class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin):
|
|||
return user_fn(self, *args)
|
||||
|
||||
def apply_jvp(self, *args):
|
||||
r"""
|
||||
Apply method used when executing forward mode AD during the forward
|
||||
"""
|
||||
# _forward_cls is defined by derived class
|
||||
return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined]
|
||||
|
||||
|
|
@ -378,10 +388,10 @@ class _SingleLevelFunction(
|
|||
|
||||
Either:
|
||||
|
||||
1. Override forward with the signature forward(ctx, *args, **kwargs).
|
||||
1. Override forward with the signature ``forward(ctx, *args, **kwargs)``.
|
||||
``setup_context`` is not overridden. Setting up the ctx for backward
|
||||
happens inside the ``forward``.
|
||||
2. Override forward with the signature forward(*args, **kwargs) and
|
||||
2. Override forward with the signature ``forward(*args, **kwargs)`` and
|
||||
override ``setup_context``. Setting up the ctx for backward happens
|
||||
inside ``setup_context`` (as opposed to inside the ``forward``)
|
||||
|
||||
|
|
@ -639,6 +649,11 @@ def traceable(fn_cls):
|
|||
|
||||
|
||||
class InplaceFunction(Function):
|
||||
r"""
|
||||
This class is here only for backward compatibility reasons.
|
||||
Use :class:`Function` instead of this for any new use case.
|
||||
"""
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
|
|
@ -754,6 +769,10 @@ _map_tensor_data = _nested_map(
|
|||
|
||||
|
||||
class NestedIOFunction(Function):
|
||||
r"""
|
||||
This class is here only for backward compatibility reasons.
|
||||
Use :class:`Function` instead of this for any new use case.
|
||||
"""
|
||||
# The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the
|
||||
# superclass (Function) but are instance methods here, which mypy reports as incompatible.
|
||||
|
||||
|
|
@ -774,6 +793,9 @@ class NestedIOFunction(Function):
|
|||
return result
|
||||
|
||||
def backward(self, *gradients: Any) -> Any: # type: ignore[override]
|
||||
r"""
|
||||
Shared backward utility.
|
||||
"""
|
||||
nested_gradients = _unflatten(gradients, self._nested_output)
|
||||
result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value]
|
||||
return tuple(_iter_None_tensors(result))
|
||||
|
|
@ -781,6 +803,9 @@ class NestedIOFunction(Function):
|
|||
__call__ = _do_forward
|
||||
|
||||
def forward(self, *args: Any) -> Any: # type: ignore[override]
|
||||
r"""
|
||||
Shared forward utility.
|
||||
"""
|
||||
nested_tensors = _map_tensor_data(self._nested_input)
|
||||
result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value]
|
||||
del self._nested_input
|
||||
|
|
@ -788,22 +813,40 @@ class NestedIOFunction(Function):
|
|||
return tuple(_iter_tensors(result))
|
||||
|
||||
def save_for_backward(self, *args: Any) -> None:
|
||||
r"""
|
||||
See :meth:`Function.save_for_backward`.
|
||||
"""
|
||||
self.to_save = tuple(_iter_tensors(args))
|
||||
self._to_save_nested = args
|
||||
|
||||
@property
|
||||
def saved_tensors(self):
|
||||
r"""
|
||||
See :meth:`Function.saved_tensors`.
|
||||
"""
|
||||
flat_tensors = super().saved_tensors # type: ignore[misc]
|
||||
return _unflatten(flat_tensors, self._to_save_nested)
|
||||
|
||||
def mark_dirty(self, *args: Any, **kwargs: Any) -> None:
|
||||
r"""
|
||||
See :meth:`Function.mark_dirty`.
|
||||
"""
|
||||
self.dirty_tensors = tuple(_iter_tensors((args, kwargs)))
|
||||
|
||||
def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None:
|
||||
r"""
|
||||
See :meth:`Function.mark_non_differentiable`.
|
||||
"""
|
||||
self.non_differentiable = tuple(_iter_tensors((args, kwargs)))
|
||||
|
||||
def forward_extended(self, *input: Any) -> None:
|
||||
r"""
|
||||
User defined forward.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def backward_extended(self, *grad_output: Any) -> None:
|
||||
r"""
|
||||
User defined backward.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -196,6 +196,9 @@ class set_grad_enabled(_DecoratorContextManager):
|
|||
torch._C._set_grad_enabled(self.prev)
|
||||
|
||||
def clone(self) -> "set_grad_enabled":
|
||||
r"""
|
||||
Create a copy of this class
|
||||
"""
|
||||
return self.__class__(self.mode)
|
||||
|
||||
|
||||
|
|
@ -272,6 +275,9 @@ class inference_mode(_DecoratorContextManager):
|
|||
self._inference_mode_context.__exit__(exc_type, exc_value, traceback)
|
||||
|
||||
def clone(self) -> "inference_mode":
|
||||
r"""
|
||||
Create a copy of this class
|
||||
"""
|
||||
return self.__class__(self.mode)
|
||||
|
||||
|
||||
|
|
@ -315,6 +321,9 @@ class set_multithreading_enabled(_DecoratorContextManager):
|
|||
torch._C._set_multithreading_enabled(self.prev)
|
||||
|
||||
def clone(self) -> "set_multithreading_enabled":
|
||||
r"""
|
||||
Create a copy of this class
|
||||
"""
|
||||
return self.__class__(self.mode)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -876,6 +876,9 @@ class EnforceUnique:
|
|||
self.seen = set()
|
||||
|
||||
def see(self, *key):
|
||||
r"""
|
||||
Observe a key and raise an error if it is seen multiple times.
|
||||
"""
|
||||
if key in self.seen:
|
||||
raise RuntimeError("duplicate key: " + str(key))
|
||||
self.seen.add(key)
|
||||
|
|
@ -962,23 +965,27 @@ class KinetoStepTracker:
|
|||
|
||||
We fix this by adding a layer of abstraction before calling step()
|
||||
to the kineto library. The idea is to maintain steps per requester in a dict:
|
||||
```
|
||||
{
|
||||
"ProfilerStep": 100, # triggered by profiler step() call
|
||||
"Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
|
||||
"Optimizer2Step": 100,
|
||||
}
|
||||
```
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
"ProfilerStep": 100, # triggered by profiler step() call
|
||||
"Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc
|
||||
"Optimizer2Step": 100,
|
||||
}
|
||||
|
||||
To figure out the global step count just take the max of dict values (100).
|
||||
|
||||
If one of the count increments the max will go up.
|
||||
```
|
||||
{
|
||||
"ProfilerStep": 100,
|
||||
"Optimizer1Step": 101, # Optimizer1 got incremented first say
|
||||
"Optimizer2Step": 100,
|
||||
}
|
||||
```
|
||||
|
||||
.. code-block::
|
||||
|
||||
{
|
||||
"ProfilerStep": 100,
|
||||
"Optimizer1Step": 101, # Optimizer1 got incremented first say
|
||||
"Optimizer2Step": 100,
|
||||
}
|
||||
|
||||
Then global step count is 101
|
||||
We only call the kineto step() function when global count increments.
|
||||
|
||||
|
|
@ -991,10 +998,16 @@ class KinetoStepTracker:
|
|||
|
||||
@classmethod
|
||||
def init_step_count(cls, requester: str):
|
||||
r"""
|
||||
Initialize for a given requester.
|
||||
"""
|
||||
cls._step_dict[requester] = cls._current_step
|
||||
|
||||
@classmethod
|
||||
def erase_step_count(cls, requester: str) -> bool:
|
||||
r"""
|
||||
Remove a given requester.
|
||||
"""
|
||||
return cls._step_dict.pop(requester, None) is not None
|
||||
|
||||
@classmethod
|
||||
|
|
@ -1023,4 +1036,7 @@ class KinetoStepTracker:
|
|||
|
||||
@classmethod
|
||||
def current_step(cls) -> int:
|
||||
r"""
|
||||
Get the latest step for any requester
|
||||
"""
|
||||
return cls._current_step
|
||||
|
|
|
|||
|
|
@ -423,6 +423,9 @@ class Interval:
|
|||
self.end = end
|
||||
|
||||
def elapsed_us(self):
|
||||
r"""
|
||||
Returns the length of the interval
|
||||
"""
|
||||
return self.end - self.start
|
||||
|
||||
|
||||
|
|
@ -781,6 +784,9 @@ class MemRecordsAcc:
|
|||
self._start_uses, self._indices = zip(*tmp) # type: ignore[assignment]
|
||||
|
||||
def in_interval(self, start_us, end_us):
|
||||
r"""
|
||||
Return all records in the given interval
|
||||
"""
|
||||
start_idx = bisect.bisect_left(self._start_uses, start_us)
|
||||
end_idx = bisect.bisect_right(self._start_uses, end_us)
|
||||
for i in range(start_idx, end_idx):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user