mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Change from self to self._class_() in _DecoratorManager to ensure a new object is every time a function is called recursively Fixes https://github.com/pytorch/pytorch/issues/44531 Pull Request resolved: https://github.com/pytorch/pytorch/pull/44633 Reviewed By: agolynski Differential Revision: D23783601 Pulled By: albanD fbshipit-source-id: a818664dee7bdb061a40ede27ef99e9546fc80bb
175 lines
5.1 KiB
Python
175 lines
5.1 KiB
Python
import torch
|
|
import functools
|
|
import inspect
|
|
from typing import Any, Callable, TypeVar, cast
|
|
|
|
|
|
__all__ = ['no_grad', 'enable_grad', 'set_grad_enabled']
|
|
|
|
|
|
# Used for annotating the decorator usage of 'no_grad' and 'enable_grad'.
|
|
# See https://mypy.readthedocs.io/en/latest/generics.html#declaring-decorators
|
|
FuncType = Callable[..., Any]
|
|
F = TypeVar('F', bound=FuncType)
|
|
|
|
|
|
class _DecoratorContextManager:
|
|
"""Allow a context manager to be used as a decorator"""
|
|
|
|
def __call__(self, func: F) -> F:
|
|
if inspect.isgeneratorfunction(func):
|
|
return self._wrap_generator(func)
|
|
|
|
@functools.wraps(func)
|
|
def decorate_context(*args, **kwargs):
|
|
with self.__class__():
|
|
return func(*args, **kwargs)
|
|
return cast(F, decorate_context)
|
|
|
|
def _wrap_generator(self, func):
|
|
"""Wrap each generator invocation with the context manager"""
|
|
@functools.wraps(func)
|
|
def generator_context(*args, **kwargs):
|
|
gen = func(*args, **kwargs)
|
|
while True:
|
|
try:
|
|
with self.__class__():
|
|
x = next(gen)
|
|
yield x
|
|
except StopIteration:
|
|
break
|
|
return generator_context
|
|
|
|
def __enter__(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
raise NotImplementedError
|
|
|
|
|
|
class no_grad(_DecoratorContextManager):
|
|
r"""Context-manager that disabled gradient calculation.
|
|
|
|
Disabling gradient calculation is useful for inference, when you are sure
|
|
that you will not call :meth:`Tensor.backward()`. It will reduce memory
|
|
consumption for computations that would otherwise have `requires_grad=True`.
|
|
|
|
In this mode, the result of every computation will have
|
|
`requires_grad=False`, even when the inputs have `requires_grad=True`.
|
|
|
|
This context manager is thread local; it will not affect computation
|
|
in other threads.
|
|
|
|
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
|
|
|
|
|
|
Example::
|
|
|
|
>>> x = torch.tensor([1], requires_grad=True)
|
|
>>> with torch.no_grad():
|
|
... y = x * 2
|
|
>>> y.requires_grad
|
|
False
|
|
>>> @torch.no_grad()
|
|
... def doubler(x):
|
|
... return x * 2
|
|
>>> z = doubler(x)
|
|
>>> z.requires_grad
|
|
False
|
|
"""
|
|
def __init__(self):
|
|
if not torch._jit_internal.is_scripting():
|
|
super().__init__()
|
|
self.prev = False
|
|
|
|
def __enter__(self):
|
|
self.prev = torch.is_grad_enabled()
|
|
torch.set_grad_enabled(False)
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
torch.set_grad_enabled(self.prev)
|
|
|
|
|
|
class enable_grad(_DecoratorContextManager):
|
|
r"""Context-manager that enables gradient calculation.
|
|
|
|
Enables gradient calculation, if it has been disabled via :class:`~no_grad`
|
|
or :class:`~set_grad_enabled`.
|
|
|
|
This context manager is thread local; it will not affect computation
|
|
in other threads.
|
|
|
|
Also functions as a decorator. (Make sure to instantiate with parenthesis.)
|
|
|
|
|
|
Example::
|
|
|
|
>>> x = torch.tensor([1], requires_grad=True)
|
|
>>> with torch.no_grad():
|
|
... with torch.enable_grad():
|
|
... y = x * 2
|
|
>>> y.requires_grad
|
|
True
|
|
>>> y.backward()
|
|
>>> x.grad
|
|
>>> @torch.enable_grad()
|
|
... def doubler(x):
|
|
... return x * 2
|
|
>>> with torch.no_grad():
|
|
... z = doubler(x)
|
|
>>> z.requires_grad
|
|
True
|
|
|
|
"""
|
|
def __enter__(self) -> None:
|
|
self.prev = torch.is_grad_enabled()
|
|
torch._C._set_grad_enabled(True)
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
torch._C._set_grad_enabled(self.prev)
|
|
|
|
|
|
class set_grad_enabled(object):
|
|
r"""Context-manager that sets gradient calculation to on or off.
|
|
|
|
``set_grad_enabled`` will enable or disable grads based on its argument :attr:`mode`.
|
|
It can be used as a context-manager or as a function.
|
|
|
|
This context manager is thread local; it will not affect computation
|
|
in other threads.
|
|
|
|
Arguments:
|
|
mode (bool): Flag whether to enable grad (``True``), or disable
|
|
(``False``). This can be used to conditionally enable
|
|
gradients.
|
|
|
|
|
|
Example::
|
|
|
|
>>> x = torch.tensor([1], requires_grad=True)
|
|
>>> is_train = False
|
|
>>> with torch.set_grad_enabled(is_train):
|
|
... y = x * 2
|
|
>>> y.requires_grad
|
|
False
|
|
>>> torch.set_grad_enabled(True)
|
|
>>> y = x * 2
|
|
>>> y.requires_grad
|
|
True
|
|
>>> torch.set_grad_enabled(False)
|
|
>>> y = x * 2
|
|
>>> y.requires_grad
|
|
False
|
|
|
|
"""
|
|
|
|
def __init__(self, mode: bool) -> None:
|
|
self.prev = torch.is_grad_enabled()
|
|
torch._C._set_grad_enabled(mode)
|
|
|
|
def __enter__(self) -> None:
|
|
pass
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
torch._C._set_grad_enabled(self.prev)
|